mirror of https://github.com/google/gemma.cpp.git
Merge branch 'dev' into main
This commit is contained in:
commit
32286f0465
|
|
@ -0,0 +1,37 @@
|
|||
*.7z filter=lfs diff=lfs merge=lfs -text
|
||||
*.arrow filter=lfs diff=lfs merge=lfs -text
|
||||
*.bin filter=lfs diff=lfs merge=lfs -text
|
||||
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
||||
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
||||
*.ftz filter=lfs diff=lfs merge=lfs -text
|
||||
*.gz filter=lfs diff=lfs merge=lfs -text
|
||||
*.h5 filter=lfs diff=lfs merge=lfs -text
|
||||
*.joblib filter=lfs diff=lfs merge=lfs -text
|
||||
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
||||
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
||||
*.model filter=lfs diff=lfs merge=lfs -text
|
||||
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
||||
*.npy filter=lfs diff=lfs merge=lfs -text
|
||||
*.npz filter=lfs diff=lfs merge=lfs -text
|
||||
*.onnx filter=lfs diff=lfs merge=lfs -text
|
||||
*.ot filter=lfs diff=lfs merge=lfs -text
|
||||
*.parquet filter=lfs diff=lfs merge=lfs -text
|
||||
*.pb filter=lfs diff=lfs merge=lfs -text
|
||||
*.pickle filter=lfs diff=lfs merge=lfs -text
|
||||
*.pkl filter=lfs diff=lfs merge=lfs -text
|
||||
*.pt filter=lfs diff=lfs merge=lfs -text
|
||||
*.pth filter=lfs diff=lfs merge=lfs -text
|
||||
*.rar filter=lfs diff=lfs merge=lfs -text
|
||||
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
||||
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
||||
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
||||
*.tar filter=lfs diff=lfs merge=lfs -text
|
||||
*.tflite filter=lfs diff=lfs merge=lfs -text
|
||||
*.tgz filter=lfs diff=lfs merge=lfs -text
|
||||
*.wasm filter=lfs diff=lfs merge=lfs -text
|
||||
*.xz filter=lfs diff=lfs merge=lfs -text
|
||||
*.zip filter=lfs diff=lfs merge=lfs -text
|
||||
*.zst filter=lfs diff=lfs merge=lfs -text
|
||||
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
||||
2b-pt-sfp.sbs filter=lfs diff=lfs merge=lfs -text
|
||||
tokenizer.spm filter=lfs diff=lfs merge=lfs -text
|
||||
|
|
@ -17,12 +17,10 @@ jobs:
|
|||
fail-fast: false
|
||||
matrix:
|
||||
# When adding another, also add to copybara's github_check_runs.
|
||||
os: ['ubuntu-latest', 'macos-latest', 'windows-latest', 'ubuntu-20.04']
|
||||
os: ['ubuntu-latest', 'macos-latest', 'windows-latest']
|
||||
build_type: ['Release']
|
||||
preset: ['make', 'windows']
|
||||
exclude:
|
||||
- os: ubuntu-20.04
|
||||
preset: windows
|
||||
- os: ubuntu-latest
|
||||
preset: windows
|
||||
- os: macos-latest
|
||||
|
|
@ -62,44 +60,6 @@ jobs:
|
|||
${{ github.workspace }}/build/gemma
|
||||
${{ github.workspace }}/build/libgemma.a
|
||||
|
||||
- if: matrix.os == 'ubuntu-20.04'
|
||||
name: Upload build artifacts to Kaggle
|
||||
uses: pculliton/push-kaggle-dataset@v1.0.0
|
||||
env:
|
||||
KAGGLE_USERNAME: ${{ secrets.KAGGLE_USERNAME }}
|
||||
KAGGLE_KEY: ${{ secrets.KAGGLE_KEY }}
|
||||
with:
|
||||
id: "phillipculliton/gemma-build-artifacts"
|
||||
files: |
|
||||
build/gemma
|
||||
build/_deps/sentencepiece-build/src/libsentencepiece.so.0
|
||||
|
||||
- if: matrix.os == 'ubuntu-20.04'
|
||||
name: Create code for new test notebook version
|
||||
run: |
|
||||
cat > runner.py << EOF
|
||||
import subprocess
|
||||
subprocess.run(["cp", "/kaggle/input/gemma-build-artifacts/gemma", "/kaggle/working"])
|
||||
subprocess.run(["chmod", "700", "/kaggle/working/gemma"])
|
||||
subprocess.run(["cp", "/kaggle/input/gemma-build-artifacts/_deps/sentencepiece-build/src/libsentencepiece.so.0", "/kaggle/working"])
|
||||
output = subprocess.run(["/kaggle/working/gemma", "--tokenizer", "/kaggle/input/gemma/gemmacpp/2b-it-sfp/4/tokenizer.spm", "--compressed_weights", "/kaggle/input/gemma/gemmacpp/2b-it-sfp/4/2b-it-sfp.sbs", "--model", "2b-it", "--verbosity", "0", "--max_generated_tokens", "128"], stdout=subprocess.PIPE, input='Write an email to the moon.', encoding='ascii').stdout
|
||||
assert("write an email to the moon." not in output.lower());
|
||||
assert("moon" in output.lower());
|
||||
EOF
|
||||
|
||||
- if: matrix.os == 'ubuntu-20.04'
|
||||
name: Run kaggle test notebook
|
||||
uses: pculliton/kaggle-action@v1.0.28
|
||||
with:
|
||||
username: ${{ secrets.KAGGLE_USERNAME }}
|
||||
key: ${{ secrets.KAGGLE_KEY }}
|
||||
title: GemmaCPP-CI-2
|
||||
code_file: runner.py
|
||||
dataset_sources: "phillipculliton/gemma-build-artifacts"
|
||||
model_sources: "google/gemma/gemmaCpp/2b-it-sfp/4"
|
||||
enable_gpu: False
|
||||
kernel_type: script
|
||||
|
||||
bazel:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
|
|
@ -116,4 +76,4 @@ jobs:
|
|||
with:
|
||||
path: ~/.cache/bazel
|
||||
key: bazel-${{ runner.os }}
|
||||
- run: bazel build --cxxopt=-std=c++20 //:all
|
||||
- run: bazel build --cxxopt=-std=c++20 //:gemma --jobs=10 --show_progress_rate_limit=1
|
||||
|
|
|
|||
|
|
@ -1,4 +1,25 @@
|
|||
# Build directories
|
||||
.cache/
|
||||
bazel-*/
|
||||
build-*/
|
||||
build/
|
||||
|
||||
# Python cache
|
||||
python/*/__pycache__
|
||||
|
||||
# Model files
|
||||
*.sbs
|
||||
*.spm
|
||||
*.data
|
||||
*.bin
|
||||
*.weights
|
||||
|
||||
# IDE and editor files
|
||||
.vscode/
|
||||
.idea/
|
||||
*.swp
|
||||
*~
|
||||
|
||||
# Local development
|
||||
.env
|
||||
.env.local
|
||||
|
|
@ -0,0 +1,15 @@
|
|||
{
|
||||
"configurations": [
|
||||
{
|
||||
"name": "Linux",
|
||||
"includePath": [
|
||||
"${workspaceFolder}/**"
|
||||
],
|
||||
"defines": [],
|
||||
"cStandard": "c17",
|
||||
"cppStandard": "c++17",
|
||||
"intelliSenseMode": "linux-clang-x64"
|
||||
}
|
||||
],
|
||||
"version": 4
|
||||
}
|
||||
630
BUILD.bazel
630
BUILD.bazel
|
|
@ -19,7 +19,10 @@ license(
|
|||
# Dual-licensed Apache 2 and 3-clause BSD.
|
||||
licenses(["notice"])
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
exports_files([
|
||||
"LICENSE",
|
||||
".github/workflows/build.yml",
|
||||
])
|
||||
|
||||
cc_library(
|
||||
name = "basics",
|
||||
|
|
@ -29,6 +32,16 @@ cc_library(
|
|||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "args",
|
||||
hdrs = ["util/args.h"],
|
||||
deps = [
|
||||
":basics",
|
||||
"//io", # Path
|
||||
"@highway//:hwy",
|
||||
],
|
||||
)
|
||||
|
||||
# Split from :threading to break a circular dependency with :allocator.
|
||||
cc_library(
|
||||
name = "topology",
|
||||
|
|
@ -59,6 +72,7 @@ cc_library(
|
|||
hdrs = ["util/threading.h"],
|
||||
deps = [
|
||||
":allocator",
|
||||
":args",
|
||||
":basics",
|
||||
":topology",
|
||||
# Placeholder for container detection, do not remove
|
||||
|
|
@ -68,14 +82,28 @@ cc_library(
|
|||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "threading_context",
|
||||
srcs = ["util/threading_context.cc"],
|
||||
hdrs = ["util/threading_context.h"],
|
||||
deps = [
|
||||
":allocator",
|
||||
":args",
|
||||
":basics",
|
||||
":threading",
|
||||
":topology",
|
||||
"@highway//:hwy",
|
||||
"@highway//:profiler",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "threading_test",
|
||||
srcs = ["util/threading_test.cc"],
|
||||
deps = [
|
||||
":allocator",
|
||||
":basics",
|
||||
":threading",
|
||||
"@googletest//:gtest_main",
|
||||
":threading_context",
|
||||
"@googletest//:gtest_main", # buildcleaner: keep
|
||||
"@highway//:auto_tune",
|
||||
"@highway//:hwy",
|
||||
"@highway//:hwy_test_util",
|
||||
|
|
@ -97,6 +125,124 @@ cc_library(
|
|||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "configs",
|
||||
srcs = ["gemma/configs.cc"],
|
||||
hdrs = ["gemma/configs.h"],
|
||||
deps = [
|
||||
":basics",
|
||||
"//compression:types",
|
||||
"//io",
|
||||
"//io:fields",
|
||||
"@highway//:hwy", # base.h
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "configs_test",
|
||||
srcs = ["gemma/configs_test.cc"],
|
||||
deps = [
|
||||
":configs",
|
||||
"@googletest//:gtest_main", # buildcleaner: keep
|
||||
"//compression:types",
|
||||
"//io:fields",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tensor_info",
|
||||
srcs = ["gemma/tensor_info.cc"],
|
||||
hdrs = ["gemma/tensor_info.h"],
|
||||
deps = [
|
||||
":basics",
|
||||
":configs",
|
||||
"//compression:types",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "mat",
|
||||
srcs = ["util/mat.cc"],
|
||||
hdrs = ["util/mat.h"],
|
||||
deps = [
|
||||
":allocator",
|
||||
":basics",
|
||||
":tensor_info",
|
||||
":threading_context",
|
||||
"//compression:types",
|
||||
"//io:fields",
|
||||
"@highway//:hwy",
|
||||
"@highway//:profiler",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tokenizer",
|
||||
srcs = ["gemma/tokenizer.cc"],
|
||||
hdrs = ["gemma/tokenizer.h"],
|
||||
deps = [
|
||||
":configs",
|
||||
"@highway//:hwy",
|
||||
"@highway//:profiler",
|
||||
"@com_google_sentencepiece//:sentencepiece_processor",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "model_store",
|
||||
srcs = ["gemma/model_store.cc"],
|
||||
hdrs = ["gemma/model_store.h"],
|
||||
deps = [
|
||||
":allocator",
|
||||
":basics",
|
||||
":configs",
|
||||
":mat",
|
||||
":tensor_info",
|
||||
":threading_context",
|
||||
":tokenizer",
|
||||
"//compression:types",
|
||||
"//io",
|
||||
"//io:blob_store",
|
||||
"//io:fields",
|
||||
"@highway//:hwy",
|
||||
"@highway//:thread_pool",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "weights",
|
||||
srcs = ["gemma/weights.cc"],
|
||||
hdrs = ["gemma/weights.h"],
|
||||
deps = [
|
||||
":configs",
|
||||
":gemma_args",
|
||||
":mat",
|
||||
":matmul",
|
||||
":model_store",
|
||||
":tensor_info",
|
||||
":threading_context",
|
||||
"//compression:compress",
|
||||
"//io:blob_store",
|
||||
"@highway//:hwy",
|
||||
"@highway//:profiler",
|
||||
"@highway//:thread_pool",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "tensor_info_test",
|
||||
srcs = ["gemma/tensor_info_test.cc"],
|
||||
deps = [
|
||||
":configs",
|
||||
":mat",
|
||||
":tensor_info",
|
||||
":weights",
|
||||
"@googletest//:gtest_main", # buildcleaner: keep
|
||||
"//compression:compress",
|
||||
"@highway//:hwy", # aligned_allocator.h
|
||||
],
|
||||
)
|
||||
|
||||
# For building all tests in one command, so we can test several.
|
||||
test_suite(
|
||||
name = "ops_tests",
|
||||
|
|
@ -104,34 +250,75 @@ test_suite(
|
|||
)
|
||||
|
||||
cc_library(
|
||||
name = "ops",
|
||||
name = "matmul",
|
||||
srcs = ["ops/matmul.cc"],
|
||||
hdrs = ["ops/matmul.h"],
|
||||
textual_hdrs = ["ops/matmul-inl.h"],
|
||||
deps = [
|
||||
":allocator",
|
||||
":basics",
|
||||
":mat",
|
||||
":threading_context",
|
||||
"//compression:compress",
|
||||
"@highway//:bit_set",
|
||||
"@highway//:hwy",
|
||||
"@highway//:nanobenchmark",
|
||||
"@highway//:profiler",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "matmul_static",
|
||||
srcs = [
|
||||
"ops/matmul.cc",
|
||||
# single-file build time is ~30sec for msan, hence shard.
|
||||
"ops/matmul_static_bf16.cc",
|
||||
"ops/matmul_static_f32.cc",
|
||||
"ops/matmul_static_nuq.cc",
|
||||
"ops/matmul_static_sfp.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"ops/matmul.h",
|
||||
"ops/ops.h",
|
||||
"ops/matmul_static.h",
|
||||
],
|
||||
textual_hdrs = [
|
||||
"ops/matmul_static-inl.h",
|
||||
"ops/matmul-inl.h",
|
||||
],
|
||||
deps = [
|
||||
":allocator",
|
||||
":basics",
|
||||
":mat",
|
||||
":matmul",
|
||||
":threading_context",
|
||||
"//compression:compress",
|
||||
"//compression:types",
|
||||
"@highway//:hwy",
|
||||
"@highway//:profiler",
|
||||
"@highway//:timer",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "ops",
|
||||
hdrs = ["ops/ops.h"],
|
||||
textual_hdrs = [
|
||||
"ops/dot-inl.h",
|
||||
"ops/sum-inl.h",
|
||||
"ops/fp_arith-inl.h",
|
||||
"ops/matmul-inl.h",
|
||||
"ops/matvec-inl.h",
|
||||
"ops/ops-inl.h",
|
||||
],
|
||||
deps = [
|
||||
":allocator",
|
||||
":basics",
|
||||
":threading",
|
||||
":topology",
|
||||
":mat",
|
||||
":matmul",
|
||||
":matmul_static",
|
||||
":threading_context",
|
||||
"//compression:compress",
|
||||
"@highway//:algo",
|
||||
"@highway//:bit_set",
|
||||
"@highway//:hwy",
|
||||
"@highway//:math",
|
||||
"@highway//:matvec",
|
||||
"@highway//:nanobenchmark",
|
||||
"@highway//:profiler",
|
||||
"@highway//:thread_pool",
|
||||
"@highway//hwy/contrib/sort:vqsort",
|
||||
|
|
@ -143,15 +330,15 @@ cc_test(
|
|||
size = "small",
|
||||
timeout = "long",
|
||||
srcs = ["ops/dot_test.cc"],
|
||||
linkstatic = True,
|
||||
local_defines = ["HWY_IS_TEST"],
|
||||
# for test_suite.
|
||||
tags = ["ops_tests"],
|
||||
deps = [
|
||||
":allocator",
|
||||
":app",
|
||||
":ops",
|
||||
":test_util",
|
||||
":threading",
|
||||
":threading_context",
|
||||
"@googletest//:gtest_main", # buildcleaner: keep
|
||||
"//compression:compress",
|
||||
"//compression:test_util",
|
||||
|
|
@ -167,20 +354,23 @@ cc_test(
|
|||
cc_test(
|
||||
name = "ops_test",
|
||||
size = "small",
|
||||
timeout = "eternal",
|
||||
timeout = "long",
|
||||
srcs = ["ops/ops_test.cc"],
|
||||
linkstatic = True,
|
||||
local_defines = ["HWY_IS_TEST"],
|
||||
# for test_suite.
|
||||
tags = ["ops_tests"],
|
||||
deps = [
|
||||
":allocator",
|
||||
":app",
|
||||
":common",
|
||||
":basics",
|
||||
":configs",
|
||||
":gemma_lib",
|
||||
":mat",
|
||||
":ops",
|
||||
":test_util",
|
||||
":threading",
|
||||
":threading_context",
|
||||
"@googletest//:gtest_main", # buildcleaner: keep
|
||||
"//compression:compress",
|
||||
"//compression:types",
|
||||
"@highway//:hwy",
|
||||
"@highway//:hwy_test_util",
|
||||
"@highway//:nanobenchmark", #buildcleaner: keep
|
||||
|
|
@ -192,11 +382,14 @@ cc_test(
|
|||
size = "small",
|
||||
timeout = "long",
|
||||
srcs = ["ops/gemma_matvec_test.cc"],
|
||||
linkstatic = True,
|
||||
local_defines = ["HWY_IS_TEST"],
|
||||
# for test_suite.
|
||||
tags = ["ops_tests"],
|
||||
deps = [
|
||||
":mat",
|
||||
":ops",
|
||||
":threading_context",
|
||||
"@googletest//:gtest_main", # buildcleaner: keep
|
||||
"//compression:compress",
|
||||
"@highway//:hwy",
|
||||
|
|
@ -210,18 +403,23 @@ cc_test(
|
|||
size = "small",
|
||||
timeout = "long",
|
||||
srcs = ["ops/matmul_test.cc"],
|
||||
linkstatic = True,
|
||||
local_defines = ["HWY_IS_TEST"],
|
||||
# for test_suite.
|
||||
tags = ["ops_tests"],
|
||||
deps = [
|
||||
":allocator",
|
||||
":basics",
|
||||
":mat",
|
||||
":matmul",
|
||||
":matmul_static",
|
||||
":ops",
|
||||
":threading",
|
||||
":threading_context",
|
||||
"@googletest//:gtest_main", # buildcleaner: keep
|
||||
"//compression:compress",
|
||||
"//compression:test_util",
|
||||
"@highway//:hwy",
|
||||
"@highway//:hwy_test_util",
|
||||
"@highway//:nanobenchmark",
|
||||
"@highway//:thread_pool",
|
||||
],
|
||||
)
|
||||
|
|
@ -231,6 +429,7 @@ cc_test(
|
|||
size = "small",
|
||||
timeout = "long",
|
||||
srcs = ["ops/bench_matmul.cc"],
|
||||
linkstatic = True,
|
||||
local_defines = ["HWY_IS_TEST"],
|
||||
tags = [
|
||||
"manual",
|
||||
|
|
@ -238,12 +437,12 @@ cc_test(
|
|||
"ops_tests", # for test_suite.
|
||||
],
|
||||
deps = [
|
||||
":allocator",
|
||||
":basics",
|
||||
":ops",
|
||||
":threading",
|
||||
":matmul",
|
||||
":threading_context",
|
||||
"@googletest//:gtest_main", # buildcleaner: keep
|
||||
"//compression:compress",
|
||||
"//compression:test_util",
|
||||
"@highway//:hwy",
|
||||
"@highway//:hwy_test_util",
|
||||
"@highway//:nanobenchmark",
|
||||
|
|
@ -252,101 +451,47 @@ cc_test(
|
|||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "common",
|
||||
srcs = [
|
||||
"gemma/common.cc",
|
||||
"gemma/configs.cc",
|
||||
"gemma/tensor_index.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"gemma/common.h",
|
||||
"gemma/configs.h",
|
||||
"gemma/tensor_index.h",
|
||||
],
|
||||
deps = [
|
||||
"//compression:fields",
|
||||
"//compression:sfp",
|
||||
"@highway//:hwy", # base.h
|
||||
"@highway//:thread_pool",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "configs_test",
|
||||
srcs = ["gemma/configs_test.cc"],
|
||||
deps = [
|
||||
":common",
|
||||
"@googletest//:gtest_main",
|
||||
"@highway//:hwy",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "tensor_index_test",
|
||||
srcs = ["gemma/tensor_index_test.cc"],
|
||||
deps = [
|
||||
":basics",
|
||||
":common",
|
||||
":weights",
|
||||
"@googletest//:gtest_main",
|
||||
"//compression:compress",
|
||||
"@highway//:hwy",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "weights",
|
||||
srcs = ["gemma/weights.cc"],
|
||||
hdrs = ["gemma/weights.h"],
|
||||
deps = [
|
||||
":common",
|
||||
"//compression:blob_store",
|
||||
"//compression:compress",
|
||||
"//compression:io",
|
||||
"@highway//:hwy",
|
||||
"@highway//:profiler",
|
||||
"@highway//:stats",
|
||||
"@highway//:thread_pool",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tokenizer",
|
||||
srcs = ["gemma/tokenizer.cc"],
|
||||
hdrs = ["gemma/tokenizer.h"],
|
||||
deps = [
|
||||
":common",
|
||||
"//compression:io",
|
||||
"//compression:sfp",
|
||||
"@highway//:hwy",
|
||||
"@highway//:profiler",
|
||||
"@com_google_sentencepiece//:sentencepiece_processor",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "kv_cache",
|
||||
srcs = ["gemma/kv_cache.cc"],
|
||||
hdrs = ["gemma/kv_cache.h"],
|
||||
deps = [
|
||||
":common",
|
||||
":basics",
|
||||
":configs",
|
||||
":gemma_args",
|
||||
":mat",
|
||||
"@highway//:hwy",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "gemma_args",
|
||||
hdrs = ["gemma/gemma_args.h"],
|
||||
deps = [
|
||||
":args",
|
||||
":basics",
|
||||
":mat",
|
||||
":matmul",
|
||||
"//io",
|
||||
"@highway//:hwy",
|
||||
"@highway//:profiler",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "gemma_lib",
|
||||
srcs = [
|
||||
"gemma/attention.cc",
|
||||
"gemma/gemma.cc",
|
||||
"gemma/instantiations/bf16.cc",
|
||||
"gemma/instantiations/f32.cc",
|
||||
"gemma/instantiations/nuq.cc",
|
||||
"gemma/instantiations/sfp.cc",
|
||||
"gemma/griffin.cc",
|
||||
"gemma/vit.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"gemma/activations.h",
|
||||
"gemma/attention.h",
|
||||
"gemma/gemma.h",
|
||||
"gemma/griffin.h",
|
||||
"gemma/vit.h",
|
||||
],
|
||||
exec_properties = {
|
||||
# Avoid linker OOMs when building with sanitizer instrumentation.
|
||||
|
|
@ -354,23 +499,26 @@ cc_library(
|
|||
},
|
||||
textual_hdrs = [
|
||||
"gemma/gemma-inl.h",
|
||||
# Placeholder for internal file2, do not remove,
|
||||
],
|
||||
deps = [
|
||||
":allocator",
|
||||
":basics",
|
||||
":common",
|
||||
":ops",
|
||||
":tokenizer",
|
||||
":configs",
|
||||
":gemma_args",
|
||||
":kv_cache",
|
||||
":weights",
|
||||
":mat",
|
||||
":matmul",
|
||||
":model_store",
|
||||
":ops",
|
||||
":threading",
|
||||
":threading_context",
|
||||
":weights",
|
||||
"//compression:compress",
|
||||
"//compression:io",
|
||||
"//compression:sfp",
|
||||
"//compression:types",
|
||||
"//io:blob_store",
|
||||
"//io",
|
||||
"//paligemma:image",
|
||||
"@highway//:hwy",
|
||||
"@highway//:bit_set",
|
||||
"@highway//:nanobenchmark", # timer
|
||||
"@highway//:profiler",
|
||||
"@highway//:thread_pool",
|
||||
|
|
@ -382,35 +530,9 @@ cc_library(
|
|||
srcs = ["evals/cross_entropy.cc"],
|
||||
hdrs = ["evals/cross_entropy.h"],
|
||||
deps = [
|
||||
":common",
|
||||
":gemma_lib",
|
||||
":ops",
|
||||
"@highway//:hwy",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "args",
|
||||
hdrs = ["util/args.h"],
|
||||
deps = [
|
||||
":basics",
|
||||
"//compression:io",
|
||||
"@highway//:hwy",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "app",
|
||||
hdrs = ["util/app.h"],
|
||||
deps = [
|
||||
":args",
|
||||
":basics",
|
||||
":common",
|
||||
":gemma_lib",
|
||||
":ops",
|
||||
":threading",
|
||||
"//compression:io",
|
||||
"//compression:sfp",
|
||||
"//compression:types",
|
||||
"@highway//:hwy",
|
||||
],
|
||||
)
|
||||
|
|
@ -420,20 +542,49 @@ cc_library(
|
|||
srcs = ["evals/benchmark_helper.cc"],
|
||||
hdrs = ["evals/benchmark_helper.h"],
|
||||
deps = [
|
||||
":app",
|
||||
":args",
|
||||
":common",
|
||||
":configs",
|
||||
":cross_entropy",
|
||||
":gemma_args",
|
||||
":gemma_lib",
|
||||
":kv_cache",
|
||||
":matmul",
|
||||
":ops",
|
||||
":threading",
|
||||
# Placeholder for internal dep, do not remove.,
|
||||
":threading_context",
|
||||
":tokenizer",
|
||||
"@google_benchmark//:benchmark",
|
||||
"//compression:compress",
|
||||
"@highway//:hwy",
|
||||
"@highway//:nanobenchmark",
|
||||
"@highway//:topology",
|
||||
"@highway//:profiler",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "gemma_shared_lib",
|
||||
srcs = [
|
||||
"gemma/bindings/c_api.cc",
|
||||
"gemma/bindings/context.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"gemma/bindings/c_api.h",
|
||||
"gemma/bindings/context.h",
|
||||
],
|
||||
exec_properties = {
|
||||
# Avoid linker OOMs when building with sanitizer instrumentation.
|
||||
"mem": "28g",
|
||||
},
|
||||
deps = [
|
||||
":benchmark_helper",
|
||||
":gemma_args",
|
||||
":gemma_lib",
|
||||
":kv_cache",
|
||||
":matmul",
|
||||
":threading",
|
||||
":threading_context",
|
||||
":tokenizer",
|
||||
"//paligemma:image",
|
||||
"@highway//:hwy",
|
||||
"@highway//:profiler",
|
||||
"@highway//:timer",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
@ -449,9 +600,10 @@ cc_test(
|
|||
],
|
||||
deps = [
|
||||
":benchmark_helper",
|
||||
":common",
|
||||
":configs",
|
||||
":gemma_lib",
|
||||
"@googletest//:gtest_main",
|
||||
"@googletest//:gtest_main", # buildcleaner: keep
|
||||
"//io",
|
||||
"@highway//:hwy",
|
||||
"@highway//:hwy_test_util",
|
||||
],
|
||||
|
|
@ -460,6 +612,7 @@ cc_test(
|
|||
cc_test(
|
||||
name = "gemma_batch_bench",
|
||||
srcs = ["evals/gemma_batch_bench.cc"],
|
||||
linkstatic = True,
|
||||
# Requires model files
|
||||
tags = [
|
||||
"local",
|
||||
|
|
@ -468,12 +621,12 @@ cc_test(
|
|||
],
|
||||
deps = [
|
||||
":benchmark_helper",
|
||||
":common",
|
||||
":gemma_lib",
|
||||
":tokenizer",
|
||||
"@googletest//:gtest_main",
|
||||
"@googletest//:gtest_main", # buildcleaner: keep
|
||||
"@highway//:hwy",
|
||||
"@highway//:hwy_test_util",
|
||||
"@highway//:nanobenchmark",
|
||||
"@highway//:profiler",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
@ -481,15 +634,13 @@ cc_binary(
|
|||
name = "gemma",
|
||||
srcs = ["gemma/run.cc"],
|
||||
deps = [
|
||||
":app",
|
||||
":args",
|
||||
":benchmark_helper",
|
||||
":common",
|
||||
":gemma_args",
|
||||
":gemma_lib",
|
||||
":ops",
|
||||
":threading",
|
||||
# Placeholder for internal dep, do not remove.,
|
||||
"//compression:sfp",
|
||||
":matmul",
|
||||
":tokenizer",
|
||||
"//compression:types",
|
||||
"//paligemma:image",
|
||||
"@highway//:hwy",
|
||||
"@highway//:profiler",
|
||||
|
|
@ -502,22 +653,15 @@ cc_binary(
|
|||
deps = [
|
||||
":args",
|
||||
":benchmark_helper",
|
||||
":common",
|
||||
":cross_entropy",
|
||||
":gemma_lib",
|
||||
"//compression:io",
|
||||
"//io",
|
||||
"@highway//:hwy",
|
||||
"@highway//:nanobenchmark",
|
||||
"@nlohmann_json//:json",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "benchmark_prompts",
|
||||
hdrs = ["evals/prompts.h"],
|
||||
deps = ["@highway//:hwy"],
|
||||
)
|
||||
|
||||
cc_binary(
|
||||
name = "benchmarks",
|
||||
srcs = [
|
||||
|
|
@ -526,7 +670,6 @@ cc_binary(
|
|||
],
|
||||
deps = [
|
||||
":benchmark_helper",
|
||||
":benchmark_prompts",
|
||||
"@google_benchmark//:benchmark",
|
||||
"@highway//:hwy", # base.h
|
||||
],
|
||||
|
|
@ -534,14 +677,12 @@ cc_binary(
|
|||
|
||||
cc_binary(
|
||||
name = "debug_prompt",
|
||||
srcs = [
|
||||
"evals/debug_prompt.cc",
|
||||
],
|
||||
srcs = ["evals/debug_prompt.cc"],
|
||||
deps = [
|
||||
":args",
|
||||
":benchmark_helper",
|
||||
":gemma_lib",
|
||||
"//compression:io",
|
||||
"//io",
|
||||
"@highway//:hwy",
|
||||
"@nlohmann_json//:json",
|
||||
],
|
||||
|
|
@ -554,158 +695,9 @@ cc_binary(
|
|||
":args",
|
||||
":benchmark_helper",
|
||||
":gemma_lib",
|
||||
"//compression:io",
|
||||
"//io",
|
||||
"@highway//:hwy",
|
||||
"@highway//:profiler",
|
||||
"@highway//:thread_pool",
|
||||
"@nlohmann_json//:json",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "prompt",
|
||||
hdrs = ["backprop/prompt.h"],
|
||||
deps = [],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "sampler",
|
||||
hdrs = ["backprop/sampler.h"],
|
||||
deps = [
|
||||
":prompt",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "backprop",
|
||||
srcs = [
|
||||
"backprop/backward.cc",
|
||||
"backprop/forward.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"backprop/activations.h",
|
||||
"backprop/backward.h",
|
||||
"backprop/forward.h",
|
||||
],
|
||||
textual_hdrs = [
|
||||
"backprop/backward-inl.h",
|
||||
"backprop/forward-inl.h",
|
||||
],
|
||||
deps = [
|
||||
":allocator",
|
||||
":common",
|
||||
":ops",
|
||||
":prompt",
|
||||
":weights",
|
||||
"//compression:compress",
|
||||
"@highway//:dot",
|
||||
"@highway//:hwy", # base.h
|
||||
"@highway//:thread_pool",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "backprop_scalar",
|
||||
hdrs = [
|
||||
"backprop/activations.h",
|
||||
"backprop/backward_scalar.h",
|
||||
"backprop/common_scalar.h",
|
||||
"backprop/forward_scalar.h",
|
||||
],
|
||||
deps = [
|
||||
":common",
|
||||
":prompt",
|
||||
":weights",
|
||||
"//compression:compress",
|
||||
"@highway//:hwy",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "backward_scalar_test",
|
||||
size = "large",
|
||||
srcs = [
|
||||
"backprop/backward_scalar_test.cc",
|
||||
"backprop/test_util.h",
|
||||
],
|
||||
deps = [
|
||||
":backprop_scalar",
|
||||
":common",
|
||||
":prompt",
|
||||
":sampler",
|
||||
":weights",
|
||||
"@googletest//:gtest_main",
|
||||
"//compression:compress",
|
||||
"@highway//:thread_pool",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "backward_test",
|
||||
size = "large",
|
||||
srcs = [
|
||||
"backprop/backward_test.cc",
|
||||
"backprop/test_util.h",
|
||||
],
|
||||
exec_properties = {
|
||||
# Avoid linker OOMs when building with sanitizer instrumentation.
|
||||
"mem": "28g",
|
||||
},
|
||||
deps = [
|
||||
":allocator",
|
||||
":backprop",
|
||||
":backprop_scalar",
|
||||
":common",
|
||||
":ops",
|
||||
":prompt",
|
||||
":sampler",
|
||||
":threading",
|
||||
":weights",
|
||||
"@googletest//:gtest_main",
|
||||
"//compression:compress",
|
||||
"@highway//:hwy",
|
||||
"@highway//:hwy_test_util",
|
||||
"@highway//:thread_pool",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "optimizer",
|
||||
srcs = ["backprop/optimizer.cc"],
|
||||
hdrs = ["backprop/optimizer.h"],
|
||||
deps = [
|
||||
":allocator",
|
||||
":common",
|
||||
":weights",
|
||||
"//compression:compress",
|
||||
"@highway//:hwy",
|
||||
"@highway//:thread_pool",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "optimize_test",
|
||||
srcs = [
|
||||
"backprop/optimize_test.cc",
|
||||
],
|
||||
exec_properties = {
|
||||
# Avoid linker OOMs when building with sanitizer instrumentation.
|
||||
"mem": "28g",
|
||||
},
|
||||
deps = [
|
||||
":allocator",
|
||||
":backprop",
|
||||
":basics",
|
||||
":common",
|
||||
":gemma_lib",
|
||||
":ops",
|
||||
":optimizer",
|
||||
":prompt",
|
||||
":sampler",
|
||||
":threading",
|
||||
":weights",
|
||||
"@googletest//:gtest_main",
|
||||
"//compression:sfp",
|
||||
"@highway//:thread_pool",
|
||||
],
|
||||
)
|
||||
|
|
|
|||
114
CMakeLists.txt
114
CMakeLists.txt
|
|
@ -22,7 +22,7 @@ set(CMAKE_CXX_STANDARD 17)
|
|||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
||||
|
||||
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG c5bebf84ad01edec97e336f5c97ca4e0df6b4d06 EXCLUDE_FROM_ALL)
|
||||
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 9414b48aeec251b69e6cadbfa42bebb5ddae1c34 EXCLUDE_FROM_ALL)
|
||||
FetchContent_MakeAvailable(highway)
|
||||
|
||||
## Note: absl needs to be installed by sentencepiece. This will only happen if
|
||||
|
|
@ -39,58 +39,54 @@ set(BENCHMARK_ENABLE_GTEST_TESTS OFF)
|
|||
FetchContent_Declare(benchmark GIT_REPOSITORY https://github.com/google/benchmark.git GIT_TAG v1.8.2 EXCLUDE_FROM_ALL)
|
||||
FetchContent_MakeAvailable(benchmark)
|
||||
|
||||
# Base source files
|
||||
set(SOURCES
|
||||
compression/blob_store.cc
|
||||
compression/blob_store.h
|
||||
compression/compress-inl.h
|
||||
compression/compress.cc
|
||||
compression/compress.h
|
||||
compression/compress-inl.h
|
||||
compression/fields.cc
|
||||
compression/fields.h
|
||||
compression/io_win.cc
|
||||
compression/io.cc
|
||||
compression/io.h
|
||||
compression/nuq-inl.h
|
||||
compression/sfp-inl.h
|
||||
compression/shared.h
|
||||
compression/types.h
|
||||
compression/test_util-inl.h
|
||||
backprop/activations.h
|
||||
backprop/backward.cc
|
||||
backprop/backward.h
|
||||
backprop/backward-inl.h
|
||||
backprop/backward_scalar.h
|
||||
backprop/common_scalar.h
|
||||
backprop/forward.cc
|
||||
backprop/forward.h
|
||||
backprop/forward-inl.h
|
||||
backprop/forward_scalar.h
|
||||
backprop/optimizer.cc
|
||||
backprop/optimizer.h
|
||||
evals/benchmark_helper.cc
|
||||
evals/benchmark_helper.h
|
||||
evals/cross_entropy.cc
|
||||
evals/cross_entropy.h
|
||||
gemma/activations.h
|
||||
gemma/common.cc
|
||||
gemma/common.h
|
||||
gemma/attention.cc
|
||||
gemma/attention.h
|
||||
gemma/configs.cc
|
||||
gemma/configs.h
|
||||
gemma/gemma_args.h
|
||||
gemma/gemma-inl.h
|
||||
gemma/gemma.cc
|
||||
gemma/gemma.h
|
||||
gemma/instantiations/bf16.cc
|
||||
gemma/instantiations/f32.cc
|
||||
gemma/instantiations/nuq.cc
|
||||
gemma/instantiations/sfp.cc
|
||||
gemma/griffin.cc
|
||||
gemma/griffin.h
|
||||
gemma/kv_cache.cc
|
||||
gemma/kv_cache.h
|
||||
gemma/tensor_index.cc
|
||||
gemma/tensor_index.h
|
||||
gemma/model_store.cc
|
||||
gemma/model_store.h
|
||||
gemma/tensor_info.cc
|
||||
gemma/tensor_info.h
|
||||
gemma/tokenizer.cc
|
||||
gemma/tokenizer.h
|
||||
gemma/vit.cc
|
||||
gemma/vit.h
|
||||
gemma/weights.cc
|
||||
gemma/weights.h
|
||||
io/blob_store.cc
|
||||
io/blob_store.h
|
||||
io/fields.cc
|
||||
io/fields.h
|
||||
io/io_win.cc
|
||||
io/io.cc
|
||||
io/io.h
|
||||
ops/dot-inl.h
|
||||
ops/matmul_static_bf16.cc
|
||||
ops/matmul_static_f32.cc
|
||||
ops/matmul_static_nuq.cc
|
||||
ops/matmul_static_sfp.cc
|
||||
ops/matmul-inl.h
|
||||
ops/matmul.cc
|
||||
ops/matmul.h
|
||||
|
|
@ -102,15 +98,28 @@ set(SOURCES
|
|||
paligemma/image.h
|
||||
util/allocator.cc
|
||||
util/allocator.h
|
||||
util/app.h
|
||||
util/args.h
|
||||
util/basics.h
|
||||
util/mat.cc
|
||||
util/mat.h
|
||||
util/test_util.h
|
||||
util/threading_context.cc
|
||||
util/threading_context.h
|
||||
util/threading.cc
|
||||
util/threading.h
|
||||
util/topology.cc
|
||||
util/topology.h
|
||||
)
|
||||
|
||||
# Add C API sources only when building DLL
|
||||
if(BUILD_GEMMA_DLL)
|
||||
list(APPEND SOURCES
|
||||
gemma/bindings/context.h
|
||||
gemma/bindings/context.cc
|
||||
gemma/bindings/c_api.h
|
||||
gemma/bindings/c_api.cc
|
||||
)
|
||||
message(STATUS "Including C API files for DLL build")
|
||||
endif()
|
||||
|
||||
if(NOT CMAKE_BUILD_TYPE)
|
||||
set(CMAKE_BUILD_TYPE "Release")
|
||||
|
|
@ -131,6 +140,33 @@ target_compile_definitions(libgemma PRIVATE $<$<PLATFORM_ID:Windows>:_CRT_SECURE
|
|||
target_compile_options(libgemma PRIVATE $<$<PLATFORM_ID:Windows>:-Wno-deprecated-declarations>)
|
||||
install(TARGETS libgemma DESTINATION lib)
|
||||
|
||||
# Shared library target for C# interop
|
||||
if(BUILD_GEMMA_DLL)
|
||||
add_library(gemma_shared SHARED ${SOURCES})
|
||||
set_property(TARGET gemma_shared PROPERTY CXX_STANDARD 17)
|
||||
set_target_properties(gemma_shared PROPERTIES
|
||||
PREFIX ""
|
||||
OUTPUT_NAME "gemma"
|
||||
)
|
||||
set_property(TARGET gemma_shared PROPERTY POSITION_INDEPENDENT_CODE ON)
|
||||
target_include_directories(gemma_shared PUBLIC ./)
|
||||
target_link_libraries(gemma_shared PRIVATE
|
||||
$<LINK_LIBRARY:WHOLE_ARCHIVE,hwy>
|
||||
$<LINK_LIBRARY:WHOLE_ARCHIVE,hwy_contrib>
|
||||
$<LINK_LIBRARY:WHOLE_ARCHIVE,sentencepiece-static>
|
||||
)
|
||||
target_include_directories(gemma_shared PUBLIC ${sentencepiece_SOURCE_DIR})
|
||||
target_compile_definitions(gemma_shared
|
||||
PRIVATE
|
||||
GEMMA_EXPORTS
|
||||
$<$<PLATFORM_ID:Windows>:_CRT_SECURE_NO_WARNINGS NOMINMAX>
|
||||
)
|
||||
target_compile_options(gemma_shared PRIVATE $<$<PLATFORM_ID:Windows>:-Wno-deprecated-declarations>)
|
||||
install(TARGETS gemma_shared DESTINATION lib)
|
||||
install(FILES gemma/c_api.h DESTINATION include/gemma)
|
||||
install(FILES gemma/GemmaInterop.cs DESTINATION include/gemma)
|
||||
endif()
|
||||
|
||||
# Executable Target
|
||||
|
||||
add_executable(gemma gemma/run.cc)
|
||||
|
|
@ -154,17 +190,14 @@ enable_testing()
|
|||
include(GoogleTest)
|
||||
|
||||
set(GEMMA_TEST_FILES
|
||||
backprop/backward_scalar_test.cc
|
||||
backprop/backward_test.cc
|
||||
backprop/optimize_test.cc
|
||||
compression/blob_store_test.cc
|
||||
compression/compress_test.cc
|
||||
compression/distortion_test.cc
|
||||
compression/fields_test.cc
|
||||
compression/nuq_test.cc
|
||||
compression/sfp_test.cc
|
||||
evals/gemma_test.cc
|
||||
gemma/tensor_index_test.cc
|
||||
gemma/tensor_info_test.cc
|
||||
io/blob_store_test.cc
|
||||
io/fields_test.cc
|
||||
ops/bench_matmul.cc
|
||||
ops/dot_test.cc
|
||||
ops/gemma_matvec_test.cc
|
||||
|
|
@ -197,8 +230,5 @@ endif() # GEMMA_ENABLE_TESTS
|
|||
|
||||
## Tools
|
||||
|
||||
add_executable(compress_weights compression/compress_weights.cc)
|
||||
target_link_libraries(compress_weights libgemma hwy hwy_contrib)
|
||||
|
||||
add_executable(migrate_weights compression/migrate_weights.cc)
|
||||
add_executable(migrate_weights io/migrate_weights.cc)
|
||||
target_link_libraries(migrate_weights libgemma hwy hwy_contrib)
|
||||
|
|
|
|||
|
|
@ -31,6 +31,24 @@
|
|||
"lhs": "${hostSystemName}",
|
||||
"rhs": "Windows"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "windows-dll",
|
||||
"inherits": "__defaults__",
|
||||
"displayName": "Windows DLL",
|
||||
"description": "Visual Studio 2022 with Clang/LLVM frontend (DLL build)",
|
||||
"generator": "Visual Studio 17 2022",
|
||||
"toolset": "ClangCL",
|
||||
"condition": {
|
||||
"type": "equals",
|
||||
"lhs": "${hostSystemName}",
|
||||
"rhs": "Windows"
|
||||
},
|
||||
"cacheVariables": {
|
||||
"BUILD_SHARED_LIBS": "OFF",
|
||||
"CMAKE_WINDOWS_EXPORT_ALL_SYMBOLS": "ON",
|
||||
"BUILD_GEMMA_DLL": "ON"
|
||||
}
|
||||
}
|
||||
],
|
||||
"buildPresets": [
|
||||
|
|
@ -54,6 +72,15 @@
|
|||
"displayName": "Windows",
|
||||
"configuration": "Release",
|
||||
"configurePreset": "windows"
|
||||
},
|
||||
{
|
||||
"name": "windows-dll",
|
||||
"displayName": "Windows DLL",
|
||||
"configuration": "Release",
|
||||
"configurePreset": "windows-dll",
|
||||
"targets": [
|
||||
"gemma_shared"
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
|
|
|||
|
|
@ -96,21 +96,10 @@ https://github.com/keras-team/keras-nlp/blob/master/tools/gemma/export_gemma_to_
|
|||
From Pytorch, use the following script to generate uncompressed weights:
|
||||
https://github.com/google/gemma.cpp/blob/dev/compression/convert_weights.py
|
||||
|
||||
Then run `compression/compress_weights.cc` (Bazel target
|
||||
`compression:compress_weights`), specifying the resulting file as `--weights`
|
||||
and the desired .sbs name as the `--compressed_weights`.
|
||||
For PaliGemma, use `python/convert_from_safetensors` to create an SBS file
|
||||
directly.
|
||||
|
||||
## Compile-Time Flags (Advanced)
|
||||
|
||||
There are several compile-time flags to be aware of (note these may or may not
|
||||
be exposed to the build system):
|
||||
|
||||
- `GEMMA_MAX_SEQ_LEN` : Sets maximum sequence length to preallocate for the KV
|
||||
Cache. The default is 4096 tokens but can be overridden. This is not exposed
|
||||
through `CMakeLists.txt` yet.
|
||||
|
||||
In the medium term this will likely be deprecated in favor of handling options
|
||||
at runtime - dynamically resizing the KV cache as needed.
|
||||
For other models, `gemma_export_main.py` is not yet open sourced.
|
||||
|
||||
## Using gemma.cpp as a Library (Advanced)
|
||||
|
||||
|
|
@ -164,7 +153,7 @@ constrained decoding type of use cases where you want to force the generation to
|
|||
fit a grammar. If you're not doing this, you can send an empty lambda or
|
||||
`std::function` as a no-op which is what `run.cc` does.
|
||||
|
||||
### `Transformer()` implements the inference (i.e. `forward()` method in PyTorch or Jax) computation of the neural network
|
||||
### `Transformer()` implements inference (i.e. `forward()` in PyTorch or Jax)
|
||||
|
||||
For high-level applications, you might only call `model.Generate()` and never
|
||||
interact directly with the neural network, but if you're doing something a bit
|
||||
|
|
@ -172,9 +161,6 @@ more custom you can call transformer which performs a single inference operation
|
|||
on a single token and mutates the Activations and the KVCache through the neural
|
||||
network computation.
|
||||
|
||||
Note that an experimental backward pass is available in backprop/, which may be
|
||||
useful for fine tuning.
|
||||
|
||||
### For low level operations, defining new architectures, call `ops.h` functions directly
|
||||
|
||||
You use `ops.h` if you're writing other NN architectures or modifying the
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ bazel_dep(name = "google_benchmark", version = "1.8.5")
|
|||
# Require a more recent version.
|
||||
git_override(
|
||||
module_name = "highway",
|
||||
commit = "c5bebf84ad01edec97e336f5c97ca4e0df6b4d06",
|
||||
commit = "9414b48aeec251b69e6cadbfa42bebb5ddae1c34",
|
||||
remote = "https://github.com/google/highway",
|
||||
)
|
||||
|
||||
|
|
@ -71,6 +71,7 @@ pip.parse(
|
|||
requirements_lock = "//compression/python:requirements.txt",
|
||||
)
|
||||
use_repo(pip, "compression_deps")
|
||||
|
||||
pip.parse(
|
||||
hub_name = "python_deps",
|
||||
python_version = "3.11",
|
||||
|
|
|
|||
216
README.md
216
README.md
|
|
@ -6,7 +6,7 @@ foundation models from Google.
|
|||
For additional information about Gemma, see
|
||||
[ai.google.dev/gemma](https://ai.google.dev/gemma). Model weights, including
|
||||
gemma.cpp specific artifacts, are
|
||||
[available on kaggle](https://www.kaggle.com/models/google/gemma).
|
||||
[available on kaggle](https://www.kaggle.com/models/google/gemma-2).
|
||||
|
||||
## Who is this project for?
|
||||
|
||||
|
|
@ -18,8 +18,8 @@ deployment-oriented C++ inference runtimes, which are not designed for
|
|||
experimentation, and Python-centric ML research frameworks, which abstract away
|
||||
low-level computation through compilation.
|
||||
|
||||
gemma.cpp provides a minimalist implementation of Gemma-1, Gemma-2, Gemma-3, and
|
||||
PaliGemma models, focusing on simplicity and directness rather than full
|
||||
gemma.cpp provides a minimalist implementation of Gemma-2, Gemma-3, and
|
||||
PaliGemma-2 models, focusing on simplicity and directness rather than full
|
||||
generality. This is inspired by vertically-integrated model implementations such
|
||||
as [ggml](https://github.com/ggerganov/ggml),
|
||||
[llama.c](https://github.com/karpathy/llama2.c), and
|
||||
|
|
@ -45,9 +45,41 @@ this invite link](https://discord.gg/H5jCBAWxAe). This project follows
|
|||
[Google's Open Source Community
|
||||
Guidelines](https://opensource.google.com/conduct/).
|
||||
|
||||
*Active development is currently done on the `dev` branch. Please open pull
|
||||
requests targeting `dev` branch instead of `main`, which is intended to be more
|
||||
stable.*
|
||||
> [!NOTE] Active development is currently done on the `dev` branch. Please open
|
||||
> pull requests targeting `dev` branch instead of `main`, which is intended to
|
||||
> be more stable.
|
||||
|
||||
## What's inside?
|
||||
|
||||
- LLM
|
||||
|
||||
- CPU-only inference for: Gemma 2-3, Griffin(SSM), PaliGemma 2.
|
||||
- Sampling with TopK and temperature.
|
||||
- Backward pass (VJP) and Adam optimizer for Gemma research.
|
||||
|
||||
- Optimizations
|
||||
|
||||
- Mixed-precision (fp8, bf16, fp32, fp64 bit) GEMM:
|
||||
- Designed for BF16 instructions, can efficiently emulate them.
|
||||
- Automatic runtime autotuning 7 parameters per matrix shape.
|
||||
- Weight compression integrated directly into GEMM:
|
||||
- Custom fp8 format with 2..3 mantissa bits; tensor scaling.
|
||||
- Also bf16, f32 and non-uniform 4-bit (NUQ); easy to add new formats.
|
||||
|
||||
- Infrastructure
|
||||
|
||||
- SIMD: single implementation via Highway. Chooses ISA at runtime.
|
||||
- Tensor parallelism: CCX-aware, multi-socket thread pool.
|
||||
- Disk I/O: memory map or parallel read (heuristic with user override).
|
||||
- Custom format with forward/backward-compatible metadata serialization.
|
||||
- Model conversion from Safetensors, not yet open sourced.
|
||||
- Portability: Linux, Windows/OS X supported. CMake/Bazel. 'Any' CPU.
|
||||
|
||||
- Frontends
|
||||
|
||||
- C++ APIs with streaming for single query and batched inference.
|
||||
- Basic interactive command-line app.
|
||||
- Basic Python bindings (pybind11).
|
||||
|
||||
## Quick Start
|
||||
|
||||
|
|
@ -74,57 +106,20 @@ winget install --id Microsoft.VisualStudio.2022.BuildTools --force --override "-
|
|||
|
||||
Visit the
|
||||
[Kaggle page for Gemma-2](https://www.kaggle.com/models/google/gemma-2/gemmaCpp)
|
||||
[or Gemma-1](https://www.kaggle.com/models/google/gemma/frameworks/gemmaCpp),
|
||||
and select `Model Variations |> Gemma C++`.
|
||||
|
||||
On this tab, the `Variation` dropdown includes the options below. Note bfloat16
|
||||
weights are higher fidelity, while 8-bit switched floating point weights enable
|
||||
faster inference. In general, we recommend starting with the `-sfp` checkpoints.
|
||||
|
||||
If you are unsure which model to start with, we recommend starting with the
|
||||
smallest Gemma-2 model, i.e. `2.0-2b-it-sfp`.
|
||||
|
||||
Alternatively, visit the
|
||||
[gemma.cpp](https://huggingface.co/models?other=gemma.cpp) models on the Hugging
|
||||
Face Hub. First go the model repository of the model of interest (see
|
||||
recommendations below). Then, click the `Files and versions` tab and download
|
||||
the model and tokenizer files. For programmatic downloading, if you have
|
||||
`huggingface_hub` installed, you can also download by running:
|
||||
|
||||
```
|
||||
huggingface-cli login # Just the first time
|
||||
huggingface-cli download google/gemma-2b-sfp-cpp --local-dir build/
|
||||
```
|
||||
|
||||
Gemma-1 2B instruction-tuned (`it`) and pre-trained (`pt`) models:
|
||||
|
||||
| Model name | Description |
|
||||
| ----------- | ----------- |
|
||||
| `2b-it` | 2 billion parameter instruction-tuned model, bfloat16 |
|
||||
| `2b-it-sfp` | 2 billion parameter instruction-tuned model, 8-bit switched floating point |
|
||||
| `2b-pt` | 2 billion parameter pre-trained model, bfloat16 |
|
||||
| `2b-pt-sfp` | 2 billion parameter pre-trained model, 8-bit switched floating point |
|
||||
|
||||
Gemma-1 7B instruction-tuned (`it`) and pre-trained (`pt`) models:
|
||||
|
||||
| Model name | Description |
|
||||
| ----------- | ----------- |
|
||||
| `7b-it` | 7 billion parameter instruction-tuned model, bfloat16 |
|
||||
| `7b-it-sfp` | 7 billion parameter instruction-tuned model, 8-bit switched floating point |
|
||||
| `7b-pt` | 7 billion parameter pre-trained model, bfloat16 |
|
||||
| `7b-pt-sfp` | 7 billion parameter pre-trained model, 8-bit switched floating point |
|
||||
|
||||
> [!NOTE]
|
||||
> **Important**: We strongly recommend starting off with the `2b-it-sfp` model to
|
||||
> get up and running.
|
||||
> [!NOTE] **Important**: We strongly recommend starting off with the
|
||||
> `gemma2-2b-it-sfp` model to get up and running.
|
||||
|
||||
Gemma 2 models are named `gemma2-2b-it` for 2B and `9b-it` or `27b-it`. See the
|
||||
`kModelFlags` definition in `common.cc`.
|
||||
`ModelPrefix` function in `configs.cc`.
|
||||
|
||||
### Step 2: Extract Files
|
||||
|
||||
If you downloaded the models from Hugging Face, skip to step 3.
|
||||
|
||||
After filling out the consent form, the download should proceed to retrieve a
|
||||
tar archive file `archive.tar.gz`. Extract files from `archive.tar.gz` (this can
|
||||
take a few minutes):
|
||||
|
|
@ -162,10 +157,9 @@ cmake --build --preset make -j [number of parallel threads to use]
|
|||
```
|
||||
|
||||
Replace `[number of parallel threads to use]` with a number - the number of
|
||||
cores available on your system is a reasonable heuristic. For example,
|
||||
`make -j4 gemma` will build using 4 threads. If the `nproc` command is
|
||||
available, you can use `make -j$(nproc) gemma` as a reasonable default
|
||||
for the number of threads.
|
||||
cores available on your system is a reasonable heuristic. For example, `make -j4
|
||||
gemma` will build using 4 threads. If the `nproc` command is available, you can
|
||||
use `make -j$(nproc) gemma` as a reasonable default for the number of threads.
|
||||
|
||||
If you aren't sure of the right value for the `-j` flag, you can simply run
|
||||
`make gemma` instead and it should still build the `./gemma` executable.
|
||||
|
|
@ -174,7 +168,8 @@ If you aren't sure of the right value for the `-j` flag, you can simply run
|
|||
> On Windows Subsystem for Linux (WSL) users should set the number of
|
||||
> parallel threads to 1. Using a larger number may result in errors.
|
||||
|
||||
If the build is successful, you should now have a `gemma` executable in the `build/` directory.
|
||||
If the build is successful, you should now have a `gemma` executable in the
|
||||
`build/` directory.
|
||||
|
||||
#### Windows
|
||||
|
||||
|
|
@ -186,7 +181,8 @@ cmake --preset windows
|
|||
cmake --build --preset windows -j [number of parallel threads to use]
|
||||
```
|
||||
|
||||
If the build is successful, you should now have a `gemma.exe` executable in the `build/` directory.
|
||||
If the build is successful, you should now have a `gemma.exe` executable in the
|
||||
`build/` directory.
|
||||
|
||||
#### Bazel
|
||||
|
||||
|
|
@ -194,7 +190,8 @@ If the build is successful, you should now have a `gemma.exe` executable in the
|
|||
bazel build -c opt --cxxopt=-std=c++20 :gemma
|
||||
```
|
||||
|
||||
If the build is successful, you should now have a `gemma` executable in the `bazel-bin/` directory.
|
||||
If the build is successful, you should now have a `gemma` executable in the
|
||||
`bazel-bin/` directory.
|
||||
|
||||
#### Make
|
||||
|
||||
|
|
@ -208,33 +205,21 @@ You can now run `gemma` from inside the `build/` directory.
|
|||
|
||||
`gemma` has the following required arguments:
|
||||
|
||||
Argument | Description | Example value
|
||||
--------------- | ---------------------------- | -----------------------
|
||||
`--model` | The model type. | `2b-it` ... (see below)
|
||||
`--weights` | The compressed weights file. | `2b-it-sfp.sbs`
|
||||
`--weight_type` | The compressed weight type. | `sfp`
|
||||
`--tokenizer` | The tokenizer file. | `tokenizer.spm`
|
||||
|
||||
`gemma` is invoked as:
|
||||
|
||||
```sh
|
||||
./gemma \
|
||||
--tokenizer [tokenizer file] \
|
||||
--weights [compressed weights file] \
|
||||
--weight_type [f32 or bf16 or sfp (default:sfp)] \
|
||||
--model [2b-it or 2b-pt or 7b-it or 7b-pt or ...]
|
||||
```
|
||||
Argument | Description | Example value
|
||||
------------- | ---------------------------- | ---------------
|
||||
`--weights` | The compressed weights file. | `2b-it-sfp.sbs`
|
||||
`--tokenizer` | The tokenizer file. | `tokenizer.spm`
|
||||
|
||||
Example invocation for the following configuration:
|
||||
|
||||
- Compressed weights file `2b-it-sfp.sbs` (2B instruction-tuned model, 8-bit
|
||||
switched floating point).
|
||||
- Tokenizer file `tokenizer.spm`.
|
||||
- weights file `gemma2-2b-it-sfp.sbs` (Gemma2 2B instruction-tuned model,
|
||||
8-bit switched floating point).
|
||||
- Tokenizer file `tokenizer.spm` (can omit for single-format weights files
|
||||
created after 2025-05-06, or output by migrate_weights.cc).
|
||||
|
||||
```sh
|
||||
./gemma \
|
||||
--tokenizer tokenizer.spm \
|
||||
--weights 2b-it-sfp.sbs --model 2b-it
|
||||
--tokenizer tokenizer.spm --weights gemma2-2b-it-sfp.sbs
|
||||
```
|
||||
|
||||
### RecurrentGemma
|
||||
|
|
@ -256,23 +241,20 @@ Step 1, and run the binary as follows:
|
|||
|
||||
### PaliGemma Vision-Language Model
|
||||
|
||||
This repository includes a version of the PaliGemma VLM
|
||||
([paper](https://arxiv.org/abs/2407.07726),
|
||||
[code](https://github.com/google-research/big_vision/tree/main/big_vision/configs/proj/paligemma))
|
||||
and its successor PaliGemma 2 ([paper](https://arxiv.org/abs/2412.03555)). We
|
||||
provide a C++ implementation of the PaliGemma model family here.
|
||||
This repository includes a version of the PaliGemma 2 VLM
|
||||
([paper](https://arxiv.org/abs/2412.03555)). We provide a C++ implementation of
|
||||
the PaliGemma 2 model here.
|
||||
|
||||
To use the version of PaliGemma included in this repository, build the gemma
|
||||
binary as noted above in Step 3. Download the compressed weights and tokenizer
|
||||
from
|
||||
[Kaggle](https://www.kaggle.com/models/google/paligemma/gemmaCpp/paligemma-3b-mix-224)
|
||||
[Kaggle](https://www.kaggle.com/models/google/paligemma-2/gemmaCpp/paligemma2-3b-mix-224)
|
||||
and run the binary as follows:
|
||||
|
||||
```sh
|
||||
./gemma \
|
||||
--tokenizer paligemma_tokenizer.model \
|
||||
--model paligemma-224 \
|
||||
--weights paligemma-3b-mix-224-sfp.sbs \
|
||||
--weights paligemma2-3b-mix-224-sfp.sbs \
|
||||
--image_file paligemma/testdata/image.ppm
|
||||
```
|
||||
|
||||
|
|
@ -312,12 +294,12 @@ allows to contain the tokenizer (and the model type) directly. A tool to migrate
|
|||
from the multi-file format to the single-file format is available.
|
||||
|
||||
```sh
|
||||
compression/migrate_weights \
|
||||
io/migrate_weights \
|
||||
--tokenizer .../tokenizer.spm --weights .../gemma2-2b-it-sfp.sbs \
|
||||
--model gemma2-2b-it --output_weights .../gemma2-2b-it-sfp-single.sbs
|
||||
--output_weights .../gemma2-2b-it-sfp-single.sbs
|
||||
```
|
||||
|
||||
After migration, you can use the new weights file with gemma.cpp like this:
|
||||
After migration, you can omit the tokenizer argument like this:
|
||||
|
||||
```sh
|
||||
./gemma --weights .../gemma2-2b-it-sfp-single.sbs
|
||||
|
|
@ -325,15 +307,6 @@ After migration, you can use the new weights file with gemma.cpp like this:
|
|||
|
||||
### Troubleshooting and FAQs
|
||||
|
||||
**Running `./gemma` fails with "Failed to read cache gating_ein_0 (error 294) ..."**
|
||||
|
||||
The most common problem is that the `--weight_type` argument does not match that
|
||||
of the model file. Revisit step #3 and check which weights you downloaded.
|
||||
|
||||
Note that we have already moved weight type from a compile-time decision to a
|
||||
runtime argument. In a subsequent step, we plan to bake this information into
|
||||
the weights.
|
||||
|
||||
**Problems building in Windows / Visual Studio**
|
||||
|
||||
Currently if you're using Windows, we recommend building in WSL (Windows
|
||||
|
|
@ -344,22 +317,22 @@ configurations, see issues for active discussion.
|
|||
|
||||
A common issue is that you are using a pre-trained model, which is not
|
||||
instruction-tuned and thus does not respond to instructions. Make sure you are
|
||||
using an instruction-tuned model (`2b-it-sfp`, `2b-it`, `7b-it-sfp`, `7b-it`)
|
||||
and not a pre-trained model (any model with a `-pt` suffix).
|
||||
using an instruction-tuned model (`gemma2-2b-it-sfp`) and not a pre-trained
|
||||
model (any model with a `-pt` suffix).
|
||||
|
||||
**What sequence lengths are supported?**
|
||||
|
||||
See `seq_len` in `configs.cc`. For the Gemma 3 models larger than 1B, this is
|
||||
typically 32K but 128K would also work given enough RAM. Note that long
|
||||
sequences will be slow due to the quadratic cost of attention.
|
||||
See `max_seq_len` in `configs.cc` and `InferenceArgs.seq_len`. For the Gemma 3
|
||||
models larger than 1B, this is typically 32K but 128K would also work given
|
||||
enough RAM. Note that long sequences will be slow due to the quadratic cost of
|
||||
attention.
|
||||
|
||||
**How do I convert my fine-tune to a `.sbs` compressed model file?**
|
||||
|
||||
For PaliGemma (1 and 2) checkpoints, you can use
|
||||
python/convert_from_safetensors.py to convert from safetensors format (tested
|
||||
with building via bazel). For an adapter model, you will likely need to call
|
||||
merge_and_unload() to convert the adapter model to a single-file format before
|
||||
converting it.
|
||||
For PaliGemma 2 checkpoints, you can use python/convert_from_safetensors.py to
|
||||
convert from safetensors format (tested with building via bazel). For an adapter
|
||||
model, you will likely need to call merge_and_unload() to convert the adapter
|
||||
model to a single-file format before converting it.
|
||||
|
||||
Here is how to use it using a bazel build of the compression library assuming
|
||||
locally installed (venv) torch, numpy, safetensors, absl-py, etc.:
|
||||
|
|
@ -373,22 +346,18 @@ ln -s $BAZEL_OUTPUT_DIR [...]/site-packages/compression
|
|||
python3 python/convert_from_safetensors.py --load_path [...].safetensors.index.json
|
||||
```
|
||||
|
||||
See also compression/convert_weights.py for a slightly older option to convert a
|
||||
pytorch checkpoint. (The code may need updates to work with Gemma-2 models.)
|
||||
|
||||
**What are some easy ways to make the model run faster?**
|
||||
|
||||
1. Make sure you are using the 8-bit switched floating point `-sfp` models.
|
||||
These are half the size of bf16 and thus use less memory bandwidth and cache
|
||||
space.
|
||||
2. If you're on a laptop, make sure power mode is set to maximize performance
|
||||
2. Due to auto-tuning, the second and especially third query will be faster.
|
||||
3. If you're on a laptop, make sure power mode is set to maximize performance
|
||||
and saving mode is **off**. For most laptops, the power saving modes get
|
||||
activated automatically if the computer is not plugged in.
|
||||
3. Close other unused cpu-intensive applications.
|
||||
4. On macs, anecdotally we observe a "warm-up" ramp-up in speed as performance
|
||||
4. Close other unused cpu-intensive applications.
|
||||
5. On macs, anecdotally we observe a "warm-up" ramp-up in speed as performance
|
||||
cores get engaged.
|
||||
5. Experiment with the `--num_threads` argument value. Depending on the device,
|
||||
larger numbers don't always mean better performance.
|
||||
|
||||
We're also working on algorithmic and optimization approaches for faster
|
||||
inference, stay tuned.
|
||||
|
|
@ -411,7 +380,7 @@ newline input.
|
|||
By default, verbosity is set to 1, bringing up a terminal-based interactive
|
||||
interface when `gemma` is invoked:
|
||||
|
||||
```console
|
||||
```sh
|
||||
$ ./gemma [...]
|
||||
__ _ ___ _ __ ___ _ __ ___ __ _ ___ _ __ _ __
|
||||
/ _` |/ _ \ '_ ` _ \| '_ ` _ \ / _` | / __| '_ \| '_ \
|
||||
|
|
@ -420,11 +389,7 @@ $ ./gemma [...]
|
|||
__/ | | | | |
|
||||
|___/ |_| |_|
|
||||
|
||||
tokenizer : tokenizer.spm
|
||||
compressed_weights : 2b-it-sfp.sbs
|
||||
model : 2b-it
|
||||
weights : [no path specified]
|
||||
max_generated_tokens : 2048
|
||||
...
|
||||
|
||||
*Usage*
|
||||
Enter an instruction and press enter (%C reset conversation, %Q quits).
|
||||
|
|
@ -462,7 +427,7 @@ For using the `gemma` executable as a command line tool, it may be useful to
|
|||
create an alias for gemma.cpp with arguments fully specified:
|
||||
|
||||
```sh
|
||||
alias gemma2b="~/gemma.cpp/build/gemma -- --tokenizer ~/gemma.cpp/build/tokenizer.spm --weights ~/gemma.cpp/build/gemma2-2b-it-sfp.sbs --model gemma2-2b-it --verbosity 0"
|
||||
alias gemma2b="~/gemma.cpp/build/gemma -- --tokenizer ~/gemma.cpp/build/tokenizer.spm --weights ~/gemma.cpp/build/gemma2-2b-it-sfp.sbs --verbosity 0"
|
||||
```
|
||||
|
||||
Replace the above paths with your own paths to the model and tokenizer paths
|
||||
|
|
@ -481,7 +446,7 @@ cat configs.h | tail -n 35 | tr '\n' ' ' | xargs -0 echo "What does this C++ cod
|
|||
|
||||
The output of the above command should look like:
|
||||
|
||||
```console
|
||||
```sh
|
||||
[ Reading prompt ] [...]
|
||||
This C++ code snippet defines a set of **constants** used in a large language model (LLM) implementation, likely related to the **attention mechanism**.
|
||||
|
||||
|
|
@ -492,8 +457,8 @@ Let's break down the code:
|
|||
### Incorporating gemma.cpp as a Library in your Project
|
||||
|
||||
The easiest way to incorporate gemma.cpp in your own project is to pull in
|
||||
gemma.cpp and dependencies using `FetchContent`. You can add the following to your
|
||||
CMakeLists.txt:
|
||||
gemma.cpp and dependencies using `FetchContent`. You can add the following to
|
||||
your CMakeLists.txt:
|
||||
|
||||
```
|
||||
include(FetchContent)
|
||||
|
|
@ -562,9 +527,10 @@ submit a PR with a `README.md` edit.
|
|||
|
||||
## Acknowledgements and Contacts
|
||||
|
||||
gemma.cpp was started in fall 2023 by [Austin Huang](mailto:austinvhuang@google.com)
|
||||
and [Jan Wassenberg](mailto:janwas@google.com), and subsequently released February 2024
|
||||
thanks to contributions from Phil Culliton, Paul Chang, and Dan Zheng.
|
||||
gemma.cpp was started in fall 2023 by
|
||||
[Austin Huang](mailto:austinvhuang@google.com) and
|
||||
[Jan Wassenberg](mailto:janwas@google.com), and subsequently released February
|
||||
2024 thanks to contributions from Phil Culliton, Paul Chang, and Dan Zheng.
|
||||
|
||||
Griffin support was implemented in April 2024 thanks to contributions by Andrey
|
||||
Mikhaylov, Eugene Kliuchnikov, Jan Wassenberg, Jyrki Alakuijala, Lode
|
||||
|
|
|
|||
|
|
@ -1,81 +0,0 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef THIRD_PARTY_GEMMA_CPP_BACKPROP_ACTIVATIONS_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_BACKPROP_ACTIVATIONS_H_
|
||||
|
||||
#include <stddef.h>
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "compression/compress.h" // MatStorageT
|
||||
#include "gemma/configs.h" // ModelConfig
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
template <typename T>
|
||||
struct ForwardLayer {
|
||||
ForwardLayer(const LayerConfig& config, size_t seq_len)
|
||||
: input("input", seq_len, config.model_dim),
|
||||
pre_att_rms_out("pre_att_rms_out", seq_len, config.model_dim),
|
||||
qkv("qkv", seq_len * (config.heads + 2), config.qkv_dim),
|
||||
att("att", seq_len * config.heads, seq_len),
|
||||
att_out("att_out", seq_len * config.heads, config.qkv_dim),
|
||||
att_post1("att_post1", seq_len, config.model_dim),
|
||||
attention_out("attention_out", seq_len, config.model_dim),
|
||||
bf_pre_ffw_rms_out("bf_pre_ffw_rms_out", seq_len, config.model_dim),
|
||||
ffw_hidden("ffw_hidden", seq_len, config.ff_hidden_dim * 2),
|
||||
ffw_hidden_gated("ffw_hidden_gated", seq_len, config.ff_hidden_dim),
|
||||
layer_config(config) {}
|
||||
|
||||
MatStorageT<T> input;
|
||||
MatStorageT<T> pre_att_rms_out;
|
||||
MatStorageT<T> qkv;
|
||||
MatStorageT<T> att;
|
||||
MatStorageT<T> att_out;
|
||||
MatStorageT<T> att_post1;
|
||||
MatStorageT<T> attention_out;
|
||||
MatStorageT<T> bf_pre_ffw_rms_out;
|
||||
MatStorageT<T> ffw_hidden;
|
||||
MatStorageT<T> ffw_hidden_gated;
|
||||
const LayerConfig& layer_config;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct ForwardPass {
|
||||
ForwardPass(const ModelConfig& config)
|
||||
: final_layer_output("final_layer_output", config.seq_len,
|
||||
config.model_dim),
|
||||
final_norm_output("final_norm_output", config.seq_len,
|
||||
config.model_dim),
|
||||
logits("logits", config.seq_len, config.vocab_size),
|
||||
probs("probs", config.seq_len, config.vocab_size),
|
||||
weights_config(config) {
|
||||
for (const auto& layer_config : config.layer_configs) {
|
||||
layers.emplace_back(layer_config, config.seq_len);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<ForwardLayer<T>> layers;
|
||||
MatStorageT<T> final_layer_output;
|
||||
MatStorageT<T> final_norm_output;
|
||||
MatStorageT<T> logits;
|
||||
MatStorageT<T> probs;
|
||||
const ModelConfig& weights_config;
|
||||
};
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_BACKPROP_ACTIVATIONS_H_
|
||||
|
|
@ -1,404 +0,0 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Implementation of the Vector-Jacobian Products (VJP) of the individual
|
||||
// operations of the forward pass.
|
||||
|
||||
// Include guard for non-SIMD code.
|
||||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_INL_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_INL_H_
|
||||
|
||||
#include <stddef.h>
|
||||
|
||||
#include <cmath>
|
||||
#include <vector>
|
||||
|
||||
#include "backprop/activations.h"
|
||||
#include "backprop/prompt.h"
|
||||
#include "gemma/common.h"
|
||||
#include "gemma/weights.h"
|
||||
#include "util/allocator.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_INL_H_
|
||||
|
||||
// Include guard for (potentially) SIMD code.
|
||||
#if defined(THIRD_PARTY_GEMMA_CPP_BACKWARD_TOGGLE) == defined(HWY_TARGET_TOGGLE)
|
||||
#ifdef THIRD_PARTY_GEMMA_CPP_BACKWARD_TOGGLE
|
||||
#undef THIRD_PARTY_GEMMA_CPP_BACKWARD_TOGGLE
|
||||
#else
|
||||
#define THIRD_PARTY_GEMMA_CPP_BACKWARD_TOGGLE
|
||||
#endif
|
||||
|
||||
#include "hwy/highway.h"
|
||||
// After highway.h
|
||||
#include "ops/matmul-inl.h"
|
||||
#include "ops/ops-inl.h"
|
||||
#include "hwy/contrib/dot/dot-inl.h"
|
||||
|
||||
HWY_BEFORE_NAMESPACE();
|
||||
namespace gcpp {
|
||||
namespace HWY_NAMESPACE {
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
|
||||
HWY_INLINE void MatMulVJP(const float* HWY_RESTRICT weights, // kRows * kCols,
|
||||
const float* HWY_RESTRICT x, // num_tokens * kCols
|
||||
const float* HWY_RESTRICT v, // num_tokens * kRows
|
||||
size_t cols, size_t rows, size_t num_tokens,
|
||||
float* HWY_RESTRICT grad_w, // kRows * kCols,
|
||||
float* HWY_RESTRICT grad_x, // num_tokens * kCols
|
||||
hwy::ThreadPool& pool) {
|
||||
hwy::ZeroBytes(grad_x, num_tokens * cols * sizeof(grad_x[0]));
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
const size_t voffs = pos * rows;
|
||||
const size_t xoffs = pos * cols;
|
||||
for (size_t j = 0; j < rows; ++j) {
|
||||
MulByConstAndAdd(v[voffs + j], &x[xoffs], &grad_w[j * cols], cols);
|
||||
MulByConstAndAdd(v[voffs + j], &weights[j * cols], &grad_x[xoffs], cols);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
HWY_INLINE void MultiHeadMatMulVJP(
|
||||
const float* HWY_RESTRICT weights, // heads * kRows * kCols
|
||||
const float* HWY_RESTRICT x, // num_tokens * heads * kCols
|
||||
const float* HWY_RESTRICT v, // num_tokens * kRows
|
||||
size_t heads, size_t cols, size_t rows, size_t num_tokens,
|
||||
float* HWY_RESTRICT grad_w, // heads * kRows * kCols
|
||||
float* HWY_RESTRICT grad_x, // num_tokens * heads * kCols
|
||||
hwy::ThreadPool& pool) {
|
||||
hwy::ZeroBytes(grad_x, num_tokens * heads * cols * sizeof(grad_x[0]));
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
for (size_t j = 0; j < rows; ++j) {
|
||||
for (size_t h = 0; h < heads; ++h) {
|
||||
MulByConstAndAdd(v[pos * rows + j], &x[pos * heads * cols + h * cols],
|
||||
&grad_w[h * rows * cols + j * cols], cols);
|
||||
MulByConstAndAdd(v[pos * rows + j],
|
||||
&weights[h * rows * cols + j * cols],
|
||||
&grad_x[pos * heads * cols + h * cols], cols);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <class D, HWY_IF_F32_D(D)>
|
||||
static HWY_INLINE hn::Vec<D> DGelu(D d, hn::Vec<D> v) {
|
||||
const hn::Vec<D> kMul = hn::Set(d, 0.044715f);
|
||||
const hn::Vec<D> kSqrt2OverPi = hn::Set(d, 0.797884560804236f);
|
||||
const hn::Vec<D> kHalf = hn::Set(d, 0.5f);
|
||||
const hn::Vec<D> kOne = hn::Set(d, 1.0f);
|
||||
// kSqrtOverPi*3*kMul
|
||||
const hn::Vec<D> kMulv2 = hn::Set(d, 0.1070322244f);
|
||||
|
||||
const hn::Vec<D> v2 = hn::Mul(v, v);
|
||||
const hn::Vec<D> v3 = hn::Mul(v2, v);
|
||||
const hn::Vec<D> arg = hn::Mul(kSqrt2OverPi, hn::MulAdd(kMul, v3, v));
|
||||
const hn::Vec<D> tanh = hn::Tanh(d, arg);
|
||||
const hn::Vec<D> cdf = hn::MulAdd(kHalf, tanh, kHalf);
|
||||
const hn::Vec<D> dtanh = hn::Sub(kOne, hn::Mul(tanh, tanh));
|
||||
const hn::Vec<D> darg = hn::MulAdd(kMulv2, v2, kSqrt2OverPi);
|
||||
return hn::MulAdd(kHalf, hn::Mul(v, hn::Mul(dtanh, darg)), cdf);
|
||||
}
|
||||
|
||||
static HWY_NOINLINE void SoftmaxVJP(const float* HWY_RESTRICT forward,
|
||||
float* HWY_RESTRICT backward,
|
||||
const size_t size) {
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
using D = hn::ScalableTag<float>;
|
||||
const D d;
|
||||
|
||||
const auto offset =
|
||||
hn::Set(d, hn::Dot::Compute<0>(d, forward, backward, size));
|
||||
hn::Transform1(
|
||||
d, backward, size, forward,
|
||||
[&offset](const auto d, const auto v, const auto y)
|
||||
HWY_ATTR { return hn::Mul(y, hn::Sub(v, offset)); });
|
||||
}
|
||||
|
||||
static HWY_NOINLINE void RMSNormVJP(
|
||||
const float* HWY_RESTRICT weights, const float* HWY_RESTRICT x,
|
||||
const float* HWY_RESTRICT v, size_t model_dim, size_t num_tokens,
|
||||
float* HWY_RESTRICT grad_w, float* HWY_RESTRICT grad_x,
|
||||
hwy::ThreadPool& pool) {
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
const size_t offset = pos * model_dim;
|
||||
const float ss = detail::RMSNormMul(x + offset, model_dim);
|
||||
|
||||
for (size_t i = 0; i < model_dim; ++i) {
|
||||
grad_w[i] += v[offset + i] * x[offset + i] * ss;
|
||||
}
|
||||
const float ss3 = ss * ss * ss / StaticCast<float>(model_dim);
|
||||
float tmp = 0.0f;
|
||||
for (size_t i = 0; i < model_dim; ++i) {
|
||||
tmp += (1.0f + weights[i]) * v[offset + i] * x[offset + i];
|
||||
}
|
||||
tmp *= ss3;
|
||||
for (size_t i = 0; i < model_dim; ++i) {
|
||||
grad_x[offset + i] = ss * (1.0f + weights[i]) * v[offset + i] -
|
||||
tmp * x[offset + i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static HWY_NOINLINE void InputEmbeddingVJP(
|
||||
const float* weights, const std::vector<int>& prompt,
|
||||
const float scaling, const float* HWY_RESTRICT v,
|
||||
float* HWY_RESTRICT grad, size_t model_dim) {
|
||||
HWY_ASSERT(!prompt.empty());
|
||||
for (size_t pos = 0; pos < prompt.size() - 1; ++pos) {
|
||||
int token = prompt[pos];
|
||||
MulByConstAndAdd(scaling, v + pos * model_dim,
|
||||
grad + token * model_dim, model_dim);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void LayerVJP(const LayerWeightsPtrs<T>& weights,
|
||||
const ForwardLayer<float>& forward,
|
||||
const float* HWY_RESTRICT next_layer_grad, size_t num_tokens,
|
||||
LayerWeightsPtrs<T>& grad, ForwardLayer<float>& backward,
|
||||
const RowVectorBatch<float>& inv_timescale,
|
||||
hwy::ThreadPool& pool) {
|
||||
const LayerConfig& config = weights.layer_config;
|
||||
const size_t model_dim = config.model_dim;
|
||||
const size_t qkv_dim = config.qkv_dim;
|
||||
const size_t heads = config.heads;
|
||||
const size_t seq_len = forward.input.Rows();
|
||||
const size_t ff_hidden_dim = config.ff_hidden_dim;
|
||||
const float query_scale =
|
||||
static_cast<float>(1.0 / sqrt(static_cast<double>(qkv_dim)));
|
||||
HWY_ASSERT(num_tokens <= seq_len);
|
||||
|
||||
MatMulVJP(weights.linear_w.data(), forward.ffw_hidden_gated.data(),
|
||||
next_layer_grad, ff_hidden_dim, model_dim, num_tokens,
|
||||
grad.linear_w.data(), backward.ffw_hidden_gated.data(), pool);
|
||||
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
const size_t hidden_offset = pos * ff_hidden_dim * 2;
|
||||
const float* HWY_RESTRICT f_out = forward.ffw_hidden.data() + hidden_offset;
|
||||
const float* HWY_RESTRICT f_out_mul = f_out + ff_hidden_dim;
|
||||
const float* HWY_RESTRICT b_out_gated =
|
||||
backward.ffw_hidden_gated.data() + pos * ff_hidden_dim;
|
||||
float* HWY_RESTRICT b_out = backward.ffw_hidden.data() + hidden_offset;
|
||||
float* HWY_RESTRICT b_out_mul = b_out + ff_hidden_dim;
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
using DF = hn::ScalableTag<float>;
|
||||
DF df;
|
||||
for (size_t i = 0; i < ff_hidden_dim; i += Lanes(df)) {
|
||||
const auto y = Load(df, f_out + i);
|
||||
const auto x = Load(df, f_out_mul + i);
|
||||
const auto v = Load(df, b_out_gated + i);
|
||||
hn::Store(hn::Mul(v, Gelu(df, y)), df, b_out_mul + i);
|
||||
hn::Store(hn::Mul(v, hn::Mul(x, DGelu(df, y))), df, b_out + i);
|
||||
}
|
||||
}
|
||||
|
||||
MatMulVJP(weights.gating_einsum_w.data(), forward.bf_pre_ffw_rms_out.data(),
|
||||
backward.ffw_hidden.data(), model_dim, ff_hidden_dim * 2,
|
||||
num_tokens, grad.gating_einsum_w.data(),
|
||||
backward.bf_pre_ffw_rms_out.data(), pool);
|
||||
RMSNormVJP(weights.pre_ffw_norm_scale.data(), forward.attention_out.data(),
|
||||
backward.bf_pre_ffw_rms_out.data(), model_dim, num_tokens,
|
||||
grad.pre_ffw_norm_scale.data(), backward.attention_out.data(),
|
||||
pool);
|
||||
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
AddFrom(next_layer_grad + pos * model_dim,
|
||||
backward.attention_out.data() + pos * model_dim, model_dim);
|
||||
}
|
||||
|
||||
backward.qkv.ZeroInit();
|
||||
|
||||
MultiHeadMatMulVJP(weights.attn_vec_einsum_w.data(), forward.att_out.data(),
|
||||
backward.attention_out.data(), heads, qkv_dim, model_dim,
|
||||
num_tokens, grad.attn_vec_einsum_w.data(),
|
||||
backward.att_out.data(), pool);
|
||||
|
||||
for (size_t head = 0; head < heads; ++head) {
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
const size_t aoffset = head * seq_len + pos * heads * seq_len;
|
||||
const float* HWY_RESTRICT f_head_att = forward.att.data() + aoffset;
|
||||
const float* HWY_RESTRICT b_att_out =
|
||||
backward.att_out.data() + (pos * heads + head) * qkv_dim;
|
||||
float* HWY_RESTRICT b_head_att = backward.att.data() + aoffset;
|
||||
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
|
||||
const size_t v2offs = (pos2 * (heads + 2) + heads + 1) * qkv_dim;
|
||||
const float* HWY_RESTRICT f_v2 = forward.qkv.data() + v2offs;
|
||||
float* HWY_RESTRICT b_v2 = backward.qkv.data() + v2offs;
|
||||
b_head_att[pos2] = Dot(b_att_out, f_v2, qkv_dim);
|
||||
MulByConstAndAdd(f_head_att[pos2], b_att_out, b_v2, qkv_dim);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t head = 0; head < heads; ++head) {
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
const size_t aoffset = head * seq_len + pos * heads * seq_len;
|
||||
const float* HWY_RESTRICT f_head_att = forward.att.data() + aoffset;
|
||||
float* HWY_RESTRICT b_head_att = backward.att.data() + aoffset;
|
||||
SoftmaxVJP(f_head_att, b_head_att, pos + 1);
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t head = 0; head < heads; ++head) {
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
const size_t qoffs = (pos * (heads + 2) + head) * qkv_dim;
|
||||
const size_t aoffs = head * seq_len + pos * heads * seq_len;
|
||||
const float* HWY_RESTRICT f_q = forward.qkv.data() + qoffs;
|
||||
const float* HWY_RESTRICT b_head_att = backward.att.data() + aoffs;
|
||||
float* HWY_RESTRICT b_q = backward.qkv.data() + qoffs;
|
||||
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
|
||||
const size_t k2offs = (pos2 * (heads + 2) + heads) * qkv_dim;
|
||||
const float* HWY_RESTRICT f_k2 = forward.qkv.data() + k2offs;
|
||||
float* HWY_RESTRICT b_k2 = backward.qkv.data() + k2offs;
|
||||
MulByConstAndAdd(b_head_att[pos2], f_k2, b_q, qkv_dim);
|
||||
MulByConstAndAdd(b_head_att[pos2], f_q, b_k2, qkv_dim);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (int pos = 0; pos < static_cast<int>(num_tokens); ++pos) {
|
||||
float* HWY_RESTRICT b_kv =
|
||||
backward.qkv.data() + (pos * (heads + 2) + heads) * qkv_dim;
|
||||
Rope(b_kv, qkv_dim, inv_timescale.Const(), -pos);
|
||||
}
|
||||
|
||||
for (size_t head = 0; head < heads; ++head) {
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
float* HWY_RESTRICT b_q =
|
||||
backward.qkv.data() + (pos * (heads + 2) + head) * qkv_dim;
|
||||
MulByConst(query_scale, b_q, qkv_dim);
|
||||
Rope(b_q, qkv_dim, inv_timescale.Const(), -pos);
|
||||
}
|
||||
}
|
||||
|
||||
MatMulVJP(weights.qkv_einsum_w.data(), forward.pre_att_rms_out.data(),
|
||||
backward.qkv.data(), model_dim, (heads + 2) * qkv_dim, num_tokens,
|
||||
grad.qkv_einsum_w.data(), backward.pre_att_rms_out.data(), pool);
|
||||
RMSNormVJP(weights.pre_attention_norm_scale.data(), forward.input.data(),
|
||||
backward.pre_att_rms_out.data(), model_dim, num_tokens,
|
||||
grad.pre_attention_norm_scale.data(), backward.input.data(), pool);
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
AddFrom(backward.attention_out.data() + pos * model_dim,
|
||||
backward.input.data() + pos * model_dim, model_dim);
|
||||
}
|
||||
}
|
||||
|
||||
static HWY_NOINLINE void SoftcapVJP(const float cap,
|
||||
const float* HWY_RESTRICT forward,
|
||||
float* HWY_RESTRICT backward,
|
||||
const size_t size) {
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
using D = hn::ScalableTag<float>;
|
||||
const D d;
|
||||
|
||||
const auto one = hn::Set(d, 1.0f);
|
||||
const auto vcap = hn::Set(d, cap);
|
||||
const auto vinv_cap = hn::Div(hn::Set(d, 1.0f), vcap);
|
||||
hn::Transform1(d, backward, size, forward,
|
||||
[&](const auto d, const auto v, const auto y) HWY_ATTR {
|
||||
const auto scaled = hn::Mul(vinv_cap, y); // = tanh
|
||||
return hn::Mul(v, hn::Sub(one, hn::Mul(scaled, scaled)));
|
||||
});
|
||||
}
|
||||
|
||||
static HWY_NOINLINE void CrossEntropyLossGrad(
|
||||
const float* HWY_RESTRICT x, float* HWY_RESTRICT grad,
|
||||
const Prompt& prompt, size_t vocab_size) {
|
||||
HWY_ASSERT(!prompt.tokens.empty());
|
||||
const float scaling = -1.0 / std::log(2.0);
|
||||
size_t num_tokens = prompt.tokens.size() - 1;
|
||||
hwy::ZeroBytes(grad, num_tokens * vocab_size * sizeof(grad[0]));
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
if (pos + 1 < prompt.context_size) {
|
||||
continue;
|
||||
}
|
||||
const int next_token = prompt.tokens[pos + 1];
|
||||
grad[pos * vocab_size + next_token] =
|
||||
scaling / x[pos * vocab_size + next_token];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void CrossEntropyLossBackwardPassInl(const Prompt& prompt,
|
||||
const ModelWeightsPtrs<T>& weights,
|
||||
const ForwardPass<float>& forward,
|
||||
ModelWeightsPtrs<T>& grad,
|
||||
ForwardPass<float>& backward,
|
||||
RowVectorBatch<float>& inv_timescale,
|
||||
hwy::ThreadPool& pool) {
|
||||
const ModelConfig& config = weights.weights_config;
|
||||
const size_t kVocabSize = config.vocab_size;
|
||||
const size_t model_dim = config.model_dim;
|
||||
const size_t kLayers = config.layer_configs.size();
|
||||
const float kEmbScaling = EmbeddingScaling(model_dim);
|
||||
HWY_ASSERT(!config.absolute_pe);
|
||||
HWY_ASSERT(config.layer_configs[0].post_norm == PostNormType::None);
|
||||
HWY_ASSERT(config.layer_configs[0].kv_heads == 1);
|
||||
|
||||
HWY_DASSERT(prompt.context_size > 0);
|
||||
HWY_DASSERT(prompt.context_size < prompt.tokens.size());
|
||||
const size_t num_tokens = prompt.tokens.size() - 1;
|
||||
|
||||
CrossEntropyLossGrad(forward.probs.data(), backward.logits.data(), prompt,
|
||||
kVocabSize);
|
||||
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
SoftmaxVJP(forward.probs.data() + pos * kVocabSize,
|
||||
backward.logits.data() + pos * kVocabSize,
|
||||
kVocabSize);
|
||||
}
|
||||
|
||||
if (config.final_cap > 0.0f) {
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
SoftcapVJP(config.final_cap, forward.logits.data() + pos * kVocabSize,
|
||||
backward.logits.data() + pos * kVocabSize, kVocabSize);
|
||||
}
|
||||
}
|
||||
|
||||
MatMulVJP(weights.embedder_input_embedding.data(),
|
||||
forward.final_norm_output.data(), backward.logits.data(), model_dim,
|
||||
kVocabSize, num_tokens, grad.embedder_input_embedding.data(),
|
||||
backward.final_norm_output.data(), pool);
|
||||
|
||||
RMSNormVJP(weights.final_norm_scale.data(), forward.final_layer_output.data(),
|
||||
backward.final_norm_output.data(), model_dim, num_tokens,
|
||||
grad.final_norm_scale.data(), backward.final_layer_output.data(),
|
||||
pool);
|
||||
|
||||
for (int layer = static_cast<int>(kLayers) - 1; layer >= 0; --layer) {
|
||||
auto layer_config = config.layer_configs[layer];
|
||||
// TODO(szabadka) Implement Griffin layer vjp.
|
||||
HWY_ASSERT(layer_config.type == LayerAttentionType::kGemma);
|
||||
float* next_layer_grad = layer + 1 < kLayers
|
||||
? backward.layers[layer + 1].input.data()
|
||||
: backward.final_layer_output.data();
|
||||
LayerVJP(*weights.GetLayer(layer), forward.layers[layer], next_layer_grad,
|
||||
num_tokens, *grad.GetLayer(layer), backward.layers[layer],
|
||||
inv_timescale, pool);
|
||||
}
|
||||
|
||||
InputEmbeddingVJP(weights.embedder_input_embedding.data(), prompt.tokens,
|
||||
kEmbScaling, backward.layers[0].input.data(),
|
||||
grad.embedder_input_embedding.data(), model_dim);
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||
} // namespace HWY_NAMESPACE
|
||||
} // namespace gcpp
|
||||
HWY_AFTER_NAMESPACE();
|
||||
|
||||
#endif // NOLINT
|
||||
|
|
@ -1,73 +0,0 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "backprop/backward.h"
|
||||
|
||||
#include "backprop/activations.h"
|
||||
#include "backprop/prompt.h"
|
||||
#include "gemma/common.h"
|
||||
#include "gemma/weights.h"
|
||||
#include "util/allocator.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
||||
// Compiles this file for multiple architectures via "foreach_target.h", to
|
||||
// which we pass the filename via macro 'argument'.
|
||||
// clang-format off
|
||||
#undef HWY_TARGET_INCLUDE
|
||||
#define HWY_TARGET_INCLUDE "backprop/backward.cc" // NOLINT
|
||||
// clang-format on
|
||||
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||
|
||||
#include "hwy/highway.h"
|
||||
// After highway.h
|
||||
#include "backprop/backward-inl.h"
|
||||
|
||||
HWY_BEFORE_NAMESPACE();
|
||||
namespace gcpp {
|
||||
namespace HWY_NAMESPACE {
|
||||
|
||||
void CrossEntropyLossBackwardPassT(const Prompt& prompt,
|
||||
const ModelWeightsPtrs<float>& weights,
|
||||
const ForwardPass<float>& forward,
|
||||
ModelWeightsPtrs<float>& grad,
|
||||
ForwardPass<float>& backward,
|
||||
RowVectorBatch<float>& inv_timescale,
|
||||
hwy::ThreadPool& pool) {
|
||||
CrossEntropyLossBackwardPassInl(prompt, weights, forward, grad, backward,
|
||||
inv_timescale, pool);
|
||||
}
|
||||
|
||||
} // namespace HWY_NAMESPACE
|
||||
} // namespace gcpp
|
||||
HWY_AFTER_NAMESPACE();
|
||||
|
||||
#if HWY_ONCE
|
||||
namespace gcpp {
|
||||
|
||||
HWY_EXPORT(CrossEntropyLossBackwardPassT);
|
||||
|
||||
void CrossEntropyLossBackwardPass(const Prompt& prompt,
|
||||
const ModelWeightsPtrs<float>& weights,
|
||||
const ForwardPass<float>& forward,
|
||||
ModelWeightsPtrs<float>& grad,
|
||||
ForwardPass<float>& backward,
|
||||
RowVectorBatch<float>& inv_timescale,
|
||||
hwy::ThreadPool& pool) {
|
||||
return HWY_DYNAMIC_DISPATCH(CrossEntropyLossBackwardPassT)(
|
||||
prompt, weights, forward, grad, backward, inv_timescale, pool);
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
#endif // HWY_ONCE
|
||||
|
|
@ -1,37 +0,0 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_H_
|
||||
|
||||
#include "backprop/activations.h"
|
||||
#include "backprop/prompt.h"
|
||||
#include "gemma/weights.h"
|
||||
#include "util/allocator.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
void CrossEntropyLossBackwardPass(const Prompt& prompt,
|
||||
const ModelWeightsPtrs<float>& weights,
|
||||
const ForwardPass<float>& forward,
|
||||
ModelWeightsPtrs<float>& grad,
|
||||
ForwardPass<float>& backward,
|
||||
RowVectorBatch<float>& inv_timescale,
|
||||
hwy::ThreadPool& pool);
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_H_
|
||||
|
|
@ -1,349 +0,0 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_SCALAR_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_SCALAR_H_
|
||||
|
||||
#include <stddef.h>
|
||||
#include <string.h>
|
||||
|
||||
#include <cmath>
|
||||
#include <vector>
|
||||
|
||||
#include "backprop/activations.h"
|
||||
#include "backprop/common_scalar.h"
|
||||
#include "backprop/prompt.h"
|
||||
#include "gemma/common.h" // EmbeddingScaling
|
||||
#include "gemma/weights.h"
|
||||
|
||||
namespace gcpp {
|
||||
template<typename T>
|
||||
void MatMulVJPT(const T* w, const T* x, const T* dy, T* dw, T* dx,
|
||||
size_t N, size_t M, size_t K) {
|
||||
memset(dx, 0, M * K * sizeof(dx[0]));
|
||||
for (size_t i = 0; i < K; ++i) {
|
||||
for (size_t j = 0; j < N; ++j) {
|
||||
MulByConstAndAddT(dy[i * N + j], &x[i * M], &dw[j * M], M);
|
||||
MulByConstAndAddT(dy[i * N + j], &w[j * M], &dx[i * M], M);
|
||||
}
|
||||
}
|
||||
}
|
||||
template<typename T>
|
||||
void MultiHeadMatMulVJPT(const T* w, const T* x, const T* dy, T* dw, T* dx,
|
||||
size_t H, size_t N, size_t M, size_t K) {
|
||||
memset(dx, 0, H * M * K * sizeof(dx[0]));
|
||||
for (size_t i = 0; i < K; ++i) {
|
||||
for (size_t j = 0; j < N; ++j) {
|
||||
for (size_t h = 0; h < H; ++h) {
|
||||
MulByConstAndAddT(dy[i * N + j], &x[i * H * M + h * M],
|
||||
&dw[h * N * M + j * M], M);
|
||||
MulByConstAndAddT(dy[i * N + j], &w[h * N * M + j * M],
|
||||
&dx[i * H * M + h * M], M);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void RMSNormVJPT(const T* w, const T* x, const T* dy, T* dw, T* dx,
|
||||
size_t N, size_t K) {
|
||||
for (size_t i = 0; i < K; ++i) {
|
||||
constexpr T eps(1e-6);
|
||||
T ss = SquaredL2(x + i * N, N);
|
||||
ss = T(1.0) / std::sqrt(ss / T(N) + eps);
|
||||
for (size_t j = 0; j < N; ++j) {
|
||||
dw[j] += dy[i * N + j] * x[i * N + j] * ss;
|
||||
}
|
||||
const T ss3 = ss * ss * ss / T(N);
|
||||
T tmp = 0.0;
|
||||
for (size_t j = 0; j < N; ++j) {
|
||||
tmp += (T(1.0) + w[j]) * dy[i* N + j] * x[i * N + j];
|
||||
}
|
||||
tmp *= ss3;
|
||||
for (size_t j = 0; j < N; ++j) {
|
||||
dx[i * N + j] = ss * (T(1.0) + w[j]) * dy[i* N + j] - tmp * x[i * N + j];
|
||||
}
|
||||
}
|
||||
}
|
||||
template<typename T>
|
||||
void SoftmaxVJPT(const T* y, T* dy, size_t N) {
|
||||
T sum = {};
|
||||
for (size_t i = 0; i < N; ++i) {
|
||||
sum += y[i] * dy[i];
|
||||
}
|
||||
for (size_t i = 0; i < N; ++i) {
|
||||
dy[i] = y[i] * (dy[i] - sum);
|
||||
}
|
||||
}
|
||||
template<typename T>
|
||||
void SoftmaxVJPT(const T* y, T* dy, size_t N, size_t K) {
|
||||
for (size_t i = 0; i < K; ++i) {
|
||||
SoftmaxVJPT(y + i * N, dy + i * N, N);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
T GeluDerivative(T x) {
|
||||
static const T kMul = 0.044715;
|
||||
static const T kSqrt2OverPi = 0.797884560804236;
|
||||
static const T kMul2 = kSqrt2OverPi * T(3.0) * kMul;
|
||||
|
||||
const T x2 = x * x;
|
||||
const T x3 = x2 * x;
|
||||
const T arg = kSqrt2OverPi * (kMul * x3 + x);
|
||||
const T tanh = std::tanh(arg);
|
||||
const T cdf = T(0.5) * (T(1.0) + tanh);
|
||||
const T dtanh = T(1.0) - tanh * tanh;
|
||||
const T darg = kMul2 * x2 + kSqrt2OverPi;
|
||||
return T(0.5) * x * dtanh * darg + cdf;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void GatedGeluVJP(const T* in, const T* d_out, T* d_in, size_t N, size_t K) {
|
||||
for (size_t i = 0; i < K; ++i) {
|
||||
const T* x1 = in + i * 2 * N;
|
||||
const T* x2 = x1 + N;
|
||||
const T* v = d_out + i * N;
|
||||
T* dx1 = d_in + i * 2 * N;
|
||||
T* dx2 = dx1 + N;
|
||||
for (size_t j = 0; j < N; ++j) {
|
||||
dx1[j] = v[j] * x2[j] * GeluDerivative(x1[j]);
|
||||
dx2[j] = v[j] * Gelu(x1[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void MaskedAttentionVJP(const T* qkv, const T* doutput, T* dqkv,
|
||||
size_t num_tokens, size_t kHeads, size_t qkv_dim,
|
||||
size_t seq_len) {
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
const size_t offset = pos * (kHeads + 2) * qkv_dim;
|
||||
memset(dqkv + offset, 0, (kHeads + 1) * qkv_dim * sizeof(qkv[0]));
|
||||
}
|
||||
for (size_t head = 0; head < kHeads; ++head) {
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
const size_t qoffs = (pos * (kHeads + 2) + head) * qkv_dim;
|
||||
const size_t aoffs = head * seq_len + pos * kHeads * seq_len;
|
||||
const T* q = qkv + qoffs;
|
||||
const T* dout = doutput + aoffs;
|
||||
T* dq = dqkv + qoffs;
|
||||
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
|
||||
const size_t koffs = (pos2 * (kHeads + 2) + kHeads) * qkv_dim;
|
||||
const T* k = qkv + koffs;
|
||||
T* dk = dqkv + koffs;
|
||||
MulByConstAndAddT(dout[pos2], k, dq, qkv_dim);
|
||||
MulByConstAndAddT(dout[pos2], q, dk, qkv_dim);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void MaskedSoftmaxVJPT(const T* y, T* dy, size_t num_tokens, size_t kHeads,
|
||||
size_t seq_len) {
|
||||
for (size_t head = 0; head < kHeads; ++head) {
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
size_t offset = pos * kHeads * seq_len + head * seq_len;
|
||||
SoftmaxVJPT(y + offset, dy + offset, pos + 1);
|
||||
memset(dy + offset + pos + 1, 0, (seq_len - pos - 1) * sizeof(T));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void MixByAttentionVJP(const T* qkv, const T* attention, const T* doutput,
|
||||
T* dqkv, T* dattention, size_t num_tokens, size_t kHeads,
|
||||
size_t qkv_dim, size_t seq_len) {
|
||||
auto v_offset = [&](size_t pos) {
|
||||
return (pos * (kHeads + 2) + kHeads + 1) * qkv_dim;
|
||||
};
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
memset(&dqkv[v_offset(pos)], 0, qkv_dim * sizeof(qkv[0]));
|
||||
}
|
||||
for (size_t head = 0; head < kHeads; ++head) {
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
const size_t offset = head * qkv_dim + pos * kHeads * qkv_dim;
|
||||
const size_t aoffset = head * seq_len + pos * kHeads * seq_len;
|
||||
const T* att = &attention[aoffset];
|
||||
const T* dout = &doutput[offset];
|
||||
T* datt = &dattention[aoffset];
|
||||
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
|
||||
datt[pos2] = DotT(dout, &qkv[v_offset(pos2)], qkv_dim);
|
||||
MulByConstAndAddT(att[pos2], dout, &dqkv[v_offset(pos2)], qkv_dim);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void InputEmbeddingVJPT(const T* w, const std::vector<int>& tokens, T scaling,
|
||||
const T* dy, T* dw, size_t N) {
|
||||
const size_t num_tokens = tokens.empty() ? 0 : tokens.size() - 1;
|
||||
for (size_t i = 0; i < num_tokens; ++i) {
|
||||
int token = tokens[i];
|
||||
MulByConstAndAddT(scaling, dy + i * N, dw + token * N, N);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void LayerVJP(const LayerWeightsPtrs<T>& weights,
|
||||
const ForwardLayer<T>& forward, const T* dy,
|
||||
LayerWeightsPtrs<T>& grad, ForwardLayer<T>& backward,
|
||||
size_t num_tokens) {
|
||||
const LayerConfig& layer_config = weights.layer_config;
|
||||
const size_t model_dim = layer_config.model_dim;
|
||||
const size_t seq_len = forward.input.Rows();
|
||||
const size_t qkv_dim = layer_config.qkv_dim;
|
||||
const size_t kHeads = layer_config.heads;
|
||||
const size_t kFFHiddenDim = layer_config.ff_hidden_dim;
|
||||
const T kQueryScale = 1.0 / std::sqrt(T(qkv_dim));
|
||||
|
||||
MatMulVJPT(weights.linear_w.data(), forward.ffw_hidden_gated.data(), dy,
|
||||
grad.linear_w.data(), backward.ffw_hidden_gated.data(), model_dim,
|
||||
kFFHiddenDim, num_tokens);
|
||||
|
||||
GatedGeluVJP(forward.ffw_hidden.data(), backward.ffw_hidden_gated.data(),
|
||||
backward.ffw_hidden.data(), kFFHiddenDim, num_tokens);
|
||||
|
||||
MatMulVJPT(weights.gating_einsum_w.data(), forward.bf_pre_ffw_rms_out.data(),
|
||||
backward.ffw_hidden.data(), grad.gating_einsum_w.data(),
|
||||
backward.bf_pre_ffw_rms_out.data(), kFFHiddenDim * 2, model_dim,
|
||||
num_tokens);
|
||||
|
||||
RMSNormVJPT(weights.pre_ffw_norm_scale.data(), forward.attention_out.data(),
|
||||
backward.bf_pre_ffw_rms_out.data(),
|
||||
grad.pre_ffw_norm_scale.data(), backward.attention_out.data(),
|
||||
model_dim, num_tokens);
|
||||
|
||||
AddFromT(dy, backward.attention_out.data(), num_tokens * model_dim);
|
||||
|
||||
MultiHeadMatMulVJPT(weights.attn_vec_einsum_w.data(), forward.att_out.data(),
|
||||
backward.attention_out.data(),
|
||||
grad.attn_vec_einsum_w.data(), backward.att_out.data(),
|
||||
kHeads, model_dim, qkv_dim, num_tokens);
|
||||
|
||||
MixByAttentionVJP(forward.qkv.data(), forward.att.data(),
|
||||
backward.att_out.data(), backward.qkv.data(),
|
||||
backward.att.data(), num_tokens, kHeads, qkv_dim, seq_len);
|
||||
|
||||
MaskedSoftmaxVJPT(forward.att.data(), backward.att.data(), num_tokens, kHeads,
|
||||
seq_len);
|
||||
|
||||
MaskedAttentionVJP(forward.qkv.data(), backward.att.data(),
|
||||
backward.qkv.data(), num_tokens, kHeads, qkv_dim, seq_len);
|
||||
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
T* qkv = backward.qkv.data() + pos * (kHeads + 2) * qkv_dim;
|
||||
MulByConstT(kQueryScale, qkv, kHeads * qkv_dim);
|
||||
}
|
||||
|
||||
for (int pos = 0; pos < num_tokens; ++pos) {
|
||||
T* qkv = backward.qkv.data() + pos * (kHeads + 2) * qkv_dim;
|
||||
for (size_t h = 0; h <= kHeads; ++h) {
|
||||
Rope(qkv + h * qkv_dim, qkv_dim, -pos);
|
||||
}
|
||||
}
|
||||
|
||||
MatMulVJPT(weights.qkv_einsum_w.data(), forward.pre_att_rms_out.data(),
|
||||
backward.qkv.data(), grad.qkv_einsum_w.data(),
|
||||
backward.pre_att_rms_out.data(), (kHeads + 2) * qkv_dim, model_dim,
|
||||
num_tokens);
|
||||
RMSNormVJPT(weights.pre_attention_norm_scale.data(), forward.input.data(),
|
||||
backward.pre_att_rms_out.data(),
|
||||
grad.pre_attention_norm_scale.data(), backward.input.data(),
|
||||
model_dim, num_tokens);
|
||||
|
||||
AddFromT(backward.attention_out.data(), backward.input.data(),
|
||||
num_tokens * model_dim);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void SoftcapVJPT(float cap, const T* y, T* dy, size_t N) {
|
||||
const T inv_cap = T{1.0} / static_cast<T>(cap);
|
||||
for (size_t i = 0; i < N; ++i) {
|
||||
T scaled = y[i] * inv_cap; // tanh
|
||||
dy[i] *= (T{1.0} - scaled * scaled);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void CrossEntropyLossGrad(const T* x, T* dx, const Prompt& prompt, size_t V) {
|
||||
T scaling = -1.0 / std::log(2.0);
|
||||
const std::vector<int> tokens = prompt.tokens;
|
||||
const size_t num_tokens = tokens.empty() ? 0 : tokens.size() - 1;
|
||||
memset(dx, 0, V * num_tokens * sizeof(x[0]));
|
||||
for (size_t i = 0; i < num_tokens; ++i) {
|
||||
if (i + 1 < prompt.context_size) {
|
||||
continue;
|
||||
}
|
||||
const int next_token = tokens[i + 1];
|
||||
dx[i * V + next_token] = scaling / x[i * V + next_token];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void CrossEntropyLossBackwardPass(const Prompt& prompt,
|
||||
const ModelWeightsPtrs<T>& weights,
|
||||
const ForwardPass<T>& forward,
|
||||
ModelWeightsPtrs<T>& grad,
|
||||
ForwardPass<T>& backward) {
|
||||
const ModelConfig& config = weights.weights_config;
|
||||
const size_t model_dim = config.model_dim;
|
||||
const size_t vocab_size = config.vocab_size;
|
||||
const size_t layers = config.layer_configs.size();
|
||||
const std::vector<int> tokens = prompt.tokens;
|
||||
const size_t num_tokens = tokens.empty() ? 0 : tokens.size() - 1;
|
||||
|
||||
CrossEntropyLossGrad(forward.probs.data(), backward.logits.data(), prompt,
|
||||
vocab_size);
|
||||
|
||||
SoftmaxVJPT(forward.probs.data(), backward.logits.data(), vocab_size,
|
||||
num_tokens);
|
||||
|
||||
if (config.final_cap > 0.0f) {
|
||||
for (size_t i = 0; i < num_tokens; ++i) {
|
||||
SoftcapVJPT(config.final_cap, forward.logits.data() + i * vocab_size,
|
||||
backward.logits.data() + i * vocab_size, vocab_size);
|
||||
}
|
||||
}
|
||||
|
||||
MatMulVJPT(
|
||||
weights.embedder_input_embedding.data(), forward.final_norm_output.data(),
|
||||
backward.logits.data(), grad.embedder_input_embedding.data(),
|
||||
backward.final_norm_output.data(), vocab_size, model_dim, num_tokens);
|
||||
|
||||
RMSNormVJPT(weights.final_norm_scale.data(),
|
||||
forward.final_layer_output.data(),
|
||||
backward.final_norm_output.data(), grad.final_norm_scale.data(),
|
||||
backward.final_layer_output.data(), model_dim, num_tokens);
|
||||
|
||||
for (int layer = static_cast<int>(layers) - 1; layer >= 0; --layer) {
|
||||
T* next_layer_grad = layer + 1 < layers
|
||||
? backward.layers[layer + 1].input.data()
|
||||
: backward.final_layer_output.data();
|
||||
LayerVJP(*weights.GetLayer(layer), forward.layers[layer], next_layer_grad,
|
||||
*grad.GetLayer(layer), backward.layers[layer], num_tokens);
|
||||
}
|
||||
|
||||
const T kEmbScaling = EmbeddingScaling(model_dim);
|
||||
InputEmbeddingVJPT(weights.embedder_input_embedding.data(), tokens,
|
||||
kEmbScaling, backward.layers[0].input.data(),
|
||||
grad.embedder_input_embedding.data(), model_dim);
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_SCALAR_H_
|
||||
|
|
@ -1,635 +0,0 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "backprop/backward_scalar.h"
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdio.h>
|
||||
#include <string.h> // memcpy
|
||||
|
||||
#include <complex>
|
||||
#include <limits>
|
||||
#include <random>
|
||||
#include <vector>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "backprop/activations.h"
|
||||
#include "backprop/common_scalar.h"
|
||||
#include "backprop/forward_scalar.h"
|
||||
#include "backprop/prompt.h"
|
||||
#include "backprop/sampler.h"
|
||||
#include "backprop/test_util.h"
|
||||
#include "compression/compress.h"
|
||||
#include "gemma/configs.h"
|
||||
#include "gemma/weights.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
TEST(BackPropTest, MatMulVJP) {
|
||||
static const size_t kRows = 8;
|
||||
static const size_t kCols = 64;
|
||||
static const size_t kTokens = 5;
|
||||
std::mt19937 gen(42);
|
||||
using T = double;
|
||||
using TC = std::complex<T>;
|
||||
MatStorageT<T> weights("weights", kRows, kCols);
|
||||
MatStorageT<T> x("x", kTokens, kCols);
|
||||
MatStorageT<T> grad("grad", kRows, kCols);
|
||||
MatStorageT<T> dx("dx", kTokens, kCols);
|
||||
MatStorageT<TC> c_weights("c_weights", kRows, kCols);
|
||||
MatStorageT<TC> c_x("c_x", kTokens, kCols);
|
||||
MatStorageT<TC> c_y("c_y", kTokens, kRows);
|
||||
MatStorageT<T> dy("dy", kTokens, kRows);
|
||||
|
||||
for (int iter = 0; iter < 10; ++iter) {
|
||||
RandInit(weights, 1.0 * (1 << iter), gen);
|
||||
RandInit(x, 1.0 * (1 << iter), gen);
|
||||
RandInit(dy, 1.0, gen);
|
||||
Complexify(weights, c_weights);
|
||||
Complexify(x, c_x);
|
||||
auto func = [&]() {
|
||||
MatMulT(c_weights.data(), c_x.data(), c_y.data(), kRows, kCols, kTokens);
|
||||
return DotT(dy.data(), c_y.data(), kTokens * kRows);
|
||||
};
|
||||
grad.ZeroInit();
|
||||
MatMulVJPT(weights.data(), x.data(), dy.data(), grad.data(), dx.data(),
|
||||
kRows, kCols, kTokens);
|
||||
TestGradient(dx, c_x, func, 1e-11, 1e-12, __LINE__);
|
||||
TestGradient(grad, c_weights, func, 1e-14, 1e-12, __LINE__);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(BackPropTest, MultiHeadMatMulVJP) {
|
||||
static const size_t kRows = 2;
|
||||
static const size_t kCols = 16;
|
||||
static const size_t kHeads = 4;
|
||||
static const size_t kTokens = 3;
|
||||
std::mt19937 gen(42);
|
||||
using T = double;
|
||||
using TC = std::complex<T>;
|
||||
MatStorageT<T> weights("weights", kRows, kCols * kHeads);
|
||||
MatStorageT<T> x("x", kTokens, kCols * kHeads);
|
||||
MatStorageT<T> grad("grad", kRows, kCols * kHeads);
|
||||
MatStorageT<T> dx("dx", kTokens, kCols * kHeads);
|
||||
MatStorageT<TC> c_weights("c_weights", kRows, kCols * kHeads);
|
||||
MatStorageT<TC> c_x("c_x", kTokens, kCols * kHeads);
|
||||
MatStorageT<TC> c_y("c_y", kTokens, kRows);
|
||||
MatStorageT<T> dy("dy", kTokens, kRows);
|
||||
|
||||
for (int iter = 0; iter < 10; ++iter) {
|
||||
RandInit(weights, 1.0 * (1 << iter), gen);
|
||||
RandInit(x, 1.0 * (1 << iter), gen);
|
||||
RandInit(dy, 1.0, gen);
|
||||
Complexify(weights, c_weights);
|
||||
Complexify(x, c_x);
|
||||
auto func = [&]() {
|
||||
MultiHeadMatMul(c_weights.data(), c_x.data(), c_y.data(), kHeads, kRows,
|
||||
kCols, kTokens);
|
||||
return DotT(dy.data(), c_y.data(), kTokens * kRows);
|
||||
};
|
||||
grad.ZeroInit();
|
||||
MultiHeadMatMulVJPT(weights.data(), x.data(), dy.data(), grad.data(),
|
||||
dx.data(), kHeads, kRows, kCols, kTokens);
|
||||
TestGradient(dx, c_x, func, 1e-15, 1e-13, __LINE__);
|
||||
TestGradient(grad, c_weights, func, 1e-15, 1e-13, __LINE__);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(BackPropTest, RMSNormVJP) {
|
||||
static const size_t K = 2;
|
||||
static const size_t N = 64;
|
||||
std::mt19937 gen(42);
|
||||
using T = double;
|
||||
using TC = std::complex<T>;
|
||||
MatStorageT<T> weights("weights", N, 1);
|
||||
MatStorageT<T> grad("grad", N, 1);
|
||||
MatStorageT<T> x("x", K, N);
|
||||
MatStorageT<T> dx("dx", K, N);
|
||||
MatStorageT<T> dy("dy", K, N);
|
||||
MatStorageT<TC> c_weights("c_weights", N, 1);
|
||||
MatStorageT<TC> c_x("c_x", K, N);
|
||||
MatStorageT<TC> c_y("c_y", K, N);
|
||||
|
||||
for (int iter = 0; iter < 10; ++iter) {
|
||||
RandInit(weights, 1.0 * (1 << iter), gen);
|
||||
RandInit(x, 1.0 * (1 << iter), gen);
|
||||
Complexify(weights, c_weights);
|
||||
Complexify(x, c_x);
|
||||
RandInit(dy, 1.0, gen);
|
||||
auto func = [&]() {
|
||||
RMSNormT(c_weights.data(), c_x.data(), c_y.data(), N, K);
|
||||
return DotT(dy.data(), c_y.data(), K * N);
|
||||
};
|
||||
grad.ZeroInit();
|
||||
RMSNormVJPT(weights.data(), x.data(), dy.data(), grad.data(), dx.data(),
|
||||
N, K);
|
||||
TestGradient(dx, c_x, func, 1e-15, 1e-14, __LINE__);
|
||||
TestGradient(grad, c_weights, func, 1e-15, 1e-14, __LINE__);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(BackPropTest, SoftmaxVJP) {
|
||||
static const size_t N = 64;
|
||||
std::mt19937 gen(42);
|
||||
using T = double;
|
||||
using TC = std::complex<T>;
|
||||
MatStorageT<T> x("x", N, 1);
|
||||
MatStorageT<T> dx("dx", N, 1);
|
||||
MatStorageT<T> dy("dy", N, 1);
|
||||
MatStorageT<TC> c_x("c_x", N, 1);
|
||||
MatStorageT<TC> c_y("c_y", N, 1);
|
||||
|
||||
for (int iter = 0; iter < 10; ++iter) {
|
||||
RandInit(x, 1.0 * (1 << iter), gen);
|
||||
Complexify(x, c_x);
|
||||
RandInit(dy, 1.0, gen);
|
||||
auto func = [&]() {
|
||||
memcpy(c_y.data(), c_x.data(), c_x.SizeBytes());
|
||||
Softmax(c_y.data(), N);
|
||||
return DotT(dy.data(), c_y.data(), N);
|
||||
};
|
||||
Softmax(x.data(), N);
|
||||
memcpy(dx.data(), dy.data(), dx.SizeBytes());
|
||||
SoftmaxVJPT(x.data(), dx.data(), N);
|
||||
TestGradient(dx, c_x, func, 1e-15, 1e-15, __LINE__);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(BackPropTest, MaskedSoftmaxVJP) {
|
||||
static const size_t kSeqLen = 16;
|
||||
static const size_t kHeads = 2;
|
||||
static const size_t kTokens = 14;
|
||||
static const size_t N = kTokens * kHeads * kSeqLen;
|
||||
std::mt19937 gen(42);
|
||||
using T = double;
|
||||
using TC = std::complex<T>;
|
||||
MatStorageT<T> x("x", N, 1);
|
||||
MatStorageT<T> dy("dy", N, 1);
|
||||
MatStorageT<T> dx("dx", N, 1);
|
||||
MatStorageT<TC> c_x("c_x", N, 1);
|
||||
MatStorageT<TC> c_y("c_y", N, 1);
|
||||
dx.ZeroInit();
|
||||
|
||||
for (int iter = 0; iter < 10; ++iter) {
|
||||
RandInit(x, 1.0 * (1 << iter), gen);
|
||||
Complexify(x, c_x);
|
||||
RandInit(dy, 1.0, gen);
|
||||
auto func = [&]() {
|
||||
memcpy(c_y.data(), c_x.data(),
|
||||
kTokens * kHeads * kSeqLen * sizeof(c_x.At(0)));
|
||||
MaskedSoftmax(c_y.data(), kTokens, kHeads, kSeqLen);
|
||||
return DotT(dy.data(), c_y.data(), N);
|
||||
};
|
||||
MaskedSoftmax(x.data(), kTokens, kHeads, kSeqLen);
|
||||
memcpy(dx.data(), dy.data(), kTokens * kHeads * kSeqLen * sizeof(dx.At(0)));
|
||||
MaskedSoftmaxVJPT(x.data(), dx.data(), kTokens, kHeads, kSeqLen);
|
||||
TestGradient(dx, c_x, func, 1e-14, 1e-15, __LINE__);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(BackPropTest, SoftcapVJP) {
|
||||
static const size_t N = 64;
|
||||
std::mt19937 gen(42);
|
||||
using T = double;
|
||||
using TC = std::complex<T>;
|
||||
MatStorageT<T> x("x", N, 1);
|
||||
MatStorageT<T> dx("dx", N, 1);
|
||||
MatStorageT<T> dy("dy", N, 1);
|
||||
MatStorageT<TC> c_x("c_x", N, 1);
|
||||
MatStorageT<TC> c_y("c_y", N, 1);
|
||||
|
||||
constexpr float kCap = 30.0f;
|
||||
for (int iter = 0; iter < 10; ++iter) {
|
||||
RandInit(x, 1.0 * (1 << iter), gen);
|
||||
Complexify(x, c_x);
|
||||
RandInit(dy, 1.0, gen);
|
||||
auto func = [&]() {
|
||||
memcpy(c_y.data(), c_x.data(), N * sizeof(c_x.At(0)));
|
||||
Softcap(kCap, c_y.data(), N);
|
||||
return DotT(dy.data(), c_y.data(), N);
|
||||
};
|
||||
Softcap(kCap, x.data(), N);
|
||||
memcpy(dx.data(), dy.data(), dx.SizeBytes());
|
||||
SoftcapVJPT(kCap, x.data(), dx.data(), N);
|
||||
TestGradient(dx, c_x, func, 1e-15, 1e-14, __LINE__);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(BackPropTest, CrossEntropyLossGrad) {
|
||||
static const size_t K = 8;
|
||||
static const size_t V = 64;
|
||||
std::mt19937 gen(42);
|
||||
using T = double;
|
||||
using TC = std::complex<T>;
|
||||
MatStorageT<T> x("x", K, V);
|
||||
MatStorageT<T> dx("dx", K, V);
|
||||
MatStorageT<TC> c_x("c_x", K, V);
|
||||
Prompt prompt;
|
||||
prompt.tokens = { 0, 1, 2, 3, 0, 3, 2, 1, 0 };
|
||||
|
||||
const float kCap = 30.0f;
|
||||
for (int iter = 0; iter < 10; ++iter) {
|
||||
prompt.context_size = 1 + (iter % 6);
|
||||
RandInit(x, 1.0 * (1 << iter), gen);
|
||||
Softcap(kCap, x.data(), V * K);
|
||||
Softmax(x.data(), V, K);
|
||||
CrossEntropyLossGrad(x.data(), dx.data(), prompt, V);
|
||||
Complexify(x, c_x);
|
||||
auto func = [&]() {
|
||||
return CrossEntropyLoss(c_x.data(), prompt, V);
|
||||
};
|
||||
TestGradient(dx, c_x, func, 1e-100, 1e-15, __LINE__);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(BackPropTest, GatedGeluVJP) {
|
||||
static const size_t K = 2;
|
||||
static const size_t N = 64;
|
||||
std::mt19937 gen(42);
|
||||
using T = double;
|
||||
using TC = std::complex<T>;
|
||||
MatStorageT<T> x("x", K, 2 * N);
|
||||
MatStorageT<T> dx("dx", K, 2 * N);
|
||||
MatStorageT<T> dy("dy", K, N);
|
||||
MatStorageT<TC> c_x("c_x", K, 2 * N);
|
||||
MatStorageT<TC> c_y("c_y", K, N);
|
||||
|
||||
for (int iter = 0; iter < 10; ++iter) {
|
||||
RandInit(x, 1.0, gen);
|
||||
Complexify(x, c_x);
|
||||
RandInit(dy, 1.0, gen);
|
||||
auto func = [&]() {
|
||||
GatedGelu(c_x.data(), c_y.data(), N, K);
|
||||
return DotT(dy.data(), c_y.data(), N * K);
|
||||
};
|
||||
GatedGeluVJP(x.data(), dy.data(), dx.data(), N, K);
|
||||
TestGradient(dx, c_x, func, 1e-15, 1e-15, __LINE__);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(BackPropTest, MaskedAttentionVJP) {
|
||||
static const size_t kSeqLen = 16;
|
||||
static const size_t kHeads = 2;
|
||||
static const size_t kQKVDim = 8;
|
||||
static const size_t kTokens = 14;
|
||||
static const size_t kQKVSize = kSeqLen * (kHeads + 2) * kQKVDim;
|
||||
static const size_t kOutSize = kTokens * kHeads * kSeqLen;
|
||||
std::mt19937 gen(42);
|
||||
using T = double;
|
||||
using TC = std::complex<T>;
|
||||
MatStorageT<T> x("x", kQKVSize, 1);
|
||||
MatStorageT<T> dx("dx", kQKVSize, 1);
|
||||
MatStorageT<T> dy("dy", kOutSize, 1);
|
||||
MatStorageT<TC> c_x("c_x", kQKVSize, 1);
|
||||
MatStorageT<TC> c_y("c_y", kOutSize, 1);
|
||||
dx.ZeroInit();
|
||||
c_y.ZeroInit();
|
||||
|
||||
for (int iter = 0; iter < 10; ++iter) {
|
||||
RandInit(x, 1.0, gen);
|
||||
Complexify(x, c_x);
|
||||
RandInit(dy, 1.0, gen);
|
||||
auto func = [&]() {
|
||||
MaskedAttention(c_x.data(), c_y.data(), kTokens, kHeads, kQKVDim,
|
||||
kSeqLen);
|
||||
return DotT(dy.data(), c_y.data(), kOutSize);
|
||||
};
|
||||
MaskedAttentionVJP(x.data(), dy.data(), dx.data(),
|
||||
kTokens, kHeads, kQKVDim, kSeqLen);
|
||||
TestGradient(dx, c_x, func, 1e-14, 1e-15, __LINE__);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(BackPropTest, MixByAttentionVJP) {
|
||||
static const size_t kSeqLen = 16;
|
||||
static const size_t kHeads = 2;
|
||||
static const size_t kQKVDim = 8;
|
||||
static const size_t kTokens = 14;
|
||||
static const size_t kQKVSize = kSeqLen * (kHeads + 2) * kQKVDim;
|
||||
static const size_t kAttnSize = kSeqLen * kHeads * kSeqLen;
|
||||
static const size_t kOutSize = kSeqLen * kHeads * kQKVDim;
|
||||
std::mt19937 gen(42);
|
||||
using T = double;
|
||||
using TC = std::complex<T>;
|
||||
MatStorageT<T> qkv("qkv", kQKVSize, 1);
|
||||
MatStorageT<T> dqkv("dqkv", kQKVSize, 1);
|
||||
MatStorageT<T> attn("attn", kAttnSize, 1);
|
||||
MatStorageT<T> dattn("dattn", kAttnSize, 1);
|
||||
MatStorageT<T> dy("dy", kOutSize, 1);
|
||||
MatStorageT<TC> c_qkv("c_qkv", kQKVSize, 1);
|
||||
MatStorageT<TC> c_attn("c_attn", kAttnSize, 1);
|
||||
MatStorageT<TC> c_y("c_y", kOutSize, 1);
|
||||
dqkv.ZeroInit();
|
||||
dattn.ZeroInit();
|
||||
c_y.ZeroInit();
|
||||
|
||||
for (int iter = 0; iter < 10; ++iter) {
|
||||
RandInit(qkv, 1.0, gen);
|
||||
RandInit(attn, 1.0, gen);
|
||||
Complexify(qkv, c_qkv);
|
||||
Complexify(attn, c_attn);
|
||||
RandInit(dy, 1.0, gen);
|
||||
auto func = [&]() {
|
||||
MixByAttention(c_qkv.data(), c_attn.data(), c_y.data(),
|
||||
kTokens, kHeads, kQKVDim, kSeqLen);
|
||||
return DotT(dy.data(), c_y.data(), kOutSize);
|
||||
};
|
||||
MixByAttentionVJP(qkv.data(), attn.data(), dy.data(), dqkv.data(),
|
||||
dattn.data(), kTokens, kHeads, kQKVDim, kSeqLen);
|
||||
TestGradient(dqkv, c_qkv, func, 1e-14, 1e-15, __LINE__);
|
||||
TestGradient(dattn, c_attn, func, 1e-14, 1e-15, __LINE__);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(BackPropTest, InputEmbeddingVJP) {
|
||||
static const size_t kSeqLen = 8;
|
||||
static const size_t kVocabSize = 4;
|
||||
static const size_t kModelDim = 16;
|
||||
std::mt19937 gen(42);
|
||||
using T = double;
|
||||
using TC = std::complex<T>;
|
||||
MatStorageT<T> weights("weights", kVocabSize, kModelDim);
|
||||
MatStorageT<T> grad("grad", kVocabSize, kModelDim);
|
||||
MatStorageT<T> dy("dy", kSeqLen, kModelDim);
|
||||
MatStorageT<TC> c_weights("c_weights", kVocabSize, kModelDim);
|
||||
MatStorageT<TC> c_y("c_y", kSeqLen, kModelDim);
|
||||
std::vector<int> tokens = { 0, 1, 2, 3, 0, 1, 2 };
|
||||
size_t num_tokens = tokens.size() - 1;
|
||||
|
||||
for (size_t iter = 0; iter < 10; ++iter) {
|
||||
RandInit(weights, 1.0, gen);
|
||||
RandInit(dy, 1.0, gen);
|
||||
Complexify(weights, c_weights);
|
||||
auto func = [&]() {
|
||||
InputEmbedding(c_weights.data(), tokens, TC(3.0), c_y.data(), kModelDim);
|
||||
return DotT(dy.data(), c_y.data(), num_tokens * kModelDim);
|
||||
};
|
||||
grad.ZeroInit();
|
||||
InputEmbeddingVJPT(weights.data(), tokens, 3.0, dy.data(), grad.data(),
|
||||
kModelDim);
|
||||
TestGradient(grad, c_weights, func, 1e-16, 1e-14, __LINE__);
|
||||
}
|
||||
}
|
||||
|
||||
static ModelConfig TestConfig() {
|
||||
ModelConfig config;
|
||||
config.scale_names = {"att_ein", "qkv_ein", "gr_lin_x_w", "gr_lin_y_w",
|
||||
"gr_lin_out_w", "gr_gate_w", "gating_ein", "linear_w"};
|
||||
config.model_dim = 32;
|
||||
config.vocab_size = 12;
|
||||
config.seq_len = 18;
|
||||
LayerConfig layer_config;
|
||||
layer_config.model_dim = config.model_dim;
|
||||
layer_config.ff_hidden_dim = 48;
|
||||
layer_config.heads = 3;
|
||||
layer_config.kv_heads = 1;
|
||||
layer_config.qkv_dim = 12;
|
||||
config.layer_configs = {2, layer_config};
|
||||
config.num_tensor_scales = 4 * config.layer_configs.size();
|
||||
config.query_scale = QueryScaleType::SqrtKeySize;
|
||||
config.attention_window_sizes = FixedAttentionWindowSizes<2>(32);
|
||||
// This is required for optimize_test to pass.
|
||||
config.final_cap = 30.0f;
|
||||
return config;
|
||||
}
|
||||
|
||||
TEST(BackPropTest, LayerVJP) {
|
||||
std::mt19937 gen(42);
|
||||
using T = double;
|
||||
using TC = std::complex<T>;
|
||||
ModelConfig config = TestConfig();
|
||||
TensorIndex tensor_index(config, /*llm_layer_idx=*/0, /*img_layer_idx=*/-1,
|
||||
/*reshape_att=*/false);
|
||||
const size_t kOutputSize = config.seq_len * config.model_dim;
|
||||
LayerWeightsPtrs<T> weights(config.layer_configs[0], tensor_index);
|
||||
LayerWeightsPtrs<T> grad(config.layer_configs[0], tensor_index);
|
||||
ForwardLayer<T> forward(config.layer_configs[0], config.seq_len);
|
||||
ForwardLayer<T> backward(config.layer_configs[0], config.seq_len);
|
||||
LayerWeightsPtrs<TC> c_weights(config.layer_configs[0], tensor_index);
|
||||
ForwardLayer<TC> c_forward(config.layer_configs[0], config.seq_len);
|
||||
MatStorageT<T> y("y", kOutputSize, 1);
|
||||
MatStorageT<T> dy("dy", kOutputSize, 1);
|
||||
MatStorageT<TC> c_y("c_y", kOutputSize, 1);
|
||||
const size_t num_tokens = 3;
|
||||
std::vector<MatStorage> layer_storage;
|
||||
weights.Allocate(layer_storage);
|
||||
grad.Allocate(layer_storage);
|
||||
c_weights.Allocate(layer_storage);
|
||||
backward.input.ZeroInit();
|
||||
|
||||
for (size_t iter = 0; iter < 10; ++iter) {
|
||||
RandInit(weights, 1.0, gen);
|
||||
RandInit(forward.input, 1.0, gen);
|
||||
RandInit(dy, 1.0, gen);
|
||||
Complexify(weights, c_weights);
|
||||
Complexify(forward.input, c_forward.input);
|
||||
auto func = [&]() {
|
||||
ApplyLayer(c_weights, c_forward, num_tokens, c_y.data());
|
||||
return DotT(dy.data(), c_y.data(), num_tokens * config.model_dim);
|
||||
};
|
||||
grad.ZeroInit(/*layer_idx=*/0);
|
||||
ApplyLayer(weights, forward, num_tokens, y.data());
|
||||
LayerVJP(weights, forward, dy.data(), grad, backward, num_tokens);
|
||||
TestGradient(backward.input, c_forward.input, func, 1e-11, 5e-11,
|
||||
__LINE__);
|
||||
TestGradient(grad, c_weights, func, 1e-11);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(BackPropTest, EndToEnd) {
|
||||
std::mt19937 gen(42);
|
||||
using T = double;
|
||||
using TC = std::complex<T>;
|
||||
ModelConfig config = TestConfig();
|
||||
WeightsWrapper<T> weights(config);
|
||||
WeightsWrapper<T> grad(config);
|
||||
ForwardPass<T> forward(config);
|
||||
ForwardPass<T> backward(config);
|
||||
WeightsWrapper<TC> c_weights(config);
|
||||
ForwardPass<TC> c_forward(config);
|
||||
|
||||
ReverseSequenceSampler training_task({0, 0, 1, 1});
|
||||
std::vector<Prompt> batch = training_task.SampleBatch(3, gen);
|
||||
|
||||
for (const Prompt& prompt : batch) {
|
||||
ReverseSequenceSampler::LogPrompt(prompt);
|
||||
RandInit(weights.get(), 1.0, gen);
|
||||
CrossEntropyLossForwardPass(prompt, weights.get(), forward);
|
||||
grad.ZeroInit();
|
||||
CrossEntropyLossBackwardPass(
|
||||
prompt, weights.get(), forward, grad.get(), backward);
|
||||
|
||||
Complexify(weights.get(), c_weights.get());
|
||||
auto func = [&]() {
|
||||
return CrossEntropyLossForwardPass(prompt, c_weights.get(), c_forward);
|
||||
};
|
||||
|
||||
TestGradient(grad.get(), c_weights.get(), func, 1e-11);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void MulByConstAndAddT(T c, const LayerWeightsPtrs<T>& x,
|
||||
LayerWeightsPtrs<T>& out) {
|
||||
MulByConstAndAddT(c, x.pre_attention_norm_scale,
|
||||
out.pre_attention_norm_scale);
|
||||
MulByConstAndAddT(c, x.attn_vec_einsum_w, out.attn_vec_einsum_w);
|
||||
MulByConstAndAddT(c, x.qkv_einsum_w, out.qkv_einsum_w);
|
||||
MulByConstAndAddT(c, x.pre_ffw_norm_scale, out.pre_ffw_norm_scale);
|
||||
MulByConstAndAddT(c, x.gating_einsum_w, out.gating_einsum_w);
|
||||
MulByConstAndAddT(c, x.linear_w, out.linear_w);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void MulByConstAndAddT(T c, const ModelWeightsPtrs<T>& x,
|
||||
ModelWeightsPtrs<T>& out) {
|
||||
const size_t layers = x.c_layers.size();
|
||||
MulByConstAndAddT(c, x.embedder_input_embedding,
|
||||
out.embedder_input_embedding);
|
||||
MulByConstAndAddT(c, x.final_norm_scale, out.final_norm_scale);
|
||||
for (size_t i = 0; i < layers; ++i) {
|
||||
MulByConstAndAddT(c, *x.GetLayer(i), *out.GetLayer(i));
|
||||
}
|
||||
}
|
||||
|
||||
// Evaluates forward pass on a batch.
|
||||
template <typename T>
|
||||
T CrossEntropyLossForwardPass(const std::vector<Prompt>& batch,
|
||||
const WeightsWrapper<T>& weights,
|
||||
ForwardPass<T>& forward) {
|
||||
T loss = 0.0;
|
||||
for (const Prompt& prompt : batch) {
|
||||
loss += CrossEntropyLossForwardPass(prompt, weights.get(), forward);
|
||||
}
|
||||
T scale = 1.0 / batch.size();
|
||||
return loss * scale;
|
||||
}
|
||||
|
||||
// Evaluates forward pass on a batch by applying gradient with the given
|
||||
// learning rate. Does not update weights, but uses the given tmp weights
|
||||
// instead.
|
||||
template <typename T>
|
||||
T CrossEntropyLossForwardPass(T learning_rate, const std::vector<Prompt>& batch,
|
||||
const WeightsWrapper<T>& weights,
|
||||
const WeightsWrapper<T>& grad,
|
||||
WeightsWrapper<T>& tmp, ForwardPass<T>& forward) {
|
||||
tmp.CopyFrom(weights);
|
||||
const T scale = -learning_rate / batch.size();
|
||||
MulByConstAndAddT(scale, grad.get(), tmp.get());
|
||||
return CrossEntropyLossForwardPass(batch, tmp, forward);
|
||||
}
|
||||
|
||||
// Uses line search in the negative gradient direction to update weights. We do
|
||||
// this so that we can test that each step during the gradient descent can
|
||||
// decrease the objective function value.
|
||||
template <typename T>
|
||||
T FindOptimalUpdate(const WeightsWrapper<T>& grad, WeightsWrapper<T>& weights,
|
||||
WeightsWrapper<T>& tmp, ForwardPass<T>& forward,
|
||||
const std::vector<Prompt>& batch, T loss,
|
||||
T initial_learning_rate) {
|
||||
T lr0 = initial_learning_rate;
|
||||
T loss0 = CrossEntropyLossForwardPass(
|
||||
lr0, batch, weights, grad, tmp, forward);
|
||||
for (size_t iter = 0; iter < 30; ++iter) {
|
||||
T lr1 = lr0 * 0.5;
|
||||
T loss1 = CrossEntropyLossForwardPass(
|
||||
lr1, batch, weights, grad, tmp, forward);
|
||||
if (loss0 < loss && loss1 >= loss0) {
|
||||
break;
|
||||
}
|
||||
loss0 = loss1;
|
||||
lr0 = lr1;
|
||||
}
|
||||
for (size_t iter = 0; iter < 30; ++iter) {
|
||||
T lr1 = lr0 * 2.0;
|
||||
T loss1 = CrossEntropyLossForwardPass(
|
||||
lr1, batch, weights, grad, tmp, forward);
|
||||
if (loss1 >= loss0) {
|
||||
break;
|
||||
}
|
||||
loss0 = loss1;
|
||||
lr0 = lr1;
|
||||
}
|
||||
const T scale = -lr0 / batch.size();
|
||||
MulByConstAndAddT(scale, grad.get(), weights.get());
|
||||
return lr0;
|
||||
}
|
||||
|
||||
TEST(BackProptest, Convergence) {
|
||||
std::mt19937 gen(42);
|
||||
using T = float;
|
||||
using TC = std::complex<double>;
|
||||
ModelConfig config = TestConfig();
|
||||
WeightsWrapper<T> weights(config);
|
||||
WeightsWrapper<T> grad(config);
|
||||
WeightsWrapper<T> tmp(config);
|
||||
ForwardPass<T> forward(config);
|
||||
ForwardPass<T> backward(config);
|
||||
WeightsWrapper<TC> c_weights(config);
|
||||
ForwardPass<TC> c_forward(config);
|
||||
constexpr size_t kBatchSize = 5;
|
||||
ReverseSequenceSampler training_task({0, 0, 0, 1, 1});
|
||||
T learning_rate = 0.01;
|
||||
|
||||
RandInit(weights.get(), T(1.0), gen);
|
||||
|
||||
printf("Sample batch:\n");
|
||||
for (size_t i = 0; i < 10; ++i) {
|
||||
ReverseSequenceSampler::LogPrompt(training_task.Sample(gen));
|
||||
}
|
||||
|
||||
T prev_loss = std::numeric_limits<T>::max();
|
||||
bool stop = false;
|
||||
size_t step = 0;
|
||||
while (!stop) {
|
||||
T loss = 0.0;
|
||||
grad.ZeroInit();
|
||||
std::mt19937 sgen(42);
|
||||
std::vector<Prompt> batch = training_task.SampleBatch(kBatchSize, sgen);
|
||||
for (const Prompt& prompt : batch) {
|
||||
loss += CrossEntropyLossForwardPass(prompt, weights.get(), forward);
|
||||
CrossEntropyLossBackwardPass(
|
||||
prompt, weights.get(), forward, grad.get(), backward);
|
||||
}
|
||||
|
||||
if (step % 250 == 0) {
|
||||
printf("Checking gradient...\n");
|
||||
Complexify(weights.get(), c_weights.get());
|
||||
auto func = [&]() {
|
||||
TC scale = batch.size();
|
||||
return CrossEntropyLossForwardPass(batch, c_weights, c_forward) * scale;
|
||||
};
|
||||
|
||||
TestGradient(grad.get(), c_weights.get(), func, 5e-3f);
|
||||
}
|
||||
|
||||
loss /= batch.size();
|
||||
EXPECT_LT(loss, prev_loss);
|
||||
stop = step >= 10000 || loss < 1e-2;
|
||||
if (step % 10 == 0 || stop) {
|
||||
printf("step: %5zu loss: %.15f learning_rate: %.15f\n",
|
||||
step, loss, learning_rate);
|
||||
}
|
||||
if (!stop) {
|
||||
learning_rate = FindOptimalUpdate(
|
||||
grad, weights, tmp, forward, batch, loss, learning_rate);
|
||||
++step;
|
||||
}
|
||||
prev_loss = loss;
|
||||
}
|
||||
EXPECT_LT(step, 1000);
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
@ -1,279 +0,0 @@
|
|||
// Copyright 2023 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef HWY_DISABLED_TARGETS
|
||||
#define HWY_DISABLED_TARGETS HWY_SCALAR
|
||||
#endif
|
||||
|
||||
#include <stddef.h>
|
||||
|
||||
#include <complex>
|
||||
#include <cstdlib> // std::abs
|
||||
#include <random>
|
||||
#include <vector>
|
||||
|
||||
#include "backprop/activations.h"
|
||||
#include "backprop/backward_scalar.h"
|
||||
#include "backprop/common_scalar.h"
|
||||
#include "backprop/forward_scalar.h"
|
||||
#include "backprop/prompt.h"
|
||||
#include "backprop/sampler.h"
|
||||
#include "backprop/test_util.h"
|
||||
#include "gemma/configs.h"
|
||||
#include "ops/ops.h"
|
||||
#include "util/threading.h"
|
||||
#include "hwy/base.h"
|
||||
|
||||
// clang-format off
|
||||
#undef HWY_TARGET_INCLUDE
|
||||
#define HWY_TARGET_INCLUDE "backprop/backward_test.cc" //NOLINT
|
||||
// clang-format on
|
||||
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||
#include "hwy/highway.h"
|
||||
#include "hwy/tests/test_util-inl.h"
|
||||
// After highway.h
|
||||
#include "backprop/backward-inl.h"
|
||||
#include "backprop/forward-inl.h"
|
||||
#include "compression/compress.h"
|
||||
#include "ops/ops-inl.h"
|
||||
#include "util/allocator.h"
|
||||
|
||||
HWY_BEFORE_NAMESPACE();
|
||||
namespace gcpp {
|
||||
namespace HWY_NAMESPACE {
|
||||
|
||||
void TestMatMulVJP() {
|
||||
static const size_t kRows = 8;
|
||||
static const size_t kCols = 64;
|
||||
static const size_t kTokens = 5;
|
||||
const BoundedTopology topology(BoundedSlice(0, 1), BoundedSlice(0, 8));
|
||||
Allocator::Init(topology);
|
||||
gcpp::NestedPools pools(topology, 1, /*pin=*/Tristate::kFalse);
|
||||
std::mt19937 gen(42);
|
||||
MatStorageT<float> weights("weights", kRows, kCols);
|
||||
MatStorageT<float> x("x", kTokens, kCols);
|
||||
MatStorageT<float> dy("dy", kTokens, kRows);
|
||||
MatStorageT<float> grad("grad", kRows, kCols);
|
||||
MatStorageT<float> dx("dx", kTokens, kCols);
|
||||
MatStorageT<float> grad_scalar("grad_scalar", kRows, kCols);
|
||||
MatStorageT<float> dx_scalar("dx_scalar", kTokens, kCols);
|
||||
using TC = std::complex<double>;
|
||||
MatStorageT<TC> c_weights("c_weights", kRows, kCols);
|
||||
MatStorageT<TC> c_x("c_x", kTokens, kCols);
|
||||
MatStorageT<TC> c_y("c_y", kTokens, kRows);
|
||||
|
||||
for (int iter = 0; iter < 10; ++iter) {
|
||||
RandInit(weights, 1.0f * (1 << iter), gen);
|
||||
RandInit(x, 1.0f * (1 << iter), gen);
|
||||
RandInit(dy, 1.0f, gen);
|
||||
Complexify(weights, c_weights);
|
||||
Complexify(x, c_x);
|
||||
auto func = [&]() {
|
||||
MatMulT(c_weights.data(), c_x.data(), c_y.data(), kRows, kCols, kTokens);
|
||||
return DotT(dy.data(), c_y.data(), kTokens * kRows);
|
||||
};
|
||||
|
||||
grad.ZeroInit();
|
||||
MatMulVJP(weights.data(), x.data(), dy.data(), kCols, kRows, kTokens,
|
||||
grad.data(), dx.data(), pools.Pool());
|
||||
TestGradient(dx, c_x, func, 5e-5f, 5e-5f, __LINE__);
|
||||
TestGradient(grad, c_weights, func, 5e-5f, 5e-5f, __LINE__);
|
||||
|
||||
grad_scalar.ZeroInit();
|
||||
MatMulVJPT(weights.data(), x.data(), dy.data(), grad_scalar.data(),
|
||||
dx_scalar.data(), kRows, kCols, kTokens);
|
||||
TestNear(dx, dx_scalar, 5e-5, 1e-4, __LINE__);
|
||||
TestNear(grad, grad_scalar, 5e-5, 5e-5, __LINE__);
|
||||
}
|
||||
}
|
||||
|
||||
void TestMultiHeadMatMulVJP() {
|
||||
static const size_t kRows = 2;
|
||||
static const size_t kCols = 16;
|
||||
static const size_t kHeads = 4;
|
||||
static const size_t kTokens = 3;
|
||||
const BoundedTopology topology(BoundedSlice(0, 1), BoundedSlice(0, 8));
|
||||
Allocator::Init(topology);
|
||||
gcpp::NestedPools pools(topology, 1, /*pin=*/Tristate::kFalse);
|
||||
std::mt19937 gen(42);
|
||||
MatStorageT<float> weights("weights", kRows, kCols * kHeads);
|
||||
MatStorageT<float> x("x", kTokens, kCols * kHeads);
|
||||
MatStorageT<float> grad("grad", kRows, kCols * kHeads);
|
||||
MatStorageT<float> dx("dx", kTokens, kCols * kHeads);
|
||||
MatStorageT<float> dy("dy", kTokens, kRows);
|
||||
MatStorageT<float> grad_scalar("grad_scalar", kRows, kCols * kHeads);
|
||||
MatStorageT<float> dx_scalar("dx_scalar", kTokens, kCols * kHeads);
|
||||
using TC = std::complex<double>;
|
||||
MatStorageT<TC> c_weights("c_weights", kRows, kCols * kHeads);
|
||||
MatStorageT<TC> c_x("c_x", kTokens, kCols * kHeads);
|
||||
MatStorageT<TC> c_y("c_y", kTokens, kRows);
|
||||
|
||||
for (int iter = 0; iter < 10; ++iter) {
|
||||
RandInit(weights, 1.0f * (1 << iter), gen);
|
||||
RandInit(x, 1.0f * (1 << iter), gen);
|
||||
RandInit(dy, 1.0f, gen);
|
||||
Complexify(weights, c_weights);
|
||||
Complexify(x, c_x);
|
||||
auto func = [&]() {
|
||||
MultiHeadMatMul(c_weights.data(), c_x.data(), c_y.data(), kHeads, kRows,
|
||||
kCols, kTokens);
|
||||
return DotT(dy.data(), c_y.data(), kTokens * kRows);
|
||||
};
|
||||
|
||||
grad.ZeroInit();
|
||||
MultiHeadMatMulVJP(weights.data(), x.data(), dy.data(), kHeads, kCols,
|
||||
kRows, kTokens, grad.data(), dx.data(), pools.Pool());
|
||||
TestGradient(dx, c_x, func, 5e-5f, 5e-5f, __LINE__);
|
||||
TestGradient(grad, c_weights, func, 5e-5f, 5e-5f, __LINE__);
|
||||
|
||||
grad_scalar.ZeroInit();
|
||||
MultiHeadMatMulVJPT(weights.data(), x.data(), dy.data(), grad_scalar.data(),
|
||||
dx_scalar.data(), kHeads, kRows, kCols, kTokens);
|
||||
TestNear(dx, dx_scalar, 5e-5, 5e-5, __LINE__);
|
||||
TestNear(grad, grad_scalar, 5e-5, 5e-5, __LINE__);
|
||||
}
|
||||
}
|
||||
|
||||
void TestRMSNormVJP() {
|
||||
static const size_t K = 2;
|
||||
static const size_t N = 64;
|
||||
const BoundedTopology topology(BoundedSlice(0, 1), BoundedSlice(0, 8));
|
||||
Allocator::Init(topology);
|
||||
gcpp::NestedPools pools(topology, 1, /*pin=*/Tristate::kFalse);
|
||||
std::mt19937 gen(42);
|
||||
MatStorageT<float> weights("weights", N, 1);
|
||||
MatStorageT<float> x("x", K, N);
|
||||
MatStorageT<float> grad("grad", N, 1);
|
||||
MatStorageT<float> dx("dx", K, N);
|
||||
MatStorageT<float> dy("dy", K, N);
|
||||
MatStorageT<float> grad_scalar("grad_scalar", N, 1);
|
||||
MatStorageT<float> dx_scalar("dx_scalar", K, N);
|
||||
using TC = std::complex<double>;
|
||||
MatStorageT<TC> c_weights("c_weights", N, 1);
|
||||
MatStorageT<TC> c_x("c_x", K, N);
|
||||
MatStorageT<TC> c_y("c_y", K, N);
|
||||
|
||||
for (int iter = 0; iter < 10; ++iter) {
|
||||
RandInit(weights, 1.0f * (1 << iter), gen);
|
||||
RandInit(x, 1.0f * (1 << iter), gen);
|
||||
RandInit(dy, 1.0f, gen);
|
||||
Complexify(weights, c_weights);
|
||||
Complexify(x, c_x);
|
||||
auto func = [&]() {
|
||||
RMSNormT(c_weights.data(), c_x.data(), c_y.data(), N, K);
|
||||
return DotT(dy.data(), c_y.data(), K * N);
|
||||
};
|
||||
|
||||
grad.ZeroInit();
|
||||
RMSNormVJP(weights.data(), x.data(), dy.data(), N, K, grad.data(),
|
||||
dx.data(), pools.Pool());
|
||||
TestGradient(dx, c_x, func, 5e-5f, 5e-5f, __LINE__);
|
||||
TestGradient(grad, c_weights, func, 5e-5f, 5e-5f, __LINE__);
|
||||
|
||||
grad_scalar.ZeroInit();
|
||||
RMSNormVJPT(weights.data(), x.data(), dy.data(), grad_scalar.data(),
|
||||
dx_scalar.data(), N, K);
|
||||
TestNear(dx, dx_scalar, 0, 2e-5, __LINE__);
|
||||
TestNear(grad, grad_scalar, 0, 2e-5, __LINE__);
|
||||
}
|
||||
}
|
||||
|
||||
static ModelConfig TestConfig() {
|
||||
ModelConfig config;
|
||||
config.scale_names = {"att_ein", "qkv_ein", "gr_lin_x_w", "gr_lin_y_w",
|
||||
"gr_lin_out_w", "gr_gate_w", "gating_ein", "linear_w"};
|
||||
config.model_dim = 32;
|
||||
config.vocab_size = 16;
|
||||
config.seq_len = 24;
|
||||
LayerConfig layer_config;
|
||||
layer_config.model_dim = config.model_dim;
|
||||
layer_config.ff_hidden_dim = 64;
|
||||
layer_config.heads = 3;
|
||||
layer_config.kv_heads = 1;
|
||||
layer_config.qkv_dim = 16;
|
||||
config.layer_configs = {2, layer_config};
|
||||
config.num_tensor_scales = 4 * config.layer_configs.size();
|
||||
config.query_scale = QueryScaleType::SqrtKeySize;
|
||||
config.attention_window_sizes = FixedAttentionWindowSizes<2>(32);
|
||||
// This is required for optimize_test to pass.
|
||||
config.att_cap = 50.0f;
|
||||
config.final_cap = 30.0f;
|
||||
return config;
|
||||
}
|
||||
|
||||
void TestEndToEnd() {
|
||||
std::mt19937 gen(42);
|
||||
const BoundedTopology topology(BoundedSlice(0, 1), BoundedSlice(0, 1));
|
||||
Allocator::Init(topology);
|
||||
gcpp::NestedPools pools(topology, 1, /*pin=*/Tristate::kFalse);
|
||||
ModelConfig config = TestConfig();
|
||||
WeightsWrapper<float> weights(config);
|
||||
WeightsWrapper<float> grad(config);
|
||||
ForwardPass<float> forward0(config);
|
||||
ForwardPass<float> forward1(config);
|
||||
ForwardPass<float> backward(config);
|
||||
using TC = std::complex<double>;
|
||||
WeightsWrapper<TC> c_weights(config);
|
||||
ForwardPass<TC> c_forward(config);
|
||||
|
||||
ReverseSequenceSampler training_task({0, 0, 1, 1});
|
||||
std::vector<Prompt> batch = training_task.SampleBatch(3, gen);
|
||||
|
||||
RowVectorBatch<float> inv_timescale = CreateInvTimescale(
|
||||
config.layer_configs[0].qkv_dim,
|
||||
config.layer_configs[0].post_qk == PostQKType::HalfRope);
|
||||
for (const Prompt& prompt : batch) {
|
||||
ReverseSequenceSampler::LogPrompt(prompt);
|
||||
RandInit(weights.get(), 1.0f, gen);
|
||||
|
||||
float loss0 = CrossEntropyLossForwardPass(prompt, weights.get(), forward0);
|
||||
|
||||
float loss1 = CrossEntropyLossForwardPass(
|
||||
prompt.tokens, prompt.context_size, weights.get(), forward1,
|
||||
inv_timescale, pools.Pool());
|
||||
|
||||
EXPECT_NEAR(loss1, loss0, std::abs(loss0) * 2e-5);
|
||||
|
||||
grad.ZeroInit();
|
||||
CrossEntropyLossBackwardPassInl(prompt, weights.get(), forward1, grad.get(),
|
||||
backward, inv_timescale, pools.Pool());
|
||||
|
||||
Complexify(weights.get(), c_weights.get());
|
||||
auto func = [&]() {
|
||||
return CrossEntropyLossForwardPass(prompt, c_weights.get(), c_forward);
|
||||
};
|
||||
|
||||
TestGradient(grad.get(), c_weights.get(), func, 2e-3f);
|
||||
}
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||
} // namespace HWY_NAMESPACE
|
||||
} // namespace gcpp
|
||||
HWY_AFTER_NAMESPACE();
|
||||
|
||||
#if HWY_ONCE
|
||||
|
||||
namespace gcpp {
|
||||
HWY_BEFORE_TEST(BackwardTest);
|
||||
HWY_EXPORT_AND_TEST_P(BackwardTest, TestMatMulVJP);
|
||||
HWY_EXPORT_AND_TEST_P(BackwardTest, TestMultiHeadMatMulVJP);
|
||||
HWY_EXPORT_AND_TEST_P(BackwardTest, TestRMSNormVJP);
|
||||
HWY_EXPORT_AND_TEST_P(BackwardTest, TestEndToEnd);
|
||||
HWY_AFTER_TEST();
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif
|
||||
|
|
@ -1,121 +0,0 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_SCALAR_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_SCALAR_H_
|
||||
|
||||
#include <stddef.h>
|
||||
|
||||
#include <complex>
|
||||
|
||||
#include "compression/compress.h" // MatStorageT
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
template<typename T, typename U>
|
||||
U DotT(const T* a, const U* b, size_t N) {
|
||||
U sum = {};
|
||||
for (size_t i = 0; i < N; ++i) {
|
||||
sum += a[i] * b[i];
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
|
||||
template<>
|
||||
inline std::complex<double> DotT(const float* a, const std::complex<double>* b,
|
||||
size_t N) {
|
||||
std::complex<double> sum = {};
|
||||
for (size_t i = 0; i < N; ++i) {
|
||||
sum += static_cast<double>(a[i]) * b[i];
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void MulByConstT(T c, T* x, size_t N) {
|
||||
for (size_t i = 0; i < N; ++i) {
|
||||
x[i] *= c;
|
||||
}
|
||||
}
|
||||
|
||||
// out += c * x
|
||||
template<typename T>
|
||||
void MulByConstAndAddT(T c, const T* x, T* out, size_t N) {
|
||||
for (size_t i = 0; i < N; ++i) {
|
||||
out[i] += c * x[i];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void MulByConstAndAddT(T c, const MatPtrT<T>& x, MatPtrT<T>& out) {
|
||||
MulByConstAndAddT(c, x.data(), out.data(), x.NumElements());
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void AddFromT(const T* a, T* out, size_t N) {
|
||||
for (size_t i = 0; i < N; ++i) {
|
||||
out[i] += a[i];
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
T SquaredL2(const T* x, size_t N) {
|
||||
T sum = {};
|
||||
for (size_t i = 0; i < N; ++i) {
|
||||
sum += x[i] * x[i];
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
T Gelu(T x) {
|
||||
static const T kMul = 0.044715;
|
||||
static const T kSqrt2OverPi = 0.797884560804236;
|
||||
|
||||
const T x3 = x * x * x;
|
||||
const T arg = kSqrt2OverPi * (kMul * x3 + x);
|
||||
const T cdf = T(0.5) * (T(1.0) + std::tanh(arg));
|
||||
return x * cdf;
|
||||
}
|
||||
|
||||
template<typename T, typename U>
|
||||
void Rope(T* x, U base, size_t N, int i) {
|
||||
const size_t N2 = N / 2;
|
||||
for (size_t dim = 0; dim < N2; ++dim) {
|
||||
const T freq_exponents = T(2 * dim) / T(N);
|
||||
const T timescale = std::pow(base, freq_exponents);
|
||||
const T theta = T(i) / timescale;
|
||||
const T cos_val = std::cos(theta);
|
||||
const T sin_val = std::sin(theta);
|
||||
const T x0 = x[dim];
|
||||
const T x1 = x[dim + N2];
|
||||
x[dim] = x0 * cos_val - x1 * sin_val;
|
||||
x[dim + N2] = x0 * sin_val + x1 * cos_val;
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void Rope(T* x, size_t N, int i) {
|
||||
Rope(x, T(10000.0), N, i);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void Rope(std::complex<T>* x, size_t N, int i) {
|
||||
Rope(x, T(10000.0), N, i);
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_SCALAR_H_
|
||||
|
|
@ -1,296 +0,0 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Include guard for non-SIMD code.
|
||||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_INL_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_INL_H_
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include <cmath>
|
||||
#include <vector>
|
||||
|
||||
#include "backprop/activations.h"
|
||||
#include "gemma/common.h"
|
||||
#include "gemma/configs.h"
|
||||
#include "gemma/weights.h"
|
||||
#include "util/allocator.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_INL_H_
|
||||
|
||||
// Include guard for (potentially) SIMD code.
|
||||
#if defined(THIRD_PARTY_GEMMA_CPP_FORWARD_TOGGLE) == defined(HWY_TARGET_TOGGLE)
|
||||
#ifdef THIRD_PARTY_GEMMA_CPP_FORWARD_TOGGLE
|
||||
#undef THIRD_PARTY_GEMMA_CPP_FORWARD_TOGGLE
|
||||
#else
|
||||
#define THIRD_PARTY_GEMMA_CPP_FORWARD_TOGGLE
|
||||
#endif
|
||||
|
||||
#include "hwy/highway.h"
|
||||
// After highway.h
|
||||
#include "ops/matvec-inl.h"
|
||||
#include "ops/ops-inl.h"
|
||||
|
||||
HWY_BEFORE_NAMESPACE();
|
||||
namespace gcpp {
|
||||
namespace HWY_NAMESPACE {
|
||||
|
||||
template <typename ArrayT>
|
||||
void InputEmbedding(const ArrayT& weights, const std::vector<int>& prompt,
|
||||
const float scaling, float* HWY_RESTRICT output,
|
||||
size_t model_dim, size_t vocab_size) {
|
||||
const hn::ScalableTag<float> df;
|
||||
HWY_ASSERT(!prompt.empty());
|
||||
for (size_t pos = 0; pos < prompt.size() - 1; ++pos) {
|
||||
int token = prompt[pos];
|
||||
DecompressAndZeroPad(df, MakeSpan(weights.data(), model_dim * vocab_size),
|
||||
token * model_dim, output + pos * model_dim,
|
||||
model_dim);
|
||||
MulByConst(scaling, output + pos * model_dim, model_dim);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename WT, typename XT, typename OutT>
|
||||
void ApplyRMSNorm(const WT* HWY_RESTRICT weights, const XT* HWY_RESTRICT x,
|
||||
size_t model_dim, size_t num_tokens,
|
||||
OutT* HWY_RESTRICT output,
|
||||
hwy::ThreadPool& pool) {
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
const size_t offset = pos * model_dim;
|
||||
RMSNorm(x + offset, weights, output + offset, model_dim);
|
||||
}
|
||||
}
|
||||
|
||||
static HWY_NOINLINE float CrossEntropyLoss(const float* HWY_RESTRICT probs,
|
||||
const std::vector<int>& prompt,
|
||||
size_t context_size,
|
||||
size_t vocab_size,
|
||||
hwy::ThreadPool& pool) {
|
||||
HWY_ASSERT(!prompt.empty());
|
||||
float loss = 0.0f;
|
||||
for (size_t pos = 0; pos < prompt.size() - 1; ++pos) {
|
||||
if (pos + 1 < context_size) {
|
||||
continue; // next token is part of context, don't try to predict it
|
||||
}
|
||||
const int next_token = prompt[pos + 1];
|
||||
loss += std::log(probs[pos * vocab_size + next_token]);
|
||||
}
|
||||
float scaling = -1.0 / std::log(2.0);
|
||||
return loss * scaling;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void ApplyForwardLayer(const LayerWeightsPtrs<T>& weights,
|
||||
ForwardLayer<float>& activations, size_t num_tokens,
|
||||
float* HWY_RESTRICT output,
|
||||
const RowVectorBatch<float>& inv_timescale,
|
||||
hwy::ThreadPool& pool) {
|
||||
const LayerConfig& config = weights.layer_config;
|
||||
const size_t model_dim = config.model_dim;
|
||||
const size_t kSeqLen = activations.input.Rows();
|
||||
const size_t kQKVDim = config.qkv_dim;
|
||||
const size_t kHeads = config.heads;
|
||||
static const float query_scale =
|
||||
static_cast<float>(1.0 / sqrt(static_cast<double>(kQKVDim)));
|
||||
HWY_ASSERT(num_tokens <= kSeqLen);
|
||||
|
||||
ApplyRMSNorm(weights.pre_attention_norm_scale.data(),
|
||||
activations.input.data(), model_dim, num_tokens,
|
||||
activations.pre_att_rms_out.data(), pool);
|
||||
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
MatVec(weights.qkv_einsum_w, 0, (kHeads + 2) * kQKVDim, model_dim,
|
||||
activations.pre_att_rms_out.data() + pos * model_dim,
|
||||
activations.qkv.data() + pos * (kHeads + 2) * kQKVDim, pool);
|
||||
}
|
||||
const size_t num_tasks = kHeads * num_tokens;
|
||||
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
float* HWY_RESTRICT k =
|
||||
activations.qkv.data() + (pos * (kHeads + 2) + kHeads) * kQKVDim;
|
||||
Rope(k, kQKVDim, inv_timescale.Const(), pos);
|
||||
}
|
||||
pool.Run(0, num_tasks, [&](const uint64_t task, size_t thread) HWY_ATTR {
|
||||
const size_t head = task % kHeads;
|
||||
const size_t pos = task / kHeads;
|
||||
float* HWY_RESTRICT q =
|
||||
activations.qkv.data() + (pos * (kHeads + 2) + head) * kQKVDim;
|
||||
Rope(q, kQKVDim, inv_timescale.Const(), pos);
|
||||
MulByConst(query_scale, q, kQKVDim);
|
||||
});
|
||||
|
||||
pool.Run(0, num_tasks, [&](const uint64_t task, size_t thread) HWY_ATTR {
|
||||
const size_t head = task % kHeads;
|
||||
const size_t pos = task / kHeads;
|
||||
const float* HWY_RESTRICT q =
|
||||
activations.qkv.data() + (pos * (kHeads + 2) + head) * kQKVDim;
|
||||
float* HWY_RESTRICT head_att =
|
||||
activations.att.data() + (pos * kHeads + head) * kSeqLen;
|
||||
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
|
||||
const float* HWY_RESTRICT k2 =
|
||||
activations.qkv.data() + (pos2 * (kHeads + 2) + kHeads) * kQKVDim;
|
||||
const float score = Dot(q, k2, kQKVDim);
|
||||
head_att[pos2] = score;
|
||||
}
|
||||
});
|
||||
|
||||
pool.Run(0, num_tasks, [&](const uint64_t task, size_t thread) HWY_ATTR {
|
||||
const size_t head = task % kHeads;
|
||||
const size_t pos = task / kHeads;
|
||||
float* HWY_RESTRICT head_att =
|
||||
activations.att.data() + (pos * kHeads + head) * kSeqLen;
|
||||
Softmax(head_att, pos + 1);
|
||||
});
|
||||
|
||||
pool.Run(0, num_tasks, [&](const uint64_t task, size_t thread) HWY_ATTR {
|
||||
const size_t head = task % kHeads;
|
||||
const size_t pos = task / kHeads;
|
||||
const float* HWY_RESTRICT head_att =
|
||||
activations.att.data() + (pos * kHeads + head) * kSeqLen;
|
||||
float* HWY_RESTRICT att_out =
|
||||
activations.att_out.data() + (pos * kHeads + head) * kQKVDim;
|
||||
hwy::ZeroBytes(att_out, kQKVDim * sizeof(*att_out));
|
||||
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
|
||||
float* HWY_RESTRICT v2 =
|
||||
activations.qkv.data() + (pos2 * (kHeads + 2) + kHeads + 1) * kQKVDim;
|
||||
MulByConstAndAdd(head_att[pos2], v2, att_out, kQKVDim);
|
||||
}
|
||||
});
|
||||
|
||||
activations.attention_out.ZeroInit();
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
for (size_t head = 0; head < kHeads; ++head) {
|
||||
MatVec(
|
||||
weights.attn_vec_einsum_w, head * model_dim * kQKVDim, model_dim,
|
||||
kQKVDim,
|
||||
activations.att_out.data() + pos * kHeads * kQKVDim + head * kQKVDim,
|
||||
activations.att_post1.data() + pos * model_dim, pool);
|
||||
AddFrom(activations.att_post1.data() + pos * model_dim,
|
||||
activations.attention_out.data() + pos * model_dim, model_dim);
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
AddFrom(activations.input.data() + pos * model_dim,
|
||||
activations.attention_out.data() + pos * model_dim, model_dim);
|
||||
}
|
||||
|
||||
ApplyRMSNorm(weights.pre_ffw_norm_scale.data(),
|
||||
activations.attention_out.data(), model_dim, num_tokens,
|
||||
activations.bf_pre_ffw_rms_out.data(), pool);
|
||||
const size_t kFFHiddenDim = config.ff_hidden_dim;
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
MatVec(weights.gating_einsum_w, 0, kFFHiddenDim * 2, model_dim,
|
||||
activations.bf_pre_ffw_rms_out.data() + pos * model_dim,
|
||||
activations.ffw_hidden.data() + pos * kFFHiddenDim * 2, pool);
|
||||
}
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
const size_t hidden_offset = pos * kFFHiddenDim * 2;
|
||||
const float* HWY_RESTRICT out =
|
||||
activations.ffw_hidden.data() + hidden_offset;
|
||||
const float* HWY_RESTRICT out_mul = out + kFFHiddenDim;
|
||||
float* HWY_RESTRICT out_gated =
|
||||
activations.ffw_hidden_gated.data() + pos * kFFHiddenDim;
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
using DF = hn::ScalableTag<float>;
|
||||
DF df;
|
||||
for (size_t i = 0; i < kFFHiddenDim; i += Lanes(df)) {
|
||||
const auto y = hn::Load(df, out + i);
|
||||
const auto x = hn::Load(df, out_mul + i);
|
||||
hn::Store(hn::Mul(x, Gelu(df, y)), df, out_gated + i);
|
||||
}
|
||||
}
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
MatVec(weights.linear_w, 0, model_dim, kFFHiddenDim,
|
||||
activations.ffw_hidden_gated.data() + pos * kFFHiddenDim,
|
||||
output + pos * model_dim, pool);
|
||||
}
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
AddFrom(activations.attention_out.data() + pos * model_dim,
|
||||
output + pos * model_dim, model_dim);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
float CrossEntropyLossForwardPass(const std::vector<int>& prompt,
|
||||
size_t context_size,
|
||||
const ModelWeightsPtrs<T>& weights,
|
||||
ForwardPass<float>& forward,
|
||||
const RowVectorBatch<float>& inv_timescale,
|
||||
hwy::ThreadPool& pool) {
|
||||
const ModelConfig& config = weights.weights_config;
|
||||
const size_t vocab_size = config.vocab_size;
|
||||
const size_t model_dim = config.model_dim;
|
||||
const size_t layers = config.layer_configs.size();
|
||||
const float emb_scaling = EmbeddingScaling(model_dim);
|
||||
HWY_ASSERT(!config.absolute_pe);
|
||||
HWY_ASSERT(config.layer_configs[0].post_norm == PostNormType::None);
|
||||
HWY_ASSERT(config.layer_configs[0].kv_heads == 1);
|
||||
|
||||
HWY_DASSERT(context_size > 0);
|
||||
HWY_DASSERT(context_size < prompt.size());
|
||||
const size_t num_tokens = prompt.size() - 1;
|
||||
|
||||
InputEmbedding(weights.embedder_input_embedding, prompt, emb_scaling,
|
||||
forward.layers[0].input.data(), model_dim, vocab_size);
|
||||
|
||||
for (size_t layer = 0; layer < config.layer_configs.size(); ++layer) {
|
||||
auto type = config.layer_configs[layer].type;
|
||||
// TODO(szabadka) Implement Griffin layer.
|
||||
HWY_ASSERT(type == LayerAttentionType::kGemma);
|
||||
float* HWY_RESTRICT output = layer + 1 < layers
|
||||
? forward.layers[layer + 1].input.data()
|
||||
: forward.final_layer_output.data();
|
||||
ApplyForwardLayer(*weights.GetLayer(layer), forward.layers[layer],
|
||||
num_tokens, output, inv_timescale, pool);
|
||||
}
|
||||
|
||||
ApplyRMSNorm(weights.final_norm_scale.data(),
|
||||
forward.final_layer_output.data(), model_dim, num_tokens,
|
||||
forward.final_norm_output.data(), pool);
|
||||
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
MatVec(weights.embedder_input_embedding, 0, vocab_size, model_dim,
|
||||
forward.final_norm_output.data() + pos * model_dim,
|
||||
forward.logits.data() + pos * vocab_size, pool);
|
||||
}
|
||||
|
||||
if (config.final_cap > 0.0f) {
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
LogitsSoftCap(config.final_cap, forward.logits.data() + pos * vocab_size,
|
||||
vocab_size);
|
||||
}
|
||||
}
|
||||
|
||||
hwy::CopyBytes(forward.logits.data(), forward.probs.data(),
|
||||
num_tokens * vocab_size * sizeof(forward.logits.At(0)));
|
||||
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
Softmax(forward.probs.data() + pos * vocab_size, vocab_size);
|
||||
}
|
||||
|
||||
return CrossEntropyLoss(forward.probs.data(), prompt, context_size,
|
||||
vocab_size, pool);
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||
} // namespace HWY_NAMESPACE
|
||||
} // namespace gcpp
|
||||
HWY_AFTER_NAMESPACE();
|
||||
|
||||
#endif // NOLINT
|
||||
|
|
@ -1,68 +0,0 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "backprop/forward.h"
|
||||
|
||||
#include "backprop/activations.h"
|
||||
#include "backprop/prompt.h"
|
||||
#include "gemma/common.h"
|
||||
#include "gemma/configs.h"
|
||||
#include "util/allocator.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
||||
// Compiles this file for multiple architectures via "foreach_target.h", to
|
||||
// which we pass the filename via macro 'argument'.
|
||||
#undef HWY_TARGET_INCLUDE
|
||||
#define HWY_TARGET_INCLUDE "backprop/forward.cc" // NOLINT
|
||||
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||
|
||||
#include "hwy/highway.h"
|
||||
// After highway.h
|
||||
#include "backprop/forward-inl.h"
|
||||
#include "gemma/weights.h"
|
||||
|
||||
HWY_BEFORE_NAMESPACE();
|
||||
namespace gcpp {
|
||||
namespace HWY_NAMESPACE {
|
||||
|
||||
float CrossEntropyLossForwardPassT(const Prompt& prompt,
|
||||
const ModelWeightsPtrs<float>& weights,
|
||||
ForwardPass<float>& forward,
|
||||
RowVectorBatch<float>& inv_timescale,
|
||||
hwy::ThreadPool& pool) {
|
||||
return CrossEntropyLossForwardPass(prompt.tokens, prompt.context_size,
|
||||
weights, forward, inv_timescale, pool);
|
||||
}
|
||||
|
||||
} // namespace HWY_NAMESPACE
|
||||
} // namespace gcpp
|
||||
HWY_AFTER_NAMESPACE();
|
||||
|
||||
#if HWY_ONCE
|
||||
namespace gcpp {
|
||||
|
||||
HWY_EXPORT(CrossEntropyLossForwardPassT);
|
||||
|
||||
float CrossEntropyLossForwardPass(const Prompt& prompt,
|
||||
const ModelWeightsPtrs<float>& weights,
|
||||
ForwardPass<float>& forward,
|
||||
RowVectorBatch<float>& inv_timescale,
|
||||
hwy::ThreadPool& pool) {
|
||||
return HWY_DYNAMIC_DISPATCH(CrossEntropyLossForwardPassT)(
|
||||
prompt, weights, forward, inv_timescale, pool);
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
#endif // HWY_ONCE
|
||||
|
|
@ -1,35 +0,0 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_H_
|
||||
|
||||
#include "backprop/activations.h"
|
||||
#include "backprop/prompt.h"
|
||||
#include "gemma/weights.h"
|
||||
#include "util/allocator.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
float CrossEntropyLossForwardPass(const Prompt& prompt,
|
||||
const ModelWeightsPtrs<float>& weights,
|
||||
ForwardPass<float>& forward,
|
||||
RowVectorBatch<float>& inv_timescale,
|
||||
hwy::ThreadPool& pool);
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_H_
|
||||
|
|
@ -1,294 +0,0 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_SCALAR_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_SCALAR_H_
|
||||
|
||||
#include <stddef.h>
|
||||
#include <string.h>
|
||||
|
||||
#include <cmath>
|
||||
#include <complex>
|
||||
#include <vector>
|
||||
|
||||
#include "backprop/activations.h"
|
||||
#include "backprop/common_scalar.h"
|
||||
#include "backprop/prompt.h"
|
||||
#include "gemma/common.h" // EmbeddingScaling
|
||||
#include "gemma/weights.h"
|
||||
#include "hwy/base.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
// w is N x M matrix in row-major order, x is M x K matrix in column-major order
|
||||
// y = w * x is N x K matrix in column-major order.
|
||||
template<typename T>
|
||||
void MatMulT(const T* w, const T* x, T* y, size_t N, size_t M, size_t K) {
|
||||
for (size_t i = 0; i < K; ++i) {
|
||||
for (size_t j = 0; j < N; ++j) {
|
||||
y[i * N + j] = DotT(&w[j * M], &x[i * M], M);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// w is H concatenated N x M matrix in row-major order, x is HM x K matrix in
|
||||
// column-major order and y = w' * x is N x K matrix in column-major order,
|
||||
// where w' is the rearrangement of w into an N x HM matrix.
|
||||
template<typename T>
|
||||
void MultiHeadMatMul(const T* w, const T* x, T* y, size_t H, size_t N,
|
||||
size_t M, size_t K) {
|
||||
memset(y, 0, N * K * sizeof(y[0]));
|
||||
for (size_t i = 0; i < K; ++i) {
|
||||
for (size_t h = 0; h < H; ++h) {
|
||||
for (size_t j = 0; j < N; ++j) {
|
||||
y[i * N + j] += DotT(&w[h * N * M + j * M], &x[i * H * M + h * M], M);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void RMSNormT(const T* w, const T* x, T* out, size_t N, size_t K) {
|
||||
constexpr T eps(1e-6);
|
||||
for (size_t i = 0; i < K; ++i) {
|
||||
T ss = SquaredL2(x + i * N, N);
|
||||
ss = T(1.0) / std::sqrt(ss / T(N) + eps);
|
||||
for (size_t j = 0; j < N; j++) {
|
||||
out[i * N + j] = (T(1.0) + w[j]) * (ss * x[i * N + j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
template<typename T>
|
||||
void Softmax(T* x, size_t N) {
|
||||
T sum = {};
|
||||
auto maxreal = std::real(x[0]);
|
||||
for (size_t i = 1; i < N; ++i) {
|
||||
if (std::real(x[i]) > maxreal) {
|
||||
maxreal = std::real(x[i]);
|
||||
}
|
||||
}
|
||||
for (size_t i = 0; i < N; ++i) {
|
||||
x[i] = std::exp(x[i] - maxreal);
|
||||
sum += x[i];
|
||||
}
|
||||
T scale = T(1.0) / sum;
|
||||
for (size_t i = 0; i < N; ++i) {
|
||||
x[i] *= scale;
|
||||
}
|
||||
}
|
||||
template<typename T>
|
||||
void Softmax(T* x, size_t N, size_t K) {
|
||||
for (size_t i = 0; i < K; ++i) {
|
||||
Softmax(x + i * N, N);
|
||||
}
|
||||
}
|
||||
template <typename T>
|
||||
void Softcap(float cap, T* x, size_t N) {
|
||||
const T inv_cap = T{1.0} / static_cast<T>(cap);
|
||||
for (size_t i = 0; i < N; ++i) {
|
||||
x[i] = static_cast<T>(cap) * std::tanh(x[i] * inv_cap);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void GatedGelu(const T* in, T* out, size_t N, size_t K) {
|
||||
for (size_t i = 0; i < K; ++i) {
|
||||
const T* x1 = in + i * 2 * N;
|
||||
const T* x2 = x1 + N;
|
||||
T* y = out + i * N;
|
||||
for (size_t j = 0; j < N; ++j) {
|
||||
y[j] = x2[j] * Gelu(x1[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void InputEmbedding(const T* w, const std::vector<int>& tokens, T scaling,
|
||||
T* y, size_t N) {
|
||||
HWY_ASSERT(w != nullptr);
|
||||
HWY_ASSERT(y != nullptr);
|
||||
const size_t num_tokens = tokens.empty() ? 0 : tokens.size() - 1;
|
||||
for (size_t i = 0; i < num_tokens; ++i) {
|
||||
int token = tokens[i];
|
||||
memcpy(y + i * N, w + token * N, N * sizeof(y[0]));
|
||||
MulByConstT(scaling, y + i * N, N);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void MaskedAttention(const T* qkv, T* output, size_t num_tokens, size_t heads,
|
||||
size_t qkv_dim, size_t seq_len) {
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
for (size_t head = 0; head < heads; ++head) {
|
||||
const size_t qoffset = pos * (heads + 2) * qkv_dim;
|
||||
const size_t aoffset = pos * heads * seq_len + head * seq_len;
|
||||
const T* q = qkv + qoffset + head * qkv_dim;
|
||||
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
|
||||
const T* k = qkv + (pos2 * (heads + 2) + heads) * qkv_dim;
|
||||
output[aoffset + pos2] = DotT(q, k, qkv_dim);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
template <typename T>
|
||||
void MaskedSoftmax(T* x, size_t num_tokens, size_t heads, size_t seq_len) {
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
for (size_t head = 0; head < heads; ++head) {
|
||||
size_t offset = pos * heads * seq_len + head * seq_len;
|
||||
Softmax(x + offset, pos + 1);
|
||||
memset(x + offset + pos + 1, 0, (seq_len - pos - 1) * sizeof(T));
|
||||
}
|
||||
}
|
||||
}
|
||||
template <typename T>
|
||||
void MixByAttention(const T* qkv, const T* attention, T* output,
|
||||
size_t num_tokens, size_t heads, size_t qkv_dim,
|
||||
size_t seq_len) {
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
for (size_t head = 0; head < heads; ++head) {
|
||||
const T* att = &attention[pos * heads * seq_len + head * seq_len];
|
||||
T* out = &output[head * qkv_dim + pos * heads * qkv_dim];
|
||||
memset(out, 0, qkv_dim * sizeof(out[0]));
|
||||
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
|
||||
size_t v_offset = (pos2 * (heads + 2) + heads + 1) * qkv_dim;
|
||||
const T* v = &qkv[v_offset];
|
||||
MulByConstAndAddT(att[pos2], v, out, qkv_dim);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
template <typename T>
|
||||
void ApplyLayer(const LayerWeightsPtrs<T>& weights,
|
||||
ForwardLayer<T>& activations, size_t num_tokens, T* output) {
|
||||
const LayerConfig& layer_config = weights.layer_config;
|
||||
const size_t model_dim = layer_config.model_dim;
|
||||
const size_t seq_len = activations.input.Rows();
|
||||
const size_t qkv_dim = layer_config.qkv_dim;
|
||||
const size_t heads = layer_config.heads;
|
||||
const size_t ff_hidden_dim = layer_config.ff_hidden_dim;
|
||||
static const T query_scale = T(1.0) / std::sqrt(T(qkv_dim));
|
||||
|
||||
RMSNormT(weights.pre_attention_norm_scale.data(), activations.input.data(),
|
||||
activations.pre_att_rms_out.data(), model_dim, num_tokens);
|
||||
|
||||
MatMulT(weights.qkv_einsum_w.data(), activations.pre_att_rms_out.data(),
|
||||
activations.qkv.data(), (heads + 2) * qkv_dim, model_dim, num_tokens);
|
||||
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
T* qkv = activations.qkv.data() + pos * (heads + 2) * qkv_dim;
|
||||
for (size_t h = 0; h <= heads; ++h) {
|
||||
Rope(qkv + h * qkv_dim, qkv_dim, pos);
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
T* qkv = activations.qkv.data() + pos * (heads + 2) * qkv_dim;
|
||||
MulByConstT(query_scale, qkv, heads * qkv_dim);
|
||||
}
|
||||
|
||||
MaskedAttention(activations.qkv.data(), activations.att.data(), num_tokens,
|
||||
heads, qkv_dim, seq_len);
|
||||
|
||||
MaskedSoftmax(activations.att.data(), num_tokens, heads, seq_len);
|
||||
|
||||
MixByAttention(activations.qkv.data(), activations.att.data(),
|
||||
activations.att_out.data(), num_tokens, heads, qkv_dim,
|
||||
seq_len);
|
||||
|
||||
MultiHeadMatMul(weights.attn_vec_einsum_w.data(), activations.att_out.data(),
|
||||
activations.attention_out.data(), heads, model_dim, qkv_dim,
|
||||
num_tokens);
|
||||
|
||||
AddFromT(activations.input.data(), activations.attention_out.data(),
|
||||
num_tokens * model_dim);
|
||||
|
||||
RMSNormT(weights.pre_ffw_norm_scale.data(), activations.attention_out.data(),
|
||||
activations.bf_pre_ffw_rms_out.data(), model_dim, num_tokens);
|
||||
|
||||
MatMulT(weights.gating_einsum_w.data(), activations.bf_pre_ffw_rms_out.data(),
|
||||
activations.ffw_hidden.data(), ff_hidden_dim * 2, model_dim,
|
||||
num_tokens);
|
||||
|
||||
GatedGelu(activations.ffw_hidden.data(), activations.ffw_hidden_gated.data(),
|
||||
ff_hidden_dim, num_tokens);
|
||||
|
||||
MatMulT(weights.linear_w.data(), activations.ffw_hidden_gated.data(), output,
|
||||
model_dim, ff_hidden_dim, num_tokens);
|
||||
|
||||
AddFromT(activations.attention_out.data(), output, num_tokens * model_dim);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
T CrossEntropyLoss(const T* x, const Prompt& prompt, size_t V) {
|
||||
T loss = {};
|
||||
const std::vector<int> tokens = prompt.tokens;
|
||||
const size_t num_tokens = tokens.empty() ? 0 : tokens.size() - 1;
|
||||
for (size_t i = 0; i < num_tokens; ++i) {
|
||||
if (i + 1 < prompt.context_size) {
|
||||
continue; // next token is part of context, don't try to predict it
|
||||
}
|
||||
const int next_token = tokens[i + 1];
|
||||
loss += std::log(x[i * V + next_token]);
|
||||
}
|
||||
T scaling = -1.0 / std::log(2.0);
|
||||
return loss * scaling;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T CrossEntropyLossForwardPass(const Prompt& prompt,
|
||||
const ModelWeightsPtrs<T>& weights,
|
||||
ForwardPass<T>& forward) {
|
||||
const ModelConfig& config = weights.weights_config;
|
||||
const size_t model_dim = config.model_dim;
|
||||
const size_t vocab_size = config.vocab_size;
|
||||
const size_t layers = config.layer_configs.size();
|
||||
const std::vector<int> tokens = prompt.tokens;
|
||||
const size_t num_tokens = tokens.empty() ? 0 : tokens.size() - 1;
|
||||
|
||||
const T kEmbScaling = EmbeddingScaling(model_dim);
|
||||
InputEmbedding(weights.embedder_input_embedding.data(), tokens, kEmbScaling,
|
||||
forward.layers[0].input.data(), model_dim);
|
||||
|
||||
for (size_t layer = 0; layer < layers; ++layer) {
|
||||
T* output = layer + 1 < layers ? forward.layers[layer + 1].input.data()
|
||||
: forward.final_layer_output.data();
|
||||
ApplyLayer(*weights.GetLayer(layer), forward.layers[layer], num_tokens,
|
||||
output);
|
||||
}
|
||||
|
||||
RMSNormT(weights.final_norm_scale.data(), forward.final_layer_output.data(),
|
||||
forward.final_norm_output.data(), model_dim, num_tokens);
|
||||
|
||||
MatMulT(weights.embedder_input_embedding.data(),
|
||||
forward.final_norm_output.data(), forward.logits.data(), vocab_size,
|
||||
model_dim, num_tokens);
|
||||
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
if (config.final_cap > 0.0f) {
|
||||
Softcap(config.final_cap, forward.logits.data() + pos * vocab_size,
|
||||
vocab_size);
|
||||
}
|
||||
}
|
||||
|
||||
memcpy(forward.probs.data(), forward.logits.data(),
|
||||
num_tokens * vocab_size * sizeof(forward.logits.At(0)));
|
||||
Softmax(forward.probs.data(), vocab_size, num_tokens);
|
||||
|
||||
return CrossEntropyLoss(forward.probs.data(), prompt, vocab_size);
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_SCALAR_H_
|
||||
|
|
@ -1,155 +0,0 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <stddef.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdio>
|
||||
#include <random>
|
||||
#include <vector>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "backprop/activations.h"
|
||||
#include "backprop/backward.h"
|
||||
#include "backprop/forward.h"
|
||||
#include "backprop/optimizer.h"
|
||||
#include "backprop/prompt.h"
|
||||
#include "backprop/sampler.h"
|
||||
#include "compression/shared.h"
|
||||
#include "gemma/common.h"
|
||||
#include "gemma/configs.h"
|
||||
#include "gemma/gemma.h"
|
||||
#include "gemma/weights.h"
|
||||
#include "ops/ops.h"
|
||||
#include "util/allocator.h"
|
||||
#include "util/basics.h"
|
||||
#include "util/threading.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
TEST(OptimizeTest, GradientDescent) {
|
||||
const BoundedTopology topology(BoundedSlice(0, 1), BoundedSlice(0, 1));
|
||||
Allocator::Init(topology);
|
||||
NestedPools pools(topology, 1, /*pin=*/Tristate::kFalse);
|
||||
MatMulEnv env(topology, pools);
|
||||
hwy::ThreadPool& pool = pools.Pool();
|
||||
std::mt19937 gen(42);
|
||||
|
||||
const ModelInfo info = {
|
||||
.model = Model::GEMMA_TINY,
|
||||
.wrapping = PromptWrapping::GEMMA_IT,
|
||||
.weight = Type::kF32,
|
||||
};
|
||||
ModelConfig config = ConfigFromModel(info.model);
|
||||
ModelWeightsStorage grad, grad_m, grad_v;
|
||||
grad.Allocate(info.model, info.weight, pool);
|
||||
grad_m.Allocate(info.model, info.weight, pool);
|
||||
grad_v.Allocate(info.model, info.weight, pool);
|
||||
grad_m.ZeroInit();
|
||||
grad_v.ZeroInit();
|
||||
ForwardPass<float> forward(config), backward(config);
|
||||
KVCache kv_cache = KVCache::Create(config, /*prefill_tbatch_size=*/16);
|
||||
|
||||
RowVectorBatch<float> inv_timescale = CreateInvTimescale(
|
||||
config.layer_configs[0].qkv_dim,
|
||||
config.layer_configs[0].post_qk == PostQKType::HalfRope);
|
||||
|
||||
Gemma gemma(GemmaTokenizer(), info, env);
|
||||
|
||||
const auto generate = [&](const std::vector<int>& prompt) {
|
||||
std::vector<int> reply;
|
||||
auto stream_token = [&reply](int token, float) {
|
||||
reply.push_back(token);
|
||||
return token != ReverseSequenceSampler::kEndToken;
|
||||
};
|
||||
RuntimeConfig runtime = {
|
||||
.max_generated_tokens = 16,
|
||||
.temperature = 1.0f,
|
||||
.gen = &gen,
|
||||
.verbosity = 0,
|
||||
.stream_token = stream_token,
|
||||
.eos_id = ReverseSequenceSampler::kEndToken,
|
||||
};
|
||||
TimingInfo timing_info;
|
||||
gemma.Generate(runtime, prompt, 0, kv_cache, timing_info);
|
||||
return reply;
|
||||
};
|
||||
|
||||
// Sanity check of reply tokens.
|
||||
// 1) Its length should be greater than the prompt.
|
||||
// 2) The prompt should be a prefix of the reply.
|
||||
auto verify = [&](const Prompt& prompt) {
|
||||
const std::vector<int>& context = prompt.context();
|
||||
std::vector<int> reply = generate(context);
|
||||
if (reply.size() <= context.size()) return false;
|
||||
return std::equal(context.begin(), context.end(), reply.begin(),
|
||||
reply.begin() + context.size());
|
||||
};
|
||||
|
||||
gemma.MutableWeights().RandInit(gen);
|
||||
gemma.MutableWeights().AllocAndCopyWithTranspose(pool);
|
||||
|
||||
printf("Initial weights:\n");
|
||||
gemma.MutableWeights().LogWeightStats();
|
||||
|
||||
constexpr size_t kBatchSize = 8;
|
||||
const float alpha = 0.001f;
|
||||
const float beta1 = 0.9f;
|
||||
const float beta2 = 0.999f;
|
||||
const float epsilon = 1e-8f;
|
||||
|
||||
ReverseSequenceSampler training_task({
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1});
|
||||
size_t steps = 0;
|
||||
size_t num_ok;
|
||||
for (; steps < 1000000; ++steps) {
|
||||
std::mt19937 sgen(42);
|
||||
grad.ZeroInit();
|
||||
float total_loss = 0.0f;
|
||||
num_ok = 0;
|
||||
for (size_t i = 0; i < kBatchSize; ++i) {
|
||||
Prompt prompt = training_task.Sample(sgen);
|
||||
total_loss += CrossEntropyLossForwardPass(
|
||||
prompt, *gemma.Weights().GetWeightsOfType<float>(), forward,
|
||||
inv_timescale, pool);
|
||||
CrossEntropyLossBackwardPass(
|
||||
prompt, *gemma.Weights().GetWeightsOfType<float>(), forward,
|
||||
*grad.GetWeightsOfType<float>(), backward, inv_timescale, pool);
|
||||
gemma.MutableWeights().CopyWithTranspose(pool);
|
||||
num_ok += verify(prompt) ? 1 : 0;
|
||||
}
|
||||
total_loss /= kBatchSize;
|
||||
|
||||
AdamUpdate(info.weight, grad, alpha, beta1, beta2, epsilon, steps + 1,
|
||||
gemma.Weights(), grad_m, grad_v, pool);
|
||||
printf("step: %zu total_loss: %.15f num_ok: %zu/%zu\n",
|
||||
steps, total_loss, num_ok, kBatchSize);
|
||||
if (steps % 100 == 0) {
|
||||
printf("Batch gradient:\n");
|
||||
grad.LogWeightStats();
|
||||
}
|
||||
if (total_loss < 0.5f) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
printf("Num steps: %zu\n", steps);
|
||||
printf("Final weights:\n");
|
||||
gemma.MutableWeights().LogWeightStats();
|
||||
EXPECT_LT(steps, 300);
|
||||
EXPECT_EQ(num_ok, kBatchSize);
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
@ -1,95 +0,0 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "backprop/optimizer.h"
|
||||
|
||||
#include <cmath>
|
||||
|
||||
#include "compression/compress.h"
|
||||
#include "gemma/common.h"
|
||||
#include "gemma/weights.h"
|
||||
#include "util/allocator.h"
|
||||
#include "hwy/aligned_allocator.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
namespace {
|
||||
|
||||
class AdamUpdater {
|
||||
public:
|
||||
explicit AdamUpdater(float alpha, float beta1, float beta2, float epsilon,
|
||||
size_t t)
|
||||
: alpha_(alpha), beta1_(beta1), beta2_(beta2), cbeta1_(1.0f - beta1),
|
||||
cbeta2_(1.0f - beta2), norm1_(1.0 / (1.0 - std::pow(beta1, t))),
|
||||
norm2_(1.0 / (1.0 - std::pow(beta2, t))), epsilon_(epsilon) {}
|
||||
|
||||
void operator()(const char* name, const MatPtr& grad, MatPtr& weights,
|
||||
MatPtr& grad_m, MatPtr& grad_v) {
|
||||
const float* HWY_RESTRICT g = grad.data<float>();
|
||||
float* HWY_RESTRICT w = weights.data<float>();
|
||||
float* HWY_RESTRICT m = grad_m.data<float>();
|
||||
float* HWY_RESTRICT v = grad_v.data<float>();
|
||||
for (size_t i = 0; i < grad.NumElements(); ++i) {
|
||||
m[i] *= beta1_;
|
||||
m[i] += cbeta1_ * g[i];
|
||||
v[i] *= beta2_;
|
||||
v[i] += cbeta2_ * g[i] * g[i];
|
||||
const float mhat = m[i] * norm1_;
|
||||
const float vhat = v[i] * norm2_;
|
||||
w[i] -= alpha_ * mhat / (std::sqrt(vhat) + epsilon_);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
float alpha_;
|
||||
float beta1_;
|
||||
float beta2_;
|
||||
float cbeta1_;
|
||||
float cbeta2_;
|
||||
float norm1_;
|
||||
float norm2_;
|
||||
float epsilon_;
|
||||
};
|
||||
|
||||
void AdamUpdate(ModelWeightsPtrs<float>* grad, float alpha, float beta1,
|
||||
float beta2, float epsilon, size_t t,
|
||||
ModelWeightsPtrs<float>* weights,
|
||||
ModelWeightsPtrs<float>* grad_m,
|
||||
ModelWeightsPtrs<float>* grad_v, hwy::ThreadPool& pool) {
|
||||
AdamUpdater updater(alpha, beta1, beta2, epsilon, t);
|
||||
ModelWeightsPtrs<float>::ForEachTensor(
|
||||
{grad, weights, grad_m, grad_v}, ForEachType::kLoadNoToc,
|
||||
[&updater](const char* name, hwy::Span<MatPtr*> tensors) {
|
||||
updater(name, *tensors[0], *tensors[1], *tensors[2], *tensors[3]);
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void AdamUpdate(Type weight_type, const ModelWeightsStorage& grad, float alpha,
|
||||
float beta1, float beta2, float epsilon, size_t t,
|
||||
const ModelWeightsStorage& weights,
|
||||
const ModelWeightsStorage& grad_m,
|
||||
const ModelWeightsStorage& grad_v, hwy::ThreadPool& pool) {
|
||||
HWY_ASSERT(weight_type == Type::kF32);
|
||||
AdamUpdate(grad.GetWeightsOfType<float>(), alpha, beta1, beta2, epsilon, t,
|
||||
weights.GetWeightsOfType<float>(),
|
||||
grad_m.GetWeightsOfType<float>(), grad_v.GetWeightsOfType<float>(),
|
||||
pool);
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
@ -1,33 +0,0 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_OPTIMIZER_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_OPTIMIZER_H_
|
||||
|
||||
#include "gemma/common.h"
|
||||
#include "gemma/weights.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
void AdamUpdate(Type weight_type, const ModelWeightsStorage& grad, float alpha,
|
||||
float beta1, float beta2, float epsilon, size_t t,
|
||||
const ModelWeightsStorage& weights,
|
||||
const ModelWeightsStorage& grad_m,
|
||||
const ModelWeightsStorage& grad_v, hwy::ThreadPool& pool);
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_OPTIMIZER_H_
|
||||
|
|
@ -1,90 +0,0 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_SAMPLER_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_SAMPLER_H_
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include <random>
|
||||
#include <vector>
|
||||
|
||||
#include "backprop/prompt.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
class PromptSampler {
|
||||
public:
|
||||
virtual Prompt Sample(std::mt19937& gen) = 0;
|
||||
virtual ~PromptSampler() = default;
|
||||
|
||||
std::vector<Prompt> SampleBatch(size_t batch_size, std::mt19937& gen) {
|
||||
std::vector<Prompt> batch;
|
||||
batch.reserve(batch_size);
|
||||
for (size_t i = 0; i < batch_size; ++i) {
|
||||
batch.emplace_back(Sample(gen));
|
||||
}
|
||||
return batch;
|
||||
}
|
||||
};
|
||||
|
||||
class ReverseSequenceSampler : public PromptSampler {
|
||||
public:
|
||||
explicit ReverseSequenceSampler(const std::vector<int>& length_histo)
|
||||
: token_dist_(0, 9) {
|
||||
for (int i = 0; i < length_histo.size(); ++i) {
|
||||
const int count = length_histo[i];
|
||||
for (int j = 0; j < count; ++j) {
|
||||
length_lut_.push_back(i + 1);
|
||||
}
|
||||
}
|
||||
length_dist_ = std::uniform_int_distribution<>(0, length_lut_.size() - 1);
|
||||
}
|
||||
virtual ~ReverseSequenceSampler() = default;
|
||||
|
||||
static constexpr int kReverseToken = 10;
|
||||
static constexpr int kEndToken = 11;
|
||||
|
||||
Prompt Sample(std::mt19937& gen) override {
|
||||
Prompt prompt;
|
||||
int len = length_lut_[length_dist_(gen)];
|
||||
prompt.tokens.resize(2 * len + 2);
|
||||
prompt.tokens[len] = kReverseToken;
|
||||
prompt.tokens[2 * len + 1] = kEndToken;
|
||||
for (size_t i = 0; i < len; ++i) {
|
||||
prompt.tokens[i] = prompt.tokens[2 * len - i] = token_dist_(gen);
|
||||
}
|
||||
prompt.context_size = len + 1;
|
||||
return prompt;
|
||||
}
|
||||
|
||||
static void LogPrompt(const Prompt& prompt) {
|
||||
static const char* kVocab[] = {
|
||||
"0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "-->", "|",
|
||||
};
|
||||
for (int token : prompt.tokens) printf("%s", kVocab[token]);
|
||||
printf(" [context_size: %zu]\n", prompt.context_size);
|
||||
}
|
||||
|
||||
private:
|
||||
std::uniform_int_distribution<> token_dist_;
|
||||
std::uniform_int_distribution<> length_dist_;
|
||||
std::vector<int> length_lut_;
|
||||
};
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_SAMPLER_H_
|
||||
|
|
@ -1,209 +0,0 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_TEST_UTIL_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_TEST_UTIL_H_
|
||||
|
||||
#include <stddef.h>
|
||||
|
||||
#include <cmath>
|
||||
#include <complex>
|
||||
#include <random>
|
||||
#include <vector>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "compression/compress.h"
|
||||
#include "gemma/configs.h"
|
||||
#include "gemma/weights.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
template <typename T>
|
||||
void RandInit(MatPtrT<T>& x, T stddev, std::mt19937& gen) {
|
||||
std::normal_distribution<T> dist(0.0, stddev);
|
||||
for (size_t i = 0; i < x.NumElements(); ++i) {
|
||||
x.At(i) = dist(gen);
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: make a member of Layer<T>.
|
||||
template <typename T>
|
||||
void RandInit(LayerWeightsPtrs<T>& w, T stddev, std::mt19937& gen) {
|
||||
RandInit(w.pre_attention_norm_scale, stddev, gen);
|
||||
RandInit(w.attn_vec_einsum_w, stddev, gen);
|
||||
RandInit(w.qkv_einsum_w, stddev, gen);
|
||||
RandInit(w.pre_ffw_norm_scale, stddev, gen);
|
||||
RandInit(w.gating_einsum_w, stddev, gen);
|
||||
RandInit(w.linear_w, stddev, gen);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void RandInit(ModelWeightsPtrs<T>& w, T stddev, std::mt19937& gen) {
|
||||
const size_t kLayers = w.c_layers.size();
|
||||
RandInit(w.embedder_input_embedding, stddev, gen);
|
||||
RandInit(w.final_norm_scale, stddev, gen);
|
||||
for (size_t i = 0; i < kLayers; ++i) {
|
||||
RandInit(*w.GetLayer(i), stddev, gen);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
void Complexify(const MatPtrT<T>& x, MatPtrT<std::complex<U>>& c_x) {
|
||||
for (size_t i = 0; i < x.NumElements(); ++i) {
|
||||
c_x.At(i) = std::complex<U>(x.At(i), 0.0);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
void Complexify(const LayerWeightsPtrs<T>& w, LayerWeightsPtrs<U>& c_w) {
|
||||
Complexify(w.pre_attention_norm_scale, c_w.pre_attention_norm_scale);
|
||||
Complexify(w.attn_vec_einsum_w, c_w.attn_vec_einsum_w);
|
||||
Complexify(w.qkv_einsum_w, c_w.qkv_einsum_w);
|
||||
Complexify(w.pre_ffw_norm_scale, c_w.pre_ffw_norm_scale);
|
||||
Complexify(w.gating_einsum_w, c_w.gating_einsum_w);
|
||||
Complexify(w.linear_w, c_w.linear_w);
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
void Complexify(const ModelWeightsPtrs<T>& w, ModelWeightsPtrs<U>& c_w) {
|
||||
const size_t kLayers = w.c_layers.size();
|
||||
Complexify(w.embedder_input_embedding, c_w.embedder_input_embedding);
|
||||
Complexify(w.final_norm_scale, c_w.final_norm_scale);
|
||||
for (size_t i = 0; i < kLayers; ++i) {
|
||||
Complexify(*w.GetLayer(i), *c_w.GetLayer(i));
|
||||
}
|
||||
}
|
||||
|
||||
// Somewhat duplicates ModelWeightsStorage, but that has neither double nor
|
||||
// complex types allowed and it would cause code bloat to add them there.
|
||||
template <typename T>
|
||||
class WeightsWrapper {
|
||||
public:
|
||||
explicit WeightsWrapper(const ModelConfig& config)
|
||||
: pool_(0), weights_(config) {
|
||||
weights_.Allocate(data_, pool_);
|
||||
}
|
||||
|
||||
const ModelWeightsPtrs<T>& get() const { return weights_; }
|
||||
ModelWeightsPtrs<T>& get() { return weights_; }
|
||||
void ZeroInit() { weights_.ZeroInit(); }
|
||||
void CopyFrom(const WeightsWrapper<T>& other) {
|
||||
weights_.CopyFrom(other.weights_);
|
||||
}
|
||||
|
||||
private:
|
||||
hwy::ThreadPool pool_;
|
||||
std::vector<MatStorage> data_;
|
||||
ModelWeightsPtrs<T> weights_;
|
||||
};
|
||||
|
||||
template <typename T, typename U>
|
||||
void TestNear(const MatPtrT<T>& actual, const MatPtrT<U>& expected,
|
||||
double max_abs_err, double max_rel_err, int line) {
|
||||
double sum0 = 0;
|
||||
double sum1 = 0;
|
||||
double sum01 = 0;
|
||||
for (size_t i = 0; i < actual.NumElements(); ++i) {
|
||||
sum0 += actual.At(i) * actual.At(i);
|
||||
sum1 += expected.At(i) * expected.At(i);
|
||||
sum01 += actual.At(i) * expected.At(i);
|
||||
ASSERT_NEAR(actual.At(i), expected.At(i),
|
||||
std::max(max_abs_err, std::abs(expected.At(i)) * max_rel_err))
|
||||
<< "line: " << line << " dim=" << expected.NumElements() << " i=" << i;
|
||||
}
|
||||
if (sum0 > 1e-40) {
|
||||
double norm_dot = sum01 / std::sqrt(sum0) / std::sqrt(sum1);
|
||||
ASSERT_NEAR(norm_dot, 1.0, 1e-7)
|
||||
<< "line: " << line << " sum0: " << sum0 << " sum1: " << sum1
|
||||
<< " sum01: " << sum01;
|
||||
}
|
||||
}
|
||||
|
||||
// Compute gradient with the finite difference method in the complex plane.
|
||||
// If f : R->R is the tested function and F : C->C is its extension on the
|
||||
// complex plane so that F is complex differentiable in x, then
|
||||
//
|
||||
// F(x + ih) = F(x) + ih F'(x) + O(h^2) F''(x)
|
||||
//
|
||||
// which means that
|
||||
//
|
||||
// F'(x) ~= Imag(F(x + ih)) / h
|
||||
//
|
||||
// This method is more numerically stable than the real-valued finite difference
|
||||
// method since we don't need to subtract floating point numbers that are near
|
||||
// to each other.
|
||||
template <typename FUNC, typename T, typename U>
|
||||
void TestGradient(const MatPtrT<T>& grad, MatPtrT<std::complex<U>>& x,
|
||||
FUNC func, U step, T max_abs_err, T max_rel_err, int line) {
|
||||
MatStorageT<T> exp_grad("exp_grad", x.Rows(), x.Cols());
|
||||
const U inv_step = 1.0 / step;
|
||||
for (size_t i = 0; i < x.NumElements(); ++i) {
|
||||
const U x0 = std::real(x.At(i));
|
||||
const std::complex<U> x1 = std::complex<U>(x0, step);
|
||||
x.At(i) = x1;
|
||||
const std::complex<U> f1 = func();
|
||||
exp_grad.At(i) = std::imag(f1) * inv_step;
|
||||
x.At(i) = x0;
|
||||
}
|
||||
TestNear(grad, exp_grad, max_abs_err, max_rel_err, line);
|
||||
}
|
||||
|
||||
template <typename FUNC>
|
||||
void TestGradient(const MatPtrT<float>& grad, MatPtrT<std::complex<float>>& x,
|
||||
FUNC func, float max_abs_err, float max_rel_error, int line) {
|
||||
TestGradient(grad, x, func, 1e-30f, max_abs_err, max_rel_error, line);
|
||||
}
|
||||
|
||||
template <typename FUNC, typename T>
|
||||
void TestGradient(const MatPtrT<T>& grad, MatPtrT<std::complex<double>>& x,
|
||||
FUNC func, T max_abs_err, T max_rel_error, int line) {
|
||||
TestGradient(grad, x, func, 1e-50, max_abs_err, max_rel_error, line);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename FUNC>
|
||||
void TestGradient(const LayerWeightsPtrs<T>& grad,
|
||||
LayerWeightsPtrs<U>& c_weights, FUNC func, T max_err) {
|
||||
TestGradient(grad.pre_attention_norm_scale,
|
||||
c_weights.pre_attention_norm_scale,
|
||||
func, max_err, max_err, __LINE__);
|
||||
TestGradient(grad.attn_vec_einsum_w, c_weights.attn_vec_einsum_w,
|
||||
func, max_err, max_err, __LINE__);
|
||||
TestGradient(grad.qkv_einsum_w, c_weights.qkv_einsum_w,
|
||||
func, max_err, max_err, __LINE__);
|
||||
TestGradient(grad.pre_ffw_norm_scale, c_weights.pre_ffw_norm_scale,
|
||||
func, max_err, max_err, __LINE__);
|
||||
TestGradient(grad.gating_einsum_w, c_weights.gating_einsum_w,
|
||||
func, max_err, max_err, __LINE__);
|
||||
TestGradient(grad.linear_w, c_weights.linear_w,
|
||||
func, max_err, max_err, __LINE__);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename FUNC>
|
||||
void TestGradient(const ModelWeightsPtrs<T>& grad,
|
||||
ModelWeightsPtrs<U>& c_weights, FUNC func, T max_err) {
|
||||
TestGradient(grad.embedder_input_embedding,
|
||||
c_weights.embedder_input_embedding,
|
||||
func, 2 * max_err, max_err, __LINE__);
|
||||
TestGradient(grad.final_norm_scale, c_weights.final_norm_scale,
|
||||
func, max_err, max_err, __LINE__);
|
||||
for (size_t i = 0; i < grad.c_layers.size(); ++i) {
|
||||
TestGradient(*grad.GetLayer(i), *c_weights.GetLayer(i), func, max_err);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_TEST_UTIL_H_
|
||||
|
|
@ -1,3 +0,0 @@
|
|||
*
|
||||
!.gitignore
|
||||
!.hgignore
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
# Weight compression, I/O and analysis
|
||||
# Weight compression and analysis.
|
||||
|
||||
package(
|
||||
default_applicable_licenses = [
|
||||
|
|
@ -20,78 +20,11 @@ config_setting(
|
|||
visibility = ["//visibility:private"],
|
||||
)
|
||||
|
||||
FILE_DEPS = select({
|
||||
"//conditions:default": [
|
||||
# Placeholder for io deps, do not remove
|
||||
],
|
||||
":android": [],
|
||||
# Placeholder for internal build rules, do not remove
|
||||
})
|
||||
|
||||
cc_library(
|
||||
name = "io",
|
||||
srcs = [
|
||||
"io.cc",
|
||||
# Placeholder for io backend, do not remove
|
||||
],
|
||||
hdrs = ["io.h"],
|
||||
local_defines = select({
|
||||
# Placeholder for internal build rules, do not remove
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
deps = [
|
||||
"@highway//:hwy",
|
||||
] + FILE_DEPS,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "fields",
|
||||
srcs = ["fields.cc"],
|
||||
hdrs = ["fields.h"],
|
||||
deps = [
|
||||
"@highway//:hwy",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "fields_test",
|
||||
srcs = ["fields_test.cc"],
|
||||
deps = [
|
||||
":fields",
|
||||
"@googletest//:gtest_main", # buildcleaner: keep
|
||||
"@highway//:hwy_test_util",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "blob_store",
|
||||
srcs = ["blob_store.cc"],
|
||||
hdrs = ["blob_store.h"],
|
||||
deps = [
|
||||
":io",
|
||||
"@highway//:hwy",
|
||||
"@highway//:thread_pool",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "blob_store_test",
|
||||
srcs = ["blob_store_test.cc"],
|
||||
deps = [
|
||||
":blob_store",
|
||||
":io",
|
||||
"@googletest//:gtest_main", # buildcleaner: keep
|
||||
"@highway//:hwy",
|
||||
"@highway//:hwy_test_util",
|
||||
"@highway//:thread_pool",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "distortion",
|
||||
hdrs = [
|
||||
"distortion.h",
|
||||
"shared.h",
|
||||
"types.h",
|
||||
],
|
||||
deps = [
|
||||
"//:basics",
|
||||
|
|
@ -115,21 +48,29 @@ cc_test(
|
|||
)
|
||||
|
||||
cc_library(
|
||||
name = "sfp",
|
||||
hdrs = ["shared.h"],
|
||||
textual_hdrs = ["sfp-inl.h"],
|
||||
name = "types",
|
||||
hdrs = ["types.h"],
|
||||
deps = [
|
||||
"//:basics",
|
||||
"@highway//:hwy",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "sfp",
|
||||
textual_hdrs = ["sfp-inl.h"],
|
||||
deps = [
|
||||
":types",
|
||||
"@highway//:hwy",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "nuq",
|
||||
hdrs = ["shared.h"],
|
||||
textual_hdrs = ["nuq-inl.h"],
|
||||
deps = [
|
||||
":sfp",
|
||||
":types",
|
||||
"//:basics",
|
||||
"@highway//:hwy",
|
||||
"@highway//hwy/contrib/sort:vqsort",
|
||||
|
|
@ -144,8 +85,10 @@ cc_library(
|
|||
deps = [
|
||||
":compress",
|
||||
":distortion",
|
||||
"//:mat",
|
||||
"@highway//:hwy",
|
||||
"@highway//:hwy_test_util",
|
||||
"@highway//:thread_pool",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
@ -153,7 +96,6 @@ cc_test(
|
|||
name = "sfp_test",
|
||||
size = "small",
|
||||
srcs = ["sfp_test.cc"],
|
||||
features = ["fully_static_link"],
|
||||
linkstatic = True,
|
||||
local_defines = ["HWY_IS_TEST"],
|
||||
# for test_suite.
|
||||
|
|
@ -174,7 +116,6 @@ cc_test(
|
|||
size = "small",
|
||||
timeout = "long",
|
||||
srcs = ["nuq_test.cc"],
|
||||
features = ["fully_static_link"],
|
||||
linkstatic = True,
|
||||
local_defines = ["HWY_IS_TEST"],
|
||||
# for test_suite.
|
||||
|
|
@ -182,7 +123,6 @@ cc_test(
|
|||
deps = [
|
||||
":distortion",
|
||||
":nuq",
|
||||
":sfp",
|
||||
"@googletest//:gtest_main", # buildcleaner: keep
|
||||
"//:test_util",
|
||||
"@highway//:hwy",
|
||||
|
|
@ -196,21 +136,18 @@ cc_library(
|
|||
srcs = ["compress.cc"],
|
||||
hdrs = [
|
||||
"compress.h",
|
||||
"shared.h",
|
||||
"types.h",
|
||||
],
|
||||
textual_hdrs = ["compress-inl.h"],
|
||||
deps = [
|
||||
":blob_store",
|
||||
":distortion",
|
||||
":fields",
|
||||
":io",
|
||||
":nuq",
|
||||
":sfp",
|
||||
"//:allocator",
|
||||
"//:basics",
|
||||
"//:common",
|
||||
"//:mat",
|
||||
"@highway//:hwy",
|
||||
"@highway//:nanobenchmark",
|
||||
"@highway//:profiler",
|
||||
"@highway//:stats",
|
||||
"@highway//:thread_pool",
|
||||
],
|
||||
|
|
@ -221,7 +158,6 @@ cc_test(
|
|||
size = "small",
|
||||
timeout = "long",
|
||||
srcs = ["compress_test.cc"],
|
||||
features = ["fully_static_link"],
|
||||
linkstatic = True,
|
||||
local_defines = ["HWY_IS_TEST"],
|
||||
# for test_suite.
|
||||
|
|
@ -245,51 +181,10 @@ cc_library(
|
|||
deps = [
|
||||
":nuq",
|
||||
":sfp",
|
||||
":types",
|
||||
"@highway//:hwy",
|
||||
"@highway//:stats",
|
||||
"@highway//:thread_pool",
|
||||
"@highway//hwy/contrib/sort:vqsort",
|
||||
],
|
||||
)
|
||||
|
||||
cc_binary(
|
||||
name = "compress_weights",
|
||||
srcs = ["compress_weights.cc"],
|
||||
deps = [
|
||||
":compress",
|
||||
":io",
|
||||
"//:allocator",
|
||||
"//:args",
|
||||
"//:common",
|
||||
"//:tokenizer",
|
||||
"//:weights",
|
||||
"@highway//:hwy",
|
||||
"@highway//:thread_pool",
|
||||
],
|
||||
)
|
||||
|
||||
cc_binary(
|
||||
name = "blob_compare",
|
||||
srcs = ["blob_compare.cc"],
|
||||
deps = [
|
||||
":blob_store",
|
||||
":io",
|
||||
"//:allocator",
|
||||
"//:basics",
|
||||
"//:threading",
|
||||
"@highway//:hwy",
|
||||
"@highway//:hwy_test_util",
|
||||
"@highway//:nanobenchmark",
|
||||
],
|
||||
)
|
||||
|
||||
cc_binary(
|
||||
name = "migrate_weights",
|
||||
srcs = ["migrate_weights.cc"],
|
||||
deps = [
|
||||
"//:app",
|
||||
"//:args",
|
||||
"//:benchmark_helper",
|
||||
"//:gemma_lib",
|
||||
],
|
||||
)
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@
|
|||
#include <cstdlib> // std::abs
|
||||
#include <vector>
|
||||
|
||||
#include "compression/shared.h"
|
||||
#include "compression/types.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
#include "hwy/stats.h"
|
||||
|
|
|
|||
|
|
@ -1,230 +0,0 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
#include <string.h>
|
||||
|
||||
#include <atomic>
|
||||
#include <vector>
|
||||
|
||||
#include "compression/blob_store.h"
|
||||
#include "compression/io.h" // Path
|
||||
#include "util/allocator.h"
|
||||
#include "util/basics.h" // IndexRange
|
||||
#include "util/threading.h"
|
||||
#include "hwy/aligned_allocator.h" // Span
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/timer.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
using KeySpan = hwy::Span<const hwy::uint128_t>;
|
||||
|
||||
// Returns false if any keys differ, because then blobs are not comparable.
|
||||
bool CompareKeys(const BlobReader& reader1, const BlobReader& reader2) {
|
||||
KeySpan keys1 = reader1.Keys();
|
||||
KeySpan keys2 = reader2.Keys();
|
||||
if (keys1.size() != keys2.size()) {
|
||||
fprintf(stderr, "#keys mismatch: %zu vs %zu\n", keys1.size(), keys2.size());
|
||||
return false;
|
||||
}
|
||||
for (size_t i = 0; i < keys1.size(); ++i) {
|
||||
if (keys1[i] != keys2[i]) {
|
||||
fprintf(stderr, "key %zu mismatch: %s vs %s\n", i,
|
||||
StringFromKey(keys1[i]).c_str(), StringFromKey(keys2[i]).c_str());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// Total amount to allocate for all blobs.
|
||||
size_t TotalBytes(BlobReader& reader) {
|
||||
size_t total_bytes = 0;
|
||||
for (const hwy::uint128_t key : reader.Keys()) {
|
||||
total_bytes += reader.BlobSize(key);
|
||||
}
|
||||
return total_bytes;
|
||||
}
|
||||
|
||||
using BytePtr = hwy::AlignedFreeUniquePtr<uint8_t[]>;
|
||||
using ByteSpan = hwy::Span<uint8_t>; // Sections within BytePtr
|
||||
using BlobVec = std::vector<ByteSpan>; // in order of keys
|
||||
|
||||
// Allocates memory within the single allocation and updates `pos`.
|
||||
BlobVec ReserveMemory(BlobReader& reader, BytePtr& all_blobs, size_t& pos) {
|
||||
BlobVec blobs;
|
||||
for (const hwy::uint128_t key : reader.Keys()) {
|
||||
const size_t bytes = reader.BlobSize(key);
|
||||
blobs.push_back(ByteSpan(all_blobs.get() + pos, bytes));
|
||||
pos += bytes;
|
||||
}
|
||||
return blobs;
|
||||
}
|
||||
|
||||
// Reads one set of blobs in parallel (helpful if in disk cache).
|
||||
void ReadBlobs(BlobReader& reader, BlobVec& blobs, hwy::ThreadPool& pool) {
|
||||
HWY_ASSERT(reader.Keys().size() == blobs.size());
|
||||
for (size_t i = 0; i < blobs.size(); ++i) {
|
||||
reader.Enqueue(reader.Keys()[i], blobs[i].data(), blobs[i].size());
|
||||
}
|
||||
const BlobError err = reader.ReadAll(pool);
|
||||
if (err != 0) {
|
||||
HWY_ABORT("Parallel read failed: %d\n", err);
|
||||
}
|
||||
}
|
||||
|
||||
// Parallelizes ReadBlobs across (two) packages, if available.
|
||||
void ReadBothBlobs(BlobReader& reader1, BlobReader& reader2, size_t total_bytes,
|
||||
BlobVec& blobs1, BlobVec& blobs2, NestedPools& pools) {
|
||||
const double t0 = hwy::platform::Now();
|
||||
fprintf(stderr, "Reading %zu GiB, %zux%zu cores: ", total_bytes >> 30,
|
||||
pools.AllPackages().NumWorkers(), pools.Pool().NumWorkers());
|
||||
pools.AllPackages().Run(0, 2, [&](size_t task, size_t pkg_idx) {
|
||||
ReadBlobs(task ? reader2 : reader1, task ? blobs2 : blobs1,
|
||||
pools.Pool(pkg_idx));
|
||||
});
|
||||
const double t1 = hwy::platform::Now();
|
||||
fprintf(stderr, "%.1f GB/s\n", total_bytes / (t1 - t0) * 1E-9);
|
||||
}
|
||||
|
||||
// Returns number of elements with a mismatch. For float and bf16 blobs, uses
|
||||
// L1 and relative error, otherwise byte-wise comparison.
|
||||
size_t BlobDifferences(const ByteSpan& data1, const ByteSpan& data2,
|
||||
const hwy::uint128_t key) {
|
||||
if (data1.size() != data2.size() || data1.size() == 0) {
|
||||
HWY_ABORT("key %s size mismatch: %zu vs %zu\n", StringFromKey(key).c_str(),
|
||||
data1.size(), data2.size());
|
||||
}
|
||||
|
||||
size_t mismatches = 0;
|
||||
char type;
|
||||
hwy::CopyBytes(&key, &type, 1);
|
||||
if (type == 'F') {
|
||||
HWY_ASSERT(data1.size() % sizeof(float) == 0);
|
||||
for (size_t j = 0; j < data1.size(); j += sizeof(float)) {
|
||||
float f1, f2;
|
||||
hwy::CopyBytes(&data1[j], &f1, sizeof(f1));
|
||||
hwy::CopyBytes(&data2[j], &f2, sizeof(f2));
|
||||
const float l1 = hwy::ScalarAbs(f1 - f2);
|
||||
const float rel = hwy::ScalarAbs(f1) == 0.0f ? 0.0f : l1 / f1;
|
||||
if (l1 > 1E-3f || rel > 1E-2f) {
|
||||
fprintf(stderr, "key %s %5zu: L1 %.5f rel %.4f\n",
|
||||
StringFromKey(key).c_str(), j, l1, rel);
|
||||
++mismatches;
|
||||
}
|
||||
}
|
||||
} else if (type == 'B') {
|
||||
for (size_t j = 0; j < data1.size(); j += sizeof(hwy::bfloat16_t)) {
|
||||
hwy::bfloat16_t b1, b2;
|
||||
hwy::CopyBytes(&data1[j], &b1, sizeof(b1));
|
||||
hwy::CopyBytes(&data2[j], &b2, sizeof(b2));
|
||||
const float f1 = hwy::ConvertScalarTo<float>(b1);
|
||||
const float f2 = hwy::ConvertScalarTo<float>(b2);
|
||||
const float l1 = hwy::ScalarAbs(f1 - f2);
|
||||
const float rel = hwy::ScalarAbs(f1) == 0.0f ? 0.0f : l1 / f1;
|
||||
if (l1 > 1E-2f || rel > 1E-1f) {
|
||||
fprintf(stderr, "key %s %5zu: L1 %.5f rel %.4f\n",
|
||||
StringFromKey(key).c_str(), j, l1, rel);
|
||||
++mismatches;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (size_t j = 0; j < data1.size(); ++j) {
|
||||
if (data1[j] != data2[j]) {
|
||||
if (mismatches == 0) {
|
||||
fprintf(stderr, "key %s mismatch at byte %5zu\n",
|
||||
StringFromKey(key).c_str(), j);
|
||||
}
|
||||
++mismatches;
|
||||
}
|
||||
}
|
||||
}
|
||||
return mismatches;
|
||||
}
|
||||
|
||||
void CompareBlobs(const KeySpan& keys, BlobVec& blobs1, BlobVec& blobs2,
|
||||
size_t total_bytes, NestedPools& pools) {
|
||||
fprintf(stderr, "Comparing %zu blobs in parallel: ", keys.size());
|
||||
const double t0 = hwy::platform::Now();
|
||||
std::atomic<size_t> blobs_equal{};
|
||||
std::atomic<size_t> blobs_diff{};
|
||||
const IndexRangePartition ranges = StaticPartition(
|
||||
IndexRange(0, keys.size()), pools.AllPackages().NumWorkers(), 1);
|
||||
ParallelizeOneRange(
|
||||
ranges, pools.AllPackages(),
|
||||
[&](const IndexRange& range, size_t pkg_idx) {
|
||||
pools.Pool(pkg_idx).Run(
|
||||
range.begin(), range.end(), [&](size_t i, size_t /*thread*/) {
|
||||
const size_t mismatches =
|
||||
BlobDifferences(blobs1[i], blobs2[i], keys[i]);
|
||||
if (mismatches != 0) {
|
||||
fprintf(stderr, "key %s has %zu mismatches in %zu bytes!\n",
|
||||
StringFromKey(keys[i]).c_str(), mismatches,
|
||||
blobs1[i].size());
|
||||
blobs_diff.fetch_add(1);
|
||||
} else {
|
||||
blobs_equal.fetch_add(1);
|
||||
}
|
||||
});
|
||||
});
|
||||
const double t1 = hwy::platform::Now();
|
||||
fprintf(stderr, "%.1f GB/s; total blob matches=%zu, mismatches=%zu\n",
|
||||
total_bytes / (t1 - t0) * 1E-9, blobs_equal.load(),
|
||||
blobs_diff.load());
|
||||
}
|
||||
|
||||
// Compares two sbs files, including blob order.
|
||||
void ReadAndCompareBlobs(const char* path1, const char* path2) {
|
||||
// Open files.
|
||||
BlobReader reader1;
|
||||
BlobReader reader2;
|
||||
const BlobError err1 = reader1.Open(Path(path1));
|
||||
const BlobError err2 = reader2.Open(Path(path2));
|
||||
if (err1 != 0 || err2 != 0) {
|
||||
HWY_ABORT("Failed to open files: %s %s: %d %d\n", path1, path2, err1, err2);
|
||||
}
|
||||
|
||||
if (!CompareKeys(reader1, reader2)) return;
|
||||
|
||||
// Single allocation, avoid initializing the memory.
|
||||
BoundedTopology topology;
|
||||
Allocator::Init(topology);
|
||||
NestedPools pools(topology);
|
||||
const size_t total_bytes = TotalBytes(reader1) + TotalBytes(reader2);
|
||||
BytePtr all_blobs = hwy::AllocateAligned<uint8_t>(total_bytes);
|
||||
size_t pos = 0;
|
||||
BlobVec blobs1 = ReserveMemory(reader1, all_blobs, pos);
|
||||
BlobVec blobs2 = ReserveMemory(reader2, all_blobs, pos);
|
||||
|
||||
ReadBothBlobs(reader1, reader2, total_bytes, blobs1, blobs2, pools);
|
||||
|
||||
CompareBlobs(reader1.Keys(), blobs1, blobs2, total_bytes, pools);
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
if (argc != 3) {
|
||||
HWY_ABORT("Usage: %s <sbs_path> <sbs_path>\n", argv[0]);
|
||||
}
|
||||
if (strcmp(argv[1], argv[2]) == 0) {
|
||||
HWY_ABORT("Filenames are the same, skipping comparison: %s\n", argv[1]);
|
||||
}
|
||||
gcpp::ReadAndCompareBlobs(argv[1], argv[2]);
|
||||
return 0;
|
||||
}
|
||||
|
|
@ -1,341 +0,0 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "compression/blob_store.h"
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include <atomic>
|
||||
#include <cstdio>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "compression/io.h"
|
||||
#include "hwy/aligned_allocator.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
#include "hwy/detect_compiler_arch.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
hwy::uint128_t MakeKey(const char* string) {
|
||||
size_t length = 0;
|
||||
for (size_t i = 0; string[i] != '\0'; ++i) {
|
||||
++length;
|
||||
}
|
||||
if (length > 16) {
|
||||
HWY_ABORT("Key %s is too long, please truncate to 16 chars.", string);
|
||||
}
|
||||
|
||||
hwy::uint128_t ret;
|
||||
hwy::ZeroBytes<sizeof(ret)>(&ret);
|
||||
hwy::CopyBytes(string, &ret, length);
|
||||
return ret;
|
||||
}
|
||||
|
||||
std::string StringFromKey(hwy::uint128_t key) {
|
||||
std::string name(sizeof(key) + 1, '\0');
|
||||
hwy::CopyBytes(&key, name.data(), sizeof(key));
|
||||
name.resize(name.find('\0'));
|
||||
return name;
|
||||
}
|
||||
|
||||
namespace {
|
||||
void EnqueueChunkRequests(uint64_t offset, uint64_t size, uint8_t* data,
|
||||
std::vector<BlobIO>& requests) {
|
||||
// Split into chunks for load-balancing even if blob sizes vary.
|
||||
constexpr size_t kChunkSize = 4 * 1024 * 1024; // bytes
|
||||
|
||||
// Split into whole chunks and possibly one remainder.
|
||||
uint64_t pos = 0;
|
||||
if (size >= kChunkSize) {
|
||||
for (; pos <= size - kChunkSize; pos += kChunkSize) {
|
||||
requests.emplace_back(offset + pos, kChunkSize, data + pos, 0);
|
||||
}
|
||||
}
|
||||
if (pos != size) {
|
||||
requests.emplace_back(offset + pos, size - pos, data + pos, 0);
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
static_assert(HWY_IS_LITTLE_ENDIAN, "Assumes little endian");
|
||||
|
||||
// On-disk representation (little-endian).
|
||||
//
|
||||
// Deliberately omits a version number because this file format is unchanging.
|
||||
// Additional data may be added only inside new blobs. Changes to the blob
|
||||
// contents or type should be handled by renaming keys.
|
||||
#pragma pack(push, 1)
|
||||
class BlobStore {
|
||||
static constexpr uint32_t kMagic = 0x0A534253; // SBS\n
|
||||
|
||||
public:
|
||||
// NOT including padding, so that we can also use ZeroFillPadding after
|
||||
// copying the header.
|
||||
static constexpr size_t HeaderSize(size_t num_blobs) {
|
||||
// 16-byte fixed fields plus per-blob: 16-byte key, 16-byte offset/size.
|
||||
return 16 + 32 * num_blobs;
|
||||
}
|
||||
|
||||
// Returns how many bytes to allocate for the header without the subsequent
|
||||
// blobs. Requires num_blobs_ to already be set, typically by reading
|
||||
// sizeof(BlobStore) bytes from disk.
|
||||
size_t PaddedHeaderSize() const {
|
||||
return hwy::RoundUpTo(HeaderSize(num_blobs_), kBlobAlign);
|
||||
}
|
||||
|
||||
// Returns aligned offset and zero-fills between that and `offset`.
|
||||
uint64_t ZeroFillPadding(uint64_t offset) {
|
||||
uint8_t* const bytes = reinterpret_cast<uint8_t*>(this);
|
||||
const uint64_t padded = hwy::RoundUpTo(offset, kBlobAlign);
|
||||
hwy::ZeroBytes(bytes + offset, padded - offset);
|
||||
return padded;
|
||||
}
|
||||
|
||||
BlobError CheckValidity(const uint64_t file_size) {
|
||||
if (magic_ != kMagic) return __LINE__;
|
||||
if (num_blobs_ == 0) return __LINE__;
|
||||
if (file_size_ != file_size) return __LINE__;
|
||||
|
||||
// Ensure blobs are back to back, and zero-pad.
|
||||
uint64_t offset = ZeroFillPadding(HeaderSize(num_blobs_));
|
||||
for (size_t i = 0; i < num_blobs_; ++i) {
|
||||
const hwy::uint128_t val = keys_[num_blobs_ + i];
|
||||
if (val.lo != offset) return __LINE__;
|
||||
offset = hwy::RoundUpTo(offset + val.hi, kBlobAlign);
|
||||
}
|
||||
|
||||
if (offset != file_size_) return __LINE__;
|
||||
|
||||
return 0; // all OK
|
||||
}
|
||||
|
||||
static BlobStorePtr Allocate(uint64_t total_size) {
|
||||
uint8_t* bytes =
|
||||
static_cast<uint8_t*>(hwy::AllocateAlignedBytes(total_size));
|
||||
if (!bytes) return BlobStorePtr();
|
||||
return BlobStorePtr(new (bytes) BlobStore(), hwy::AlignedFreer());
|
||||
}
|
||||
|
||||
static std::vector<BlobIO> PrepareWriteRequests(
|
||||
const hwy::uint128_t keys[], const hwy::Span<const uint8_t> blobs[],
|
||||
size_t num_blobs, BlobStore* bs) {
|
||||
// Sanity check and ensure the cast below is safe.
|
||||
HWY_ASSERT(num_blobs < (1ULL << 20));
|
||||
|
||||
// Allocate var-length header.
|
||||
const size_t header_size = HeaderSize(num_blobs);
|
||||
const size_t padded_header_size = hwy::RoundUpTo(header_size, kBlobAlign);
|
||||
const uint64_t padded_header_end = bs->ZeroFillPadding(header_size);
|
||||
HWY_ASSERT(padded_header_end == padded_header_size);
|
||||
|
||||
// All-zero buffer used to write padding to the file without copying the
|
||||
// input blobs.
|
||||
static uint8_t zeros[kBlobAlign] = {0};
|
||||
|
||||
// Total file size will be the header plus all padded blobs.
|
||||
uint64_t payload = 0;
|
||||
for (size_t i = 0; i < num_blobs; ++i) {
|
||||
payload += hwy::RoundUpTo(blobs[i].size(), kBlobAlign);
|
||||
}
|
||||
const size_t total_size = padded_header_size + payload;
|
||||
|
||||
// Fill header.
|
||||
bs->magic_ = kMagic;
|
||||
bs->num_blobs_ = static_cast<uint32_t>(num_blobs);
|
||||
bs->file_size_ = total_size;
|
||||
hwy::CopyBytes(keys, bs->keys_, num_blobs * sizeof(keys[0]));
|
||||
|
||||
// First IO request is for the header (not yet filled!).
|
||||
std::vector<BlobIO> requests;
|
||||
requests.reserve(1 + 2 * num_blobs);
|
||||
requests.emplace_back(/*offset=*/0, padded_header_size,
|
||||
reinterpret_cast<uint8_t*>(bs), 0);
|
||||
|
||||
// Fill second half of keys_ with offset/size and prepare IO requests.
|
||||
uint64_t offset = padded_header_end;
|
||||
for (size_t i = 0; i < num_blobs; ++i) {
|
||||
bs->keys_[num_blobs + i].lo = offset;
|
||||
bs->keys_[num_blobs + i].hi = blobs[i].size();
|
||||
|
||||
EnqueueChunkRequests(offset, blobs[i].size(),
|
||||
const_cast<uint8_t*>(blobs[i].data()), requests);
|
||||
offset += blobs[i].size();
|
||||
const size_t padded_size = hwy::RoundUpTo(blobs[i].size(), kBlobAlign);
|
||||
if (padded_size != blobs[i].size()) {
|
||||
const size_t padding = padded_size - blobs[i].size();
|
||||
HWY_ASSERT(padding <= kBlobAlign);
|
||||
requests.emplace_back(offset, padding, zeros, 0);
|
||||
offset += padding;
|
||||
}
|
||||
}
|
||||
|
||||
HWY_ASSERT(offset == total_size);
|
||||
return requests;
|
||||
}
|
||||
|
||||
bool FindKey(const hwy::uint128_t key, uint64_t& offset, size_t& size) const {
|
||||
for (size_t i = 0; i < num_blobs_; ++i) {
|
||||
if (keys_[i] == key) {
|
||||
const hwy::uint128_t val = keys_[num_blobs_ + i];
|
||||
offset = val.lo;
|
||||
size = val.hi;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
hwy::Span<const hwy::uint128_t> Keys() const {
|
||||
return hwy::Span<const hwy::uint128_t>(keys_, num_blobs_);
|
||||
}
|
||||
|
||||
private:
|
||||
uint32_t magic_;
|
||||
uint32_t num_blobs_; // never 0
|
||||
uint64_t file_size_; // must match actual size of file
|
||||
hwy::uint128_t keys_[1]; // length: 2 * num_blobs
|
||||
// Padding, then the blob identified by keys[0], then padding etc.
|
||||
};
|
||||
#pragma pack(pop)
|
||||
|
||||
BlobError BlobReader::Open(const Path& filename) {
|
||||
file_ = OpenFileOrNull(filename, "r");
|
||||
if (!file_) return __LINE__;
|
||||
|
||||
// Read first part of header to get actual size.
|
||||
BlobStore bs;
|
||||
if (!file_->Read(0, sizeof(bs), &bs)) return __LINE__;
|
||||
const size_t padded_size = bs.PaddedHeaderSize();
|
||||
HWY_ASSERT(padded_size >= sizeof(bs));
|
||||
|
||||
// Allocate full header.
|
||||
blob_store_ = BlobStore::Allocate(padded_size);
|
||||
if (!blob_store_) return __LINE__;
|
||||
|
||||
// Copy what we already read (more efficient than seek + re-read).
|
||||
hwy::CopySameSize(&bs, blob_store_.get());
|
||||
// Read the rest of the header, but not the full file.
|
||||
uint8_t* bytes = reinterpret_cast<uint8_t*>(blob_store_.get());
|
||||
if (!file_->Read(sizeof(bs), padded_size - sizeof(bs), bytes + sizeof(bs))) {
|
||||
return __LINE__;
|
||||
}
|
||||
|
||||
return blob_store_->CheckValidity(file_->FileSize());
|
||||
}
|
||||
|
||||
size_t BlobReader::BlobSize(hwy::uint128_t key) const {
|
||||
uint64_t offset;
|
||||
size_t size;
|
||||
if (!blob_store_->FindKey(key, offset, size)) return 0;
|
||||
return size;
|
||||
}
|
||||
|
||||
BlobError BlobReader::Enqueue(hwy::uint128_t key, void* data, size_t size) {
|
||||
uint64_t offset;
|
||||
size_t actual_size;
|
||||
if (!blob_store_->FindKey(key, offset, actual_size)) return __LINE__;
|
||||
if (actual_size != size) {
|
||||
fprintf(stderr,
|
||||
"Mismatch between expected %d and actual %d KiB size of blob %s. "
|
||||
"Please see README.md on how to update the weights.\n",
|
||||
static_cast<int>(size >> 10), static_cast<int>(actual_size >> 10),
|
||||
StringFromKey(key).c_str());
|
||||
return __LINE__;
|
||||
}
|
||||
|
||||
EnqueueChunkRequests(offset, actual_size, reinterpret_cast<uint8_t*>(data),
|
||||
requests_);
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Parallel synchronous I/O. Alternatives considered:
|
||||
// - readv is limited to 0x7FFFF000 bytes on Linux (even 64-bit). Note that
|
||||
// pread calls preadv with a single iovec.
|
||||
// - O_DIRECT seems undesirable because we do want to use the OS cache
|
||||
// between consecutive runs.
|
||||
// - memory-mapped I/O is less predictable and adds noise to measurements.
|
||||
BlobError BlobReader::ReadAll(hwy::ThreadPool& pool) {
|
||||
File* pfile = file_.get(); // not owned
|
||||
const auto& requests = requests_;
|
||||
std::atomic_flag err = ATOMIC_FLAG_INIT;
|
||||
// >5x speedup from parallel reads when cached.
|
||||
pool.Run(0, requests.size(),
|
||||
[pfile, &requests, &err](uint64_t i, size_t /*thread*/) {
|
||||
if (!pfile->Read(requests[i].offset, requests[i].size,
|
||||
requests[i].data)) {
|
||||
fprintf(stderr, "Failed to read blob %zu\n",
|
||||
static_cast<size_t>(i));
|
||||
err.test_and_set();
|
||||
}
|
||||
});
|
||||
if (err.test_and_set()) return __LINE__;
|
||||
return 0;
|
||||
}
|
||||
|
||||
BlobError BlobReader::ReadOne(hwy::uint128_t key, void* data,
|
||||
size_t size) const {
|
||||
uint64_t offset;
|
||||
size_t actual_size;
|
||||
if (!blob_store_->FindKey(key, offset, actual_size)) return __LINE__;
|
||||
if (actual_size != size) {
|
||||
fprintf(stderr,
|
||||
"Mismatch between expected %d and actual %d KiB size of blob %s. "
|
||||
"Please see README.md on how to update the weights.\n",
|
||||
static_cast<int>(size >> 10), static_cast<int>(actual_size >> 10),
|
||||
StringFromKey(key).c_str());
|
||||
return __LINE__;
|
||||
}
|
||||
if (!file_->Read(offset, actual_size, data)) {
|
||||
return __LINE__;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
hwy::Span<const hwy::uint128_t> BlobReader::Keys() const {
|
||||
return blob_store_->Keys();
|
||||
}
|
||||
|
||||
BlobError BlobWriter::WriteAll(hwy::ThreadPool& pool, const Path& filename) {
|
||||
HWY_ASSERT(keys_.size() == blobs_.size());
|
||||
|
||||
// Concatenate blobs in memory.
|
||||
const size_t header_size = BlobStore::HeaderSize(keys_.size());
|
||||
const size_t padded_header_size = hwy::RoundUpTo(header_size, kBlobAlign);
|
||||
const BlobStorePtr bs = BlobStore::Allocate(padded_header_size);
|
||||
const std::vector<BlobIO> requests = BlobStore::PrepareWriteRequests(
|
||||
keys_.data(), blobs_.data(), keys_.size(), bs.get());
|
||||
|
||||
// Create/replace existing file.
|
||||
std::unique_ptr<File> file = OpenFileOrNull(filename, "w+");
|
||||
if (!file) return __LINE__;
|
||||
File* pfile = file.get(); // not owned
|
||||
|
||||
std::atomic_flag err = ATOMIC_FLAG_INIT;
|
||||
pool.Run(0, requests.size(),
|
||||
[pfile, &requests, &err](uint64_t i, size_t /*thread*/) {
|
||||
if (!pfile->Write(requests[i].data, requests[i].size,
|
||||
requests[i].offset)) {
|
||||
err.test_and_set();
|
||||
}
|
||||
});
|
||||
if (err.test_and_set()) return __LINE__;
|
||||
return 0;
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
@ -1,117 +0,0 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_BLOB_STORE_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_BLOB_STORE_H_
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "compression/io.h"
|
||||
#include "hwy/aligned_allocator.h"
|
||||
#include "hwy/base.h" // hwy::uint128_t
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
// Convenient way to construct a key from a string (<= 16 chars).
|
||||
hwy::uint128_t MakeKey(const char* string);
|
||||
|
||||
// Returns a string from a key.
|
||||
std::string StringFromKey(hwy::uint128_t key);
|
||||
|
||||
// Ordered list of opaque blobs (~hundreds), identified by unique opaque
|
||||
// 128-bit keys.
|
||||
class BlobStore;
|
||||
|
||||
// Incomplete type, so dtor will not be called.
|
||||
using BlobStorePtr = hwy::AlignedFreeUniquePtr<BlobStore>;
|
||||
|
||||
// 0 if successful, otherwise the line number of the failing check.
|
||||
using BlobError = int;
|
||||
|
||||
// Blob offsets on disk and memory addresses are a multiple of this, because
|
||||
// we pad the header and each blob's size. This matches CUDA alignment and the
|
||||
// maximum SVE vector size, and exceeds typical x86 cache line sizes (64 or
|
||||
// 128), which can help performance.
|
||||
static constexpr size_t kBlobAlign = 256;
|
||||
|
||||
// One I/O request, serviced by threads in a pool.
|
||||
struct BlobIO {
|
||||
BlobIO(uint64_t offset, size_t size, void* data, uint64_t padding)
|
||||
: offset(offset), size(size), data(data), padding(padding) {}
|
||||
|
||||
uint64_t offset;
|
||||
size_t size; // bytes
|
||||
void* data;
|
||||
uint64_t padding;
|
||||
};
|
||||
|
||||
class BlobReader {
|
||||
public:
|
||||
BlobReader() { requests_.reserve(500); }
|
||||
~BlobReader() = default;
|
||||
|
||||
// Opens `filename` and reads its header.
|
||||
BlobError Open(const Path& filename);
|
||||
|
||||
// Returns the size of the blob identified by `key`, or 0 if not found.
|
||||
size_t BlobSize(hwy::uint128_t key) const;
|
||||
|
||||
// Enqueues read requests if `key` is found and its size matches `size`, which
|
||||
// is in units of bytes.
|
||||
BlobError Enqueue(hwy::uint128_t key, void* data, size_t size);
|
||||
|
||||
// Reads all enqueued requests.
|
||||
BlobError ReadAll(hwy::ThreadPool& pool);
|
||||
|
||||
// Reads one blob directly.
|
||||
BlobError ReadOne(hwy::uint128_t key, void* data, size_t size) const;
|
||||
|
||||
// Returns all available blob keys.
|
||||
hwy::Span<const hwy::uint128_t> Keys() const;
|
||||
|
||||
private:
|
||||
BlobStorePtr blob_store_; // holds header, not the entire file
|
||||
std::vector<BlobIO> requests_;
|
||||
std::unique_ptr<File> file_;
|
||||
};
|
||||
|
||||
class BlobWriter {
|
||||
public:
|
||||
// `size` is in bytes.
|
||||
void Add(hwy::uint128_t key, const void* data, size_t size) {
|
||||
keys_.push_back(key);
|
||||
blobs_.emplace_back(static_cast<const uint8_t*>(data), size);
|
||||
}
|
||||
|
||||
// Stores all blobs to disk in the given order with padding for alignment.
|
||||
BlobError WriteAll(hwy::ThreadPool& pool, const Path& filename);
|
||||
|
||||
// Returns the number of blobs added.
|
||||
size_t DebugNumBlobsAdded() const { return keys_.size(); }
|
||||
|
||||
private:
|
||||
std::vector<hwy::uint128_t> keys_;
|
||||
std::vector<hwy::Span<const uint8_t>> blobs_;
|
||||
};
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_BLOB_STORE_H_
|
||||
|
|
@ -1,86 +0,0 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "compression/blob_store.h"
|
||||
|
||||
#include <stdio.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <array>
|
||||
|
||||
#include "compression/io.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
#include "hwy/tests/hwy_gtest.h"
|
||||
#include "hwy/tests/test_util-inl.h" // HWY_ASSERT_EQ
|
||||
|
||||
namespace gcpp {
|
||||
namespace {
|
||||
|
||||
#if !HWY_TEST_STANDALONE
|
||||
class BlobStoreTest : public testing::Test {};
|
||||
#endif
|
||||
|
||||
#if !HWY_OS_WIN
|
||||
TEST(BlobStoreTest, TestReadWrite) {
|
||||
static const std::array<float, 4> kOriginalData = {-1, 0, 3.14159, 2.71828};
|
||||
|
||||
// mkstemp will modify path_str so it holds a newly-created temporary file.
|
||||
char path_str[] = "/tmp/blob_store_test.sbs-XXXXXX";
|
||||
const int fd = mkstemp(path_str);
|
||||
HWY_ASSERT(fd > 0);
|
||||
|
||||
hwy::ThreadPool pool(4);
|
||||
const Path path(path_str);
|
||||
std::array<float, 4> buffer = kOriginalData;
|
||||
|
||||
const hwy::uint128_t keyA = MakeKey("0123456789abcdef");
|
||||
const hwy::uint128_t keyB = MakeKey("q");
|
||||
BlobWriter writer;
|
||||
writer.Add(keyA, "DATA", 5);
|
||||
writer.Add(keyB, buffer.data(), sizeof(buffer));
|
||||
HWY_ASSERT_EQ(writer.WriteAll(pool, path), 0);
|
||||
HWY_ASSERT_ARRAY_EQ(kOriginalData.data(), buffer.data(), buffer.size());
|
||||
|
||||
std::fill(buffer.begin(), buffer.end(), 0);
|
||||
BlobReader reader;
|
||||
HWY_ASSERT_EQ(reader.Open(path), 0);
|
||||
HWY_ASSERT_EQ(reader.BlobSize(keyA), 5);
|
||||
HWY_ASSERT_EQ(reader.BlobSize(keyB), sizeof(buffer));
|
||||
|
||||
HWY_ASSERT_EQ(reader.Enqueue(keyB, buffer.data(), sizeof(buffer)), 0);
|
||||
HWY_ASSERT_EQ(reader.ReadAll(pool), 0);
|
||||
HWY_ASSERT_ARRAY_EQ(kOriginalData.data(), buffer.data(), buffer.size());
|
||||
|
||||
{
|
||||
std::array<char, 5> buffer;
|
||||
HWY_ASSERT(reader.ReadOne(keyA, buffer.data(), 1) != 0);
|
||||
HWY_ASSERT_EQ(reader.ReadOne(keyA, buffer.data(), 5), 0);
|
||||
HWY_ASSERT_STRING_EQ("DATA", buffer.data());
|
||||
}
|
||||
|
||||
const hwy::Span<const hwy::uint128_t> keys = reader.Keys();
|
||||
HWY_ASSERT_EQ(keys.size(), 2);
|
||||
HWY_ASSERT_EQ(keys[0], keyA);
|
||||
HWY_ASSERT_EQ(keys[1], keyB);
|
||||
|
||||
close(fd);
|
||||
unlink(path_str);
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace
|
||||
} // namespace gcpp
|
||||
|
||||
HWY_TEST_MAIN();
|
||||
|
|
@ -21,19 +21,20 @@
|
|||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include <cmath> // lroundf, only if COMPRESS_STATS
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "compression/blob_store.h"
|
||||
#include "compression/compress.h" // IWYU pragma: export
|
||||
#include "compression/distortion.h"
|
||||
#include "gemma/configs.h"
|
||||
#include "hwy/aligned_allocator.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
#include "hwy/timer.h"
|
||||
|
||||
#if COMPRESS_STATS
|
||||
#include <cmath> // lroundf
|
||||
#endif
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_COMPRESS_INL_H_
|
||||
|
||||
// Include guard for (potentially) SIMD code.
|
||||
|
|
@ -64,8 +65,8 @@ static constexpr bool kIsTest = false;
|
|||
template <typename T> // primary, must specialize
|
||||
struct CompressTraits {};
|
||||
|
||||
// Used by backprop/, where weights are currently f32; also MatMul for f32
|
||||
// weights or activations, if native `ReorderWidenMulAccumulate` is available.
|
||||
// Used by MatMul for f32 weights or activations, if native
|
||||
// `ReorderWidenMulAccumulate` is available.
|
||||
template <>
|
||||
struct CompressTraits<float> {
|
||||
using Packed = float;
|
||||
|
|
@ -379,7 +380,7 @@ struct CompressTraits<SfpStream> {
|
|||
using Packed = SfpStream;
|
||||
|
||||
// Callers are responsible for scaling `raw` such that its magnitudes do not
|
||||
// exceed `SfpStream::kMax`. See CompressedArray::scale().
|
||||
// exceed `SfpStream::kMax`. See CompressedArray::Scale().
|
||||
template <class DF, HWY_IF_F32_D(DF)>
|
||||
static HWY_INLINE void Compress(DF df, const float* HWY_RESTRICT raw,
|
||||
size_t num, CompressPerThread& tls,
|
||||
|
|
@ -387,7 +388,7 @@ struct CompressTraits<SfpStream> {
|
|||
const size_t packed_ofs) {
|
||||
SfpCodec::Enc(df, raw, num, packed.ptr + packed_ofs);
|
||||
|
||||
if (COMPRESS_STATS) {
|
||||
if constexpr (COMPRESS_STATS) {
|
||||
const hn::Repartition<BF16, DF> dbf;
|
||||
auto distorted =
|
||||
hwy::AllocateAligned<BF16>(hwy::RoundUpTo(num, hn::Lanes(dbf)));
|
||||
|
|
@ -431,9 +432,10 @@ struct CompressTraits<NuqStream> {
|
|||
size_t num, CompressPerThread& tls,
|
||||
const PackedSpan<Packed>& packed,
|
||||
const size_t packed_ofs) {
|
||||
NuqCodec::Enc(df, raw, num, tls.buf, packed, packed_ofs);
|
||||
if (!tls.buf) tls.buf = std::make_unique<NuqStream::ClusterBuf>();
|
||||
NuqCodec::Enc(df, raw, num, *tls.buf, packed, packed_ofs);
|
||||
|
||||
if (COMPRESS_STATS) {
|
||||
if constexpr (COMPRESS_STATS) {
|
||||
for (size_t i = 0; i < num; ++i) {
|
||||
tls.stats.NotifyIn(static_cast<int>(lroundf(raw[i] * 100.0f + 500.0f)));
|
||||
}
|
||||
|
|
@ -477,7 +479,7 @@ HWY_NOINLINE void Compress(const float* HWY_RESTRICT raw, size_t num,
|
|||
const size_t packed_ofs, hwy::ThreadPool& pool) {
|
||||
packed.BoundsCheck(packed_ofs, num);
|
||||
work.tls.resize(pool.NumWorkers());
|
||||
if (COMPRESS_STATS) {
|
||||
if constexpr (COMPRESS_STATS) {
|
||||
for (auto& tls : work.tls) {
|
||||
tls.stats.Reset();
|
||||
}
|
||||
|
|
@ -486,7 +488,7 @@ HWY_NOINLINE void Compress(const float* HWY_RESTRICT raw, size_t num,
|
|||
const bool want_bench = COMPRESS_STATS || !kIsTest;
|
||||
const double t0 = want_bench ? hwy::platform::Now() : 0.0;
|
||||
|
||||
using Traits = CompressTraits<Packed>;
|
||||
using Traits = CompressTraits<hwy::RemoveConst<Packed>>;
|
||||
constexpr size_t kBatch = 8192;
|
||||
const size_t num_batches = hwy::DivCeil(num, kBatch);
|
||||
pool.Run(0, num_batches,
|
||||
|
|
@ -507,7 +509,7 @@ HWY_NOINLINE void Compress(const float* HWY_RESTRICT raw, size_t num,
|
|||
fprintf(stderr, "Compress %.1f MB/s\n", mbps);
|
||||
}
|
||||
|
||||
if (COMPRESS_STATS) {
|
||||
if constexpr (COMPRESS_STATS) {
|
||||
for (size_t i = 1; i < work.tls.size(); ++i) {
|
||||
work.tls[0].stats.Assimilate(work.tls[i].stats);
|
||||
}
|
||||
|
|
@ -515,26 +517,25 @@ HWY_NOINLINE void Compress(const float* HWY_RESTRICT raw, size_t num,
|
|||
}
|
||||
}
|
||||
|
||||
// Adapter that compresses into `MatStorageT`. `raw` must already be scaled
|
||||
// to fit the value range, if `Packed` is `SfpStream`.
|
||||
// Same as above, but without parallelization nor benchmarking.
|
||||
template <typename Packed>
|
||||
HWY_INLINE void CompressScaled(const float* HWY_RESTRICT raw, size_t num,
|
||||
CompressWorkingSet& work,
|
||||
MatStorageT<Packed>& compressed,
|
||||
hwy::ThreadPool& pool) {
|
||||
Compress(raw, num, work,
|
||||
MakeSpan(compressed.data(), compressed.NumElements()),
|
||||
/*packed_ofs=*/0, pool);
|
||||
HWY_NOINLINE void Compress(const float* HWY_RESTRICT raw, size_t num,
|
||||
CompressPerThread& tls,
|
||||
const PackedSpan<Packed>& packed,
|
||||
const size_t packed_ofs) {
|
||||
packed.BoundsCheck(packed_ofs, num);
|
||||
using Traits = CompressTraits<hwy::RemoveConst<Packed>>;
|
||||
const hn::ScalableTag<float> df;
|
||||
Traits::Compress(df, raw, num, tls, packed, packed_ofs);
|
||||
}
|
||||
|
||||
// Stores two f32 vectors to f32 or bf16; avoids duplicating RMSNorm and
|
||||
// RMSNormInplace for the two output types.
|
||||
// Stores two f32 vectors to f32 or bf16.
|
||||
template <class DF, typename Packed, HWY_IF_F32_D(DF), class VF = hn::Vec<DF>>
|
||||
void Compress2(DF df, VF raw0, VF raw1, const PackedSpan<Packed>& packed,
|
||||
const size_t packed_ofs) {
|
||||
static_assert(hwy::IsSameEither<Packed, float, BF16>());
|
||||
packed.BoundsCheck(packed_ofs, 2 * hn::Lanes(df));
|
||||
using Traits = CompressTraits<Packed>;
|
||||
using Traits = CompressTraits<hwy::RemoveConst<Packed>>;
|
||||
Traits::Store2(df, raw0, raw1, packed, packed_ofs);
|
||||
}
|
||||
|
||||
|
|
@ -566,7 +567,7 @@ HWY_INLINE void Decompress2(DRaw d, const PackedSpan<Packed>& packed,
|
|||
// Decompresses from any type of `packed`, starting at (any) `packed_ofs`, to
|
||||
// (any) `num` elements in `raw`, then appends `[0, hn::Lanes(d))` zeroes as
|
||||
// required to round `num` up to one vector, if it is not already. The caller is
|
||||
// responsible for scaling `raw` to the original range because `EmbedToken`
|
||||
// responsible for scaling `raw` to the original range because `EmbedMMToken`
|
||||
// also wants to scale the decompressed elements.
|
||||
// `TRaw` can be `float/BF16`, or `double` if `Packed` is `float`.
|
||||
template <class DRaw, typename Packed, typename TRaw = hn::TFromD<DRaw>>
|
||||
|
|
@ -708,51 +709,6 @@ HWY_INLINE float DecompressAndCall(D, const PackedSpan<const VT> v,
|
|||
comp3);
|
||||
}
|
||||
|
||||
// Functor called for each tensor, which compresses and stores them along with
|
||||
// their scaling factors to BlobStore.
|
||||
class Compressor {
|
||||
public:
|
||||
explicit Compressor(hwy::ThreadPool& pool) : writer_(pool) {}
|
||||
|
||||
template <typename Packed>
|
||||
void operator()(MatPtrT<Packed>* compressed, const char* decorated_name,
|
||||
const float* HWY_RESTRICT weights) {
|
||||
size_t num_weights = compressed->NumElements();
|
||||
if (num_weights == 0 || weights == nullptr || compressed->Ptr() == nullptr)
|
||||
return;
|
||||
size_t num_compressed = compressed->NumElements();
|
||||
PackedSpan<Packed> packed = MakeSpan(compressed->data(), num_compressed);
|
||||
fprintf(stderr, "Compressing %s (%zuM), please wait\n", decorated_name,
|
||||
num_weights / (1000 * 1000));
|
||||
Compress(weights, num_weights, work_, packed, /*packed_ofs=*/0,
|
||||
writer_.pool());
|
||||
writer_(compressed, decorated_name);
|
||||
}
|
||||
|
||||
void AddTokenizer(const std::string& tokenizer) {
|
||||
writer_.AddTokenizer(tokenizer);
|
||||
}
|
||||
|
||||
void AddScales(const float* scales, size_t len) {
|
||||
writer_.AddScales(scales, len);
|
||||
}
|
||||
|
||||
// Writes all blobs to disk in the given order. The config is optional and
|
||||
// if given, it is written to the file, along with the TOC, making it
|
||||
// single-file format. Otherwise, the file is written in the multi-file format
|
||||
// without a TOC.
|
||||
BlobError WriteAll(const Path& blob_filename, const ModelConfig* config) {
|
||||
return writer_.WriteAll(blob_filename, config);
|
||||
}
|
||||
|
||||
// Returns the number of blobs added.
|
||||
size_t DebugNumBlobsAdded() const { return writer_.DebugNumBlobsAdded(); }
|
||||
|
||||
private:
|
||||
CompressWorkingSet work_;
|
||||
WriteToBlobStore writer_;
|
||||
};
|
||||
|
||||
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||
} // namespace HWY_NAMESPACE
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
|
|
@ -15,8 +15,34 @@
|
|||
|
||||
#include "compression/compress.h"
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include "util/mat.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/profiler.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
MatPtr::~MatPtr() {}
|
||||
float ScaleWeights(float* HWY_RESTRICT raw, size_t num) {
|
||||
PROFILER_FUNC;
|
||||
|
||||
float maxabs = 0.0;
|
||||
for (size_t i = 0; i < num; ++i) {
|
||||
maxabs = HWY_MAX(maxabs, hwy::ScalarAbs(raw[i]));
|
||||
}
|
||||
if (maxabs <= SfpStream::kMax) {
|
||||
return 1.0f;
|
||||
}
|
||||
const float scale = maxabs / SfpStream::kMax;
|
||||
const float inv_scale = static_cast<float>(1.0 / static_cast<double>(scale));
|
||||
for (size_t i = 0; i < num; ++i) {
|
||||
// Clamp because kMax may still be exceeded.
|
||||
const float magn =
|
||||
HWY_MIN(SfpStream::kMax, hwy::ScalarAbs(raw[i] * inv_scale));
|
||||
raw[i] = hwy::ScalarCopySign(magn, raw[i]);
|
||||
}
|
||||
return scale;
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
|
|
@ -17,31 +17,19 @@
|
|||
#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_COMPRESS_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_COMPRESS_H_
|
||||
|
||||
#include "hwy/base.h"
|
||||
#define COMPRESS_STATS 0
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#if COMPRESS_STATS
|
||||
#include <stdio.h>
|
||||
#endif
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
// IWYU pragma: begin_exports
|
||||
#include "compression/blob_store.h"
|
||||
#include "compression/fields.h"
|
||||
#include "compression/io.h"
|
||||
#include "compression/shared.h"
|
||||
#include "gemma/tensor_index.h"
|
||||
#include "util/basics.h"
|
||||
// IWYU pragma: end_exports
|
||||
#include "gemma/configs.h"
|
||||
#include "util/allocator.h"
|
||||
#include "hwy/per_target.h"
|
||||
#include "compression/types.h" // IWYU pragma: export
|
||||
#if COMPRESS_STATS
|
||||
#include "compression/distortion.h"
|
||||
#include "hwy/stats.h"
|
||||
|
|
@ -49,388 +37,6 @@
|
|||
|
||||
namespace gcpp {
|
||||
|
||||
// Base class for rank-1 or 2 tensors (vector or matrix).
|
||||
// Supports both dynamic and compile-time sizing.
|
||||
// Holds metadata and a non-owning pointer to the data, owned by the derived
|
||||
// MatStorageT class.
|
||||
// This class also provides easy conversion from/to a table of contents for a
|
||||
// BlobStore file, and a templated (compile-time) accessor for a 2-d array of
|
||||
// fixed inner dimension and type.
|
||||
// It is designed to be put in a vector, and has default copy and operator=, so
|
||||
// it is easy to read/write a blob_store file.
|
||||
class MatPtr : public IFields {
|
||||
public:
|
||||
// Full constructor for dynamic sizing.
|
||||
MatPtr(const std::string& name, Type type, size_t element_size, size_t rows,
|
||||
size_t cols)
|
||||
: name_(name),
|
||||
type_(type),
|
||||
element_size_(element_size),
|
||||
num_elements_(rows * cols),
|
||||
rows_(rows),
|
||||
cols_(cols),
|
||||
ptr_(nullptr) {
|
||||
stride_ = cols;
|
||||
}
|
||||
// Default is to leave all fields default-initialized.
|
||||
MatPtr() = default;
|
||||
virtual ~MatPtr();
|
||||
|
||||
// Compatibility interface for CompressedArray.
|
||||
// TODO: remove.
|
||||
template <typename T>
|
||||
T* data() {
|
||||
return HWY_RCAST_ALIGNED(T*, ptr_);
|
||||
}
|
||||
template <typename T>
|
||||
const T* data() const {
|
||||
return HWY_RCAST_ALIGNED(const T*, ptr_);
|
||||
}
|
||||
|
||||
const void* Ptr() const { return ptr_; }
|
||||
void* Ptr() { return ptr_; }
|
||||
// Sets the pointer from another MatPtr.
|
||||
void SetPtr(const MatPtr& other) { ptr_ = other.ptr_; }
|
||||
|
||||
// Copying allowed as the metadata is small.
|
||||
MatPtr(const MatPtr& other) = default;
|
||||
MatPtr& operator=(const MatPtr& other) = default;
|
||||
|
||||
// Returns the name of the blob.
|
||||
const char* Name() const override { return name_.c_str(); }
|
||||
void SetName(const std::string& name) { name_ = name; }
|
||||
|
||||
// Returns the type of the blob.
|
||||
Type GetType() const { return type_; }
|
||||
|
||||
// Returns the size of each element in bytes.
|
||||
size_t ElementSize() const { return element_size_; }
|
||||
|
||||
// Returns the number of elements in the array.
|
||||
size_t NumElements() const { return num_elements_; }
|
||||
|
||||
// Returns the number of bytes in the array.
|
||||
size_t SizeBytes() const {
|
||||
if (this->GetType() == TypeEnum<NuqStream>()) {
|
||||
return NuqStream::PackedEnd(num_elements_);
|
||||
}
|
||||
return num_elements_ * element_size_;
|
||||
}
|
||||
|
||||
// Returns the number of rows in the 2-d array (outer dimension).
|
||||
size_t Rows() const { return rows_; }
|
||||
|
||||
// Returns the number of columns in the 2-d array (inner dimension).
|
||||
size_t Cols() const { return cols_; }
|
||||
|
||||
Extents2D Extents() const { return Extents2D(rows_, cols_); }
|
||||
|
||||
// Currently same as cols, but may differ in the future. This is the offset by
|
||||
// which to advance pointers to the next row.
|
||||
size_t Stride() const { return stride_; }
|
||||
|
||||
// Decoded elements should be multiplied by this to restore their original
|
||||
// range. This is required because SfpStream can only encode a limited range
|
||||
// of magnitudes.
|
||||
float scale() const { return scale_; }
|
||||
void set_scale(float scale) { scale_ = scale; }
|
||||
|
||||
std::string LayerName(int layer) const {
|
||||
std::string name = name_ + std::to_string(layer);
|
||||
HWY_ASSERT(name.size() <= sizeof(hwy::uint128_t));
|
||||
return name;
|
||||
}
|
||||
|
||||
// Sets all data to zero.
|
||||
void ZeroInit() {
|
||||
if (ptr_ == nullptr)
|
||||
HWY_ABORT("ptr_ is null on tensor %s\n", name_.c_str());
|
||||
hwy::ZeroBytes(ptr_, SizeBytes());
|
||||
}
|
||||
|
||||
void VisitFields(IFieldsVisitor& visitor) override {
|
||||
visitor(name_);
|
||||
visitor(type_);
|
||||
visitor(element_size_);
|
||||
visitor(num_elements_);
|
||||
visitor(rows_);
|
||||
visitor(cols_);
|
||||
visitor(scale_);
|
||||
visitor(stride_);
|
||||
}
|
||||
|
||||
// Calls func on the upcasted type. Since MatPtr by design is not templated,
|
||||
// here we provide a way to get to the derived type, provided that `Type()`
|
||||
// is one of the strings returned by `TypeName()`.
|
||||
template <class FuncT, typename... TArgs>
|
||||
decltype(auto) CallUpcasted(FuncT& func, TArgs&&... args);
|
||||
|
||||
protected:
|
||||
// Arbitrary name for the array of preferably <= 16 characters.
|
||||
std::string name_;
|
||||
// Should be the result of TypeEnum<T> for CallUpcasted() to work.
|
||||
Type type_;
|
||||
// sizeof(T)
|
||||
uint32_t element_size_ = 0;
|
||||
// Number of elements in the array.
|
||||
uint32_t num_elements_ = 0; // In element_size units.
|
||||
// Number of rows in the 2-d array (outer dimension).
|
||||
uint32_t rows_ = 0;
|
||||
// Number of columns in the 2-d array (inner dimension).
|
||||
uint32_t cols_ = 0;
|
||||
// Scaling to apply to each element.
|
||||
float scale_ = 1.0f;
|
||||
// Aligned data array. This is always a borrowed pointer. It should never be
|
||||
// freed. The underlying memory is owned by a subclass or some external class
|
||||
// and must outlive this object.
|
||||
void* ptr_ = nullptr;
|
||||
|
||||
uint32_t stride_;
|
||||
};
|
||||
|
||||
// MatPtrT adds a single template argument to MatPtr for an explicit type.
|
||||
// Use this class as a function argument where the type needs to be known.
|
||||
// Use MatPtr where the type does not need to be known.
|
||||
template <typename MatT>
|
||||
class MatPtrT : public MatPtr {
|
||||
public:
|
||||
// Full constructor for dynamic sizing.
|
||||
MatPtrT(const std::string& name, size_t rows, size_t cols)
|
||||
: MatPtr(name, TypeEnum<MatT>(), sizeof(MatT), rows, cols) {}
|
||||
// Construction from TensorIndex entry to remove duplication of sizes.
|
||||
MatPtrT(const std::string& name, const TensorIndex& tensor_index)
|
||||
: MatPtrT<MatT>(name, tensor_index.FindName(name)) {}
|
||||
MatPtrT(const std::string& name, const TensorInfo* tensor)
|
||||
: MatPtr(name, TypeEnum<MatT>(), sizeof(MatT), 0, 0) {
|
||||
if (tensor == nullptr) {
|
||||
cols_ = 0;
|
||||
rows_ = 0;
|
||||
} else {
|
||||
cols_ = tensor->shape.back();
|
||||
rows_ = 1;
|
||||
if (tensor->cols_take_extra_dims) {
|
||||
// The columns eat the extra dimensions.
|
||||
rows_ = tensor->shape[0];
|
||||
for (size_t i = 1; i < tensor->shape.size() - 1; ++i) {
|
||||
cols_ *= tensor->shape[i];
|
||||
}
|
||||
} else {
|
||||
// The rows eat the extra dimensions.
|
||||
for (size_t i = 0; i < tensor->shape.size() - 1; ++i) {
|
||||
rows_ *= tensor->shape[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
stride_ = cols_;
|
||||
num_elements_ = rows_ * cols_;
|
||||
}
|
||||
|
||||
// Copying allowed as the metadata is small.
|
||||
MatPtrT(const MatPtr& other) : MatPtr(other) {}
|
||||
MatPtrT& operator=(const MatPtr& other) {
|
||||
MatPtr::operator=(other);
|
||||
return *this;
|
||||
}
|
||||
MatPtrT(const MatPtrT& other) = default;
|
||||
MatPtrT& operator=(const MatPtrT& other) = default;
|
||||
|
||||
std::string CacheName(int layer = -1, char separator = ' ',
|
||||
int index = -1) const {
|
||||
// Already used/retired: s, S, n, 1
|
||||
const char prefix = hwy::IsSame<MatT, float>() ? 'F'
|
||||
: hwy::IsSame<MatT, BF16>() ? 'B'
|
||||
: hwy::IsSame<MatT, SfpStream>() ? '$'
|
||||
: hwy::IsSame<MatT, NuqStream>() ? '2'
|
||||
: '?';
|
||||
std::string name = std::string(1, prefix) + name_;
|
||||
if (layer >= 0 || index >= 0) {
|
||||
name += '_';
|
||||
if (layer >= 0) name += std::to_string(layer);
|
||||
if (index >= 0) {
|
||||
name += separator + std::to_string(index);
|
||||
}
|
||||
}
|
||||
return name;
|
||||
}
|
||||
|
||||
// Sets the number of elements in the array. For use when the number of
|
||||
// elements is != rows * cols ONLY.
|
||||
void SetNumElements(size_t num_elements) {
|
||||
num_elements_ = CompressedArrayElements<MatT>(num_elements);
|
||||
}
|
||||
|
||||
// 2-d Accessor for a specific type but with a dynamic inner dimension.
|
||||
template <typename T = MatT>
|
||||
const T& At(size_t row, size_t col) const {
|
||||
size_t index = row * cols_ + col;
|
||||
HWY_DASSERT(index < num_elements_);
|
||||
return HWY_RCAST_ALIGNED(const T*, ptr_)[index];
|
||||
}
|
||||
|
||||
// 1-d Accessor for a specific type.
|
||||
// TODO: replace this with a Foreach(), or at least a ForEachRow().
|
||||
const MatT& At(size_t index) const {
|
||||
HWY_DASSERT(index < num_elements_);
|
||||
return HWY_RCAST_ALIGNED(const MatT*, ptr_)[index];
|
||||
}
|
||||
MatT& At(size_t index) { return HWY_RCAST_ALIGNED(MatT*, ptr_)[index]; }
|
||||
|
||||
// Compatibility interface for CompressedArray.
|
||||
// TODO: remove
|
||||
template <typename T = MatT>
|
||||
T* data() {
|
||||
return HWY_RCAST_ALIGNED(T*, ptr_);
|
||||
}
|
||||
template <typename T = MatT>
|
||||
const T* data() const {
|
||||
return HWY_RCAST_ALIGNED(const T*, ptr_);
|
||||
}
|
||||
// The const accessor data_scale1() asserts (!) that the scale is 1.0f, so
|
||||
// calling it means "I am sure the scale is 1 and therefore ignore the scale".
|
||||
// A scale of 0 indicates that the scale has likely never been set, so is
|
||||
// "implicitly 1".
|
||||
const MatT* data_scale1() const {
|
||||
HWY_ASSERT(scale() == 1.f);
|
||||
return HWY_RCAST_ALIGNED(const MatT*, ptr_);
|
||||
}
|
||||
};
|
||||
|
||||
template <class FuncT, typename... TArgs>
|
||||
decltype(auto) MatPtr::CallUpcasted(FuncT& func, TArgs&&... args) {
|
||||
if (type_ == TypeEnum<float>()) {
|
||||
return func(dynamic_cast<MatPtrT<float>*>(this),
|
||||
std::forward<TArgs>(args)...);
|
||||
} else if (type_ == TypeEnum<BF16>()) {
|
||||
return func(dynamic_cast<MatPtrT<BF16>*>(this),
|
||||
std::forward<TArgs>(args)...);
|
||||
} else if (type_ == TypeEnum<SfpStream>()) {
|
||||
return func(dynamic_cast<MatPtrT<SfpStream>*>(this),
|
||||
std::forward<TArgs>(args)...);
|
||||
} else if (type_ == TypeEnum<NuqStream>()) {
|
||||
return func(dynamic_cast<MatPtrT<NuqStream>*>(this),
|
||||
std::forward<TArgs>(args)...);
|
||||
} else {
|
||||
HWY_ABORT("Type %d unknown.", type_);
|
||||
}
|
||||
}
|
||||
|
||||
// MatStorageT adds the actual data storage to MatPtrT.
|
||||
// TODO: use Extents2D instead of rows and cols.
|
||||
template <typename MatT>
|
||||
class MatStorageT : public MatPtrT<MatT> {
|
||||
public:
|
||||
// Full constructor for dynamic sizing.
|
||||
MatStorageT(const std::string& name, size_t rows, size_t cols)
|
||||
: MatPtrT<MatT>(name, rows, cols) {
|
||||
Allocate();
|
||||
}
|
||||
// Can copy the metadata, from a MatPtr, and allocate later.
|
||||
MatStorageT(const MatPtr& other) : MatPtrT<MatT>(other) {}
|
||||
~MatStorageT() = default;
|
||||
|
||||
// Move-only because this contains a unique_ptr.
|
||||
MatStorageT(const MatStorageT& other) = delete;
|
||||
MatStorageT& operator=(const MatStorageT& other) = delete;
|
||||
MatStorageT(MatStorageT&& other) = default;
|
||||
MatStorageT& operator=(MatStorageT&& other) = default;
|
||||
|
||||
// Allocate the memory and copy the pointer to the MatPtr.
|
||||
// num_elements is in elements. In the default (zero) case, it is computed
|
||||
// from the current num_elements_ which was set by the constructor from the
|
||||
// rows and cols.
|
||||
void Allocate(size_t num_elements = 0) {
|
||||
if (num_elements == 0) {
|
||||
num_elements = hwy::DivCeil(this->SizeBytes(), sizeof(MatT));
|
||||
} else {
|
||||
this->num_elements_ = num_elements;
|
||||
}
|
||||
// Pad to allow overrunning the last row by 2 BF16 vectors, hence at most
|
||||
// `2 * VectorBytes / sizeof(BF16)` elements of MatT.
|
||||
const size_t padding = hwy::VectorBytes();
|
||||
data_ = Allocator::Alloc<MatT>(num_elements + padding);
|
||||
hwy::ZeroBytes(&data_[num_elements], padding * sizeof(MatT));
|
||||
this->ptr_ = data_.get();
|
||||
}
|
||||
|
||||
// Zeros the content.
|
||||
void ZeroInit() {
|
||||
HWY_ASSERT(data_ != nullptr);
|
||||
hwy::ZeroBytes(data_.get(), this->SizeBytes());
|
||||
}
|
||||
|
||||
private:
|
||||
AlignedPtr<MatT> data_;
|
||||
};
|
||||
|
||||
// MatStorage allows heterogeneous tensors to be stored in a single vector.
|
||||
using MatStorage = MatStorageT<hwy::uint128_t>;
|
||||
|
||||
// Table of contents for a blob store file. Full metadata, but not actual data.
|
||||
class BlobToc {
|
||||
public:
|
||||
BlobToc() = default;
|
||||
|
||||
// Loads the table of contents from the given reader.
|
||||
BlobError LoadToc(BlobReader& reader) {
|
||||
hwy::uint128_t toc_key = MakeKey(kTocName);
|
||||
size_t toc_size = reader.BlobSize(toc_key);
|
||||
if (toc_size != 0) {
|
||||
std::vector<uint32_t> toc(toc_size / sizeof(uint32_t));
|
||||
BlobError err = reader.ReadOne(toc_key, toc.data(), toc_size);
|
||||
if (err != 0) {
|
||||
fprintf(stderr, "Failed to read toc (error %d)\n", err);
|
||||
return err;
|
||||
}
|
||||
size_t consumed = 0;
|
||||
size_t prev_consumed = static_cast<size_t>(-1);
|
||||
while (consumed < toc.size() && prev_consumed != consumed) {
|
||||
MatPtr blob;
|
||||
const IFields::ReadResult result =
|
||||
blob.Read(hwy::Span<const uint32_t>(toc), consumed);
|
||||
prev_consumed = consumed;
|
||||
consumed = result.pos;
|
||||
if (blob.NumElements() > 0) {
|
||||
AddToToc(blob);
|
||||
}
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
bool Empty() const { return toc_map_.empty(); }
|
||||
|
||||
// Returns true if the table of contents contains the given name.
|
||||
bool Contains(const std::string& name) const {
|
||||
return toc_map_.find(name) != toc_map_.end();
|
||||
}
|
||||
|
||||
// Returns the blob with the given name, or nullptr if not found.
|
||||
const MatPtr* Get(const std::string& name) const {
|
||||
auto it = toc_map_.find(name);
|
||||
if (it == toc_map_.end()) return nullptr;
|
||||
return &toc_[it->second];
|
||||
}
|
||||
// The name of the toc in the blob store file.
|
||||
static constexpr char kTocName[] = "toc";
|
||||
|
||||
// The name of the config in the blob store file.
|
||||
static constexpr char kConfigName[] = "config";
|
||||
|
||||
// The name of the tokenizer in the blob store file.
|
||||
static constexpr char kTokenizerName[] = "tokenizer";
|
||||
|
||||
private:
|
||||
// Adds the blob to the table of contents.
|
||||
void AddToToc(const MatPtr& blob) {
|
||||
HWY_ASSERT(!Contains(blob.Name()));
|
||||
toc_map_[blob.Name()] = toc_.size();
|
||||
toc_.push_back(blob);
|
||||
}
|
||||
|
||||
std::unordered_map<std::string, size_t> toc_map_;
|
||||
std::vector<MatPtr> toc_;
|
||||
};
|
||||
|
||||
#if COMPRESS_STATS
|
||||
class CompressStats {
|
||||
public:
|
||||
|
|
@ -489,7 +95,8 @@ struct CompressStats {
|
|||
#endif // COMPRESS_STATS
|
||||
|
||||
struct CompressPerThread {
|
||||
NuqStream::ClusterBuf buf;
|
||||
// Allocated the first time NUQ is used.
|
||||
std::unique_ptr<NuqStream::ClusterBuf> buf;
|
||||
CompressStats stats;
|
||||
};
|
||||
|
||||
|
|
@ -497,196 +104,11 @@ struct CompressWorkingSet {
|
|||
std::vector<CompressPerThread> tls;
|
||||
};
|
||||
|
||||
// Class to collect and write a set of tensors to a blob store file.
|
||||
class WriteToBlobStore {
|
||||
public:
|
||||
explicit WriteToBlobStore(hwy::ThreadPool& pool) : pool_(pool) {}
|
||||
|
||||
template <typename Packed>
|
||||
void operator()(MatPtrT<Packed>* compressed, const char* decorated_name) {
|
||||
if (compressed->Ptr() == nullptr) return;
|
||||
writer_.Add(MakeKey(decorated_name), compressed->Ptr(),
|
||||
compressed->SizeBytes());
|
||||
MatPtr renamed_tensor(*compressed);
|
||||
renamed_tensor.SetName(decorated_name);
|
||||
renamed_tensor.AppendTo(toc_);
|
||||
}
|
||||
|
||||
void AddTokenizer(const std::string& tokenizer) {
|
||||
writer_.Add(MakeKey(BlobToc::kTokenizerName), tokenizer.data(),
|
||||
tokenizer.size() * sizeof(tokenizer[0]));
|
||||
}
|
||||
|
||||
void AddScales(const float* scales, size_t len) {
|
||||
if (len) {
|
||||
MatPtrT<float> scales_ptr("scales", 0, 1);
|
||||
writer_.Add(MakeKey(scales_ptr.CacheName().c_str()), scales,
|
||||
len * sizeof(scales[0]));
|
||||
}
|
||||
}
|
||||
|
||||
// Writes all blobs to disk in the given order. The config is optional and
|
||||
// if given, it is written to the file, along with the TOC, making it
|
||||
// single-file format. Otherwise, the file is written in the multi-file format
|
||||
// without a TOC.
|
||||
BlobError WriteAll(const Path& blob_filename, const ModelConfig* config) {
|
||||
if (config) {
|
||||
writer_.Add(MakeKey(BlobToc::kTocName), toc_.data(),
|
||||
toc_.size() * sizeof(toc_[0]));
|
||||
config_buffer_ = config->Write();
|
||||
writer_.Add(MakeKey(BlobToc::kConfigName), config_buffer_.data(),
|
||||
config_buffer_.size() * sizeof(config_buffer_[0]));
|
||||
}
|
||||
const BlobError err = writer_.WriteAll(pool_, blob_filename);
|
||||
if (err != 0) {
|
||||
fprintf(stderr, "Failed to write blobs to %s (error %d)\n",
|
||||
blob_filename.path.c_str(), err);
|
||||
}
|
||||
return err;
|
||||
}
|
||||
|
||||
// Returns the number of blobs added.
|
||||
size_t DebugNumBlobsAdded() const { return writer_.DebugNumBlobsAdded(); }
|
||||
|
||||
hwy::ThreadPool& pool() { return pool_; }
|
||||
|
||||
protected:
|
||||
hwy::ThreadPool& pool_;
|
||||
|
||||
private:
|
||||
std::vector<uint32_t> toc_;
|
||||
BlobWriter writer_;
|
||||
std::vector<uint32_t> config_buffer_;
|
||||
};
|
||||
|
||||
// Functor called for each tensor, which loads them and their scaling factors
|
||||
// from BlobStore.
|
||||
class ReadFromBlobStore {
|
||||
public:
|
||||
explicit ReadFromBlobStore(const Path& blob_filename) {
|
||||
err_ = reader_.Open(blob_filename);
|
||||
if (HWY_UNLIKELY(err_ != 0)) {
|
||||
fprintf(stderr, "Error %d opening BlobStore %s.\n", err_,
|
||||
blob_filename.path.c_str());
|
||||
return; // avoid overwriting err_ to ensure ReadAll will fail.
|
||||
}
|
||||
err_ = file_toc_.LoadToc(reader_);
|
||||
if (HWY_UNLIKELY(err_ != 0)) {
|
||||
fprintf(stderr, "Found a TOC, but failed to load it (code %d)\n", err_);
|
||||
}
|
||||
}
|
||||
|
||||
// Returns true if there is a TOC.
|
||||
bool HaveToc() const { return !file_toc_.Empty(); }
|
||||
|
||||
// Reads the config from the blob store file.
|
||||
BlobError LoadConfig(ModelConfig& config) {
|
||||
hwy::uint128_t config_key = MakeKey(BlobToc::kConfigName);
|
||||
size_t config_size = reader_.BlobSize(config_key);
|
||||
if (config_size == 0) return __LINE__;
|
||||
std::vector<uint32_t> config_buffer(config_size / sizeof(uint32_t));
|
||||
BlobError err =
|
||||
reader_.ReadOne(config_key, config_buffer.data(), config_size);
|
||||
if (err != 0) {
|
||||
fprintf(stderr, "Failed to read config (error %d)\n", err);
|
||||
return err;
|
||||
}
|
||||
config.Read(hwy::Span<const uint32_t>(config_buffer), 0);
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Reads the tokenizer from the blob store file.
|
||||
BlobError LoadTokenizer(std::string& tokenizer) {
|
||||
hwy::uint128_t key = MakeKey(BlobToc::kTokenizerName);
|
||||
size_t tokenizer_size = reader_.BlobSize(key);
|
||||
if (tokenizer_size == 0) return __LINE__;
|
||||
tokenizer.resize(tokenizer_size);
|
||||
;
|
||||
BlobError err = reader_.ReadOne(key, tokenizer.data(), tokenizer_size);
|
||||
if (err != 0) {
|
||||
fprintf(stderr, "Failed to read tokenizer (error %d)\n", err);
|
||||
return err;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Called for each tensor, enqueues read requests.
|
||||
void operator()(const char* name, hwy::Span<MatPtr*> tensors) {
|
||||
if (file_toc_.Empty() || file_toc_.Contains(name)) {
|
||||
model_toc_.push_back(tensors[0]);
|
||||
file_keys_.push_back(name);
|
||||
}
|
||||
}
|
||||
|
||||
BlobError LoadScales(float* scales, size_t len) {
|
||||
for (size_t i = 0; i < len; ++i) {
|
||||
scales[i] = 1.0f;
|
||||
}
|
||||
MatPtrT<float> scales_ptr("scales", 0, 1);
|
||||
auto key = MakeKey(scales_ptr.CacheName().c_str());
|
||||
if (reader_.BlobSize(key) == 0) return 0;
|
||||
return reader_.Enqueue(key, scales, len * sizeof(scales[0]));
|
||||
}
|
||||
|
||||
// Returns whether all tensors are successfully loaded from cache.
|
||||
BlobError ReadAll(hwy::ThreadPool& pool,
|
||||
std::vector<MatStorage>& model_memory) {
|
||||
// reader_ invalid or any Enqueue failed
|
||||
if (err_ != 0) return err_;
|
||||
// Setup the model_memory.
|
||||
for (size_t b = 0; b < model_toc_.size(); ++b) {
|
||||
const std::string& file_key = file_keys_[b];
|
||||
MatPtr* blob = model_toc_[b];
|
||||
if (!file_toc_.Empty()) {
|
||||
const MatPtr* toc_blob = file_toc_.Get(file_key);
|
||||
if (toc_blob == nullptr) {
|
||||
fprintf(stderr, "Blob %s not found in TOC\n", file_key.c_str());
|
||||
return __LINE__;
|
||||
}
|
||||
if (toc_blob->Rows() != blob->Rows() ||
|
||||
toc_blob->Cols() != blob->Cols()) {
|
||||
fprintf(stderr, "Blob %s has size mismatch TOC\n", file_key.c_str());
|
||||
return __LINE__;
|
||||
}
|
||||
std::string name = blob->Name();
|
||||
*blob = *toc_blob;
|
||||
blob->SetName(name);
|
||||
}
|
||||
model_memory.emplace_back(*blob);
|
||||
model_memory.back().SetName(file_key);
|
||||
}
|
||||
// Allocate in parallel using the pool.
|
||||
pool.Run(0, model_memory.size(),
|
||||
[this, &model_memory](uint64_t task, size_t /*thread*/) {
|
||||
model_memory[task].Allocate();
|
||||
model_toc_[task]->SetPtr(model_memory[task]);
|
||||
});
|
||||
// Enqueue the read requests.
|
||||
for (auto& blob : model_memory) {
|
||||
err_ =
|
||||
reader_.Enqueue(MakeKey(blob.Name()), blob.data(), blob.SizeBytes());
|
||||
if (err_ != 0) {
|
||||
fprintf(stderr,
|
||||
"Failed to read blob %s (error %d) of size %zu x %zu x %zu\n",
|
||||
blob.Name(), err_, blob.Rows(), blob.Cols(),
|
||||
blob.ElementSize());
|
||||
return err_;
|
||||
}
|
||||
}
|
||||
return reader_.ReadAll(pool);
|
||||
}
|
||||
|
||||
private:
|
||||
BlobReader reader_;
|
||||
BlobError err_ = 0;
|
||||
// Table of contents from the file, if present.
|
||||
BlobToc file_toc_;
|
||||
// Table of contents from the model. Pointers to original MatPtrT so the
|
||||
// data pointers can be updated.
|
||||
std::vector<MatPtr*> model_toc_;
|
||||
// Mangled names of the tensors in model_toc_ for reading from the file.
|
||||
std::vector<std::string> file_keys_;
|
||||
};
|
||||
// Returns 1.0f if all magnitudes are <= `SfpStream::kMax`, otherwise scales
|
||||
// them such that the largest magnitude is `SfpStream::kMax`, and returns the
|
||||
// multiplier with which to restore the original values. This is only necessary
|
||||
// before compressing to `SfpStream` and `NuqStream`.
|
||||
float ScaleWeights(float* HWY_RESTRICT raw, size_t num);
|
||||
|
||||
} // namespace gcpp
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_COMPRESS_H_
|
||||
|
|
|
|||
|
|
@ -13,10 +13,10 @@
|
|||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// SFP uses ConcatEven/Odd which are not supported; skip SVE for faster tests.
|
||||
#include "compression/types.h"
|
||||
#ifndef HWY_DISABLED_TARGETS
|
||||
#define HWY_DISABLED_TARGETS (HWY_SCALAR | HWY_SVE)
|
||||
#endif
|
||||
#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS
|
||||
#endif // HWY_DISABLED_TARGETS
|
||||
|
||||
#include "compression/compress.h"
|
||||
|
||||
|
|
@ -80,7 +80,7 @@ struct TestDecompress2T {
|
|||
stats.Notify(raw[i], hwy::ConvertScalarTo<float>(dec[i]));
|
||||
}
|
||||
|
||||
if constexpr (false) {
|
||||
if constexpr (true) { // leave enabled due to sporadic failures
|
||||
fprintf(stderr,
|
||||
"TypeName<Packed>() %s TypeName<T>() %s: num %zu: stats.SumL1() "
|
||||
"%f stats.GeomeanValueDivL1() %f stats.WeightedAverageL1() %f "
|
||||
|
|
|
|||
|
|
@ -1,286 +0,0 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Command line tool to create compressed weights.
|
||||
|
||||
// Compiles this file for multiple architectures via "foreach_target.h", to
|
||||
// which we pass the filename via macro 'argument'.
|
||||
#undef HWY_TARGET_INCLUDE
|
||||
#define HWY_TARGET_INCLUDE \
|
||||
"compression/compress_weights.cc" // NOLINT
|
||||
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||
#include "hwy/highway.h"
|
||||
// After highway.h
|
||||
#include "compression/compress-inl.h"
|
||||
#include "gemma/configs.h"
|
||||
#include "gemma/tokenizer.h"
|
||||
|
||||
#ifndef GEMMA_COMPRESS_WEIGHTS_ONCE
|
||||
#define GEMMA_COMPRESS_WEIGHTS_ONCE
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include <algorithm> // std::clamp
|
||||
#include <cstdlib>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <thread> // NOLINT
|
||||
#include <vector>
|
||||
|
||||
#include "compression/compress.h"
|
||||
#include "compression/io.h" // Path
|
||||
#include "compression/shared.h" // PromptWrapping
|
||||
#include "gemma/common.h" // Model
|
||||
#include "gemma/weights.h"
|
||||
#include "util/allocator.h"
|
||||
#include "util/args.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
namespace {
|
||||
|
||||
} // namespace
|
||||
|
||||
struct Args : public ArgsBase<Args> {
|
||||
static constexpr size_t kDefaultNumThreads = ~size_t{0};
|
||||
|
||||
void ChooseNumThreads() {
|
||||
if (num_threads == kDefaultNumThreads) {
|
||||
// This is a rough heuristic, replace with something better in the future.
|
||||
num_threads = static_cast<size_t>(std::clamp(
|
||||
static_cast<int>(std::thread::hardware_concurrency()) - 2, 1, 18));
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
Args(int argc, char* argv[]) {
|
||||
InitAndParse(argc, argv);
|
||||
ChooseNumThreads();
|
||||
}
|
||||
|
||||
// Returns error string or nullptr if OK.
|
||||
const char* Validate() {
|
||||
if (const char* err = ParseModelTypeAndWrapping(model_type_str, model_type_,
|
||||
prompt_wrapping_)) {
|
||||
return err;
|
||||
}
|
||||
if (const char* err = ParseType(weight_type_str, weight_type_)) {
|
||||
return err;
|
||||
}
|
||||
if (weights.path.empty()) {
|
||||
return "Missing --weights flag, a file for the uncompressed model.";
|
||||
}
|
||||
if (compressed_weights.path.empty()) {
|
||||
return "Missing --compressed_weights flag, a file for the compressed "
|
||||
"model.";
|
||||
}
|
||||
if (!weights.Exists()) {
|
||||
return "Can't open file specified with --weights flag.";
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Path weights; // uncompressed weights file location
|
||||
Path compressed_weights; // compressed weights file location
|
||||
std::string model_type_str;
|
||||
std::string weight_type_str;
|
||||
size_t num_threads;
|
||||
// If non-empty, whether to include the config and TOC in the output file, as
|
||||
// well as the tokenizer.
|
||||
Path tokenizer;
|
||||
|
||||
template <class Visitor>
|
||||
void ForEach(const Visitor& visitor) {
|
||||
visitor(weights, "weights", Path(),
|
||||
"Path to model weights (.bin) file.\n"
|
||||
" Required argument.");
|
||||
visitor(model_type_str, "model", std::string(),
|
||||
"Model type\n 2b-it = 2B parameters, instruction-tuned\n "
|
||||
"2b-pt = 2B parameters, pretrained\n 7b-it = 7B parameters "
|
||||
"instruction-tuned\n 7b-pt = 7B parameters, pretrained\n "
|
||||
"gr2b-it = griffin 2B parameters, instruction-tuned\n "
|
||||
"gr2b-pt = griffin 2B parameters, pretrained\n "
|
||||
" Required argument.");
|
||||
visitor(weight_type_str, "weight_type", std::string("sfp"),
|
||||
"Weight type\n f32 = float, bf16 = bfloat16, SFP = 8-bit FP\n"
|
||||
" Required argument.");
|
||||
visitor(compressed_weights, "compressed_weights", Path(),
|
||||
"Path name where compressed weights (.sbs) file will be written.\n"
|
||||
" Required argument.");
|
||||
visitor(num_threads, "num_threads",
|
||||
kDefaultNumThreads, // see ChooseNumThreads
|
||||
"Number of threads to use.\n Default = Estimate of the "
|
||||
"number of supported concurrent threads.",
|
||||
2);
|
||||
visitor(tokenizer, "tokenizer", Path(),
|
||||
"Path to tokenizer file. If given, the config and TOC are also "
|
||||
"added to the output file.");
|
||||
}
|
||||
|
||||
// Uninitialized before Validate, must call after that.
|
||||
gcpp::Model ModelType() const { return model_type_; }
|
||||
gcpp::PromptWrapping PromptWrappingType() const { return prompt_wrapping_; }
|
||||
gcpp::Type WeightType() const { return weight_type_; }
|
||||
|
||||
private:
|
||||
Model model_type_;
|
||||
PromptWrapping prompt_wrapping_;
|
||||
Type weight_type_;
|
||||
};
|
||||
|
||||
void ShowHelp(gcpp::Args& args) {
|
||||
std::cerr
|
||||
<< "Usage:\n./compress_weights --weights <path to uncompressed weights> "
|
||||
" --model <model type> --compressed_weights <output path>\n";
|
||||
std::cerr << "\n*Arguments*\n\n";
|
||||
args.Help();
|
||||
std::cerr << "\n";
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
#endif // GEMMA_COMPRESS_WEIGHTS_ONCE
|
||||
|
||||
// SIMD code, compiled once per target.
|
||||
HWY_BEFORE_NAMESPACE();
|
||||
namespace gcpp {
|
||||
namespace HWY_NAMESPACE {
|
||||
|
||||
template <typename T>
|
||||
void CompressWeights(const Path& weights_path,
|
||||
const Path& compressed_weights_path, Model model_type,
|
||||
Type weight_type, PromptWrapping wrapping,
|
||||
const Path& tokenizer_path, hwy::ThreadPool& pool) {
|
||||
if (!weights_path.Exists()) {
|
||||
HWY_ABORT("The model weights file '%s' does not exist.",
|
||||
weights_path.path.c_str());
|
||||
}
|
||||
printf("Compressing weights from %s to %s\n", weights_path.path.c_str(),
|
||||
compressed_weights_path.path.c_str());
|
||||
ModelConfig config = ConfigFromModel(model_type);
|
||||
config.weight = weight_type;
|
||||
config.wrapping = wrapping;
|
||||
std::vector<MatStorage> model_storage;
|
||||
ModelWeightsPtrs<T> c_weights(config);
|
||||
c_weights.Allocate(model_storage, pool);
|
||||
ModelWeightsPtrs<float> uc_weights(config);
|
||||
uc_weights.Allocate(model_storage, pool);
|
||||
// Get uncompressed weights, compress, and store.
|
||||
FILE* fptr = fopen(weights_path.path.c_str(), "rb");
|
||||
if (fptr == nullptr) {
|
||||
HWY_ABORT("Failed to open model file %s - does it exist?",
|
||||
weights_path.path.c_str());
|
||||
}
|
||||
bool ok = true;
|
||||
uint64_t total_size = 0;
|
||||
ModelWeightsPtrs<float>::ForEachTensor(
|
||||
{&uc_weights}, ForEachType::kLoadNoToc,
|
||||
[&](const char* name, hwy::Span<MatPtr*> tensors) {
|
||||
fprintf(stderr, "Loading Parameters (size %zu): %s\n",
|
||||
tensors[0]->SizeBytes(), name);
|
||||
ok &= 1 == fread(tensors[0]->Ptr(), tensors[0]->SizeBytes(), 1, fptr);
|
||||
total_size += tensors[0]->SizeBytes();
|
||||
});
|
||||
if (!tokenizer_path.path.empty()) {
|
||||
uc_weights.AllocAndCopyWithTranspose(pool, model_storage);
|
||||
}
|
||||
const bool scale_for_compression = config.num_tensor_scales > 0;
|
||||
std::vector<float> scales;
|
||||
if (scale_for_compression) {
|
||||
uc_weights.GetOrApplyScales(scales);
|
||||
}
|
||||
Compressor compressor(pool);
|
||||
ModelWeightsPtrs<T>::ForEachTensor(
|
||||
{reinterpret_cast<ModelWeightsPtrs<T>*>(&uc_weights), &c_weights},
|
||||
tokenizer_path.path.empty() ? ForEachType::kLoadNoToc
|
||||
: ForEachType::kLoadWithToc,
|
||||
[&compressor](const char* name, hwy::Span<MatPtr*> tensors) {
|
||||
tensors[1]->CallUpcasted(
|
||||
compressor, name,
|
||||
reinterpret_cast<const float*>(tensors[0]->Ptr()));
|
||||
});
|
||||
if (!tokenizer_path.path.empty()) {
|
||||
std::string tokenizer_proto = ReadFileToString(tokenizer_path);
|
||||
compressor.AddTokenizer(tokenizer_proto);
|
||||
} else {
|
||||
compressor.AddScales(scales.data(), scales.size() * sizeof(scales[0]));
|
||||
}
|
||||
compressor.WriteAll(compressed_weights_path,
|
||||
tokenizer_path.path.empty() ? nullptr : &config);
|
||||
}
|
||||
|
||||
} // namespace HWY_NAMESPACE
|
||||
} // namespace gcpp
|
||||
HWY_AFTER_NAMESPACE();
|
||||
|
||||
#if HWY_ONCE
|
||||
namespace gcpp {
|
||||
|
||||
void Run(Args& args) {
|
||||
hwy::ThreadPool pool(args.num_threads);
|
||||
if (args.PromptWrappingType() == PromptWrapping::PALIGEMMA) {
|
||||
HWY_ABORT("PaliGemma is not supported in compress_weights.");
|
||||
}
|
||||
const Model model_type = args.ModelType();
|
||||
const Type weight_type = args.WeightType();
|
||||
switch (weight_type) {
|
||||
case Type::kF32:
|
||||
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(CompressWeights<float>)
|
||||
(args.weights, args.compressed_weights, model_type, weight_type,
|
||||
args.PromptWrappingType(), args.tokenizer, pool);
|
||||
break;
|
||||
case Type::kBF16:
|
||||
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(CompressWeights<BF16>)
|
||||
(args.weights, args.compressed_weights, model_type, weight_type,
|
||||
args.PromptWrappingType(), args.tokenizer, pool);
|
||||
break;
|
||||
case Type::kSFP:
|
||||
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(CompressWeights<SfpStream>)
|
||||
(args.weights, args.compressed_weights, model_type, weight_type,
|
||||
args.PromptWrappingType(), args.tokenizer, pool);
|
||||
break;
|
||||
case Type::kNUQ:
|
||||
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(CompressWeights<NuqStream>)
|
||||
(args.weights, args.compressed_weights, model_type, weight_type,
|
||||
args.PromptWrappingType(), args.tokenizer, pool);
|
||||
break;
|
||||
default:
|
||||
HWY_ABORT("Weight type %d unsupported.", static_cast<int>(weight_type));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
gcpp::Args args(argc, argv);
|
||||
|
||||
if (gcpp::HasHelp(argc, argv)) {
|
||||
gcpp::ShowHelp(args);
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (const char* error = args.Validate()) {
|
||||
gcpp::ShowHelp(args);
|
||||
HWY_ABORT("\nInvalid args: %s", error);
|
||||
}
|
||||
|
||||
gcpp::Run(args);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
#endif // HWY_ONCE
|
||||
|
|
@ -1,209 +0,0 @@
|
|||
# Copyright 2024 Google LLC
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Converts pytorch to f32 for use by compress_weights.cc."""
|
||||
|
||||
import argparse
|
||||
import collections
|
||||
import os
|
||||
from gemma import config
|
||||
from gemma import model as gemma_model
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
# Requires torch 2.2 and gemma package from
|
||||
# https://github.com/google/gemma_pytorch
|
||||
|
||||
|
||||
def check_file_exists(value):
|
||||
if not os.path.exists(str(value)):
|
||||
raise argparse.ArgumentTypeError(
|
||||
"The file %s does not appear to exist." % value
|
||||
)
|
||||
return value
|
||||
|
||||
|
||||
def check_model_types(value):
|
||||
if str(value).lower() not in ["2b", "7b"]:
|
||||
raise argparse.ArgumentTypeError(
|
||||
"Model type value %s is not in [2b, 7b]." % value
|
||||
)
|
||||
return value
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--tokenizer",
|
||||
dest="tokenizer",
|
||||
default="models/tokenizer.spm",
|
||||
help="Location of tokenizer file (.model or .spm)",
|
||||
type=check_file_exists,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--weights",
|
||||
dest="weights",
|
||||
default="models/gemma-2b-it.ckpt",
|
||||
help="Location of input checkpoint file (.ckpt)",
|
||||
type=check_file_exists,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output_file",
|
||||
dest="output_file",
|
||||
default="2bit-f32.sbs",
|
||||
help="Location to write converted weights",
|
||||
type=str,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--model_type",
|
||||
dest="model_type",
|
||||
default="2b",
|
||||
help="Model size / type (2b, 7b)",
|
||||
type=check_model_types,
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
TRANSFORMATIONS = {
|
||||
"2b": collections.defaultdict(
|
||||
lambda: lambda x: x,
|
||||
{
|
||||
"embedder.weight": lambda x: x,
|
||||
"self_attn.qkv_proj.weight": lambda x: x.reshape((10, 256, 2048)),
|
||||
"self_attn.o_proj.weight": lambda x: x.reshape(
|
||||
(2048, 8, 256)
|
||||
).transpose([1, 0, 2]),
|
||||
"mlp.gate_proj.weight": lambda x: x[np.newaxis, :, :],
|
||||
"mlp.up_proj.weight": lambda x: x[np.newaxis, :, :],
|
||||
"mlp.down_proj.weight": lambda x: x,
|
||||
},
|
||||
),
|
||||
"7b": collections.defaultdict(
|
||||
lambda: lambda x: x,
|
||||
{
|
||||
"embedder.weight": lambda x: x,
|
||||
"self_attn.qkv_proj.weight": lambda x: x.reshape(
|
||||
(3, 16, 256, 3072)
|
||||
).transpose([1, 0, 2, 3]),
|
||||
"self_attn.o_proj.weight": lambda x: x.reshape(
|
||||
(3072, 16, 256)
|
||||
).transpose([1, 0, 2]),
|
||||
"mlp.gate_proj.weight": lambda x: x[np.newaxis, :, :],
|
||||
"mlp.up_proj.weight": lambda x: x[np.newaxis, :, :],
|
||||
"mlp.down_proj.weight": lambda x: x,
|
||||
},
|
||||
),
|
||||
}
|
||||
|
||||
VALIDATIONS = {
|
||||
"2b": {
|
||||
"embedder.weight": lambda x: x.shape == (256000, 2048),
|
||||
"model.norm.weight": lambda x: x.shape == (2048,),
|
||||
"self_attn.qkv_proj.weight": lambda x: x.shape == (10, 256, 2048),
|
||||
"self_attn.o_proj.weight": lambda x: x.shape == (8, 2048, 256),
|
||||
"mlp.gate_proj.weight": lambda x: x.shape == (1, 16384, 2048),
|
||||
"mlp.up_proj.weight": lambda x: x.shape == (1, 16384, 2048),
|
||||
"mlp.down_proj.weight": lambda x: x.shape == (2048, 16384),
|
||||
"input_layernorm.weight": lambda x: x.shape == (2048,),
|
||||
"post_attention_layernorm.weight": lambda x: x.shape == (2048,),
|
||||
},
|
||||
"7b": {
|
||||
"embedder.weight": lambda x: x.shape == (256000, 3072),
|
||||
"model.norm.weight": lambda x: x.shape == (3072,),
|
||||
"self_attn.qkv_proj.weight": lambda x: x.shape == (16, 3, 256, 3072),
|
||||
"self_attn.o_proj.weight": lambda x: x.shape == (16, 3072, 256),
|
||||
"mlp.gate_proj.weight": lambda x: x.shape == (1, 24576, 3072),
|
||||
"mlp.up_proj.weight": lambda x: x.shape == (1, 24576, 3072),
|
||||
"mlp.down_proj.weight": lambda x: x.shape == (3072, 24576),
|
||||
"input_layernorm.weight": lambda x: x.shape == (3072,),
|
||||
"post_attention_layernorm.weight": lambda x: x.shape == (3072,),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def param_names(num_hidden_layers: int):
|
||||
"""Return parameter names in the order they are expected for deserialization."""
|
||||
|
||||
# note *weight_scaler params are ignored in the forward computation unless
|
||||
# quantization is being used.
|
||||
#
|
||||
# since we are working with the full precision weights as input, don't
|
||||
# include these in the parameters being iterated over.
|
||||
|
||||
names = [
|
||||
("embedder.weight",) * 2, # embedder_input_embedding
|
||||
("model.norm.weight",) * 2, # final_norm_scale
|
||||
]
|
||||
layer_params = [
|
||||
"self_attn.o_proj.weight", # attn_vec_einsum_w
|
||||
"self_attn.qkv_proj.weight", # qkv_einsum_w
|
||||
"mlp.gate_proj.weight", # gating_einsum_w
|
||||
"mlp.up_proj.weight",
|
||||
"mlp.down_proj.weight", # linear_w
|
||||
"input_layernorm.weight", # pre_attention_norm_scale
|
||||
"post_attention_layernorm.weight", # pre_ffw_norm_scale
|
||||
]
|
||||
for layer in range(num_hidden_layers):
|
||||
for layer_param in layer_params:
|
||||
names = names + [(f"model.layers.{layer}.{layer_param}", layer_param)]
|
||||
return names
|
||||
|
||||
|
||||
def convert_weights():
|
||||
"""Main function; loads weights, runs transformations, writes f32."""
|
||||
model_type = args.model_type
|
||||
output_file = args.output_file
|
||||
|
||||
model_config = config.get_model_config(model_type)
|
||||
model_config.dtype = "float32"
|
||||
model_config.tokenizer = args.tokenizer
|
||||
device = torch.device("cpu")
|
||||
torch.set_default_dtype(torch.float)
|
||||
model = gemma_model.GemmaForCausalLM(model_config)
|
||||
|
||||
model.load_weights(args.weights)
|
||||
model.to(device).eval()
|
||||
|
||||
model_dict = dict(model.named_parameters())
|
||||
param_order = param_names(model_config.num_hidden_layers)
|
||||
|
||||
all_ok = True
|
||||
print("Checking transformations ...")
|
||||
for name, layer_name in param_order:
|
||||
arr = model_dict[name].detach().numpy()
|
||||
arr = TRANSFORMATIONS[model_type][layer_name](arr)
|
||||
check = "OK" if VALIDATIONS[model_type][layer_name](arr) else "FAILED"
|
||||
|
||||
if check == "FAILED":
|
||||
all_ok = False
|
||||
print(f" {name : <60}{str(arr.shape) : <20}{check}")
|
||||
|
||||
if all_ok:
|
||||
print("Writing parameters ...")
|
||||
with open(output_file, "wb") as bin_handle:
|
||||
for name, layer_name in param_order:
|
||||
arr = model_dict[name].detach().numpy()
|
||||
arr = TRANSFORMATIONS[model_type][layer_name](arr)
|
||||
check = "OK" if VALIDATIONS[model_type][layer_name](arr) else "FAILED"
|
||||
print(f" {name : <60}{str(arr.shape) : <20}{check}")
|
||||
arr.flatten().astype(np.float32).tofile(bin_handle)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
convert_weights()
|
||||
print("Done")
|
||||
|
|
@ -17,7 +17,7 @@
|
|||
|
||||
#include <stdio.h>
|
||||
|
||||
#include "compression/shared.h" // SfpStream::kMax
|
||||
#include "compression/types.h" // SfpStream::kMax
|
||||
#include "util/test_util.h"
|
||||
#include "hwy/nanobenchmark.h"
|
||||
#include "hwy/tests/hwy_gtest.h"
|
||||
|
|
|
|||
|
|
@ -1,121 +0,0 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Safe to be first, does not include POSIX headers.
|
||||
#include "hwy/detect_compiler_arch.h"
|
||||
// Only compile this file on non-Windows; it replaces io_win.cc. It is easier to
|
||||
// check this in source code because we support multiple build systems.
|
||||
#if !HWY_OS_WIN
|
||||
|
||||
// Request POSIX 2008, including `pread()` and `posix_fadvise()`.
|
||||
#if !defined(_XOPEN_SOURCE) || _XOPEN_SOURCE < 700
|
||||
#undef _XOPEN_SOURCE
|
||||
#define _XOPEN_SOURCE 700
|
||||
#endif
|
||||
#if !defined(_POSIX_C_SOURCE) || _POSIX_C_SOURCE < 200809
|
||||
#define _POSIX_C_SOURCE 200809
|
||||
#endif
|
||||
|
||||
// Make `off_t` 64-bit even on 32-bit systems. Works for Android >= r15c.
|
||||
#undef _FILE_OFFSET_BITS
|
||||
#define _FILE_OFFSET_BITS 64
|
||||
|
||||
#include <fcntl.h> // open
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h> // SEEK_END - unistd isn't enough for IDE.
|
||||
#include <sys/stat.h> // O_RDONLY
|
||||
#include <unistd.h> // read, write, close
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "compression/io.h"
|
||||
#include "hwy/base.h" // HWY_ASSERT
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
class FilePosix : public File {
|
||||
int fd_ = 0;
|
||||
|
||||
public:
|
||||
explicit FilePosix(int fd) : fd_(fd) { HWY_ASSERT(fd > 0); }
|
||||
~FilePosix() override {
|
||||
if (fd_ != 0) {
|
||||
HWY_ASSERT(close(fd_) != -1);
|
||||
}
|
||||
}
|
||||
|
||||
uint64_t FileSize() const override {
|
||||
static_assert(sizeof(off_t) == 8, "64-bit off_t required");
|
||||
const off_t size = lseek(fd_, 0, SEEK_END);
|
||||
if (size < 0) {
|
||||
return 0;
|
||||
}
|
||||
return static_cast<uint64_t>(size);
|
||||
}
|
||||
|
||||
bool Read(uint64_t offset, uint64_t size, void* to) const override {
|
||||
uint8_t* bytes = reinterpret_cast<uint8_t*>(to);
|
||||
uint64_t pos = 0;
|
||||
for (;;) {
|
||||
// pread seems to be faster than lseek + read when parallelized.
|
||||
const auto bytes_read = pread(fd_, bytes + pos, size - pos, offset + pos);
|
||||
if (bytes_read <= 0) break;
|
||||
pos += bytes_read;
|
||||
HWY_ASSERT(pos <= size);
|
||||
if (pos == size) break;
|
||||
}
|
||||
return pos == size; // success if managed to read desired size
|
||||
}
|
||||
|
||||
bool Write(const void* from, uint64_t size, uint64_t offset) override {
|
||||
const uint8_t* bytes = reinterpret_cast<const uint8_t*>(from);
|
||||
uint64_t pos = 0;
|
||||
for (;;) {
|
||||
const auto bytes_written =
|
||||
pwrite(fd_, bytes + pos, size - pos, offset + pos);
|
||||
if (bytes_written <= 0) break;
|
||||
pos += bytes_written;
|
||||
HWY_ASSERT(pos <= size);
|
||||
if (pos == size) break;
|
||||
}
|
||||
return pos == size; // success if managed to write desired size
|
||||
}
|
||||
}; // FilePosix
|
||||
|
||||
HWY_MAYBE_UNUSED extern std::unique_ptr<File> OpenFileGoogle(
|
||||
const Path& filename, const char* mode);
|
||||
|
||||
std::unique_ptr<File> OpenFileOrNull(const Path& filename, const char* mode) {
|
||||
std::unique_ptr<File> file; // OpenFileGoogle omitted
|
||||
if (file) return file;
|
||||
|
||||
const bool is_read = mode[0] != 'w';
|
||||
const int flags = is_read ? O_RDONLY : O_CREAT | O_RDWR | O_TRUNC;
|
||||
const int fd = open(filename.path.c_str(), flags, 0644);
|
||||
if (fd < 0) return file;
|
||||
|
||||
#if HWY_OS_LINUX && (!defined(__ANDROID_API__) || __ANDROID_API__ >= 21)
|
||||
if (is_read) {
|
||||
// Doubles the readahead window, which seems slightly faster when cached.
|
||||
(void)posix_fadvise(fd, 0, 0, POSIX_FADV_SEQUENTIAL);
|
||||
}
|
||||
#endif
|
||||
|
||||
return std::make_unique<FilePosix>(fd);
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
#endif // !HWY_OS_WIN
|
||||
|
|
@ -23,7 +23,7 @@
|
|||
|
||||
#include <cstdio>
|
||||
|
||||
#include "compression/shared.h"
|
||||
#include "compression/types.h"
|
||||
#include "util/basics.h"
|
||||
#include "hwy/base.h"
|
||||
|
||||
|
|
|
|||
|
|
@ -13,20 +13,20 @@
|
|||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// SFP uses ConcatEven/Odd which are not supported; skip SVE for faster tests.
|
||||
#include "compression/types.h"
|
||||
#ifndef HWY_DISABLED_TARGETS
|
||||
#define HWY_DISABLED_TARGETS (HWY_SCALAR | HWY_SVE)
|
||||
#endif
|
||||
#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS
|
||||
#endif // HWY_DISABLED_TARGETS
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include <algorithm> // std::shuffle
|
||||
#include <array>
|
||||
#include <random>
|
||||
|
||||
#include "compression/distortion.h"
|
||||
#include "compression/shared.h"
|
||||
#include "util/test_util.h"
|
||||
#include "hwy/aligned_allocator.h"
|
||||
#include "hwy/base.h"
|
||||
|
|
@ -104,7 +104,7 @@ struct TestPlateaus {
|
|||
HWY_ASSERT(-0.5f <= in[i] && in[i] < 0.5f);
|
||||
}
|
||||
|
||||
std::random_device rd;
|
||||
std::random_device rd; // NOLINT
|
||||
std::mt19937 rng(rd());
|
||||
std::shuffle(in.get(), in.get() + kGroupSize, rng);
|
||||
|
||||
|
|
@ -151,7 +151,7 @@ struct TestRamp {
|
|||
HWY_ASSERT(-0.45f <= in[i] && in[i] < 0.55f);
|
||||
}
|
||||
|
||||
std::random_device rd;
|
||||
std::random_device rd; // NOLINT
|
||||
std::mt19937 rng(rd());
|
||||
std::shuffle(in.get(), in.get() + kGroupSize, rng);
|
||||
|
||||
|
|
@ -246,7 +246,8 @@ struct TestOffset {
|
|||
auto in = hwy::AllocateAligned<float>(total); // Enc() requires f32
|
||||
auto dec1 = hwy::AllocateAligned<T>(total);
|
||||
auto dec2 = hwy::AllocateAligned<T>(kMidLen);
|
||||
auto nuq = hwy::AllocateAligned<NuqStream>(NuqStream::PackedEnd(total));
|
||||
auto nuq = hwy::AllocateAligned<NuqStream>(
|
||||
hwy::RoundUpTo(NuqStream::PackedEnd(total), hwy::VectorBytes()));
|
||||
HWY_ASSERT(in && dec1 && dec2 && nuq);
|
||||
const auto nuq_span = MakeSpan(nuq.get(), total);
|
||||
|
||||
|
|
@ -296,7 +297,8 @@ struct TestUnalignedOffset {
|
|||
|
||||
auto in = hwy::AllocateAligned<float>(total); // Enc() requires f32
|
||||
auto dec1 = hwy::AllocateAligned<T>(total);
|
||||
auto nuq = hwy::AllocateAligned<NuqStream>(NuqStream::PackedEnd(total));
|
||||
auto nuq = hwy::AllocateAligned<NuqStream>(
|
||||
hwy::RoundUpTo(NuqStream::PackedEnd(total), hwy::VectorBytes()));
|
||||
auto dec2 = hwy::AllocateAligned<T>(num_decompressed);
|
||||
HWY_ASSERT(in && dec1 && dec2 && nuq);
|
||||
const auto nuq_span = MakeSpan(nuq.get(), total);
|
||||
|
|
@ -347,7 +349,8 @@ struct TestDec2 {
|
|||
auto dec0 = hwy::AllocateAligned<T>(total);
|
||||
auto dec1 = hwy::AllocateAligned<T>(total);
|
||||
auto dec2 = hwy::AllocateAligned<T>(kMidLen);
|
||||
auto nuq = hwy::AllocateAligned<NuqStream>(NuqStream::PackedEnd(total));
|
||||
auto nuq = hwy::AllocateAligned<NuqStream>(
|
||||
hwy::RoundUpTo(NuqStream::PackedEnd(total), hwy::VectorBytes()));
|
||||
HWY_ASSERT(in && dec0 && dec1 && dec2 && nuq);
|
||||
const auto nuq_span = MakeSpan(nuq.get(), total);
|
||||
|
||||
|
|
@ -449,7 +452,8 @@ struct TestEncDec {
|
|||
const size_t num = 4 * kGroupSize;
|
||||
auto in = hwy::AllocateAligned<float>(num); // Enc() requires f32
|
||||
auto out = hwy::AllocateAligned<T>(num); // already padded
|
||||
auto nuq = hwy::AllocateAligned<NuqStream>(NuqStream::PackedEnd(num));
|
||||
auto nuq = hwy::AllocateAligned<NuqStream>(
|
||||
hwy::RoundUpTo(NuqStream::PackedEnd(num), hwy::VectorBytes()));
|
||||
HWY_ASSERT(in && out && nuq);
|
||||
const auto nuq_span = MakeSpan(nuq.get(), num);
|
||||
|
||||
|
|
@ -512,6 +516,7 @@ HWY_AFTER_NAMESPACE();
|
|||
#if HWY_ONCE
|
||||
namespace gcpp {
|
||||
HWY_BEFORE_TEST(NuqTest);
|
||||
#if GEMMA_ENABLE_NUQ
|
||||
HWY_EXPORT_AND_TEST_P(NuqTest, TestAllFlat);
|
||||
HWY_EXPORT_AND_TEST_P(NuqTest, TestAllPlateaus);
|
||||
HWY_EXPORT_AND_TEST_P(NuqTest, TestAllRamp);
|
||||
|
|
@ -525,6 +530,9 @@ HWY_EXPORT_AND_TEST_P(NuqTest, TestUnalignedOffsetF32);
|
|||
HWY_EXPORT_AND_TEST_P(NuqTest, TestAllNibble);
|
||||
HWY_EXPORT_AND_TEST_P(NuqTest, TestEncDecBF16);
|
||||
HWY_EXPORT_AND_TEST_P(NuqTest, TestEncDecF32);
|
||||
#else
|
||||
GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(NuqTest);
|
||||
#endif // GEMMA_ENABLE_NUQ
|
||||
HWY_AFTER_TEST();
|
||||
} // namespace gcpp
|
||||
#endif // HWY_ONCE
|
||||
|
|
|
|||
|
|
@ -14,11 +14,16 @@ cc_library(
|
|||
hdrs = ["compression_clif_aux.h"],
|
||||
visibility = ["//visibility:private"],
|
||||
deps = [
|
||||
"@abseil-cpp//absl/types:span",
|
||||
"//:common",
|
||||
"//:basics",
|
||||
"//:configs",
|
||||
"//:mat",
|
||||
"//:model_store",
|
||||
"//:tensor_info",
|
||||
"//:threading_context",
|
||||
"//:tokenizer",
|
||||
"//compression:compress",
|
||||
"//compression:io",
|
||||
"//io",
|
||||
"//io:blob_store",
|
||||
"@highway//:hwy",
|
||||
"@highway//:thread_pool",
|
||||
],
|
||||
|
|
@ -29,9 +34,9 @@ pybind_extension(
|
|||
srcs = ["compression_extension.cc"],
|
||||
deps = [
|
||||
":compression_clif_aux",
|
||||
"@abseil-cpp//absl/types:span",
|
||||
"//:common",
|
||||
"//compression:sfp",
|
||||
"//:mat",
|
||||
"//:tensor_info",
|
||||
"//compression:types",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -15,14 +15,29 @@
|
|||
|
||||
#include "compression/python/compression_clif_aux.h"
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstdio>
|
||||
#include "compression/types.h" // GEMMA_DISABLED_TARGETS
|
||||
#ifndef HWY_DISABLED_TARGETS
|
||||
#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS
|
||||
#endif // HWY_DISABLED_TARGETS
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "compression/compress.h"
|
||||
#include "compression/shared.h"
|
||||
#include "hwy/aligned_allocator.h"
|
||||
#include "compression/compress.h" // ScaleWeights
|
||||
#include "gemma/configs.h" // ModelConfig
|
||||
#include "gemma/model_store.h" // ModelStore
|
||||
#include "gemma/tensor_info.h" // TensorInfo
|
||||
#include "gemma/tokenizer.h"
|
||||
#include "io/blob_store.h" // BlobWriter
|
||||
#include "io/io.h" // Path
|
||||
#include "util/basics.h"
|
||||
#include "util/mat.h"
|
||||
#include "util/threading_context.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
||||
#undef HWY_TARGET_INCLUDE
|
||||
#define HWY_TARGET_INCLUDE \
|
||||
|
|
@ -32,157 +47,97 @@
|
|||
// After highway.h
|
||||
#include "compression/compress-inl.h"
|
||||
|
||||
// Non-SIMD includes and types. Note that HWY_ONCE is only true on the last
|
||||
// compile pass, whereas we want this defined in the first.
|
||||
#ifndef GEMMA_ONCE
|
||||
#define GEMMA_ONCE
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "compression/io.h"
|
||||
#include "gemma/configs.h"
|
||||
#include "gemma/tensor_index.h"
|
||||
#include "gemma/tokenizer.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
class WriterInterface {
|
||||
public:
|
||||
virtual ~WriterInterface() = default;
|
||||
|
||||
virtual void Insert(std::string name, absl::Span<const float> weights,
|
||||
Type type, const TensorInfo& tensor_info,
|
||||
float scale) = 0;
|
||||
virtual void InsertSfp(std::string name, absl::Span<const float> weights) = 0;
|
||||
virtual void InsertNUQ(std::string name, absl::Span<const float> weights) = 0;
|
||||
virtual void InsertBfloat16(std::string name,
|
||||
absl::Span<const float> weights) = 0;
|
||||
virtual void InsertFloat(std::string name,
|
||||
absl::Span<const float> weights) = 0;
|
||||
virtual void AddScales(const std::vector<float>& scales) = 0;
|
||||
virtual void AddTokenizer(const std::string& tokenizer_path) = 0;
|
||||
|
||||
virtual size_t DebugNumBlobsAdded() const = 0;
|
||||
|
||||
virtual int WriteWithConfig(std::string path, const ModelConfig* config) = 0;
|
||||
};
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif // GEMMA_ONCE
|
||||
|
||||
// SIMD code, compiled once per target.
|
||||
HWY_BEFORE_NAMESPACE();
|
||||
namespace gcpp {
|
||||
namespace HWY_NAMESPACE {
|
||||
|
||||
class SbsWriterImpl : public WriterInterface {
|
||||
// Implementation for the currently compiled SIMD target.
|
||||
class SbsWriterImpl : public ISbsWriter {
|
||||
template <typename Packed>
|
||||
void AllocateAndCompress(const std::string& name,
|
||||
absl::Span<const float> weights) {
|
||||
MatPtrT<Packed> storage(name, 1, weights.size());
|
||||
model_memory_.push_back(storage);
|
||||
model_memory_.back().Allocate();
|
||||
storage.SetPtr(model_memory_.back());
|
||||
std::string decorated_name = storage.CacheName();
|
||||
compressor_(&storage, decorated_name.c_str(), weights.data());
|
||||
}
|
||||
template <typename Packed>
|
||||
void AllocateWithShape(const std::string& name,
|
||||
absl::Span<const float> weights,
|
||||
const TensorInfo& tensor_info, float scale) {
|
||||
MatPtrT<Packed> storage(name, &tensor_info);
|
||||
storage.set_scale(scale);
|
||||
void InsertT(const char* name, F32Span weights,
|
||||
const TensorInfo& tensor_info) {
|
||||
// TODO(janwas): 1D parallel-for.
|
||||
hwy::ThreadPool& pool = ctx_.pools.Pool();
|
||||
|
||||
// Don't reset num_elements for NUQ.
|
||||
if (!hwy::IsSame<hwy::RemoveCvRef<Packed>, NuqStream>()) {
|
||||
storage.SetNumElements(CompressedArrayElements<Packed>(weights.size()));
|
||||
MatPtrT<Packed> mat(name, ExtentsFromInfo(&tensor_info));
|
||||
// SFP and NUQ (which uses SFP for cluster centers) have a limited range
|
||||
// and depending on the input values may require rescaling. Scaling is
|
||||
// cheap for matmul and probably not an issue for other ops, but it might be
|
||||
// beneficial for precision to keep the original data range for other types.
|
||||
if (mat.GetType() == Type::kSFP || mat.GetType() == Type::kNUQ) {
|
||||
mat.SetScale(ScaleWeights(weights.data(), weights.size()));
|
||||
}
|
||||
|
||||
model_memory_.push_back(storage);
|
||||
if (mode_ == CompressorMode::kTEST_ONLY) return;
|
||||
model_memory_.back().Allocate();
|
||||
storage.SetPtr(model_memory_.back());
|
||||
std::string decorated_name = storage.CacheName();
|
||||
compressor_(&storage, decorated_name.c_str(), weights.data());
|
||||
if (weights.size() == 0) {
|
||||
HWY_WARN("Ignoring zero-sized tensor %s.", name);
|
||||
return;
|
||||
}
|
||||
|
||||
mat.AppendTo(serialized_mat_ptrs_);
|
||||
MatOwner mat_owner;
|
||||
mat_owner.AllocateFor(mat, ctx_.allocator, MatPadding::kPacked);
|
||||
|
||||
// Handle gemma_export_test's MockArray. Write blobs so that the test
|
||||
// succeeds, but we only have 10 floats, not the full tensor.
|
||||
if (weights.size() == 10 && mat.Extents().Area() != 10) {
|
||||
Compress(weights.data(), weights.size(), working_set_, mat.Span(),
|
||||
/*packed_ofs=*/0, pool);
|
||||
writer_.Add(name, mat.Packed(), mat.ElementBytes() * 10);
|
||||
return;
|
||||
}
|
||||
|
||||
fprintf(stderr, "Compressing %s (%zu x %zu = %zuM) to %s, please wait\n",
|
||||
name, mat.Rows(), mat.Cols(), weights.size() / (1000 * 1000),
|
||||
TypeName(TypeEnum<Packed>()));
|
||||
HWY_ASSERT(weights.size() == mat.Extents().Area());
|
||||
Compress(weights.data(), weights.size(), working_set_, mat.Span(),
|
||||
/*packed_ofs=*/0, pool);
|
||||
writer_.Add(name, mat.Packed(), mat.PackedBytes());
|
||||
}
|
||||
|
||||
public:
|
||||
explicit SbsWriterImpl(CompressorMode mode)
|
||||
: pool_(0), compressor_(pool_), mode_(mode) {}
|
||||
SbsWriterImpl(const std::string& sbs_path)
|
||||
: ctx_(ThreadingArgs()),
|
||||
writer_(gcpp::Path(sbs_path), ctx_.pools.Pool()) {}
|
||||
|
||||
void Insert(std::string name, absl::Span<const float> weights, Type type,
|
||||
const TensorInfo& tensor_info, float scale) override {
|
||||
void Insert(const char* name, F32Span weights, Type type,
|
||||
const TensorInfo& tensor_info) override {
|
||||
switch (type) {
|
||||
case Type::kSFP:
|
||||
AllocateWithShape<SfpStream>(name, weights, tensor_info, scale);
|
||||
InsertT<SfpStream>(name, weights, tensor_info);
|
||||
break;
|
||||
case Type::kNUQ:
|
||||
AllocateWithShape<NuqStream>(name, weights, tensor_info, scale);
|
||||
InsertT<NuqStream>(name, weights, tensor_info);
|
||||
break;
|
||||
case Type::kBF16:
|
||||
AllocateWithShape<BF16>(name, weights, tensor_info, scale);
|
||||
InsertT<BF16>(name, weights, tensor_info);
|
||||
break;
|
||||
case Type::kF32:
|
||||
AllocateWithShape<float>(name, weights, tensor_info, scale);
|
||||
InsertT<float>(name, weights, tensor_info);
|
||||
break;
|
||||
default:
|
||||
HWY_ABORT("Unsupported type");
|
||||
HWY_ABORT("Unsupported destination (compressed) type %s",
|
||||
TypeName(type));
|
||||
}
|
||||
}
|
||||
|
||||
void InsertSfp(std::string name, absl::Span<const float> weights) override {
|
||||
AllocateAndCompress<SfpStream>(name, weights);
|
||||
void Write(const ModelConfig& config,
|
||||
const std::string& tokenizer_path) override {
|
||||
const GemmaTokenizer tokenizer(
|
||||
tokenizer_path.empty() ? kMockTokenizer
|
||||
: ReadFileToString(Path(tokenizer_path)));
|
||||
WriteSingleFile(config, tokenizer, serialized_mat_ptrs_, writer_);
|
||||
}
|
||||
|
||||
void InsertNUQ(std::string name, absl::Span<const float> weights) override {
|
||||
AllocateAndCompress<NuqStream>(name, weights);
|
||||
}
|
||||
|
||||
void InsertBfloat16(std::string name,
|
||||
absl::Span<const float> weights) override {
|
||||
AllocateAndCompress<BF16>(name, weights);
|
||||
}
|
||||
|
||||
void InsertFloat(std::string name, absl::Span<const float> weights) override {
|
||||
AllocateAndCompress<float>(name, weights);
|
||||
}
|
||||
|
||||
void AddScales(const std::vector<float>& scales) override {
|
||||
HWY_ASSERT(scales_.empty());
|
||||
scales_ = scales;
|
||||
compressor_.AddScales(scales_.data(), scales_.size());
|
||||
}
|
||||
|
||||
void AddTokenizer(const std::string& tokenizer_path) override {
|
||||
Path path(tokenizer_path);
|
||||
GemmaTokenizer tokenizer(path);
|
||||
std::string tokenizer_proto = tokenizer.Serialize();
|
||||
HWY_ASSERT(!tokenizer_proto.empty());
|
||||
compressor_.AddTokenizer(tokenizer_proto);
|
||||
}
|
||||
|
||||
// Returns the number of blobs added.
|
||||
size_t DebugNumBlobsAdded() const {
|
||||
if (mode_ == CompressorMode::kTEST_ONLY) return model_memory_.size();
|
||||
return compressor_.DebugNumBlobsAdded();
|
||||
}
|
||||
|
||||
int WriteWithConfig(std::string path, const ModelConfig* config) override {
|
||||
return compressor_.WriteAll(gcpp::Path(path), config);
|
||||
}
|
||||
|
||||
hwy::ThreadPool pool_;
|
||||
Compressor compressor_;
|
||||
ThreadingContext ctx_;
|
||||
CompressWorkingSet working_set_;
|
||||
std::vector<MatStorage> model_memory_;
|
||||
std::vector<float> scales_;
|
||||
CompressorMode mode_;
|
||||
BlobWriter writer_;
|
||||
std::vector<uint32_t> serialized_mat_ptrs_;
|
||||
};
|
||||
|
||||
WriterInterface* NewSbsWriter(CompressorMode mode) {
|
||||
return new SbsWriterImpl(mode);
|
||||
ISbsWriter* NewSbsWriter(const std::string& sbs_path) {
|
||||
return new SbsWriterImpl(sbs_path);
|
||||
}
|
||||
|
||||
} // namespace HWY_NAMESPACE
|
||||
|
|
@ -194,43 +149,11 @@ namespace gcpp {
|
|||
|
||||
HWY_EXPORT(NewSbsWriter);
|
||||
|
||||
SbsWriter::SbsWriter(CompressorMode mode)
|
||||
: impl_(HWY_DYNAMIC_DISPATCH(NewSbsWriter)(mode)) {}
|
||||
SbsWriter::~SbsWriter() = default;
|
||||
SbsWriter::SbsWriter(const std::string& path)
|
||||
: impl_(HWY_DYNAMIC_DISPATCH(NewSbsWriter)(path)) {}
|
||||
|
||||
void SbsWriter::Insert(std::string name, absl::Span<const float> weights,
|
||||
Type type, const TensorInfo& tensor_info, float scale) {
|
||||
impl_->Insert(name, weights, type, tensor_info, scale);
|
||||
}
|
||||
void SbsWriter::InsertSfp(std::string name, absl::Span<const float> weights) {
|
||||
impl_->InsertSfp(name, weights);
|
||||
}
|
||||
void SbsWriter::InsertNUQ(std::string name, absl::Span<const float> weights) {
|
||||
impl_->InsertNUQ(name, weights);
|
||||
}
|
||||
void SbsWriter::InsertBfloat16(std::string name,
|
||||
absl::Span<const float> weights) {
|
||||
impl_->InsertBfloat16(name, weights);
|
||||
}
|
||||
void SbsWriter::InsertFloat(std::string name, absl::Span<const float> weights) {
|
||||
impl_->InsertFloat(name, weights);
|
||||
}
|
||||
|
||||
void SbsWriter::AddScales(const std::vector<float>& scales) {
|
||||
impl_->AddScales(scales);
|
||||
}
|
||||
|
||||
void SbsWriter::AddTokenizer(const std::string& tokenizer_path) {
|
||||
impl_->AddTokenizer(tokenizer_path);
|
||||
}
|
||||
|
||||
size_t SbsWriter::DebugNumBlobsAdded() const {
|
||||
return impl_->DebugNumBlobsAdded();
|
||||
}
|
||||
|
||||
int SbsWriter::WriteWithConfig(std::string path, const ModelConfig* config) {
|
||||
return impl_->WriteWithConfig(path, config);
|
||||
}
|
||||
SbsReader::SbsReader(const std::string& path)
|
||||
: reader_(Path(path)), model_(reader_) {}
|
||||
|
||||
} // namespace gcpp
|
||||
#endif // HWY_ONCE
|
||||
|
|
|
|||
|
|
@ -16,52 +16,67 @@
|
|||
#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_PYTHON_COMPRESSION_CLIF_AUX_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_PYTHON_COMPRESSION_CLIF_AUX_H_
|
||||
|
||||
#include <cstddef>
|
||||
#include <stddef.h>
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "compression/shared.h"
|
||||
#include "compression/types.h" // Type
|
||||
#include "gemma/configs.h"
|
||||
#include "gemma/tensor_index.h"
|
||||
#include "gemma/model_store.h"
|
||||
#include "gemma/tensor_info.h"
|
||||
#include "io/blob_store.h"
|
||||
#include "util/mat.h"
|
||||
#include "hwy/aligned_allocator.h" // Span
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
// How to process the data.
|
||||
enum class CompressorMode {
|
||||
// No compression, no write to file, just for testing.
|
||||
kTEST_ONLY,
|
||||
// Old-style compression, no table of contents.
|
||||
kNO_TOC,
|
||||
// New-style compression, with table of contents.
|
||||
kWITH_TOC,
|
||||
// Can be modified in place by ScaleWeights.
|
||||
using F32Span = hwy::Span<float>;
|
||||
|
||||
// Interface because we compile one derived implementation per SIMD target,
|
||||
// because Compress() uses SIMD.
|
||||
class ISbsWriter {
|
||||
public:
|
||||
virtual ~ISbsWriter() = default;
|
||||
|
||||
virtual void Insert(const char* name, F32Span weights, Type type,
|
||||
const TensorInfo& tensor_info) = 0;
|
||||
|
||||
virtual void Write(const ModelConfig& config,
|
||||
const std::string& tokenizer_path) = 0;
|
||||
};
|
||||
|
||||
class WriterInterface;
|
||||
|
||||
// Non-virtual class used by pybind that calls the interface's virtual methods.
|
||||
// This avoids having to register the derived types with pybind.
|
||||
class SbsWriter {
|
||||
public:
|
||||
explicit SbsWriter(CompressorMode mode);
|
||||
~SbsWriter();
|
||||
explicit SbsWriter(const std::string& sbs_path);
|
||||
|
||||
void Insert(std::string name, absl::Span<const float> weights, Type type,
|
||||
const TensorInfo& tensor_info, float scale);
|
||||
void InsertSfp(std::string name, absl::Span<const float> weights);
|
||||
void InsertNUQ(std::string name, absl::Span<const float> weights);
|
||||
void InsertBfloat16(std::string name, absl::Span<const float> weights);
|
||||
void InsertFloat(std::string name, absl::Span<const float> weights);
|
||||
void AddScales(const std::vector<float>& scales);
|
||||
void AddTokenizer(const std::string& tokenizer_path);
|
||||
void Insert(const char* name, F32Span weights, Type type,
|
||||
const TensorInfo& tensor_info) {
|
||||
impl_->Insert(name, weights, type, tensor_info);
|
||||
}
|
||||
|
||||
size_t DebugNumBlobsAdded() const;
|
||||
|
||||
int Write(std::string path) { return WriteWithConfig(path, nullptr); }
|
||||
int WriteWithConfig(std::string path, const ModelConfig* config);
|
||||
void Write(const ModelConfig& config, const std::string& tokenizer_path) {
|
||||
impl_->Write(config, tokenizer_path);
|
||||
}
|
||||
|
||||
private:
|
||||
// Isolates Highway-dispatched types and other internals from CLIF.
|
||||
std::unique_ptr<WriterInterface> impl_;
|
||||
std::unique_ptr<ISbsWriter> impl_;
|
||||
};
|
||||
|
||||
// Limited metadata-only reader for tests.
|
||||
class SbsReader {
|
||||
public:
|
||||
SbsReader(const std::string& path);
|
||||
|
||||
const ModelConfig& Config() const { return model_.Config(); }
|
||||
const MatPtr* FindMat(const char* name) const { return model_.FindMat(name); }
|
||||
|
||||
private:
|
||||
gcpp::BlobReader reader_;
|
||||
gcpp::ModelStore model_;
|
||||
};
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
|
|
@ -15,58 +15,54 @@
|
|||
|
||||
#include <pybind11/numpy.h>
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "compression/python/compression_clif_aux.h"
|
||||
#include "compression/shared.h"
|
||||
#include "compression/types.h" // Type
|
||||
#include "gemma/tensor_info.h"
|
||||
#include "util/mat.h"
|
||||
|
||||
using gcpp::CompressorMode;
|
||||
using gcpp::MatPtr;
|
||||
using gcpp::SbsReader;
|
||||
using gcpp::SbsWriter;
|
||||
|
||||
namespace py = pybind11;
|
||||
namespace pybind11 {
|
||||
|
||||
namespace {
|
||||
template <auto Func>
|
||||
void wrap_span(SbsWriter& writer, std::string name, py::array_t<float> data) {
|
||||
static void CallWithF32Span(SbsWriter& writer, const char* name,
|
||||
array_t<float> data, gcpp::Type type,
|
||||
const gcpp::TensorInfo& tensor_info) {
|
||||
if (data.ndim() != 1 || data.strides(0) != sizeof(float)) {
|
||||
throw std::domain_error("Input array must be 1D and densely packed.");
|
||||
HWY_ABORT("Input array must be 1D (not %d) and contiguous floats.",
|
||||
static_cast<int>(data.ndim()));
|
||||
}
|
||||
std::invoke(Func, writer, name, absl::MakeSpan(data.data(0), data.size()));
|
||||
std::invoke(Func, writer, name,
|
||||
gcpp::F32Span(data.mutable_data(0), data.size()), type,
|
||||
tensor_info);
|
||||
}
|
||||
template <auto Func>
|
||||
void wrap_span_typed(SbsWriter& writer, std::string name,
|
||||
py::array_t<float> data, gcpp::Type type,
|
||||
gcpp::TensorInfo tensor_info, float scale) {
|
||||
if (data.ndim() != 1 || data.strides(0) != sizeof(float)) {
|
||||
throw std::domain_error("Input array must be 1D and densely packed.");
|
||||
}
|
||||
std::invoke(Func, writer, name, absl::MakeSpan(data.data(0), data.size()),
|
||||
type, tensor_info, scale);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
PYBIND11_MODULE(compression, m) {
|
||||
py::enum_<CompressorMode>(m, "CompressorMode")
|
||||
.value("TEST_ONLY", CompressorMode::kTEST_ONLY)
|
||||
.value("NO_TOC", CompressorMode::kNO_TOC)
|
||||
.value("WITH_TOC", CompressorMode::kWITH_TOC);
|
||||
class_<SbsWriter>(m, "SbsWriter")
|
||||
.def(init<std::string>())
|
||||
.def("insert", CallWithF32Span<&SbsWriter::Insert>)
|
||||
.def("write", &SbsWriter::Write, arg("config"), arg("tokenizer_path"));
|
||||
|
||||
py::class_<SbsWriter>(m, "SbsWriter")
|
||||
.def(py::init<CompressorMode>())
|
||||
// NOTE: Individual compression backends may impose constraints on the
|
||||
// array length, such as a minimum of (say) 32 elements.
|
||||
.def("insert", wrap_span_typed<&SbsWriter::Insert>)
|
||||
.def("insert_sfp", wrap_span<&SbsWriter::InsertSfp>)
|
||||
.def("insert_nuq", wrap_span<&SbsWriter::InsertNUQ>)
|
||||
.def("insert_bf16", wrap_span<&SbsWriter::InsertBfloat16>)
|
||||
.def("insert_float", wrap_span<&SbsWriter::InsertFloat>)
|
||||
.def("add_scales", &SbsWriter::AddScales)
|
||||
.def("add_tokenizer", &SbsWriter::AddTokenizer)
|
||||
.def("debug_num_blobs_added", &SbsWriter::DebugNumBlobsAdded)
|
||||
.def("write", &SbsWriter::Write)
|
||||
.def("write_with_config", &SbsWriter::WriteWithConfig);
|
||||
class_<MatPtr>(m, "MatPtr")
|
||||
// No init, only created within C++.
|
||||
.def_property_readonly("rows", &MatPtr::Rows, "Number of rows")
|
||||
.def_property_readonly("cols", &MatPtr::Cols, "Number of cols")
|
||||
.def_property_readonly("type", &MatPtr::GetType, "Element type")
|
||||
.def_property_readonly("scale", &MatPtr::Scale, "Scaling factor");
|
||||
|
||||
class_<SbsReader>(m, "SbsReader")
|
||||
.def(init<std::string>())
|
||||
.def_property_readonly("config", &SbsReader::Config,
|
||||
return_value_policy::reference_internal,
|
||||
"ModelConfig")
|
||||
.def("find_mat", &SbsReader::FindMat,
|
||||
return_value_policy::reference_internal,
|
||||
"Returns MatPtr for given name.");
|
||||
}
|
||||
|
||||
} // namespace pybind11
|
||||
|
|
|
|||
|
|
@ -25,46 +25,120 @@ from python import configs
|
|||
class CompressionTest(absltest.TestCase):
|
||||
|
||||
def test_sbs_writer(self):
|
||||
info_192 = configs.TensorInfo()
|
||||
info_192.name = "ignored_192"
|
||||
info_192.axes = [0]
|
||||
info_192.shape = [192]
|
||||
|
||||
temp_file = self.create_tempfile("test.sbs")
|
||||
tensor_info = configs.TensorInfo()
|
||||
tensor_info.name = "foo"
|
||||
tensor_info.axes = [0]
|
||||
tensor_info.shape = [192]
|
||||
|
||||
writer = compression.SbsWriter(compression.CompressorMode.NO_TOC)
|
||||
writer = compression.SbsWriter(temp_file.full_path)
|
||||
writer.insert(
|
||||
"foo",
|
||||
np.array([0.0012] * 128 + [0.001] * 64, dtype=np.float32),
|
||||
"tensor0",
|
||||
# Large enough to require scaling.
|
||||
np.array([3.0012] * 128 + [4.001] * 64, dtype=np.float32),
|
||||
configs.Type.kSFP,
|
||||
tensor_info,
|
||||
1.0,
|
||||
info_192,
|
||||
)
|
||||
|
||||
tensor_info_nuq = configs.TensorInfo()
|
||||
tensor_info_nuq.name = "fooNUQ"
|
||||
tensor_info_nuq.axes = [0]
|
||||
tensor_info_nuq.shape = [256]
|
||||
# 2D tensor.
|
||||
info_2d = configs.TensorInfo()
|
||||
info_2d.name = "ignored_2d"
|
||||
info_2d.axes = [0, 1]
|
||||
info_2d.shape = [96, 192]
|
||||
writer.insert(
|
||||
"fooNUQ",
|
||||
"tensor_2d",
|
||||
np.array([i / 1e3 for i in range(96 * 192)], dtype=np.float32),
|
||||
configs.Type.kBF16,
|
||||
info_2d,
|
||||
)
|
||||
|
||||
# 3D collapsed into rows.
|
||||
info_3d = configs.TensorInfo()
|
||||
info_3d.name = "ignored_3d"
|
||||
info_3d.axes = [0, 1, 2]
|
||||
info_3d.shape = [10, 12, 192]
|
||||
info_3d.cols_take_extra_dims = False
|
||||
writer.insert(
|
||||
"tensor_3d",
|
||||
# Verification of scale below depends on the shape and multiplier here.
|
||||
np.array([i / 1e3 for i in range(10 * 12 * 192)], dtype=np.float32),
|
||||
configs.Type.kSFP,
|
||||
info_3d,
|
||||
)
|
||||
|
||||
# Exercise all types supported by Compress.
|
||||
info_256 = configs.TensorInfo()
|
||||
info_256.name = "ignored_256"
|
||||
info_256.axes = [0]
|
||||
info_256.shape = [256]
|
||||
writer.insert(
|
||||
"tensor_sfp",
|
||||
np.array([0.000375] * 128 + [0.00009] * 128, dtype=np.float32),
|
||||
configs.Type.kNUQ,
|
||||
tensor_info_nuq,
|
||||
1.0,
|
||||
configs.Type.kSFP,
|
||||
info_256,
|
||||
)
|
||||
writer.insert_sfp(
|
||||
"bar", np.array([0.000375] * 128 + [0.00009] * 128, dtype=np.float32)
|
||||
writer.insert(
|
||||
"tensor_bf",
|
||||
np.array([0.000375] * 128 + [0.00007] * 128, dtype=np.float32),
|
||||
configs.Type.kBF16,
|
||||
info_256,
|
||||
)
|
||||
writer.insert_nuq(
|
||||
"baz", np.array([0.000125] * 128 + [0.00008] * 128, dtype=np.float32)
|
||||
writer.insert(
|
||||
"tensor_f32",
|
||||
np.array([0.000375] * 128 + [0.00006] * 128, dtype=np.float32),
|
||||
configs.Type.kF32,
|
||||
info_256,
|
||||
)
|
||||
writer.insert_bf16(
|
||||
"qux", np.array([0.000375] * 128 + [0.00007] * 128, dtype=np.float32)
|
||||
|
||||
config = configs.ModelConfig(
|
||||
configs.Model.GEMMA_TINY,
|
||||
configs.Type.kSFP,
|
||||
configs.PromptWrapping.GEMMA_IT,
|
||||
)
|
||||
writer.insert_float(
|
||||
"quux", np.array([0.000375] * 128 + [0.00006] * 128, dtype=np.float32)
|
||||
)
|
||||
self.assertEqual(writer.debug_num_blobs_added(), 6)
|
||||
self.assertEqual(writer.write(temp_file.full_path), 0)
|
||||
tokenizer_path = "" # no tokenizer required for testing
|
||||
writer.write(config, tokenizer_path)
|
||||
|
||||
print("Ignore next two warnings; test does not enable model deduction.")
|
||||
reader = compression.SbsReader(temp_file.full_path)
|
||||
|
||||
self.assertEqual(reader.config.model, configs.Model.GEMMA_TINY)
|
||||
self.assertEqual(reader.config.weight, configs.Type.kSFP)
|
||||
|
||||
mat = reader.find_mat("tensor0")
|
||||
self.assertEqual(mat.cols, 192)
|
||||
self.assertEqual(mat.rows, 1)
|
||||
self.assertEqual(mat.type, configs.Type.kSFP)
|
||||
self.assertAlmostEqual(mat.scale, 4.001 / 1.875, places=5)
|
||||
|
||||
mat = reader.find_mat("tensor_2d")
|
||||
self.assertEqual(mat.cols, 192)
|
||||
self.assertEqual(mat.rows, 96)
|
||||
self.assertEqual(mat.type, configs.Type.kBF16)
|
||||
self.assertAlmostEqual(mat.scale, 1.0)
|
||||
|
||||
mat = reader.find_mat("tensor_3d")
|
||||
self.assertEqual(mat.cols, 192)
|
||||
self.assertEqual(mat.rows, 10 * 12)
|
||||
self.assertEqual(mat.type, configs.Type.kSFP)
|
||||
self.assertAlmostEqual(mat.scale, 192 * 120 / 1e3 / 1.875, places=2)
|
||||
|
||||
mat = reader.find_mat("tensor_sfp")
|
||||
self.assertEqual(mat.cols, 256)
|
||||
self.assertEqual(mat.rows, 1)
|
||||
self.assertEqual(mat.type, configs.Type.kSFP)
|
||||
self.assertAlmostEqual(mat.scale, 1.0)
|
||||
|
||||
mat = reader.find_mat("tensor_bf")
|
||||
self.assertEqual(mat.cols, 256)
|
||||
self.assertEqual(mat.rows, 1)
|
||||
self.assertEqual(mat.type, configs.Type.kBF16)
|
||||
self.assertAlmostEqual(mat.scale, 1.0)
|
||||
|
||||
mat = reader.find_mat("tensor_f32")
|
||||
self.assertEqual(mat.cols, 256)
|
||||
self.assertEqual(mat.rows, 1)
|
||||
self.assertEqual(mat.type, configs.Type.kF32)
|
||||
self.assertAlmostEqual(mat.scale, 1.0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@
|
|||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include "compression/shared.h"
|
||||
#include "compression/types.h"
|
||||
#include "hwy/base.h"
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_SFP_INL_H_
|
||||
|
|
|
|||
|
|
@ -13,10 +13,10 @@
|
|||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// We use ConcatEven/Odd which are not supported. Use HWY_EMU128 instead.
|
||||
#include "compression/types.h"
|
||||
#ifndef HWY_DISABLED_TARGETS
|
||||
#define HWY_DISABLED_TARGETS HWY_SCALAR
|
||||
#endif
|
||||
#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS
|
||||
#endif // HWY_DISABLED_TARGETS
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
|
@ -25,7 +25,6 @@
|
|||
#include <set>
|
||||
|
||||
#include "compression/distortion.h"
|
||||
#include "compression/shared.h"
|
||||
#include "util/test_util.h"
|
||||
#include "hwy/aligned_allocator.h"
|
||||
#include "hwy/base.h"
|
||||
|
|
|
|||
|
|
@ -18,10 +18,13 @@
|
|||
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_TEST_UTIL_INL_H_
|
||||
|
||||
// IWYU pragma: begin_exports
|
||||
#include "compression/compress.h"
|
||||
#include "compression/distortion.h"
|
||||
#include "util/mat.h"
|
||||
// IWYU pragma: end_exports
|
||||
|
||||
#include "compression/compress.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_TEST_UTIL_INL_H_
|
||||
|
||||
// Include guard for (potentially) SIMD code.
|
||||
|
|
@ -59,7 +62,63 @@ void ForeachPackedAndRawType() {
|
|||
ForeachRawType<BF16, TestT>();
|
||||
ForeachRawType<float, TestT>();
|
||||
ForeachRawType<SfpStream, TestT>();
|
||||
ForeachRawType<NuqStream, TestT>();
|
||||
if constexpr (GEMMA_ENABLE_NUQ) {
|
||||
ForeachRawType<NuqStream, TestT>();
|
||||
}
|
||||
}
|
||||
|
||||
// Generates inputs: deterministic, within max SfpStream range.
|
||||
template <typename MatT>
|
||||
MatStorageT<MatT> GenerateMat(const Extents2D& extents,
|
||||
const Allocator& allocator, MatPadding padding,
|
||||
hwy::ThreadPool& pool) {
|
||||
gcpp::CompressWorkingSet ws;
|
||||
ws.tls.resize(pool.NumWorkers());
|
||||
MatStorageT<float> raw("raw", extents, allocator, MatPadding::kPacked);
|
||||
MatStorageT<MatT> compressed("mat", extents, allocator, padding);
|
||||
const float scale = SfpStream::kMax / extents.Area();
|
||||
pool.Run(0, extents.rows, [&](const size_t r, size_t thread) {
|
||||
float* HWY_RESTRICT row = raw.Row(r);
|
||||
for (size_t c = 0; c < extents.cols; c++) {
|
||||
float f = static_cast<float>(r * extents.cols + c) * scale;
|
||||
if ((r + c) & 1) f = -f; // Also generate some negative values.
|
||||
row[c] = f;
|
||||
}
|
||||
Compress(raw.Row(r), raw.Cols(), ws.tls[thread],
|
||||
MakeSpan(compressed.Row(r), compressed.Cols()),
|
||||
/*packed_ofs=*/0);
|
||||
});
|
||||
|
||||
compressed.SetScale(0.6f); // Arbitrary value, different from 1.
|
||||
return compressed;
|
||||
}
|
||||
|
||||
// Same, but `extents` describes the transposed matrix.
|
||||
template <typename MatT>
|
||||
MatStorageT<MatT> GenerateTransposedMat(const Extents2D extents,
|
||||
const Allocator& allocator,
|
||||
MatPadding padding,
|
||||
hwy::ThreadPool& pool) {
|
||||
gcpp::CompressWorkingSet ws;
|
||||
ws.tls.resize(pool.NumWorkers());
|
||||
MatStorageT<float> raw("raw", extents, allocator, MatPadding::kPacked);
|
||||
MatStorageT<MatT> compressed("trans", extents, allocator, padding);
|
||||
const float scale = SfpStream::kMax / extents.Area();
|
||||
pool.Run(0, extents.rows, [&](const size_t r, size_t thread) {
|
||||
float* HWY_RESTRICT row = raw.Row(r);
|
||||
for (size_t c = 0; c < extents.cols; c++) {
|
||||
float f = static_cast<float>(c * extents.rows + r) * scale;
|
||||
if ((r + c) & 1) f = -f; // Also generate some negative values.
|
||||
row[c] = f;
|
||||
}
|
||||
Compress(raw.Row(r), raw.Cols(), ws.tls[thread],
|
||||
MakeSpan(compressed.Row(r), compressed.Cols()),
|
||||
/*packed_ofs=*/0);
|
||||
});
|
||||
|
||||
// Arbitrary value, different from 1, must match `GenerateMat`.
|
||||
compressed.SetScale(0.6f);
|
||||
return compressed;
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||
|
|
|
|||
|
|
@ -13,18 +13,14 @@
|
|||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Definitions shared between the public compress-inl.h interface and the
|
||||
// sfp-inl.h and nuq-inl.h implementation details.
|
||||
// Types shared between tensor definitions and `compress-inl.h`.
|
||||
|
||||
#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_SHARED_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_SHARED_H_
|
||||
#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_TYPES_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_TYPES_H_
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include <complex>
|
||||
#include <cstdio>
|
||||
|
||||
// IWYU pragma: begin_exports
|
||||
#include "util/basics.h" // BF16
|
||||
#include "hwy/aligned_allocator.h"
|
||||
|
|
@ -33,6 +29,35 @@
|
|||
|
||||
namespace gcpp {
|
||||
|
||||
// EMU128 must not be disabled because we disable SCALAR.
|
||||
#define HWY_BROKEN_EMU128 0
|
||||
|
||||
// Allow user override of disabled targets.
|
||||
#ifndef GEMMA_DISABLED_TARGETS
|
||||
|
||||
// All platforms: exclude SCALAR because we use ReorderWidenMulAccumulate.
|
||||
|
||||
#if HWY_ARCH_ARM_V7
|
||||
// No NEON because we require double-precision support.
|
||||
#define HWY_DISABLED_TARGETS (HWY_SCALAR | HWY_ALL_NEON)
|
||||
#elif HWY_ARCH_ARM_A64
|
||||
// We do not yet use AES (e.g. for random generation), hence NEON is the same
|
||||
// as NEON_WITHOUT_AES. Also skip SVE because SVE2_128 and SVE_256 cover most.
|
||||
#define GEMMA_DISABLED_TARGETS (HWY_SCALAR | HWY_NEON | HWY_SVE)
|
||||
#elif HWY_ARCH_X86
|
||||
// Skip anything older than Haswell (2013); also use Zen4 for recent CPUs,
|
||||
// because we do not use anything added by SPR (e.g. FP16) nor AVX 10.2.
|
||||
#define GEMMA_DISABLED_TARGETS \
|
||||
(HWY_SCALAR | HWY_SSE2 | HWY_SSSE3 | HWY_SSE4 | HWY_AVX3_SPR | HWY_AVX10_2)
|
||||
#endif // HWY_ARCH_*
|
||||
|
||||
#endif // GEMMA_DISABLED_TARGETS
|
||||
|
||||
// Only used in experiments, hence disable in default builds.
|
||||
#ifndef GEMMA_ENABLE_NUQ
|
||||
#define GEMMA_ENABLE_NUQ 0
|
||||
#endif
|
||||
|
||||
// Switching Floating Point: a hybrid 8-bit float representation of bf16/f32
|
||||
// inputs that combines the advantages of e4m3 and e5m2 into a single format.
|
||||
// It supports seeking at a granularity of 1 and decoding to bf16/f32.
|
||||
|
|
@ -63,30 +88,6 @@ struct SfpStream {
|
|||
};
|
||||
#pragma pack(pop)
|
||||
|
||||
// Returns 1.0f if all magnitudes are <= SfpStream::kMax, otherwise scales them
|
||||
// such that the largest magnitude is SfpStream::kMax, and returns the
|
||||
// multiplier with which to restore the original values. This is only necessary
|
||||
// before compressing to SfpStream.
|
||||
// TODO: vectorize
|
||||
static inline float ScaleWeights(float* HWY_RESTRICT raw, size_t num) {
|
||||
float maxabs = 0.0;
|
||||
for (size_t i = 0; i < num; ++i) {
|
||||
maxabs = HWY_MAX(maxabs, hwy::ScalarAbs(raw[i]));
|
||||
}
|
||||
if (maxabs <= SfpStream::kMax) {
|
||||
return 1.0f;
|
||||
}
|
||||
const float scale = maxabs / SfpStream::kMax;
|
||||
const float inv_scale = static_cast<float>(1.0 / static_cast<double>(scale));
|
||||
for (size_t i = 0; i < num; ++i) {
|
||||
// Clamp because kMax may still be exceeded.
|
||||
const float magn =
|
||||
HWY_MIN(SfpStream::kMax, hwy::ScalarAbs(raw[i] * inv_scale));
|
||||
raw[i] = hwy::ScalarCopySign(magn, raw[i]);
|
||||
}
|
||||
return scale;
|
||||
}
|
||||
|
||||
// Non-uniform quantization: a compressed representation of f32 inputs that
|
||||
// supports seeking at a granularity of 1 (for `DecompressAndZeroPad`) or
|
||||
// two vectors (for `Decompress2`), and decoding to bf16/f32.
|
||||
|
|
@ -185,31 +186,25 @@ constexpr bool IsNuqStream() {
|
|||
return hwy::IsSame<hwy::RemoveCvRef<Packed>, NuqStream>();
|
||||
}
|
||||
|
||||
// Instruction-tuned models require extra 'turn structure' tokens in prompts.
|
||||
enum class PromptWrapping {
|
||||
GEMMA_IT,
|
||||
GEMMA_PT,
|
||||
GEMMA_VLM,
|
||||
PALIGEMMA,
|
||||
kSentinel // must be last
|
||||
// Tensor types for loading weights.
|
||||
enum class Type { kUnknown, kF32, kBF16, kSFP, kNUQ, kF64 };
|
||||
// These are used in `ModelConfig.Specifier`, hence the strings will not
|
||||
// change, though new ones may be added.
|
||||
static constexpr const char* kTypeStrings[] = {"unknown", "f32", "bf16",
|
||||
"sfp", "nuq", "f64"};
|
||||
static constexpr size_t kNumTypes =
|
||||
sizeof(kTypeStrings) / sizeof(kTypeStrings[0]);
|
||||
static constexpr size_t kTypeBits[] = {
|
||||
0,
|
||||
8 * sizeof(float),
|
||||
8 * sizeof(BF16),
|
||||
8 * sizeof(SfpStream),
|
||||
4 /* NuqStream, actually 4.5 */,
|
||||
8 * sizeof(double),
|
||||
};
|
||||
|
||||
inline bool EnumValid(PromptWrapping type) {
|
||||
return static_cast<int>(type) >= 0 &&
|
||||
static_cast<int>(type) < static_cast<int>(PromptWrapping::kSentinel);
|
||||
}
|
||||
|
||||
// Tensor types for loading weights. Note that not all types are supported as
|
||||
// weights for a model, but can be used for other purposes, such as types for
|
||||
// ModelWeightsPtrs. When adding a new type that is supported, also
|
||||
// update gemma.cc, weights.*, and add instantiations/new_one.cc.
|
||||
enum class Type { kUnknown, kF32, kBF16, kSFP, kNUQ, kF64, kC64, kU128 };
|
||||
constexpr const char* kTypeStrings[] = {"unknown", "f32", "bf16", "sfp",
|
||||
"nuq", "f64", "c64", "u128"};
|
||||
|
||||
inline bool EnumValid(Type type) {
|
||||
return static_cast<int>(type) >= 0 &&
|
||||
static_cast<int>(type) <= static_cast<int>(Type::kU128);
|
||||
static inline bool EnumValid(Type type) {
|
||||
return static_cast<size_t>(type) < kNumTypes;
|
||||
}
|
||||
|
||||
// Returns a Type enum for the type of the template parameter.
|
||||
|
|
@ -226,20 +221,22 @@ Type TypeEnum() {
|
|||
return Type::kNUQ;
|
||||
} else if constexpr (hwy::IsSame<Packed, double>()) {
|
||||
return Type::kF64;
|
||||
} else if constexpr (hwy::IsSame<Packed, std::complex<double>>()) {
|
||||
return Type::kC64;
|
||||
} else if constexpr (hwy::IsSame<Packed, hwy::uint128_t>()) {
|
||||
return Type::kU128;
|
||||
} else {
|
||||
HWY_DASSERT(false);
|
||||
return Type::kUnknown;
|
||||
}
|
||||
}
|
||||
|
||||
// Returns a string name for the type of the template parameter.
|
||||
static inline size_t TypeBits(Type type) {
|
||||
return kTypeBits[static_cast<int>(type)];
|
||||
}
|
||||
|
||||
static inline const char* TypeName(Type type) {
|
||||
return kTypeStrings[static_cast<int>(type)];
|
||||
}
|
||||
template <typename PackedT>
|
||||
const char* TypeName() {
|
||||
return kTypeStrings[static_cast<int>(TypeEnum<PackedT>())];
|
||||
return TypeName(TypeEnum<PackedT>());
|
||||
}
|
||||
|
||||
template <typename Packed>
|
||||
|
|
@ -248,7 +245,9 @@ constexpr bool IsCompressed() {
|
|||
}
|
||||
|
||||
// Returns the number of `MatT` elements required to store `capacity` values,
|
||||
// which must not be zero.
|
||||
// which must not be zero. This is only intended to support the extra tables
|
||||
// required for NUQ. `capacity` includes any padding and is `rows * stride`.
|
||||
// Deprecated, replaced by fixup within `MatPtr`. Only used by tests.
|
||||
template <typename Packed>
|
||||
constexpr size_t CompressedArrayElements(size_t capacity) {
|
||||
if constexpr (hwy::IsSame<hwy::RemoveCvRef<Packed>, NuqStream>()) {
|
||||
|
|
@ -304,4 +303,4 @@ HWY_INLINE PackedSpan<const Packed> MakeConst(PackedSpan<Packed> packed) {
|
|||
}
|
||||
|
||||
} // namespace gcpp
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_SHARED_H_
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_TYPES_H_
|
||||
|
|
@ -6,14 +6,12 @@
|
|||
#include <iostream>
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
#include <utility> // std::pair
|
||||
#include <vector>
|
||||
|
||||
#include "compression/io.h" // Path
|
||||
#include "evals/benchmark_helper.h"
|
||||
#include "evals/cross_entropy.h"
|
||||
#include "gemma/common.h"
|
||||
#include "gemma/gemma.h"
|
||||
#include "io/io.h" // Path
|
||||
#include "util/args.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/timer.h"
|
||||
|
|
@ -27,7 +25,6 @@ class BenchmarkArgs : public ArgsBase<BenchmarkArgs> {
|
|||
public:
|
||||
BenchmarkArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
|
||||
|
||||
Path goldens;
|
||||
Path summarize_text;
|
||||
Path cross_entropy;
|
||||
Path trivia_qa;
|
||||
|
|
@ -36,8 +33,6 @@ class BenchmarkArgs : public ArgsBase<BenchmarkArgs> {
|
|||
|
||||
template <class Visitor>
|
||||
void ForEach(const Visitor& visitor) {
|
||||
visitor(goldens.path, "goldens_dir", std::string(""),
|
||||
"Directory containing golden files", 2);
|
||||
visitor(summarize_text.path, "summarize_text", std::string(""),
|
||||
"Path to text file to summarize", 2);
|
||||
visitor(cross_entropy.path, "cross_entropy", std::string(""),
|
||||
|
|
@ -53,56 +48,6 @@ class BenchmarkArgs : public ArgsBase<BenchmarkArgs> {
|
|||
}
|
||||
};
|
||||
|
||||
std::vector<std::pair<std::string, std::string>> load_goldens(
|
||||
const std::string& path) {
|
||||
std::ifstream goldens_file(path);
|
||||
if (!goldens_file) {
|
||||
std::cout << "Could not load goldens file: " << path << "\n" << std::flush;
|
||||
return {};
|
||||
}
|
||||
std::vector<std::pair<std::string, std::string>> res;
|
||||
std::string query_separator;
|
||||
std::string query;
|
||||
std::string answer_separator;
|
||||
std::string answer;
|
||||
while (std::getline(goldens_file, query_separator) &&
|
||||
std::getline(goldens_file, query) &&
|
||||
std::getline(goldens_file, answer_separator) &&
|
||||
std::getline(goldens_file, answer)) {
|
||||
res.push_back({query, answer});
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
int BenchmarkGoldens(GemmaEnv& env, const std::string& golden_path) {
|
||||
std::vector<std::pair<std::string, std::string>> queries_answers =
|
||||
load_goldens(golden_path);
|
||||
size_t correct_answers = 0;
|
||||
size_t total_tokens = 0;
|
||||
const double time_start = hwy::platform::Now();
|
||||
for (auto& [question, expected_answer] : queries_answers) {
|
||||
QueryResult result = env.QueryModel(question);
|
||||
total_tokens += result.tokens_generated;
|
||||
if (result.response.find(expected_answer) != std::string::npos) {
|
||||
correct_answers++;
|
||||
} else {
|
||||
std::cout << "Wrong!\n";
|
||||
std::cout << "Input: " << question << "\n";
|
||||
std::cout << "Expected: " << expected_answer << "\n";
|
||||
std::cout << "Output: " << result.response << "\n\n" << std::flush;
|
||||
}
|
||||
}
|
||||
LogSpeedStats(time_start, total_tokens);
|
||||
|
||||
std::cout << "Correct: " << correct_answers << " out of "
|
||||
<< queries_answers.size() << "\n"
|
||||
<< std::flush;
|
||||
if (correct_answers != queries_answers.size()) {
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
return EXIT_SUCCESS;
|
||||
}
|
||||
|
||||
int BenchmarkSummary(GemmaEnv& env, const Path& text) {
|
||||
std::string prompt("Here is some text to summarize:\n");
|
||||
prompt.append(ReadFileToString(text));
|
||||
|
|
@ -117,6 +62,7 @@ int BenchmarkSummary(GemmaEnv& env, const Path& text) {
|
|||
|
||||
int BenchmarkCrossEntropy(GemmaEnv& env, const Path& text,
|
||||
size_t batch_tokens) {
|
||||
const Gemma& gemma = *env.GetGemma();
|
||||
std::string input = ReadFileToString(text);
|
||||
std::vector<int> prompt = env.Tokenize(input);
|
||||
std::cout << "Number of input tokens: " << prompt.size() << "\n";
|
||||
|
|
@ -128,10 +74,11 @@ int BenchmarkCrossEntropy(GemmaEnv& env, const Path& text,
|
|||
size_t num_tokens = std::min<size_t>(prompt.size() - pos, batch_tokens);
|
||||
std::vector<int> prompt_slice(prompt.begin() + pos,
|
||||
prompt.begin() + pos + num_tokens);
|
||||
KVCache kv_cache = KVCache::Create(env.GetModel()->GetModelConfig(),
|
||||
env.MutableConfig().prefill_tbatch_size);
|
||||
float entropy = ComputeCrossEntropy(
|
||||
*env.GetModel(), num_tokens, prompt_slice, kv_cache, env.Verbosity());
|
||||
KVCache kv_cache(gemma.Config(), gemma.Inference(),
|
||||
env.MutableEnv().ctx.allocator);
|
||||
float entropy =
|
||||
ComputeCrossEntropy(*env.GetGemma(), num_tokens, prompt_slice, kv_cache,
|
||||
env.MutableEnv(), env.Verbosity());
|
||||
total_entropy += entropy;
|
||||
LogSpeedStats(time_start, pos + num_tokens);
|
||||
std::string text_slice = env.StringFromTokens(prompt_slice);
|
||||
|
|
@ -183,14 +130,7 @@ int main(int argc, char** argv) {
|
|||
gcpp::GemmaEnv env(argc, argv);
|
||||
gcpp::BenchmarkArgs benchmark_args(argc, argv);
|
||||
|
||||
if (!benchmark_args.goldens.Empty()) {
|
||||
const std::string golden_path =
|
||||
benchmark_args.goldens.path + "/" +
|
||||
gcpp::ModelString(env.GetModel()->Info().model,
|
||||
env.GetModel()->Info().wrapping) +
|
||||
".txt";
|
||||
return BenchmarkGoldens(env, golden_path);
|
||||
} else if (!benchmark_args.summarize_text.Empty()) {
|
||||
if (!benchmark_args.summarize_text.Empty()) {
|
||||
return BenchmarkSummary(env, benchmark_args.summarize_text);
|
||||
} else if (!benchmark_args.cross_entropy.Empty()) {
|
||||
return BenchmarkCrossEntropy(env, benchmark_args.cross_entropy,
|
||||
|
|
|
|||
|
|
@ -18,27 +18,21 @@
|
|||
#include <stdio.h>
|
||||
#include <time.h>
|
||||
|
||||
#include <cstdio>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <ostream>
|
||||
#include <random>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
// Placeholder for internal header, do not modify.
|
||||
#include "compression/compress.h" // TypeName
|
||||
#include "compression/types.h" // TypeName
|
||||
#include "evals/cross_entropy.h"
|
||||
#include "gemma/common.h" // StringFromType
|
||||
#include "gemma/gemma.h"
|
||||
#include "gemma/kv_cache.h"
|
||||
#include "util/app.h"
|
||||
#include "util/args.h"
|
||||
#include "util/threading.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/topology.h"
|
||||
#include "gemma/gemma_args.h"
|
||||
#include "ops/matmul.h" // MatMulEnv
|
||||
#include "util/threading_context.h"
|
||||
#include "hwy/highway.h"
|
||||
#include "hwy/per_target.h" // VectorBytes
|
||||
#include "hwy/per_target.h" // DispatchedTarget
|
||||
#include "hwy/profiler.h" // PROFILER_ENABLED
|
||||
#include "hwy/timer.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
|
@ -49,53 +43,37 @@ void InitGenerator(const InferenceArgs& inference, std::mt19937& gen) {
|
|||
gen.seed(0x12345678);
|
||||
} else {
|
||||
// Depending on the library implementation, this may still be deterministic.
|
||||
std::random_device rd;
|
||||
std::random_device rd; // NOLINT
|
||||
gen.seed(rd());
|
||||
}
|
||||
}
|
||||
|
||||
GemmaEnv::GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference,
|
||||
const AppArgs& app)
|
||||
: topology_(CreateTopology(app)),
|
||||
pools_(CreatePools(topology_, app)),
|
||||
env_(topology_, pools_) {
|
||||
InferenceArgs mutable_inference = inference;
|
||||
AbortIfInvalidArgs(mutable_inference);
|
||||
LoaderArgs mutable_loader = loader;
|
||||
if (const char* err = mutable_loader.Validate()) {
|
||||
mutable_loader.Help();
|
||||
fprintf(stderr, "Skipping model load because: %s\n", err);
|
||||
} else {
|
||||
fprintf(stderr, "Loading model...\n");
|
||||
model_ = AllocateGemma(mutable_loader, env_);
|
||||
// Only allocate one for starters because GenerateBatch might not be called.
|
||||
kv_caches_.resize(1);
|
||||
kv_caches_[0] = KVCache::Create(model_->GetModelConfig(),
|
||||
inference.prefill_tbatch_size);
|
||||
GemmaEnv::GemmaEnv(const LoaderArgs& loader, const ThreadingArgs& threading,
|
||||
const InferenceArgs& inference)
|
||||
: ctx_(threading), env_(ctx_), gemma_(loader, inference, ctx_) {
|
||||
const ModelConfig& config = gemma_.Config();
|
||||
// Only allocate one for starters because GenerateBatch might not be called.
|
||||
kv_caches_.push_back(KVCache(config, inference, ctx_.allocator));
|
||||
|
||||
if (inference.verbosity >= 2) {
|
||||
ShowConfig(loader, threading, inference, config, gemma_.WeightReadMode(),
|
||||
ctx_);
|
||||
}
|
||||
|
||||
InitGenerator(inference, gen_);
|
||||
|
||||
runtime_config_ = {
|
||||
.max_generated_tokens = inference.max_generated_tokens,
|
||||
.temperature = inference.temperature,
|
||||
.gen = &gen_,
|
||||
.verbosity = app.verbosity,
|
||||
.verbosity = inference.verbosity,
|
||||
};
|
||||
}
|
||||
|
||||
// Internal init must run before the GemmaEnv ctor above, hence it cannot occur
|
||||
// in the argv ctor below because its body runs *after* the delegating ctor.
|
||||
// This helper function takes care of the init, and could be applied to any of
|
||||
// the *Args classes, it does not matter which.
|
||||
static AppArgs MakeAppArgs(int argc, char** argv) {
|
||||
{ // So that indentation matches expectations.
|
||||
// Placeholder for internal init, do not modify.
|
||||
}
|
||||
return AppArgs(argc, argv);
|
||||
inference.CopyTo(runtime_config_);
|
||||
}
|
||||
|
||||
GemmaEnv::GemmaEnv(int argc, char** argv)
|
||||
: GemmaEnv(LoaderArgs(argc, argv), InferenceArgs(argc, argv),
|
||||
MakeAppArgs(argc, argv)) {}
|
||||
: GemmaEnv(LoaderArgs(argc, argv), ThreadingArgs(argc, argv),
|
||||
InferenceArgs(argc, argv)) {}
|
||||
|
||||
QueryResult GemmaEnv::QueryModel(const std::vector<int>& tokens) {
|
||||
QueryResult result;
|
||||
|
|
@ -117,8 +95,8 @@ QueryResult GemmaEnv::QueryModel(const std::vector<int>& tokens) {
|
|||
}
|
||||
gcpp::TimingInfo timing_info { .verbosity = runtime_config_.verbosity };
|
||||
runtime_config_.batch_stream_token = batch_stream_token;
|
||||
model_->Generate(runtime_config_, tokens, /*start_pos=*/0, kv_caches_[0],
|
||||
timing_info);
|
||||
gemma_.Generate(runtime_config_, tokens, /*start_pos=*/0, kv_caches_[0], env_,
|
||||
timing_info);
|
||||
return result;
|
||||
}
|
||||
|
||||
|
|
@ -127,23 +105,25 @@ void GemmaEnv::QueryModel(
|
|||
gcpp::TimingInfo timing_info { .verbosity = runtime_config_.verbosity };
|
||||
const StreamFunc previous_stream_token = runtime_config_.stream_token;
|
||||
runtime_config_.stream_token = stream_token;
|
||||
model_->Generate(runtime_config_, tokens, /*start_pos=*/0, kv_caches_[0],
|
||||
timing_info);
|
||||
gemma_.Generate(runtime_config_, tokens, /*start_pos=*/0, kv_caches_[0], env_,
|
||||
timing_info);
|
||||
runtime_config_.stream_token = previous_stream_token;
|
||||
}
|
||||
|
||||
std::vector<QueryResult> GemmaEnv::BatchQueryModel(
|
||||
const QueriesPromptTokens& queries_prompt) {
|
||||
const QueriesPromptTokens& queries_prompt,
|
||||
const hwy::Span<const size_t>& prefix_end) {
|
||||
const size_t num_queries = queries_prompt.size();
|
||||
HWY_ASSERT(num_queries != 0);
|
||||
std::vector<QueryResult> res(num_queries);
|
||||
const BatchStreamFunc batch_stream_token = [&res, &queries_prompt, this](
|
||||
size_t query_index, size_t pos,
|
||||
int token, float) {
|
||||
const BatchStreamFunc batch_stream_token = [&, this](const size_t query_index,
|
||||
const size_t pos,
|
||||
const int token, float) {
|
||||
HWY_ASSERT(query_index < num_queries);
|
||||
std::string token_text;
|
||||
HWY_ASSERT(
|
||||
model_->Tokenizer().Decode(std::vector<int>{token}, &token_text));
|
||||
HWY_ASSERT(gemma_.Tokenizer().Decode(std::vector<int>{token}, &token_text));
|
||||
res[query_index].response.append(token_text);
|
||||
HWY_ASSERT(pos == res[query_index].tokens_generated);
|
||||
res[query_index].tokens_generated += 1;
|
||||
if (res[query_index].tokens_generated ==
|
||||
queries_prompt[query_index].size()) {
|
||||
|
|
@ -151,6 +131,7 @@ std::vector<QueryResult> GemmaEnv::BatchQueryModel(
|
|||
}
|
||||
return true;
|
||||
};
|
||||
runtime_config_.batch_stream_token = batch_stream_token;
|
||||
if (runtime_config_.verbosity >= 2) {
|
||||
fprintf(stderr, "Max gen: %zu temp: %f tbatch: %zu qbatch: %zu\n",
|
||||
runtime_config_.max_generated_tokens, runtime_config_.temperature,
|
||||
|
|
@ -158,23 +139,16 @@ std::vector<QueryResult> GemmaEnv::BatchQueryModel(
|
|||
runtime_config_.decode_qbatch_size);
|
||||
}
|
||||
|
||||
// Ensure we have one KVCache per query.
|
||||
if (kv_caches_.size() < num_queries) {
|
||||
kv_caches_.resize(num_queries);
|
||||
}
|
||||
for (size_t i = 1; i < num_queries; ++i) {
|
||||
if (kv_caches_[i].seq_len == 0) {
|
||||
kv_caches_[i] = KVCache::Create(model_->GetModelConfig(),
|
||||
runtime_config_.prefill_tbatch_size);
|
||||
}
|
||||
// Ensure we have at least one KVCache per query.
|
||||
while (kv_caches_.size() < num_queries) {
|
||||
kv_caches_.push_back(
|
||||
KVCache(gemma_.Config(), gemma_.Inference(), ctx_.allocator));
|
||||
}
|
||||
const hwy::Span<KVCache> kv_caches(&kv_caches_[0], num_queries);
|
||||
|
||||
gcpp::AllQueries all_queries(queries_prompt, kv_caches, prefix_end);
|
||||
gcpp::TimingInfo timing_info = {.verbosity = runtime_config_.verbosity};
|
||||
runtime_config_.batch_stream_token = batch_stream_token;
|
||||
std::vector<size_t> queries_pos(num_queries, 0);
|
||||
model_->GenerateBatch(runtime_config_, queries_prompt,
|
||||
QueriesPos(queries_pos.data(), num_queries),
|
||||
KVCaches(&kv_caches_[0], num_queries), timing_info);
|
||||
gemma_.GenerateBatch(runtime_config_, all_queries, env_, timing_info);
|
||||
return res;
|
||||
}
|
||||
|
||||
|
|
@ -203,8 +177,8 @@ std::vector<QueryResult> GemmaEnv::BatchQueryModel(
|
|||
float GemmaEnv::CrossEntropy(const std::string& input) {
|
||||
std::vector<int> prompt = Tokenize(input);
|
||||
prompt.insert(prompt.begin(), BOS_ID);
|
||||
return ComputeCrossEntropy(*GetModel(), /*max_generated_tokens=*/3072, prompt,
|
||||
MutableKVCache(),
|
||||
return ComputeCrossEntropy(*GetGemma(), /*max_generated_tokens=*/3072, prompt,
|
||||
MutableKVCache(), env_,
|
||||
/*verbosity=*/0) /
|
||||
static_cast<int>(input.size());
|
||||
}
|
||||
|
|
@ -236,13 +210,37 @@ std::string CacheString() {
|
|||
return buf;
|
||||
}
|
||||
|
||||
void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app,
|
||||
const BoundedTopology& topology, NestedPools& pools) {
|
||||
loader.Print(app.verbosity);
|
||||
inference.Print(app.verbosity);
|
||||
app.Print(app.verbosity);
|
||||
static constexpr const char* CompiledConfig() {
|
||||
if constexpr (HWY_IS_ASAN) {
|
||||
return "asan";
|
||||
} else if constexpr (HWY_IS_MSAN) {
|
||||
return "msan";
|
||||
} else if constexpr (HWY_IS_TSAN) {
|
||||
return "tsan";
|
||||
} else if constexpr (HWY_IS_HWASAN) {
|
||||
return "hwasan";
|
||||
} else if constexpr (HWY_IS_UBSAN) {
|
||||
return "ubsan";
|
||||
} else if constexpr (HWY_IS_DEBUG_BUILD) {
|
||||
return "dbg";
|
||||
} else {
|
||||
return "opt";
|
||||
}
|
||||
}
|
||||
|
||||
if (app.verbosity >= 2) {
|
||||
void ShowConfig(const LoaderArgs& loader, const ThreadingArgs& threading,
|
||||
const InferenceArgs& inference, const ModelConfig& config,
|
||||
const WeightsPtrs::Mode weight_read_mode,
|
||||
const ThreadingContext& ctx) {
|
||||
threading.Print(inference.verbosity);
|
||||
loader.Print(inference.verbosity);
|
||||
inference.Print(inference.verbosity);
|
||||
fprintf(
|
||||
stderr, "Model : %s, to_bf16 %d, mmap %d => %s\n",
|
||||
config.Specifier().c_str(), static_cast<int>(loader.to_bf16),
|
||||
static_cast<int>(loader.map), WeightsPtrs::ToString(weight_read_mode));
|
||||
|
||||
if (inference.verbosity >= 2) {
|
||||
time_t now = time(nullptr);
|
||||
char* dt = ctime(&now); // NOLINT
|
||||
char cpu100[100] = "unknown";
|
||||
|
|
@ -250,38 +248,34 @@ void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app,
|
|||
|
||||
fprintf(stderr,
|
||||
"Date & Time : %s" // dt includes \n
|
||||
"CPU : %s\n"
|
||||
"CPU : %s, bind %d\n"
|
||||
"CPU topology : %s, %s, %s\n"
|
||||
"Instruction set : %s (%zu bits)\n"
|
||||
"Compiled config : %s\n"
|
||||
"Weight Type : %s\n"
|
||||
"EmbedderInput Type : %s\n",
|
||||
dt, cpu100, topology.TopologyString(), pools.PinString(),
|
||||
"Compiled config : %s, profiler %d\n"
|
||||
"Memory MiB : %4zu\n",
|
||||
dt, cpu100, static_cast<int>(threading.bind),
|
||||
ctx.topology.TopologyString(), ctx.pools.PinString(),
|
||||
CacheString().c_str(), hwy::TargetName(hwy::DispatchedTarget()),
|
||||
hwy::VectorBytes() * 8, CompiledConfig(),
|
||||
StringFromType(loader.Info().weight), TypeName<EmbedderInputT>());
|
||||
ctx.allocator.VectorBytes() * 8, CompiledConfig(), PROFILER_ENABLED,
|
||||
ctx.allocator.TotalMiB());
|
||||
}
|
||||
}
|
||||
|
||||
void ShowHelp(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
|
||||
void ShowHelp(const LoaderArgs& loader, const ThreadingArgs& threading,
|
||||
const InferenceArgs& inference) {
|
||||
std::cerr
|
||||
<< "\n\ngemma.cpp : a lightweight, standalone C++ inference engine\n"
|
||||
"==========================================================\n\n"
|
||||
"To run gemma.cpp, you need to "
|
||||
"specify 3 required model loading arguments:\n"
|
||||
" --tokenizer\n"
|
||||
" --weights\n"
|
||||
" --model,\n"
|
||||
" or with the newer weights format, specify just:\n"
|
||||
" --weights\n";
|
||||
"To run with pre-2025 weights, specify --tokenizer and --weights.\n"
|
||||
"With the single-file weights format, specify just --weights.\n";
|
||||
std::cerr << "\n*Example Usage*\n\n./gemma --tokenizer tokenizer.spm "
|
||||
"--weights 2b-it-sfp.sbs --model 2b-it\n";
|
||||
"--weights gemma2-2b-it-sfp.sbs\n";
|
||||
std::cerr << "\n*Model Loading Arguments*\n\n";
|
||||
loader.Help();
|
||||
std::cerr << "\n*Threading Arguments*\n\n";
|
||||
threading.Help();
|
||||
std::cerr << "\n*Inference Arguments*\n\n";
|
||||
inference.Help();
|
||||
std::cerr << "\n*Application Arguments*\n\n";
|
||||
app.Help();
|
||||
std::cerr << "\n";
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -18,15 +18,16 @@
|
|||
|
||||
#include <stddef.h>
|
||||
|
||||
#include <memory>
|
||||
#include <random>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "gemma/configs.h"
|
||||
#include "gemma/gemma.h"
|
||||
#include "gemma/gemma_args.h"
|
||||
#include "gemma/tokenizer.h" // WrapAndTokenize
|
||||
#include "ops/matmul.h"
|
||||
#include "util/app.h"
|
||||
#include "util/threading.h"
|
||||
#include "util/threading_context.h"
|
||||
#include "hwy/base.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
|
@ -46,19 +47,21 @@ class GemmaEnv {
|
|||
public:
|
||||
// Calls the other constructor with *Args arguments initialized from argv.
|
||||
GemmaEnv(int argc, char** argv);
|
||||
GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference,
|
||||
const AppArgs& app);
|
||||
GemmaEnv(const LoaderArgs& loader, const ThreadingArgs& threading,
|
||||
const InferenceArgs& inference);
|
||||
MatMulEnv& Env() { return env_; }
|
||||
|
||||
size_t MaxGeneratedTokens() const {
|
||||
return runtime_config_.max_generated_tokens;
|
||||
}
|
||||
void SetMaxGeneratedTokens(size_t max_generated_tokens) {
|
||||
runtime_config_.max_generated_tokens = max_generated_tokens;
|
||||
void SetMaxGeneratedTokens(int max_generated_tokens) {
|
||||
runtime_config_.max_generated_tokens =
|
||||
static_cast<size_t>(max_generated_tokens);
|
||||
}
|
||||
|
||||
std::vector<int> Tokenize(const std::string& input) const {
|
||||
std::vector<int> tokens;
|
||||
HWY_ASSERT(model_->Tokenizer().Encode(input, &tokens));
|
||||
HWY_ASSERT(gemma_.Tokenizer().Encode(input, &tokens));
|
||||
return tokens;
|
||||
}
|
||||
|
||||
|
|
@ -69,20 +72,23 @@ class GemmaEnv {
|
|||
}
|
||||
|
||||
std::vector<int> WrapAndTokenize(std::string& input) const {
|
||||
return gcpp::WrapAndTokenize(model_->Tokenizer(), model_->Info(), 0, input);
|
||||
return gcpp::WrapAndTokenize(gemma_.Tokenizer(), gemma_.ChatTemplate(),
|
||||
gemma_.Config().wrapping, 0, input);
|
||||
}
|
||||
|
||||
std::string StringFromTokens(const std::vector<int>& tokens) const {
|
||||
std::string string;
|
||||
HWY_ASSERT(model_->Tokenizer().Decode(tokens, &string));
|
||||
HWY_ASSERT(gemma_.Tokenizer().Decode(tokens, &string));
|
||||
return string;
|
||||
}
|
||||
|
||||
// Runs inference on the given input and returns the top-1 result string and
|
||||
// the number of tokens that were generated.
|
||||
QueryResult QueryModel(const std::vector<int>& tokens);
|
||||
// The default prefix_end means "causal attention".
|
||||
std::vector<QueryResult> BatchQueryModel(
|
||||
const QueriesPromptTokens& queries_prompt);
|
||||
const QueriesPromptTokens& queries_prompt,
|
||||
const hwy::Span<const size_t>& prefix_end = hwy::Span<const size_t>());
|
||||
// Adds turn structure to input, tokenizes and calls the above overload.
|
||||
QueryResult QueryModel(std::string& input);
|
||||
std::vector<QueryResult> BatchQueryModel(
|
||||
|
|
@ -97,20 +103,19 @@ class GemmaEnv {
|
|||
// number of bits per token.
|
||||
float CrossEntropy(const std::string& input);
|
||||
|
||||
// Returns nullptr if the model failed to load.
|
||||
Gemma* GetModel() const { return model_.get(); }
|
||||
const Gemma* GetGemma() const { return &gemma_; }
|
||||
|
||||
int Verbosity() const { return runtime_config_.verbosity; }
|
||||
RuntimeConfig& MutableConfig() { return runtime_config_; }
|
||||
std::mt19937& MutableGen() { return gen_; }
|
||||
KVCache& MutableKVCache() { return kv_caches_[0]; }
|
||||
MatMulEnv& MutableEnv() { return env_; }
|
||||
|
||||
private:
|
||||
BoundedTopology topology_;
|
||||
NestedPools pools_; // Thread pool.
|
||||
ThreadingContext ctx_;
|
||||
MatMulEnv env_;
|
||||
std::mt19937 gen_; // Random number generator.
|
||||
std::unique_ptr<Gemma> model_;
|
||||
Gemma gemma_;
|
||||
std::mt19937 gen_; // Random number generator.
|
||||
std::vector<KVCache> kv_caches_; // Same number as query batch.
|
||||
RuntimeConfig runtime_config_;
|
||||
};
|
||||
|
|
@ -118,9 +123,12 @@ class GemmaEnv {
|
|||
// Logs the inference speed in tokens/sec.
|
||||
void LogSpeedStats(double time_start, size_t total_tokens);
|
||||
|
||||
void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app,
|
||||
const BoundedTopology& topology, NestedPools& pools);
|
||||
void ShowHelp(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app);
|
||||
void ShowConfig(const LoaderArgs& loader, const ThreadingArgs& threading,
|
||||
const InferenceArgs& inference, const ModelConfig& config,
|
||||
WeightsPtrs::Mode weight_read_mode,
|
||||
const ThreadingContext& ctx);
|
||||
void ShowHelp(const LoaderArgs& loader, const ThreadingArgs& threading,
|
||||
const InferenceArgs& inference);
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
|
|
|
|||
|
|
@ -13,6 +13,11 @@
|
|||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "compression/types.h" // GEMMA_DISABLED_TARGETS
|
||||
#ifndef HWY_DISABLED_TARGETS
|
||||
#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS
|
||||
#endif // HWY_DISABLED_TARGETS
|
||||
|
||||
// Compiles this file for multiple architectures via "foreach_target.h", to
|
||||
// which we pass the filename via macro 'argument'.
|
||||
// clang-format off
|
||||
|
|
@ -38,17 +43,12 @@
|
|||
#include <vector>
|
||||
|
||||
#include "evals/cross_entropy.h"
|
||||
#include "gemma/common.h"
|
||||
#include "gemma/gemma.h"
|
||||
#include "hwy/base.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
namespace {
|
||||
template <typename TConfig>
|
||||
struct GetVocabSize {
|
||||
int operator()() const { return TConfig::kVocabSize; }
|
||||
};
|
||||
|
||||
static std::string TokenString(const GemmaTokenizer& tokenizer, int token) {
|
||||
std::string token_str;
|
||||
|
|
@ -85,7 +85,7 @@ namespace gcpp {
|
|||
namespace HWY_NAMESPACE {
|
||||
|
||||
void CallSoftmax(float* HWY_RESTRICT logits, size_t vocab_size) {
|
||||
Softmax(logits, vocab_size);
|
||||
Softmax(logits, vocab_size, /*worker=*/0);
|
||||
}
|
||||
|
||||
} // namespace HWY_NAMESPACE
|
||||
|
|
@ -97,12 +97,12 @@ namespace gcpp {
|
|||
|
||||
HWY_EXPORT(CallSoftmax);
|
||||
|
||||
float ComputeCrossEntropy(Gemma& gemma, size_t max_generated_tokens,
|
||||
float ComputeCrossEntropy(const Gemma& gemma, size_t max_generated_tokens,
|
||||
const std::vector<int>& prompt, KVCache& kv_cache,
|
||||
int verbosity) {
|
||||
MatMulEnv& env, int verbosity) {
|
||||
const StreamFunc stream_token = [](int, float) { return true; };
|
||||
|
||||
const int vocab_size = gemma.GetModelConfig().vocab_size;
|
||||
const int vocab_size = gemma.Config().vocab_size;
|
||||
float cross_entropy = std::log(vocab_size); // first token; == -log(1/v_s)
|
||||
size_t pos = 1;
|
||||
|
||||
|
|
@ -145,7 +145,7 @@ float ComputeCrossEntropy(Gemma& gemma, size_t max_generated_tokens,
|
|||
};
|
||||
TimingInfo timing_info;
|
||||
|
||||
gemma.Generate(runtime, prompt0, 0, kv_cache, timing_info);
|
||||
gemma.Generate(runtime, prompt0, 0, kv_cache, env, timing_info);
|
||||
|
||||
const float scale = 1.0f / std::log(2.0f);
|
||||
return cross_entropy * scale;
|
||||
|
|
|
|||
|
|
@ -24,9 +24,9 @@
|
|||
|
||||
namespace gcpp {
|
||||
|
||||
float ComputeCrossEntropy(Gemma& gemma, size_t max_generated_tokens,
|
||||
float ComputeCrossEntropy(const Gemma& gemma, size_t max_generated_tokens,
|
||||
const std::vector<int>& prompt, KVCache& kv_cache,
|
||||
int verbosity);
|
||||
MatMulEnv& env, int verbosity);
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
|
|
|
|||
|
|
@ -18,9 +18,9 @@
|
|||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "compression/io.h"
|
||||
#include "evals/benchmark_helper.h"
|
||||
#include "gemma/gemma.h" // LayersOutputFunc
|
||||
#include "io/io.h"
|
||||
#include "util/args.h"
|
||||
#include "hwy/base.h"
|
||||
#include "nlohmann/json.hpp"
|
||||
|
|
|
|||
|
|
@ -13,25 +13,18 @@
|
|||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "gemma/gemma.h"
|
||||
|
||||
#include <stdio.h>
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "evals/benchmark_helper.h"
|
||||
#include "gemma/common.h"
|
||||
#include "gemma/gemma.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/nanobenchmark.h"
|
||||
#include "hwy/profiler.h"
|
||||
#include "hwy/tests/hwy_gtest.h"
|
||||
|
||||
// This test can be run manually with the downloaded gemma weights.
|
||||
// To run the test, pass the following flags:
|
||||
// --model <model> --tokenizer <tokenizer_path> --weights <weights_path>
|
||||
// It should pass for the following models:
|
||||
// Gemma1: 2b-it (v1 and v1.1), 7b-it (v1 and v1.1), gr2b-it,
|
||||
// Gemma2: gemma2-2b-it, 9b-it, 27b-it,
|
||||
|
||||
namespace gcpp {
|
||||
namespace {
|
||||
|
||||
|
|
@ -40,61 +33,23 @@ namespace {
|
|||
// non-local static variables with dtors.
|
||||
GemmaEnv* s_env = nullptr;
|
||||
|
||||
class GemmaTest : public ::testing::Test {
|
||||
class GemmaBatchBench : public ::testing::Test {
|
||||
protected:
|
||||
std::vector<std::string> BatchGemmaReply(
|
||||
const std::vector<std::string>& inputs) {
|
||||
s_env->SetMaxGeneratedTokens(64);
|
||||
s_env->SetMaxGeneratedTokens(24);
|
||||
s_env->MutableConfig().temperature = 0.0f; // deterministic
|
||||
s_env->MutableConfig().verbosity = 5;
|
||||
s_env->MutableConfig().verbosity = 2;
|
||||
std::vector<std::string> replies;
|
||||
// Using the turn structure worsens results sometimes.
|
||||
// However, some models need the turn structure to work.
|
||||
// It would be good to make these tests more consistent.
|
||||
if (s_env->GetModel()->Info().model == Model::GEMMA2_27B ||
|
||||
s_env->GetModel()->Info().model == Model::GRIFFIN_2B) {
|
||||
for (QueryResult result : s_env->BatchQueryModel(inputs)) {
|
||||
replies.push_back(result.response);
|
||||
}
|
||||
return replies;
|
||||
}
|
||||
// Otherwise, do not use turn structure.
|
||||
std::vector<std::vector<int>> prompts_vector;
|
||||
prompts_vector.reserve(inputs.size());
|
||||
for (const auto& input_string : inputs) {
|
||||
prompts_vector.push_back(s_env->TokenizeAndPrependBOS(input_string));
|
||||
}
|
||||
std::vector<PromptTokens> prompt_spans;
|
||||
for (const auto& prompt : prompts_vector) {
|
||||
prompt_spans.push_back(PromptTokens(prompt.data(), prompt.size()));
|
||||
}
|
||||
QueriesPromptTokens prompts(prompt_spans.data(), prompt_spans.size());
|
||||
for (const QueryResult& result : s_env->BatchQueryModel(prompts)) {
|
||||
for (const QueryResult& result : s_env->BatchQueryModel(inputs)) {
|
||||
replies.push_back(result.response);
|
||||
}
|
||||
return replies;
|
||||
}
|
||||
|
||||
void GenerateTokens(std::vector<std::string> &kQA, size_t num_questions) {
|
||||
ASSERT_NE(s_env->GetModel(), nullptr);
|
||||
|
||||
std::vector<std::string> inputs;
|
||||
for (size_t i = 0; i < num_questions; ++i) {
|
||||
inputs.push_back(kQA[i]);
|
||||
}
|
||||
std::vector<std::string> responses = BatchGemmaReply(inputs);
|
||||
for (size_t i = 0; i < num_questions; ++i) {
|
||||
std::string response = responses.at(i);
|
||||
fprintf(stderr, "Batch answer %zu '%s'\n\n", i + 1, response.c_str());
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(GemmaTest, RandomQuestionsBatched) {
|
||||
s_env->MutableConfig().decode_qbatch_size = 3;
|
||||
s_env->MutableConfig().verbosity = 5;
|
||||
|
||||
static std::vector<std::string> kQA = {
|
||||
TEST_F(GemmaBatchBench, RandomQuestionsBatched) {
|
||||
const std::vector<std::string> questions = {
|
||||
{"Write me a poem about Australia?"},
|
||||
{"What's the history of Denmark?"},
|
||||
{"Write me a comedy story about the USA."},
|
||||
|
|
@ -128,13 +83,27 @@ TEST_F(GemmaTest, RandomQuestionsBatched) {
|
|||
{"Tell me about space travel."},
|
||||
{"Explain to me how electric cars work."},
|
||||
};
|
||||
static const size_t kNum = kQA.size();
|
||||
GenerateTokens(kQA, kNum);
|
||||
|
||||
// Fills prompts round robin from `questions` until the desired batch size.
|
||||
std::vector<std::string> inputs;
|
||||
inputs.reserve(s_env->MutableConfig().decode_qbatch_size);
|
||||
size_t qpos = 0;
|
||||
for (size_t i = 0; i < inputs.capacity(); ++i) {
|
||||
inputs.push_back(questions[qpos++]);
|
||||
if (qpos == questions.size()) qpos = 0;
|
||||
}
|
||||
std::vector<std::string> responses = BatchGemmaReply(inputs);
|
||||
for (size_t i = 0; i < hwy::Unpredictable1() * 3; ++i) {
|
||||
fprintf(stderr, "Batch answer %zu '%s'\n\n", i, responses[i].c_str());
|
||||
}
|
||||
|
||||
PROFILER_PRINT_RESULTS();
|
||||
}
|
||||
} // namespace
|
||||
} // namespace gcpp
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
fprintf(stderr, "GemmaEnv setup..\n");
|
||||
gcpp::GemmaEnv env(argc, argv);
|
||||
gcpp::s_env = &env;
|
||||
|
||||
|
|
|
|||
|
|
@ -21,7 +21,8 @@
|
|||
#include <vector>
|
||||
|
||||
#include "evals/benchmark_helper.h"
|
||||
#include "gemma/common.h"
|
||||
#include "gemma/configs.h"
|
||||
#include "io/io.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/tests/hwy_gtest.h"
|
||||
|
||||
|
|
@ -36,132 +37,75 @@
|
|||
namespace gcpp {
|
||||
namespace {
|
||||
|
||||
// Shared state. Requires argc/argv, so construct in main and use the same raw
|
||||
// pointer approach as in benchmarks.cc. Note that the style guide forbids
|
||||
// non-local static variables with dtors.
|
||||
GemmaEnv* s_env = nullptr;
|
||||
|
||||
class GemmaTest : public ::testing::Test {
|
||||
protected:
|
||||
std::string GemmaReply(const std::string& prompt) {
|
||||
s_env->SetMaxGeneratedTokens(2048);
|
||||
s_env->MutableConfig().temperature = 0.0f; // deterministic
|
||||
s_env->MutableConfig().verbosity = 0;
|
||||
// Using the turn structure worsens results sometimes.
|
||||
// However, some models need the turn structure to work.
|
||||
// It would be good to make these tests more consistent.
|
||||
if (s_env->GetModel()->Info().model == Model::GEMMA2_27B ||
|
||||
s_env->GetModel()->Info().model == Model::GRIFFIN_2B) {
|
||||
std::string mutable_prompt = prompt;
|
||||
QueryResult result = s_env->QueryModel(mutable_prompt); // Uses turns.
|
||||
return result.response;
|
||||
}
|
||||
// Otherwise, do not use turn structure.
|
||||
const std::vector<int> tokens = s_env->TokenizeAndPrependBOS(prompt);
|
||||
QueryResult result = s_env->QueryModel(tokens);
|
||||
return result.response;
|
||||
public:
|
||||
// Requires argc/argv, hence do not use `SetUpTestSuite`.
|
||||
static void InitEnv(int argc, char** argv) {
|
||||
HWY_ASSERT(s_env == nullptr); // Should only be called once.
|
||||
s_env = new GemmaEnv(argc, argv);
|
||||
const gcpp::ModelConfig& config = s_env->GetGemma()->Config();
|
||||
fprintf(stderr, "Using %s\n", config.Specifier().c_str());
|
||||
}
|
||||
|
||||
static void DeleteEnv() { delete s_env; }
|
||||
|
||||
protected:
|
||||
std::vector<std::string> BatchGemmaReply(
|
||||
const std::vector<std::string>& inputs) {
|
||||
HWY_ASSERT(s_env); // must have called InitEnv()
|
||||
s_env->SetMaxGeneratedTokens(64);
|
||||
s_env->MutableConfig().temperature = 0.0f; // deterministic
|
||||
s_env->MutableConfig().verbosity = 0;
|
||||
// Always use turn structure (WrapAndTokenize).
|
||||
std::vector<std::string> replies;
|
||||
// Using the turn structure worsens results sometimes.
|
||||
// However, some models need the turn structure to work.
|
||||
// It would be good to make these tests more consistent.
|
||||
if (s_env->GetModel()->Info().model == Model::GEMMA2_27B ||
|
||||
s_env->GetModel()->Info().model == Model::GRIFFIN_2B) {
|
||||
for (QueryResult result : s_env->BatchQueryModel(inputs)) {
|
||||
replies.push_back(result.response);
|
||||
}
|
||||
return replies;
|
||||
}
|
||||
// Otherwise, do not use turn structure.
|
||||
std::vector<std::vector<int>> prompts_vector;
|
||||
prompts_vector.reserve(inputs.size());
|
||||
for (const auto& input_string : inputs) {
|
||||
prompts_vector.push_back(s_env->TokenizeAndPrependBOS(input_string));
|
||||
}
|
||||
std::vector<PromptTokens> prompt_spans;
|
||||
for (const auto& prompt : prompts_vector) {
|
||||
prompt_spans.push_back(PromptTokens(prompt.data(), prompt.size()));
|
||||
}
|
||||
QueriesPromptTokens prompts(prompt_spans.data(), prompt_spans.size());
|
||||
for (const QueryResult& result : s_env->BatchQueryModel(prompts)) {
|
||||
for (QueryResult result : s_env->BatchQueryModel(inputs)) {
|
||||
replies.push_back(result.response);
|
||||
}
|
||||
return replies;
|
||||
}
|
||||
|
||||
void TestQuestions(const char* kQA[][2], size_t num_questions, bool batch) {
|
||||
ASSERT_NE(s_env->GetModel(), nullptr);
|
||||
if (batch) {
|
||||
std::vector<std::string> inputs;
|
||||
for (size_t i = 0; i < num_questions; ++i) {
|
||||
fprintf(stderr, "Batch Question %zu\n\n", i + 1);
|
||||
inputs.push_back(kQA[i][0]);
|
||||
}
|
||||
std::vector<std::string> responses = BatchGemmaReply(inputs);
|
||||
for (size_t i = 0; i < num_questions; ++i) {
|
||||
std::string response = responses.at(i);
|
||||
fprintf(stderr, "Batch answer %zu '%s'\n\n", i + 1, response.c_str());
|
||||
EXPECT_TRUE(response.find(kQA[i][1]) != std::string::npos); // NOLINT
|
||||
}
|
||||
} else {
|
||||
for (size_t i = 0; i < num_questions; ++i) {
|
||||
fprintf(stderr, "Question %zu\n\n", i + 1);
|
||||
std::string response = GemmaReply(kQA[i][0]);
|
||||
fprintf(stderr, "'%s'\n\n", response.c_str());
|
||||
EXPECT_TRUE(response.find(kQA[i][1]) != std::string::npos); // NOLINT
|
||||
}
|
||||
}
|
||||
}
|
||||
// Shared state. Requires argc/argv, so construct in main via InitEnv.
|
||||
// Note that the style guide forbids non-local static variables with dtors.
|
||||
static GemmaEnv* s_env;
|
||||
};
|
||||
|
||||
TEST_F(GemmaTest, GeographyBatched) {
|
||||
s_env->MutableConfig().decode_qbatch_size = 3;
|
||||
// 6 are enough to test batching and the loop.
|
||||
GemmaEnv* GemmaTest::s_env = nullptr;
|
||||
|
||||
TEST_F(GemmaTest, Batched) {
|
||||
// Test remainder handling in MatMul (four rows per tile), but avoid a
|
||||
// second batch in debug builds to speed up the test.
|
||||
s_env->MutableConfig().decode_qbatch_size = HWY_IS_DEBUG_BUILD ? 6 : 3;
|
||||
static const char* kQA[][2] = {
|
||||
{"What is the capital of Australia?", "Canberra"},
|
||||
{"What is the capital of Denmark?", "Copenhagen"},
|
||||
{"Ljubljana is the capital of which country?", "Slovenia"},
|
||||
{"Is Chicago a country?", "city"},
|
||||
{"How many states does the US have?", "50"},
|
||||
{"What is the Pacific?", "ocean"},
|
||||
};
|
||||
static const size_t kNum = sizeof(kQA) / sizeof(kQA[0]);
|
||||
TestQuestions(kQA, HWY_MIN(kNum, 3), /*batch=*/false);
|
||||
TestQuestions(kQA, 1, /*batch=*/true);
|
||||
TestQuestions(kQA, kNum, /*batch=*/true);
|
||||
}
|
||||
|
||||
TEST_F(GemmaTest, History) {
|
||||
static const char* kQA[][2] = {
|
||||
{"When was the battle of Hastings?", "1066"},
|
||||
};
|
||||
static const size_t kNum = sizeof(kQA) / sizeof(kQA[0]);
|
||||
TestQuestions(kQA, kNum, /*batch=*/false);
|
||||
}
|
||||
|
||||
TEST_F(GemmaTest, Arithmetic) {
|
||||
static const char* kQA[][2] = {
|
||||
{"what is 13 + 14?", "27"},
|
||||
{"what is 7 * 8?", "56"},
|
||||
};
|
||||
static const size_t kNum = sizeof(kQA) / sizeof(kQA[0]);
|
||||
TestQuestions(kQA, kNum, /*batch=*/false);
|
||||
const size_t kNum = sizeof(kQA) / sizeof(kQA[0]);
|
||||
std::vector<std::string> inputs;
|
||||
for (size_t i = 0; i < kNum; ++i) {
|
||||
inputs.push_back(kQA[i][0]);
|
||||
}
|
||||
std::vector<std::string> responses = BatchGemmaReply(inputs);
|
||||
HWY_ASSERT(responses.size() == kNum);
|
||||
for (size_t i = 0; i < kNum; ++i) {
|
||||
fprintf(stderr, "#%zu: '%s'\n\n", i, responses[i].c_str());
|
||||
EXPECT_TRUE(responses[i].find(kQA[i][1]) != std::string::npos); // NOLINT
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(GemmaTest, Multiturn) {
|
||||
Gemma* model = s_env->GetModel();
|
||||
ASSERT_NE(model, nullptr);
|
||||
const Gemma* model = s_env->GetGemma();
|
||||
const ModelConfig& config = model->Config();
|
||||
size_t abs_pos = 0;
|
||||
std::string response;
|
||||
auto stream_token = [&](int token, float) {
|
||||
if (token == EOS_ID) return true;
|
||||
auto stream_token = [&](size_t query_idx, size_t pos, int token, float) {
|
||||
HWY_ASSERT(query_idx == 0);
|
||||
HWY_ASSERT(pos == abs_pos);
|
||||
++abs_pos;
|
||||
if (config.IsEOS(token)) return true;
|
||||
std::string token_text;
|
||||
EXPECT_TRUE(
|
||||
model->Tokenizer().Decode(std::vector<int>{token}, &token_text));
|
||||
|
|
@ -173,83 +117,44 @@ TEST_F(GemmaTest, Multiturn) {
|
|||
.temperature = 0.0f,
|
||||
.gen = &s_env->MutableGen(),
|
||||
.verbosity = 2,
|
||||
.stream_token = stream_token,
|
||||
.batch_stream_token = stream_token,
|
||||
};
|
||||
TimingInfo timing_info{.verbosity = 0};
|
||||
// First "say" something slightly unusual.
|
||||
std::string mutable_prompt = "I have a car and its color is turquoise.";
|
||||
std::vector<int> tokens = WrapAndTokenize(model->Tokenizer(), model->Info(),
|
||||
abs_pos, mutable_prompt);
|
||||
std::vector<int> tokens =
|
||||
WrapAndTokenize(model->Tokenizer(), model->ChatTemplate(),
|
||||
config.wrapping, abs_pos, mutable_prompt);
|
||||
|
||||
model->Generate(runtime_config, tokens, abs_pos, s_env->MutableKVCache(),
|
||||
timing_info);
|
||||
s_env->MutableEnv(), timing_info);
|
||||
// Note: we do not rewind any <end_of_turn> tokens here. If the model
|
||||
// produced one and WrapAndTokenize() inserts another one, it will just be
|
||||
// duplicated.
|
||||
mutable_prompt = "Please repeat all prior statements.";
|
||||
tokens = WrapAndTokenize(model->Tokenizer(), model->Info(), abs_pos,
|
||||
mutable_prompt);
|
||||
tokens = WrapAndTokenize(model->Tokenizer(), model->ChatTemplate(),
|
||||
config.wrapping, abs_pos, mutable_prompt);
|
||||
|
||||
// Reset the `response` string here, then check that the model actually has
|
||||
// access to the previous turn by asking to reproduce.
|
||||
response.clear();
|
||||
model->Generate(runtime_config, tokens, abs_pos, s_env->MutableKVCache(),
|
||||
timing_info);
|
||||
fprintf(stderr, "decoded: %s\n", response.c_str());
|
||||
s_env->MutableEnv(), timing_info);
|
||||
fprintf(stderr, "decoded: '%s'\n", response.c_str());
|
||||
bool remembered_turquoise =
|
||||
response.find("turquoise") != std::string::npos; // NOLINT
|
||||
bool remembered_car = response.find("car") != std::string::npos; // NOLINT
|
||||
EXPECT_TRUE(remembered_turquoise || remembered_car);
|
||||
}
|
||||
|
||||
static const char kJingleBells[] = R"(
|
||||
Dashing through the snow
|
||||
In a one-horse open sleigh
|
||||
O'er the fields we go
|
||||
Laughing all the way
|
||||
Bells on bobtails ring
|
||||
Making spirits bright
|
||||
What fun it is to ride and sing
|
||||
A sleighing song tonight
|
||||
)";
|
||||
|
||||
// The "Hay Draft" of the Gettysburg Address.
|
||||
static const char kGettysburg[] = {
|
||||
"Four score and seven years ago our fathers brought forth, upon this "
|
||||
"continent, a new nation, conceived in Liberty, and dedicated to the "
|
||||
"proposition that all men are created equal.\n\nNow we are engaged in a "
|
||||
"great civil war, testing whether that nation, or any nation, so "
|
||||
"conceived, and so dedicated, can long endure. We are met here on a great "
|
||||
"battlefield of that war. We have come to dedicate a portion of it as a "
|
||||
"final resting place for those who here gave their lives that that nation "
|
||||
"might live. It is altogether fitting and proper that we should do "
|
||||
"this.\n\nBut in a larger sense we can not dedicate -- we can not "
|
||||
"consecrate -- we can not hallow this ground. The brave men, living and "
|
||||
"dead, who struggled, here, have consecrated it far above our poor power "
|
||||
"to add or detract. The world will little note, nor long remember, what we "
|
||||
"say here, but can never forget what they did here. It is for us, the "
|
||||
"living, rather to be dedicated here to the unfinished work which they "
|
||||
"have, thus far, so nobly carried on. It is rather for us to be here "
|
||||
"dedicated to the great task remaining before us -- that from these "
|
||||
"honored dead we take increased devotion to that cause for which they here "
|
||||
"gave the last full measure of devotion -- that we here highly resolve "
|
||||
"that these dead shall not have died in vain; that this nation shall have "
|
||||
"a new birth of freedom; and that this government of the people, by the "
|
||||
"people, for the people, shall not perish from the earth.\n"};
|
||||
|
||||
TEST_F(GemmaTest, CrossEntropySmall) {
|
||||
ASSERT_NE(s_env->GetModel(), nullptr);
|
||||
HWY_ASSERT(s_env->GetGemma() != nullptr);
|
||||
const ModelConfig& config = s_env->GetGemma()->Config();
|
||||
static const char kSmall[] =
|
||||
"The capital of Hungary is Budapest which is located in Europe.";
|
||||
float entropy = s_env->CrossEntropy(kSmall);
|
||||
fprintf(stderr, "per-token entropy: %f\n", entropy);
|
||||
switch (s_env->GetModel()->Info().model) {
|
||||
case gcpp::Model::GEMMA_2B:
|
||||
// 2B v.1 and v.1.1 produce slightly different results.
|
||||
EXPECT_NEAR(entropy, 2.6f, 0.2f);
|
||||
break;
|
||||
case gcpp::Model::GEMMA_7B:
|
||||
// 7B v.1 and v.1.1 produce slightly different results.
|
||||
EXPECT_NEAR(entropy, 2.8f, 0.2f);
|
||||
break;
|
||||
switch (config.model) {
|
||||
case gcpp::Model::GRIFFIN_2B:
|
||||
EXPECT_NEAR(entropy, 2.61f, 0.02f);
|
||||
break;
|
||||
|
|
@ -268,76 +173,14 @@ TEST_F(GemmaTest, CrossEntropySmall) {
|
|||
}
|
||||
}
|
||||
|
||||
TEST_F(GemmaTest, CrossEntropyJingleBells) {
|
||||
ASSERT_NE(s_env->GetModel(), nullptr);
|
||||
float entropy = s_env->CrossEntropy(kJingleBells);
|
||||
fprintf(stderr, "per-token entropy: %f\n", entropy);
|
||||
switch (s_env->GetModel()->Info().model) {
|
||||
case gcpp::Model::GEMMA_2B:
|
||||
// 2B v.1 and v.1.1 produce slightly different results.
|
||||
EXPECT_NEAR(entropy, 1.9f, 0.2f);
|
||||
break;
|
||||
case gcpp::Model::GEMMA_7B:
|
||||
// 7B v.1 and v.1.1 produce slightly different results.
|
||||
EXPECT_NEAR(entropy, 1.07f, 0.05f);
|
||||
break;
|
||||
case gcpp::Model::GRIFFIN_2B:
|
||||
EXPECT_NEAR(entropy, 1.62f, 0.02f);
|
||||
break;
|
||||
case gcpp::Model::GEMMA2_2B:
|
||||
EXPECT_NEAR(entropy, 0.49f, 0.02f);
|
||||
break;
|
||||
case gcpp::Model::GEMMA2_9B:
|
||||
EXPECT_NEAR(entropy, 0.37f, 0.02f);
|
||||
break;
|
||||
case gcpp::Model::GEMMA2_27B:
|
||||
EXPECT_NEAR(entropy, 0.33f, 0.02f);
|
||||
break;
|
||||
default:
|
||||
FAIL() << "no entropy expectation for this model";
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(GemmaTest, CrossEntropyGettysburg) {
|
||||
ASSERT_NE(s_env->GetModel(), nullptr);
|
||||
float entropy = s_env->CrossEntropy(kGettysburg);
|
||||
fprintf(stderr, "per-token entropy: %f\n", entropy);
|
||||
switch (s_env->GetModel()->Info().model) {
|
||||
case gcpp::Model::GEMMA_2B:
|
||||
// 2B v.1 and v.1.1 produce slightly different results.
|
||||
EXPECT_NEAR(entropy, 1.1f, 0.1f);
|
||||
break;
|
||||
case gcpp::Model::GEMMA_7B:
|
||||
// 7B v.1 and v.1.1 produce slightly different results.
|
||||
EXPECT_NEAR(entropy, 0.75f, 0.1f);
|
||||
break;
|
||||
case gcpp::Model::GRIFFIN_2B:
|
||||
EXPECT_NEAR(entropy, 0.71f, 0.02f);
|
||||
break;
|
||||
case gcpp::Model::GEMMA2_2B:
|
||||
EXPECT_NEAR(entropy, 0.20f, 0.02f);
|
||||
break;
|
||||
case gcpp::Model::GEMMA2_9B:
|
||||
EXPECT_NEAR(entropy, 0.15f, 0.02f);
|
||||
break;
|
||||
case gcpp::Model::GEMMA2_27B:
|
||||
EXPECT_NEAR(entropy, 0.14f, 0.02f);
|
||||
break;
|
||||
default:
|
||||
FAIL() << "no entropy expectation for this model";
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace gcpp
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
gcpp::GemmaEnv env(argc, argv);
|
||||
gcpp::s_env = &env;
|
||||
|
||||
testing::InitGoogleTest(&argc, argv);
|
||||
|
||||
return RUN_ALL_TESTS();
|
||||
gcpp::InternalInit();
|
||||
gcpp::GemmaTest::InitEnv(argc, argv);
|
||||
int ret = RUN_ALL_TESTS();
|
||||
gcpp::GemmaTest::DeleteEnv();
|
||||
return ret;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -19,12 +19,11 @@
|
|||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "compression/io.h" // Path
|
||||
#include "evals/benchmark_helper.h"
|
||||
#include "gemma/gemma.h" // Gemma
|
||||
#include "io/io.h" // Path
|
||||
#include "util/args.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
#include "hwy/highway.h"
|
||||
#include "hwy/profiler.h"
|
||||
#include "nlohmann/json.hpp"
|
||||
|
|
@ -89,7 +88,7 @@ void Run(GemmaEnv& env, JsonArgs& json) {
|
|||
"A", "B", "C", "D", //
|
||||
" A", " B", " C", " D", //
|
||||
"**", "**:", ":**", "The", "Answer", "is", ":", "."};
|
||||
const TokenSet accept_set(env.GetModel()->Tokenizer(), accept_strings);
|
||||
const TokenSet accept_set(env.GetGemma()->Tokenizer(), accept_strings);
|
||||
|
||||
for (auto sample : json_data["samples"]) {
|
||||
const int id = sample["i"];
|
||||
|
|
@ -131,8 +130,9 @@ void Run(GemmaEnv& env, JsonArgs& json) {
|
|||
.verbosity = env.Verbosity(),
|
||||
.stream_token = stream_token,
|
||||
};
|
||||
env.GetModel()->Generate(runtime_config, prompt, /*pos=*/0,
|
||||
env.MutableKVCache(), timing_info);
|
||||
env.GetGemma()->Generate(runtime_config, prompt, /*pos=*/0,
|
||||
env.MutableKVCache(), env.MutableEnv(),
|
||||
timing_info);
|
||||
|
||||
std::string output_string = env.StringFromTokens(predicted_token_ids);
|
||||
fprintf(stderr, "Correct %s, model '%s'\n", correct_answer.c_str(),
|
||||
|
|
|
|||
|
|
@ -10,13 +10,11 @@ cc_binary(
|
|||
name = "hello_world",
|
||||
srcs = ["run.cc"],
|
||||
deps = [
|
||||
# Placeholder for internal dep, do not remove.,
|
||||
"//:app",
|
||||
"//:args",
|
||||
"//:gemma_args",
|
||||
"//:gemma_lib",
|
||||
"//:threading",
|
||||
"//:threading_context",
|
||||
"//:tokenizer",
|
||||
"@highway//:hwy",
|
||||
"@highway//:thread_pool",
|
||||
],
|
||||
)
|
||||
|
|
|
|||
|
|
@ -23,19 +23,15 @@
|
|||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
// Placeholder for internal header, do not modify.
|
||||
#include "gemma/gemma.h"
|
||||
#include "gemma/gemma_args.h" // LoaderArgs
|
||||
#include "gemma/tokenizer.h"
|
||||
#include "util/app.h" // LoaderArgs
|
||||
#include "util/args.h"
|
||||
#include "util/threading.h"
|
||||
#include "util/threading_context.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
||||
int main(int argc, char **argv) { {
|
||||
// Placeholder for internal init, do not modify.
|
||||
}
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
gcpp::LoaderArgs loader(argc, argv);
|
||||
gcpp::InferenceArgs inference(argc, argv);
|
||||
gcpp::AppArgs app(argc, argv);
|
||||
|
|
|
|||
|
|
@ -10,10 +10,10 @@ cc_library(
|
|||
name = "gemma",
|
||||
hdrs = ["gemma.hpp"],
|
||||
deps = [
|
||||
"//:app",
|
||||
"//:gemma_args",
|
||||
"//:gemma_lib",
|
||||
"//:ops",
|
||||
"//:threading",
|
||||
"//:matmul",
|
||||
"//:threading_context",
|
||||
"//:tokenizer",
|
||||
"@highway//:hwy",
|
||||
],
|
||||
|
|
@ -24,15 +24,6 @@ cc_binary(
|
|||
srcs = ["run.cc"],
|
||||
deps = [
|
||||
":gemma",
|
||||
# Placeholder for internal dep, do not remove.,
|
||||
"//:app",
|
||||
"//:args",
|
||||
"//:common",
|
||||
"//:gemma_lib",
|
||||
"//:ops",
|
||||
"//:threading",
|
||||
"//:tokenizer",
|
||||
"@highway//:hwy",
|
||||
"@highway//:thread_pool",
|
||||
"//:gemma_args",
|
||||
],
|
||||
)
|
||||
|
|
|
|||
|
|
@ -14,10 +14,11 @@
|
|||
|
||||
cmake_minimum_required(VERSION 3.11)
|
||||
project(simplified_gemma)
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
|
||||
include(FetchContent)
|
||||
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG c5bebf84ad01edec97e336f5c97ca4e0df6b4d06)
|
||||
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 9414b48aeec251b69e6cadbfa42bebb5ddae1c34)
|
||||
FetchContent_MakeAvailable(highway)
|
||||
FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 53de76561cfc149d3c01037f0595669ad32a5e7c)
|
||||
FetchContent_MakeAvailable(sentencepiece)
|
||||
|
|
@ -31,7 +32,7 @@ if (NOT BUILD_MODE)
|
|||
endif()
|
||||
if (BUILD_MODE STREQUAL "local")
|
||||
# Relative path to gemma.cpp from examples/simplified_gemma/build/
|
||||
FetchContent_Declare(gemma SOURCE_DIR ../../..)
|
||||
FetchContent_Declare(gemma SOURCE_DIR ../../..)
|
||||
else()
|
||||
FetchContent_Declare(gemma GIT_REPOSITORY https://github.com/google/gemma.cpp.git GIT_TAG a9aa63fd2ea6b786ed0706d619588bfe2d43370e)
|
||||
endif()
|
||||
|
|
|
|||
|
|
@ -41,7 +41,7 @@ gemma.cpp specifying the tokenizer, compressed weights file, and model type, for
|
|||
example:
|
||||
|
||||
```sh
|
||||
./simplified_gemma --tokenizer tokenizer.spm --compressed_weights 2b-it-sfp.sbs --model 2b-it
|
||||
./simplified_gemma --tokenizer tokenizer.spm --weights 2b-it-sfp.sbs --model 2b-it
|
||||
```
|
||||
|
||||
Should print a greeting to the terminal:
|
||||
|
|
|
|||
|
|
@ -24,55 +24,39 @@
|
|||
#include <vector>
|
||||
|
||||
#include "third_party/gemma_cpp/gemma/gemma.h"
|
||||
#include "third_party/gemma_cpp/gemma/gemma_args.h" // LoaderArgs
|
||||
#include "third_party/gemma_cpp/gemma/tokenizer.h"
|
||||
#include "third_party/gemma_cpp/ops/matmul.h"
|
||||
#include "third_party/gemma_cpp/util/app.h" // LoaderArgs
|
||||
#include "third_party/gemma_cpp/util/threading.h"
|
||||
#include "third_party/gemma_cpp/util/threading_context.h"
|
||||
#include "third_party/highway/hwy/base.h"
|
||||
|
||||
class SimplifiedGemma {
|
||||
public:
|
||||
SimplifiedGemma(const gcpp::LoaderArgs& loader,
|
||||
const gcpp::InferenceArgs& inference = gcpp::InferenceArgs(),
|
||||
const gcpp::AppArgs& app = gcpp::AppArgs())
|
||||
: loader_(loader),
|
||||
inference_(inference),
|
||||
app_(app),
|
||||
topology_(gcpp::CreateTopology(app_)),
|
||||
pools_(gcpp::CreatePools(topology_, app_)),
|
||||
env_(topology_, pools_),
|
||||
model_(gcpp::CreateGemma(loader_, env_)) {
|
||||
Init();
|
||||
}
|
||||
|
||||
SimplifiedGemma(int argc, char** argv)
|
||||
: loader_(argc, argv, /*validate=*/true),
|
||||
inference_(argc, argv),
|
||||
app_(argc, argv),
|
||||
topology_(gcpp::CreateTopology(app_)),
|
||||
pools_(gcpp::CreatePools(topology_, app_)),
|
||||
env_(topology_, pools_),
|
||||
model_(gcpp::CreateGemma(loader_, env_)) {
|
||||
Init();
|
||||
}
|
||||
|
||||
void Init() {
|
||||
// Instantiate model and KV Cache
|
||||
kv_cache_ = gcpp::KVCache::Create(model_.GetModelConfig(),
|
||||
inference_.prefill_tbatch_size);
|
||||
|
||||
const gcpp::ThreadingArgs& threading = gcpp::ThreadingArgs(),
|
||||
const gcpp::InferenceArgs& inference = gcpp::InferenceArgs())
|
||||
: ctx_(UpdateArgs(threading, inference)),
|
||||
env_(ctx_),
|
||||
gemma_(loader, inference, ctx_),
|
||||
kv_cache_(gemma_.Config(), inference, ctx_.allocator) {
|
||||
// Initialize random number generator
|
||||
std::random_device rd;
|
||||
gen_.seed(rd());
|
||||
}
|
||||
|
||||
SimplifiedGemma(int argc, char** argv)
|
||||
: SimplifiedGemma(gcpp::LoaderArgs(argc, argv),
|
||||
gcpp::ThreadingArgs(argc, argv),
|
||||
gcpp::InferenceArgs(argc, argv)) {}
|
||||
|
||||
void Generate(std::string& prompt, size_t max_generated_tokens = 1024,
|
||||
float temperature = 0.7,
|
||||
const std::set<int>& reject_tokens = {}) {
|
||||
size_t generated = 0;
|
||||
|
||||
const std::vector<int> tokens = gcpp::WrapAndTokenize(
|
||||
model_.Tokenizer(), loader_.Info(), generated, prompt);
|
||||
gemma_.Tokenizer(), gemma_.ChatTemplate(),
|
||||
gemma_.Config().wrapping, generated, prompt);
|
||||
const size_t prompt_size = tokens.size();
|
||||
|
||||
// This callback function gets invoked every time a token is generated
|
||||
|
|
@ -80,9 +64,9 @@ class SimplifiedGemma {
|
|||
++generated;
|
||||
if (generated < prompt_size) {
|
||||
// print feedback
|
||||
} else if (!this->model_.GetModelConfig().IsEOS(token)) {
|
||||
} else if (!gemma_.Config().IsEOS(token)) {
|
||||
std::string token_text;
|
||||
HWY_ASSERT(this->model_.Tokenizer().Decode({token}, &token_text));
|
||||
HWY_ASSERT(gemma_.Tokenizer().Decode({token}, &token_text));
|
||||
std::cout << token_text << std::flush;
|
||||
}
|
||||
return true;
|
||||
|
|
@ -100,19 +84,15 @@ class SimplifiedGemma {
|
|||
return !reject_tokens.contains(token);
|
||||
},
|
||||
};
|
||||
model_.Generate(runtime_config, tokens, 0, kv_cache_, timing_info);
|
||||
gemma_.Generate(runtime_config, tokens, 0, kv_cache_, env_, timing_info);
|
||||
}
|
||||
~SimplifiedGemma() = default;
|
||||
|
||||
private:
|
||||
gcpp::LoaderArgs loader_;
|
||||
gcpp::InferenceArgs inference_;
|
||||
gcpp::AppArgs app_;
|
||||
gcpp::BoundedTopology topology_;
|
||||
gcpp::NestedPools pools_;
|
||||
gcpp::ThreadingContext ctx_;
|
||||
gcpp::MatMulEnv env_;
|
||||
gcpp::Gemma model_;
|
||||
gcpp::Gemma gemma_;
|
||||
gcpp::KVCache kv_cache_;
|
||||
std::mt19937 gen_;
|
||||
std::string validation_error_;
|
||||
};
|
||||
};
|
||||
|
|
|
|||
|
|
@ -17,30 +17,25 @@
|
|||
|
||||
#include <string>
|
||||
|
||||
// Placeholder for internal header, do not modify.
|
||||
#include "third_party/gemma_cpp/examples/simplified_gemma/gemma.hpp"
|
||||
#include "util/app.h" // LoaderArgs
|
||||
#include "gemma/gemma_args.h" // LoaderArgs
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
{
|
||||
// Placeholder for internal init, do not modify.
|
||||
}
|
||||
|
||||
// Standard usage: LoaderArgs takes argc and argv as input, then parses
|
||||
// necessary flags.
|
||||
gcpp::LoaderArgs loader(argc, argv, /*validate=*/true);
|
||||
gcpp::LoaderArgs loader(argc, argv);
|
||||
|
||||
// Optional: LoaderArgs can also take tokenizer and weights paths directly.
|
||||
//
|
||||
// gcpp::LoaderArgs loader("/path/to/tokenizer", "/path/to/weights",
|
||||
// "model_identifier");
|
||||
|
||||
// Optional: InferenceArgs and AppArgs can be passed in as well. If not
|
||||
// Optional: ThreadingArgs and InferenceArgs can be passed in as well. If not
|
||||
// specified, default values will be used.
|
||||
//
|
||||
// gcpp::InferenceArgs inference(argc, argv);
|
||||
// gcpp::AppArgs app(argc, argv);
|
||||
// SimplifiedGemma gemma(loader, inference, app);
|
||||
// gcpp::ThreadingArgs threading(argc, argv);
|
||||
// SimplifiedGemma gemma(loader, threading, inference);
|
||||
|
||||
SimplifiedGemma gemma(loader);
|
||||
std::string prompt = "Write a greeting to the world.";
|
||||
|
|
|
|||
|
|
@ -16,104 +16,199 @@
|
|||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_ACTIVATIONS_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_ACTIVATIONS_H_
|
||||
|
||||
#include <math.h> // sqrtf
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include "compression/shared.h" // BF16
|
||||
#include "gemma/configs.h"
|
||||
#include "ops/matmul.h" // MatMulEnv
|
||||
#include "ops/ops.h" // CreateInvTimescale
|
||||
#include "util/allocator.h" // RowVectorBatch
|
||||
#include "util/threading.h"
|
||||
#include "hwy/base.h" // HWY_DASSERT
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
#include <atomic>
|
||||
#include <vector>
|
||||
|
||||
#include "gemma/configs.h" // ModelConfig
|
||||
#include "ops/matmul.h" // MatMulEnv
|
||||
#include "ops/ops.h" // CreateInvTimescale
|
||||
#include "util/allocator.h" // Allocator
|
||||
#include "util/basics.h" // BF16
|
||||
#include "util/mat.h" // MatStorageT
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
struct Activations {
|
||||
explicit Activations(const ModelConfig& config)
|
||||
: weights_config(config),
|
||||
layer_config(config.layer_configs[0]),
|
||||
seq_len(config.seq_len),
|
||||
cache_pos_size(config.CachePosSize()) {}
|
||||
struct GriffinActivations {
|
||||
GriffinActivations(const ModelConfig& config, size_t batch_size,
|
||||
const Allocator& allocator)
|
||||
: griffin_x(
|
||||
MatFactory("griffin_x", batch_size, config.model_dim, allocator)),
|
||||
griffin_y(
|
||||
MatFactory("griffin_y", batch_size, config.model_dim, allocator)),
|
||||
griffin_gate_x(MatFactory("griffin_gate_x", batch_size,
|
||||
config.model_dim, allocator)),
|
||||
griffin_multiplier(MatFactory("griffin_mul", batch_size,
|
||||
config.model_dim, allocator)) {}
|
||||
|
||||
RowVectorBatch<float> x; // input
|
||||
RowVectorBatch<float> q; // query, also KV if MHA.
|
||||
RowVectorBatch<float> logits;
|
||||
void SetBatchSize(size_t batch_size) {
|
||||
if (griffin_x.Rows() == 0) return;
|
||||
griffin_x.OverrideRows(batch_size);
|
||||
griffin_y.OverrideRows(batch_size);
|
||||
griffin_gate_x.OverrideRows(batch_size);
|
||||
griffin_multiplier.OverrideRows(batch_size);
|
||||
}
|
||||
|
||||
// Attention
|
||||
RowVectorBatch<float> pre_att_rms_out;
|
||||
RowVectorBatch<float> att; // attention vector
|
||||
RowVectorBatch<float> att_out; // attention output
|
||||
MatStorageT<float> griffin_x;
|
||||
MatStorageT<float> griffin_y;
|
||||
MatStorageT<float> griffin_gate_x;
|
||||
MatStorageT<float> griffin_multiplier;
|
||||
};
|
||||
|
||||
struct AttentionActivations {
|
||||
// Returns the scale value to use for the query in the attention computation.
|
||||
// Also called by ops_test.
|
||||
static inline float ChooseQueryScale(const ModelConfig& config) {
|
||||
if (config.query_scale == QueryScaleType::SqrtModelDimDivNumHeads)
|
||||
return 1.0f / sqrtf(static_cast<float>(config.model_dim /
|
||||
config.layer_configs[0].heads));
|
||||
// QueryScaleType::SqrtKeySize
|
||||
return 1.0f / sqrtf(static_cast<float>(config.layer_configs[0].qkv_dim));
|
||||
}
|
||||
|
||||
AttentionActivations(
|
||||
const ModelConfig& config, const LayerConfig& layer_config,
|
||||
size_t batch_size, size_t seq_len, const Allocator& allocator,
|
||||
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>>& row_ptrs)
|
||||
: config(config),
|
||||
|
||||
// `vocab_size == 0` means it is for Vit part, VitAttention is still MHA
|
||||
// and does not use an external KV cache.
|
||||
q(MatFactory("q", batch_size,
|
||||
config.vocab_size == 0
|
||||
? layer_config.heads * 3 * layer_config.qkv_dim
|
||||
: layer_config.heads * layer_config.qkv_dim,
|
||||
allocator)),
|
||||
|
||||
pre_att_rms_out(MatFactory("pre_att_rms_out", batch_size,
|
||||
config.model_dim, allocator)),
|
||||
att(MatFactory("att", batch_size, layer_config.heads * seq_len,
|
||||
allocator)),
|
||||
att_out(MatFactory("att_out", batch_size,
|
||||
layer_config.heads * layer_config.qkv_dim,
|
||||
allocator)),
|
||||
att_sums(
|
||||
MatFactory("att_sums", batch_size, config.model_dim, allocator)),
|
||||
|
||||
inv_timescale(
|
||||
CreateInvTimescale(allocator, layer_config.qkv_dim,
|
||||
layer_config.post_qk == PostQKType::HalfRope)),
|
||||
inv_timescale_global(CreateInvTimescale(
|
||||
allocator, layer_config.qkv_dim,
|
||||
layer_config.post_qk == PostQKType::HalfRope, 1000000.0)),
|
||||
|
||||
div_seq_len(static_cast<uint32_t>(seq_len)),
|
||||
div_heads(static_cast<uint32_t>(layer_config.heads)),
|
||||
query_scale(ChooseQueryScale(config)) {
|
||||
// Batch size can be 0 in experimental code so do not assert.
|
||||
if (batch_size == 0) {
|
||||
static std::atomic_flag warned = ATOMIC_FLAG_INIT;
|
||||
if (!warned.test_and_set()) {
|
||||
HWY_WARN("Creating mostly empty activations with a batch_size of 0.");
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// For MatMul outputs, precompute their row pointers.
|
||||
// If we forget any MatMul outputs here, debug builds print a warning but
|
||||
// fill them in each MatMul call.
|
||||
q.AllocateAndAttachRowPtrs(row_ptrs);
|
||||
att_sums.AllocateAndAttachRowPtrs(row_ptrs);
|
||||
}
|
||||
|
||||
void SetBatchSize(size_t batch_size) {
|
||||
q.OverrideRows(batch_size);
|
||||
|
||||
pre_att_rms_out.OverrideRows(batch_size);
|
||||
att.OverrideRows(batch_size);
|
||||
att_out.OverrideRows(batch_size);
|
||||
att_sums.OverrideRows(batch_size);
|
||||
}
|
||||
|
||||
const ModelConfig& config;
|
||||
|
||||
MatStorageT<float> q; // query
|
||||
|
||||
MatStorageT<float> pre_att_rms_out;
|
||||
MatStorageT<float> att; // attention vector
|
||||
MatStorageT<float> att_out; // attention output
|
||||
// Accumulation of attention outputs over heads
|
||||
RowVectorBatch<float> att_sums;
|
||||
|
||||
// Gated FFW
|
||||
RowVectorBatch<BF16> bf_pre_ffw_rms_out;
|
||||
RowVectorBatch<float> C1;
|
||||
RowVectorBatch<float> C2;
|
||||
RowVectorBatch<float> ffw_out;
|
||||
|
||||
// Griffin
|
||||
RowVectorBatch<float> griffin_x;
|
||||
RowVectorBatch<float> griffin_y;
|
||||
RowVectorBatch<float> griffin_gate_x;
|
||||
RowVectorBatch<float> griffin_multiplier;
|
||||
MatStorageT<BF16> att_sums;
|
||||
|
||||
// Rope
|
||||
RowVectorBatch<float> inv_timescale;
|
||||
RowVectorBatch<float> inv_timescale_global;
|
||||
MatStorageT<float> inv_timescale;
|
||||
MatStorageT<float> inv_timescale_global;
|
||||
|
||||
// Dynamic because no default ctor and only initialized in `Allocate`.
|
||||
MatMulEnv* env;
|
||||
hwy::Divisor div_seq_len;
|
||||
// Unfortunately, some models (Griffin) have non-power-of-two heads.
|
||||
hwy::Divisor div_heads;
|
||||
float query_scale;
|
||||
};
|
||||
|
||||
PostQKType post_qk = PostQKType::Rope;
|
||||
// And the config.
|
||||
const ModelConfig& weights_config;
|
||||
const LayerConfig& layer_config;
|
||||
size_t seq_len;
|
||||
size_t cache_pos_size = 0;
|
||||
struct Activations {
|
||||
Activations(const ModelConfig& config, size_t batch_size, size_t seq_len,
|
||||
const Allocator& allocator,
|
||||
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>>& row_ptrs)
|
||||
: layer_config(config.layer_configs[0]),
|
||||
|
||||
void Allocate(size_t batch_size, MatMulEnv* env) {
|
||||
post_qk = layer_config.post_qk;
|
||||
const size_t model_dim = weights_config.model_dim;
|
||||
const size_t ff_hidden_dim = layer_config.ff_hidden_dim;
|
||||
const size_t vocab_size = weights_config.vocab_size;
|
||||
const size_t qkv_dim = layer_config.qkv_dim;
|
||||
const size_t heads = layer_config.heads;
|
||||
x(MatFactory("x", batch_size, config.model_dim, allocator)),
|
||||
logits(MatFactory("logits", batch_size, config.vocab_size, allocator)),
|
||||
|
||||
x = RowVectorBatch<float>(Extents2D(batch_size, model_dim));
|
||||
q = RowVectorBatch<float>(
|
||||
Extents2D(batch_size, heads * layer_config.QStride()));
|
||||
if (vocab_size > 0) {
|
||||
logits = RowVectorBatch<float>(Extents2D(batch_size, vocab_size));
|
||||
}
|
||||
pre_ffw_rms_out(MatFactory("pre_ffw_rms_out", batch_size,
|
||||
config.model_dim, allocator)),
|
||||
C1(MatFactory("C1", batch_size, layer_config.ff_hidden_dim, allocator)),
|
||||
C2(MatFactory("C2", batch_size, layer_config.ff_hidden_dim, allocator)),
|
||||
ffw_out(MatFactory("ffw_out", batch_size, config.model_dim, allocator)),
|
||||
|
||||
pre_att_rms_out = RowVectorBatch<float>(Extents2D(batch_size, model_dim));
|
||||
att = RowVectorBatch<float>(
|
||||
Extents2D(batch_size, heads * weights_config.seq_len));
|
||||
att_out = RowVectorBatch<float>(Extents2D(batch_size, heads * qkv_dim));
|
||||
att_sums = RowVectorBatch<float>(Extents2D(batch_size, model_dim));
|
||||
attention(config, layer_config, batch_size, seq_len, allocator,
|
||||
row_ptrs),
|
||||
griffin(config, config.model == Model::GRIFFIN_2B ? batch_size : 0,
|
||||
allocator) {
|
||||
HWY_ASSERT(batch_size != 0);
|
||||
|
||||
bf_pre_ffw_rms_out = RowVectorBatch<BF16>(Extents2D(batch_size, model_dim));
|
||||
C1 = RowVectorBatch<float>(Extents2D(batch_size, ff_hidden_dim));
|
||||
C2 = RowVectorBatch<float>(Extents2D(batch_size, ff_hidden_dim));
|
||||
ffw_out = RowVectorBatch<float>(Extents2D(batch_size, model_dim));
|
||||
// For MatMul outputs, precompute their row pointers.
|
||||
// If we forget any MatMul outputs here, debug builds print a warning but
|
||||
// fill them in each MatMul call.
|
||||
x.AllocateAndAttachRowPtrs(row_ptrs);
|
||||
logits.AllocateAndAttachRowPtrs(row_ptrs);
|
||||
C1.AllocateAndAttachRowPtrs(row_ptrs);
|
||||
C2.AllocateAndAttachRowPtrs(row_ptrs);
|
||||
ffw_out.AllocateAndAttachRowPtrs(row_ptrs);
|
||||
|
||||
if (layer_config.type == LayerAttentionType::kGriffinRecurrentBlock) {
|
||||
griffin_x = RowVectorBatch<float>(Extents2D(batch_size, model_dim));
|
||||
griffin_y = RowVectorBatch<float>(Extents2D(batch_size, model_dim));
|
||||
griffin_gate_x = RowVectorBatch<float>(Extents2D(batch_size, model_dim));
|
||||
griffin_multiplier =
|
||||
RowVectorBatch<float>(Extents2D(batch_size, model_dim));
|
||||
}
|
||||
|
||||
inv_timescale = CreateInvTimescale(layer_config.qkv_dim,
|
||||
post_qk == PostQKType::HalfRope);
|
||||
inv_timescale_global =
|
||||
CreateInvTimescale(qkv_dim, post_qk == PostQKType::HalfRope, 1000000.0);
|
||||
|
||||
this->env = env;
|
||||
// Note that BindC on any MatMul output considerably slows down Prefill.
|
||||
}
|
||||
|
||||
// Negligible CPU time.
|
||||
void SetBatchSize(size_t batch_size) {
|
||||
x.OverrideRows(batch_size);
|
||||
logits.OverrideRows(batch_size);
|
||||
|
||||
pre_ffw_rms_out.OverrideRows(batch_size);
|
||||
C1.OverrideRows(batch_size);
|
||||
C2.OverrideRows(batch_size);
|
||||
ffw_out.OverrideRows(batch_size);
|
||||
|
||||
attention.SetBatchSize(batch_size);
|
||||
griffin.SetBatchSize(batch_size);
|
||||
}
|
||||
|
||||
const LayerConfig& layer_config;
|
||||
|
||||
MatStorageT<float> x; // input
|
||||
MatStorageT<float> logits;
|
||||
|
||||
// Gated FFW
|
||||
MatStorageT<BF16> pre_ffw_rms_out;
|
||||
// Norm may be large, so prefer to keep as f32.
|
||||
MatStorageT<float> C1;
|
||||
MatStorageT<float> C2;
|
||||
MatStorageT<BF16> ffw_out;
|
||||
|
||||
AttentionActivations attention;
|
||||
GriffinActivations griffin;
|
||||
};
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
|
|
@ -0,0 +1,358 @@
|
|||
// Copyright 2025 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "compression/types.h" // GEMMA_DISABLED_TARGETS
|
||||
#ifndef HWY_DISABLED_TARGETS
|
||||
#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS
|
||||
#endif // HWY_DISABLED_TARGETS
|
||||
|
||||
#include "gemma/activations.h"
|
||||
#include "gemma/configs.h" // kMaxQKVDim
|
||||
#include "gemma/gemma.h"
|
||||
#include "gemma/weights.h"
|
||||
#include "util/threading.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
#include "hwy/profiler.h"
|
||||
|
||||
// Compiles this file for multiple architectures via "foreach_target.h", to
|
||||
// which we pass the filename via macro 'argument'.
|
||||
// clang-format off
|
||||
#undef HWY_TARGET_INCLUDE
|
||||
#define HWY_TARGET_INCLUDE "gemma/attention.cc" // NOLINT
|
||||
// clang-format on
|
||||
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||
#include "hwy/highway.h"
|
||||
// After highway.h
|
||||
#include "compression/compress-inl.h"
|
||||
#include "ops/ops-inl.h"
|
||||
|
||||
HWY_BEFORE_NAMESPACE();
|
||||
namespace gcpp {
|
||||
namespace HWY_NAMESPACE {
|
||||
|
||||
// Computes Q.K scores, which are "logits" (or scores) stored to att.
|
||||
// `k` is a strided view of the kv cache with dimensions [seq_len, qkv_dim].
|
||||
static HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos,
|
||||
const hwy::Divisor& div_seq_len,
|
||||
const float* HWY_RESTRICT q,
|
||||
const MatPtrT<KV_t>& k, float* HWY_RESTRICT att,
|
||||
const size_t worker) {
|
||||
PROFILER_ZONE2(worker, "Gen.Attention.QDotK");
|
||||
if (HWY_LIKELY(last_pos < static_cast<size_t>(div_seq_len.GetDivisor()))) {
|
||||
// Slightly faster: no wraparound.
|
||||
for (size_t pos = start_pos; pos <= last_pos; ++pos) {
|
||||
const float score = Dot(q, k.Row(pos), k.Cols());
|
||||
att[pos] = score;
|
||||
}
|
||||
} else {
|
||||
for (size_t pos = start_pos; pos <= last_pos; ++pos) {
|
||||
const size_t pos_modulo = div_seq_len.Remainder(pos);
|
||||
const float score = Dot(q, k.Row(pos_modulo), k.Cols());
|
||||
att[pos_modulo] = score;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void PositionalEncodingQK(float* qk, const size_t layer_idx,
|
||||
const LayerWeightsPtrs& layer,
|
||||
const AttentionActivations& activations,
|
||||
const size_t worker, const size_t pos,
|
||||
const float mul = 1.0f) {
|
||||
const size_t qkv_dim = layer.layer_config.qkv_dim;
|
||||
const PostQKType& post_qk = layer.layer_config.post_qk;
|
||||
// qk is either q or k, so qkv_dim is the length we operate on.
|
||||
const float* inv_timescale = activations.inv_timescale.PackedScale1();
|
||||
const bool is_global_layer = activations.config.IsGlobalLayer(layer_idx);
|
||||
// TODO: add a config flag instead of hardcoding the model.
|
||||
if (is_global_layer && IsVLM(activations.config.model)) {
|
||||
inv_timescale = activations.inv_timescale_global.PackedScale1();
|
||||
}
|
||||
// PostQKType::Rope
|
||||
if (post_qk == PostQKType::HalfRope) {
|
||||
Rope(qk, qkv_dim / 2, inv_timescale, pos, worker);
|
||||
if (mul != 1.0f) MulByConst(mul, qk, qkv_dim, worker);
|
||||
} else {
|
||||
RopeAndMulBy(mul, qk, qkv_dim, inv_timescale, pos, worker);
|
||||
}
|
||||
}
|
||||
|
||||
// Accumulates the sum of v (from `kv_cache`) * probability (`att`) into
|
||||
// `att_out`. Equivalent in gemma/modules.py:
|
||||
// encoded = jnp.einsum('BTNS,BSNH->BTNH', probs, value_proj)
|
||||
// `v` is a strided view of the kv cache with dimensions [seq_len, qkv_dim].
|
||||
static HWY_INLINE void WeightedSumV(
|
||||
const size_t start_pos, const size_t last_pos,
|
||||
const hwy::Divisor& div_seq_len, const float* HWY_RESTRICT att,
|
||||
const MatPtrT<KV_t>& v, float* HWY_RESTRICT att_out, const size_t worker) {
|
||||
if (HWY_LIKELY(last_pos < static_cast<size_t>(div_seq_len.GetDivisor()))) {
|
||||
// Slightly faster: no wraparound. Could be replaced with MatMul(att, v) if
|
||||
// we supported non-transposed B.
|
||||
// TODO: 2..4x unroll
|
||||
MulByConstTo(att[start_pos], v.Row(start_pos), att_out, v.Cols(), worker);
|
||||
for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) {
|
||||
MulByConstAndAdd(att[pos], v.Row(pos), att_out, v.Cols(), worker);
|
||||
}
|
||||
} else {
|
||||
{
|
||||
const size_t pos_mod = div_seq_len.Remainder(start_pos);
|
||||
MulByConstTo(att[pos_mod], v.Row(pos_mod), att_out, v.Cols(), worker);
|
||||
}
|
||||
for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) {
|
||||
const size_t pos_mod = div_seq_len.Remainder(pos);
|
||||
MulByConstAndAdd(att[pos_mod], v.Row(pos_mod), att_out, v.Cols(), worker);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Calculates the attention outputs for a single q, which may be updated
|
||||
// in place for RMSNorm.
|
||||
void SingleDotSoftmaxWeightedSum(
|
||||
const size_t pos, const size_t start_pos, const size_t last_pos,
|
||||
float* HWY_RESTRICT q, const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v,
|
||||
const size_t layer_idx, const LayerWeightsPtrs& layer,
|
||||
const AttentionActivations& activations, float* HWY_RESTRICT att,
|
||||
float* HWY_RESTRICT att_out, const size_t worker) {
|
||||
const float att_cap = activations.config.att_cap;
|
||||
const float query_scale = activations.query_scale;
|
||||
const size_t seq_len =
|
||||
static_cast<size_t>(activations.div_seq_len.GetDivisor());
|
||||
|
||||
// Apply rope and scaling to Q.
|
||||
if (layer.query_norm_scale.HasPtr()) {
|
||||
CallUpcasted(&layer.query_norm_scale, [&](const auto* weights_t) {
|
||||
RMSNormInplace(weights_t->PackedScale1(), 0, q,
|
||||
layer.layer_config.qkv_dim, worker);
|
||||
});
|
||||
}
|
||||
|
||||
PositionalEncodingQK(q, layer_idx, layer, activations, worker, pos,
|
||||
query_scale);
|
||||
|
||||
QDotK(start_pos, last_pos, activations.div_seq_len, q, k, att, worker);
|
||||
|
||||
// SoftMax with optional SoftCap yields "probabilities" in att.
|
||||
const size_t att_len = HWY_MIN(last_pos + 1, seq_len);
|
||||
MaybeLogitsSoftCap(att_cap, att, att_len, worker);
|
||||
Softmax(att, att_len, worker, /*temperature=*/1.0f);
|
||||
|
||||
WeightedSumV(start_pos, last_pos, activations.div_seq_len, att, v, att_out,
|
||||
worker);
|
||||
}
|
||||
|
||||
// The attention window usually starts at 0 unless `pos` is larger than
|
||||
// the attention window size, then it is `pos` - window_size + 1.
|
||||
static HWY_INLINE size_t StartPos(size_t pos, const ModelConfig& config,
|
||||
size_t layer_idx) {
|
||||
const size_t att_window_size = config.attention_window_sizes[layer_idx];
|
||||
return pos - HWY_MIN(att_window_size - 1, pos);
|
||||
}
|
||||
|
||||
void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
|
||||
const LayerWeightsPtrs& layer,
|
||||
AttentionActivations& activations, QBatch& qbatch,
|
||||
NestedPools& pools) {
|
||||
static const uint32_t HWY_MAYBE_UNUSED zone_id_par =
|
||||
PROFILER_ADD_ZONE("Gen.Attention.DotSoftmax.par");
|
||||
|
||||
const hwy::Divisor div_qbatch(qbatch.Size());
|
||||
const LayerConfig& layer_config = layer.layer_config;
|
||||
const size_t qkv_dim = layer_config.qkv_dim;
|
||||
|
||||
// A "head group" in the context of GQA refers to a collection of query
|
||||
// heads that share the same key and value heads.
|
||||
const size_t kHeadGroups = layer_config.heads / layer_config.kv_heads;
|
||||
|
||||
const size_t cache_layer_size = layer_config.CacheLayerSize();
|
||||
const size_t seq_len =
|
||||
static_cast<size_t>(activations.div_seq_len.GetDivisor());
|
||||
// All layers should have the same number of heads.
|
||||
HWY_DASSERT(activations.div_heads.GetDivisor() == layer_config.heads);
|
||||
|
||||
// For each head/token/query, compute Q.K, softmax, and weighted V.
|
||||
const auto func = [&](const size_t task, size_t worker) HWY_ATTR {
|
||||
const size_t tq_idx = activations.div_heads.Divide(task);
|
||||
const size_t head = activations.div_heads.Remainder(task);
|
||||
#if PROFILER_ENABLED
|
||||
const hwy::Zone zone(worker, zone_id_par);
|
||||
#endif
|
||||
|
||||
const size_t qi = div_qbatch.Remainder(tq_idx);
|
||||
const size_t batch_idx = div_qbatch.Divide(tq_idx);
|
||||
auto& kv_cache = qbatch.KV(qi).kv_cache;
|
||||
|
||||
// Find the token position in the query and calculate
|
||||
// the range of cache positions to attend to.
|
||||
const size_t pos = qbatch.Pos(qi) + batch_idx;
|
||||
const size_t start_pos = StartPos(pos, activations.config, layer_idx);
|
||||
size_t last_pos = pos;
|
||||
const size_t prefix_end = qbatch.PrefixEnd(qi);
|
||||
if (prefix_end > 0 && prefix_end - 1 > last_pos) {
|
||||
// last_pos in QDotK and WeightedSumV is inclusive.
|
||||
last_pos = prefix_end - 1;
|
||||
}
|
||||
|
||||
float* HWY_RESTRICT q = activations.q.Row(tq_idx) + head * qkv_dim;
|
||||
float* HWY_RESTRICT att = activations.att.Row(tq_idx) + head * seq_len;
|
||||
float* HWY_RESTRICT att_out =
|
||||
activations.att_out.Row(tq_idx) + head * qkv_dim;
|
||||
|
||||
// Make strided read-only views into the kv cache for
|
||||
// this query and head.
|
||||
const size_t head_offset = (head / kHeadGroups) * qkv_dim * 2;
|
||||
const size_t kv_head_offset = layer_idx * cache_layer_size + head_offset;
|
||||
MatPtrT<KV_t> k("k_view", Extents2D(seq_len, qkv_dim));
|
||||
k.SetPtr(kv_cache.Row(0) + kv_head_offset, kv_cache.Stride());
|
||||
MatPtrT<KV_t> v("v_view", Extents2D(seq_len, qkv_dim));
|
||||
v.SetPtr(kv_cache.Row(0) + kv_head_offset + qkv_dim, kv_cache.Stride());
|
||||
|
||||
SingleDotSoftmaxWeightedSum(pos, start_pos, last_pos, q, k, v, layer_idx,
|
||||
layer, activations, att, att_out, worker);
|
||||
};
|
||||
|
||||
{
|
||||
PROFILER_ZONE("Gen.Attention.DotSoftmax.ForkJoin");
|
||||
const size_t pkg_idx = 0;
|
||||
// Full parallelism is helpful, SmallParallelFor is insufficient.
|
||||
ParallelFor(num_tokens * div_qbatch.GetDivisor() * layer_config.heads,
|
||||
pools, pkg_idx, func);
|
||||
}
|
||||
}
|
||||
|
||||
// Different functions use different naming conventions for the number of
|
||||
// tokens. Functions that are query-independent, such as RMSNorm*, call the
|
||||
// count `num_interleaved`. Functions that are query-dependent, such as
|
||||
// `Attention`, use separate `num_tokens` and `num_queries`. `num_tokens` is the
|
||||
// number of tokens from one query: 1 for decode, otherwise prefill_tbatch_size.
|
||||
|
||||
// Fills activations.q and writes to KV cache.
|
||||
static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
|
||||
const LayerWeightsPtrs& layer,
|
||||
AttentionActivations& activations,
|
||||
const QBatch& qbatch, const int flags,
|
||||
MatMulEnv& env) {
|
||||
PROFILER_ZONE("Gen.Attention.QKV");
|
||||
const hwy::Divisor div_qbatch(qbatch.Size());
|
||||
const size_t num_interleaved = num_tokens * div_qbatch.GetDivisor();
|
||||
const LayerConfig& layer_config = layer.layer_config;
|
||||
const size_t qkv_dim = layer_config.qkv_dim;
|
||||
const size_t kv_heads = layer_config.kv_heads;
|
||||
const size_t cache_layer_size = layer_config.CacheLayerSize();
|
||||
|
||||
// The original qkv_einsum_w has shape [(heads + kv_heads * 2), qkv_dim,
|
||||
// model_dim], which we reshaped to (heads + kv_heads * 2) * qkv_dim rows.
|
||||
CallMatMul(activations.pre_att_rms_out, layer.qkv_einsum_w1,
|
||||
/*add=*/nullptr, env, activations.q);
|
||||
|
||||
// Set up MatMul row pointers for writing to KV, which consists of
|
||||
// `kv_heads` pairs of (k, v) vectors. This safely handles wraparound
|
||||
// because rows are computed modulo seq_len.
|
||||
MatPtrT<KV_t> kv_rows("kv", Extents2D(activations.pre_att_rms_out.Rows(),
|
||||
layer.qkv_einsum_w2.Rows()));
|
||||
for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved;
|
||||
++interleaved_idx) {
|
||||
const size_t qi = div_qbatch.Remainder(interleaved_idx);
|
||||
const size_t batch_idx = div_qbatch.Divide(interleaved_idx);
|
||||
const size_t cache_pos =
|
||||
activations.div_seq_len.Remainder(qbatch.Pos(qi) + batch_idx);
|
||||
env.row_ptrs[2][interleaved_idx] = reinterpret_cast<uint8_t*>(
|
||||
qbatch.KV(qi).kv_cache.Row(cache_pos) + layer_idx * cache_layer_size);
|
||||
}
|
||||
kv_rows.AttachRowPtrs(env.row_ptrs[2].get());
|
||||
CallMatMul(activations.pre_att_rms_out, layer.qkv_einsum_w2,
|
||||
/*add=*/nullptr, env, kv_rows);
|
||||
|
||||
// Apply positional encodings for K.
|
||||
// Note that 2D parallelism is not worth the fork/join overhead because the
|
||||
// tasks are very lightweight.
|
||||
env.ctx.pools.Pool(0).Run(
|
||||
0, kv_heads * num_interleaved,
|
||||
[&](uint64_t task, size_t thread) HWY_ATTR {
|
||||
const size_t head = task % kv_heads;
|
||||
const size_t interleaved_idx = task / kv_heads;
|
||||
const size_t qi = div_qbatch.Remainder(interleaved_idx);
|
||||
const size_t batch_idx = div_qbatch.Divide(interleaved_idx);
|
||||
const size_t pos = qbatch.Pos(qi) + batch_idx;
|
||||
const size_t cache_pos = activations.div_seq_len.Remainder(pos);
|
||||
auto& kv_cache = qbatch.KV(qi).kv_cache;
|
||||
KV_t* HWY_RESTRICT kv = kv_cache.Row(cache_pos) +
|
||||
layer_idx * cache_layer_size +
|
||||
head * qkv_dim * 2;
|
||||
|
||||
HWY_ALIGN float kv_f32[2 * kMaxQKVDim];
|
||||
const hn::ScalableTag<float> df;
|
||||
DecompressAndZeroPad(df, MakeSpan(kv, 2 * qkv_dim), 0, kv_f32,
|
||||
2 * qkv_dim);
|
||||
|
||||
// Apply further processing to K.
|
||||
if (layer.key_norm_scale.HasPtr()) {
|
||||
CallUpcasted(&layer.key_norm_scale, [&](const auto* weights_t) {
|
||||
RMSNormInplace(weights_t->PackedScale1(), 0, kv_f32, qkv_dim,
|
||||
thread);
|
||||
});
|
||||
}
|
||||
|
||||
PositionalEncodingQK(kv_f32, layer_idx, layer, activations, thread,
|
||||
pos);
|
||||
CompressPerThread tls;
|
||||
Compress(kv_f32, 2 * qkv_dim, tls, MakeSpan(kv, 2 * qkv_dim), 0);
|
||||
});
|
||||
}
|
||||
|
||||
// Sums encoded (`att_out`) over num_heads (`layer_config.heads`) and
|
||||
// head_dim (`qkv_dim`) into output (`layer_out`).
|
||||
static HWY_INLINE void SumHeads(const LayerWeightsPtrs& layer,
|
||||
AttentionActivations& activations,
|
||||
MatMulEnv& env) {
|
||||
PROFILER_ZONE("Gen.Attention.SumHeads");
|
||||
const LayerConfig& layer_config = layer.layer_config;
|
||||
// att_weights and att_out are concatenated heads, each of length
|
||||
// layer_config.qkv_dim. Thus the [num_interleaved,
|
||||
// layer_config.model_dim] matmul output is the sum over heads. Compare
|
||||
// gemma/modules.py: attn_output = self.attn_vec_einsum('BTNH,NHD->BTD',
|
||||
// encoded)
|
||||
HWY_DASSERT(layer_config.model_dim != 0 && layer_config.heads != 0 &&
|
||||
layer_config.qkv_dim != 0);
|
||||
const float* add = layer_config.softmax_attn_output_biases
|
||||
? layer.attention_output_biases.PackedScale1()
|
||||
: nullptr;
|
||||
CallMatMul(activations.att_out, layer.att_weights, add, env,
|
||||
activations.att_sums);
|
||||
}
|
||||
|
||||
void GemmaAttention(size_t num_tokens, const size_t layer_idx,
|
||||
const LayerWeightsPtrs& layer,
|
||||
AttentionActivations& activations, QBatch& qbatch,
|
||||
MatMulEnv& env, int flags) {
|
||||
const LayerConfig& layer_config = layer.layer_config;
|
||||
HWY_DASSERT(!layer_config.IsMHA()); // No longer supported.
|
||||
HWY_DASSERT_M((layer_config.heads % layer_config.kv_heads) == 0,
|
||||
"query heads must be a multiple of key-value heads");
|
||||
(void)layer_config; // only used in HWY_DASSERT
|
||||
|
||||
ComputeQKV(num_tokens, layer_idx, layer, activations, qbatch, flags, env);
|
||||
DotSoftmaxWeightedSum(num_tokens, layer_idx, layer, activations, qbatch,
|
||||
env.ctx.pools);
|
||||
SumHeads(layer, activations, env);
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||
} // namespace HWY_NAMESPACE
|
||||
} // namespace gcpp
|
||||
HWY_AFTER_NAMESPACE();
|
||||
|
|
@ -0,0 +1,59 @@
|
|||
// Copyright 2025 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_ATTENTION_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_ATTENTION_H_
|
||||
|
||||
// Declares GemmaAttention for all SIMD targets.
|
||||
|
||||
#include <stddef.h>
|
||||
|
||||
#include "gemma/gemma.h"
|
||||
#include "hwy/highway.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
// Passed to HWY_VISIT_TARGETS; declares for one target.
|
||||
#define GEMMA_DECL_ATTENTION(TARGET, NAMESPACE) \
|
||||
namespace NAMESPACE { \
|
||||
void SingleDotSoftmaxWeightedSum( \
|
||||
const size_t pos, const size_t start_pos, const size_t last_pos, \
|
||||
float* HWY_RESTRICT q, const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v, \
|
||||
size_t layer_idx, const LayerWeightsPtrs& layer, \
|
||||
const AttentionActivations& activations, float* HWY_RESTRICT att, \
|
||||
float* HWY_RESTRICT att_out, size_t worker); \
|
||||
\
|
||||
void DotSoftmaxWeightedSum(const size_t num_tokens, size_t layer_idx, \
|
||||
const LayerWeightsPtrs& layer, \
|
||||
AttentionActivations& activations, \
|
||||
QBatch& qbatch, NestedPools& pools); \
|
||||
\
|
||||
void GemmaAttention(size_t num_tokens, const size_t layer_idx, \
|
||||
const LayerWeightsPtrs& layer, \
|
||||
AttentionActivations& activations, QBatch& qbatch, \
|
||||
MatMulEnv& env, int flags); \
|
||||
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
|
||||
} // namespace NAMESPACE
|
||||
|
||||
// Function declarations for each SIMD target. Allows direct call from the
|
||||
// per-target namespace. We may later replace this with dynamic dispatch if
|
||||
// the overhead is acceptable.
|
||||
HWY_VISIT_TARGETS(GEMMA_DECL_ATTENTION)
|
||||
|
||||
#undef GEMMA_DECL_ATTENTION
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_ATTENTION_H_
|
||||
|
|
@ -0,0 +1,473 @@
|
|||
// Copyright 2025 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
using System;
|
||||
using System.Diagnostics;
|
||||
using System.Runtime.InteropServices;
|
||||
using System.Text;
|
||||
namespace GemmaCpp
|
||||
{
|
||||
public class GemmaException : Exception
|
||||
{
|
||||
public GemmaException(string message) : base(message) { }
|
||||
}
|
||||
|
||||
public class Gemma : IDisposable
|
||||
{
|
||||
private IntPtr _context;
|
||||
private bool _disposed;
|
||||
|
||||
// Optional: Allow setting DLL path
|
||||
public static string DllPath { get; set; } = "gemma.dll";
|
||||
|
||||
[DllImport("kernel32.dll", CharSet = CharSet.Unicode, SetLastError = true)]
|
||||
private static extern IntPtr LoadLibrary(string lpFileName);
|
||||
|
||||
static Gemma()
|
||||
{
|
||||
// Load DLL from specified path
|
||||
if (LoadLibrary(DllPath) == IntPtr.Zero)
|
||||
{
|
||||
throw new DllNotFoundException($"Failed to load {DllPath}. Error: {Marshal.GetLastWin32Error()}");
|
||||
}
|
||||
}
|
||||
|
||||
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl)]
|
||||
private static extern IntPtr GemmaCreate(
|
||||
[MarshalAs(UnmanagedType.LPUTF8Str)] string tokenizerPath,
|
||||
[MarshalAs(UnmanagedType.LPUTF8Str)] string modelType,
|
||||
[MarshalAs(UnmanagedType.LPUTF8Str)] string weightsPath,
|
||||
[MarshalAs(UnmanagedType.LPUTF8Str)] string weightType,
|
||||
int maxGeneratedTokens);
|
||||
|
||||
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl)]
|
||||
private static extern void GemmaDestroy(IntPtr context);
|
||||
|
||||
// Delegate type for token callbacks
|
||||
public delegate bool TokenCallback(string token);
|
||||
|
||||
// Keep delegate alive for duration of calls
|
||||
private GCHandle _callbackHandle;
|
||||
|
||||
[UnmanagedFunctionPointer(CallingConvention.Cdecl)]
|
||||
private delegate bool GemmaTokenCallback(
|
||||
[MarshalAs(UnmanagedType.LPUTF8Str)] string text,
|
||||
IntPtr userData);
|
||||
|
||||
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl)]
|
||||
private static extern int GemmaGenerate(
|
||||
IntPtr context,
|
||||
[MarshalAs(UnmanagedType.LPUTF8Str)] string prompt,
|
||||
[Out] byte[] output,
|
||||
int maxOutputChars,
|
||||
GemmaTokenCallback callback,
|
||||
IntPtr userData);
|
||||
|
||||
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl)]
|
||||
private static extern int GemmaGenerateMultimodal(
|
||||
IntPtr context,
|
||||
[MarshalAs(UnmanagedType.LPUTF8Str)] string prompt,
|
||||
IntPtr image_data, // Renamed param to match C API
|
||||
int image_width, // Added dimension
|
||||
int image_height, // Added dimension
|
||||
[MarshalAs(UnmanagedType.LPUTF8Str)] StringBuilder output, // Output should be StringBuilder for multimodal
|
||||
int maxOutputChars,
|
||||
GemmaTokenCallback callback,
|
||||
IntPtr userData);
|
||||
|
||||
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl)]
|
||||
private static extern int GemmaCountTokens(
|
||||
IntPtr context,
|
||||
[MarshalAs(UnmanagedType.LPUTF8Str)] string text);
|
||||
|
||||
// Configuration function imports
|
||||
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl)]
|
||||
private static extern void GemmaSetMaxGeneratedTokens(IntPtr context, int value);
|
||||
|
||||
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl)]
|
||||
private static extern void GemmaSetMultiturn(IntPtr context, int value);
|
||||
|
||||
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl)]
|
||||
private static extern void GemmaSetTemperature(IntPtr context, float value);
|
||||
|
||||
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl)]
|
||||
private static extern void GemmaSetTopK(IntPtr context, int value);
|
||||
|
||||
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl)]
|
||||
private static extern void GemmaSetDeterministic(IntPtr context, int value);
|
||||
|
||||
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl)]
|
||||
private static extern void GemmaSetPrefillTbatchSize(IntPtr context, int value);
|
||||
|
||||
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl, EntryPoint = "GemmaResetConversation")]
|
||||
private static extern void GemmaResetConversation(IntPtr context);
|
||||
|
||||
// Conversation management function imports
|
||||
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl, EntryPoint = "GemmaCreateConversation")]
|
||||
private static extern int GemmaCreateConversation(
|
||||
IntPtr context,
|
||||
[MarshalAs(UnmanagedType.LPUTF8Str)] string conversationName);
|
||||
|
||||
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl, EntryPoint = "GemmaSwitchConversation")]
|
||||
private static extern int GemmaSwitchConversation(
|
||||
IntPtr context,
|
||||
[MarshalAs(UnmanagedType.LPUTF8Str)] string conversationName);
|
||||
|
||||
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl, EntryPoint = "GemmaDeleteConversation")]
|
||||
private static extern int GemmaDeleteConversation(
|
||||
IntPtr context,
|
||||
[MarshalAs(UnmanagedType.LPUTF8Str)] string conversationName);
|
||||
|
||||
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl, EntryPoint = "GemmaHasConversation")]
|
||||
private static extern int GemmaHasConversation(
|
||||
IntPtr context,
|
||||
[MarshalAs(UnmanagedType.LPUTF8Str)] string conversationName);
|
||||
|
||||
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl, EntryPoint = "GemmaGetCurrentConversation")]
|
||||
[return: MarshalAs(UnmanagedType.LPUTF8Str)] // Marshal the const char* return value as a string
|
||||
private static extern string GemmaGetCurrentConversation(IntPtr context);
|
||||
|
||||
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl, EntryPoint = "GemmaSaveConversation")]
|
||||
private static extern void GemmaSaveConversation(IntPtr context);
|
||||
|
||||
// Native callback delegate type
|
||||
[UnmanagedFunctionPointer(CallingConvention.Cdecl)]
|
||||
private delegate void GemmaLogCallback(
|
||||
[MarshalAs(UnmanagedType.LPUTF8Str)] string message,
|
||||
IntPtr userData);
|
||||
|
||||
[DllImport("gemma", CallingConvention = CallingConvention.Cdecl)]
|
||||
private static extern void GemmaSetLogCallback(
|
||||
IntPtr context,
|
||||
GemmaLogCallback callback,
|
||||
IntPtr userData);
|
||||
|
||||
private GCHandle _logCallbackHandle;
|
||||
private bool _loggingEnabled = false;
|
||||
|
||||
public Gemma(string tokenizerPath, string weightsPath, int maxGeneratedTokens = 8192)
|
||||
{
|
||||
_context = GemmaCreate(tokenizerPath, weightsPath, maxGeneratedTokens);
|
||||
if (_context == IntPtr.Zero)
|
||||
{
|
||||
throw new GemmaException("Failed to create Gemma context");
|
||||
}
|
||||
}
|
||||
|
||||
// Enable debug logging
|
||||
public void EnableLogging(bool enable = true)
|
||||
{
|
||||
if (enable && !_loggingEnabled)
|
||||
{
|
||||
GemmaLogCallback logCallback = (message, _) =>
|
||||
{
|
||||
Debug.WriteLine($"Gemma: {message}");
|
||||
};
|
||||
_logCallbackHandle = GCHandle.Alloc(logCallback);
|
||||
GemmaSetLogCallback(_context, logCallback, IntPtr.Zero);
|
||||
_loggingEnabled = true;
|
||||
}
|
||||
else if (!enable && _loggingEnabled)
|
||||
{
|
||||
if (_logCallbackHandle.IsAllocated)
|
||||
_logCallbackHandle.Free();
|
||||
GemmaSetLogCallback(_context, null, IntPtr.Zero);
|
||||
_loggingEnabled = false;
|
||||
}
|
||||
}
|
||||
|
||||
// Configuration methods
|
||||
public void SetMultiturn(bool enable)
|
||||
{
|
||||
if (_disposed)
|
||||
throw new ObjectDisposedException(nameof(Gemma));
|
||||
|
||||
if (_context == IntPtr.Zero)
|
||||
throw new GemmaException("Gemma context is invalid");
|
||||
|
||||
GemmaSetMultiturn(_context, enable ? 1 : 0);
|
||||
Debug.WriteLine($"Gemma: Set multiturn to {(enable ? "enabled" : "disabled")}");
|
||||
}
|
||||
|
||||
public void SetTemperature(float temperature)
|
||||
{
|
||||
if (_disposed)
|
||||
throw new ObjectDisposedException(nameof(Gemma));
|
||||
|
||||
if (_context == IntPtr.Zero)
|
||||
throw new GemmaException("Gemma context is invalid");
|
||||
|
||||
GemmaSetTemperature(_context, temperature);
|
||||
Debug.WriteLine($"Gemma: Set temperature to {temperature}");
|
||||
}
|
||||
|
||||
public void SetTopK(int topK)
|
||||
{
|
||||
if (_disposed)
|
||||
throw new ObjectDisposedException(nameof(Gemma));
|
||||
|
||||
if (_context == IntPtr.Zero)
|
||||
throw new GemmaException("Gemma context is invalid");
|
||||
|
||||
GemmaSetTopK(_context, topK);
|
||||
Debug.WriteLine($"Gemma: Set topK to {topK}");
|
||||
}
|
||||
|
||||
public void SetDeterministic(bool deterministic)
|
||||
{
|
||||
if (_disposed)
|
||||
throw new ObjectDisposedException(nameof(Gemma));
|
||||
|
||||
if (_context == IntPtr.Zero)
|
||||
throw new GemmaException("Gemma context is invalid");
|
||||
|
||||
GemmaSetDeterministic(_context, deterministic ? 1 : 0);
|
||||
Debug.WriteLine($"Gemma: Set deterministic to {(deterministic ? "true" : "false")}");
|
||||
}
|
||||
|
||||
// Renamed public method
|
||||
public void ResetConversation()
|
||||
{
|
||||
if (_disposed)
|
||||
throw new ObjectDisposedException(nameof(Gemma));
|
||||
|
||||
if (_context == IntPtr.Zero)
|
||||
throw new GemmaException("Gemma context is invalid");
|
||||
|
||||
GemmaResetConversation(_context); // Call P/Invoke method
|
||||
Debug.WriteLine("Gemma: Reset active conversation");
|
||||
}
|
||||
|
||||
// Conversation management methods
|
||||
public bool CreateConversation(string conversationName)
|
||||
{
|
||||
if (_disposed)
|
||||
throw new ObjectDisposedException(nameof(Gemma));
|
||||
|
||||
if (_context == IntPtr.Zero)
|
||||
throw new GemmaException("Gemma context is invalid");
|
||||
|
||||
bool result = GemmaCreateConversation(_context, conversationName) != 0; // Call P/Invoke method
|
||||
Debug.WriteLine($"Gemma: Create conversation '{conversationName}' - {(result ? "succeeded" : "failed")}");
|
||||
return result;
|
||||
}
|
||||
|
||||
public bool SwitchConversation(string conversationName)
|
||||
{
|
||||
if (_disposed)
|
||||
throw new ObjectDisposedException(nameof(Gemma));
|
||||
|
||||
if (_context == IntPtr.Zero)
|
||||
throw new GemmaException("Gemma context is invalid");
|
||||
|
||||
bool result = GemmaSwitchConversation(_context, conversationName) != 0; // Call P/Invoke method
|
||||
Debug.WriteLine($"Gemma: Switch to conversation '{conversationName}' - {(result ? "succeeded" : "failed")}");
|
||||
return result;
|
||||
}
|
||||
|
||||
public bool DeleteConversation(string conversationName)
|
||||
{
|
||||
if (_disposed)
|
||||
throw new ObjectDisposedException(nameof(Gemma));
|
||||
|
||||
if (_context == IntPtr.Zero)
|
||||
throw new GemmaException("Gemma context is invalid");
|
||||
|
||||
bool result = GemmaDeleteConversation(_context, conversationName) != 0; // Call P/Invoke method
|
||||
Debug.WriteLine($"Gemma: Delete conversation '{conversationName}' - {(result ? "succeeded" : "failed")}");
|
||||
return result;
|
||||
}
|
||||
|
||||
public bool HasConversation(string conversationName)
|
||||
{
|
||||
if (_disposed)
|
||||
throw new ObjectDisposedException(nameof(Gemma));
|
||||
|
||||
if (_context == IntPtr.Zero)
|
||||
throw new GemmaException("Gemma context is invalid");
|
||||
|
||||
bool result = GemmaHasConversation(_context, conversationName) != 0; // Call P/Invoke method
|
||||
Debug.WriteLine($"Gemma: Has conversation '{conversationName}' - {result}");
|
||||
return result;
|
||||
}
|
||||
|
||||
public string GetCurrentConversation()
|
||||
{
|
||||
if (_disposed)
|
||||
throw new ObjectDisposedException(nameof(Gemma));
|
||||
|
||||
if (_context == IntPtr.Zero)
|
||||
throw new GemmaException("Gemma context is invalid");
|
||||
|
||||
string currentConversation = GemmaGetCurrentConversation(_context); // Call P/Invoke method
|
||||
Debug.WriteLine($"Gemma: Current conversation is '{currentConversation}'");
|
||||
return currentConversation;
|
||||
}
|
||||
|
||||
public void SaveConversation()
|
||||
{
|
||||
if (_disposed)
|
||||
throw new ObjectDisposedException(nameof(Gemma));
|
||||
|
||||
if (_context == IntPtr.Zero)
|
||||
throw new GemmaException("Gemma context is invalid");
|
||||
|
||||
GemmaSaveConversation(_context);
|
||||
Debug.WriteLine($"Gemma: Saved current conversation ('{GetCurrentConversation()}') to prewarmed cache.");
|
||||
}
|
||||
|
||||
public int CountTokens(string prompt)
|
||||
{
|
||||
if (_disposed)
|
||||
throw new ObjectDisposedException(nameof(Gemma));
|
||||
|
||||
if (_context == IntPtr.Zero)
|
||||
throw new GemmaException("Gemma context is invalid");
|
||||
int count = GemmaCountTokens(_context, prompt);
|
||||
return count;
|
||||
}
|
||||
|
||||
public string Generate(string prompt, int maxOutputChars = 4096)
|
||||
{
|
||||
return Generate(prompt, null, maxOutputChars);
|
||||
}
|
||||
|
||||
public string Generate(string prompt, TokenCallback callback, int maxOutputChars = 4096)
|
||||
{
|
||||
if (_disposed)
|
||||
throw new ObjectDisposedException(nameof(Gemma));
|
||||
|
||||
if (_context == IntPtr.Zero)
|
||||
throw new GemmaException("Gemma context is invalid");
|
||||
|
||||
var outputBuffer = new byte[maxOutputChars * 4]; // Allow for worst case UTF-8 size
|
||||
GemmaTokenCallback nativeCallback = null;
|
||||
|
||||
// Track token count for debugging
|
||||
int tokenCount = 0;
|
||||
|
||||
if (callback != null)
|
||||
{
|
||||
nativeCallback = (text, _) =>
|
||||
{
|
||||
tokenCount++;
|
||||
// Log token for debugging
|
||||
Debug.WriteLine($"Token {tokenCount}: '{text}'");
|
||||
|
||||
// Pass token to user callback
|
||||
return callback(text);
|
||||
};
|
||||
_callbackHandle = GCHandle.Alloc(nativeCallback);
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
int length = GemmaGenerate(_context, prompt, outputBuffer, maxOutputChars,
|
||||
nativeCallback, IntPtr.Zero);
|
||||
|
||||
if (length < 0)
|
||||
throw new GemmaException("Generation failed");
|
||||
|
||||
Debug.WriteLine($"Generation complete: {tokenCount} tokens processed, result length: {length}");
|
||||
|
||||
// Convert the byte buffer to a string using UTF-8 encoding
|
||||
string result = Encoding.UTF8.GetString(outputBuffer, 0, length);
|
||||
return result;
|
||||
}
|
||||
finally
|
||||
{
|
||||
if (_callbackHandle.IsAllocated)
|
||||
_callbackHandle.Free();
|
||||
}
|
||||
}
|
||||
|
||||
public string GenerateMultimodal(string prompt, float[] imageData, int imageWidth, int imageHeight, int maxOutputChars = 4096)
|
||||
{
|
||||
// Pass width and height to the overloaded method
|
||||
return GenerateMultimodal(prompt, imageData, imageWidth, imageHeight, null, maxOutputChars);
|
||||
}
|
||||
|
||||
public string GenerateMultimodal(string prompt, float[] imageData, int imageWidth, int imageHeight, TokenCallback callback, int maxOutputChars = 4096)
|
||||
{
|
||||
if (_disposed)
|
||||
throw new ObjectDisposedException(nameof(Gemma));
|
||||
|
||||
if (_context == IntPtr.Zero)
|
||||
throw new GemmaException("Gemma context is invalid");
|
||||
|
||||
if (imageData == null || imageData.Length == 0)
|
||||
throw new ArgumentException("Image data cannot be null or empty", nameof(imageData));
|
||||
|
||||
if (imageWidth <= 0 || imageHeight <= 0)
|
||||
throw new ArgumentException("Image dimensions must be positive");
|
||||
|
||||
if (imageData.Length < imageWidth * imageHeight * 3)
|
||||
throw new ArgumentException("Image data array is too small for the specified dimensions");
|
||||
|
||||
var output = new StringBuilder(maxOutputChars);
|
||||
GemmaTokenCallback nativeCallback = null;
|
||||
|
||||
if (callback != null)
|
||||
{
|
||||
nativeCallback = (text, _) => callback(text);
|
||||
_callbackHandle = GCHandle.Alloc(nativeCallback);
|
||||
}
|
||||
|
||||
// Pin the image data so it doesn't move during the native call
|
||||
GCHandle imageHandle = GCHandle.Alloc(imageData, GCHandleType.Pinned);
|
||||
|
||||
try
|
||||
{
|
||||
IntPtr imagePtr = imageHandle.AddrOfPinnedObject();
|
||||
|
||||
// Pass image dimensions to the native call
|
||||
int length = GemmaGenerateMultimodal(_context, prompt, imagePtr, imageWidth, imageHeight, output, maxOutputChars,
|
||||
nativeCallback, IntPtr.Zero);
|
||||
|
||||
if (length < 0)
|
||||
throw new GemmaException("Multimodal generation failed");
|
||||
|
||||
return output.ToString();
|
||||
}
|
||||
finally
|
||||
{
|
||||
imageHandle.Free();
|
||||
|
||||
if (_callbackHandle.IsAllocated)
|
||||
_callbackHandle.Free();
|
||||
}
|
||||
}
|
||||
|
||||
public void Dispose()
|
||||
{
|
||||
if (!_disposed)
|
||||
{
|
||||
if (_context != IntPtr.Zero)
|
||||
{
|
||||
GemmaDestroy(_context);
|
||||
_context = IntPtr.Zero;
|
||||
}
|
||||
if (_logCallbackHandle.IsAllocated)
|
||||
_logCallbackHandle.Free();
|
||||
_disposed = true;
|
||||
}
|
||||
}
|
||||
|
||||
~Gemma()
|
||||
{
|
||||
Dispose();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,139 @@
|
|||
// Copyright 2025 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef GEMMA_EXPORTS
|
||||
#define GEMMA_EXPORTS
|
||||
#endif
|
||||
|
||||
#include "gemma/bindings/c_api.h"
|
||||
|
||||
extern "C" {
|
||||
|
||||
GEMMA_API GemmaContext* GemmaCreate(const char* tokenizer_path,
|
||||
const char* weights_path,
|
||||
int max_generated_tokens) {
|
||||
try {
|
||||
GemmaContext* ctx = GemmaContext::Create(tokenizer_path, weights_path,
|
||||
max_generated_tokens);
|
||||
return ctx;
|
||||
} catch (...) {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
GEMMA_API void GemmaDestroy(GemmaContext* ctx) {
|
||||
delete ctx;
|
||||
}
|
||||
|
||||
GEMMA_API int GemmaGenerate(GemmaContext* ctx, const char* prompt, char* output,
|
||||
int max_output_chars, GemmaTokenCallback callback,
|
||||
void* user_data) {
|
||||
if (!ctx) return -1;
|
||||
return ctx->Generate(prompt, output, max_output_chars, callback, user_data);
|
||||
}
|
||||
|
||||
GEMMA_API int GemmaGenerateMultimodal(GemmaContext* ctx, const char* prompt,
|
||||
const void* image_data, int image_width,
|
||||
int image_height, char* output,
|
||||
int max_output_chars,
|
||||
GemmaTokenCallback callback,
|
||||
void* user_data) {
|
||||
if (!ctx) return -1;
|
||||
|
||||
return ctx->GenerateMultimodal(prompt, image_data, image_width, image_height,
|
||||
output, max_output_chars, callback, user_data);
|
||||
}
|
||||
|
||||
GEMMA_API int GemmaCountTokens(GemmaContext* ctx, const char* text) {
|
||||
if (!ctx || !text) return -1;
|
||||
return ctx->CountTokens(text);
|
||||
}
|
||||
|
||||
GEMMA_API void GemmaSetLogCallback(GemmaContext* ctx, GemmaLogCallback callback,
|
||||
void* user_data) {
|
||||
if (!ctx) return;
|
||||
ctx->SetLogCallback(callback, user_data);
|
||||
}
|
||||
|
||||
// Configuration functions implementation
|
||||
GEMMA_API void GemmaSetMaxGeneratedTokens(GemmaContext* ctx, int value) {
|
||||
if (!ctx) return;
|
||||
ctx->SetMaxGeneratedTokens(value);
|
||||
}
|
||||
|
||||
GEMMA_API void GemmaSetMultiturn(GemmaContext* ctx, int value) {
|
||||
if (!ctx) return;
|
||||
ctx->SetMultiturn(value);
|
||||
}
|
||||
|
||||
GEMMA_API void GemmaSetTemperature(GemmaContext* ctx, float value) {
|
||||
if (!ctx) return;
|
||||
ctx->SetTemperature(value);
|
||||
}
|
||||
|
||||
GEMMA_API void GemmaSetTopK(GemmaContext* ctx, int value) {
|
||||
if (!ctx) return;
|
||||
ctx->SetTopK(value);
|
||||
}
|
||||
|
||||
GEMMA_API void GemmaSetDeterministic(GemmaContext* ctx, int value) {
|
||||
if (!ctx) return;
|
||||
ctx->SetDeterministic(value != 0);
|
||||
}
|
||||
|
||||
GEMMA_API void GemmaSetPrefillTbatchSize(GemmaContext* ctx, int value) {
|
||||
if (!ctx) return;
|
||||
ctx->SetPrefillTbatchSize(value);
|
||||
}
|
||||
|
||||
GEMMA_API void GemmaResetConversation(GemmaContext* ctx) { // Renamed function
|
||||
if (!ctx) return;
|
||||
ctx->ResetConversation();
|
||||
}
|
||||
|
||||
GEMMA_API int GemmaCreateConversation(GemmaContext* ctx,
|
||||
const char* conversation_name) {
|
||||
if (!ctx || !conversation_name) return 0;
|
||||
return ctx->CreateConversation(conversation_name) ? 1 : 0;
|
||||
}
|
||||
|
||||
GEMMA_API int GemmaSwitchConversation(GemmaContext* ctx,
|
||||
const char* conversation_name) {
|
||||
if (!ctx || !conversation_name) return 0;
|
||||
return ctx->SwitchConversation(conversation_name) ? 1 : 0;
|
||||
}
|
||||
|
||||
GEMMA_API int GemmaDeleteConversation(GemmaContext* ctx,
|
||||
const char* conversation_name) {
|
||||
if (!ctx || !conversation_name) return 0;
|
||||
return ctx->DeleteConversation(conversation_name) ? 1 : 0;
|
||||
}
|
||||
|
||||
GEMMA_API int GemmaHasConversation(GemmaContext* ctx,
|
||||
const char* conversation_name) {
|
||||
if (!ctx || !conversation_name) return 0;
|
||||
return ctx->HasConversation(conversation_name) ? 1 : 0;
|
||||
}
|
||||
|
||||
GEMMA_API const char* GemmaGetCurrentConversation(GemmaContext* ctx) {
|
||||
if (!ctx) return nullptr;
|
||||
return ctx->GetCurrentConversation();
|
||||
}
|
||||
|
||||
GEMMA_API void GemmaSaveConversation(GemmaContext* ctx) {
|
||||
if (!ctx) return;
|
||||
ctx->SaveConversation();
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,86 @@
|
|||
// Copyright 2025 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef THIRD_PARTY_GEMMA_C_API_H_
|
||||
#define THIRD_PARTY_GEMMA_C_API_H_
|
||||
|
||||
#include "gemma/bindings/context.h"
|
||||
|
||||
#ifdef _WIN32
|
||||
#ifdef GEMMA_EXPORTS
|
||||
#define GEMMA_API __declspec(dllexport)
|
||||
#else
|
||||
#define GEMMA_API __declspec(dllimport)
|
||||
#endif
|
||||
#else
|
||||
#define GEMMA_API __attribute__((visibility("default")))
|
||||
#endif
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
#ifdef __cplusplus
|
||||
typedef gcpp::GemmaContext GemmaContext;
|
||||
#else
|
||||
typedef struct GemmaContext GemmaContext;
|
||||
#endif
|
||||
|
||||
typedef bool (*GemmaTokenCallback)(const char* text, void* user_data);
|
||||
typedef void (*GemmaLogCallback)(const char* message, void* user_data);
|
||||
|
||||
GEMMA_API GemmaContext* GemmaCreate(const char* tokenizer_path,
|
||||
const char* weights_path,
|
||||
int max_generated_tokens);
|
||||
GEMMA_API void GemmaDestroy(GemmaContext* ctx);
|
||||
GEMMA_API int GemmaGenerate(GemmaContext* ctx, const char* prompt, char* output,
|
||||
int max_output_chars, GemmaTokenCallback callback,
|
||||
void* user_data);
|
||||
GEMMA_API int GemmaGenerateMultimodal(GemmaContext* ctx, const char* prompt,
|
||||
const void* image_data, int image_width,
|
||||
int image_height, char* output,
|
||||
int max_output_chars,
|
||||
GemmaTokenCallback callback,
|
||||
void* user_data);
|
||||
|
||||
GEMMA_API int GemmaCountTokens(GemmaContext* ctx, const char* text);
|
||||
|
||||
GEMMA_API void GemmaSetLogCallback(GemmaContext* ctx, GemmaLogCallback callback,
|
||||
void* user_data);
|
||||
|
||||
// Configuration functions
|
||||
GEMMA_API void GemmaSetMultiturn(GemmaContext* ctx, int value);
|
||||
GEMMA_API void GemmaSetTemperature(GemmaContext* ctx, float value);
|
||||
GEMMA_API void GemmaSetTopK(GemmaContext* ctx, int value);
|
||||
GEMMA_API void GemmaSetDeterministic(GemmaContext* ctx, int value);
|
||||
GEMMA_API void GemmaResetConversation(GemmaContext* ctx);
|
||||
|
||||
// Conversation management functions (renamed)
|
||||
GEMMA_API int GemmaCreateConversation(GemmaContext* ctx,
|
||||
const char* conversation_name);
|
||||
GEMMA_API int GemmaSwitchConversation(GemmaContext* ctx,
|
||||
const char* conversation_name);
|
||||
GEMMA_API int GemmaDeleteConversation(GemmaContext* ctx,
|
||||
const char* conversation_name);
|
||||
GEMMA_API int GemmaHasConversation(GemmaContext* ctx,
|
||||
const char* conversation_name);
|
||||
GEMMA_API const char* GemmaGetCurrentConversation(GemmaContext* ctx);
|
||||
GEMMA_API void GemmaSaveConversation(GemmaContext* ctx);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_C_API_H_
|
||||
|
|
@ -0,0 +1,350 @@
|
|||
// Copyright 2025 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "gemma/bindings/context.h"
|
||||
|
||||
#include <stddef.h>
|
||||
#include <string.h> // strncpy
|
||||
|
||||
#include <memory>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "evals/benchmark_helper.h" // InitGenerator
|
||||
#include "gemma/gemma.h"
|
||||
#include "gemma/gemma_args.h"
|
||||
#include "gemma/tokenizer.h" // WrapAndTokenize
|
||||
#include "util/threading.h"
|
||||
#include "util/threading_context.h"
|
||||
#include "hwy/profiler.h"
|
||||
#include "hwy/timer.h"
|
||||
|
||||
#ifdef _WIN32
|
||||
#include <Windows.h>
|
||||
#endif
|
||||
|
||||
#include "gemma/kv_cache.h"
|
||||
#include "paligemma/image.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
// ConversationData constructor implementation
|
||||
ConversationData::ConversationData(const ModelConfig& model_config,
|
||||
const InferenceArgs& inference_args,
|
||||
const Allocator& allocator)
|
||||
: kv_cache(
|
||||
std::make_unique<KVCache>(model_config, inference_args, allocator)),
|
||||
abs_pos(0) {}
|
||||
|
||||
// ConversationData copy constructor implementation
|
||||
ConversationData::ConversationData(const ConversationData& other)
|
||||
: kv_cache(nullptr), abs_pos(other.abs_pos) {
|
||||
if (other.kv_cache) {
|
||||
kv_cache = std::make_unique<KVCache>(other.kv_cache->Copy());
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize static members
|
||||
GemmaLogCallback GemmaContext::s_log_callback = nullptr;
|
||||
void* GemmaContext::s_log_user_data = nullptr;
|
||||
|
||||
GemmaContext* GemmaContext::Create(const char* tokenizer_path,
|
||||
const char* weights_path,
|
||||
int max_generated_tokens) {
|
||||
std::stringstream ss;
|
||||
ss << "Creating GemmaContext with tokenizer_path: "
|
||||
<< (tokenizer_path ? tokenizer_path : "null")
|
||||
<< ", weights_path: " << (weights_path ? weights_path : "null")
|
||||
<< ", max_generated_tokens: " << max_generated_tokens;
|
||||
LogDebug(ss.str().c_str());
|
||||
|
||||
ThreadingArgs threading_args;
|
||||
threading_args.spin = gcpp::Tristate::kFalse;
|
||||
|
||||
LoaderArgs loader(tokenizer_path, weights_path);
|
||||
LogDebug("LoaderArgs created");
|
||||
|
||||
// Initialize cached args
|
||||
LogDebug("Initializing inference args");
|
||||
InferenceArgs inference_args;
|
||||
inference_args.Init();
|
||||
inference_args.max_generated_tokens = max_generated_tokens;
|
||||
inference_args.temperature = 0.7f;
|
||||
inference_args.top_k = 1;
|
||||
inference_args.deterministic = false;
|
||||
|
||||
ss.str("");
|
||||
ss << "Inference args initialized with max_tokens: " << max_generated_tokens
|
||||
<< ", temperature: " << inference_args.temperature
|
||||
<< ", top_k: " << inference_args.top_k << ", deterministic: "
|
||||
<< (inference_args.deterministic ? "true" : "false");
|
||||
LogDebug(ss.str().c_str());
|
||||
|
||||
return new GemmaContext(loader, inference_args, threading_args,
|
||||
max_generated_tokens);
|
||||
}
|
||||
|
||||
GemmaContext::GemmaContext(const LoaderArgs& loader,
|
||||
const InferenceArgs& inference_args,
|
||||
const ThreadingArgs& threading_args,
|
||||
int max_generated_tokens)
|
||||
: inference_args(inference_args),
|
||||
threading_args(threading_args),
|
||||
ctx(UpdateArgs(threading_args, inference_args)),
|
||||
matmul_env(ctx),
|
||||
active_conversation_name("default"),
|
||||
model(loader, inference_args, matmul_env.ctx) {
|
||||
std::stringstream ss;
|
||||
|
||||
LogDebug("Creating initial ConversationData");
|
||||
// Create the initial ConversationData object using make_shared
|
||||
active_conversation = std::make_shared<ConversationData>(
|
||||
model.Config(), inference_args, ctx.allocator);
|
||||
|
||||
LogDebug(
|
||||
"Storing initial ConversationData in conversation_cache[\"default\"]");
|
||||
// Store the shared_ptr in the map under the "default" key
|
||||
conversation_cache["default"] = active_conversation;
|
||||
|
||||
LogDebug("GemmaContext constructor completed");
|
||||
}
|
||||
|
||||
// Internal implementation shared by Generate and GenerateMultimodal
|
||||
int GemmaContext::GenerateInternal(const char* prompt_string,
|
||||
const void* image_data, int image_width,
|
||||
int image_height, char* output,
|
||||
int max_output_chars,
|
||||
GemmaTokenCallback callback,
|
||||
void* user_data) {
|
||||
PROFILER_ZONE("Gen.Internal");
|
||||
size_t tokens_generated_this_turn = 0; // differentiates prefill from reply
|
||||
size_t prompt_size = 0;
|
||||
std::stringstream ss;
|
||||
result_buffer.clear();
|
||||
|
||||
InitGenerator(inference_args, gen);
|
||||
|
||||
// Ensure we have an active conversation
|
||||
if (!active_conversation || !active_conversation->kv_cache) {
|
||||
LogDebug("Generate called with null active_conversation or kv_cache");
|
||||
return -1;
|
||||
}
|
||||
|
||||
// callback function invoked for each generated token.
|
||||
auto stream_token = [&, callback, user_data](int token, float) {
|
||||
// Use abs_pos from the active conversation
|
||||
++(active_conversation->abs_pos);
|
||||
const bool in_prompt = tokens_generated_this_turn < prompt_size;
|
||||
const bool first_response_token = tokens_generated_this_turn == prompt_size;
|
||||
++tokens_generated_this_turn;
|
||||
if (in_prompt || model.Config().IsEOS(token)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
std::string token_text;
|
||||
HWY_ASSERT(model.Tokenizer().Decode(std::vector<int>{token}, &token_text));
|
||||
if (first_response_token) {
|
||||
token_text.erase(0, token_text.find_first_not_of(" \t\n"));
|
||||
}
|
||||
|
||||
// if we have a managed callback, pass it the token text
|
||||
if (callback) {
|
||||
if (!callback(token_text.c_str(), user_data)) {
|
||||
LogDebug("Callback returned false, stopping generation");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
result_buffer.append(token_text);
|
||||
return true;
|
||||
};
|
||||
|
||||
// set up runtime config
|
||||
TimingInfo timing_info = {};
|
||||
RuntimeConfig runtime_config = {.gen = &gen,
|
||||
.stream_token = stream_token,
|
||||
.use_spinning = threading_args.spin};
|
||||
inference_args.CopyTo(runtime_config);
|
||||
size_t prefix_end = 0;
|
||||
|
||||
const ModelConfig& model_config = model.Config();
|
||||
|
||||
// generate
|
||||
std::vector<int> prompt;
|
||||
const size_t pool_dim = model_config.vit_config.pool_dim;
|
||||
ImageTokens image_tokens(
|
||||
"image_tokens",
|
||||
image_data
|
||||
? Extents2D(model_config.vit_config.seq_len / (pool_dim * pool_dim),
|
||||
model_config.model_dim)
|
||||
: Extents2D(0, 0),
|
||||
ctx.allocator, MatPadding::kOdd);
|
||||
if (image_data != nullptr) {
|
||||
HWY_ASSERT(model_config.wrapping == PromptWrapping::PALIGEMMA ||
|
||||
model_config.wrapping == PromptWrapping::GEMMA_VLM);
|
||||
|
||||
Image image;
|
||||
image.Set(image_width, image_height, static_cast<const float*>(image_data));
|
||||
|
||||
// We may need to resize the supplied image depending on whether we're using
|
||||
// PaliGemma or Gemma 3.
|
||||
const size_t image_size = model_config.vit_config.image_size;
|
||||
image.Resize(image_size, image_size);
|
||||
|
||||
// Use the existing runtime_config defined earlier in the function.
|
||||
// RuntimeConfig runtime_config = { ... }; // This was already defined
|
||||
double image_tokens_start = hwy::platform::Now();
|
||||
// Pass the populated image object to GenerateImageTokens
|
||||
model.GenerateImageTokens(runtime_config,
|
||||
active_conversation->kv_cache->SeqLen(), image,
|
||||
image_tokens, matmul_env);
|
||||
double image_tokens_duration = hwy::platform::Now() - image_tokens_start;
|
||||
|
||||
ss.str("");
|
||||
ss << "\n\n[ Timing info ] Image token generation took: ";
|
||||
ss << static_cast<int>(image_tokens_duration * 1000) << " ms\n",
|
||||
LogDebug(ss.str().c_str());
|
||||
|
||||
prompt = WrapAndTokenize(
|
||||
model.Tokenizer(), model.ChatTemplate(), model_config.wrapping,
|
||||
active_conversation->abs_pos, prompt_string, image_tokens.Rows());
|
||||
runtime_config.image_tokens = &image_tokens;
|
||||
prompt_size = prompt.size();
|
||||
// The end of the prefix for prefix-LM style attention in Paligemma.
|
||||
// See Figure 2 of https://arxiv.org/abs/2407.07726.
|
||||
prefix_end = prompt_size;
|
||||
} else {
|
||||
// Text-only case (original logic)
|
||||
// Use abs_pos from the active conversation
|
||||
prompt = WrapAndTokenize(model.Tokenizer(), model.ChatTemplate(),
|
||||
model_config.wrapping,
|
||||
active_conversation->abs_pos, prompt_string);
|
||||
prompt_size = prompt.size();
|
||||
}
|
||||
|
||||
// Check if prompt generation failed (e.g., multimodal not implemented yet)
|
||||
if (prompt.empty() && image_data != nullptr) {
|
||||
// Already logged the error, just ensure we don't proceed.
|
||||
return -1;
|
||||
}
|
||||
|
||||
// Create a span from the prompt vector - Generate() expects a hwy::Span,
|
||||
// which has a different memory footprint to that of a std::vector.
|
||||
hwy::Span<const int> prompt_span(prompt.data(), prompt.size());
|
||||
|
||||
// Pass the KVCache object by reference from the active conversation
|
||||
model.Generate(runtime_config, prompt_span, active_conversation->abs_pos,
|
||||
prefix_end, *active_conversation->kv_cache, matmul_env,
|
||||
timing_info);
|
||||
|
||||
// prepare for next turn
|
||||
if (!inference_args.multiturn ||
|
||||
model_config.wrapping == PromptWrapping::PALIGEMMA) {
|
||||
// If not multiturn, or Paligemma (which handles turns differently),
|
||||
// reset the *active* conversation's position.
|
||||
active_conversation->abs_pos = 0;
|
||||
InitGenerator(inference_args, gen);
|
||||
} else {
|
||||
// Multi-turn Gemma: Rewind position in the active conversation
|
||||
// The last token was either EOS, then it should be ignored because it is
|
||||
// never part of the dialog, see Table 5 in the Gemma-2 paper:
|
||||
// https://arxiv.org/pdf/2408.00118
|
||||
// Or we have hit max_generated_tokens, then the last token will be lost.
|
||||
// (We could store it in stream_token, and then prepend to the next turn,
|
||||
// but it's not worth the complexity, as multi-turn with max_generated is
|
||||
// not a common use case.)
|
||||
// In either case, we need to rewind the active conversation's abs_pos by
|
||||
// one.
|
||||
HWY_ASSERT(active_conversation->abs_pos > 0);
|
||||
active_conversation->abs_pos--;
|
||||
}
|
||||
|
||||
// Copy result buffer to output C-string (ensure null termination)
|
||||
strncpy(output, result_buffer.c_str(), max_output_chars - 1);
|
||||
output[max_output_chars - 1] = '\0';
|
||||
|
||||
return static_cast<int>(strlen(output));
|
||||
}
|
||||
|
||||
// Public Generate method (wrapper for text-only)
|
||||
int GemmaContext::Generate(const char* prompt_string, char* output,
|
||||
int max_output_chars, GemmaTokenCallback callback,
|
||||
void* user_data) {
|
||||
// Call the internal implementation with null image_data and 0 dimensions
|
||||
return GenerateInternal(prompt_string, nullptr, 0, 0, output,
|
||||
max_output_chars, callback, user_data);
|
||||
}
|
||||
|
||||
// Public GenerateMultimodal method (wrapper)
|
||||
int GemmaContext::GenerateMultimodal(const char* prompt_string,
|
||||
const void* image_data, int image_width,
|
||||
int image_height, char* output,
|
||||
int max_output_chars,
|
||||
GemmaTokenCallback callback,
|
||||
void* user_data) {
|
||||
if (image_data == nullptr) {
|
||||
LogDebug(
|
||||
"GenerateMultimodal called with null image_data. Use Generate for "
|
||||
"text-only.");
|
||||
// Or potentially call GenerateInternal with null image_data anyway?
|
||||
// Returning error seems safer.
|
||||
return -1;
|
||||
}
|
||||
|
||||
return GenerateInternal(prompt_string, image_data, image_width, image_height,
|
||||
output, max_output_chars, callback, user_data);
|
||||
}
|
||||
|
||||
int GemmaContext::CountTokens(const char* text) {
|
||||
LogDebug("CountTokens method started");
|
||||
std::stringstream ss;
|
||||
ss << "CountTokens called with text: '" << (text ? text : "null") << "'";
|
||||
LogDebug(ss.str().c_str());
|
||||
|
||||
if (!text) {
|
||||
LogDebug("CountTokens failed: Invalid parameters");
|
||||
if (!text) LogDebug(" text is null");
|
||||
return -1;
|
||||
}
|
||||
|
||||
try {
|
||||
LogDebug("Creating text string");
|
||||
std::string text_str(text);
|
||||
|
||||
LogDebug("Creating tokens vector");
|
||||
std::vector<int> tokens;
|
||||
|
||||
LogDebug("Encoding text to tokens");
|
||||
HWY_ASSERT(model.Tokenizer().Encode(text_str, &tokens));
|
||||
|
||||
ss.str("");
|
||||
ss << "Text tokenized into " << tokens.size() << " tokens";
|
||||
LogDebug(ss.str().c_str());
|
||||
|
||||
LogDebug("CountTokens completed successfully");
|
||||
return static_cast<int>(tokens.size());
|
||||
} catch (...) {
|
||||
LogDebug("Unknown exception in CountTokens");
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
// Get the name of the currently active conversation
|
||||
const char* GemmaContext::GetCurrentConversation() {
|
||||
return active_conversation_name.c_str();
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
@ -0,0 +1,316 @@
|
|||
// Copyright 2025 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_CONTEXT_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_CONTEXT_H_
|
||||
|
||||
#include <memory> // For std::shared_ptr, std::make_shared
|
||||
#include <random>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
// Logging
|
||||
#ifdef _WIN32
|
||||
#include <windows.h>
|
||||
#else
|
||||
#include <stdio.h>
|
||||
#endif
|
||||
|
||||
#include "gemma/gemma.h"
|
||||
#include "gemma/gemma_args.h"
|
||||
#include "gemma/kv_cache.h"
|
||||
#include "ops/matmul.h" // MatMulEnv
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/highway.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
// Struct to hold data for a single conversation thread
|
||||
struct ConversationData {
|
||||
ConversationData(const ModelConfig& model_config,
|
||||
const InferenceArgs& inference_args,
|
||||
const Allocator& allocator);
|
||||
ConversationData(const ConversationData& other);
|
||||
|
||||
std::unique_ptr<KVCache> kv_cache;
|
||||
size_t abs_pos = 0;
|
||||
};
|
||||
|
||||
typedef bool (*GemmaTokenCallback)(const char* text, void* user_data);
|
||||
typedef void (*GemmaLogCallback)(const char* message, void* user_data);
|
||||
|
||||
class GemmaContext {
|
||||
private:
|
||||
GemmaContext(const LoaderArgs& loader, const InferenceArgs& inference_args,
|
||||
const ThreadingArgs& threading_args, int max_generated_tokens);
|
||||
|
||||
public:
|
||||
static GemmaContext* Create(const char* tokenizer_path,
|
||||
const char* weights_path,
|
||||
int max_generated_tokens);
|
||||
|
||||
// Returns length of generated text, or -1 on error
|
||||
int Generate(const char* prompt_string, char* output, int max_output_chars,
|
||||
GemmaTokenCallback callback, void* user_data);
|
||||
// Returns length of generated text, or -1 on error
|
||||
int GenerateMultimodal(const char* prompt_string, const void* image_data,
|
||||
int image_width, int image_height, char* output,
|
||||
int max_output_chars, GemmaTokenCallback callback,
|
||||
void* user_data);
|
||||
|
||||
// Returns number of tokens in text, or -1 on error
|
||||
int CountTokens(const char* text);
|
||||
|
||||
// Add new method to set logger
|
||||
static void SetLogCallback(GemmaLogCallback callback, void* user_data) {
|
||||
s_log_callback = callback;
|
||||
s_log_user_data = user_data;
|
||||
}
|
||||
|
||||
// Set max generated tokens
|
||||
void SetMaxGeneratedTokens(size_t value) {
|
||||
inference_args.max_generated_tokens = value;
|
||||
LogDebug("Setting max_generated_tokens to configured value");
|
||||
}
|
||||
|
||||
// Set multiturn flag (0 = disabled, 1 = enabled)
|
||||
void SetMultiturn(int value) {
|
||||
inference_args.multiturn = value;
|
||||
LogDebug("Setting multiturn to configured value");
|
||||
}
|
||||
|
||||
// Set temperature for token generation
|
||||
void SetTemperature(float value) {
|
||||
inference_args.temperature = value;
|
||||
LogDebug("Setting temperature to configured value");
|
||||
}
|
||||
|
||||
// Set top_k parameter for sampling
|
||||
void SetTopK(int value) {
|
||||
inference_args.top_k = value;
|
||||
LogDebug("Setting top_k to configured value");
|
||||
}
|
||||
|
||||
// Set deterministic flag
|
||||
void SetDeterministic(bool value) {
|
||||
inference_args.deterministic = value;
|
||||
// Reset the random number generator for deterministic generation
|
||||
if (value) {
|
||||
gen.seed(0x87654321);
|
||||
}
|
||||
LogDebug("Setting deterministic flag to configured value");
|
||||
}
|
||||
|
||||
// Set prefill_tbatch_size
|
||||
void SetPrefillTbatchSize(size_t value) {
|
||||
inference_args.prefill_tbatch_size = value;
|
||||
LogDebug("Setting prefill_tbatch_size to configured value");
|
||||
}
|
||||
|
||||
void SaveConversation() {
|
||||
if (!active_conversation || active_conversation_name.empty()) {
|
||||
if (!active_conversation) {
|
||||
LogDebug("SaveConversation: No active conversation to save.");
|
||||
} else { // active_conversation_name must be empty
|
||||
LogDebug(
|
||||
"SaveConversation: Active conversation name is empty. Cannot "
|
||||
"save.");
|
||||
}
|
||||
return;
|
||||
}
|
||||
std::string log_msg = "SaveConversation: Attempting to save '";
|
||||
log_msg += active_conversation_name;
|
||||
log_msg += "' to prewarmed_cache.";
|
||||
LogDebug(log_msg.c_str());
|
||||
|
||||
// Create a deep copy of the active_conversation via copy ctor.
|
||||
auto conversation_copy =
|
||||
std::make_shared<ConversationData>(*active_conversation);
|
||||
|
||||
// Store the deep copy in prewarmed_cache.
|
||||
// If a conversation with the same name already exists, it will be
|
||||
// overwritten. std::shared_ptr will handle the destruction of the old
|
||||
// object if it's being replaced.
|
||||
prewarmed_cache[active_conversation_name] = conversation_copy;
|
||||
|
||||
log_msg = "SaveConversation: Successfully saved '";
|
||||
log_msg += active_conversation_name;
|
||||
log_msg += "' to prewarmed_cache.";
|
||||
LogDebug(log_msg.c_str());
|
||||
}
|
||||
|
||||
// Reset the currently active conversation
|
||||
void ResetConversation() {
|
||||
if (active_conversation) {
|
||||
std::string log_prefix = "ResetConversation ('";
|
||||
log_prefix += active_conversation_name.empty() ? "[unnamed]"
|
||||
: active_conversation_name;
|
||||
log_prefix += "'): ";
|
||||
LogDebug((log_prefix + "Attempting to reset.").c_str());
|
||||
// Attempt to restore from prewarmed_cache first, regardless of name.
|
||||
auto it = prewarmed_cache.find(active_conversation_name);
|
||||
if (it != prewarmed_cache.end() && it->second && it->second->kv_cache) {
|
||||
// Found in prewarmed_cache and the cached entry is valid.
|
||||
LogDebug((log_prefix + "Found in prewarmed_cache. Restoring state.")
|
||||
.c_str());
|
||||
active_conversation->abs_pos = it->second->abs_pos;
|
||||
// Perform a deep copy of the KVCache from the prewarmed version.
|
||||
active_conversation->kv_cache =
|
||||
std::make_unique<KVCache>(it->second->kv_cache->Copy());
|
||||
LogDebug((log_prefix + "Successfully restored from prewarmed_cache.")
|
||||
.c_str());
|
||||
return;
|
||||
}
|
||||
|
||||
// If not found in prewarmed_cache or prewarmed_cache entry is invalid,
|
||||
// rewind to initial state.
|
||||
active_conversation->abs_pos = 0;
|
||||
// Replace the cache within the current ConversationData object
|
||||
active_conversation->kv_cache = std::make_unique<KVCache>(
|
||||
model.Config(), inference_args, ctx.allocator);
|
||||
|
||||
LogDebug((log_prefix + "Successfully rewound to initial state.").c_str());
|
||||
} else {
|
||||
LogDebug("Cannot reset conversation: active_conversation is null");
|
||||
}
|
||||
}
|
||||
|
||||
// Create a new named conversation
|
||||
bool CreateConversation(const char* conversation_name) {
|
||||
std::string name(conversation_name);
|
||||
if (conversation_cache.count(name)) {
|
||||
LogDebug("Conversation already exists");
|
||||
return false;
|
||||
}
|
||||
LogDebug("Creating new conversation");
|
||||
// Create a new ConversationData object using make_shared
|
||||
conversation_cache[name] = std::make_shared<ConversationData>(
|
||||
model.Config(), inference_args, ctx.allocator);
|
||||
return true;
|
||||
}
|
||||
|
||||
// Switch to a named conversation
|
||||
bool SwitchConversation(const char* conversation_name) {
|
||||
std::string name(conversation_name);
|
||||
auto it = conversation_cache.find(name);
|
||||
if (it == conversation_cache.end()) {
|
||||
LogDebug("Conversation not found");
|
||||
return false;
|
||||
}
|
||||
LogDebug("Switching active conversation");
|
||||
active_conversation = it->second;
|
||||
active_conversation_name = conversation_name;
|
||||
return true;
|
||||
}
|
||||
|
||||
// Delete a named conversation
|
||||
bool DeleteConversation(const char* conversation_name) {
|
||||
std::string name(conversation_name);
|
||||
auto it = conversation_cache.find(name);
|
||||
|
||||
if (it == conversation_cache.end()) {
|
||||
LogDebug("Conversation not found for deletion");
|
||||
return false;
|
||||
}
|
||||
if (name == "default") {
|
||||
LogDebug("Cannot delete the default conversation");
|
||||
return false;
|
||||
}
|
||||
if (it->second == active_conversation) {
|
||||
LogDebug("Cannot delete the currently active conversation");
|
||||
return false;
|
||||
}
|
||||
|
||||
LogDebug("Deleting conversation");
|
||||
conversation_cache.erase(it);
|
||||
|
||||
auto it2 = prewarmed_cache.find(name);
|
||||
if (it2 != prewarmed_cache.end()) {
|
||||
prewarmed_cache.erase(it2);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// Check if a named conversation exists
|
||||
bool HasConversation(const char* conversation_name) {
|
||||
std::string name(conversation_name);
|
||||
return conversation_cache.count(name);
|
||||
}
|
||||
|
||||
// Get the name of the currently active conversation
|
||||
const char* GetCurrentConversation();
|
||||
|
||||
private:
|
||||
// Internal implementation shared by Generate and GenerateMultimodal
|
||||
int GenerateInternal(const char* prompt_string,
|
||||
const void* image_data, // Null for text-only generation
|
||||
int image_width,
|
||||
int image_height,
|
||||
char* output, int max_output_chars,
|
||||
GemmaTokenCallback callback, void* user_data);
|
||||
|
||||
// Pointer to the currently active conversation's data
|
||||
std::shared_ptr<ConversationData> active_conversation;
|
||||
|
||||
// Cache of all named conversations
|
||||
std::unordered_map<std::string, std::shared_ptr<ConversationData>>
|
||||
conversation_cache;
|
||||
std::unordered_map<std::string, std::shared_ptr<ConversationData>>
|
||||
prewarmed_cache;
|
||||
|
||||
// Buffers (potentially could be moved into ConversationData if needed
|
||||
// per-conversation)
|
||||
std::string prompt_buffer;
|
||||
std::string result_buffer;
|
||||
std::vector<int> token_buffer;
|
||||
|
||||
// Cached args (remain global for the context)
|
||||
InferenceArgs inference_args;
|
||||
ThreadingArgs threading_args;
|
||||
ThreadingContext ctx;
|
||||
MatMulEnv matmul_env;
|
||||
|
||||
std::string active_conversation_name;
|
||||
|
||||
// Model itself (don't move this, needs to be below the args above)
|
||||
Gemma model;
|
||||
|
||||
// Random generator (remains global for the context)
|
||||
std::mt19937 gen;
|
||||
|
||||
// Static members for logging
|
||||
static GemmaLogCallback s_log_callback;
|
||||
static void* s_log_user_data;
|
||||
|
||||
// Use logging helper method to print messages into a managed callback if
|
||||
// necessary
|
||||
static void LogDebug(const char* message) {
|
||||
if (s_log_callback != nullptr) {
|
||||
s_log_callback(message, s_log_user_data);
|
||||
} else {
|
||||
#ifdef _WIN32
|
||||
OutputDebugStringA(message);
|
||||
#else
|
||||
printf("%s", message);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_CONTEXT_H_
|
||||
177
gemma/common.cc
177
gemma/common.cc
|
|
@ -1,177 +0,0 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "gemma/common.h"
|
||||
|
||||
#include <math.h> // sqrtf
|
||||
#include <stddef.h>
|
||||
#include <string.h>
|
||||
|
||||
#include <algorithm> // std::transform
|
||||
#include <cctype>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "compression/shared.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
constexpr const char* kModelFlags[] = {
|
||||
"2b-pt", "2b-it", // Gemma 2B
|
||||
"7b-pt", "7b-it", // Gemma 7B
|
||||
"gr2b-pt", "gr2b-it", // RecurrentGemma
|
||||
"tiny", // Gemma Tiny (mostly for debugging)
|
||||
"gemma2-2b-pt", "gemma2-2b-it", // Gemma2 2B
|
||||
"9b-pt", "9b-it", // Gemma2 9B
|
||||
"27b-pt", "27b-it", // Gemma2 27B
|
||||
"paligemma-224", // PaliGemma 224
|
||||
"paligemma-448", // PaliGemma 448
|
||||
"paligemma2-3b-224", // PaliGemma2 3B 224
|
||||
"paligemma2-3b-448", // PaliGemma2 3B 448
|
||||
"paligemma2-10b-224", // PaliGemma2 10B 224
|
||||
"paligemma2-10b-448", // PaliGemma2 10B 448
|
||||
"gemma3-4b", // Gemma3 4B
|
||||
"gemma3-1b", // Gemma3 1B
|
||||
"gemma3-12b", // Gemma3 12B
|
||||
"gemma3-27b", // Gemma3 27B
|
||||
};
|
||||
constexpr Model kModelTypes[] = {
|
||||
Model::GEMMA_2B, Model::GEMMA_2B, // Gemma 2B
|
||||
Model::GEMMA_7B, Model::GEMMA_7B, // Gemma 7B
|
||||
Model::GRIFFIN_2B, Model::GRIFFIN_2B, // RecurrentGemma
|
||||
Model::GEMMA_TINY, // Gemma Tiny
|
||||
Model::GEMMA2_2B, Model::GEMMA2_2B, // Gemma2 2B
|
||||
Model::GEMMA2_9B, Model::GEMMA2_9B, // Gemma2 9B
|
||||
Model::GEMMA2_27B, Model::GEMMA2_27B, // Gemma2 27B
|
||||
Model::PALIGEMMA_224, // PaliGemma 224
|
||||
Model::PALIGEMMA_448, // PaliGemma 448
|
||||
Model::PALIGEMMA2_3B_224, // PaliGemma2 3B 224
|
||||
Model::PALIGEMMA2_3B_448, // PaliGemma2 3B 448
|
||||
Model::PALIGEMMA2_10B_224, // PaliGemma2 10B 224
|
||||
Model::PALIGEMMA2_10B_448, // PaliGemma2 10B 448
|
||||
Model::GEMMA3_4B, // Gemma3 4B
|
||||
Model::GEMMA3_1B, // Gemma3 1B
|
||||
Model::GEMMA3_12B, // Gemma3 12B
|
||||
Model::GEMMA3_27B, // Gemma3 27B
|
||||
};
|
||||
constexpr PromptWrapping kPromptWrapping[] = {
|
||||
PromptWrapping::GEMMA_PT, PromptWrapping::GEMMA_IT, // Gemma 2B
|
||||
PromptWrapping::GEMMA_PT, PromptWrapping::GEMMA_IT, // Gemma 7B
|
||||
PromptWrapping::GEMMA_PT, PromptWrapping::GEMMA_IT, // RecurrentGemma
|
||||
PromptWrapping::GEMMA_IT, // Gemma Tiny
|
||||
PromptWrapping::GEMMA_PT, PromptWrapping::GEMMA_IT, // Gemma2 2B
|
||||
PromptWrapping::GEMMA_PT, PromptWrapping::GEMMA_IT, // Gemma2 9B
|
||||
PromptWrapping::GEMMA_PT, PromptWrapping::GEMMA_IT, // Gemma2 27B
|
||||
PromptWrapping::PALIGEMMA, PromptWrapping::PALIGEMMA, // PaliGemma 224/448
|
||||
PromptWrapping::PALIGEMMA, PromptWrapping::PALIGEMMA, // PG2 3B 224/448
|
||||
PromptWrapping::PALIGEMMA, PromptWrapping::PALIGEMMA, // PG2 10B 224/448
|
||||
PromptWrapping::GEMMA_VLM, // Gemma3 4B
|
||||
PromptWrapping::GEMMA_PT, // Gemma3 1B
|
||||
PromptWrapping::GEMMA_VLM, // Gemma3 12B
|
||||
PromptWrapping::GEMMA_VLM, // Gemma3 27B
|
||||
};
|
||||
|
||||
constexpr size_t kNumModelFlags = std::size(kModelFlags);
|
||||
static_assert(kNumModelFlags == std::size(kModelTypes));
|
||||
static_assert(kNumModelFlags == std::size(kPromptWrapping));
|
||||
|
||||
const char* ParseModelTypeAndWrapping(const std::string& model_flag,
|
||||
Model& model, PromptWrapping& wrapping) {
|
||||
static std::string kErrorMessageBuffer =
|
||||
"Invalid or missing model flag, need to specify one of ";
|
||||
for (size_t i = 0; i + 1 < kNumModelFlags; ++i) {
|
||||
kErrorMessageBuffer.append(kModelFlags[i]);
|
||||
kErrorMessageBuffer.append(", ");
|
||||
}
|
||||
kErrorMessageBuffer.append(kModelFlags[kNumModelFlags - 1]);
|
||||
kErrorMessageBuffer.append(".");
|
||||
std::string model_type_lc = model_flag;
|
||||
std::transform(model_type_lc.begin(), model_type_lc.end(),
|
||||
model_type_lc.begin(), ::tolower);
|
||||
for (size_t i = 0; i < kNumModelFlags; ++i) {
|
||||
if (kModelFlags[i] == model_type_lc) {
|
||||
model = kModelTypes[i];
|
||||
wrapping = kPromptWrapping[i];
|
||||
HWY_ASSERT(std::string(ModelString(model, wrapping)) == model_type_lc);
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
return kErrorMessageBuffer.c_str();
|
||||
}
|
||||
|
||||
const char* ModelString(Model model, PromptWrapping wrapping) {
|
||||
for (size_t i = 0; i < kNumModelFlags; i++) {
|
||||
if (kModelTypes[i] == model && kPromptWrapping[i] == wrapping)
|
||||
return kModelFlags[i];
|
||||
}
|
||||
HWY_ABORT("Unknown model %d wrapping %d\n", static_cast<int>(model),
|
||||
static_cast<int>(wrapping));
|
||||
}
|
||||
|
||||
const char* StringFromType(Type type) {
|
||||
return kTypeStrings[static_cast<size_t>(type)];
|
||||
}
|
||||
|
||||
const char* ParseType(const std::string& type_string, Type& type) {
|
||||
constexpr size_t kNum = std::size(kTypeStrings);
|
||||
static std::string kErrorMessageBuffer =
|
||||
"Invalid or missing type, need to specify one of ";
|
||||
for (size_t i = 0; i + 1 < kNum; ++i) {
|
||||
kErrorMessageBuffer.append(kTypeStrings[i]);
|
||||
kErrorMessageBuffer.append(", ");
|
||||
}
|
||||
kErrorMessageBuffer.append(kTypeStrings[kNum - 1]);
|
||||
kErrorMessageBuffer.append(".");
|
||||
std::string type_lc = type_string;
|
||||
std::transform(type_lc.begin(), type_lc.end(), type_lc.begin(), ::tolower);
|
||||
for (size_t i = 0; i < kNum; ++i) {
|
||||
if (kTypeStrings[i] == type_lc) {
|
||||
type = static_cast<Type>(i);
|
||||
HWY_ASSERT(std::string(StringFromType(type)) == type_lc);
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
return kErrorMessageBuffer.c_str();
|
||||
}
|
||||
|
||||
void Wrap(const ModelInfo& info, size_t pos, std::string& prompt) {
|
||||
|
||||
// Instruction-tuned models are trained to expect control tokens.
|
||||
if (info.wrapping == PromptWrapping::GEMMA_IT) {
|
||||
// Prepend "<end_of_turn>" if this is a multi-turn dialogue continuation.
|
||||
const std::string start = (pos == 0)
|
||||
? "<start_of_turn>user\n"
|
||||
: "<end_of_turn>\n<start_of_turn>user\n";
|
||||
prompt = start + prompt + "<end_of_turn>\n<start_of_turn>model\n";
|
||||
}
|
||||
}
|
||||
|
||||
float EmbeddingScaling(size_t model_dim) {
|
||||
// Round to bf16 to match Gemma's Embedder, which casts before mul.
|
||||
return hwy::ConvertScalarTo<float>(hwy::ConvertScalarTo<hwy::bfloat16_t>(
|
||||
sqrtf(static_cast<float>(model_dim))));
|
||||
}
|
||||
|
||||
float ChooseQueryScale(const ModelConfig& config) {
|
||||
if (config.query_scale == QueryScaleType::SqrtModelDimDivNumHeads)
|
||||
return 1.0f / sqrtf(static_cast<float>(config.model_dim /
|
||||
config.layer_configs[0].heads));
|
||||
// QueryScaleType::SqrtKeySize
|
||||
return 1.0f / sqrtf(static_cast<float>(config.layer_configs[0].qkv_dim));
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
@ -1,57 +0,0 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_H_
|
||||
|
||||
#include <stddef.h>
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "compression/shared.h" // PromptWrapping
|
||||
#include "gemma/configs.h" // IWYU pragma: export
|
||||
#include "hwy/base.h" // ConvertScalarTo
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
// Struct to bundle model information.
|
||||
struct ModelInfo {
|
||||
Model model;
|
||||
PromptWrapping wrapping;
|
||||
Type weight;
|
||||
};
|
||||
|
||||
// Returns error string or nullptr if OK.
|
||||
// Thread-hostile.
|
||||
const char* ParseModelTypeAndWrapping(const std::string& model_flag,
|
||||
Model& model, PromptWrapping& wrapping);
|
||||
const char* ParseType(const std::string& type_string, Type& type);
|
||||
|
||||
// Inverse of ParseModelTypeAndWrapping.
|
||||
const char* ModelString(Model model, PromptWrapping wrapping);
|
||||
const char* StringFromType(Type type);
|
||||
|
||||
// Wraps the given prompt using the expected control tokens for IT models.
|
||||
void Wrap(const ModelInfo& info, size_t pos, std::string& prompt);
|
||||
|
||||
// Returns the scale value to use for the embedding (basically sqrt model_dim).
|
||||
float EmbeddingScaling(size_t model_dim);
|
||||
|
||||
// Returns the scale value to use for the query in the attention computation.
|
||||
float ChooseQueryScale(const ModelConfig& config);
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_H_
|
||||
575
gemma/configs.cc
575
gemma/configs.cc
|
|
@ -15,22 +15,31 @@
|
|||
|
||||
#include "gemma/configs.h"
|
||||
|
||||
#include <cstddef>
|
||||
#include <iostream>
|
||||
#include <stddef.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "compression/types.h" // Type
|
||||
#include "io/fields.h" // IFields
|
||||
#include "io/io.h" // Path
|
||||
#include "hwy/base.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
static constexpr size_t kVocabSize = 256000;
|
||||
|
||||
static constexpr size_t kGemmaV3VocabSize = 262144;
|
||||
|
||||
static ModelConfig ConfigNoSSM() {
|
||||
ModelConfig config;
|
||||
config.scale_names = {"att_ein", "qkv_ein", "gr_lin_x_w", "gr_lin_y_w",
|
||||
"gr_lin_out_w", "gr_gate_w", "gating_ein", "linear_w"};
|
||||
config.scale_base_names = {"att_ein", "qkv_ein", "gr_lin_x_w",
|
||||
"gr_lin_y_w", "gr_lin_out_w", "gr_gate_w",
|
||||
"gating_ein", "linear_w"};
|
||||
return config;
|
||||
}
|
||||
|
||||
static ModelConfig ConfigBaseGemmaV1() { return ConfigNoSSM(); }
|
||||
|
||||
static ModelConfig ConfigBaseGemmaV2() {
|
||||
ModelConfig config = ConfigNoSSM();
|
||||
config.att_cap = 50.0f;
|
||||
|
|
@ -54,17 +63,17 @@ static LayerConfig LayerConfigGemma2_27B(size_t model_dim) {
|
|||
|
||||
static ModelConfig ConfigGemma2_27B() {
|
||||
ModelConfig config = ConfigBaseGemmaV2();
|
||||
config.model_name = "Gemma2_27B";
|
||||
config.display_name = "Gemma2_27B";
|
||||
config.model = Model::GEMMA2_27B;
|
||||
config.model_dim = 4608;
|
||||
config.vocab_size = kVocabSize;
|
||||
config.seq_len = 8192;
|
||||
config.max_seq_len = 8192;
|
||||
LayerConfig layer_config = LayerConfigGemma2_27B(config.model_dim);
|
||||
config.layer_configs = {46, layer_config};
|
||||
config.num_tensor_scales = 4 * config.layer_configs.size();
|
||||
config.num_layers = 46;
|
||||
config.layer_configs = {config.num_layers, layer_config};
|
||||
config.query_scale = QueryScaleType::SqrtModelDimDivNumHeads;
|
||||
config.attention_window_sizes =
|
||||
RepeatedAttentionWindowSizes<46, 2>({4096, 8192});
|
||||
RepeatedAttentionWindowSizes<46, 2>({4096, config.max_seq_len});
|
||||
return config;
|
||||
}
|
||||
|
||||
|
|
@ -82,17 +91,17 @@ static LayerConfig LayerConfigGemma2_9B(size_t model_dim) {
|
|||
|
||||
static ModelConfig ConfigGemma2_9B() {
|
||||
ModelConfig config = ConfigBaseGemmaV2();
|
||||
config.model_name = "Gemma2_9B";
|
||||
config.display_name = "Gemma2_9B";
|
||||
config.model = Model::GEMMA2_9B;
|
||||
config.model_dim = 3584;
|
||||
config.vocab_size = kVocabSize;
|
||||
config.seq_len = 8192;
|
||||
config.max_seq_len = 8192;
|
||||
LayerConfig layer_config = LayerConfigGemma2_9B(config.model_dim);
|
||||
config.layer_configs = {42, layer_config};
|
||||
config.num_tensor_scales = 4 * config.layer_configs.size();
|
||||
config.num_layers = 42;
|
||||
config.layer_configs = {config.num_layers, layer_config};
|
||||
config.query_scale = QueryScaleType::SqrtKeySize;
|
||||
config.attention_window_sizes =
|
||||
RepeatedAttentionWindowSizes<42, 2>({4096, 8192});
|
||||
RepeatedAttentionWindowSizes<42, 2>({4096, config.max_seq_len});
|
||||
return config;
|
||||
}
|
||||
|
||||
|
|
@ -110,66 +119,17 @@ static LayerConfig LayerConfigGemma2_2B(size_t model_dim) {
|
|||
|
||||
static ModelConfig ConfigGemma2_2B() {
|
||||
ModelConfig config = ConfigBaseGemmaV2();
|
||||
config.model_name = "Gemma2_2B";
|
||||
config.display_name = "Gemma2_2B";
|
||||
config.model = Model::GEMMA2_2B;
|
||||
config.model_dim = 2304;
|
||||
config.vocab_size = kVocabSize;
|
||||
config.seq_len = 8192;
|
||||
config.max_seq_len = 8192;
|
||||
LayerConfig layer_config = LayerConfigGemma2_2B(config.model_dim);
|
||||
config.layer_configs = {26, layer_config};
|
||||
config.num_tensor_scales = 4 * config.layer_configs.size();
|
||||
config.num_layers = 26;
|
||||
config.layer_configs = {config.num_layers, layer_config};
|
||||
config.query_scale = QueryScaleType::SqrtKeySize;
|
||||
config.attention_window_sizes =
|
||||
RepeatedAttentionWindowSizes<26, 2>({4096, 8192});
|
||||
return config;
|
||||
}
|
||||
|
||||
static LayerConfig LayerConfigGemma7B(size_t model_dim) {
|
||||
LayerConfig config;
|
||||
config.model_dim = model_dim;
|
||||
config.ff_hidden_dim = 16 * 3072 / 2; // = 24576
|
||||
config.heads = 16;
|
||||
config.kv_heads = 16;
|
||||
config.qkv_dim = 256;
|
||||
return config;
|
||||
}
|
||||
|
||||
static ModelConfig ConfigGemma7B() {
|
||||
ModelConfig config = ConfigBaseGemmaV1();
|
||||
config.model_name = "Gemma7B";
|
||||
config.model = Model::GEMMA_7B;
|
||||
config.model_dim = 3072;
|
||||
config.vocab_size = kVocabSize;
|
||||
config.seq_len = kSeqLen;
|
||||
LayerConfig layer_config = LayerConfigGemma7B(config.model_dim);
|
||||
config.layer_configs = {28, layer_config};
|
||||
config.num_tensor_scales = 4 * config.layer_configs.size();
|
||||
config.query_scale = QueryScaleType::SqrtKeySize;
|
||||
config.attention_window_sizes = FixedAttentionWindowSizes<28>(kSeqLen);
|
||||
return config;
|
||||
}
|
||||
|
||||
static LayerConfig LayerConfigGemma2B(size_t model_dim) {
|
||||
LayerConfig config;
|
||||
config.model_dim = model_dim;
|
||||
config.ff_hidden_dim = 16 * 2048 / 2; // = 16384
|
||||
config.heads = 8;
|
||||
config.kv_heads = 1;
|
||||
config.qkv_dim = 256;
|
||||
return config;
|
||||
}
|
||||
|
||||
static ModelConfig ConfigGemma2B() {
|
||||
ModelConfig config = ConfigBaseGemmaV1();
|
||||
config.model_name = "Gemma2B";
|
||||
config.model = Model::GEMMA_2B;
|
||||
config.model_dim = 2048;
|
||||
config.vocab_size = kVocabSize;
|
||||
config.seq_len = kSeqLen;
|
||||
LayerConfig layer_config = LayerConfigGemma2B(config.model_dim);
|
||||
config.layer_configs = {18, layer_config};
|
||||
config.num_tensor_scales = 4 * config.layer_configs.size();
|
||||
config.attention_window_sizes = FixedAttentionWindowSizes<18>(kSeqLen);
|
||||
RepeatedAttentionWindowSizes<26, 2>({4096, config.max_seq_len});
|
||||
return config;
|
||||
}
|
||||
|
||||
|
|
@ -185,17 +145,18 @@ static LayerConfig LayerConfigGemmaTiny(size_t model_dim) {
|
|||
|
||||
static ModelConfig ConfigGemmaTiny() {
|
||||
ModelConfig config = ConfigNoSSM();
|
||||
config.model_name = "GemmaTiny";
|
||||
config.display_name = "GemmaTiny";
|
||||
config.model = Model::GEMMA_TINY;
|
||||
config.model_dim = 128;
|
||||
config.vocab_size = 64;
|
||||
config.seq_len = 32;
|
||||
config.wrapping = PromptWrapping::GEMMA_IT;
|
||||
config.model_dim = 32;
|
||||
config.vocab_size = 32; // at least two f32 vectors
|
||||
config.max_seq_len = 32;
|
||||
LayerConfig layer_config = LayerConfigGemmaTiny(config.model_dim);
|
||||
config.layer_configs = {3, layer_config};
|
||||
config.num_tensor_scales = 4 * config.layer_configs.size();
|
||||
config.num_layers = 2;
|
||||
config.layer_configs = {config.num_layers, layer_config};
|
||||
config.query_scale = QueryScaleType::SqrtKeySize;
|
||||
config.attention_window_sizes = FixedAttentionWindowSizes<3>(32);
|
||||
// This is required for optimize_test to pass.
|
||||
config.attention_window_sizes = FixedAttentionWindowSizes<2>(32);
|
||||
config.att_cap = 50.0f;
|
||||
config.final_cap = 30.0f;
|
||||
config.eos_id = 11;
|
||||
config.secondary_eos_id = 11;
|
||||
|
|
@ -223,23 +184,23 @@ static LayerConfig LayerConfigGriffin2B(size_t model_dim) {
|
|||
|
||||
static ModelConfig ConfigGriffin2B() {
|
||||
ModelConfig config = ConfigNoSSM();
|
||||
config.model_name = "Griffin2B";
|
||||
config.display_name = "Griffin2B";
|
||||
config.model = Model::GRIFFIN_2B;
|
||||
// Griffin uses local attention, so kSeqLen is actually the local attention
|
||||
// window.
|
||||
// Griffin uses local attention, so max_seq_len is actually the local
|
||||
// attention window.
|
||||
config.model_dim = 2560;
|
||||
config.vocab_size = kVocabSize;
|
||||
config.seq_len = 2048;
|
||||
config.max_seq_len = 2048;
|
||||
LayerConfig layer_config = LayerConfigGriffin2B(config.model_dim);
|
||||
config.layer_configs = {26, layer_config};
|
||||
for (size_t i = 2; i < config.layer_configs.size(); i += 3) {
|
||||
config.num_layers = 26;
|
||||
config.layer_configs = {config.num_layers, layer_config};
|
||||
for (size_t i = 2; i < config.num_layers; i += 3) {
|
||||
config.layer_configs[i].type = LayerAttentionType::kGemma;
|
||||
config.layer_configs[i].griffin_dim = 0;
|
||||
}
|
||||
config.num_tensor_scales = 140;
|
||||
config.attention_window_sizes = FixedAttentionWindowSizes<26>(config.seq_len);
|
||||
config.attention_window_sizes =
|
||||
FixedAttentionWindowSizes<26>(config.max_seq_len);
|
||||
config.use_local_attention = true;
|
||||
// This is required for optimize_test to pass.
|
||||
config.final_cap = 0.0f;
|
||||
return config;
|
||||
}
|
||||
|
|
@ -273,26 +234,10 @@ static void AddVitConfig(ModelConfig& config, size_t image_size = 224) {
|
|||
config.vit_config.num_scales = 4 * config.vit_config.layer_configs.size();
|
||||
}
|
||||
|
||||
static ModelConfig ConfigPaliGemma_224() {
|
||||
ModelConfig config = ConfigGemma2B();
|
||||
config.model_name = "PaliGemma_224";
|
||||
config.model = Model::PALIGEMMA_224;
|
||||
AddVitConfig(config);
|
||||
return config;
|
||||
}
|
||||
|
||||
static ModelConfig ConfigPaliGemma_448() {
|
||||
ModelConfig config = ConfigGemma2B();
|
||||
config.model_name = "PaliGemma_448";
|
||||
config.model = Model::PALIGEMMA_448;
|
||||
AddVitConfig(config, /*image_size=*/448);
|
||||
return config;
|
||||
}
|
||||
|
||||
ModelConfig GetVitConfig(const ModelConfig& config) {
|
||||
ModelConfig vit_config = ConfigNoSSM();
|
||||
vit_config.model_dim = config.vit_config.model_dim;
|
||||
vit_config.seq_len = config.vit_config.seq_len;
|
||||
vit_config.max_seq_len = config.vit_config.seq_len;
|
||||
vit_config.layer_configs = config.vit_config.layer_configs;
|
||||
vit_config.pool_dim = config.vit_config.pool_dim;
|
||||
vit_config.wrapping = config.wrapping;
|
||||
|
|
@ -303,32 +248,36 @@ ModelConfig GetVitConfig(const ModelConfig& config) {
|
|||
|
||||
static ModelConfig ConfigPaliGemma2_3B_224() {
|
||||
ModelConfig config = ConfigGemma2_2B();
|
||||
config.model_name = "PaliGemma2_3B_224";
|
||||
config.display_name = "PaliGemma2_3B_224";
|
||||
config.model = Model::PALIGEMMA2_3B_224;
|
||||
config.wrapping = PromptWrapping::PALIGEMMA;
|
||||
AddVitConfig(config);
|
||||
return config;
|
||||
}
|
||||
|
||||
static ModelConfig ConfigPaliGemma2_3B_448() {
|
||||
ModelConfig config = ConfigGemma2_2B();
|
||||
config.model_name = "PaliGemma2_3B_448";
|
||||
config.display_name = "PaliGemma2_3B_448";
|
||||
config.model = Model::PALIGEMMA2_3B_448;
|
||||
config.wrapping = PromptWrapping::PALIGEMMA;
|
||||
AddVitConfig(config, /*image_size=*/448);
|
||||
return config;
|
||||
}
|
||||
|
||||
static ModelConfig ConfigPaliGemma2_10B_224() {
|
||||
ModelConfig config = ConfigGemma2_9B();
|
||||
config.model_name = "PaliGemma2_10B_224";
|
||||
config.display_name = "PaliGemma2_10B_224";
|
||||
config.model = Model::PALIGEMMA2_10B_224;
|
||||
config.wrapping = PromptWrapping::PALIGEMMA;
|
||||
AddVitConfig(config);
|
||||
return config;
|
||||
}
|
||||
|
||||
static ModelConfig ConfigPaliGemma2_10B_448() {
|
||||
ModelConfig config = ConfigGemma2_9B();
|
||||
config.model_name = "PaliGemma2_10B_448";
|
||||
config.display_name = "PaliGemma2_10B_448";
|
||||
config.model = Model::PALIGEMMA2_10B_448;
|
||||
config.wrapping = PromptWrapping::PALIGEMMA;
|
||||
AddVitConfig(config, /*image_size=*/448);
|
||||
return config;
|
||||
}
|
||||
|
|
@ -358,18 +307,19 @@ static LayerConfig LayerConfigGemma3_1B_LM(size_t model_dim) {
|
|||
|
||||
static ModelConfig ConfigGemma3_1B() {
|
||||
ModelConfig config = ConfigBaseGemmaV3();
|
||||
config.model_name = "Gemma3_1B";
|
||||
config.display_name = "Gemma3_1B";
|
||||
config.model = Model::GEMMA3_1B;
|
||||
config.wrapping = PromptWrapping::GEMMA_VLM;
|
||||
config.model_dim = 1152;
|
||||
config.vocab_size = 262144; // new vocab size / tokenizer
|
||||
config.seq_len = 32 * 1024;
|
||||
config.vocab_size = kGemmaV3VocabSize; // new vocab size / tokenizer
|
||||
config.max_seq_len = 32 * 1024;
|
||||
LayerConfig layer_config = LayerConfigGemma3_1B_LM(config.model_dim);
|
||||
config.layer_configs = {26, layer_config};
|
||||
config.num_tensor_scales = 4 * config.layer_configs.size();
|
||||
config.num_layers = 26;
|
||||
config.layer_configs = {config.num_layers, layer_config};
|
||||
config.query_scale = QueryScaleType::SqrtKeySize;
|
||||
// interleaved local / global attention
|
||||
config.attention_window_sizes = RepeatedAttentionWindowSizes<26, 6>(
|
||||
{512, 512, 512, 512, 512, config.seq_len});
|
||||
{512, 512, 512, 512, 512, config.max_seq_len});
|
||||
return config;
|
||||
}
|
||||
|
||||
|
|
@ -389,27 +339,29 @@ static LayerConfig LayerConfigGemma3_4B_LM(size_t model_dim) {
|
|||
// Until we have the SigLIP checkpoints included, we use the LM config directly.
|
||||
static ModelConfig ConfigGemma3_4B_LM() {
|
||||
ModelConfig config = ConfigBaseGemmaV3();
|
||||
config.model_name = "Gemma3_4B";
|
||||
config.display_name = "Gemma3_4B";
|
||||
config.model = Model::GEMMA3_4B;
|
||||
config.wrapping = PromptWrapping::GEMMA_VLM;
|
||||
config.model_dim = 2560;
|
||||
config.vocab_size = 262144; // new vocab size / tokenizer
|
||||
config.seq_len = 32 * 1024;
|
||||
config.vocab_size = kGemmaV3VocabSize; // new vocab size / tokenizer
|
||||
config.max_seq_len = 32 * 1024;
|
||||
LayerConfig layer_config = LayerConfigGemma3_4B_LM(config.model_dim);
|
||||
config.layer_configs = {34, layer_config};
|
||||
config.num_tensor_scales = 4 * config.layer_configs.size();
|
||||
config.num_layers = 34;
|
||||
config.layer_configs = {config.num_layers, layer_config};
|
||||
config.query_scale = QueryScaleType::SqrtKeySize;
|
||||
// interleaved local / global attention
|
||||
config.attention_window_sizes = RepeatedAttentionWindowSizes<34, 6>(
|
||||
{1024, 1024, 1024, 1024, 1024, config.seq_len});
|
||||
{1024, 1024, 1024, 1024, 1024, config.max_seq_len});
|
||||
return config;
|
||||
}
|
||||
|
||||
static ModelConfig ConfigGemma3_4B() {
|
||||
ModelConfig config = ConfigGemma3_4B_LM();
|
||||
config.model_name = "Gemma3_4B";
|
||||
config.display_name = "Gemma3_4B";
|
||||
config.model = Model::GEMMA3_4B;
|
||||
config.wrapping = PromptWrapping::GEMMA_VLM;
|
||||
AddVitConfig(config, /*image_size=*/896);
|
||||
config.vocab_size = 262144;
|
||||
config.vocab_size = kGemmaV3VocabSize;
|
||||
config.vit_config.pool_dim = 4;
|
||||
const size_t num_patches =
|
||||
config.vit_config.image_size / config.vit_config.patch_width;
|
||||
|
|
@ -436,27 +388,29 @@ static LayerConfig LayerConfigGemma3_12B_LM(size_t model_dim) {
|
|||
|
||||
static ModelConfig ConfigGemma3_12B_LM() {
|
||||
ModelConfig config = ConfigBaseGemmaV3();
|
||||
config.model_name = "Gemma3_12B";
|
||||
config.display_name = "Gemma3_12B";
|
||||
config.model = Model::GEMMA3_12B;
|
||||
config.wrapping = PromptWrapping::GEMMA_VLM;
|
||||
config.model_dim = 3840;
|
||||
config.vocab_size = 262144; // new vocab size / tokenizer
|
||||
config.seq_len = 32 * 1024;
|
||||
config.vocab_size = kGemmaV3VocabSize; // new vocab size / tokenizer
|
||||
config.max_seq_len = 32 * 1024;
|
||||
LayerConfig layer_config = LayerConfigGemma3_12B_LM(config.model_dim);
|
||||
config.layer_configs = {48, layer_config};
|
||||
config.num_tensor_scales = 4 * config.layer_configs.size();
|
||||
config.num_layers = 48;
|
||||
config.layer_configs = {config.num_layers, layer_config};
|
||||
config.query_scale = QueryScaleType::SqrtKeySize;
|
||||
// interleaved local / global attention
|
||||
config.attention_window_sizes = RepeatedAttentionWindowSizes<48, 6>(
|
||||
{1024, 1024, 1024, 1024, 1024, config.seq_len});
|
||||
{1024, 1024, 1024, 1024, 1024, config.max_seq_len});
|
||||
return config;
|
||||
}
|
||||
|
||||
static ModelConfig ConfigGemma3_12B() {
|
||||
ModelConfig config = ConfigGemma3_12B_LM();
|
||||
config.model_name = "Gemma3_12B";
|
||||
config.display_name = "Gemma3_12B";
|
||||
config.model = Model::GEMMA3_12B;
|
||||
config.wrapping = PromptWrapping::GEMMA_VLM;
|
||||
AddVitConfig(config, /*image_size=*/896);
|
||||
config.vocab_size = 262144;
|
||||
config.vocab_size = kGemmaV3VocabSize;
|
||||
config.vit_config.pool_dim = 4;
|
||||
const size_t num_patches =
|
||||
config.vit_config.image_size / config.vit_config.patch_width;
|
||||
|
|
@ -483,27 +437,29 @@ static LayerConfig LayerConfigGemma3_27B_LM(size_t model_dim) {
|
|||
|
||||
static ModelConfig ConfigGemma3_27B_LM() {
|
||||
ModelConfig config = ConfigBaseGemmaV3();
|
||||
config.model_name = "Gemma3_27B";
|
||||
config.display_name = "Gemma3_27B";
|
||||
config.model = Model::GEMMA3_27B;
|
||||
config.wrapping = PromptWrapping::GEMMA_VLM;
|
||||
config.model_dim = 5376;
|
||||
config.vocab_size = 262144; // new vocab size / tokenizer
|
||||
config.seq_len = 32 * 1024;
|
||||
config.vocab_size = kGemmaV3VocabSize; // new vocab size / tokenizer
|
||||
config.max_seq_len = 32 * 1024;
|
||||
LayerConfig layer_config = LayerConfigGemma3_27B_LM(config.model_dim);
|
||||
config.layer_configs = {62, layer_config};
|
||||
config.num_tensor_scales = 4 * config.layer_configs.size();
|
||||
config.num_layers = 62;
|
||||
config.layer_configs = {config.num_layers, layer_config};
|
||||
config.query_scale = QueryScaleType::SqrtKeySize;
|
||||
// interleaved local / global attention
|
||||
config.attention_window_sizes = RepeatedAttentionWindowSizes<62, 6>(
|
||||
{1024, 1024, 1024, 1024, 1024, config.seq_len});
|
||||
{1024, 1024, 1024, 1024, 1024, config.max_seq_len});
|
||||
return config;
|
||||
}
|
||||
|
||||
static ModelConfig ConfigGemma3_27B() {
|
||||
ModelConfig config = ConfigGemma3_27B_LM();
|
||||
config.model_name = "Gemma3_27B";
|
||||
config.display_name = "Gemma3_27B";
|
||||
config.model = Model::GEMMA3_27B;
|
||||
config.wrapping = PromptWrapping::GEMMA_VLM;
|
||||
AddVitConfig(config, /*image_size=*/896);
|
||||
config.vocab_size = 262144;
|
||||
config.vocab_size = kGemmaV3VocabSize;
|
||||
config.vit_config.pool_dim = 4;
|
||||
const size_t num_patches =
|
||||
config.vit_config.image_size / config.vit_config.patch_width;
|
||||
|
|
@ -515,12 +471,8 @@ static ModelConfig ConfigGemma3_27B() {
|
|||
return config;
|
||||
}
|
||||
|
||||
ModelConfig ConfigFromModel(Model model) {
|
||||
static ModelConfig ConfigFromModel(Model model) {
|
||||
switch (model) {
|
||||
case Model::GEMMA_2B:
|
||||
return ConfigGemma2B();
|
||||
case Model::GEMMA_7B:
|
||||
return ConfigGemma7B();
|
||||
case Model::GEMMA2_2B:
|
||||
return ConfigGemma2_2B();
|
||||
case Model::GEMMA2_9B:
|
||||
|
|
@ -531,10 +483,6 @@ ModelConfig ConfigFromModel(Model model) {
|
|||
return ConfigGriffin2B();
|
||||
case Model::GEMMA_TINY:
|
||||
return ConfigGemmaTiny();
|
||||
case Model::PALIGEMMA_224:
|
||||
return ConfigPaliGemma_224();
|
||||
case Model::PALIGEMMA_448:
|
||||
return ConfigPaliGemma_448();
|
||||
case Model::PALIGEMMA2_3B_224:
|
||||
return ConfigPaliGemma2_3B_224();
|
||||
case Model::PALIGEMMA2_3B_448:
|
||||
|
|
@ -556,124 +504,249 @@ ModelConfig ConfigFromModel(Model model) {
|
|||
}
|
||||
}
|
||||
|
||||
#define TEST_EQUAL(a, b) \
|
||||
if (a != b) { \
|
||||
if (debug) \
|
||||
std::cerr << #a << "=" << a << " != " << #b << "=" << b << "\n"; \
|
||||
result = false; \
|
||||
const char* ModelPrefix(Model model) {
|
||||
switch (model) {
|
||||
case Model::UNKNOWN:
|
||||
return "unknown";
|
||||
case Model::GEMMA2_2B:
|
||||
return "gemma2-2b";
|
||||
case Model::GEMMA2_9B:
|
||||
return "9b";
|
||||
case Model::GEMMA2_27B:
|
||||
return "27b";
|
||||
case Model::GRIFFIN_2B:
|
||||
return "gr2b";
|
||||
case Model::GEMMA_TINY:
|
||||
return "tiny";
|
||||
case Model::PALIGEMMA2_3B_224:
|
||||
return "paligemma2-3b-224";
|
||||
case Model::PALIGEMMA2_3B_448:
|
||||
return "paligemma2-3b-448";
|
||||
case Model::PALIGEMMA2_10B_224:
|
||||
return "paligemma2-10b-224";
|
||||
case Model::PALIGEMMA2_10B_448:
|
||||
return "paligemma2-10b-448";
|
||||
case Model::GEMMA3_4B:
|
||||
return "gemma3-4b";
|
||||
case Model::GEMMA3_1B:
|
||||
return "gemma3-1b";
|
||||
case Model::GEMMA3_12B:
|
||||
return "gemma3-12b";
|
||||
case Model::GEMMA3_27B:
|
||||
return "gemma3-27b";
|
||||
default:
|
||||
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
|
||||
}
|
||||
|
||||
#define RETURN_IF_NOT_EQUAL(a, b) \
|
||||
if (a != b) { \
|
||||
if (debug) \
|
||||
std::cerr << #a << "=" << a << " != " << #b << "=" << b << "\n"; \
|
||||
return false; \
|
||||
}
|
||||
|
||||
#define WARN_IF_NOT_EQUAL(a, b) \
|
||||
if (a != b) { \
|
||||
std::cerr << #a << "=" << a << " != " << #b << "=" << b << "\n"; \
|
||||
}
|
||||
|
||||
bool LayerConfig::TestEqual(const LayerConfig& other, bool partial,
|
||||
bool debug) const {
|
||||
bool result = true;
|
||||
// Optimized gating may not be set correctly in the c++ configs.
|
||||
if (debug) {
|
||||
WARN_IF_NOT_EQUAL(optimized_gating, other.optimized_gating)
|
||||
}
|
||||
TEST_EQUAL(model_dim, other.model_dim);
|
||||
TEST_EQUAL(griffin_dim, other.griffin_dim);
|
||||
TEST_EQUAL(ff_hidden_dim, other.ff_hidden_dim);
|
||||
TEST_EQUAL(heads, other.heads);
|
||||
TEST_EQUAL(kv_heads, other.kv_heads);
|
||||
TEST_EQUAL(qkv_dim, other.qkv_dim);
|
||||
TEST_EQUAL(conv1d_width, other.conv1d_width);
|
||||
if (!partial) {
|
||||
TEST_EQUAL(ff_biases, other.ff_biases);
|
||||
TEST_EQUAL(softmax_attn_output_biases, other.softmax_attn_output_biases);
|
||||
}
|
||||
TEST_EQUAL(static_cast<int>(post_norm), static_cast<int>(other.post_norm));
|
||||
TEST_EQUAL(static_cast<int>(type), static_cast<int>(other.type));
|
||||
TEST_EQUAL(static_cast<int>(activation), static_cast<int>(other.activation));
|
||||
TEST_EQUAL(static_cast<int>(post_qk), static_cast<int>(other.post_qk));
|
||||
return result;
|
||||
}
|
||||
|
||||
bool VitConfig::TestEqual(const VitConfig& other, bool partial,
|
||||
bool debug) const {
|
||||
bool result = true;
|
||||
TEST_EQUAL(model_dim, other.model_dim);
|
||||
TEST_EQUAL(seq_len, other.seq_len);
|
||||
if (!partial) {
|
||||
TEST_EQUAL(num_scales, other.num_scales);
|
||||
PromptWrapping ChooseWrapping(const Model model, Tristate wrapping) {
|
||||
if (IsPaliGemma(model)) {
|
||||
if (wrapping != Tristate::kDefault) {
|
||||
HWY_WARN("Ignoring unnecessary --wrapping for PaliGemma models.");
|
||||
}
|
||||
return PromptWrapping::PALIGEMMA;
|
||||
}
|
||||
TEST_EQUAL(patch_width, other.patch_width);
|
||||
TEST_EQUAL(image_size, other.image_size);
|
||||
RETURN_IF_NOT_EQUAL(layer_configs.size(), other.layer_configs.size());
|
||||
for (size_t i = 0; i < layer_configs.size(); ++i) {
|
||||
result &=
|
||||
layer_configs[i].TestEqual(other.layer_configs[i], partial, debug);
|
||||
if (IsVLM(model)) {
|
||||
if (wrapping != Tristate::kDefault) {
|
||||
HWY_WARN("Ignoring unnecessary --wrapping for VLM models.");
|
||||
}
|
||||
return PromptWrapping::GEMMA_VLM;
|
||||
}
|
||||
return result;
|
||||
// Default to IT unless --wrapping=0.
|
||||
return wrapping == Tristate::kFalse ? PromptWrapping::GEMMA_PT
|
||||
: PromptWrapping::GEMMA_IT;
|
||||
}
|
||||
|
||||
bool ModelConfig::TestEqual(const ModelConfig& other, bool partial,
|
||||
bool debug) const {
|
||||
bool result = true;
|
||||
TEST_EQUAL(model_family_version, other.model_family_version);
|
||||
// We don't care about model_name, model, wrapping, or weight being different,
|
||||
// but will output in debug mode if they are.
|
||||
if (debug) {
|
||||
WARN_IF_NOT_EQUAL(model_name, other.model_name);
|
||||
WARN_IF_NOT_EQUAL(static_cast<int>(model), static_cast<int>(other.model));
|
||||
WARN_IF_NOT_EQUAL(static_cast<int>(wrapping),
|
||||
static_cast<int>(other.wrapping));
|
||||
WARN_IF_NOT_EQUAL(static_cast<int>(weight), static_cast<int>(other.weight));
|
||||
ModelConfig::ModelConfig(const Model model, Type weight,
|
||||
PromptWrapping wrapping) {
|
||||
HWY_ASSERT(weight != Type::kUnknown);
|
||||
HWY_ASSERT(wrapping != PromptWrapping::kSentinel);
|
||||
this->model = model;
|
||||
if (model != Model::UNKNOWN) *this = ConfigFromModel(model);
|
||||
HWY_ASSERT(this->model == model);
|
||||
this->weight = weight;
|
||||
this->wrapping = wrapping;
|
||||
}
|
||||
|
||||
static Model FindModel(const std::string& specifier) {
|
||||
Model found_model = Model::UNKNOWN;
|
||||
ForEachModel([&](Model model) {
|
||||
// Some model names are prefixes of other model names
|
||||
const std::string prefix = std::string(ModelPrefix(model)) + "-";
|
||||
if (specifier.rfind(prefix, 0) == 0) { // Starts with prefix.
|
||||
// We only expect one match.
|
||||
HWY_ASSERT_M(found_model == Model::UNKNOWN, specifier.c_str());
|
||||
found_model = model;
|
||||
}
|
||||
});
|
||||
HWY_ASSERT_M(found_model != Model::UNKNOWN, specifier.c_str());
|
||||
return found_model;
|
||||
}
|
||||
|
||||
static Type FindType(const std::string& specifier) {
|
||||
Type found_type = Type::kUnknown;
|
||||
for (size_t i = 1; i < kNumTypes; ++i) {
|
||||
const Type type = static_cast<Type>(i);
|
||||
if (specifier.find(TypeName(type)) != std::string::npos) { // NOLINT
|
||||
// We only expect one match.
|
||||
HWY_ASSERT_M(found_type == Type::kUnknown, specifier.c_str());
|
||||
found_type = type;
|
||||
}
|
||||
}
|
||||
TEST_EQUAL(model_dim, other.model_dim);
|
||||
TEST_EQUAL(vocab_size, other.vocab_size);
|
||||
TEST_EQUAL(seq_len, other.seq_len);
|
||||
if (!partial) {
|
||||
TEST_EQUAL(num_tensor_scales, other.num_tensor_scales);
|
||||
HWY_ASSERT_M(found_type != Type::kUnknown, specifier.c_str());
|
||||
return found_type;
|
||||
}
|
||||
|
||||
static PromptWrapping FindWrapping(const std::string& specifier) {
|
||||
PromptWrapping found_wrapping = PromptWrapping::kSentinel;
|
||||
for (size_t i = 0; i < static_cast<size_t>(PromptWrapping::kSentinel); ++i) {
|
||||
const PromptWrapping w = static_cast<PromptWrapping>(i);
|
||||
if (specifier.find(WrappingSuffix(w)) != std::string::npos) { // NOLINT
|
||||
// We expect zero or one match.
|
||||
HWY_ASSERT_M(found_wrapping == PromptWrapping::kSentinel,
|
||||
specifier.c_str());
|
||||
found_wrapping = w;
|
||||
}
|
||||
}
|
||||
TEST_EQUAL(att_cap, other.att_cap);
|
||||
TEST_EQUAL(final_cap, other.final_cap);
|
||||
TEST_EQUAL(absolute_pe, other.absolute_pe);
|
||||
TEST_EQUAL(use_local_attention, other.use_local_attention);
|
||||
TEST_EQUAL(static_cast<int>(query_scale),
|
||||
static_cast<int>(other.query_scale));
|
||||
RETURN_IF_NOT_EQUAL(layer_configs.size(), other.layer_configs.size());
|
||||
for (size_t i = 0; i < layer_configs.size(); ++i) {
|
||||
result &=
|
||||
layer_configs[i].TestEqual(other.layer_configs[i], partial, debug);
|
||||
if (found_wrapping == PromptWrapping::kSentinel) {
|
||||
return ChooseWrapping(FindModel(specifier));
|
||||
}
|
||||
RETURN_IF_NOT_EQUAL(attention_window_sizes.size(),
|
||||
other.attention_window_sizes.size());
|
||||
for (size_t i = 0; i < attention_window_sizes.size(); ++i) {
|
||||
TEST_EQUAL(attention_window_sizes[i], other.attention_window_sizes[i]);
|
||||
return found_wrapping;
|
||||
}
|
||||
|
||||
// Obtains model/weight/wrapping by finding prefix and suffix strings.
|
||||
ModelConfig::ModelConfig(const std::string& specifier)
|
||||
: ModelConfig(FindModel(specifier), FindType(specifier),
|
||||
FindWrapping(specifier)) {}
|
||||
|
||||
std::string ModelConfig::Specifier() const {
|
||||
HWY_ASSERT(model != Model::UNKNOWN);
|
||||
HWY_ASSERT(weight != Type::kUnknown);
|
||||
HWY_ASSERT(wrapping != PromptWrapping::kSentinel);
|
||||
|
||||
std::string base_name = ModelPrefix(model);
|
||||
|
||||
base_name += '-';
|
||||
base_name += TypeName(weight);
|
||||
|
||||
if (wrapping != PromptWrapping::GEMMA_VLM &&
|
||||
wrapping != PromptWrapping::PALIGEMMA) {
|
||||
base_name += WrappingSuffix(wrapping);
|
||||
}
|
||||
if (!partial) {
|
||||
if (scale_names != other.scale_names) {
|
||||
result = false;
|
||||
if (debug) {
|
||||
std::cerr << "scale_names mismatch\n";
|
||||
|
||||
return base_name;
|
||||
}
|
||||
|
||||
// Returns whether all fields match.
|
||||
static bool AllEqual(const IFields& a, const IFields& b, bool print) {
|
||||
const std::vector<uint32_t> serialized_a = a.Write();
|
||||
const std::vector<uint32_t> serialized_b = b.Write();
|
||||
if (serialized_a != serialized_b) {
|
||||
if (print) {
|
||||
fprintf(stderr, "%s differs. Recommend generating a diff:\n", a.Name());
|
||||
a.Print();
|
||||
b.Print();
|
||||
}
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool LayerConfig::TestEqual(const LayerConfig& other, bool print) const {
|
||||
return AllEqual(*this, other, print);
|
||||
}
|
||||
|
||||
bool VitConfig::TestEqual(const VitConfig& other, bool print) const {
|
||||
return AllEqual(*this, other, print);
|
||||
}
|
||||
|
||||
bool ModelConfig::TestEqual(const ModelConfig& other, bool print) const {
|
||||
// Early out to guard the loop below; a differing number of layers will anyway
|
||||
// cause a mismatch.
|
||||
if (layer_configs.size() != other.layer_configs.size()) {
|
||||
if (print) {
|
||||
HWY_WARN("Layer configs size mismatch %zu vs %zu", layer_configs.size(),
|
||||
other.layer_configs.size());
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Copy so we can 'ignore' fields by setting them to the same value.
|
||||
ModelConfig a = *this;
|
||||
ModelConfig b = other;
|
||||
// Called by `OverwriteWithCanonical`, so ignore the fields it will set.
|
||||
a.display_name = b.display_name;
|
||||
a.model = b.model;
|
||||
|
||||
// The following are not yet set by config_converter.py, so we here ignore
|
||||
// them for purposes of comparison, and there overwrite the converter's config
|
||||
// with the canonical ModelConfig constructed via (deduced) enum, so that
|
||||
// these fields will be set.
|
||||
// `vit_config` is also not yet set, but we must not ignore it because
|
||||
// otherwise PaliGemma models will be indistinguishable for `configs_test`.
|
||||
a.pool_dim = b.pool_dim; // ViT
|
||||
a.eos_id = b.eos_id;
|
||||
a.secondary_eos_id = b.secondary_eos_id;
|
||||
a.scale_base_names = b.scale_base_names;
|
||||
for (size_t i = 0; i < a.layer_configs.size(); ++i) {
|
||||
a.layer_configs[i].optimized_gating = b.layer_configs[i].optimized_gating;
|
||||
}
|
||||
|
||||
return AllEqual(a, b, print);
|
||||
}
|
||||
|
||||
// Constructs the canonical ModelConfig for each model. If there is one for
|
||||
// which TestEqual returns true, overwrites `*this` with that and returns true.
|
||||
bool ModelConfig::OverwriteWithCanonical() {
|
||||
bool found = false;
|
||||
const bool print = false;
|
||||
ForEachModel([&](Model model) {
|
||||
const ModelConfig config(model, weight, wrapping);
|
||||
if (config.TestEqual(*this, print)) {
|
||||
HWY_ASSERT(!found); // Should only find one.
|
||||
found = true;
|
||||
*this = config;
|
||||
}
|
||||
});
|
||||
return found;
|
||||
}
|
||||
|
||||
Model DeduceModel(const Path& blob_path, size_t layers, int layer_types) {
|
||||
switch (layers) {
|
||||
case 2:
|
||||
return Model::GEMMA_TINY;
|
||||
case 26:
|
||||
if (layer_types & kDeducedGriffin) return Model::GRIFFIN_2B;
|
||||
if (layer_types & kDeducedViT) return Model::GEMMA3_1B;
|
||||
return Model::GEMMA2_2B;
|
||||
case 27:
|
||||
return (layer_types & kDeduced448) ? Model::PALIGEMMA2_3B_448
|
||||
: Model::PALIGEMMA2_3B_224;
|
||||
case 34:
|
||||
return Model::GEMMA3_4B;
|
||||
case 42:
|
||||
if (layer_types & kDeducedViT) {
|
||||
return (layer_types & kDeduced448) ? Model::PALIGEMMA2_10B_448
|
||||
: Model::PALIGEMMA2_10B_224;
|
||||
}
|
||||
}
|
||||
}
|
||||
TEST_EQUAL(norm_num_groups, other.norm_num_groups);
|
||||
result &= vit_config.TestEqual(other.vit_config, partial, debug);
|
||||
return result;
|
||||
}
|
||||
return Model::GEMMA2_9B;
|
||||
case 46:
|
||||
return Model::GEMMA2_27B;
|
||||
case 48:
|
||||
return Model::GEMMA3_12B;
|
||||
case 62:
|
||||
return Model::GEMMA3_27B;
|
||||
|
||||
Model ModelFromConfig(const ModelConfig& config) {
|
||||
for (Model model : kAllModels) {
|
||||
ModelConfig model_config = ConfigFromModel(model);
|
||||
if (config.TestEqual(model_config, /*partial=*/true, /*debug=*/false)) {
|
||||
return model;
|
||||
}
|
||||
// TODO: detect these.
|
||||
/*
|
||||
return Model::GEMMA2_772M;
|
||||
return Model::PALIGEMMA2_772M_224;
|
||||
*/
|
||||
default:
|
||||
HWY_WARN("Failed to deduce model type from %s, layer count %zu types %x.",
|
||||
blob_path.path.c_str(), layers, layer_types);
|
||||
return Model::UNKNOWN;
|
||||
}
|
||||
return Model::UNKNOWN;
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
391
gemma/configs.h
391
gemma/configs.h
|
|
@ -19,35 +19,52 @@
|
|||
// Model configurations
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <array>
|
||||
#include <cstdint>
|
||||
#include <string>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
#include "compression/fields.h" // IFieldsVisitor
|
||||
#include "compression/shared.h" // BF16
|
||||
#include "compression/types.h" // Type
|
||||
#include "io/fields.h" // IFieldsVisitor
|
||||
#include "io/io.h" // Path
|
||||
#include "util/basics.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
// Allow changing pre-allocated kv cache size as a compiler flag
|
||||
#ifndef GEMMA_MAX_SEQLEN
|
||||
#define GEMMA_MAX_SEQLEN 4096
|
||||
#endif // !GEMMA_MAX_SEQLEN
|
||||
|
||||
// Allow changing k parameter of `SampleTopK` as a compiler flag
|
||||
#ifndef GEMMA_TOPK
|
||||
#define GEMMA_TOPK 1
|
||||
#endif // !GEMMA_TOPK
|
||||
|
||||
static constexpr size_t kSeqLen = GEMMA_MAX_SEQLEN;
|
||||
static constexpr size_t kTopK = GEMMA_TOPK;
|
||||
static constexpr size_t kVocabSize = 256000;
|
||||
static constexpr size_t kMaxConv1DWidth = 4;
|
||||
static constexpr size_t kMaxQKVDim = 1024;
|
||||
|
||||
using EmbedderInputT = BF16;
|
||||
// Instruction-tuned models require extra 'turn structure' tokens in prompts.
|
||||
enum class PromptWrapping {
|
||||
GEMMA_IT,
|
||||
GEMMA_PT,
|
||||
GEMMA_VLM, // for >1B Gemma3
|
||||
PALIGEMMA,
|
||||
kSentinel // must be last
|
||||
};
|
||||
|
||||
// This is used in `ModelConfig.Specifier`, so the strings will not change,
|
||||
// though new ones may be added.
|
||||
static inline const char* WrappingSuffix(PromptWrapping wrapping) {
|
||||
switch (wrapping) {
|
||||
case PromptWrapping::GEMMA_IT:
|
||||
return "-it";
|
||||
case PromptWrapping::GEMMA_PT:
|
||||
return "-pt";
|
||||
case PromptWrapping::GEMMA_VLM:
|
||||
return "-vlm";
|
||||
case PromptWrapping::PALIGEMMA:
|
||||
return "-pg";
|
||||
default:
|
||||
return "-?";
|
||||
}
|
||||
}
|
||||
|
||||
static inline bool EnumValid(PromptWrapping wrapping) {
|
||||
return static_cast<size_t>(wrapping) <
|
||||
static_cast<size_t>(PromptWrapping::kSentinel);
|
||||
}
|
||||
|
||||
enum class LayerAttentionType {
|
||||
kGemma,
|
||||
|
|
@ -55,63 +72,68 @@ enum class LayerAttentionType {
|
|||
kVit,
|
||||
};
|
||||
|
||||
inline bool EnumValid(LayerAttentionType type) {
|
||||
return static_cast<int>(type) >= 0 &&
|
||||
static_cast<int>(type) <= static_cast<int>(LayerAttentionType::kVit);
|
||||
static inline bool EnumValid(LayerAttentionType type) {
|
||||
return type == LayerAttentionType::kGemma ||
|
||||
type == LayerAttentionType::kGriffinRecurrentBlock ||
|
||||
type == LayerAttentionType::kVit;
|
||||
}
|
||||
|
||||
// Post attention and ffw normalization type.
|
||||
enum class PostNormType {
|
||||
None,
|
||||
Scale,
|
||||
kSentinel // must be last
|
||||
};
|
||||
|
||||
inline bool EnumValid(PostNormType type) {
|
||||
return static_cast<int>(type) >= 0 &&
|
||||
static_cast<int>(type) <= static_cast<int>(PostNormType::Scale);
|
||||
static inline bool EnumValid(PostNormType type) {
|
||||
return static_cast<size_t>(type) <
|
||||
static_cast<size_t>(PostNormType::kSentinel);
|
||||
}
|
||||
|
||||
// Post qk projection operation type.
|
||||
enum class PostQKType {
|
||||
Rope,
|
||||
HalfRope,
|
||||
kSentinel // must be last
|
||||
};
|
||||
|
||||
inline bool EnumValid(PostQKType type) {
|
||||
return static_cast<int>(type) >= 0 &&
|
||||
static_cast<int>(type) <= static_cast<int>(PostQKType::HalfRope);
|
||||
static inline bool EnumValid(PostQKType type) {
|
||||
return static_cast<size_t>(type) <
|
||||
static_cast<size_t>(PostNormType::kSentinel);
|
||||
}
|
||||
|
||||
// FFW activation function.
|
||||
enum class ActivationType {
|
||||
Gelu,
|
||||
kSentinel // must be last
|
||||
};
|
||||
|
||||
inline bool EnumValid(ActivationType type) {
|
||||
return static_cast<int>(type) >= 0 &&
|
||||
static_cast<int>(type) <= static_cast<int>(ActivationType::Gelu);
|
||||
static inline bool EnumValid(ActivationType type) {
|
||||
return static_cast<size_t>(type) <
|
||||
static_cast<size_t>(ActivationType::kSentinel);
|
||||
}
|
||||
|
||||
// Attention query scale.
|
||||
enum class QueryScaleType {
|
||||
SqrtKeySize,
|
||||
SqrtModelDimDivNumHeads,
|
||||
kSentinel // must be last
|
||||
};
|
||||
|
||||
inline bool EnumValid(QueryScaleType type) {
|
||||
return static_cast<int>(type) >= 0 &&
|
||||
static_cast<int>(type) <=
|
||||
static_cast<int>(QueryScaleType::SqrtModelDimDivNumHeads);
|
||||
static inline bool EnumValid(QueryScaleType type) {
|
||||
return static_cast<size_t>(type) <
|
||||
static_cast<size_t>(QueryScaleType::kSentinel);
|
||||
}
|
||||
|
||||
// Residual connection type.
|
||||
enum class ResidualType {
|
||||
Add,
|
||||
kSentinel // must be last
|
||||
};
|
||||
|
||||
inline bool EnumValid(ResidualType type) {
|
||||
return static_cast<int>(type) >= 0 &&
|
||||
static_cast<int>(type) <= static_cast<int>(ResidualType::Add);
|
||||
static inline bool EnumValid(ResidualType type) {
|
||||
return static_cast<size_t>(type) <
|
||||
static_cast<size_t>(ResidualType::kSentinel);
|
||||
}
|
||||
|
||||
template <size_t kNum>
|
||||
|
|
@ -137,17 +159,15 @@ std::vector<uint32_t> RepeatedAttentionWindowSizes(
|
|||
|
||||
// Model variants: see configs.cc for details.
|
||||
enum class Model {
|
||||
UNKNOWN,
|
||||
GEMMA_2B,
|
||||
GEMMA_7B,
|
||||
GEMMA2_9B,
|
||||
UNKNOWN = 0,
|
||||
// 1 and 2 are obsolete.
|
||||
GEMMA2_9B = 3,
|
||||
GEMMA2_27B,
|
||||
GRIFFIN_2B,
|
||||
GEMMA_TINY,
|
||||
GEMMA_TINY, // for testing only
|
||||
GEMMA2_2B,
|
||||
PALIGEMMA_224,
|
||||
PALIGEMMA_448,
|
||||
PALIGEMMA2_3B_224,
|
||||
// 8 and 9 are obsolete.
|
||||
PALIGEMMA2_3B_224 = 10,
|
||||
PALIGEMMA2_3B_448,
|
||||
PALIGEMMA2_10B_224,
|
||||
PALIGEMMA2_10B_448,
|
||||
|
|
@ -155,43 +175,64 @@ enum class Model {
|
|||
GEMMA3_1B,
|
||||
GEMMA3_12B,
|
||||
GEMMA3_27B,
|
||||
kSentinel,
|
||||
};
|
||||
|
||||
// Allows the Model enum to be iterated over.
|
||||
static constexpr Model kAllModels[] = {
|
||||
Model::GEMMA_2B, Model::GEMMA_7B, Model::GEMMA2_9B, Model::GEMMA2_27B,
|
||||
Model::GRIFFIN_2B, Model::GEMMA_TINY, Model::GEMMA2_2B,
|
||||
Model::PALIGEMMA_224, Model::PALIGEMMA_448, Model::PALIGEMMA2_3B_224,
|
||||
Model::PALIGEMMA2_3B_448, Model::PALIGEMMA2_10B_224,
|
||||
Model::PALIGEMMA2_10B_448, Model::GEMMA3_4B, Model::GEMMA3_1B,
|
||||
Model::GEMMA3_12B, Model::GEMMA3_27B,
|
||||
};
|
||||
// Returns canonical model name without the PromptWrapping suffix. This is used
|
||||
// in Specifier and thus does not change.
|
||||
const char* ModelPrefix(Model model);
|
||||
|
||||
inline bool EnumValid(Model model) {
|
||||
for (Model m : kAllModels) {
|
||||
if (m == model) return true;
|
||||
// Gemma3 is multimodal and has a different prompt wrapping than PaliGemma.
|
||||
// This is used for deducing the PromptWrapping for pre-2025 BlobStore.
|
||||
static inline bool IsVLM(Model model) {
|
||||
return model == Model::GEMMA3_4B || model == Model::GEMMA3_1B ||
|
||||
model == Model::GEMMA3_12B || model == Model::GEMMA3_27B;
|
||||
}
|
||||
|
||||
static inline bool IsPaliGemma(Model model) {
|
||||
if (model == Model::PALIGEMMA2_3B_224 || model == Model::PALIGEMMA2_3B_448 ||
|
||||
model == Model::PALIGEMMA2_10B_224 ||
|
||||
model == Model::PALIGEMMA2_10B_448) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Visits every valid model enum, skipping `UNKNOWN` and `kSentinel`.
|
||||
template <class Func>
|
||||
void ForEachModel(const Func& func) {
|
||||
for (size_t i = static_cast<size_t>(Model::GEMMA2_9B);
|
||||
i < static_cast<size_t>(Model::kSentinel); ++i) {
|
||||
if (i == 8 || i == 9) continue;
|
||||
func(static_cast<Model>(i));
|
||||
}
|
||||
}
|
||||
|
||||
static inline bool EnumValid(Model model) {
|
||||
// Valid for purposes of serialization, even if unknown.
|
||||
if (model == Model::UNKNOWN) return true;
|
||||
const size_t i = static_cast<size_t>(model);
|
||||
if (i >= static_cast<size_t>(Model::GEMMA2_9B) &&
|
||||
i < static_cast<size_t>(Model::kSentinel) && i != 8 && i != 9) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
struct InternalLayerConfig : public IFields {
|
||||
const char* Name() const override { return "InternalLayerConfig"; }
|
||||
|
||||
// Source of truth for field ordering.
|
||||
void VisitFields(IFieldsVisitor& visitor) override {
|
||||
// Append new fields here, then update `python/configs.cc`.
|
||||
}
|
||||
};
|
||||
|
||||
// Per-layer configuration.
|
||||
struct LayerConfig : public IFields {
|
||||
// Returns true if *this and other are equal.
|
||||
// If partial is true, then we don't check for items that are only set after
|
||||
// the tensors are loaded from the checkpoint.
|
||||
// If debug is true, then we output the mismatched fields to stderr.
|
||||
bool TestEqual(const LayerConfig& other, bool partial, bool debug) const;
|
||||
|
||||
size_t CacheLayerSize() const { return kv_heads * qkv_dim * 2; }
|
||||
|
||||
// Multi-Head Attention?
|
||||
bool IsMHA() const { return heads == kv_heads; }
|
||||
|
||||
// Stride between subsequent queries. Each of Q, K, V are of length kQKVDim,
|
||||
// but for MHA we store them as Q,K,V, Q,K,V, .. instead of Q..Q, K..K, V..V.
|
||||
size_t QStride() const { return qkv_dim * (IsMHA() ? 3 : 1); }
|
||||
|
||||
const char* Name() const override { return "LayerConfig"; }
|
||||
|
||||
// Source of truth for field ordering.
|
||||
void VisitFields(IFieldsVisitor& visitor) override {
|
||||
visitor(model_dim);
|
||||
visitor(griffin_dim);
|
||||
|
|
@ -208,35 +249,41 @@ struct LayerConfig : public IFields {
|
|||
visitor(activation);
|
||||
visitor(post_qk);
|
||||
visitor(use_qk_norm);
|
||||
internal.VisitFields(visitor);
|
||||
// Append new fields here, then update `python/configs.cc`.
|
||||
}
|
||||
|
||||
// Returns whether all fields match.
|
||||
bool TestEqual(const LayerConfig& other, bool print) const;
|
||||
|
||||
size_t CacheLayerSize() const { return kv_heads * qkv_dim * 2; }
|
||||
|
||||
// Multi-Head Attention?
|
||||
bool IsMHA() const { return heads == kv_heads; }
|
||||
|
||||
uint32_t model_dim = 0;
|
||||
uint32_t griffin_dim = 0;
|
||||
uint32_t ff_hidden_dim = 0;
|
||||
uint32_t heads = 0;
|
||||
uint32_t kv_heads = 0;
|
||||
uint32_t qkv_dim = 0;
|
||||
uint32_t conv1d_width = 0; // griffin only
|
||||
uint32_t qkv_dim = 0; // length of Q, K, V vectors (contiguous).
|
||||
uint32_t conv1d_width = 0; // Griffin only
|
||||
bool ff_biases = false;
|
||||
bool softmax_attn_output_biases = false;
|
||||
bool optimized_gating = true;
|
||||
bool softmax_attn_output_biases = false; // for Griffin
|
||||
bool optimized_gating = true; // for Gemma3
|
||||
PostNormType post_norm = PostNormType::None;
|
||||
LayerAttentionType type = LayerAttentionType::kGemma;
|
||||
ActivationType activation = ActivationType::Gelu;
|
||||
PostQKType post_qk = PostQKType::Rope;
|
||||
bool use_qk_norm = false;
|
||||
InternalLayerConfig internal;
|
||||
};
|
||||
|
||||
// Dimensions related to image processing.
|
||||
struct VitConfig : public IFields {
|
||||
// Returns true if *this and other are equal.
|
||||
// If partial is true, then we don't check for items that are only set after
|
||||
// the tensors are loaded from the checkpoint.
|
||||
// If debug is true, then we output the mismatched fields to stderr.
|
||||
bool TestEqual(const VitConfig& other, bool partial, bool debug) const;
|
||||
|
||||
const char* Name() const override { return "VitConfig"; }
|
||||
|
||||
// Source of truth for field ordering.
|
||||
void VisitFields(IFieldsVisitor& visitor) override {
|
||||
visitor(model_dim);
|
||||
visitor(seq_len);
|
||||
|
|
@ -245,8 +292,12 @@ struct VitConfig : public IFields {
|
|||
visitor(image_size);
|
||||
visitor(layer_configs);
|
||||
visitor(pool_dim);
|
||||
// Append new fields here, then update `python/configs.cc`.
|
||||
}
|
||||
|
||||
// Returns whether all fields match.
|
||||
bool TestEqual(const VitConfig& other, bool print) const;
|
||||
|
||||
uint32_t model_dim = 0;
|
||||
uint32_t seq_len = 0;
|
||||
uint32_t num_scales = 0;
|
||||
|
|
@ -256,20 +307,96 @@ struct VitConfig : public IFields {
|
|||
std::vector<LayerConfig> layer_configs;
|
||||
};
|
||||
|
||||
// Returns a valid `PromptWrapping` for the given `model`, for passing to the
|
||||
// `ModelConfig` ctor when the caller does not care about the wrapping. The
|
||||
// wrapping mode is either determined by the model (for PaliGemma and Gemma3),
|
||||
// or defaults to IT, subject to user override for PT.
|
||||
PromptWrapping ChooseWrapping(Model model,
|
||||
Tristate wrapping = Tristate::kDefault);
|
||||
|
||||
struct InternalModelConfig : public IFields {
|
||||
const char* Name() const override { return "InternalModelConfig"; }
|
||||
|
||||
// Source of truth for field ordering.
|
||||
void VisitFields(IFieldsVisitor& visitor) override {
|
||||
// Append new fields here, then update `python/configs.cc`.
|
||||
}
|
||||
};
|
||||
|
||||
struct ModelConfig : public IFields {
|
||||
// Returns true if *this and other are equal.
|
||||
// If partial is true, then we don't check for items that are only set after
|
||||
// the tensors are loaded from the checkpoint.
|
||||
// If debug is true, then we output the mismatched fields to stderr.
|
||||
bool TestEqual(const ModelConfig& other, bool partial, bool debug) const;
|
||||
// Preferred usage (single-file format): default-construct, then deserialize
|
||||
// from a blob. Also used by `config_converter.py`, which sets sufficient
|
||||
// fields for `TestEqual` and then calls `OverwriteWithCanonical()`.
|
||||
ModelConfig() = default;
|
||||
// For use by `model_store.cc` for pre-2025 format after deducing the model
|
||||
// from tensors plus a user-specified `wrapping` override (`ChooseWrapping`).
|
||||
ModelConfig(Model model, Type weight, PromptWrapping wrapping);
|
||||
// Parses a string returned by `Specifier()`. Used by the exporter to select
|
||||
// the model from command line arguments. Do not use this elsewhere - the
|
||||
// second ctor is preferred because it is type-checked.
|
||||
ModelConfig(const std::string& specifier);
|
||||
|
||||
const char* Name() const override { return "ModelConfig"; }
|
||||
|
||||
// Source of truth for field ordering.
|
||||
void VisitFields(IFieldsVisitor& visitor) override {
|
||||
visitor(model_family_version);
|
||||
visitor(display_name);
|
||||
visitor(model);
|
||||
visitor(wrapping);
|
||||
visitor(weight);
|
||||
|
||||
visitor(num_layers);
|
||||
visitor(model_dim);
|
||||
visitor(vocab_size);
|
||||
visitor(max_seq_len);
|
||||
|
||||
visitor(unused_num_tensor_scales);
|
||||
|
||||
visitor(att_cap);
|
||||
visitor(final_cap);
|
||||
|
||||
visitor(absolute_pe);
|
||||
visitor(use_local_attention);
|
||||
visitor(query_scale);
|
||||
visitor(layer_configs);
|
||||
visitor(attention_window_sizes);
|
||||
visitor(norm_num_groups);
|
||||
visitor(vit_config);
|
||||
visitor(pool_dim);
|
||||
|
||||
visitor(eos_id);
|
||||
visitor(secondary_eos_id);
|
||||
|
||||
visitor(scale_base_names);
|
||||
|
||||
internal.VisitFields(visitor);
|
||||
|
||||
// Append new fields here, then update `python/configs.cc`.
|
||||
}
|
||||
|
||||
// Returns whether all fields match except `model` and `display_name`, and
|
||||
// some others that are not yet set by config_converter.py. This is for
|
||||
// internal use by `OverwriteWithCanonical`, but potentially useful elsewhere.
|
||||
bool TestEqual(const ModelConfig& other, bool print) const;
|
||||
|
||||
// For each model, constructs its canonical `ModelConfig` and if `TestEqual`
|
||||
// returns true, overwrites `*this` with that. Otherwise, returns false to
|
||||
// indicate this is not a known model. Called by `config_converter.py`.
|
||||
bool OverwriteWithCanonical();
|
||||
|
||||
// Returns a string encoding of the model family, size, weight, and
|
||||
// `PromptWrapping`. Stable/unchanging; can be used as the model file name.
|
||||
// The third ctor also expects a string returned by this.
|
||||
std::string Specifier() const;
|
||||
|
||||
void AddLayerConfig(const LayerConfig& layer_config) {
|
||||
layer_configs.push_back(layer_config);
|
||||
HWY_ASSERT(layer_configs.size() <= num_layers);
|
||||
}
|
||||
|
||||
size_t CachePosSize() const {
|
||||
size_t num_layers = layer_configs.size();
|
||||
return num_layers * layer_configs[0].CacheLayerSize();
|
||||
bool IsGlobalLayer(size_t layer_idx) const {
|
||||
return attention_window_sizes[layer_idx] == max_seq_len;
|
||||
}
|
||||
|
||||
size_t NumLayersOfTypeBefore(LayerAttentionType type, size_t num) const {
|
||||
|
|
@ -287,77 +414,77 @@ struct ModelConfig : public IFields {
|
|||
size_t NumHeads() const {
|
||||
uint32_t num_heads = 0;
|
||||
for (const auto& layer_config : layer_configs) {
|
||||
num_heads = std::max(num_heads, layer_config.heads);
|
||||
num_heads = HWY_MAX(num_heads, layer_config.heads);
|
||||
}
|
||||
return num_heads;
|
||||
}
|
||||
|
||||
const char* Name() const override { return "ModelConfig"; }
|
||||
size_t KVCacheCols() const {
|
||||
size_t num_layers = layer_configs.size();
|
||||
return num_layers * layer_configs[0].CacheLayerSize();
|
||||
}
|
||||
|
||||
bool IsEOS(int id) const { return (id == eos_id || id == secondary_eos_id); }
|
||||
|
||||
void VisitFields(IFieldsVisitor& visitor) override {
|
||||
visitor(model_family_version);
|
||||
visitor(model_name);
|
||||
visitor(model);
|
||||
visitor(wrapping);
|
||||
visitor(weight);
|
||||
visitor(num_layers);
|
||||
visitor(model_dim);
|
||||
visitor(vocab_size);
|
||||
visitor(seq_len);
|
||||
visitor(num_tensor_scales);
|
||||
visitor(att_cap);
|
||||
visitor(final_cap);
|
||||
visitor(absolute_pe);
|
||||
visitor(use_local_attention);
|
||||
visitor(query_scale);
|
||||
visitor(layer_configs);
|
||||
visitor(attention_window_sizes);
|
||||
visitor(norm_num_groups);
|
||||
visitor(vit_config);
|
||||
visitor(pool_dim);
|
||||
visitor(eos_id);
|
||||
visitor(secondary_eos_id);
|
||||
}
|
||||
|
||||
// Major version of the model family. It is used as a fallback to distinguish
|
||||
// between model types when there is no explicit information in the config.
|
||||
// Major version of the model family, reflecting architecture changes. This is
|
||||
// more convenient to compare than `Model` because that also includes the
|
||||
// model size.
|
||||
uint32_t model_family_version = 1;
|
||||
std::string model_name;
|
||||
Model model = Model::UNKNOWN;
|
||||
// For display only, may change. Use `Specifier()` for setting the
|
||||
// file name. Not checked by `TestEqual` because `config_converter.py` does
|
||||
// not set this.
|
||||
std::string display_name;
|
||||
Model model = Model::UNKNOWN; // Not checked by `TestEqual`, see above.
|
||||
PromptWrapping wrapping = PromptWrapping::GEMMA_PT;
|
||||
Type weight = Type::kUnknown;
|
||||
|
||||
uint32_t num_layers = 0;
|
||||
uint32_t model_dim = 0;
|
||||
uint32_t vocab_size = 0;
|
||||
uint32_t seq_len = 0;
|
||||
uint32_t num_tensor_scales = 0;
|
||||
uint32_t max_seq_len = 0;
|
||||
|
||||
// We no longer set nor use this: config_converter is not able to set this,
|
||||
// and only pre-2025 format stores scales, and we do not require advance
|
||||
// knowledge of how many there will be. Any scales present will just be
|
||||
// assigned in order to the tensors matching `scale_base_names`.
|
||||
uint32_t unused_num_tensor_scales = 0;
|
||||
|
||||
float att_cap = 0.0f;
|
||||
float final_cap = 0.0f;
|
||||
|
||||
bool absolute_pe = false;
|
||||
bool use_local_attention = false; // griffin only
|
||||
bool use_local_attention = false; // Griffin only
|
||||
QueryScaleType query_scale = QueryScaleType::SqrtKeySize;
|
||||
std::vector<LayerConfig> layer_configs;
|
||||
std::vector<uint32_t> attention_window_sizes;
|
||||
std::unordered_set<std::string> scale_names;
|
||||
uint32_t norm_num_groups = 1;
|
||||
|
||||
// Dimensions related to image processing.
|
||||
VitConfig vit_config;
|
||||
uint32_t pool_dim = 1; // used only for VitConfig copy
|
||||
|
||||
int eos_id = 1;
|
||||
int secondary_eos_id = 1;
|
||||
|
||||
// Tensor base names without a layer suffix, used by `ModelStore` only for
|
||||
// pre-2025 format.
|
||||
std::vector<std::string> scale_base_names;
|
||||
|
||||
InternalModelConfig internal;
|
||||
};
|
||||
|
||||
// Returns the config for the given model.
|
||||
ModelConfig ConfigFromModel(Model model);
|
||||
|
||||
// Returns the model for the given config, if it matches any standard model.
|
||||
Model ModelFromConfig(const ModelConfig& config);
|
||||
|
||||
// Returns the sub-config for the ViT model of the PaliGemma model.
|
||||
ModelConfig GetVitConfig(const ModelConfig& config);
|
||||
|
||||
enum DeducedLayerTypes {
|
||||
kDeducedGriffin = 1,
|
||||
kDeducedViT = 2,
|
||||
kDeduced448 = 4, // For ViT, 448x448 resolution instead of 224x224.
|
||||
};
|
||||
|
||||
// layer_types is one or more of `DeducedLayerTypes`.
|
||||
Model DeduceModel(const Path& blob_path, size_t layers, int layer_types);
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_CONFIGS_H_
|
||||
|
|
|
|||
|
|
@ -1,461 +1,44 @@
|
|||
#include "gemma/configs.h"
|
||||
|
||||
#include <array>
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <type_traits>
|
||||
#include <stdio.h>
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "hwy/aligned_allocator.h"
|
||||
#include "compression/types.h" // Type
|
||||
#include "io/fields.h" // Type
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
template <size_t kNum>
|
||||
constexpr std::array<LayerAttentionType, kNum> OldFixedLayerConfig(
|
||||
LayerAttentionType type) {
|
||||
std::array<LayerAttentionType, kNum> config = {};
|
||||
for (LayerAttentionType& l : config) {
|
||||
l = type;
|
||||
}
|
||||
return config;
|
||||
}
|
||||
TEST(ConfigsTest, TestAll) {
|
||||
ForEachModel([&](Model model) {
|
||||
ModelConfig config(model, Type::kSFP, ChooseWrapping(model));
|
||||
fprintf(stderr, "Testing %s (%s)\n", config.display_name.c_str(),
|
||||
config.Specifier().c_str());
|
||||
HWY_ASSERT(config.model == model);
|
||||
|
||||
template <size_t kNum>
|
||||
constexpr std::array<size_t, kNum> OldFixedAttentionWindowSizes(
|
||||
size_t window_size) {
|
||||
std::array<size_t, kNum> window_size_configs = {};
|
||||
for (size_t& l : window_size_configs) {
|
||||
l = window_size;
|
||||
}
|
||||
return window_size_configs;
|
||||
}
|
||||
// We can deduce the model/display_name from all other fields.
|
||||
config.model = Model::UNKNOWN;
|
||||
const std::string saved_display_name = config.display_name;
|
||||
config.display_name.clear();
|
||||
HWY_ASSERT(config.OverwriteWithCanonical());
|
||||
HWY_ASSERT(config.model == model);
|
||||
HWY_ASSERT(config.display_name == saved_display_name);
|
||||
|
||||
// Repeat window_size_pattern for kNum / kPatternSize times.
|
||||
template <size_t kNum, size_t kPatternSize>
|
||||
constexpr std::array<size_t, kNum> OldRepeatedAttentionWindowSizes(
|
||||
const std::array<size_t, kPatternSize>& window_size_pattern) {
|
||||
static_assert(kNum % kPatternSize == 0,
|
||||
"kNum must be a multiple of kPatternSize");
|
||||
std::array<size_t, kNum> window_size_configs = {};
|
||||
for (size_t i = 0; i < kNum; ++i) {
|
||||
window_size_configs[i] = window_size_pattern[i % kPatternSize];
|
||||
}
|
||||
return window_size_configs;
|
||||
}
|
||||
|
||||
template <size_t kNumLayers>
|
||||
constexpr size_t OldNumLayersOfTypeBefore(
|
||||
const std::array<LayerAttentionType, kNumLayers>& layers,
|
||||
LayerAttentionType type, size_t num) {
|
||||
size_t count = 0;
|
||||
for (size_t i = 0; i < num; i++) {
|
||||
if (layers[i] == type) count++;
|
||||
}
|
||||
return count;
|
||||
}
|
||||
|
||||
template <class TConfig, typename = void>
|
||||
struct CacheLayerSize {
|
||||
constexpr size_t operator()() const {
|
||||
return TConfig::kKVHeads * TConfig::kQKVDim * 2;
|
||||
}
|
||||
};
|
||||
|
||||
template <class TConfig, typename = void>
|
||||
struct CachePosSize {
|
||||
constexpr size_t operator()() const {
|
||||
return TConfig::kGemmaLayers * CacheLayerSize<TConfig>()();
|
||||
}
|
||||
};
|
||||
|
||||
struct OldConfigNoVit {
|
||||
struct VitConfig {
|
||||
// Some of these are needed to make the compiler happy when trying to
|
||||
// generate code that will actually never be used.
|
||||
using Weight = float;
|
||||
static constexpr int kLayers = 0;
|
||||
static constexpr std::array<LayerAttentionType, 0> kLayerConfig =
|
||||
OldFixedLayerConfig<0>(LayerAttentionType::kVit);
|
||||
static constexpr int kModelDim = 0;
|
||||
static constexpr int kFFHiddenDim = 0;
|
||||
static constexpr int kHeads = 1; // Avoid division by 0 in griffin gate_w.
|
||||
static constexpr int kKVHeads = 0;
|
||||
static constexpr int kQKVDim = 0;
|
||||
static constexpr int kSeqLen = 0;
|
||||
static constexpr ResidualType kResidual = ResidualType::Add;
|
||||
static constexpr int kGriffinLayers = 0;
|
||||
static constexpr int kConv1dWidth = 0;
|
||||
static constexpr bool kFFBiases = false;
|
||||
static constexpr bool kSoftmaxAttnOutputBiases = false;
|
||||
static constexpr PostNormType kPostNorm = PostNormType::None;
|
||||
};
|
||||
};
|
||||
|
||||
struct OldConfigNoSSM : OldConfigNoVit {
|
||||
static constexpr int kGriffinLayers = 0;
|
||||
|
||||
static constexpr int kConv1dWidth = 0;
|
||||
static constexpr bool kFFBiases = false;
|
||||
static constexpr bool kSoftmaxAttnOutputBiases = false;
|
||||
static constexpr bool kUseHalfRope = false;
|
||||
static constexpr bool kUseLocalAttention = false;
|
||||
static constexpr bool kInterleaveQKV = true;
|
||||
static constexpr PostQKType kPostQK = PostQKType::Rope;
|
||||
static constexpr ActivationType kActivation = ActivationType::Gelu;
|
||||
static constexpr ResidualType kResidual = ResidualType::Add;
|
||||
};
|
||||
|
||||
struct OldConfigBaseGemmaV1 : OldConfigNoSSM {
|
||||
static constexpr float kAttCap = 0.0f;
|
||||
static constexpr float kFinalCap = 0.0f;
|
||||
static constexpr PostNormType kPostNorm = PostNormType::None;
|
||||
static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize;
|
||||
};
|
||||
|
||||
struct OldConfigBaseGemmaV2 : OldConfigNoSSM {
|
||||
static constexpr float kAttCap = 50.0f;
|
||||
static constexpr float kFinalCap = 30.0f;
|
||||
static constexpr PostNormType kPostNorm = PostNormType::Scale;
|
||||
};
|
||||
|
||||
template <typename TWeight>
|
||||
struct OldConfigGemma2_27B : public OldConfigBaseGemmaV2 {
|
||||
using Weight = TWeight; // make accessible where we only have a TConfig
|
||||
|
||||
static constexpr int kSeqLen = 8192;
|
||||
static constexpr int kVocabSize = gcpp::kVocabSize;
|
||||
static constexpr std::array<LayerAttentionType, 46> kLayerConfig =
|
||||
OldFixedLayerConfig<46>(LayerAttentionType::kGemma);
|
||||
static constexpr std::array<size_t, 46> kAttentionWindowSizes =
|
||||
OldRepeatedAttentionWindowSizes<46, 2>({4096, kSeqLen});
|
||||
static constexpr int kLayers = kLayerConfig.size();
|
||||
static constexpr int kNumTensorScales = 4 * kLayers;
|
||||
static constexpr int kGemmaLayers = kLayers;
|
||||
static constexpr int kModelDim = 4608;
|
||||
static constexpr int kFFHiddenDim = 16 * 4608 / 2; // = 36864
|
||||
static constexpr int kHeads = 32;
|
||||
static constexpr int kKVHeads = 16;
|
||||
static constexpr int kQKVDim = 128; // query size == key size == value size
|
||||
static constexpr int kTopK = gcpp::kTopK;
|
||||
static constexpr bool kAbsolutePE = false;
|
||||
static constexpr QueryScaleType kQueryScale =
|
||||
QueryScaleType::SqrtModelDimDivNumHeads;
|
||||
};
|
||||
|
||||
template <typename TWeight>
|
||||
struct OldConfigGemma2_9B : public OldConfigBaseGemmaV2 {
|
||||
using Weight = TWeight; // make accessible where we only have a TConfig
|
||||
|
||||
static constexpr int kSeqLen = 8192;
|
||||
static constexpr int kVocabSize = gcpp::kVocabSize;
|
||||
static constexpr std::array<LayerAttentionType, 42> kLayerConfig =
|
||||
OldFixedLayerConfig<42>(LayerAttentionType::kGemma);
|
||||
static constexpr std::array<size_t, 42> kAttentionWindowSizes =
|
||||
OldRepeatedAttentionWindowSizes<42, 2>({4096, kSeqLen});
|
||||
static constexpr int kLayers = kLayerConfig.size();
|
||||
static constexpr int kNumTensorScales = 4 * kLayers;
|
||||
static constexpr int kGemmaLayers = kLayers;
|
||||
static constexpr int kModelDim = 3584;
|
||||
static constexpr int kFFHiddenDim = 8 * 3584 / 2; // = 14336
|
||||
static constexpr int kHeads = 16;
|
||||
static constexpr int kKVHeads = 8;
|
||||
static constexpr int kQKVDim = 256; // query size == key size == value size
|
||||
static constexpr int kTopK = gcpp::kTopK;
|
||||
static constexpr bool kAbsolutePE = false;
|
||||
static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize;
|
||||
};
|
||||
|
||||
template <typename TWeight>
|
||||
struct OldConfigGemma7B : public OldConfigBaseGemmaV1 {
|
||||
using Weight = TWeight; // make accessible where we only have a TConfig
|
||||
|
||||
static constexpr int kSeqLen = gcpp::kSeqLen;
|
||||
static constexpr int kVocabSize = gcpp::kVocabSize;
|
||||
static constexpr std::array<LayerAttentionType, 28> kLayerConfig =
|
||||
OldFixedLayerConfig<28>(LayerAttentionType::kGemma);
|
||||
static constexpr std::array<size_t, 28> kAttentionWindowSizes =
|
||||
OldFixedAttentionWindowSizes<28>(kSeqLen);
|
||||
static constexpr int kLayers = kLayerConfig.size();
|
||||
static constexpr int kNumTensorScales = 4 * kLayers;
|
||||
static constexpr int kGemmaLayers = kLayers;
|
||||
static constexpr int kModelDim = 3072;
|
||||
static constexpr int kFFHiddenDim = 16 * 3072 / 2; // = 24576
|
||||
static constexpr int kHeads = 16;
|
||||
static constexpr int kKVHeads = 16; // standard MHA
|
||||
static constexpr int kQKVDim = 256; // query size == key size == value size
|
||||
static constexpr int kTopK = gcpp::kTopK;
|
||||
static constexpr bool kAbsolutePE = false;
|
||||
};
|
||||
|
||||
template <typename TWeight>
|
||||
struct OldConfigGemma2B : public OldConfigBaseGemmaV1 {
|
||||
using Weight = TWeight; // make accessible where we only have a TConfig
|
||||
|
||||
static constexpr int kSeqLen = gcpp::kSeqLen;
|
||||
static constexpr int kVocabSize = gcpp::kVocabSize;
|
||||
static constexpr std::array<LayerAttentionType, 18> kLayerConfig =
|
||||
OldFixedLayerConfig<18>(LayerAttentionType::kGemma);
|
||||
static constexpr std::array<size_t, 18> kAttentionWindowSizes =
|
||||
OldFixedAttentionWindowSizes<18>(kSeqLen);
|
||||
static constexpr int kLayers = kLayerConfig.size();
|
||||
static constexpr int kNumTensorScales = 4 * kLayers;
|
||||
static constexpr int kGemmaLayers = kLayers;
|
||||
static constexpr int kModelDim = 2048;
|
||||
static constexpr int kFFHiddenDim = 16 * 2048 / 2; // = 16384
|
||||
static constexpr int kHeads = 8;
|
||||
static constexpr int kKVHeads = 1;
|
||||
static constexpr int kQKVDim = 256; // query size == key size == value size
|
||||
static constexpr int kTopK = gcpp::kTopK;
|
||||
static constexpr bool kAbsolutePE = false;
|
||||
};
|
||||
|
||||
template <typename TWeight>
|
||||
struct OldConfigPaliGemma_224 : public OldConfigGemma2B<TWeight> {
|
||||
// On the LM side, the vocab size is one difference to Gemma1-2B in the
|
||||
// architecture. PaliGemma adds 1024 <locNNNN> and 128 <segNNN> tokens.
|
||||
static constexpr int kVocabSize = 256000 + 1024 + 128; // = 257152
|
||||
|
||||
// Sub-config for the Vision-Transformer part.
|
||||
struct VitConfig : public OldConfigNoSSM {
|
||||
using Weight = TWeight;
|
||||
// The ViT parts. https://arxiv.org/abs/2305.13035
|
||||
// "SoViT-400m/14 [...] has a width of 1152, depth 27, and MLP dim 4304."
|
||||
static constexpr std::array<LayerAttentionType, 27> kLayerConfig =
|
||||
OldFixedLayerConfig<27>(LayerAttentionType::kVit);
|
||||
static constexpr int kLayers = kLayerConfig.size();
|
||||
static constexpr int kNumTensorScales = 4 * kLayers;
|
||||
static constexpr int kModelDim = 1152;
|
||||
static constexpr int kFFHiddenDim = 4304;
|
||||
static constexpr int kHeads = 16;
|
||||
static constexpr int kKVHeads = 16; // standard MHA
|
||||
static constexpr int kQKVDim = 72;
|
||||
static constexpr int kSeqLen = 16 * 16; // 256
|
||||
static constexpr bool kFFBiases = true;
|
||||
// The Vit part does not have a vocabulary, the image patches are embedded.
|
||||
static constexpr int kVocabSize = 0;
|
||||
// Dimensions related to image processing.
|
||||
static constexpr int kPatchWidth = 14;
|
||||
static constexpr int kImageSize = 224;
|
||||
// Necessary constant for the layer configuration.
|
||||
static constexpr PostNormType kPostNorm = PostNormType::None;
|
||||
};
|
||||
};
|
||||
|
||||
template <typename TWeight>
|
||||
struct OldConfigGemma2_2B : public OldConfigBaseGemmaV2 {
|
||||
using Weight = TWeight; // make accessible where we only have a TConfig
|
||||
|
||||
static constexpr int kSeqLen = 8192;
|
||||
static constexpr int kVocabSize = gcpp::kVocabSize;
|
||||
static constexpr std::array<LayerAttentionType, 26> kLayerConfig =
|
||||
OldFixedLayerConfig<26>(LayerAttentionType::kGemma);
|
||||
static constexpr std::array<size_t, 26> kAttentionWindowSizes =
|
||||
OldRepeatedAttentionWindowSizes<26, 2>({4096, kSeqLen});
|
||||
static constexpr int kLayers = kLayerConfig.size();
|
||||
static constexpr int kNumTensorScales = 4 * kLayers;
|
||||
static constexpr int kGemmaLayers = kLayers;
|
||||
static constexpr int kModelDim = 2304;
|
||||
static constexpr int kFFHiddenDim = 8 * 2304 / 2; // = 9216
|
||||
static constexpr int kHeads = 8;
|
||||
static constexpr int kKVHeads = 4;
|
||||
static constexpr int kQKVDim = 256; // query size == key size == value size
|
||||
static constexpr int kTopK = gcpp::kTopK;
|
||||
static constexpr bool kAbsolutePE = false;
|
||||
static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize;
|
||||
};
|
||||
|
||||
template <typename TWeight>
|
||||
struct OldConfigGemmaTiny : public OldConfigNoSSM {
|
||||
using Weight = TWeight; // make accessible where we only have a TConfig
|
||||
|
||||
static constexpr int kSeqLen = 32;
|
||||
static constexpr int kVocabSize = 64;
|
||||
static constexpr std::array<LayerAttentionType, 3> kLayerConfig =
|
||||
OldFixedLayerConfig<3>(LayerAttentionType::kGemma);
|
||||
static constexpr std::array<size_t, 3> kAttentionWindowSizes =
|
||||
OldFixedAttentionWindowSizes<3>(kSeqLen);
|
||||
static constexpr int kLayers = kLayerConfig.size();
|
||||
static constexpr int kNumTensorScales = 4 * kLayers;
|
||||
static constexpr int kGemmaLayers = kLayers;
|
||||
static constexpr int kModelDim = 128;
|
||||
static constexpr int kFFHiddenDim = 256;
|
||||
static constexpr int kHeads = 4;
|
||||
static constexpr int kKVHeads = 1;
|
||||
static constexpr int kQKVDim = 16; // query size == key size == value size
|
||||
static constexpr int kTopK = gcpp::kTopK;
|
||||
static constexpr bool kAbsolutePE = false;
|
||||
static constexpr PostNormType kPostNorm = PostNormType::None;
|
||||
static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize;
|
||||
|
||||
static constexpr float kAttCap = 0.0f;
|
||||
// This is required for optimize_test to pass.
|
||||
static constexpr float kFinalCap = 30.0f;
|
||||
};
|
||||
|
||||
template <typename TWeight>
|
||||
struct OldConfigGriffin2B : OldConfigNoVit {
|
||||
using Weight = TWeight; // make accessible where we only have a TConfig
|
||||
|
||||
// Griffin uses local attention, so kSeqLen is actually the local attention
|
||||
// window.
|
||||
static constexpr int kSeqLen = 2048;
|
||||
static constexpr int kVocabSize = gcpp::kVocabSize;
|
||||
static constexpr std::array<LayerAttentionType, 26> kLayerConfig = {
|
||||
LayerAttentionType::kGriffinRecurrentBlock,
|
||||
LayerAttentionType::kGriffinRecurrentBlock,
|
||||
LayerAttentionType::kGemma,
|
||||
LayerAttentionType::kGriffinRecurrentBlock,
|
||||
LayerAttentionType::kGriffinRecurrentBlock,
|
||||
LayerAttentionType::kGemma,
|
||||
LayerAttentionType::kGriffinRecurrentBlock,
|
||||
LayerAttentionType::kGriffinRecurrentBlock,
|
||||
LayerAttentionType::kGemma,
|
||||
LayerAttentionType::kGriffinRecurrentBlock,
|
||||
LayerAttentionType::kGriffinRecurrentBlock,
|
||||
LayerAttentionType::kGemma,
|
||||
LayerAttentionType::kGriffinRecurrentBlock,
|
||||
LayerAttentionType::kGriffinRecurrentBlock,
|
||||
LayerAttentionType::kGemma,
|
||||
LayerAttentionType::kGriffinRecurrentBlock,
|
||||
LayerAttentionType::kGriffinRecurrentBlock,
|
||||
LayerAttentionType::kGemma,
|
||||
LayerAttentionType::kGriffinRecurrentBlock,
|
||||
LayerAttentionType::kGriffinRecurrentBlock,
|
||||
LayerAttentionType::kGemma,
|
||||
LayerAttentionType::kGriffinRecurrentBlock,
|
||||
LayerAttentionType::kGriffinRecurrentBlock,
|
||||
LayerAttentionType::kGemma,
|
||||
LayerAttentionType::kGriffinRecurrentBlock,
|
||||
LayerAttentionType::kGriffinRecurrentBlock,
|
||||
};
|
||||
static constexpr std::array<size_t, 26> kAttentionWindowSizes =
|
||||
OldFixedAttentionWindowSizes<26>(kSeqLen);
|
||||
static constexpr int kLayers = kLayerConfig.size();
|
||||
static constexpr int kGemmaLayers = OldNumLayersOfTypeBefore(
|
||||
kLayerConfig, LayerAttentionType::kGemma, kLayers);
|
||||
static constexpr int kGriffinLayers = OldNumLayersOfTypeBefore(
|
||||
kLayerConfig, LayerAttentionType::kGriffinRecurrentBlock, kLayers);
|
||||
static constexpr int kModelDim = 2560;
|
||||
static constexpr int kFFHiddenDim = 7680;
|
||||
static constexpr int kHeads = 10;
|
||||
static constexpr int kKVHeads = 1;
|
||||
static constexpr int kQKVDim = 256; // query size == key size == value size
|
||||
static constexpr int kTopK = gcpp::kTopK;
|
||||
static constexpr bool kAbsolutePE = false;
|
||||
static constexpr PostNormType kPostNorm = PostNormType::None;
|
||||
|
||||
// No SoftCap.
|
||||
static constexpr float kAttCap = 0.0f;
|
||||
static constexpr float kFinalCap = 0.0f;
|
||||
|
||||
// SSM config.
|
||||
static constexpr int kConv1dWidth = 4;
|
||||
static constexpr bool kFFBiases = true;
|
||||
static constexpr bool kSoftmaxAttnOutputBiases = true;
|
||||
static constexpr bool kUseHalfRope = true;
|
||||
static constexpr bool kUseLocalAttention = true;
|
||||
static constexpr bool kInterleaveQKV = false;
|
||||
static constexpr int kNumTensorScales = 140;
|
||||
static constexpr PostQKType kPostQK = PostQKType::Rope;
|
||||
static constexpr ActivationType kActivation = ActivationType::Gelu;
|
||||
static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize;
|
||||
static constexpr ResidualType kResidual = ResidualType::Add;
|
||||
};
|
||||
|
||||
template <class TConfig>
|
||||
void AssertMatch(const ModelConfig& config) {
|
||||
ASSERT_EQ(TConfig::kModelDim, config.model_dim);
|
||||
if constexpr (TConfig::VitConfig::kModelDim != 0) {
|
||||
ASSERT_EQ(TConfig::VitConfig::kModelDim, config.vit_config.model_dim);
|
||||
ASSERT_EQ(TConfig::VitConfig::kSeqLen, config.vit_config.seq_len);
|
||||
ASSERT_EQ(TConfig::VitConfig::kNumTensorScales,
|
||||
config.vit_config.num_scales);
|
||||
for (size_t i = 0; i < config.vit_config.layer_configs.size(); ++i) {
|
||||
ASSERT_EQ(TConfig::VitConfig::kLayerConfig[i],
|
||||
config.vit_config.layer_configs[i].type);
|
||||
}
|
||||
}
|
||||
ASSERT_EQ(TConfig::kVocabSize, config.vocab_size);
|
||||
ASSERT_EQ(TConfig::kSeqLen, config.seq_len);
|
||||
ASSERT_EQ(TConfig::kAttCap, config.att_cap);
|
||||
ASSERT_EQ(TConfig::kFinalCap, config.final_cap);
|
||||
ASSERT_EQ(TConfig::kAbsolutePE, config.absolute_pe);
|
||||
ASSERT_EQ(TConfig::kUseLocalAttention, config.use_local_attention);
|
||||
ASSERT_EQ(TConfig::kQueryScale, config.query_scale);
|
||||
ASSERT_EQ(TConfig::kGemmaLayers,
|
||||
config.NumLayersOfType(LayerAttentionType::kGemma));
|
||||
ASSERT_EQ(TConfig::kGriffinLayers,
|
||||
config.NumLayersOfType(LayerAttentionType::kGriffinRecurrentBlock));
|
||||
for (size_t i = 0; i < config.layer_configs.size(); ++i) {
|
||||
ASSERT_EQ(TConfig::kModelDim, config.layer_configs[i].model_dim);
|
||||
ASSERT_EQ(TConfig::kFFHiddenDim, config.layer_configs[i].ff_hidden_dim);
|
||||
ASSERT_EQ(TConfig::kHeads, config.layer_configs[i].heads);
|
||||
ASSERT_EQ(TConfig::kKVHeads, config.layer_configs[i].kv_heads);
|
||||
ASSERT_EQ(TConfig::kQKVDim, config.layer_configs[i].qkv_dim);
|
||||
ASSERT_EQ(TConfig::kConv1dWidth, config.layer_configs[i].conv1d_width);
|
||||
ASSERT_EQ(TConfig::kFFBiases, config.layer_configs[i].ff_biases);
|
||||
ASSERT_EQ(TConfig::kSoftmaxAttnOutputBiases,
|
||||
config.layer_configs[i].softmax_attn_output_biases);
|
||||
ASSERT_EQ(TConfig::kPostNorm, config.layer_configs[i].post_norm);
|
||||
ASSERT_EQ(TConfig::kLayerConfig[i], config.layer_configs[i].type);
|
||||
ASSERT_EQ(TConfig::kActivation, config.layer_configs[i].activation);
|
||||
PostQKType post_qk = TConfig::kPostQK;
|
||||
if (TConfig::kUseHalfRope) {
|
||||
post_qk = PostQKType::HalfRope;
|
||||
}
|
||||
ASSERT_EQ(post_qk, config.layer_configs[i].post_qk);
|
||||
}
|
||||
|
||||
ASSERT_EQ(TConfig::kAttentionWindowSizes.size(),
|
||||
config.attention_window_sizes.size());
|
||||
for (size_t i = 0; i < config.attention_window_sizes.size(); ++i) {
|
||||
ASSERT_EQ(TConfig::kAttentionWindowSizes[i],
|
||||
config.attention_window_sizes[i]);
|
||||
}
|
||||
ASSERT_EQ(TConfig::kNumTensorScales, config.num_tensor_scales);
|
||||
}
|
||||
|
||||
ModelConfig RoundTripSerialize(const ModelConfig& config) {
|
||||
std::vector<uint32_t> config_buffer = config.Write();
|
||||
ModelConfig deserialized;
|
||||
deserialized.Read(hwy::Span<const uint32_t>(config_buffer), 0);
|
||||
return deserialized;
|
||||
}
|
||||
|
||||
TEST(ConfigsTest, OldConfigGemma2B) {
|
||||
AssertMatch<OldConfigGemma2B<float>>(ConfigFromModel(Model::GEMMA_2B));
|
||||
ModelConfig config = RoundTripSerialize(ConfigFromModel(Model::GEMMA_2B));
|
||||
AssertMatch<OldConfigGemma2B<float>>(config);
|
||||
}
|
||||
|
||||
TEST(ConfigsTest, OldConfigGemma7B) {
|
||||
AssertMatch<OldConfigGemma7B<float>>(ConfigFromModel(Model::GEMMA_7B));
|
||||
}
|
||||
|
||||
TEST(ConfigsTest, OldConfigGemma2_2B) {
|
||||
AssertMatch<OldConfigGemma2_2B<float>>(ConfigFromModel(Model::GEMMA2_2B));
|
||||
}
|
||||
|
||||
TEST(ConfigsTest, OldConfigGemma2_9B) {
|
||||
AssertMatch<OldConfigGemma2_9B<float>>(ConfigFromModel(Model::GEMMA2_9B));
|
||||
}
|
||||
|
||||
TEST(ConfigsTest, OldConfigGemma2_27B) {
|
||||
AssertMatch<OldConfigGemma2_27B<float>>(ConfigFromModel(Model::GEMMA2_27B));
|
||||
}
|
||||
|
||||
TEST(ConfigsTest, OldConfigGriffin2B) {
|
||||
AssertMatch<OldConfigGriffin2B<float>>(ConfigFromModel(Model::GRIFFIN_2B));
|
||||
}
|
||||
|
||||
TEST(ConfigsTest, OldConfigGemmaTiny) {
|
||||
AssertMatch<OldConfigGemmaTiny<float>>(ConfigFromModel(Model::GEMMA_TINY));
|
||||
}
|
||||
|
||||
TEST(ConfigsTest, OldConfigPaliGemma_224) {
|
||||
AssertMatch<OldConfigPaliGemma_224<float>>(
|
||||
ConfigFromModel(Model::PALIGEMMA_224));
|
||||
const std::vector<uint32_t> serialized = config.Write();
|
||||
ModelConfig deserialized;
|
||||
const IFields::ReadResult result =
|
||||
deserialized.Read(hwy::Span<const uint32_t>(serialized), /*pos=*/0);
|
||||
HWY_ASSERT(result.pos == serialized.size());
|
||||
// We wrote it, so all fields should be known, and no extra.
|
||||
HWY_ASSERT(result.extra_u32 == 0);
|
||||
HWY_ASSERT(result.missing_fields == 0);
|
||||
// All fields should match.
|
||||
HWY_ASSERT(deserialized.TestEqual(config, /*print=*/true));
|
||||
HWY_ASSERT(deserialized.model == model);
|
||||
HWY_ASSERT(deserialized.display_name == saved_display_name);
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
1622
gemma/gemma-inl.h
1622
gemma/gemma-inl.h
File diff suppressed because it is too large
Load Diff
747
gemma/gemma.cc
747
gemma/gemma.cc
|
|
@ -13,173 +13,654 @@
|
|||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Defines Gemma member functions; the actual implementations are in
|
||||
// gemma-inl.h, included from instantiations/*.cc.
|
||||
// Defines Gemma member functions which dynamic-dispatch into the SIMD
|
||||
// implementations in gemma-inl.h.
|
||||
|
||||
#include "gemma/gemma.h"
|
||||
|
||||
#include "compression/types.h" // GEMMA_DISABLED_TARGETS
|
||||
#ifndef HWY_DISABLED_TARGETS
|
||||
#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS
|
||||
#endif // HWY_DISABLED_TARGETS
|
||||
|
||||
// Compiles this file for multiple architectures via "foreach_target.h", to
|
||||
// which we pass the filename via macro 'argument'.
|
||||
// clang-format off
|
||||
#undef HWY_TARGET_INCLUDE
|
||||
#define HWY_TARGET_INCLUDE "gemma/gemma.cc" // NOLINT
|
||||
// clang-format on
|
||||
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||
#include "hwy/highway.h"
|
||||
// After highway.h
|
||||
#include "gemma/attention.h" // includes highway.h
|
||||
#include "gemma/gemma-inl.h"
|
||||
#include "gemma/griffin.h" // includes highway.h
|
||||
#include "gemma/vit.h" // includes highway.h
|
||||
|
||||
#ifndef GEMMA_CC_ONCE
|
||||
#define GEMMA_CC_ONCE
|
||||
|
||||
#include <math.h> // sqrtf
|
||||
#include <stddef.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
|
||||
#include <string>
|
||||
#include <utility> // std::move
|
||||
#include <vector>
|
||||
|
||||
#include "compression/io.h" // Path
|
||||
#include "compression/shared.h"
|
||||
#include "gemma/common.h"
|
||||
#include "gemma/configs.h"
|
||||
#include "gemma/model_store.h"
|
||||
#include "gemma/weights.h"
|
||||
#include "ops/ops-inl.h"
|
||||
#include "io/blob_store.h"
|
||||
#include "io/io.h" // Path
|
||||
#include "ops/matmul.h"
|
||||
#include "paligemma/image.h"
|
||||
#include "util/threading.h"
|
||||
#include "hwy/highway.h"
|
||||
#include "util/threading_context.h"
|
||||
#include "hwy/aligned_allocator.h" // Span
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/timer.h"
|
||||
|
||||
#endif // GEMMA_CC_ONCE
|
||||
|
||||
HWY_BEFORE_NAMESPACE();
|
||||
namespace gcpp {
|
||||
namespace HWY_NAMESPACE {
|
||||
|
||||
Gemma::Gemma(const Path& tokenizer_path, const Path& weights,
|
||||
const ModelInfo& info, MatMulEnv& env)
|
||||
: env_(env), tokenizer_(tokenizer_path) {
|
||||
model_.Load(weights, info.model, info.weight, info.wrapping,
|
||||
env_.parallel.Pools().Pool(0),
|
||||
/*tokenizer_proto=*/nullptr);
|
||||
}
|
||||
|
||||
Gemma::Gemma(const Path& weights, MatMulEnv& env) : env_(env) {
|
||||
std::string tokenizer_proto;
|
||||
model_.Load(weights, Model::UNKNOWN, Type::kUnknown, PromptWrapping::GEMMA_IT,
|
||||
env_.parallel.Pools().Pool(0), &tokenizer_proto);
|
||||
tokenizer_.Deserialize(tokenizer_proto);
|
||||
}
|
||||
|
||||
Gemma::Gemma(GemmaTokenizer&& tokenizer, const ModelInfo& info, MatMulEnv& env)
|
||||
: env_(env), tokenizer_(std::move(tokenizer)) {
|
||||
HWY_ASSERT(info.weight == Type::kF32);
|
||||
model_.Allocate(info.model, info.weight, env_.parallel.Pools().Pool(0));
|
||||
}
|
||||
|
||||
Gemma::~Gemma() {
|
||||
}
|
||||
|
||||
// There are >=3 types of the inference code. To reduce compile time,
|
||||
// we shard them across multiple translation units in instantiations/*.cc.
|
||||
// This declares the functions defined there. We use overloading because
|
||||
// explicit instantiations are still too slow to compile.
|
||||
#define GEMMA_DECLARE(TWEIGHT) \
|
||||
extern void GenerateSingle(TWEIGHT, const ModelWeightsStorage& model, \
|
||||
const RuntimeConfig& runtime_config, \
|
||||
const PromptTokens& prompt, size_t pos, \
|
||||
size_t prefix_end, KVCache& kv_cache, \
|
||||
MatMulEnv* env, TimingInfo& timing_info); \
|
||||
extern void GenerateBatch( \
|
||||
TWEIGHT, const ModelWeightsStorage& model, \
|
||||
const RuntimeConfig& runtime_config, const QueriesPromptTokens& prompts, \
|
||||
const QueriesPos& queries_pos, const QueriesPos& queries_prefix_end, \
|
||||
const KVCaches& kv_caches, MatMulEnv* env, TimingInfo& timing_info); \
|
||||
extern void GenerateImageTokens(TWEIGHT, const ModelWeightsStorage& model, \
|
||||
const RuntimeConfig& runtime_config, \
|
||||
const Image& image, \
|
||||
ImageTokens& image_tokens, MatMulEnv* env);
|
||||
GEMMA_DECLARE(float)
|
||||
GEMMA_DECLARE(BF16)
|
||||
GEMMA_DECLARE(NuqStream)
|
||||
GEMMA_DECLARE(SfpStream)
|
||||
|
||||
// Adapters to select from the above overloads via CallForModelWeight.
|
||||
template <class TConfig>
|
||||
struct GenerateSingleT {
|
||||
void operator()(const ModelWeightsStorage& model,
|
||||
const RuntimeConfig& runtime_config,
|
||||
const PromptTokens& prompt, size_t pos, size_t prefix_end,
|
||||
KVCache& kv_cache, MatMulEnv* env,
|
||||
TimingInfo& timing_info) const {
|
||||
GenerateSingle(TConfig(), model, runtime_config, prompt, pos, prefix_end,
|
||||
kv_cache, env, timing_info);
|
||||
void Attention(LayerAttentionType type, const size_t num_tokens,
|
||||
const size_t layer_idx, const LayerWeightsPtrs& layer,
|
||||
Activations& activations, QBatch& qbatch, MatMulEnv& env) {
|
||||
if (type == LayerAttentionType::kGemma) {
|
||||
GemmaAttention(num_tokens, layer_idx, layer, activations.attention, qbatch,
|
||||
env,
|
||||
/*flags=*/0);
|
||||
} else {
|
||||
HWY_DASSERT(type == LayerAttentionType::kGriffinRecurrentBlock);
|
||||
// KVCache conv1d_cache and rglru_cache have one row per *Griffin* layer,
|
||||
// so map `layer` to the Griffin layer index.
|
||||
const size_t griffin_layer =
|
||||
activations.attention.config.NumLayersOfTypeBefore(type, layer_idx);
|
||||
GriffinRecurrent(num_tokens, griffin_layer, &layer, activations, qbatch,
|
||||
env);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
template <class TConfig>
|
||||
struct GenerateBatchT {
|
||||
void operator()(const ModelWeightsStorage& model,
|
||||
const RuntimeConfig& runtime_config,
|
||||
const QueriesPromptTokens& queries_prompt,
|
||||
const QueriesPos& queries_pos,
|
||||
const QueriesPos& queries_prefix_end,
|
||||
const KVCaches& kv_caches, MatMulEnv* env,
|
||||
TimingInfo& timing_info) const {
|
||||
GenerateBatch(TConfig(), model, runtime_config, queries_prompt, queries_pos,
|
||||
queries_prefix_end, kv_caches, env, timing_info);
|
||||
}
|
||||
};
|
||||
static HWY_NOINLINE void TransformerLayer(const size_t num_tokens,
|
||||
const size_t layer_idx,
|
||||
const LayerWeightsPtrs& layer,
|
||||
Activations& activations,
|
||||
QBatch& qbatch, MatMulEnv& env) {
|
||||
const LayerConfig& layer_config = layer.layer_config;
|
||||
|
||||
template <class TConfig>
|
||||
struct GenerateImageTokensT {
|
||||
void operator()(const ModelWeightsStorage& model,
|
||||
const RuntimeConfig& runtime_config, const Image& image,
|
||||
ImageTokens& image_tokens, MatMulEnv* env) const {
|
||||
GenerateImageTokens(TConfig(), model, runtime_config, image, image_tokens,
|
||||
env);
|
||||
RMSNormBatched(activations.x, layer.pre_attention_norm_scale,
|
||||
activations.attention.pre_att_rms_out, env.ctx);
|
||||
|
||||
Attention(layer_config.type, num_tokens, layer_idx, layer, activations,
|
||||
qbatch, env);
|
||||
|
||||
PostNorm(layer_config.post_norm, layer.post_attention_norm_scale,
|
||||
activations.attention.att_sums, env.ctx);
|
||||
|
||||
ResidualConnection(activations.attention.att_sums, activations.x, layer,
|
||||
/*is_attention=*/true, env.ctx);
|
||||
|
||||
RMSNormBatched(activations.x, layer.pre_ffw_norm_scale,
|
||||
activations.pre_ffw_rms_out, env.ctx);
|
||||
|
||||
if (layer_config.type == LayerAttentionType::kVit) {
|
||||
FFWVit(layer, activations, env);
|
||||
} else {
|
||||
FFWNoVit(layer, activations, env);
|
||||
}
|
||||
};
|
||||
|
||||
PostNorm(layer_config.post_norm, layer.post_ffw_norm_scale,
|
||||
activations.ffw_out, env.ctx);
|
||||
|
||||
ResidualConnection(activations.ffw_out, activations.x, layer,
|
||||
/*is_attention=*/false, env.ctx);
|
||||
}
|
||||
|
||||
// Returns the scale value to use for the embedding (basically sqrt model_dim).
|
||||
static float EmbeddingScaling(size_t model_dim) {
|
||||
// Round to bf16 to match Gemma's Embedder, which casts before mul.
|
||||
return hwy::ConvertScalarTo<float>(
|
||||
hwy::ConvertScalarTo<BF16>(sqrtf(static_cast<float>(model_dim))));
|
||||
}
|
||||
|
||||
// `batch_idx` indicates which row of `x` to write to.
|
||||
// `pos` is the *token*'s position, not the start of the batch, because this is
|
||||
// called for batches of tokens in prefill, but batches of queries in decode.
|
||||
//
|
||||
// For GEMMA_VLM, image tokens are copied into -2 locations (per the Gemma 3
|
||||
// spec) until we run out of image tokens. This allows for a multi-image prompt
|
||||
// if -2 locations with appropriate begin/end image tokens are created by the
|
||||
// calling application.
|
||||
// Returns new image_token_position.
|
||||
static HWY_NOINLINE size_t
|
||||
EmbedMMToken(int token, size_t qi, size_t pos, size_t pos_in_prompt,
|
||||
const ModelConfig& model_config, const WeightsPtrs& weights,
|
||||
MatStorageT<float>& x, const ImageTokens* image_tokens = nullptr,
|
||||
size_t image_token_position = 0) {
|
||||
// Image tokens just need to be copied.
|
||||
if (model_config.wrapping == PromptWrapping::GEMMA_VLM &&
|
||||
image_tokens != nullptr && token == -2 &&
|
||||
image_token_position < image_tokens->Rows()) {
|
||||
hwy::CopyBytes(image_tokens->Row(image_token_position), x.Row(qi),
|
||||
x.Cols() * x.ElementBytes());
|
||||
return image_token_position + 1;
|
||||
}
|
||||
|
||||
if (model_config.wrapping == PromptWrapping::PALIGEMMA &&
|
||||
image_tokens != nullptr && pos_in_prompt < image_tokens->Rows()) {
|
||||
hwy::CopyBytes(image_tokens->Row(pos_in_prompt), x.Row(qi),
|
||||
x.Cols() * x.ElementBytes());
|
||||
return image_token_position;
|
||||
}
|
||||
|
||||
const size_t model_dim = model_config.model_dim;
|
||||
const float emb_scaling = EmbeddingScaling(model_dim);
|
||||
const size_t worker = 0; // Not yet parallelized.
|
||||
|
||||
HWY_DASSERT(token >= 0);
|
||||
HWY_DASSERT(token < static_cast<int>(model_config.vocab_size));
|
||||
|
||||
CallUpcasted(&weights.embedder_input_embedding, [&](const auto* weights_t) {
|
||||
// Using `Stride` to compute the offset works for both NUQ (because we use
|
||||
// an offset and NUQ is never padded) and padded, because non-NUQ types are
|
||||
// seekable, hence the offset can also skip any padding.
|
||||
const size_t embedding_ofs = token * weights_t->Stride();
|
||||
HWY_ASSERT(weights_t->Cols() == model_dim);
|
||||
const auto embedding_span =
|
||||
MakeSpan(weights_t->Row(0), embedding_ofs + model_dim);
|
||||
const hn::ScalableTag<float> df;
|
||||
DecompressAndZeroPad(df, embedding_span, embedding_ofs, x.Row(qi),
|
||||
model_dim);
|
||||
MulByConst(emb_scaling * weights_t->Scale(), x.Row(qi), model_dim, worker);
|
||||
});
|
||||
|
||||
if (model_config.absolute_pe) {
|
||||
AddAbsolutePositionalEmbeddings(x.Row(qi), model_dim, pos);
|
||||
}
|
||||
return image_token_position;
|
||||
}
|
||||
|
||||
// Populates KV cache for batches of tokens from one query at a time. This is
|
||||
// called if prompts are longer than the query batch size, and also in
|
||||
// prefix-LM mode (end > 0), which must see all tokens in one batch.
|
||||
static HWY_NOINLINE void PrefillTBatch(const ModelConfig& config,
|
||||
const RuntimeConfig& runtime_config,
|
||||
const WeightsPtrs& weights,
|
||||
Activations& activations, QBatch& qbatch,
|
||||
MatMulEnv& env,
|
||||
hwy::BitSet4096<>& non_eos) {
|
||||
PROFILER_ZONE("Gen.PrefillT");
|
||||
|
||||
// Batches are important for amortizing loading weights over multiple tokens.
|
||||
// This is possible in prefill because we know all tokens beforehand, whereas
|
||||
// decode depends on the previous output token. However, each prefill batch of
|
||||
// a query requires that preceding batches already wrote to the KV cache,
|
||||
// hence we sequentially loop over token batches. We can reduce the number of
|
||||
// iterations by increasing the batch size, but this also increases arithmetic
|
||||
// intensity, and so we are eventually compute-limited. TransformerLayer uses
|
||||
// all available threads, so we do not also parallelize over queries, but note
|
||||
// that PrefillQBatch uses queries as the batch dimension.
|
||||
const size_t max_tbatch_size = runtime_config.prefill_tbatch_size;
|
||||
|
||||
// For each query. `qi` is within the batch, not the global query index.
|
||||
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
|
||||
non_eos.Set(qi);
|
||||
|
||||
// One query at a time, batching will be the query's prompt tokens.
|
||||
QBatch qbatch_1 = qbatch.Single(qi);
|
||||
|
||||
const size_t prompt_size = qbatch_1.Prompt(0).size();
|
||||
// In autoregressive mode, we don't need to prefill the last token, so - 1.
|
||||
size_t prefill_this_query = prompt_size - 1;
|
||||
const size_t prefix_end_this_query = qbatch_1.PrefixEnd(0);
|
||||
// We can't attend beyond the prompt_size.
|
||||
HWY_ASSERT(prefix_end_this_query <= prompt_size);
|
||||
// Special case: if the prefix includes the last token, we need to prefill
|
||||
// the last token, too. However, we need to rewind this for the generation
|
||||
// of the first token. So we need to keep track of this.
|
||||
// TODO: consider implementing masking instead of this logic?
|
||||
const bool attend_to_last_token =
|
||||
(prefill_this_query < prefix_end_this_query);
|
||||
if (attend_to_last_token) {
|
||||
// The difference can be at most 1.
|
||||
prefill_this_query += 1;
|
||||
HWY_ASSERT(prefill_this_query == prefix_end_this_query);
|
||||
}
|
||||
// In prefix-LM mode, we need to look at all the tokens for the prefix in
|
||||
// one iteration through the layers, so we need a large enough batch size.
|
||||
HWY_ASSERT(prefix_end_this_query == 0 ||
|
||||
max_tbatch_size >= prefill_this_query);
|
||||
|
||||
// For each batch of tokens in the query:
|
||||
for (size_t tbatch_start = 0; tbatch_start < prefill_this_query;
|
||||
tbatch_start += max_tbatch_size) {
|
||||
const size_t tbatch_size =
|
||||
HWY_MIN(max_tbatch_size, prefill_this_query - tbatch_start);
|
||||
activations.SetBatchSize(tbatch_size);
|
||||
|
||||
// Fill activations.x (much faster than TransformerLayer).
|
||||
size_t image_token_position = 0;
|
||||
for (size_t ti = 0; ti < tbatch_size; ++ti) {
|
||||
const size_t pos = qbatch_1.Pos(0) + ti;
|
||||
const size_t pos_in_prompt = tbatch_start + ti;
|
||||
const int token = qbatch_1.Prompt(0)[pos_in_prompt];
|
||||
image_token_position = EmbedMMToken(
|
||||
token, ti, pos, pos_in_prompt, config, weights, activations.x,
|
||||
runtime_config.image_tokens, image_token_position);
|
||||
}
|
||||
|
||||
// Transformer with one batch of tokens from a single query.
|
||||
for (size_t layer_idx = 0; layer_idx < config.layer_configs.size();
|
||||
++layer_idx) {
|
||||
TransformerLayer(tbatch_size, layer_idx, *weights.GetLayer(layer_idx),
|
||||
activations, qbatch_1, env);
|
||||
}
|
||||
|
||||
// NOTE: we unconditionally call StreamToken, even if EOS.
|
||||
for (size_t ti = 0; ti < tbatch_size; ++ti) {
|
||||
const size_t pos = qbatch_1.Pos(0) + ti;
|
||||
const size_t pos_in_prompt = tbatch_start + ti;
|
||||
const int token = qbatch_1.Prompt(0)[pos_in_prompt];
|
||||
if (pos_in_prompt < prompt_size - 1) {
|
||||
runtime_config.StreamToken(qbatch_1.QueryIdx(0), pos, token, 0.0f);
|
||||
} else {
|
||||
// The last token will be streamed later and we should only get here
|
||||
// if we need to attend to the last token because it is in the prefix.
|
||||
HWY_ASSERT(attend_to_last_token);
|
||||
}
|
||||
}
|
||||
|
||||
qbatch_1.MutablePos(0) += tbatch_size;
|
||||
} // for tbatch_start
|
||||
if (attend_to_last_token) {
|
||||
// We need to rewind the position for the last token that we only
|
||||
// attended to to make sure the prefix LM sees everything.
|
||||
// This means we duplicate work on the last prompt token in autoregressive
|
||||
// decoding. Alternatives: (1) real masking; (2) always prefill the last
|
||||
// token and only generate the next one from the already prefilled
|
||||
// activations.
|
||||
qbatch_1.MutablePos(0) -= 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Embeds PrevToken (one from each query) and calls each TransformerLayer.
|
||||
// Called by query-batched `PrefillQBatch` and `DecodeStepT`, but not the
|
||||
// token-batched `PrefillTBatch`.
|
||||
static HWY_NOINLINE void Transformer(const ModelConfig& config,
|
||||
const RuntimeConfig& runtime_config,
|
||||
const WeightsPtrs& weights,
|
||||
Activations& activations, QBatch& qbatch,
|
||||
MatMulEnv& env) {
|
||||
if (HWY_UNLIKELY(runtime_config.layers_output)) {
|
||||
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
|
||||
const float token_f = qbatch.PrevToken(qi);
|
||||
runtime_config.layers_output(qbatch.QueryIdx(qi), qbatch.Pos(qi),
|
||||
"tokens", -1, &token_f, 1);
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: parallelize?
|
||||
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
|
||||
EmbedMMToken(qbatch.PrevToken(qi), qi, qbatch.Pos(qi),
|
||||
/*pos_in_prompt=*/0, config, weights, activations.x);
|
||||
}
|
||||
|
||||
for (size_t layer_idx = 0; layer_idx < weights.c_layers.size(); ++layer_idx) {
|
||||
TransformerLayer(/*num_tokens=*/1, layer_idx, *weights.GetLayer(layer_idx),
|
||||
activations, qbatch, env);
|
||||
|
||||
if (HWY_UNLIKELY(runtime_config.activations_observer)) {
|
||||
runtime_config.activations_observer(
|
||||
QueriesPos(&qbatch.MutablePos(0), qbatch.Size()), layer_idx,
|
||||
activations);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Populates KV cache for the batch queries, one token at a time. Only called
|
||||
// for autoregressive (non-prefix-LM) prefill, so `queries_prefix_end` == 0.
|
||||
static HWY_NOINLINE void PrefillQBatch(const size_t max_prompt_size,
|
||||
const ModelConfig& config,
|
||||
const RuntimeConfig& runtime_config,
|
||||
const WeightsPtrs& weights,
|
||||
Activations& activations, QBatch& qbatch,
|
||||
MatMulEnv& env,
|
||||
hwy::BitSet4096<>& non_eos) {
|
||||
PROFILER_ZONE("Gen.PrefillQ");
|
||||
|
||||
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
|
||||
non_eos.Set(qi);
|
||||
HWY_DASSERT(qbatch.PrefixEnd(qi) == 0);
|
||||
}
|
||||
|
||||
// In autoregressive mode, we don't prefill the last token, hence - 1.
|
||||
for (size_t pos_in_prompt = 0; pos_in_prompt < max_prompt_size - 1;
|
||||
++pos_in_prompt) {
|
||||
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
|
||||
int token = config.eos_id;
|
||||
if (pos_in_prompt < qbatch.Prompt(qi).size() - 1) {
|
||||
token = qbatch.Prompt(qi)[pos_in_prompt];
|
||||
// Ignore StreamToken return value because requesting to stop does not
|
||||
// make sense during prefill.
|
||||
(void)runtime_config.StreamToken(qbatch.QueryIdx(qi), pos_in_prompt,
|
||||
token, 0.0f);
|
||||
qbatch.MutablePos(qi) = pos_in_prompt;
|
||||
}
|
||||
|
||||
qbatch.PrevToken(qi) = token;
|
||||
}
|
||||
|
||||
// The input (PrevToken) is one token from each query in the batch.
|
||||
// Do not call DecodeStepT because it computes logits for token
|
||||
// probabilities, which are not required for the prompt tokens.
|
||||
Transformer(config, runtime_config, weights, activations, qbatch, env);
|
||||
}
|
||||
|
||||
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
|
||||
qbatch.MutablePos(qi) = qbatch.Prompt(qi).size() - 1;
|
||||
}
|
||||
}
|
||||
|
||||
// Calls `StreamToken`, writes the token to `PrevToken` for use by subsequent
|
||||
// `DecodeStepT`, and increments `MutablePos`. Also updates `non_eos` if the
|
||||
// query is at the end of its sequence.
|
||||
static void StreamAndUpdateEOS(const size_t qi, int token, const float prob,
|
||||
const ModelConfig& config,
|
||||
const RuntimeConfig& runtime_config,
|
||||
QBatch& qbatch, hwy::BitSet4096<>& non_eos) {
|
||||
HWY_DASSERT(non_eos.Get(qi)); // otherwise, should not be called.
|
||||
|
||||
if (HWY_UNLIKELY(!runtime_config.StreamToken(qbatch.QueryIdx(qi),
|
||||
qbatch.Pos(qi), token, prob))) {
|
||||
// User decided to stop: set token to primary EOS to trigger IsEOS below.
|
||||
token = config.eos_id;
|
||||
HWY_DASSERT(config.IsEOS(token));
|
||||
}
|
||||
|
||||
qbatch.PrevToken(qi) = token;
|
||||
qbatch.MutablePos(qi) += 1;
|
||||
|
||||
// Primary or secondary EOS: mark query as EOS, but still increment (for
|
||||
// multi-turn, we should still keep the prior EOS).
|
||||
if (HWY_UNLIKELY(config.IsEOS(token))) non_eos.Clear(qi);
|
||||
}
|
||||
|
||||
// For a batch of queries, runs Transformer, computes logits, samples and
|
||||
// streams the token.
|
||||
static void DecodeStepT(const ModelConfig& config,
|
||||
const RuntimeConfig& runtime_config,
|
||||
const WeightsPtrs& weights,
|
||||
const SampleFunc& sample_token,
|
||||
Activations& activations, QBatch& qbatch,
|
||||
MatMulEnv& env, hwy::BitSet4096<>& non_eos,
|
||||
TimingInfo& timing_info) {
|
||||
HWY_DASSERT(qbatch.Size() == activations.x.Rows());
|
||||
|
||||
Transformer(config, runtime_config, weights, activations, qbatch, env);
|
||||
|
||||
RMSNormInplaceBatched(weights.final_norm_scale, activations.x, env.ctx);
|
||||
|
||||
if (HWY_UNLIKELY(runtime_config.activations_observer)) {
|
||||
runtime_config.activations_observer(
|
||||
QueriesPos(&qbatch.MutablePos(0), qbatch.Size()), -1, activations);
|
||||
}
|
||||
|
||||
{
|
||||
PROFILER_ZONE("Gen.EmbeddingMatmul");
|
||||
// Compute logits from last layer activations.
|
||||
CallMatMul(activations.x, weights.embedder_input_embedding,
|
||||
/*add=*/nullptr, env, activations.logits);
|
||||
}
|
||||
PROFILER_ZONE("Gen.Softcap+Sample+Stream");
|
||||
const size_t worker = 0; // TODO: parallelize
|
||||
non_eos.Foreach([&](size_t qi) {
|
||||
float* HWY_RESTRICT logits = activations.logits.Row(qi);
|
||||
MaybeLogitsSoftCap(config.final_cap, logits, config.vocab_size, worker);
|
||||
const TokenAndProb tp = sample_token(logits, config.vocab_size);
|
||||
timing_info.NotifyGenerated();
|
||||
|
||||
StreamAndUpdateEOS(qi, tp.token, tp.prob, config, runtime_config, qbatch,
|
||||
non_eos);
|
||||
});
|
||||
}
|
||||
|
||||
static HWY_INLINE SampleFunc
|
||||
ChooseSampleFunc(const RuntimeConfig& runtime_config) {
|
||||
// If user provided a sample_func, use it.
|
||||
if (runtime_config.sample_func) return runtime_config.sample_func;
|
||||
|
||||
const size_t worker = 0; // TODO: parallelize
|
||||
|
||||
// Fast path for top-1 with no accept_token.
|
||||
if (runtime_config.top_k == 1 && !runtime_config.accept_token) {
|
||||
return [](float* logits, size_t vocab_size) HWY_ATTR -> TokenAndProb {
|
||||
PROFILER_ZONE2(worker, "Gen.Sample Top1");
|
||||
return Top1OfSoftmax(logits, vocab_size);
|
||||
};
|
||||
}
|
||||
|
||||
// General case: Softmax with top-k sampling.
|
||||
return [&runtime_config](float* logits,
|
||||
size_t vocab_size) HWY_ATTR -> TokenAndProb {
|
||||
PROFILER_ZONE("Gen.Sample general");
|
||||
return FusedSoftmaxAndSampleTopK(
|
||||
logits, runtime_config.top_k, vocab_size, *runtime_config.gen,
|
||||
runtime_config.temperature, runtime_config.accept_token, worker);
|
||||
};
|
||||
}
|
||||
|
||||
// Decode: generates one continuation token for each query in `qbatch`.
|
||||
static void GenerateT(const ModelConfig& config,
|
||||
const RuntimeConfig& runtime_config,
|
||||
const WeightsPtrs& weights, Activations& activations,
|
||||
QBatch& qbatch, MatMulEnv& env, TimingInfo& timing_info) {
|
||||
// Griffin assumes that the recurrent block cache is zero-initialized.
|
||||
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
|
||||
if (qbatch.MutablePos(qi) == 0) {
|
||||
qbatch.KV(qi).ZeroGriffinCache(); // No-op for non-Griffin models.
|
||||
}
|
||||
}
|
||||
|
||||
size_t max_prompt_size = 0;
|
||||
bool all_prefix_end_are_zero = true;
|
||||
size_t total_prefill_tokens = 0; // only for throughput stats.
|
||||
const size_t seq_len = qbatch.KV(0).SeqLen();
|
||||
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
|
||||
const PromptTokens& prompt = qbatch.Prompt(qi);
|
||||
max_prompt_size = HWY_MAX(max_prompt_size, prompt.size());
|
||||
|
||||
// Prefill stops before size - 1 because the last prompt token is the
|
||||
// first input token for generation.
|
||||
total_prefill_tokens += prompt.size() - 1;
|
||||
|
||||
// Sanity check: prompts should not be empty, nor start with EOS.
|
||||
HWY_ASSERT(prompt.size() != 0 && prompt[0] != config.eos_id);
|
||||
|
||||
all_prefix_end_are_zero &= qbatch.PrefixEnd(qi) == 0;
|
||||
|
||||
// We use a single divisor, so all sequence lengths must be the same.
|
||||
HWY_ASSERT(qbatch.KV(qi).SeqLen() == seq_len);
|
||||
}
|
||||
if (max_prompt_size >= seq_len) {
|
||||
HWY_ABORT("max_prompt_size = %zu, increase --seq_len to at least that.",
|
||||
max_prompt_size);
|
||||
}
|
||||
HWY_ASSERT(activations.attention.div_seq_len.GetDivisor() == seq_len);
|
||||
|
||||
// Lacks a constructor to bulk-set, hence initialized by Prefill* which have
|
||||
// qi loops anyway.
|
||||
hwy::BitSet4096<> non_eos; // indexed by qi
|
||||
|
||||
timing_info.prefill_start = hwy::platform::Now();
|
||||
// Batch over the larger of prompt length, or queries.
|
||||
if ((qbatch.Size() > max_prompt_size) && all_prefix_end_are_zero) {
|
||||
activations.SetBatchSize(qbatch.Size()); // required before PrefillQBatch
|
||||
PrefillQBatch(max_prompt_size, config, runtime_config, weights, activations,
|
||||
qbatch, env, non_eos);
|
||||
} else {
|
||||
PrefillTBatch(config, runtime_config, weights, activations, qbatch, env,
|
||||
non_eos);
|
||||
activations.SetBatchSize(qbatch.Size()); // Restore after PrefillTBatch.
|
||||
}
|
||||
HWY_DASSERT(non_eos.Count() == qbatch.Size());
|
||||
timing_info.NotifyPrefill(total_prefill_tokens);
|
||||
// queries_pos have been incremented by Prefill.
|
||||
|
||||
// Stream the last prompt token from each query, fill activations.gen_tokens.
|
||||
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
|
||||
const size_t last_pos_in_prompt = qbatch.Pos(qi) - qbatch.InitialPos(qi);
|
||||
StreamAndUpdateEOS(qi, qbatch.Prompt(qi)[last_pos_in_prompt], 0.0f, config,
|
||||
runtime_config, qbatch, non_eos);
|
||||
}
|
||||
|
||||
size_t max_gen_steps = runtime_config.max_generated_tokens;
|
||||
if (max_prompt_size + max_gen_steps > seq_len) {
|
||||
HWY_WARN("prefill %zu + max_gen_steps %zu > seq_len %zu, truncating.",
|
||||
max_prompt_size, max_gen_steps, seq_len);
|
||||
max_gen_steps = seq_len - max_prompt_size;
|
||||
}
|
||||
|
||||
const SampleFunc sample_token = ChooseSampleFunc(runtime_config);
|
||||
|
||||
{
|
||||
timing_info.generate_start = hwy::platform::Now();
|
||||
for (size_t gen = 0; gen < max_gen_steps && non_eos.Any(); ++gen) {
|
||||
DecodeStepT(config, runtime_config, weights, sample_token, activations,
|
||||
qbatch, env, non_eos, timing_info);
|
||||
}
|
||||
timing_info.NotifyGenerateDone();
|
||||
}
|
||||
}
|
||||
|
||||
void GenerateSingleT(const PromptTokens& prompt, size_t pos, size_t prefix_end,
|
||||
const ModelConfig& config,
|
||||
const RuntimeConfig& runtime_config,
|
||||
const WeightsPtrs& weights, KVCache& kv_cache,
|
||||
MatMulEnv& env, TimingInfo& timing_info) {
|
||||
Activations activations(config, runtime_config.prefill_tbatch_size,
|
||||
kv_cache.SeqLen(), env.ctx.allocator, env.row_ptrs);
|
||||
|
||||
AllQueries all_queries(prompt, pos, prefix_end,
|
||||
hwy::Span<KVCache>(&kv_cache, 1));
|
||||
QBatch qbatch(/*start=*/0, /*max_size=*/1, all_queries);
|
||||
GenerateT(config, runtime_config, weights, activations, qbatch, env,
|
||||
timing_info);
|
||||
}
|
||||
|
||||
// Splits the input into batches of at most `runtime_config.decode_qbatch_size`
|
||||
// queries, and calls `GenerateT` on each batch.
|
||||
void GenerateBatchT(const ModelConfig& config,
|
||||
const RuntimeConfig& runtime_config,
|
||||
const WeightsPtrs& weights, AllQueries& all_queries,
|
||||
MatMulEnv& env, TimingInfo& timing_info) {
|
||||
const size_t max_batch_size = HWY_MAX(runtime_config.decode_qbatch_size,
|
||||
runtime_config.prefill_tbatch_size);
|
||||
Activations activations(config, max_batch_size,
|
||||
all_queries[0].kv_cache.SeqLen(), env.ctx.allocator,
|
||||
env.row_ptrs);
|
||||
|
||||
for (size_t start = 0; start < all_queries.NumQueries();
|
||||
start += runtime_config.decode_qbatch_size) {
|
||||
QBatch qbatch(start, runtime_config.decode_qbatch_size, all_queries);
|
||||
// Generate a batch of one token for each of `qbatch.Size()` queries.
|
||||
GenerateT(config, runtime_config, weights, activations, qbatch, env,
|
||||
timing_info);
|
||||
}
|
||||
}
|
||||
|
||||
void GenerateImageTokensT(const ModelConfig& config,
|
||||
const RuntimeConfig& runtime_config, size_t seq_len,
|
||||
const WeightsPtrs& weights, const Image& image,
|
||||
ImageTokens& image_tokens, MatMulEnv& env) {
|
||||
if (config.vit_config.layer_configs.empty()) {
|
||||
HWY_ABORT("Model does not support generating image tokens.");
|
||||
}
|
||||
RuntimeConfig prefill_runtime_config = runtime_config;
|
||||
const ModelConfig vit_config = GetVitConfig(config);
|
||||
const size_t num_tokens = vit_config.max_seq_len;
|
||||
prefill_runtime_config.prefill_tbatch_size =
|
||||
num_tokens / (vit_config.pool_dim * vit_config.pool_dim);
|
||||
Activations prefill_activations(vit_config, num_tokens, num_tokens,
|
||||
env.ctx.allocator, env.row_ptrs);
|
||||
// Weights are for the full PaliGemma model, not just the ViT part.
|
||||
PrefillVit(config, weights, prefill_runtime_config, image, image_tokens,
|
||||
prefill_activations, env);
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||
} // namespace HWY_NAMESPACE
|
||||
} // namespace gcpp
|
||||
HWY_AFTER_NAMESPACE();
|
||||
|
||||
#if HWY_ONCE
|
||||
namespace gcpp {
|
||||
HWY_EXPORT(GenerateSingleT);
|
||||
HWY_EXPORT(GenerateBatchT);
|
||||
HWY_EXPORT(GenerateImageTokensT);
|
||||
|
||||
Gemma::Gemma(const LoaderArgs& loader, const InferenceArgs& inference,
|
||||
ThreadingContext& ctx)
|
||||
: reader_(loader.weights),
|
||||
model_(reader_, loader.tokenizer, loader.wrapping),
|
||||
weights_(model_.Config()),
|
||||
chat_template_(model_.Tokenizer(), model_.Config().model),
|
||||
inference_(inference) {
|
||||
weight_read_mode_ = weights_.ReadFromBlobs(model_, reader_, loader, inference,
|
||||
mat_owners_, ctx);
|
||||
// Read everything into memory, or `weights_.mapped_` keeps the mapping alive.
|
||||
reader_.CloseFile();
|
||||
}
|
||||
|
||||
Gemma::~Gemma() = default;
|
||||
|
||||
void Gemma::Save(const Path& weights_path, NestedPools& pools) const {
|
||||
BlobWriter writer(weights_path, pools.Pool());
|
||||
const std::vector<uint32_t> serialized_mat_ptrs =
|
||||
weights_.AddTensorDataToWriter(writer);
|
||||
WriteSingleFile(model_.Config(), model_.Tokenizer(), serialized_mat_ptrs,
|
||||
writer);
|
||||
}
|
||||
|
||||
void Gemma::Generate(const RuntimeConfig& runtime_config,
|
||||
const PromptTokens& prompt, size_t pos, size_t prefix_end,
|
||||
KVCache& kv_cache, TimingInfo& timing_info) {
|
||||
env_.parallel.Pools().MaybeStartSpinning(runtime_config.use_spinning);
|
||||
KVCache& kv_cache, MatMulEnv& env,
|
||||
TimingInfo& timing_info) const {
|
||||
env.ctx.pools.MaybeStartSpinning(runtime_config.use_spinning);
|
||||
|
||||
model_.CallForModelWeight<GenerateSingleT>(
|
||||
runtime_config, prompt, pos, prefix_end, kv_cache, &env_, timing_info);
|
||||
HWY_DYNAMIC_DISPATCH(GenerateSingleT)(prompt, pos, prefix_end,
|
||||
model_.Config(), runtime_config,
|
||||
weights_, kv_cache, env, timing_info);
|
||||
|
||||
env_.parallel.Pools().MaybeStopSpinning(runtime_config.use_spinning);
|
||||
env.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning);
|
||||
}
|
||||
|
||||
void Gemma::GenerateBatch(const RuntimeConfig& runtime_config,
|
||||
const QueriesPromptTokens& queries_prompt,
|
||||
const QueriesPos& queries_pos,
|
||||
const QueriesPos& queries_prefix_end,
|
||||
const KVCaches& kv_caches, TimingInfo& timing_info) {
|
||||
// If we did not get passed prefix ends (size 0), assume 0 and pass that on.
|
||||
QueriesPos mutable_queries_prefix_end = queries_prefix_end;
|
||||
std::vector<size_t> prefix_end_vec;
|
||||
if (queries_prefix_end.size() == 0) {
|
||||
prefix_end_vec.resize(queries_prompt.size(), 0);
|
||||
mutable_queries_prefix_end =
|
||||
QueriesPos(prefix_end_vec.data(), prefix_end_vec.size());
|
||||
}
|
||||
AllQueries& all_queries, MatMulEnv& env,
|
||||
TimingInfo& timing_info) const {
|
||||
env.ctx.pools.MaybeStartSpinning(runtime_config.use_spinning);
|
||||
|
||||
env_.parallel.Pools().MaybeStartSpinning(runtime_config.use_spinning);
|
||||
HWY_DYNAMIC_DISPATCH(GenerateBatchT)(model_.Config(), runtime_config,
|
||||
weights_, all_queries, env, timing_info);
|
||||
|
||||
model_.CallForModelWeight<GenerateBatchT>(
|
||||
runtime_config, queries_prompt, queries_pos, mutable_queries_prefix_end,
|
||||
kv_caches, &env_, timing_info);
|
||||
|
||||
env_.parallel.Pools().MaybeStopSpinning(runtime_config.use_spinning);
|
||||
env.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning);
|
||||
}
|
||||
|
||||
void Gemma::GenerateImageTokens(const RuntimeConfig& runtime_config,
|
||||
const Image& image, ImageTokens& image_tokens) {
|
||||
env_.parallel.Pools().MaybeStartSpinning(runtime_config.use_spinning);
|
||||
size_t seq_len, const Image& image,
|
||||
ImageTokens& image_tokens,
|
||||
MatMulEnv& env) const {
|
||||
env.ctx.pools.MaybeStartSpinning(runtime_config.use_spinning);
|
||||
|
||||
model_.CallForModelWeight<GenerateImageTokensT>(runtime_config, image,
|
||||
image_tokens, &env_);
|
||||
HWY_DYNAMIC_DISPATCH(GenerateImageTokensT)(model_.Config(), runtime_config,
|
||||
seq_len, weights_, image,
|
||||
image_tokens, env);
|
||||
|
||||
env_.parallel.Pools().MaybeStopSpinning(runtime_config.use_spinning);
|
||||
}
|
||||
|
||||
// Non-template functions moved from gemma-inl.h to avoid ODR violations.
|
||||
|
||||
void RangeChecks(const ModelConfig& weights_config,
|
||||
size_t& max_generated_tokens, const size_t prompt_size) {
|
||||
if (!weights_config.use_local_attention) {
|
||||
if (max_generated_tokens > weights_config.seq_len) {
|
||||
fprintf(stderr,
|
||||
"WARNING: max_generated_tokens %zu > kSeqLen %u, truncating.\n",
|
||||
max_generated_tokens, weights_config.seq_len);
|
||||
max_generated_tokens = weights_config.seq_len;
|
||||
}
|
||||
}
|
||||
HWY_ASSERT(prompt_size > 0);
|
||||
env.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning);
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
#endif // HWY_ONCE
|
||||
|
|
|
|||
294
gemma/gemma.h
294
gemma/gemma.h
|
|
@ -16,122 +16,154 @@
|
|||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_H_
|
||||
|
||||
#include <functional>
|
||||
#include <random>
|
||||
#include <string>
|
||||
#include <stdio.h>
|
||||
|
||||
#include <vector>
|
||||
|
||||
// IWYU pragma: begin_exports
|
||||
#include "compression/io.h" // Path
|
||||
#include "gemma/activations.h"
|
||||
#include "gemma/common.h"
|
||||
#include "gemma/configs.h"
|
||||
#include "gemma/gemma_args.h"
|
||||
#include "gemma/kv_cache.h"
|
||||
#include "gemma/tokenizer.h"
|
||||
#include "gemma/model_store.h"
|
||||
#include "gemma/weights.h"
|
||||
#include "io/blob_store.h"
|
||||
#include "io/io.h" // Path
|
||||
#include "ops/matmul.h" // MatMulEnv
|
||||
#include "paligemma/image.h"
|
||||
#include "util/allocator.h" // RowVectorBatch
|
||||
#include "util/basics.h" // TokenAndProb
|
||||
#include "util/basics.h" // TokenAndProb
|
||||
#include "util/threading_context.h"
|
||||
#include "hwy/timer.h"
|
||||
// IWYU pragma: end_exports
|
||||
#include "hwy/aligned_allocator.h" // Span
|
||||
|
||||
namespace gcpp {
|
||||
using PromptTokens = hwy::Span<const int>;
|
||||
|
||||
// Batches of independent queries have their own prompt, previous token,
|
||||
// position in the sequence, and KVCache.
|
||||
using QueriesPromptTokens = hwy::Span<const PromptTokens>;
|
||||
using QueriesToken = hwy::Span<const int>;
|
||||
using QueriesPos = hwy::Span<const size_t>;
|
||||
using KVCaches = hwy::Span<KVCache>;
|
||||
struct PerQuery {
|
||||
PromptTokens prompt;
|
||||
|
||||
// StreamFunc is called with (token, probability). For prompt tokens,
|
||||
// probability is 0.0f. StreamFunc should return false to stop generation and
|
||||
// true to continue generation.
|
||||
using StreamFunc = std::function<bool(int, float)>;
|
||||
// BatchStreamFunc is called with (query_idx, pos, token, probability).
|
||||
// For prompt tokens, probability is 0.0f.
|
||||
// StreamFunc should return false to stop generation and true to continue.
|
||||
using BatchStreamFunc = std::function<bool(size_t, size_t, int, float)>;
|
||||
// If not empty, AcceptFunc is called with token. It should return false for
|
||||
// tokens you don't want to generate and true for tokens you want to generate.
|
||||
using AcceptFunc = std::function<bool(int, float)>;
|
||||
// If not empty, SampleFunc is called with the logits for the next token, which
|
||||
// it may modify/overwrite, and its return value is the next generated token
|
||||
// together with its probability.
|
||||
using SampleFunc = std::function<TokenAndProb(float*, size_t)>;
|
||||
// If not empty, LayersOutputFunc is called for layer outputs, specified with:
|
||||
// - index of query within containing batch (if any); zero otherwise.
|
||||
// - position in the tokens sequence
|
||||
// - name of the data, e.g. "tokens" for token IDs
|
||||
// - layer index (or -1 for global outputs)
|
||||
// - pointer to the data array
|
||||
// - size of the data array
|
||||
using LayersOutputFunc = std::function<void(size_t, size_t, const std::string&,
|
||||
int, const float*, size_t)>;
|
||||
// If not empty, ActivationsObserverFunc is invoked after each layer with:
|
||||
// - per-query position within the tokens sequence
|
||||
// - layer index (or -1 for post-norm output)
|
||||
// - activations
|
||||
using ActivationsObserverFunc =
|
||||
std::function<void(const QueriesPos& queries_pos, int, const Activations&)>;
|
||||
// Position in the KV cache: initially zero for the first turn, or when
|
||||
// multi-turn is NOT desired. Incremented by prefill and `StreamAndUpdateEOS`.
|
||||
size_t mutable_pos;
|
||||
// Allows computing the last prefill token as `mutable_pos - initial_pos`,
|
||||
// which might differ from `prompt.size() - 1` for prefix-LM.
|
||||
size_t initial_pos;
|
||||
// Zero for causal attention, or the end of the prefix for prefix-LM style
|
||||
// attention in Paligemma.
|
||||
size_t prefix_end;
|
||||
|
||||
// ImageTokens are represented as a RowVectorBatch, where each "batch" index
|
||||
// corresponds to a token for an image patch as computed by the image encoder.
|
||||
using ImageTokens = RowVectorBatch<float>;
|
||||
KVCache& kv_cache;
|
||||
|
||||
// RuntimeConfig holds configuration for a single generation run.
|
||||
struct RuntimeConfig {
|
||||
// If not empty, batch_stream_token is called for each token in the batch,
|
||||
// instead of stream_token.
|
||||
bool StreamToken(size_t query_idx, size_t pos, int token, float prob) const {
|
||||
if (batch_stream_token) {
|
||||
return batch_stream_token(query_idx, pos, token, prob);
|
||||
// Previous token generated for this query, or the last prompt token. Will be
|
||||
// fed into the next Transformer() call.
|
||||
int prev_token = 0;
|
||||
};
|
||||
|
||||
// Array of `PerQuery`. Referenced by `QBatch` and passed to `GenerateBatch`.
|
||||
struct AllQueries {
|
||||
AllQueries() = default;
|
||||
|
||||
// For `GenerateSingleT`: same prompt/pos, replicated for each KV cache.
|
||||
AllQueries(const PromptTokens& prompt, size_t pos, size_t prefix_end,
|
||||
const hwy::Span<KVCache>& kv_caches) {
|
||||
per_query_.reserve(kv_caches.size());
|
||||
for (size_t i = 0; i < kv_caches.size(); ++i) {
|
||||
HWY_ASSERT(kv_caches[i].SeqLen() == kv_caches[0].SeqLen());
|
||||
per_query_.push_back(PerQuery{
|
||||
.prompt = prompt,
|
||||
.mutable_pos = pos,
|
||||
.initial_pos = pos,
|
||||
.prefix_end = prefix_end,
|
||||
.kv_cache = kv_caches[i],
|
||||
});
|
||||
}
|
||||
return stream_token(token, prob);
|
||||
}
|
||||
|
||||
// Limit on the number of tokens generated.
|
||||
size_t max_generated_tokens;
|
||||
// Batch of queries with initial position set to zero. Causal attention
|
||||
// is requested via empty or all-zero `prefix_end`.
|
||||
AllQueries(
|
||||
const hwy::Span<const PromptTokens>& prompts,
|
||||
const hwy::Span<KVCache>& kv_caches,
|
||||
const hwy::Span<const size_t>& prefix_end = hwy::Span<const size_t>()) {
|
||||
HWY_ASSERT(prompts.size() == kv_caches.size());
|
||||
HWY_ASSERT(prompts.size() == prefix_end.size() || prefix_end.size() == 0);
|
||||
per_query_.reserve(kv_caches.size());
|
||||
for (size_t i = 0; i < kv_caches.size(); ++i) {
|
||||
HWY_ASSERT(kv_caches[i].SeqLen() == kv_caches[0].SeqLen());
|
||||
per_query_.push_back(PerQuery{
|
||||
.prompt = prompts[i],
|
||||
.mutable_pos = 0,
|
||||
.initial_pos = 0,
|
||||
.prefix_end = prefix_end.size() == 0 ? 0 : prefix_end[i],
|
||||
.kv_cache = kv_caches[i],
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// These defaults are overridden by InferenceArgs::CopyTo(*this):
|
||||
// Max tokens per batch during prefill.
|
||||
size_t prefill_tbatch_size = 256;
|
||||
// Max queries per batch (one token from each) during decode.
|
||||
size_t decode_qbatch_size = 16;
|
||||
void Reserve(size_t size) { per_query_.reserve(size); }
|
||||
void Append(const PerQuery& query) { per_query_.push_back(query); }
|
||||
|
||||
// Sampling-related parameters.
|
||||
float temperature; // Temperature for sampling.
|
||||
size_t top_k = kTopK; // Top-k for sampling.
|
||||
std::mt19937* gen; // Random number generator used for sampling.
|
||||
size_t NumQueries() const { return per_query_.size(); }
|
||||
|
||||
int verbosity; // Controls verbosity of printed messages.
|
||||
PerQuery& operator[](size_t query_idx) {
|
||||
HWY_DASSERT(query_idx < NumQueries());
|
||||
return per_query_[query_idx];
|
||||
}
|
||||
const PerQuery& operator[](size_t query_idx) const {
|
||||
HWY_DASSERT(query_idx < NumQueries());
|
||||
return per_query_[query_idx];
|
||||
}
|
||||
|
||||
// Functions operating on the generated tokens.
|
||||
StreamFunc stream_token;
|
||||
BatchStreamFunc batch_stream_token;
|
||||
AcceptFunc accept_token; // if empty, accepts all tokens.
|
||||
SampleFunc sample_func; // if empty, uses SampleTopK.
|
||||
private:
|
||||
std::vector<PerQuery> per_query_;
|
||||
};
|
||||
|
||||
// Observer callbacks for intermediate data.
|
||||
LayersOutputFunc layers_output; // if not empty, called after each layer.
|
||||
ActivationsObserverFunc activations_observer; // if set, called per-layer.
|
||||
// View into AllQueries: either a batch of queries, or a single query for use
|
||||
// in PrefillTBatch or GenerateSingleT. Cheap to create because it holds a
|
||||
// reference to AllQueries.
|
||||
class QBatch {
|
||||
public:
|
||||
QBatch(size_t start, size_t max_size, AllQueries& queries)
|
||||
: start_(start),
|
||||
max_size_(max_size),
|
||||
queries_(queries),
|
||||
size_(HWY_MIN(max_size_, queries_.NumQueries() - start_)) {
|
||||
HWY_ASSERT(max_size_ <= 4096); // non_eos uses `BitSet4096`.
|
||||
HWY_DASSERT(size_ != 0);
|
||||
HWY_DASSERT(start_ + size_ <= queries_.NumQueries());
|
||||
}
|
||||
|
||||
// If not empty, these point to the image tokens and are used in the
|
||||
// PaliGemma prefix-LM style attention.
|
||||
const ImageTokens *image_tokens = nullptr;
|
||||
// Returns a single-query view starting at `qi` relative to this batch.
|
||||
QBatch Single(size_t qi) const { return QBatch(start_ + qi, 1, queries_); }
|
||||
|
||||
// Whether to use thread spinning to reduce barrier synchronization latency.
|
||||
// Mutable so we can change kDefault to kTrue/kFalse during Generate, because
|
||||
// RuntimeConfig is const there and is not passed to the Gemma ctor. This
|
||||
// default decision is likely sufficient because it is based on whether
|
||||
// threads are successfully pinned.
|
||||
mutable Tristate use_spinning = Tristate::kDefault;
|
||||
// How many queries in this batch, <= `queries_.NumQueries()` and `max_size_`.
|
||||
size_t Size() const { return size_; }
|
||||
|
||||
// End-of-sequence token.
|
||||
int eos_id = EOS_ID;
|
||||
// Returns index for use with `AllQueries` and `BatchStreamToken`.
|
||||
size_t QueryIdx(size_t qi) const {
|
||||
HWY_DASSERT(qi < size_);
|
||||
return start_ + qi;
|
||||
}
|
||||
|
||||
// Accessor functions to bridge the previous SoA and current AoS layout.
|
||||
const PromptTokens& Prompt(size_t qi) const {
|
||||
return queries_[QueryIdx(qi)].prompt;
|
||||
}
|
||||
size_t Pos(size_t qi) const { return queries_[QueryIdx(qi)].mutable_pos; }
|
||||
size_t& MutablePos(size_t qi) { return queries_[QueryIdx(qi)].mutable_pos; }
|
||||
size_t InitialPos(size_t qi) const {
|
||||
return queries_[QueryIdx(qi)].initial_pos;
|
||||
}
|
||||
size_t PrefixEnd(size_t qi) const {
|
||||
return queries_[QueryIdx(qi)].prefix_end;
|
||||
}
|
||||
KVCache& KV(size_t qi) const { return queries_[QueryIdx(qi)].kv_cache; }
|
||||
int& PrevToken(size_t qi) { return queries_[QueryIdx(qi)].prev_token; }
|
||||
|
||||
private:
|
||||
size_t start_;
|
||||
size_t max_size_;
|
||||
AllQueries& queries_;
|
||||
size_t size_;
|
||||
};
|
||||
|
||||
struct TimingInfo {
|
||||
|
|
@ -193,82 +225,58 @@ struct TimingInfo {
|
|||
size_t tokens_generated = 0;
|
||||
};
|
||||
|
||||
// After construction, all methods are const and thread-compatible if using
|
||||
// separate ThreadingContext for each thread.
|
||||
class Gemma {
|
||||
public:
|
||||
// Reads old format weights file and tokenizer file.
|
||||
// `env` must remain valid for the lifetime of this Gemma.
|
||||
Gemma(const Path& tokenizer_path, const Path& weights, const ModelInfo& info,
|
||||
MatMulEnv& env);
|
||||
// Reads new format weights file that contains everything in a single file.
|
||||
// `env` must remain valid for the lifetime of this Gemma.
|
||||
Gemma(const Path& weights, MatMulEnv& env);
|
||||
// Allocates weights, caller is responsible for filling them.
|
||||
Gemma(GemmaTokenizer&& tokenizer, const ModelInfo& info, MatMulEnv& env);
|
||||
// Reads weights/config/tokenizer from the `BlobStore` at `loader.weights`.
|
||||
// `ctx` is only used to read tensors, but it is typically also referenced
|
||||
// by the `MatMulEnv` passed to the Generate* methods.
|
||||
Gemma(const LoaderArgs& loader, const InferenceArgs& inference,
|
||||
ThreadingContext& ctx);
|
||||
~Gemma();
|
||||
|
||||
const ModelConfig& GetModelConfig() const { return model_.Config(); }
|
||||
ModelInfo Info() const {
|
||||
return ModelInfo({.model = model_.Config().model,
|
||||
.wrapping = model_.Config().wrapping,
|
||||
.weight = model_.Config().weight});
|
||||
}
|
||||
const GemmaTokenizer& Tokenizer() const { return tokenizer_; }
|
||||
const ModelWeightsStorage& Weights() const { return model_; }
|
||||
ModelWeightsStorage& MutableWeights() { return model_; }
|
||||
void Save(const Path& weights, hwy::ThreadPool& pool) {
|
||||
std::string tokenizer_proto = tokenizer_.Serialize();
|
||||
model_.Save(tokenizer_proto, weights, pool);
|
||||
}
|
||||
const ModelConfig& Config() const { return model_.Config(); }
|
||||
const GemmaTokenizer& Tokenizer() const { return model_.Tokenizer(); }
|
||||
const WeightsPtrs& Weights() const { return weights_; }
|
||||
WeightsPtrs::Mode WeightReadMode() const { return weight_read_mode_; }
|
||||
const GemmaChatTemplate& ChatTemplate() const { return chat_template_; }
|
||||
const InferenceArgs& Inference() const { return inference_; }
|
||||
|
||||
void Save(const Path& weights_path, NestedPools& pools) const;
|
||||
|
||||
// `pos` is the position in the KV cache. Users are responsible for
|
||||
// incrementing it in the `*StreamFunc`, or setting to zero for single-turn.
|
||||
void Generate(const RuntimeConfig& runtime_config, const PromptTokens& prompt,
|
||||
size_t pos, KVCache& kv_cache, TimingInfo& timing_info) {
|
||||
Generate(runtime_config, prompt, pos, /*prefix_end=*/0, kv_cache,
|
||||
size_t pos, KVCache& kv_cache, MatMulEnv& env,
|
||||
TimingInfo& timing_info) const {
|
||||
Generate(runtime_config, prompt, pos, /*prefix_end=*/0, kv_cache, env,
|
||||
timing_info);
|
||||
}
|
||||
// For prefix-LM style attention, we can pass the end of the prefix.
|
||||
void Generate(const RuntimeConfig& runtime_config, const PromptTokens& prompt,
|
||||
size_t pos, size_t prefix_end, KVCache& kv_cache,
|
||||
TimingInfo& timing_info);
|
||||
MatMulEnv& env, TimingInfo& timing_info) const;
|
||||
|
||||
// `queries_pos` are the positions in the KV cache. Users are responsible for
|
||||
// incrementing them in `BatchStreamFunc`, or setting to zero for single-turn.
|
||||
void GenerateBatch(const RuntimeConfig& runtime_config,
|
||||
const QueriesPromptTokens& queries_prompt,
|
||||
const QueriesPos& queries_pos, const KVCaches& kv_caches,
|
||||
TimingInfo& timing_info) {
|
||||
GenerateBatch(runtime_config, queries_prompt, queries_pos,
|
||||
/*queries_prefix_end=*/{}, kv_caches, timing_info);
|
||||
}
|
||||
// For prefix-LM style attention, we can pass the ends of the prefixes.
|
||||
void GenerateBatch(const RuntimeConfig& runtime_config,
|
||||
const QueriesPromptTokens& queries_prompt,
|
||||
const QueriesPos& queries_pos,
|
||||
const QueriesPos& queries_prefix_end,
|
||||
const KVCaches& kv_caches, TimingInfo& timing_info);
|
||||
AllQueries& all_queries, MatMulEnv& env,
|
||||
TimingInfo& timing_info) const;
|
||||
|
||||
// Generates the image tokens by running the image encoder ViT.
|
||||
void GenerateImageTokens(const RuntimeConfig& runtime_config,
|
||||
const Image& image, ImageTokens& image_tokens);
|
||||
void GenerateImageTokens(const RuntimeConfig& runtime_config, size_t seq_len,
|
||||
const Image& image, ImageTokens& image_tokens,
|
||||
MatMulEnv& env) const;
|
||||
|
||||
private:
|
||||
MatMulEnv& env_;
|
||||
|
||||
GemmaTokenizer tokenizer_;
|
||||
// Type-erased so that this can be defined in the header.
|
||||
ModelWeightsStorage model_;
|
||||
BlobReader reader_;
|
||||
ModelStore model_;
|
||||
std::vector<MatOwner> mat_owners_;
|
||||
WeightsPtrs weights_;
|
||||
WeightsPtrs::Mode weight_read_mode_;
|
||||
GemmaChatTemplate chat_template_;
|
||||
InferenceArgs inference_;
|
||||
};
|
||||
|
||||
// Adds BOS token and possibly 'turn' annotations, which depend on `info`
|
||||
// and `pos`, the number of tokens decoded so far; returns the corresponding
|
||||
// tokens. Asserts that tokenization is successful.
|
||||
std::vector<int> WrapAndTokenize(const GemmaTokenizer& tokenizer,
|
||||
const ModelInfo& info, size_t pos,
|
||||
std::string& prompt);
|
||||
void RangeChecks(const ModelConfig& weights_config,
|
||||
size_t& max_generated_tokens, size_t prompt_size);
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_H_
|
||||
|
|
|
|||
|
|
@ -0,0 +1,273 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Shared between various frontends.
|
||||
|
||||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_ARGS_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_ARGS_H_
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include <functional>
|
||||
#include <random>
|
||||
#include <string>
|
||||
|
||||
#include "io/io.h" // Path
|
||||
#include "ops/matmul.h" // MMStorage::kMax*
|
||||
#include "util/args.h"
|
||||
#include "util/basics.h" // Tristate
|
||||
#include "util/mat.h"
|
||||
#include "hwy/aligned_allocator.h" // Span
|
||||
#include "hwy/base.h" // HWY_ABORT
|
||||
#include "hwy/profiler.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
||||
LoaderArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
|
||||
LoaderArgs(const std::string& tokenizer_path,
|
||||
const std::string& weights_path) {
|
||||
Init(); // Init sets to defaults, so assignments must come after Init().
|
||||
tokenizer.path = tokenizer_path;
|
||||
weights.path = weights_path;
|
||||
};
|
||||
|
||||
Path tokenizer;
|
||||
Path weights; // weights file location
|
||||
Tristate map;
|
||||
Tristate to_bf16;
|
||||
Tristate wrapping;
|
||||
|
||||
template <class Visitor>
|
||||
void ForEach(const Visitor& visitor) {
|
||||
visitor(tokenizer, "tokenizer", Path(),
|
||||
"Path name of tokenizer model; only required for pre-2025 format.");
|
||||
visitor(weights, "weights", Path(),
|
||||
"Path name of model weights (.sbs) file.\n Required argument.\n");
|
||||
visitor(map, "map", Tristate::kDefault,
|
||||
"Enable memory-mapping? -1 = auto, 0 = no, 1 = yes.");
|
||||
visitor(to_bf16, "to_bf16", Tristate::kDefault,
|
||||
"Convert weights to bf16? -1 = auto, 0 = no, 1 = yes.");
|
||||
visitor(wrapping, "wrapping", Tristate::kDefault,
|
||||
"Enable prompt wrapping? Specify 0 for pre-2025 format PT models.");
|
||||
}
|
||||
};
|
||||
|
||||
using PromptTokens = hwy::Span<const int>;
|
||||
|
||||
// Batches of independent queries have their own prompt, previous token,
|
||||
// position in the sequence, and KVCache.
|
||||
using QueriesPromptTokens = hwy::Span<const PromptTokens>;
|
||||
using QueriesToken = hwy::Span<const int>;
|
||||
using QueriesPos = hwy::Span<const size_t>;
|
||||
|
||||
// ImageTokens are represented as a matrix, where each row corresponds
|
||||
// to a token for an image patch as computed by the image encoder.
|
||||
using ImageTokens = MatStorageT<float>;
|
||||
|
||||
// StreamFunc is called with (token, probability). For prompt tokens,
|
||||
// probability is 0.0f. StreamFunc should return false to stop generation and
|
||||
// true to continue generation.
|
||||
using StreamFunc = std::function<bool(int, float)>;
|
||||
// BatchStreamFunc is called with (query_idx, pos, token, probability).
|
||||
// For prompt tokens, probability is 0.0f. Generation continues if this returns
|
||||
// true and stops if it returns false. Note that query_idx is absolute, not
|
||||
// relative to the batch.
|
||||
using BatchStreamFunc = std::function<bool(size_t, size_t, int, float)>;
|
||||
// If not empty, AcceptFunc is called with token. It should return false for
|
||||
// tokens you don't want to generate and true for tokens you want to generate.
|
||||
using AcceptFunc = std::function<bool(int, float)>;
|
||||
// If not empty, SampleFunc is called with the logits for the next token, which
|
||||
// it may modify/overwrite, and its return value is the next generated token
|
||||
// together with its probability.
|
||||
using SampleFunc = std::function<TokenAndProb(float*, size_t)>;
|
||||
// If not empty, LayersOutputFunc is called for layer outputs, specified with:
|
||||
// - index of query within containing batch (if any); zero otherwise.
|
||||
// - position in the tokens sequence
|
||||
// - name of the data, e.g. "tokens" for token IDs
|
||||
// - layer index (or -1 for global outputs)
|
||||
// - pointer to the data array
|
||||
// - size of the data array
|
||||
using LayersOutputFunc = std::function<void(size_t, size_t, const std::string&,
|
||||
int, const float*, size_t)>;
|
||||
// If not empty, ActivationsObserverFunc is invoked after each layer with:
|
||||
// - per-query position within the tokens sequence
|
||||
// - layer index (or -1 for post-norm output)
|
||||
// - activations
|
||||
struct Activations;
|
||||
using ActivationsObserverFunc =
|
||||
std::function<void(const QueriesPos& queries_pos, int, const Activations&)>;
|
||||
|
||||
// RuntimeConfig holds configuration for a single generation run.
|
||||
// TODO: move into InferenceArgs, use that directly.
|
||||
struct RuntimeConfig {
|
||||
// If non-null, `batch_stream_token` is called for each token in the batch,
|
||||
// otherwise `stream_token`. `query_idx` is absolute, not batch-relative.
|
||||
bool StreamToken(size_t query_idx, size_t pos, int token, float prob) const {
|
||||
PROFILER_ZONE("Gen.StreamToken");
|
||||
if (batch_stream_token) {
|
||||
return batch_stream_token(query_idx, pos, token, prob);
|
||||
}
|
||||
return stream_token(token, prob);
|
||||
}
|
||||
|
||||
// Limit on the number of tokens generated.
|
||||
size_t max_generated_tokens;
|
||||
|
||||
// These defaults are overridden by InferenceArgs::CopyTo(*this):
|
||||
// Max tokens per batch during prefill.
|
||||
size_t prefill_tbatch_size = 256;
|
||||
// Max queries per batch (one token from each) during decode.
|
||||
size_t decode_qbatch_size = 16;
|
||||
|
||||
// Sampling-related parameters.
|
||||
float temperature; // Temperature for sampling.
|
||||
|
||||
size_t top_k = 1; // Top-k for sampling.
|
||||
std::mt19937* gen; // Random number generator used for sampling.
|
||||
|
||||
int verbosity; // Controls verbosity of printed messages.
|
||||
|
||||
// Functions operating on the generated tokens.
|
||||
StreamFunc stream_token;
|
||||
BatchStreamFunc batch_stream_token;
|
||||
AcceptFunc accept_token; // if empty, accepts all tokens.
|
||||
SampleFunc sample_func; // if empty, uses SampleTopK.
|
||||
|
||||
// Observer callbacks for intermediate data.
|
||||
LayersOutputFunc layers_output; // if not empty, called after each layer.
|
||||
ActivationsObserverFunc activations_observer; // if set, called per-layer.
|
||||
|
||||
// If not empty, these point to the image tokens and are used in the
|
||||
// PaliGemma prefix-LM style attention.
|
||||
const ImageTokens* image_tokens = nullptr;
|
||||
|
||||
// Whether to use thread spinning to reduce barrier synchronization latency.
|
||||
// Mutable so we can change kDefault to kTrue/kFalse during Generate, because
|
||||
// RuntimeConfig is const there and is not passed to the Gemma ctor. This
|
||||
// default decision is likely sufficient because it is based on whether
|
||||
// threads are successfully pinned.
|
||||
mutable Tristate use_spinning = Tristate::kDefault;
|
||||
};
|
||||
|
||||
struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
||||
InferenceArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
|
||||
InferenceArgs() { Init(); };
|
||||
|
||||
bool IsInteractive() const { return prompt.empty() && prompt_file.Empty(); }
|
||||
|
||||
int verbosity;
|
||||
|
||||
size_t seq_len;
|
||||
size_t max_generated_tokens;
|
||||
|
||||
size_t prefill_tbatch_size;
|
||||
size_t decode_qbatch_size;
|
||||
|
||||
float temperature;
|
||||
size_t top_k;
|
||||
bool deterministic;
|
||||
bool multiturn;
|
||||
Path image_file;
|
||||
|
||||
std::string prompt; // Bypasses std::getline
|
||||
// For prompts longer than the Linux terminal's 4K line edit buffer.
|
||||
Path prompt_file;
|
||||
std::string eot_line;
|
||||
|
||||
template <class Visitor>
|
||||
void ForEach(const Visitor& visitor) {
|
||||
visitor(verbosity, "verbosity", 1,
|
||||
"Show verbose developer information\n 0 = only print generation "
|
||||
"output\n 1 = standard user-facing terminal ui\n 2 = show "
|
||||
"developer/debug info).\n Default = 1.",
|
||||
1);
|
||||
|
||||
visitor(seq_len, "seq_len", size_t{8192},
|
||||
"Sequence length, capped by ModelConfig.max_seq_len.");
|
||||
visitor(max_generated_tokens, "max_generated_tokens", size_t{4096},
|
||||
"Maximum number of tokens to generate.");
|
||||
|
||||
visitor(prefill_tbatch_size, "prefill_tbatch", size_t{256},
|
||||
"Prefill: max tokens per batch.");
|
||||
visitor(decode_qbatch_size, "decode_qbatch", size_t{16},
|
||||
"Decode: max queries per batch.");
|
||||
|
||||
visitor(temperature, "temperature", 1.0f, "Temperature for top-K", 2);
|
||||
visitor(top_k, "top_k", size_t{1}, "Number of top-K tokens to sample from",
|
||||
2);
|
||||
visitor(deterministic, "deterministic", false,
|
||||
"Make top-k sampling deterministic", 2);
|
||||
visitor(multiturn, "multiturn", false,
|
||||
"Multiturn mode\n 0 = clear KV cache after every "
|
||||
"interaction\n 1 = continue KV cache after every interaction\n "
|
||||
" Default : 0 (conversation "
|
||||
"resets every turn)");
|
||||
visitor(image_file, "image_file", Path(), "Image file to load.");
|
||||
|
||||
visitor(prompt, "prompt", std::string(""),
|
||||
"Initial prompt for non-interactive mode. When specified, "
|
||||
"generates a response and exits.",
|
||||
1);
|
||||
visitor(prompt_file, "prompt_file", Path(),
|
||||
"Path to file containing the prompt for non-interactive mode. When "
|
||||
" specified, generates a response and exits.",
|
||||
1);
|
||||
|
||||
visitor(
|
||||
eot_line, "eot_line", std::string(""),
|
||||
"End of turn line. "
|
||||
"When you specify this, the prompt will be all lines "
|
||||
"before the line where only the given string appears.\n Default = "
|
||||
"When a newline is encountered, that signals the end of the turn.",
|
||||
2);
|
||||
}
|
||||
|
||||
void CopyTo(RuntimeConfig& runtime_config) const {
|
||||
runtime_config.max_generated_tokens = max_generated_tokens;
|
||||
runtime_config.prefill_tbatch_size = prefill_tbatch_size;
|
||||
runtime_config.decode_qbatch_size = decode_qbatch_size;
|
||||
if (prefill_tbatch_size > MMStorage::kMaxM) {
|
||||
HWY_ABORT(
|
||||
"prefill_tbatch_size %zu > kMaxM %zu: specify a smaller value, "
|
||||
"or increase the constant in MMStorage.\n",
|
||||
prefill_tbatch_size, MMStorage::kMaxM);
|
||||
}
|
||||
if (decode_qbatch_size > MMStorage::kMaxM) {
|
||||
HWY_ABORT(
|
||||
"decode_qbatch_size %zu > kMaxM %zu: specify a smaller value, "
|
||||
"or increase the constant in MMStorage.\n",
|
||||
decode_qbatch_size, MMStorage::kMaxM);
|
||||
}
|
||||
|
||||
runtime_config.temperature = temperature;
|
||||
runtime_config.top_k = top_k;
|
||||
}
|
||||
};
|
||||
|
||||
static inline ThreadingArgs UpdateArgs(const ThreadingArgs& threading_args,
|
||||
const InferenceArgs& inference_args) {
|
||||
if (inference_args.decode_qbatch_size >= 256) {
|
||||
ThreadingArgs copy = threading_args;
|
||||
copy.max_packages = 1;
|
||||
return copy;
|
||||
}
|
||||
return threading_args;
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_ARGS_H_
|
||||
|
|
@ -0,0 +1,192 @@
|
|||
// Copyright 2025 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include "compression/types.h" // GEMMA_DISABLED_TARGETS
|
||||
#ifndef HWY_DISABLED_TARGETS
|
||||
#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS
|
||||
#endif // HWY_DISABLED_TARGETS
|
||||
|
||||
#include "gemma/activations.h"
|
||||
#include "gemma/gemma.h"
|
||||
#include "gemma/gemma_args.h"
|
||||
#include "gemma/weights.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
#include "hwy/profiler.h"
|
||||
|
||||
// Compiles this file for multiple architectures via "foreach_target.h", to
|
||||
// which we pass the filename via macro 'argument'.
|
||||
// clang-format off
|
||||
#undef HWY_TARGET_INCLUDE
|
||||
#define HWY_TARGET_INCLUDE "gemma/griffin.cc" // NOLINT
|
||||
// clang-format on
|
||||
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||
#include "hwy/highway.h"
|
||||
// After highway.h
|
||||
#include "ops/matvec-inl.h"
|
||||
#include "ops/ops-inl.h"
|
||||
|
||||
HWY_BEFORE_NAMESPACE();
|
||||
namespace gcpp {
|
||||
namespace HWY_NAMESPACE {
|
||||
|
||||
void GriffinRecurrent(size_t num_tokens, size_t griffin_layer,
|
||||
const LayerWeightsPtrs* layer_weights,
|
||||
Activations& activations, QBatch& qbatch,
|
||||
MatMulEnv& env) {
|
||||
PROFILER_ZONE("Gen.Griffin");
|
||||
hwy::ThreadPool& pool = env.ctx.pools.Pool(0);
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
using D = hn::ScalableTag<float>;
|
||||
const D df;
|
||||
|
||||
const size_t model_dim = layer_weights->layer_config.model_dim;
|
||||
HWY_DASSERT(model_dim % hn::Lanes(df) == 0);
|
||||
|
||||
const size_t heads = layer_weights->layer_config.heads;
|
||||
const size_t conv_1d_width = layer_weights->layer_config.conv1d_width;
|
||||
HWY_ASSERT_M(conv_1d_width % 2 == 0, "Conv width must be even");
|
||||
const size_t kHeadDim = model_dim / heads;
|
||||
const size_t kMatrixSize = kHeadDim * kHeadDim;
|
||||
|
||||
const size_t num_interleaved = num_tokens * qbatch.Size();
|
||||
const hwy::Divisor div_qbatch(static_cast<uint32_t>(qbatch.Size()));
|
||||
GriffinActivations& griffin = activations.griffin;
|
||||
|
||||
// X / Y linear layers.
|
||||
// TODO: MatMul
|
||||
HWY_DASSERT(griffin.griffin_y.Rows() == griffin.griffin_x.Rows());
|
||||
HWY_DASSERT(num_interleaved == griffin.griffin_y.Rows());
|
||||
CallUpcastedSame(
|
||||
&layer_weights->griffin.linear_x_w, &layer_weights->griffin.linear_y_w,
|
||||
[&](const auto* wx, const auto* wy) {
|
||||
for (size_t r = 0; r < num_interleaved; ++r) {
|
||||
float* HWY_RESTRICT y = griffin.griffin_y.Row(r);
|
||||
float* HWY_RESTRICT x = griffin.griffin_x.Row(r);
|
||||
TwoMatVecAdd(
|
||||
*wx, *wy, 0, model_dim, model_dim,
|
||||
activations.attention.pre_att_rms_out.Row(r),
|
||||
/*add0=*/layer_weights->griffin.linear_x_biases.PackedScale1(),
|
||||
/*add1=*/layer_weights->griffin.linear_y_biases.PackedScale1(),
|
||||
/*out0=*/x, /*out1=*/y, pool);
|
||||
Gelu(y, model_dim);
|
||||
}
|
||||
});
|
||||
|
||||
// Conv1D.
|
||||
for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved;
|
||||
++interleaved_idx) {
|
||||
const size_t qi = div_qbatch.Remainder(interleaved_idx);
|
||||
const size_t batch_idx = div_qbatch.Divide(interleaved_idx);
|
||||
const size_t pos = qbatch.Pos(qi) + batch_idx;
|
||||
float* HWY_RESTRICT x = griffin.griffin_x.Row(qi);
|
||||
|
||||
// cache[i] = input at time t-i.
|
||||
float* HWY_RESTRICT cache[kMaxConv1DWidth];
|
||||
cache[0] = x;
|
||||
for (size_t i = 1; i < conv_1d_width; i++) {
|
||||
cache[i] =
|
||||
qbatch.KV(qi).conv1d_cache.Row(griffin_layer) +
|
||||
((pos + conv_1d_width - 1 - i) % (conv_1d_width - 1)) * model_dim;
|
||||
}
|
||||
for (size_t i = 0; i < model_dim; i += hn::Lanes(df)) {
|
||||
auto xv = hn::Load(df, x + i);
|
||||
auto accum0 =
|
||||
hn::Load(df, layer_weights->griffin.conv_biases.PackedScale1() + i);
|
||||
auto accum1 = hn::Zero(df);
|
||||
for (size_t l = 0; 2 * l < conv_1d_width; l++) {
|
||||
auto wv0 =
|
||||
hn::Load(df, layer_weights->griffin.conv_w.PackedScale1() +
|
||||
(conv_1d_width - 1 - 2 * l) * model_dim + i);
|
||||
auto wv1 =
|
||||
hn::Load(df, layer_weights->griffin.conv_w.PackedScale1() +
|
||||
(conv_1d_width - 2 - 2 * l) * model_dim + i);
|
||||
accum0 = hn::MulAdd(wv0, hn::Load(df, cache[l * 2] + i), accum0);
|
||||
accum1 = hn::MulAdd(wv1, hn::Load(df, cache[l * 2 + 1] + i), accum1);
|
||||
}
|
||||
hn::Store(hn::Add(accum0, accum1), df, x + i);
|
||||
hn::Store(xv, df, cache[HWY_MAX(conv_1d_width, 1) - 1] + i);
|
||||
}
|
||||
}
|
||||
|
||||
// RGLRU
|
||||
for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved;
|
||||
++interleaved_idx) {
|
||||
const size_t qi = div_qbatch.Remainder(interleaved_idx);
|
||||
const size_t batch_idx = div_qbatch.Divide(interleaved_idx);
|
||||
const size_t pos = qbatch.Pos(qi) + batch_idx;
|
||||
|
||||
float* HWY_RESTRICT x = griffin.griffin_x.Row(qi);
|
||||
float* HWY_RESTRICT y = griffin.griffin_y.Row(qi);
|
||||
float* HWY_RESTRICT gate_x = griffin.griffin_gate_x.Row(qi);
|
||||
float* HWY_RESTRICT a = griffin.griffin_multiplier.Row(qi);
|
||||
float* HWY_RESTRICT rnn_state =
|
||||
qbatch.KV(qi).rglru_cache.Row(griffin_layer);
|
||||
|
||||
pool.Run(0, heads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR {
|
||||
size_t head_offset = head * kHeadDim;
|
||||
CallUpcasted(&layer_weights->griffin.gate_w, [&](const auto* gate_w) {
|
||||
TwoOfsMatVecAddLoop(
|
||||
*gate_w, kMatrixSize * head, kMatrixSize * (heads + head), kHeadDim,
|
||||
kHeadDim, x + head_offset,
|
||||
/*add0=*/layer_weights->griffin.gate_biases.PackedScale1() +
|
||||
head_offset,
|
||||
/*add1=*/layer_weights->griffin.gate_biases.PackedScale1() +
|
||||
model_dim + head_offset,
|
||||
/*out0=*/gate_x + head_offset, /*out1=*/a + head_offset);
|
||||
});
|
||||
Sigmoid(gate_x + head_offset, kHeadDim);
|
||||
Sigmoid(a + head_offset, kHeadDim);
|
||||
const auto fn_mul = [](D d, hn::Vec<D> x, hn::Vec<D> gate_x)
|
||||
HWY_ATTR { return hn::Mul(x, gate_x); };
|
||||
hn::Transform1(D(), a + head_offset, kHeadDim,
|
||||
layer_weights->griffin.a.PackedScale1() + head_offset,
|
||||
fn_mul);
|
||||
hn::Transform1(D(), x + head_offset, kHeadDim, gate_x + head_offset,
|
||||
fn_mul);
|
||||
// RNN scan
|
||||
HWY_DASSERT(kHeadDim % hn::Lanes(df) == 0);
|
||||
for (size_t i = 0; i < kHeadDim; i += hn::Lanes(df)) {
|
||||
auto log_a = hn::Load(df, a + head_offset + i);
|
||||
auto gated_x = hn::Load(df, x + head_offset + i);
|
||||
auto rnn = hn::Load(df, rnn_state + head_offset + i);
|
||||
auto a = hn::Exp(df, log_a);
|
||||
auto x_multiplier = hn::Sqrt(hn::NegMulAdd(a, a, hn::Set(df, 1.0f)));
|
||||
if (pos == 0) {
|
||||
x_multiplier = hn::Set(df, 1.0f);
|
||||
}
|
||||
auto new_x = hn::MulAdd(x_multiplier, gated_x, hn::Mul(a, rnn));
|
||||
hn::Store(new_x, df, rnn_state + head_offset + i);
|
||||
|
||||
// Join branches.
|
||||
auto yv = hn::Load(df, y + head_offset + i);
|
||||
auto pre_out = hn::Mul(yv, new_x);
|
||||
hn::Store(pre_out, df, x + head_offset + i);
|
||||
}
|
||||
});
|
||||
} // interleaved_idx
|
||||
|
||||
// Final linear layer.
|
||||
CallMatMul(griffin.griffin_x, layer_weights->griffin.linear_out_w,
|
||||
layer_weights->griffin.linear_out_biases.PackedScale1(), env,
|
||||
activations.attention.att_sums);
|
||||
} // GriffinRecurrent
|
||||
|
||||
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||
} // namespace HWY_NAMESPACE
|
||||
} // namespace gcpp
|
||||
HWY_AFTER_NAMESPACE();
|
||||
|
|
@ -0,0 +1,47 @@
|
|||
// Copyright 2025 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_GRIFFIN_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_GRIFFIN_H_
|
||||
|
||||
// Declares GriffinRecurrent for all SIMD targets.
|
||||
|
||||
#include <stddef.h>
|
||||
|
||||
#include "gemma/gemma.h"
|
||||
#include "hwy/highway.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
// Passed to HWY_VISIT_TARGETS; declares for one target.
|
||||
#define GEMMA_DECL_GRIFFIN(TARGET, NAMESPACE) \
|
||||
namespace NAMESPACE { \
|
||||
void GriffinRecurrent(size_t num_tokens, size_t griffin_layer, \
|
||||
const LayerWeightsPtrs* layer_weights, \
|
||||
Activations& activations, QBatch& qbatch, \
|
||||
MatMulEnv& env); \
|
||||
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
|
||||
} // namespace NAMESPACE
|
||||
|
||||
// Function declarations for each SIMD target. Allows direct call from the
|
||||
// per-target namespace. We may later replace this with dynamic dispatch if
|
||||
// the overhead is acceptable.
|
||||
HWY_VISIT_TARGETS(GEMMA_DECL_GRIFFIN)
|
||||
|
||||
#undef GEMMA_DECL_GRIFFIN
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_GRIFFIN_H_
|
||||
|
|
@ -15,91 +15,75 @@
|
|||
|
||||
#include "gemma/kv_cache.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <stddef.h>
|
||||
|
||||
#include "gemma/common.h" // CallForModel
|
||||
#include "hwy/aligned_allocator.h"
|
||||
#include "hwy/base.h" // ZeroBytes
|
||||
#include "gemma/configs.h"
|
||||
#include "gemma/gemma_args.h"
|
||||
#include "util/mat.h" // ZeroInit
|
||||
#include "hwy/base.h" // HWY_MAX
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
void KVCache::ZeroGriffinCache() {
|
||||
if (conv1d_cache_size != 0) {
|
||||
hwy::ZeroBytes(conv1d_cache.get(),
|
||||
conv1d_cache_size * sizeof(conv1d_cache[0]));
|
||||
}
|
||||
if (rglru_cache_size != 0) {
|
||||
hwy::ZeroBytes(rglru_cache.get(),
|
||||
rglru_cache_size * sizeof(rglru_cache[0]));
|
||||
}
|
||||
if (conv1d_cache.Rows() == 0) return;
|
||||
ZeroInit(conv1d_cache);
|
||||
ZeroInit(rglru_cache);
|
||||
}
|
||||
|
||||
// prefill_tbatch_size is the maximum number of tokens from one query to
|
||||
// prefill at a time.
|
||||
KVCache KVCache::Create(const ModelConfig& weights_config,
|
||||
size_t prefill_tbatch_size) {
|
||||
KVCache kv_cache = {};
|
||||
|
||||
const size_t size_cache_pos = weights_config.CachePosSize();
|
||||
if (size_cache_pos != 0) {
|
||||
// Allocate more so that prefill can always access one batch, even if
|
||||
// near the end of the sequence.
|
||||
kv_cache.seq_len = weights_config.seq_len + prefill_tbatch_size;
|
||||
kv_cache.kv_cache =
|
||||
hwy::AllocateAligned<float>(kv_cache.seq_len * size_cache_pos);
|
||||
}
|
||||
|
||||
const size_t num_griffin_layers = weights_config.NumLayersOfType(
|
||||
LayerAttentionType::kGriffinRecurrentBlock);
|
||||
// TODO(patrickms): Add query batching support for Griffin.
|
||||
if (num_griffin_layers > 0) {
|
||||
uint32_t conv1d_width = 0;
|
||||
for (const auto& layer_config : weights_config.layer_configs) {
|
||||
conv1d_width = std::max(conv1d_width, layer_config.conv1d_width);
|
||||
}
|
||||
const size_t conv1d_cache_size =
|
||||
num_griffin_layers * (conv1d_width == 0 ? 0 : conv1d_width - 1) *
|
||||
weights_config.model_dim;
|
||||
kv_cache.conv1d_cache_size = conv1d_cache_size;
|
||||
if (conv1d_cache_size != 0) {
|
||||
kv_cache.conv1d_cache = hwy::AllocateAligned<float>(conv1d_cache_size);
|
||||
}
|
||||
|
||||
const size_t rglru_cache_size =
|
||||
num_griffin_layers * weights_config.model_dim;
|
||||
kv_cache.rglru_cache_size = rglru_cache_size;
|
||||
if (rglru_cache_size != 0) {
|
||||
kv_cache.rglru_cache = hwy::AllocateAligned<float>(rglru_cache_size);
|
||||
}
|
||||
} // num_griffin_layers
|
||||
|
||||
return kv_cache;
|
||||
static size_t GriffinLayers(const ModelConfig& config) {
|
||||
return config.NumLayersOfType(LayerAttentionType::kGriffinRecurrentBlock);
|
||||
}
|
||||
|
||||
KVCache KVCache::Copy(const ModelConfig& weights_config,
|
||||
size_t prefill_tbatch_size) {
|
||||
KVCache kv_cache_copy = Create(weights_config, prefill_tbatch_size);
|
||||
static size_t GriffinConv1dCols(const ModelConfig& config) {
|
||||
size_t conv1d_width = 0;
|
||||
for (const auto& layer_config : config.layer_configs) {
|
||||
conv1d_width = HWY_MAX(conv1d_width, layer_config.conv1d_width);
|
||||
}
|
||||
// The row offset, in blocks of model_dim is computed mod (conv1d_width - 1),
|
||||
// hence allocate conv1d_width * model_dim total columns.
|
||||
return conv1d_width * config.model_dim;
|
||||
}
|
||||
|
||||
const size_t size_cache_pos = weights_config.CachePosSize();
|
||||
if (size_cache_pos != 0) {
|
||||
std::copy(kv_cache.get(), kv_cache.get() + size_cache_pos * seq_len,
|
||||
kv_cache_copy.kv_cache.get());
|
||||
// Number of rows for KV cache. Note that both rows and cols are u32, and
|
||||
// the total number of elements can exceed 2^32.
|
||||
static size_t CappedSeqLen(const ModelConfig& config,
|
||||
const InferenceArgs& inference_args) {
|
||||
if (inference_args.seq_len > config.max_seq_len) {
|
||||
HWY_WARN("Capping seq_len %zu to config.max_seq_len %u.",
|
||||
inference_args.seq_len, config.max_seq_len);
|
||||
return config.max_seq_len;
|
||||
}
|
||||
return inference_args.seq_len;
|
||||
}
|
||||
|
||||
KVCache::KVCache(const Extents2D& conv1d_extents,
|
||||
const Extents2D& rglru_extents, const Extents2D& kv_extents,
|
||||
const Allocator& allocator)
|
||||
: conv1d_cache("conv1d_cache", conv1d_extents, allocator, MatPadding::kOdd),
|
||||
rglru_cache("rglru_cache", rglru_extents, allocator, MatPadding::kOdd),
|
||||
kv_cache("kv", kv_extents, allocator, MatPadding::kOdd),
|
||||
allocator_(allocator) {}
|
||||
|
||||
KVCache::KVCache(const ModelConfig& config, const InferenceArgs& inference_args,
|
||||
const Allocator& allocator)
|
||||
: KVCache(
|
||||
Extents2D(GriffinLayers(config), GriffinConv1dCols(config)),
|
||||
Extents2D(GriffinLayers(config), config.model_dim),
|
||||
Extents2D(CappedSeqLen(config, inference_args), config.KVCacheCols()),
|
||||
allocator) {}
|
||||
|
||||
KVCache KVCache::Copy() {
|
||||
KVCache copy(conv1d_cache.Extents(), rglru_cache.Extents(),
|
||||
kv_cache.Extents(), allocator_);
|
||||
|
||||
if (conv1d_cache.Rows() != 0) {
|
||||
CopyMat(conv1d_cache, copy.conv1d_cache);
|
||||
CopyMat(rglru_cache, copy.rglru_cache);
|
||||
}
|
||||
|
||||
const size_t num_griffin_layers = weights_config.NumLayersOfType(
|
||||
LayerAttentionType::kGriffinRecurrentBlock);
|
||||
if (num_griffin_layers > 0) {
|
||||
if (conv1d_cache_size != 0) {
|
||||
std::copy(conv1d_cache.get(), conv1d_cache.get() + conv1d_cache_size,
|
||||
kv_cache_copy.conv1d_cache.get());
|
||||
}
|
||||
if (rglru_cache_size != 0) {
|
||||
std::copy(rglru_cache.get(),
|
||||
rglru_cache.get() + rglru_cache_size * sizeof(rglru_cache[0]),
|
||||
kv_cache_copy.rglru_cache.get());
|
||||
}
|
||||
}
|
||||
return kv_cache_copy;
|
||||
CopyMat(kv_cache, copy.kv_cache);
|
||||
|
||||
return copy;
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
|
|
@ -18,34 +18,41 @@
|
|||
|
||||
#include <stddef.h>
|
||||
|
||||
#include "gemma/common.h" // Model
|
||||
#include "hwy/aligned_allocator.h"
|
||||
#include "gemma/configs.h" // ModelConfig
|
||||
#include "gemma/gemma_args.h" // InferenceArgs
|
||||
#include "util/basics.h" // BF16
|
||||
#include "util/mat.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
using KV_t = float;
|
||||
|
||||
struct KVCache {
|
||||
size_t seq_len = 0; // = kSeqLen + prefill_tbatch_size
|
||||
KVCache(const ModelConfig& config, const InferenceArgs& inference_args,
|
||||
const Allocator& allocator);
|
||||
|
||||
// seq_len * kGemmaLayers * kKVHeads * kQKVDim * 2
|
||||
hwy::AlignedFreeUniquePtr<float[]> kv_cache;
|
||||
|
||||
// (kConv1dWidth - 1) * kModelDim * kGriffinLayers
|
||||
hwy::AlignedFreeUniquePtr<float[]> conv1d_cache;
|
||||
size_t conv1d_cache_size = 0;
|
||||
|
||||
// kModelDim * kGriffinLayers
|
||||
hwy::AlignedFreeUniquePtr<float[]> rglru_cache;
|
||||
size_t rglru_cache_size = 0;
|
||||
// Returns a deep copy of the KVCache. Use explicit function instead of
|
||||
// copy ctor to make the cost explicit.
|
||||
KVCache Copy();
|
||||
|
||||
// Zero-initialize the Griffin recurrent block cache, i.e. the conv1d_cache
|
||||
// and rglru_cache.
|
||||
void ZeroGriffinCache();
|
||||
|
||||
static KVCache Create(const ModelConfig& weights_config,
|
||||
size_t prefill_tbatch_size);
|
||||
size_t SeqLen() const { return kv_cache.Rows(); }
|
||||
|
||||
// Returns a deep copy of the KVCache.
|
||||
KVCache Copy(const ModelConfig& weights_config, size_t prefill_tbatch_size);
|
||||
// [griffin_layers, griffin_conv1d_cols * model_dim]
|
||||
MatStorageT<float> conv1d_cache;
|
||||
MatStorageT<float> rglru_cache; // [griffin_layers, model_dim]
|
||||
|
||||
MatStorageT<KV_t> kv_cache; // [seq_len, layers * kv_heads * qkv_dim * 2]
|
||||
|
||||
private:
|
||||
const Allocator& allocator_;
|
||||
|
||||
// For use by other ctor and Copy()
|
||||
KVCache(const Extents2D& conv1d_extents, const Extents2D& rglru_extents,
|
||||
const Extents2D& kv_extents, const Allocator& allocator);
|
||||
};
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
|
|
@ -0,0 +1,464 @@
|
|||
// Copyright 2025 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "gemma/model_store.h"
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include <array>
|
||||
#include <charconv>
|
||||
#include <cstdlib>
|
||||
#include <cstring> // strcmp
|
||||
#include <string>
|
||||
#include <system_error> // std::errc // NOLINT
|
||||
|
||||
#include "compression/types.h"
|
||||
#include "gemma/configs.h" // ModelConfig, kMaxQKVDim
|
||||
#include "gemma/tensor_info.h"
|
||||
#include "gemma/tokenizer.h"
|
||||
#include "io/blob_store.h"
|
||||
#include "io/fields.h"
|
||||
#include "io/io.h" // Path
|
||||
#include "util/basics.h"
|
||||
#include "util/threading_context.h"
|
||||
#include "hwy/base.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
// Single-file format contains blobs with these names:
|
||||
static constexpr char kConfigName[] = "config";
|
||||
static constexpr char kTokenizerName[] = "tokenizer";
|
||||
static constexpr char kMatPtrsName[] = "toc";
|
||||
// Pre-2025 format has one metadata blob. 'F' denoted f32.
|
||||
static constexpr char kDecoratedScalesName[] = "Fscales";
|
||||
|
||||
static void WarnIfExtra(const IFields::ReadResult& result, const char* name) {
|
||||
// No warning if missing_fields > 0: those fields are default-initialized.
|
||||
if (result.extra_u32) {
|
||||
HWY_WARN(
|
||||
"Serialized blob %s has %u extra fields the code is not aware of. "
|
||||
"Consider updating to the latest code from GitHub.",
|
||||
name, result.extra_u32);
|
||||
}
|
||||
}
|
||||
|
||||
// Returns the serialized tokenizer (std::string is required for proto).
|
||||
// Reads it from a blob or from a separate file if pre-2025.
|
||||
static std::string ReadTokenizer(BlobReader& reader,
|
||||
const Path& tokenizer_path) {
|
||||
std::string tokenizer;
|
||||
// Check prevents `CallWithSpan` from printing a warning.
|
||||
if (reader.Find(kTokenizerName)) {
|
||||
if (!reader.CallWithSpan<char>(
|
||||
kTokenizerName, [&tokenizer](const hwy::Span<const char> bytes) {
|
||||
tokenizer.assign(bytes.data(), bytes.size());
|
||||
})) {
|
||||
HWY_WARN(
|
||||
"Reading tokenizer blob failed, please raise an issue. You can "
|
||||
"instead specify a tokenizer file via --tokenizer.");
|
||||
}
|
||||
}
|
||||
|
||||
// Read actual tokenizer from blob.
|
||||
if (!tokenizer.empty() && tokenizer != kMockTokenizer) {
|
||||
if (!tokenizer_path.Empty()) {
|
||||
HWY_WARN("--weights has tokenizer but overriding with %s.",
|
||||
tokenizer_path.path.c_str());
|
||||
return ReadFileToString(tokenizer_path);
|
||||
}
|
||||
|
||||
return tokenizer;
|
||||
}
|
||||
|
||||
// No blob but user specified path to file: read it or abort.
|
||||
if (!tokenizer_path.Empty()) {
|
||||
return ReadFileToString(tokenizer_path);
|
||||
}
|
||||
|
||||
HWY_WARN(
|
||||
"BlobStore does not contain a tokenizer and no --tokenizer was "
|
||||
"specified. Tests may continue but inference will fail.");
|
||||
return kMockTokenizer;
|
||||
}
|
||||
|
||||
using KeyVec = std::vector<std::string>;
|
||||
|
||||
class TypePrefix {
|
||||
public:
|
||||
static Type TypeFromChar(char c) {
|
||||
switch (c) {
|
||||
case 'F':
|
||||
return Type::kF32;
|
||||
case 'B':
|
||||
return Type::kBF16;
|
||||
case '$':
|
||||
return Type::kSFP;
|
||||
case '2':
|
||||
return Type::kNUQ;
|
||||
default:
|
||||
// The other types were not written to pre-2025 files, hence no need to
|
||||
// encode and check for them here.
|
||||
return Type::kUnknown;
|
||||
}
|
||||
}
|
||||
|
||||
TypePrefix(const KeyVec& keys, const BlobReader& reader) {
|
||||
for (size_t key_idx = 0; key_idx < keys.size(); ++key_idx) {
|
||||
const std::string& key = keys[key_idx];
|
||||
const Type type = TypeFromChar(key[0]);
|
||||
const uint64_t bytes = reader.Range(key_idx).bytes;
|
||||
bytes_[static_cast<size_t>(type)] += bytes;
|
||||
blobs_[static_cast<size_t>(type)]++;
|
||||
total_bytes_ += bytes;
|
||||
}
|
||||
}
|
||||
|
||||
// Returns true for pre-2025 format, which has type prefixes and thus the
|
||||
// functions below may be used.
|
||||
bool HasPrefixes() const {
|
||||
return bytes_[static_cast<size_t>(Type::kUnknown)] != total_bytes_;
|
||||
}
|
||||
|
||||
// Returns the weight type deduced from the histogram of blobs per type.
|
||||
// Rationale: We expect a mix of types due to varying precision requirements
|
||||
// for each tensor. The preferred weight type might not even be the most
|
||||
// common, because we prioritize higher compression for the *large* tensors.
|
||||
// Ignore types which only have a few blobs (might be metadata), and assume
|
||||
// that there would be at least 4 of the large tensors (in particular, global
|
||||
// attention layers). Hence return the smallest type with >= 4 blobs.
|
||||
Type DeduceWeightType() const {
|
||||
size_t min_bits = ~size_t{0};
|
||||
Type weight_type = Type::kUnknown;
|
||||
for (size_t i = 0; i < kNumTypes; ++i) {
|
||||
if (blobs_[i] < 4) continue;
|
||||
const size_t bits = TypeBits(static_cast<Type>(i));
|
||||
if (bits < min_bits) {
|
||||
min_bits = bits;
|
||||
weight_type = static_cast<Type>(i);
|
||||
}
|
||||
}
|
||||
return weight_type;
|
||||
}
|
||||
|
||||
// Prints statistics on the total size of tensors by type.
|
||||
void PrintTypeBytes() const {
|
||||
for (size_t type_idx = 0; type_idx < kNumTypes; ++type_idx) {
|
||||
const Type type = static_cast<Type>(type_idx);
|
||||
const uint64_t bytes = bytes_[type_idx];
|
||||
if (bytes == 0) continue;
|
||||
const double percent = 100.0 * bytes / total_bytes_;
|
||||
fprintf(stderr, "%12zu blob bytes (%5.2f%%) of %4s\n",
|
||||
static_cast<size_t>(bytes), percent, TypeName(type));
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
uint64_t total_bytes_ = 0;
|
||||
std::array<size_t, kNumTypes> bytes_{0};
|
||||
std::array<size_t, kNumTypes> blobs_{0};
|
||||
};
|
||||
|
||||
// Returns 0 if the blob does not seem to be a per-layer tensor, otherwise the
|
||||
// layer index.
|
||||
static size_t LayerIdxFromKey(const std::string& key) {
|
||||
const auto parse_num = [&key](size_t begin, size_t end) -> int {
|
||||
HWY_DASSERT(begin <= end);
|
||||
HWY_DASSERT(end <= key.size());
|
||||
int val = 0;
|
||||
auto [ptr, ec] = std::from_chars(key.data() + begin, key.data() + end, val);
|
||||
return (ec == std::errc()) ? val : -1;
|
||||
};
|
||||
|
||||
const size_t suffix_pos = key.rfind('_');
|
||||
// If there is no digit after the last underscore, it is not a layer name.
|
||||
if (suffix_pos == std::string::npos) return 0;
|
||||
if (suffix_pos == key.size() - 1) return 0;
|
||||
|
||||
int layer_idx = parse_num(suffix_pos + 1, key.size());
|
||||
|
||||
HWY_ASSERT(layer_idx < 999);
|
||||
return layer_idx == -1 ? 0 : static_cast<size_t>(layer_idx);
|
||||
}
|
||||
|
||||
// Returns the number of layers based on the largest blob name suffix seen.
|
||||
// This works with or without type prefixes because it searches for suffixes.
|
||||
static size_t DeduceNumLayers(const KeyVec& keys) {
|
||||
// Built-in self-test.
|
||||
{
|
||||
HWY_ASSERT(LayerIdxFromKey("gr_conv_w_2") == 2); // common case
|
||||
HWY_ASSERT(LayerIdxFromKey("prefix_") == 0); // no number
|
||||
HWY_ASSERT(LayerIdxFromKey("c_embedding") == 0); // per-model
|
||||
HWY_ASSERT(LayerIdxFromKey("c_final_norm") == 0); // per-model, two _
|
||||
}
|
||||
|
||||
size_t max_layer_idx = 0;
|
||||
for (const std::string& key : keys) {
|
||||
max_layer_idx = HWY_MAX(max_layer_idx, LayerIdxFromKey(key));
|
||||
}
|
||||
return max_layer_idx + 1;
|
||||
}
|
||||
|
||||
// Looks for known tensor names associated with model families.
|
||||
// This works with or without type prefixes because it searches for substrings.
|
||||
static int DeduceLayerTypes(const BlobReader& reader) {
|
||||
int layer_types = 0;
|
||||
for (size_t key_idx = 0; key_idx < reader.Keys().size(); ++key_idx) {
|
||||
const std::string& key = reader.Keys()[key_idx];
|
||||
if (key.find("gr_conv_w") != std::string::npos) { // NOLINT
|
||||
return kDeducedGriffin;
|
||||
}
|
||||
if (key.find("qkv_ein_w") != std::string::npos) { // NOLINT
|
||||
layer_types |= kDeducedViT;
|
||||
}
|
||||
if (key.find("img_pos_emb") != std::string::npos) { // NOLINT
|
||||
// About 5.88 elements per pixel; assume at least bf16.
|
||||
if (reader.Range(key_idx).bytes > 448 * 448 * 5 * sizeof(BF16)) {
|
||||
layer_types |= kDeduced448;
|
||||
}
|
||||
}
|
||||
}
|
||||
return layer_types;
|
||||
}
|
||||
|
||||
// `wrapping_override` is forwarded from the command line. For pre-2025 files
|
||||
// without `ModelConfig`, it is the only way to force PT.
|
||||
static ModelConfig ReadOrDeduceConfig(BlobReader& reader,
|
||||
Tristate wrapping_override) {
|
||||
const TypePrefix type_prefix(reader.Keys(), reader);
|
||||
Type deduced_weight = Type::kUnknown;
|
||||
if (type_prefix.HasPrefixes()) {
|
||||
deduced_weight = type_prefix.DeduceWeightType();
|
||||
type_prefix.PrintTypeBytes();
|
||||
}
|
||||
|
||||
// Always deduce so we can verify it against the config we read.
|
||||
const size_t layers = DeduceNumLayers(reader.Keys());
|
||||
const int layer_types = DeduceLayerTypes(reader);
|
||||
const Model deduced_model =
|
||||
DeduceModel(reader.blob_path(), layers, layer_types);
|
||||
|
||||
ModelConfig config;
|
||||
// Check first to prevent `CallWithSpan` from printing a warning.
|
||||
if (reader.Find(kConfigName)) {
|
||||
HWY_ASSERT(reader.CallWithSpan<uint32_t>(
|
||||
kConfigName, [&config](const SerializedSpan serialized) {
|
||||
const IFields::ReadResult result = config.Read(serialized, 0);
|
||||
WarnIfExtra(result, kConfigName);
|
||||
HWY_ASSERT_M(result.pos != 0, "Error deserializing config");
|
||||
}));
|
||||
|
||||
HWY_ASSERT(config.model != Model::UNKNOWN);
|
||||
HWY_ASSERT(config.wrapping != PromptWrapping::kSentinel);
|
||||
HWY_ASSERT(config.weight != Type::kUnknown);
|
||||
for (const LayerConfig& layer_config : config.layer_configs) {
|
||||
if (static_cast<size_t>(layer_config.qkv_dim) > kMaxQKVDim) {
|
||||
HWY_ABORT("Increase kMaxQKVDim to at least %u.", layer_config.qkv_dim);
|
||||
}
|
||||
}
|
||||
|
||||
// We trust the deserialized config, but checking helps to validate the
|
||||
// deduction, which we rely on below for pre-2025 files.
|
||||
if (config.model != deduced_model) {
|
||||
const std::string suffix = WrappingSuffix(config.wrapping);
|
||||
HWY_WARN("Detected model %s does not match config %s.",
|
||||
(std::string(ModelPrefix(deduced_model)) + suffix).c_str(),
|
||||
(std::string(ModelPrefix(config.model)) + suffix).c_str());
|
||||
}
|
||||
return config;
|
||||
}
|
||||
|
||||
// Pre-2025 format: no config, rely on deduction plus `wrapping_override`.
|
||||
return ModelConfig(deduced_model, deduced_weight,
|
||||
ChooseWrapping(deduced_model, wrapping_override));
|
||||
}
|
||||
|
||||
static std::vector<float> ReadScales(BlobReader& reader,
|
||||
const ModelConfig& config) {
|
||||
std::vector<float> scales;
|
||||
// Check first to prevent `CallWithSpan` from printing a warning. This blob is
|
||||
// optional even in pre-2025 format; Griffin was the first to include it.
|
||||
if (reader.Find(kDecoratedScalesName)) {
|
||||
HWY_ASSERT(reader.CallWithSpan<float>(
|
||||
kDecoratedScalesName,
|
||||
[&scales](const hwy::Span<const float> scales_blob) {
|
||||
scales.assign(scales_blob.cbegin(), scales_blob.cend());
|
||||
}));
|
||||
}
|
||||
return scales;
|
||||
}
|
||||
|
||||
// Single-file format: reads `MatPtr` from the blob; returns false if not found.
|
||||
bool ModelStore::ReadMatPtrs(BlobReader& reader) {
|
||||
// Check first to prevent `CallWithSpan` from printing a warning.
|
||||
if (!reader.Find(kMatPtrsName)) return false;
|
||||
|
||||
// For verifying `config_.weight`.
|
||||
size_t min_bits = ~size_t{0};
|
||||
Type weight_type = Type::kUnknown;
|
||||
|
||||
HWY_ASSERT(reader.CallWithSpan<uint32_t>(
|
||||
kMatPtrsName, [&, this](SerializedSpan serialized) {
|
||||
for (size_t pos = 0; pos < serialized.size();) {
|
||||
MatPtr mat;
|
||||
const IFields::ReadResult result = mat.Read(serialized, pos);
|
||||
WarnIfExtra(result, mat.Name());
|
||||
if (result.pos == 0) {
|
||||
HWY_ABORT("Deserializing MatPtr %s failed (pos %zu of %zu).",
|
||||
mat.Name(), pos, serialized.size());
|
||||
}
|
||||
pos = result.pos + result.extra_u32;
|
||||
|
||||
// Retrieve actual key index because a writer may have written other
|
||||
// blobs before the tensor data.
|
||||
const BlobRange* range = reader.Find(mat.Name());
|
||||
HWY_ASSERT(range);
|
||||
const size_t key_idx = range->key_idx;
|
||||
AddMatPtr(key_idx, mat);
|
||||
|
||||
const size_t bits = TypeBits(mat.GetType());
|
||||
if (bits < min_bits) {
|
||||
min_bits = bits;
|
||||
weight_type = mat.GetType();
|
||||
}
|
||||
}
|
||||
}));
|
||||
|
||||
HWY_ASSERT(weight_type != Type::kUnknown);
|
||||
HWY_ASSERT(weight_type == config_.weight);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// Pre-2025 format: synthesizes `MatPtr` from the blob names if `!ReadMatPtrs`.
|
||||
void ModelStore::CreateMatPtrs(BlobReader& reader) {
|
||||
const TensorInfoRegistry tensors(config_);
|
||||
|
||||
const KeyVec& keys = reader.Keys();
|
||||
mat_ptrs_.reserve(keys.size());
|
||||
// `key_idx` is the blob index. It is not the same as the index of the
|
||||
// `MatPtr` in `mat_ptrs_` because not all blobs are tensors.
|
||||
for (size_t key_idx = 0; key_idx < keys.size(); ++key_idx) {
|
||||
const Type type = TypePrefix::TypeFromChar(keys[key_idx][0]);
|
||||
if (type == Type::kUnknown) continue; // likely not a tensor
|
||||
|
||||
// Strip type prefix from the key. Still includes layer suffix.
|
||||
const std::string name = keys[key_idx].substr(1);
|
||||
const TensorInfo* info = tensors.Find(name);
|
||||
if (HWY_UNLIKELY(!info)) {
|
||||
if (name == "scales") continue; // ignore, not a tensor.
|
||||
HWY_ABORT("Unknown tensor %s.", name.c_str());
|
||||
}
|
||||
// Unable to set scale already because they are ordered according to
|
||||
// `ForEachTensor`, which we do not know here. The initial value is 1.0f
|
||||
// and we set the correct value in `FindAndUpdateMatPtr`.
|
||||
AddMatPtr(key_idx, MatPtr(name.c_str(), type, ExtentsFromInfo(info)));
|
||||
}
|
||||
HWY_ASSERT(mat_ptrs_.size() <= keys.size());
|
||||
HWY_ASSERT(mat_ptrs_.size() == key_idx_.size());
|
||||
}
|
||||
|
||||
ModelStore::ModelStore(BlobReader& reader, const Path& tokenizer_path,
|
||||
Tristate wrapping)
|
||||
: config_(ReadOrDeduceConfig(reader, wrapping)),
|
||||
tokenizer_(ReadTokenizer(reader, tokenizer_path)) {
|
||||
if (!ReadMatPtrs(reader)) { // Pre-2025 format.
|
||||
CreateMatPtrs(reader);
|
||||
scales_ = ReadScales(reader, config_);
|
||||
// ModelConfig serialized a vector of strings. Unpack into a set for more
|
||||
// efficient lookup.
|
||||
for (const std::string& name : config_.scale_base_names) {
|
||||
scale_base_names_.insert(name);
|
||||
}
|
||||
// If the model has scales, the config must know about it.
|
||||
HWY_ASSERT(scales_.empty() || !scale_base_names_.empty());
|
||||
}
|
||||
|
||||
HWY_ASSERT(key_idx_.size() == mat_ptrs_.size());
|
||||
}
|
||||
|
||||
ModelStore::~ModelStore() {
|
||||
// Sanity check: ensure all scales were consumed.
|
||||
HWY_ASSERT(scales_consumed_ == scales_.size());
|
||||
}
|
||||
|
||||
const MatPtr* ModelStore::FindMat(const char* name) const {
|
||||
auto it = mat_idx_for_name_.find(name);
|
||||
if (it == mat_idx_for_name_.end()) return nullptr;
|
||||
const size_t mat_idx = it->second;
|
||||
const MatPtr* file_mat = &mat_ptrs_[mat_idx];
|
||||
HWY_ASSERT(!strcmp(file_mat->Name(), name));
|
||||
return file_mat;
|
||||
}
|
||||
|
||||
bool ModelStore::FindAndUpdateMatPtr(MatPtr& mat, size_t& key_idx) const {
|
||||
const MatPtr* file_mat = FindMat(mat.Name());
|
||||
if (!file_mat) return false;
|
||||
if (file_mat->Rows() != mat.Rows() || file_mat->Cols() != mat.Cols()) {
|
||||
HWY_ABORT("Tensor %s shape %zu %zu mismatches file %zu %zu.", mat.Name(),
|
||||
mat.Rows(), mat.Cols(), file_mat->Rows(), file_mat->Cols());
|
||||
}
|
||||
// `Compress()` output is always packed because it assumes a 1D array.
|
||||
HWY_ASSERT(mat.IsPacked());
|
||||
// Update fields. Name already matched, otherwise we would not find it.
|
||||
// For MatPtr tensors, the type will be `kUnknown`. If it was a `MatPtrT`,
|
||||
// ensure the type set via code matches the file.
|
||||
HWY_ASSERT_M(
|
||||
mat.GetType() == Type::kUnknown || mat.GetType() == file_mat->GetType(),
|
||||
mat.Name());
|
||||
mat.SetType(file_mat->GetType());
|
||||
if (scales_.empty()) {
|
||||
// `file_mat->Scale()` is either read from file, or we have pre-2025 format
|
||||
// without the optional scales, and it is default-initialized to 1.0f.
|
||||
mat.SetScale(file_mat->Scale());
|
||||
} else { // Pre-2025 with scaling factors: set next if `mat` wants one.
|
||||
if (scale_base_names_.find(StripLayerSuffix(mat.Name())) !=
|
||||
scale_base_names_.end()) {
|
||||
HWY_ASSERT(scales_consumed_ < scales_.size());
|
||||
mat.SetScale(scales_[scales_consumed_++]);
|
||||
}
|
||||
}
|
||||
|
||||
key_idx = key_idx_[file_mat - mat_ptrs_.data()];
|
||||
return true;
|
||||
}
|
||||
|
||||
static void AddBlob(const char* name, const std::vector<uint32_t>& data,
|
||||
BlobWriter& writer) {
|
||||
HWY_ASSERT(!data.empty());
|
||||
writer.Add(name, data.data(), data.size() * sizeof(data[0]));
|
||||
}
|
||||
|
||||
void WriteSingleFile(const ModelConfig& config, const GemmaTokenizer& tokenizer,
|
||||
const std::vector<uint32_t>& serialized_mat_ptrs,
|
||||
BlobWriter& writer) {
|
||||
HWY_ASSERT(config.model != Model::UNKNOWN);
|
||||
HWY_ASSERT(config.weight != Type::kUnknown);
|
||||
HWY_ASSERT(config.wrapping != PromptWrapping::kSentinel);
|
||||
const std::vector<uint32_t> serialized_config = config.Write();
|
||||
AddBlob(kConfigName, serialized_config, writer);
|
||||
|
||||
const std::string serialized_tokenizer = tokenizer.Serialize();
|
||||
HWY_ASSERT(!serialized_tokenizer.empty());
|
||||
writer.Add(kTokenizerName, serialized_tokenizer.data(),
|
||||
serialized_tokenizer.size());
|
||||
|
||||
AddBlob(kMatPtrsName, serialized_mat_ptrs, writer);
|
||||
|
||||
writer.WriteAll();
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
@ -0,0 +1,111 @@
|
|||
// Copyright 2025 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Reads/writes model metadata (all but the weights) from/to a `BlobStore`.
|
||||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_MODEL_STORE_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_MODEL_STORE_H_
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
// IWYU pragma: begin_exports
|
||||
#include "gemma/configs.h" // ModelConfig
|
||||
#include "gemma/tokenizer.h"
|
||||
#include "io/blob_store.h"
|
||||
#include "io/io.h" // Path
|
||||
#include "util/basics.h" // Tristate
|
||||
#include "util/mat.h" // MatPtr
|
||||
// IWYU pragma: end_exports
|
||||
|
||||
#include "util/allocator.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
// Reads and holds the model config, tokenizer and all `MatPtr`: everything
|
||||
// except the tensor data, which are read/written by `weights.cc`.
|
||||
//
|
||||
// As of 2025-04, the `BlobStore` format includes blobs for `ModelConfig`,
|
||||
// tokenizer, and all `MatPtr` metadata. "Pre-2025" format instead stored the
|
||||
// tokenizer in a separate file, encoded tensor type in a prefix of the blob
|
||||
// name, and had a blob for tensor scaling factors. We still support reading
|
||||
// both, but only write single-file format.
|
||||
class ModelStore {
|
||||
public:
|
||||
// Reads from file(s) or aborts on error. The latter two arguments are only
|
||||
// used for pre-2025 files.
|
||||
ModelStore(BlobReader& reader, const Path& tokenizer_path = Path(),
|
||||
Tristate wrapping = Tristate::kDefault);
|
||||
~ModelStore();
|
||||
|
||||
const ModelConfig& Config() const {
|
||||
HWY_ASSERT(config_.model != Model::UNKNOWN);
|
||||
return config_;
|
||||
}
|
||||
|
||||
const GemmaTokenizer& Tokenizer() const { return tokenizer_; }
|
||||
|
||||
// Returns nullptr if `name` is not available for loading, otherwise the
|
||||
// metadata of that tensor.
|
||||
const MatPtr* FindMat(const char* name) const;
|
||||
|
||||
// Returns false if `mat` is not available for loading, otherwise updates
|
||||
// `mat` with metadata from the file and sets `key_idx` for use by
|
||||
// `BlobReader`. Called via `ReadOrAllocate` in `weights.cc`.
|
||||
bool FindAndUpdateMatPtr(MatPtr& mat, size_t& key_idx) const;
|
||||
|
||||
private:
|
||||
void AddMatPtr(const size_t key_idx, const MatPtr& mat) {
|
||||
auto pair_ib = mat_idx_for_name_.insert({mat.Name(), mat_ptrs_.size()});
|
||||
HWY_ASSERT_M(pair_ib.second, mat.Name()); // Ensure inserted/unique.
|
||||
mat_ptrs_.push_back(mat);
|
||||
key_idx_.push_back(key_idx);
|
||||
}
|
||||
|
||||
bool ReadMatPtrs(BlobReader& reader);
|
||||
void CreateMatPtrs(BlobReader& reader); // Aborts on error.
|
||||
|
||||
ModelConfig config_;
|
||||
GemmaTokenizer tokenizer_;
|
||||
|
||||
// All `MatPtr` present in the `BlobStore`, see `ReadMatPtrs`/`CreateMatPtrs`.
|
||||
std::vector<MatPtr> mat_ptrs_;
|
||||
// For each of `mat_ptrs_`, the index within `BlobReader::Keys()`. This is
|
||||
// not necessarily iota because some blobs are not tensors, and callers may
|
||||
// have added blobs before ours.
|
||||
std::vector<size_t> key_idx_;
|
||||
// Index within `mat_ptrs_` and `key_idx_` for each tensor name.
|
||||
std::unordered_map<std::string, size_t> mat_idx_for_name_;
|
||||
|
||||
// Only used if `!ReadMatPtrs` (pre-2025 format):
|
||||
std::vector<float> scales_;
|
||||
std::unordered_set<std::string> scale_base_names_;
|
||||
mutable size_t scales_consumed_ = 0;
|
||||
};
|
||||
|
||||
// Adds metadata blobs to `writer` and writes everything to `path`. This
|
||||
// produces a single BlobStore file holding everything required for inference.
|
||||
void WriteSingleFile(const ModelConfig& config, const GemmaTokenizer& tokenizer,
|
||||
const std::vector<uint32_t>& serialized_mat_ptrs,
|
||||
BlobWriter& writer);
|
||||
|
||||
} // namespace gcpp
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_MODEL_STORE_H_
|
||||
237
gemma/run.cc
237
gemma/run.cc
|
|
@ -15,22 +15,22 @@
|
|||
|
||||
// Command line text interface to gemma.
|
||||
|
||||
#include <stdio.h>
|
||||
|
||||
#include <iostream>
|
||||
#include <random>
|
||||
#include <string>
|
||||
#include <string_view>
|
||||
#include <vector>
|
||||
|
||||
// Placeholder for internal header, do not modify.
|
||||
#include "compression/shared.h" // PromptWrapping
|
||||
#include "compression/types.h" // PromptWrapping
|
||||
#include "evals/benchmark_helper.h"
|
||||
#include "gemma/common.h"
|
||||
#include "gemma/gemma.h" // Gemma
|
||||
#include "ops/matmul.h" // MatMulEnv
|
||||
#include "gemma/gemma_args.h"
|
||||
#include "gemma/tokenizer.h" // WrapAndTokenize
|
||||
#include "ops/matmul.h" // MatMulEnv
|
||||
#include "paligemma/image.h"
|
||||
#include "util/app.h"
|
||||
#include "util/args.h" // HasHelp
|
||||
#include "util/threading.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/highway.h"
|
||||
#include "hwy/profiler.h"
|
||||
|
|
@ -55,9 +55,8 @@ static constexpr std::string_view kAsciiArtBanner = R""(
|
|||
|___/ |_| |_|
|
||||
)"";
|
||||
|
||||
std::string GetPrompt(std::istream& input, int verbosity,
|
||||
std::string_view eot_line) {
|
||||
PROFILER_ZONE("Gen.input");
|
||||
std::string GetPromptFromStream(std::istream& input, int verbosity,
|
||||
std::string_view eot_line) {
|
||||
if (verbosity >= 1) {
|
||||
std::cout << "> " << std::flush;
|
||||
}
|
||||
|
|
@ -77,36 +76,55 @@ std::string GetPrompt(std::istream& input, int verbosity,
|
|||
return prompt_string;
|
||||
}
|
||||
|
||||
// Get prompt either from interactive input or command line
|
||||
std::string GetPrompt(const InferenceArgs& inference) {
|
||||
PROFILER_ZONE("Gen.input");
|
||||
// If prompt is provided via command line, use that
|
||||
if (!inference.prompt.empty()) {
|
||||
return inference.prompt;
|
||||
}
|
||||
if (!inference.prompt_file.Empty()) {
|
||||
return ReadFileToString(inference.prompt_file);
|
||||
}
|
||||
|
||||
return GetPromptFromStream(std::cin, inference.verbosity, inference.eot_line);
|
||||
}
|
||||
|
||||
// The main Read-Eval-Print Loop.
|
||||
void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
|
||||
const InferenceArgs& args, const AcceptFunc& accept_token,
|
||||
std::string& eot_line) {
|
||||
void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
|
||||
const Gemma& gemma, KVCache& kv_cache, MatMulEnv& env) {
|
||||
PROFILER_ZONE("Gen.misc");
|
||||
size_t abs_pos = 0; // across turns
|
||||
size_t tokens_generated_this_turn = 0; // differentiates prefill from reply
|
||||
size_t prompt_size = 0;
|
||||
const ModelConfig& config = gemma.Config();
|
||||
|
||||
std::mt19937 gen;
|
||||
InitGenerator(args, gen);
|
||||
InitGenerator(inference, gen);
|
||||
|
||||
const bool have_image = !args.image_file.path.empty();
|
||||
const bool have_image = !inference.image_file.path.empty();
|
||||
Image image;
|
||||
ImageTokens image_tokens;
|
||||
const size_t pool_dim = config.vit_config.pool_dim;
|
||||
ImageTokens image_tokens(
|
||||
"image_tokens",
|
||||
have_image ? Extents2D(config.vit_config.seq_len / (pool_dim * pool_dim),
|
||||
config.model_dim)
|
||||
: Extents2D(0, 0),
|
||||
env.ctx.allocator, MatPadding::kOdd);
|
||||
image_tokens.AllocateAndAttachRowPtrs(env.row_ptrs);
|
||||
if (have_image) {
|
||||
size_t pool_dim = model.GetModelConfig().vit_config.pool_dim;
|
||||
image_tokens = ImageTokens(Extents2D(
|
||||
model.GetModelConfig().vit_config.seq_len / (pool_dim * pool_dim),
|
||||
model.GetModelConfig().model_dim));
|
||||
HWY_ASSERT(model.Info().wrapping == PromptWrapping::PALIGEMMA ||
|
||||
model.Info().wrapping == PromptWrapping::GEMMA_VLM);
|
||||
HWY_ASSERT(image.ReadPPM(args.image_file.path));
|
||||
const size_t image_size = model.GetModelConfig().vit_config.image_size;
|
||||
HWY_ASSERT(config.wrapping == PromptWrapping::PALIGEMMA ||
|
||||
config.wrapping == PromptWrapping::GEMMA_VLM);
|
||||
HWY_ASSERT(image.ReadPPM(inference.image_file.path));
|
||||
const size_t image_size = config.vit_config.image_size;
|
||||
image.Resize(image_size, image_size);
|
||||
RuntimeConfig runtime_config = {
|
||||
.gen = &gen, .verbosity = app.verbosity, .use_spinning = app.spin};
|
||||
RuntimeConfig runtime_config = {.gen = &gen,
|
||||
.verbosity = inference.verbosity,
|
||||
.use_spinning = threading.spin};
|
||||
double image_tokens_start = hwy::platform::Now();
|
||||
model.GenerateImageTokens(runtime_config, image, image_tokens);
|
||||
if (app.verbosity >= 1) {
|
||||
gemma.GenerateImageTokens(runtime_config, kv_cache.SeqLen(), image,
|
||||
image_tokens, env);
|
||||
if (inference.verbosity >= 1) {
|
||||
double image_tokens_duration = hwy::platform::Now() - image_tokens_start;
|
||||
fprintf(stderr,
|
||||
"\n\n[ Timing info ] Image token generation took: %d ms\n",
|
||||
|
|
@ -121,21 +139,21 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
|
|||
const bool first_response_token = tokens_generated_this_turn == prompt_size;
|
||||
++tokens_generated_this_turn;
|
||||
if (in_prompt) {
|
||||
if (app.verbosity >= 1) {
|
||||
std::cerr << "." << std::flush;
|
||||
if (inference.verbosity >= 1) {
|
||||
std::cout << "." << std::flush;
|
||||
}
|
||||
return true;
|
||||
} else if (model.GetModelConfig().IsEOS(token)) {
|
||||
if (app.verbosity >= 2) {
|
||||
} else if (config.IsEOS(token)) {
|
||||
if (inference.verbosity >= 2) {
|
||||
std::cout << "\n[ End ]\n";
|
||||
}
|
||||
return true;
|
||||
}
|
||||
std::string token_text;
|
||||
HWY_ASSERT(model.Tokenizer().Decode(std::vector<int>{token}, &token_text));
|
||||
HWY_ASSERT(gemma.Tokenizer().Decode(std::vector<int>{token}, &token_text));
|
||||
if (first_response_token) {
|
||||
token_text.erase(0, token_text.find_first_not_of(" \t\n"));
|
||||
if (app.verbosity >= 1) {
|
||||
if (inference.verbosity >= 1) {
|
||||
std::cout << "\n\n";
|
||||
}
|
||||
}
|
||||
|
|
@ -146,72 +164,77 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
|
|||
while (true) { // Loop until user quits.
|
||||
tokens_generated_this_turn = 0;
|
||||
|
||||
// Read prompt and handle special commands.
|
||||
std::string prompt_string = GetPrompt(std::cin, app.verbosity, eot_line);
|
||||
if (!std::cin) return;
|
||||
// If !eot_line.empty(), we append \n, so only look at the first 2 chars.
|
||||
if (prompt_string.size() >= 2 && prompt_string[0] == '%') {
|
||||
if (prompt_string[1] == 'q' || prompt_string[1] == 'Q') return;
|
||||
if (prompt_string[1] == 'c' || prompt_string[1] == 'C') {
|
||||
abs_pos = 0;
|
||||
const std::string prompt_string = GetPrompt(inference);
|
||||
const bool is_interactive = inference.IsInteractive();
|
||||
if (is_interactive) { // handle special commands:
|
||||
if (!std::cin) return;
|
||||
|
||||
// If !eot_line.empty(), we append \n, so only look at the first 2 chars.
|
||||
if (prompt_string.size() >= 2 && prompt_string[0] == '%') {
|
||||
if (prompt_string[1] == 'q' || prompt_string[1] == 'Q') return;
|
||||
if (prompt_string[1] == 'c' || prompt_string[1] == 'C') {
|
||||
abs_pos = 0;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
if (prompt_string.empty()) {
|
||||
std::cout << "Use '%q' to quit.\n";
|
||||
continue;
|
||||
}
|
||||
}
|
||||
if (prompt_string.empty()) {
|
||||
std::cout << "Use '%q' to quit.\n";
|
||||
continue;
|
||||
|
||||
// Set up runtime config.
|
||||
TimingInfo timing_info = {.verbosity = inference.verbosity};
|
||||
RuntimeConfig runtime_config = {.gen = &gen,
|
||||
.verbosity = inference.verbosity,
|
||||
.stream_token = stream_token,
|
||||
.use_spinning = threading.spin};
|
||||
inference.CopyTo(runtime_config);
|
||||
std::vector<int> prompt;
|
||||
size_t prefix_end = 0;
|
||||
if (have_image) {
|
||||
prompt = WrapAndTokenize(gemma.Tokenizer(), gemma.ChatTemplate(),
|
||||
config.wrapping, abs_pos, prompt_string,
|
||||
image_tokens.Rows());
|
||||
runtime_config.image_tokens = &image_tokens;
|
||||
prompt_size = prompt.size();
|
||||
if (config.wrapping == PromptWrapping::PALIGEMMA) {
|
||||
// The end of the prefix for prefix-LM style attention in Paligemma.
|
||||
// See Figure 2 of https://arxiv.org/abs/2407.07726.
|
||||
prefix_end = prompt_size;
|
||||
// We need to look at all the tokens for the prefix.
|
||||
// NOTE: Online softmax is on the roadmap, after which this requirement
|
||||
// can be lifted.
|
||||
runtime_config.prefill_tbatch_size = prompt_size;
|
||||
}
|
||||
} else {
|
||||
prompt = WrapAndTokenize(gemma.Tokenizer(), gemma.ChatTemplate(),
|
||||
config.wrapping, abs_pos, prompt_string);
|
||||
prompt_size = prompt.size();
|
||||
}
|
||||
|
||||
// Wrap, tokenize and maybe log prompt tokens.
|
||||
std::vector<int> prompt = WrapAndTokenize(
|
||||
model.Tokenizer(), model.Info(), abs_pos, prompt_string);
|
||||
prompt_size = prompt.size();
|
||||
if constexpr (kVerboseLogTokens) {
|
||||
for (int i = 0; i < prompt_size; ++i) {
|
||||
fprintf(stderr, "DDD TOKEN %3d: %6d\n", i, prompt[i]);
|
||||
}
|
||||
}
|
||||
|
||||
// Set up runtime config.
|
||||
TimingInfo timing_info = {.verbosity = app.verbosity};
|
||||
RuntimeConfig runtime_config = {.gen = &gen,
|
||||
.verbosity = app.verbosity,
|
||||
.stream_token = stream_token,
|
||||
.accept_token = accept_token,
|
||||
.use_spinning = app.spin};
|
||||
args.CopyTo(runtime_config);
|
||||
size_t prefix_end = 0;
|
||||
if (have_image) {
|
||||
runtime_config.image_tokens = &image_tokens;
|
||||
if (model.Info().wrapping == PromptWrapping::PALIGEMMA) {
|
||||
prompt.insert(prompt.begin(), image_tokens.BatchSize(), 0);
|
||||
} else if (model.Info().wrapping == PromptWrapping::GEMMA_VLM) {
|
||||
size_t seq_len = model.GetModelConfig().vit_config.seq_len;
|
||||
size_t pool_dim = model.GetModelConfig().vit_config.pool_dim;
|
||||
prompt =
|
||||
WrapVLM(model.Tokenizer(), model.Info(), abs_pos, prompt,
|
||||
image_tokens.BatchSize(), seq_len / (pool_dim * pool_dim));
|
||||
}
|
||||
prompt_size = prompt.size();
|
||||
// The end of the prefix for prefix-LM style attention in Paligemma.
|
||||
// See Figure 2 of https://arxiv.org/abs/2407.07726.
|
||||
prefix_end = prompt_size;
|
||||
// We need to look at all the tokens for the prefix.
|
||||
runtime_config.prefill_tbatch_size = prompt_size;
|
||||
}
|
||||
|
||||
// Generate until EOS or max_generated_tokens.
|
||||
if (app.verbosity >= 1) {
|
||||
if (inference.verbosity >= 1) {
|
||||
std::cerr << "\n[ Reading prompt ] " << std::flush;
|
||||
}
|
||||
model.Generate(runtime_config, prompt, abs_pos, prefix_end, kv_cache,
|
||||
gemma.Generate(runtime_config, prompt, abs_pos, prefix_end, kv_cache, env,
|
||||
timing_info);
|
||||
std::cout << "\n\n";
|
||||
|
||||
// In non-interactive mode, we only process one prompt/turn.
|
||||
if (!is_interactive) break;
|
||||
|
||||
// Prepare for the next turn. Works only for PaliGemma.
|
||||
if (!args.multiturn || model.Info().wrapping == PromptWrapping::PALIGEMMA) {
|
||||
if (!inference.multiturn || config.wrapping == PromptWrapping::PALIGEMMA) {
|
||||
abs_pos = 0; // Start a new turn at position 0.
|
||||
InitGenerator(args, gen);
|
||||
InitGenerator(inference, gen);
|
||||
} else {
|
||||
// The last token was either EOS, then it should be ignored because it is
|
||||
// never part of the dialog, see Table 5 in the Gemma-2 paper:
|
||||
|
|
@ -227,20 +250,17 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
|
|||
}
|
||||
}
|
||||
|
||||
void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
|
||||
void Run(const LoaderArgs& loader, const ThreadingArgs& threading,
|
||||
const InferenceArgs& inference) {
|
||||
PROFILER_ZONE("Run.misc");
|
||||
|
||||
// Note that num_threads is an upper bound; we also limit to the number of
|
||||
// detected and enabled cores.
|
||||
const BoundedTopology topology = CreateTopology(app);
|
||||
NestedPools pools = CreatePools(topology, app);
|
||||
MatMulEnv env(topology, pools);
|
||||
if (app.verbosity >= 2) env.print_best = true;
|
||||
Gemma model = CreateGemma(loader, env);
|
||||
KVCache kv_cache =
|
||||
KVCache::Create(model.GetModelConfig(), inference.prefill_tbatch_size);
|
||||
ThreadingContext ctx(UpdateArgs(threading, inference));
|
||||
MatMulEnv env(ctx);
|
||||
if (inference.verbosity >= 2) env.print_best = true;
|
||||
const Gemma gemma(loader, inference, ctx);
|
||||
KVCache kv_cache(gemma.Config(), inference, ctx.allocator);
|
||||
|
||||
if (app.verbosity >= 1) {
|
||||
if (inference.verbosity >= 1) {
|
||||
std::string instructions =
|
||||
"*Usage*\n"
|
||||
" Enter an instruction and press enter (%C resets conversation, "
|
||||
|
|
@ -261,46 +281,37 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
|
|||
instructions += multiturn;
|
||||
instructions += examples;
|
||||
|
||||
std::cout << "\033[2J\033[1;1H" // clear screen
|
||||
<< kAsciiArtBanner << "\n\n";
|
||||
ShowConfig(loader, inference, app, topology, pools);
|
||||
std::cout << "\n" << instructions << "\n";
|
||||
// Skip the banner and instructions in non-interactive mode
|
||||
if (inference.IsInteractive()) {
|
||||
std::cout << "\033[2J\033[1;1H" // clear screen
|
||||
<< kAsciiArtBanner << "\n\n";
|
||||
ShowConfig(loader, threading, inference, gemma.Config(),
|
||||
gemma.WeightReadMode(), ctx);
|
||||
std::cout << "\n" << instructions << "\n";
|
||||
}
|
||||
}
|
||||
|
||||
ReplGemma(model, kv_cache, app, inference, AcceptFunc(), app.eot_line);
|
||||
ReplGemma(threading, inference, gemma, kv_cache, env);
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
gcpp::InternalInit();
|
||||
{
|
||||
PROFILER_ZONE("Startup.misc");
|
||||
|
||||
// Placeholder for internal init, do not modify.
|
||||
|
||||
gcpp::LoaderArgs loader(argc, argv);
|
||||
gcpp::ThreadingArgs threading(argc, argv);
|
||||
gcpp::InferenceArgs inference(argc, argv);
|
||||
gcpp::AppArgs app(argc, argv);
|
||||
|
||||
if (gcpp::HasHelp(argc, argv)) {
|
||||
std::cerr << gcpp::kAsciiArtBanner;
|
||||
gcpp::ShowHelp(loader, inference, app);
|
||||
gcpp::ShowHelp(loader, threading, inference);
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (const char* error = loader.Validate()) {
|
||||
std::cerr << gcpp::kAsciiArtBanner;
|
||||
gcpp::ShowHelp(loader, inference, app);
|
||||
HWY_ABORT("\nInvalid args: %s", error);
|
||||
}
|
||||
|
||||
if (const char* error = inference.Validate()) {
|
||||
std::cerr << gcpp::kAsciiArtBanner;
|
||||
gcpp::ShowHelp(loader, inference, app);
|
||||
HWY_ABORT("\nInvalid args: %s", error);
|
||||
}
|
||||
|
||||
gcpp::Run(loader, inference, app);
|
||||
gcpp::Run(loader, threading, inference);
|
||||
}
|
||||
PROFILER_PRINT_RESULTS(); // Must call outside the zone above.
|
||||
return 0;
|
||||
|
|
|
|||
|
|
@ -1,607 +0,0 @@
|
|||
#include "gemma/tensor_index.h"
|
||||
|
||||
#include <stddef.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <cctype>
|
||||
#include <cstdio>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "compression/shared.h"
|
||||
#include "gemma/configs.h"
|
||||
|
||||
namespace gcpp {
|
||||
namespace {
|
||||
|
||||
// Returns the non-layer tensors for the model.
|
||||
std::vector<TensorInfo> ModelTensors(const ModelConfig& config) {
|
||||
return {
|
||||
TensorInfo{
|
||||
.name = "c_embedding",
|
||||
.source_names = {"embedder/input_embedding"},
|
||||
.axes = {0, 1},
|
||||
.shape = {config.vocab_size, config.model_dim},
|
||||
.min_size = Type::kBF16,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "c_final_norm",
|
||||
.source_names = {"final_norm/scale"},
|
||||
.axes = {0},
|
||||
.shape = {config.model_dim},
|
||||
.min_size = Type::kBF16,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "enc_norm_bias",
|
||||
.source_names = {"img/Transformer/encoder_norm/bias"},
|
||||
.axes = {0},
|
||||
.shape = {config.vit_config.model_dim},
|
||||
.min_size = Type::kBF16,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "enc_norm_scale",
|
||||
.source_names = {"img/Transformer/encoder_norm/scale"},
|
||||
.axes = {0},
|
||||
.shape = {config.vit_config.model_dim},
|
||||
.min_size = Type::kBF16,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "img_emb_bias",
|
||||
.source_names = {"img/embedding/bias"},
|
||||
.axes = {0},
|
||||
.shape = {config.vit_config.model_dim},
|
||||
.min_size = Type::kF32,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "img_emb_kernel",
|
||||
.source_names = {"img/embedding/kernel"},
|
||||
.axes = {3, 0, 1, 2},
|
||||
.shape = {config.vit_config.model_dim, config.vit_config.patch_width,
|
||||
config.vit_config.patch_width, 3},
|
||||
.min_size = Type::kBF16,
|
||||
.cols_take_extra_dims = true,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "img_head_bias",
|
||||
.source_names = {"img/head/bias", "embedder/mm_input_projection/b"},
|
||||
.axes = {0},
|
||||
.shape = {config.model_dim},
|
||||
.min_size = Type::kF32,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "img_head_kernel",
|
||||
.source_names = {"img/head/kernel", "embedder/mm_input_projection/w"},
|
||||
.axes = {1, 0},
|
||||
.shape = {config.model_dim, config.vit_config.model_dim},
|
||||
.min_size = Type::kBF16,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "img_pos_emb",
|
||||
.source_names = {"img/pos_embedding"},
|
||||
.axes = {0, 1},
|
||||
.shape = {/*1,*/ config.vit_config.seq_len,
|
||||
config.vit_config.model_dim},
|
||||
.min_size = Type::kF32,
|
||||
},
|
||||
// RMS norm applied to soft tokens prior to pos embedding.
|
||||
TensorInfo{
|
||||
.name = "mm_embed_norm",
|
||||
.source_names = {"embedder/mm_soft_embedding_norm/scale"},
|
||||
.axes = {0},
|
||||
.shape = {config.vit_config.model_dim},
|
||||
.min_size = Type::kBF16,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
// Returns the tensors for the given image layer config.
|
||||
std::vector<TensorInfo> ImageLayerTensors(const ModelConfig& config,
|
||||
const LayerConfig& layer_config,
|
||||
const int img_layer_idx) {
|
||||
return {
|
||||
// Vit layers.
|
||||
TensorInfo{
|
||||
.name = "attn_out_w",
|
||||
.source_names = {"MultiHeadDotProductAttention_0/out/kernel"},
|
||||
.axes = {2, 0, 1},
|
||||
.shape = {config.vit_config.model_dim, layer_config.heads,
|
||||
layer_config.qkv_dim},
|
||||
.min_size = Type::kBF16,
|
||||
.cols_take_extra_dims = true,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "attn_out_b",
|
||||
.source_names = {"MultiHeadDotProductAttention_0/out/bias"},
|
||||
.axes = {0},
|
||||
.shape = {config.vit_config.model_dim},
|
||||
.min_size = Type::kF32,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "q_ein_w",
|
||||
.source_names = {"MultiHeadDotProductAttention_0/query/kernel"},
|
||||
.axes = {1, 2, 0},
|
||||
.shape = {layer_config.heads, layer_config.qkv_dim,
|
||||
config.vit_config.model_dim},
|
||||
.concat_names = {"qkv_ein_w", "k_ein_w", "v_ein_w"},
|
||||
.concat_axis = 1,
|
||||
.min_size = Type::kBF16,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "k_ein_w",
|
||||
.source_names = {"MultiHeadDotProductAttention_0/key/kernel"},
|
||||
.axes = {1, 2, 0},
|
||||
.shape = {layer_config.heads, layer_config.qkv_dim,
|
||||
config.vit_config.model_dim},
|
||||
.concat_names = {""},
|
||||
.min_size = Type::kBF16,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "v_ein_w",
|
||||
.source_names = {"MultiHeadDotProductAttention_0/value/kernel"},
|
||||
.axes = {1, 2, 0},
|
||||
.shape = {layer_config.heads, layer_config.qkv_dim,
|
||||
config.vit_config.model_dim},
|
||||
.concat_names = {""},
|
||||
.min_size = Type::kBF16,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "qkv_ein_w",
|
||||
.source_names = {"MultiHeadDotProductAttention_0/qkv/kernel"},
|
||||
.axes = {1, 2, 0},
|
||||
.shape = {layer_config.heads, 3 * layer_config.qkv_dim,
|
||||
config.vit_config.model_dim},
|
||||
.min_size = Type::kBF16,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "q_ein_b",
|
||||
.source_names = {"MultiHeadDotProductAttention_0/query/bias"},
|
||||
.axes = {0, 1},
|
||||
.shape = {layer_config.heads, layer_config.qkv_dim},
|
||||
.concat_names = {"qkv_ein_b", "k_ein_b", "v_ein_b"},
|
||||
.concat_axis = 1,
|
||||
.min_size = Type::kF32,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "k_ein_b",
|
||||
.source_names = {"MultiHeadDotProductAttention_0/key/bias"},
|
||||
.axes = {0, 1},
|
||||
.shape = {layer_config.kv_heads, layer_config.qkv_dim},
|
||||
.concat_names = {""},
|
||||
.min_size = Type::kF32,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "v_ein_b",
|
||||
.source_names = {"MultiHeadDotProductAttention_0/value/bias"},
|
||||
.axes = {0, 1},
|
||||
.shape = {layer_config.kv_heads, layer_config.qkv_dim},
|
||||
.concat_names = {""},
|
||||
.min_size = Type::kF32,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "qkv_ein_b",
|
||||
.source_names = {"MultiHeadDotProductAttention_0/qkv/bias"},
|
||||
.axes = {0, 1},
|
||||
.shape = {layer_config.heads + layer_config.kv_heads * 2,
|
||||
layer_config.qkv_dim},
|
||||
.min_size = Type::kF32,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "linear_0_w",
|
||||
.source_names = {"MlpBlock_0/Dense_0/kernel"},
|
||||
.axes = {1, 0},
|
||||
.shape = {layer_config.ff_hidden_dim, config.vit_config.model_dim},
|
||||
.min_size = Type::kBF16,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "linear_0_b",
|
||||
.source_names = {"MlpBlock_0/Dense_0/bias"},
|
||||
.axes = {0},
|
||||
.shape = {layer_config.ff_hidden_dim},
|
||||
.min_size = Type::kF32,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "linear_1_w",
|
||||
.source_names = {"MlpBlock_0/Dense_1/kernel"},
|
||||
.axes = {1, 0},
|
||||
.shape = {config.vit_config.model_dim, layer_config.ff_hidden_dim},
|
||||
.min_size = Type::kBF16,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "linear_1_b",
|
||||
.source_names = {"MlpBlock_0/Dense_1/bias"},
|
||||
.axes = {0},
|
||||
.shape = {config.vit_config.model_dim},
|
||||
.min_size = Type::kF32,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "ln_0_bias",
|
||||
.source_names = {"img/Transformer/encoderblock/LayerNorm_0/bias",
|
||||
"img/Transformer/encoderblock_" +
|
||||
std::to_string(img_layer_idx) +
|
||||
"/LayerNorm_0/bias"},
|
||||
.axes = {0},
|
||||
.shape = {config.vit_config.model_dim},
|
||||
.min_size = Type::kBF16,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "ln_0_scale",
|
||||
.source_names = {"img/Transformer/encoderblock/LayerNorm_0/scale",
|
||||
"img/Transformer/encoderblock_" +
|
||||
std::to_string(img_layer_idx) +
|
||||
"/LayerNorm_0/scale"},
|
||||
.axes = {0},
|
||||
.shape = {config.vit_config.model_dim},
|
||||
.min_size = Type::kBF16,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "ln_1_bias",
|
||||
.source_names = {"img/Transformer/encoderblock/LayerNorm_1/bias",
|
||||
"img/Transformer/encoderblock_" +
|
||||
std::to_string(img_layer_idx) +
|
||||
"/LayerNorm_1/bias"},
|
||||
.axes = {0},
|
||||
.shape = {config.vit_config.model_dim},
|
||||
.min_size = Type::kBF16,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "ln_1_scale",
|
||||
.source_names = {"img/Transformer/encoderblock/LayerNorm_1/scale",
|
||||
"img/Transformer/encoderblock_" +
|
||||
std::to_string(img_layer_idx) +
|
||||
"/LayerNorm_1/scale"},
|
||||
.axes = {0},
|
||||
.shape = {config.vit_config.model_dim},
|
||||
.min_size = Type::kBF16,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
// Returns the tensors for the given LLM layer config.
|
||||
std::vector<TensorInfo> LLMLayerTensors(const ModelConfig& config,
|
||||
const LayerConfig& layer_config,
|
||||
bool reshape_att) {
|
||||
std::vector<TensorInfo> tensors = {
|
||||
TensorInfo{
|
||||
.name = "key_norm",
|
||||
.source_names = {"attn/_key_norm/scale"},
|
||||
.axes = {0},
|
||||
.shape = {layer_config.qkv_dim},
|
||||
.min_size = Type::kBF16,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "query_norm",
|
||||
.source_names = {"attn/_query_norm/scale"},
|
||||
.axes = {0},
|
||||
.shape = {layer_config.qkv_dim},
|
||||
.min_size = Type::kBF16,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "qkv1_w",
|
||||
.source_names = {"attn/q_einsum/w"},
|
||||
.axes = {0, 2, 1},
|
||||
.shape = {layer_config.heads * layer_config.qkv_dim,
|
||||
config.model_dim},
|
||||
.concat_names = {"qkv_ein", "qkv2_w"},
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "qkv2_w",
|
||||
.source_names = {"attn/kv_einsum/w"},
|
||||
.axes = {1, 0, 3, 2},
|
||||
.shape = {2 * layer_config.kv_heads * layer_config.qkv_dim,
|
||||
config.model_dim},
|
||||
.concat_names = {""},
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "q_ein",
|
||||
.source_names = {"attention_block/proj_q/kernel"},
|
||||
.axes = {1, 0},
|
||||
.shape = {layer_config.model_dim, layer_config.model_dim},
|
||||
.concat_names = {"qkv_ein", "k_ein", "v_ein"},
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "k_ein",
|
||||
.source_names = {"attention_block/proj_k/kernel"},
|
||||
.axes = {1, 0},
|
||||
.shape = {layer_config.qkv_dim, layer_config.model_dim},
|
||||
.concat_names = {""},
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "v_ein",
|
||||
.source_names = {"attention_block/proj_v/kernel"},
|
||||
.axes = {1, 0},
|
||||
.shape = {layer_config.qkv_dim, layer_config.model_dim},
|
||||
.concat_names = {""},
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "qkv_ein",
|
||||
.source_names = {"attn/qkv_einsum/w"},
|
||||
.axes = {1, 0, 3, 2},
|
||||
.shape = {(layer_config.heads + 2 * layer_config.kv_heads) *
|
||||
layer_config.qkv_dim,
|
||||
config.model_dim},
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "attn_ob",
|
||||
.source_names = {"attention_block/proj_final/bias"},
|
||||
.axes = {0},
|
||||
.shape = {config.model_dim},
|
||||
.min_size = Type::kF32,
|
||||
},
|
||||
// Griffin layers.
|
||||
TensorInfo{
|
||||
.name = "gr_lin_x_w",
|
||||
.source_names = {"recurrent_block/linear_x/kernel"},
|
||||
.axes = {1, 0},
|
||||
.shape = {layer_config.griffin_dim, layer_config.griffin_dim},
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "gr_lin_x_b",
|
||||
.source_names = {"recurrent_block/linear_x/bias"},
|
||||
.axes = {0},
|
||||
.shape = {layer_config.griffin_dim},
|
||||
.min_size = Type::kF32,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "gr_lin_y_w",
|
||||
.source_names = {"recurrent_block/linear_y/kernel"},
|
||||
.axes = {1, 0},
|
||||
.shape = {layer_config.griffin_dim, layer_config.griffin_dim},
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "gr_lin_y_b",
|
||||
.source_names = {"recurrent_block/linear_y/bias"},
|
||||
.axes = {0},
|
||||
.shape = {layer_config.griffin_dim},
|
||||
.min_size = Type::kF32,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "gr_lin_out_w",
|
||||
.source_names = {"recurrent_block/linear_out/kernel"},
|
||||
.axes = {1, 0},
|
||||
.shape = {layer_config.griffin_dim, layer_config.griffin_dim},
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "gr_lin_out_b",
|
||||
.source_names = {"recurrent_block/linear_out/bias"},
|
||||
.axes = {0},
|
||||
.shape = {layer_config.griffin_dim},
|
||||
.min_size = Type::kF32,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "gr_conv_w",
|
||||
.source_names = {"recurrent_block/conv_1d/w"},
|
||||
.axes = {0, 1},
|
||||
.shape = {layer_config.conv1d_width, layer_config.griffin_dim},
|
||||
.min_size = Type::kF32,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "gr_conv_b",
|
||||
.source_names = {"recurrent_block/conv_1d/b"},
|
||||
.axes = {0},
|
||||
.shape = {layer_config.griffin_dim},
|
||||
.min_size = Type::kF32,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "gr1_gate_w",
|
||||
.source_names = {"recurrent_block/rg_lru/input_gate/w"},
|
||||
.axes = {0, 2, 1},
|
||||
.shape = {layer_config.heads,
|
||||
layer_config.griffin_dim / layer_config.heads,
|
||||
layer_config.griffin_dim / layer_config.heads},
|
||||
.concat_names = {"gr_gate_w", "gr2_gate_w"},
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "gr2_gate_w",
|
||||
.source_names = {"recurrent_block/rg_lru/a_gate/w"},
|
||||
.axes = {0, 2, 1},
|
||||
.shape = {layer_config.heads,
|
||||
layer_config.griffin_dim / layer_config.heads,
|
||||
layer_config.griffin_dim / layer_config.heads},
|
||||
.concat_names = {""},
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "gr_gate_w",
|
||||
.source_names = {"recurrent_block/rg_lru/gate/w"},
|
||||
.axes = {0, 2, 1},
|
||||
.shape = {2 * layer_config.heads,
|
||||
layer_config.griffin_dim / layer_config.heads,
|
||||
layer_config.griffin_dim / layer_config.heads},
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "gr1_gate_b",
|
||||
.source_names = {"recurrent_block/rg_lru/input_gate/b"},
|
||||
.axes = {0},
|
||||
.shape = {layer_config.griffin_dim},
|
||||
.concat_names = {"gr_gate_b", "gr2_gate_b"},
|
||||
.min_size = Type::kF32,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "gr2_gate_b",
|
||||
.source_names = {"recurrent_block/rg_lru/a_gate/b"},
|
||||
.axes = {0},
|
||||
.shape = {layer_config.griffin_dim},
|
||||
.concat_names = {""},
|
||||
.min_size = Type::kF32,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "gr_gate_b",
|
||||
.source_names = {"recurrent_block/rg_lru/input_gate/b"},
|
||||
.axes = {0, 1},
|
||||
.shape = {2 * layer_config.griffin_dim},
|
||||
.min_size = Type::kF32,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "gr_a",
|
||||
.source_names = {"recurrent_block/rg_lru/a_param"},
|
||||
.axes = {0},
|
||||
.shape = {layer_config.griffin_dim},
|
||||
.min_size = Type::kF32,
|
||||
.scaled_softplus = true,
|
||||
},
|
||||
|
||||
TensorInfo{
|
||||
.name = "gating_ein",
|
||||
.source_names = {"mlp/gating_einsum/w", "mlp/gating_einsum",
|
||||
"mlp_block/ffw_up/w"},
|
||||
.axes = {0, layer_config.optimized_gating ? 1u : 2u,
|
||||
layer_config.optimized_gating ? 2u : 1u},
|
||||
.shape = {2, layer_config.ff_hidden_dim, config.model_dim},
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "gating1_w",
|
||||
.source_names = {"none"},
|
||||
.axes = {0, layer_config.optimized_gating ? 1u : 2u,
|
||||
layer_config.optimized_gating ? 2u : 1u},
|
||||
.shape = {layer_config.ff_hidden_dim, config.model_dim},
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "gating2_w",
|
||||
.source_names = {"none"},
|
||||
.axes = {0, layer_config.optimized_gating ? 1u : 2u,
|
||||
layer_config.optimized_gating ? 2u : 1u},
|
||||
.shape = {layer_config.ff_hidden_dim, config.model_dim},
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "linear_w",
|
||||
.source_names = {"mlp/linear/w", "mlp/linear",
|
||||
"mlp_block/ffw_down/kernel"},
|
||||
.axes = {1, 0},
|
||||
.shape = {config.model_dim, layer_config.ff_hidden_dim},
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "pre_att_ns",
|
||||
.source_names = {"pre_attention_norm/scale",
|
||||
"temporal_pre_norm/scale"},
|
||||
.axes = {0},
|
||||
.shape = {config.model_dim},
|
||||
.min_size = Type::kBF16,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "pre_ff_ns",
|
||||
.source_names = {"pre_ffw_norm/scale", "channel_pre_norm/scale"},
|
||||
.axes = {0},
|
||||
.shape = {config.model_dim},
|
||||
.min_size = Type::kBF16,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "post_att_ns",
|
||||
.source_names = {"post_attention_norm/scale"},
|
||||
.axes = {0},
|
||||
.shape = {config.model_dim},
|
||||
.min_size = Type::kBF16,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "post_ff_ns",
|
||||
.source_names = {"post_ffw_norm/scale"},
|
||||
.axes = {0},
|
||||
.shape = {config.model_dim},
|
||||
.min_size = Type::kBF16,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "ffw_gat_b",
|
||||
.source_names = {"mlp_block/ffw_up/b"},
|
||||
.axes = {0},
|
||||
.shape = {2 * layer_config.ff_hidden_dim},
|
||||
.min_size = Type::kF32,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "ffw_out_b",
|
||||
.source_names = {"mlp_block/ffw_down/bias"},
|
||||
.axes = {0},
|
||||
.shape = {config.model_dim},
|
||||
.min_size = Type::kF32,
|
||||
},
|
||||
};
|
||||
if (reshape_att) {
|
||||
tensors.push_back(TensorInfo{
|
||||
.name = "att_w",
|
||||
.source_names = {"attn/attn_vec_einsum/w",
|
||||
"attention_block/proj_final/kernel"},
|
||||
.preshape = {layer_config.heads, layer_config.qkv_dim,
|
||||
config.model_dim},
|
||||
.axes = {2, 0, 1},
|
||||
.shape = {config.model_dim, layer_config.heads, layer_config.qkv_dim},
|
||||
.cols_take_extra_dims = true,
|
||||
});
|
||||
tensors.push_back(TensorInfo{
|
||||
.name = "att_ein",
|
||||
.shape = {layer_config.heads, config.model_dim, layer_config.qkv_dim},
|
||||
});
|
||||
} else {
|
||||
tensors.push_back(TensorInfo{
|
||||
.name = "att_ein",
|
||||
.source_names = {"attn/attn_vec_einsum/w",
|
||||
"attention_block/proj_final/kernel"},
|
||||
.preshape = {layer_config.heads, layer_config.qkv_dim,
|
||||
config.model_dim},
|
||||
.axes = {0, 2, 1},
|
||||
.shape = {layer_config.heads, config.model_dim, layer_config.qkv_dim},
|
||||
});
|
||||
tensors.push_back(TensorInfo{
|
||||
.name = "att_w",
|
||||
.shape = {config.model_dim, layer_config.heads, layer_config.qkv_dim},
|
||||
.cols_take_extra_dims = true,
|
||||
});
|
||||
}
|
||||
return tensors;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
TensorIndex::TensorIndex(const ModelConfig& config, int llm_layer_idx,
|
||||
int img_layer_idx, bool reshape_att)
|
||||
: config_(config),
|
||||
llm_layer_idx_(llm_layer_idx),
|
||||
img_layer_idx_(img_layer_idx) {
|
||||
int layer_idx = std::max(llm_layer_idx_, img_layer_idx_);
|
||||
std::string suffix;
|
||||
if (layer_idx >= 0) {
|
||||
suffix = "_" + std::to_string(layer_idx);
|
||||
}
|
||||
if (llm_layer_idx < 0 && img_layer_idx < 0) {
|
||||
tensors_ = ModelTensors(config);
|
||||
} else if (llm_layer_idx_ < 0 && 0 <= img_layer_idx &&
|
||||
img_layer_idx < config.vit_config.layer_configs.size()) {
|
||||
const auto& layer_config = config.vit_config.layer_configs[img_layer_idx];
|
||||
tensors_ = ImageLayerTensors(config, layer_config, img_layer_idx);
|
||||
} else if (0 <= llm_layer_idx &&
|
||||
llm_layer_idx < config.layer_configs.size()) {
|
||||
const auto& layer_config = config.layer_configs[llm_layer_idx];
|
||||
tensors_ = LLMLayerTensors(config, layer_config, reshape_att);
|
||||
}
|
||||
for (size_t i = 0; i < tensors_.size(); ++i) {
|
||||
std::string key = tensors_[i].name + suffix;
|
||||
name_map_.insert({key, i});
|
||||
}
|
||||
}
|
||||
|
||||
TensorInfo TensorIndex::TensorInfoFromSourcePath(
|
||||
const std::string& path) const {
|
||||
for (const auto& tensor : tensors_) {
|
||||
for (const auto& source_name : tensor.source_names) {
|
||||
auto pos = path.rfind(source_name);
|
||||
if (pos != std::string::npos && path.size() == pos + source_name.size())
|
||||
return tensor;
|
||||
}
|
||||
}
|
||||
return TensorInfo();
|
||||
}
|
||||
|
||||
const TensorInfo* TensorIndex::FindName(const std::string& name) const {
|
||||
std::string name_to_find = name;
|
||||
if (!std::isdigit(name[name.size() - 1])) {
|
||||
if (img_layer_idx_ >= 0 && llm_layer_idx_ < 0) {
|
||||
name_to_find = name + "_" + std::to_string(img_layer_idx_);
|
||||
} else if (llm_layer_idx_ >= 0) {
|
||||
name_to_find = name + "_" + std::to_string(llm_layer_idx_);
|
||||
}
|
||||
}
|
||||
auto it = name_map_.find(name_to_find);
|
||||
if (it == name_map_.end()) {
|
||||
return nullptr;
|
||||
}
|
||||
return &tensors_[it->second];
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
@ -1,101 +0,0 @@
|
|||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_TENSOR_INDEX_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_TENSOR_INDEX_H_
|
||||
|
||||
#include <stddef.h>
|
||||
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "compression/shared.h"
|
||||
#include "gemma/configs.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
// Universal tensor information. Holds enough information to construct a
|
||||
// tensor in LayerWeightsPtrs/ModelWeightsPtrs, as well as to export the
|
||||
// tensor from the python model with necessary transpose/reshape info.
|
||||
struct TensorInfo {
|
||||
// The name of the tensor in the sbs file
|
||||
std::string name;
|
||||
// Strings to match to the end of the name of the tensor in the python model.
|
||||
std::vector<std::string> source_names;
|
||||
// Initial reshape shape. Use only as a last resort when input may have
|
||||
// dimensions combined that need to be split before the transpose, as it
|
||||
// defeats the post-transpose shape checking. Normally empty.
|
||||
std::vector<size_t> preshape;
|
||||
// Transpose axes arg. If the input tensor has more dimensions than axes,
|
||||
// then leading dimensions are collapsed until the number of axes matches.
|
||||
std::vector<size_t> axes;
|
||||
// Expected final shape of the tensor after reshape/transpose.
|
||||
// Note that this is the shape of the tensor during export,
|
||||
// not the shape of the tensor in the sbs file, as the sbs file
|
||||
// is restricted to 2D tensors. With few exceptions, the sbs file
|
||||
// tensor rows gather all the excess dimensions. See cols_take_extra_dims.
|
||||
std::vector<size_t> shape;
|
||||
// List of names to concatenate with, used only if multiple tensors are
|
||||
// concatenated into one. The first tensor in the concatenation should have
|
||||
// concat names thus: The first name is the name of the result, and the
|
||||
// tensors with the remaining names are concatenated after this.
|
||||
// The remaining tensors to be concatenated should have just a single
|
||||
// empty string in concat_names to indicate that they have been consumed.
|
||||
std::vector<std::string> concat_names;
|
||||
// Axis at which to concatenate.
|
||||
size_t concat_axis = 0;
|
||||
// The minimum compression weight type for this tensor. The default is
|
||||
// kNUQ, which provides maximum compression. Other values such as kBF16
|
||||
// or kF32 can be used to limit the compression to a specific type.
|
||||
Type min_size = Type::kNUQ;
|
||||
// Whether to apply scaled softplus to the data.
|
||||
bool scaled_softplus = false;
|
||||
// Whether the columns or the rows take any extra dimensions.
|
||||
// If false, then [10, 20, 30] -> [10*20, 30] and [30] -> [1, 30].
|
||||
// If true, then [10, 20, 30] -> [10, 20*30] and [30] -> [1, 30].
|
||||
bool cols_take_extra_dims = false;
|
||||
};
|
||||
|
||||
// Universal index of tensor information, which can be built for a specific
|
||||
// layer_idx.
|
||||
class TensorIndex {
|
||||
public:
|
||||
// Builds a list of TensorInfo for the given layer_idx.
|
||||
// If reshape_att is true, the attn_vec_einsum tensor is reshaped.
|
||||
TensorIndex(const ModelConfig& config, int llm_layer_idx, int img_layer_idx,
|
||||
bool reshape_att);
|
||||
~TensorIndex() = default;
|
||||
|
||||
// Returns the TensorInfo whose source_name matches the end of the given path,
|
||||
// or an empty TensorInfo if not found.
|
||||
// NOTE: that the returned TensorInfo is a copy, so that the source
|
||||
// TensorIndex can be destroyed without affecting the returned TensorInfo.
|
||||
TensorInfo TensorInfoFromSourcePath(const std::string& path) const;
|
||||
|
||||
// Returns the TensorInfo whose name matches the given name,
|
||||
// or an empty TensorInfo if not found.
|
||||
// NOTE: that the returned TensorInfo is a copy, so that the source
|
||||
// TensorIndex can be destroyed without affecting the returned TensorInfo.
|
||||
TensorInfo TensorInfoFromName(const std::string& name) const {
|
||||
const TensorInfo* info = FindName(name);
|
||||
if (info == nullptr) return TensorInfo();
|
||||
return *info;
|
||||
}
|
||||
|
||||
// Returns the TensorInfo for the given tensor name, for concise construction
|
||||
// of ModelWeightsPtrs/LayerWeightsPtrs.
|
||||
const TensorInfo* FindName(const std::string& name) const;
|
||||
|
||||
private:
|
||||
// Config that was used to build the tensor index.
|
||||
const ModelConfig& config_;
|
||||
// Layer that this tensor index is for - either LLM or image.
|
||||
int llm_layer_idx_;
|
||||
int img_layer_idx_;
|
||||
// List of tensor information for this layer.
|
||||
std::vector<TensorInfo> tensors_;
|
||||
// Map from tensor name to index in tensors_.
|
||||
std::unordered_map<std::string, size_t> name_map_;
|
||||
};
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_TENSOR_INDEX_H_
|
||||
|
|
@ -1,72 +0,0 @@
|
|||
#include "gemma/tensor_index.h"
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "compression/compress.h"
|
||||
#include "compression/shared.h"
|
||||
#include "gemma/configs.h"
|
||||
#include "gemma/weights.h"
|
||||
#include "util/basics.h"
|
||||
#include "hwy/aligned_allocator.h"
|
||||
|
||||
namespace gcpp {
|
||||
namespace {
|
||||
|
||||
// Tests that each tensor in the model can be found by exactly one TensorIndex,
|
||||
// and that the TensorIndex returns the correct shape and name for the tensor,
|
||||
// for all models.
|
||||
TEST(TensorIndexTest, FindName) {
|
||||
for (Model model : kAllModels) {
|
||||
fprintf(stderr, "Testing model %d\n", static_cast<int>(model));
|
||||
ModelConfig config = ConfigFromModel(model);
|
||||
std::vector<TensorIndex> tensor_indexes;
|
||||
tensor_indexes.emplace_back(config, /*llm_layer_idx=*/-1,
|
||||
/*img_layer_idx=*/-1,
|
||||
/*split_and_reshape=*/false);
|
||||
for (size_t llm_layer_idx = 0; llm_layer_idx < config.layer_configs.size();
|
||||
++llm_layer_idx) {
|
||||
tensor_indexes.emplace_back(config, static_cast<int>(llm_layer_idx),
|
||||
/*img_layer_idx=*/-1,
|
||||
/*split_and_reshape=*/false);
|
||||
}
|
||||
for (size_t img_layer_idx = 0;
|
||||
img_layer_idx < config.vit_config.layer_configs.size();
|
||||
++img_layer_idx) {
|
||||
tensor_indexes.emplace_back(config, /*llm_layer_idx=*/-1,
|
||||
static_cast<int>(img_layer_idx),
|
||||
/*split_and_reshape=*/false);
|
||||
}
|
||||
// For each tensor in any model, exactly one TensorIndex should find it.
|
||||
ModelWeightsPtrs<SfpStream> weights(config);
|
||||
ModelWeightsPtrs<SfpStream>::ForEachTensor(
|
||||
{&weights}, ForEachType::kInitNoToc,
|
||||
[&tensor_indexes](const char* name, hwy::Span<MatPtr*> tensors) {
|
||||
int num_found = 0;
|
||||
const MatPtr& tensor = *tensors[0];
|
||||
for (const auto& tensor_index : tensor_indexes) {
|
||||
// Skip the type marker prefix, but we want the layer index suffix.
|
||||
std::string name_to_find(name + 1, strlen(name) - 1);
|
||||
const TensorInfo* info = tensor_index.FindName(name_to_find);
|
||||
if (info != nullptr) {
|
||||
// Test that the MatPtr can be constructed from the TensorInfo,
|
||||
// and that the dimensions match.
|
||||
MatPtrT<SfpStream> mat_ptr(tensor.Name(), tensor_index);
|
||||
EXPECT_STREQ(tensor.Name(), mat_ptr.Name())
|
||||
<< "on tensor " << name;
|
||||
EXPECT_EQ(tensor.Rows(), mat_ptr.Rows()) << "on tensor " << name;
|
||||
EXPECT_EQ(tensor.Cols(), mat_ptr.Cols()) << "on tensor " << name;
|
||||
++num_found;
|
||||
}
|
||||
}
|
||||
EXPECT_EQ(num_found, 1) << " for tensor " << name;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace gcpp
|
||||
|
|
@ -0,0 +1,593 @@
|
|||
#include "gemma/tensor_info.h"
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "compression/types.h"
|
||||
#include "gemma/configs.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
void TensorInfoRegistry::Add(const std::string& suffix,
|
||||
const TensorInfo& info) {
|
||||
const size_t idx = tensors_.size();
|
||||
tensors_.push_back(info);
|
||||
// Also add suffix to `concat_names`.
|
||||
for (std::string& name : tensors_.back().concat_names) {
|
||||
name += suffix;
|
||||
}
|
||||
|
||||
const std::string name = info.base_name + suffix;
|
||||
// Ensure successful insertion because `suffix` ensures uniqueness for
|
||||
// per-layer tensors, and per-model should only be inserted once.
|
||||
HWY_ASSERT_M(idx_from_name_.insert({name, idx}).second, name.c_str());
|
||||
}
|
||||
|
||||
// Non-layer tensors.
|
||||
void TensorInfoRegistry::AddModelTensors(const ModelConfig& config) {
|
||||
const std::string no_suffix;
|
||||
Add(no_suffix, {
|
||||
.base_name = "c_embedding",
|
||||
.source_names = {"embedder/input_embedding"},
|
||||
.axes = {0, 1},
|
||||
.shape = {config.vocab_size, config.model_dim},
|
||||
.min_size = Type::kBF16,
|
||||
});
|
||||
Add(no_suffix, {
|
||||
.base_name = "c_final_norm",
|
||||
.source_names = {"final_norm/scale"},
|
||||
.axes = {0},
|
||||
.shape = {config.model_dim},
|
||||
.min_size = Type::kBF16,
|
||||
});
|
||||
Add(no_suffix, {
|
||||
.base_name = "enc_norm_bias",
|
||||
.source_names = {"img/Transformer/encoder_norm/bias"},
|
||||
.axes = {0},
|
||||
.shape = {config.vit_config.model_dim},
|
||||
.min_size = Type::kBF16,
|
||||
});
|
||||
Add(no_suffix, {
|
||||
.base_name = "enc_norm_scale",
|
||||
.source_names = {"img/Transformer/encoder_norm/scale"},
|
||||
.axes = {0},
|
||||
.shape = {config.vit_config.model_dim},
|
||||
.min_size = Type::kBF16,
|
||||
});
|
||||
Add(no_suffix, {
|
||||
.base_name = "img_emb_bias",
|
||||
.source_names = {"img/embedding/bias"},
|
||||
.axes = {0},
|
||||
.shape = {config.vit_config.model_dim},
|
||||
.min_size = Type::kF32,
|
||||
});
|
||||
Add(no_suffix,
|
||||
{
|
||||
.base_name = "img_emb_kernel",
|
||||
.source_names = {"img/embedding/kernel"},
|
||||
.axes = {3, 0, 1, 2},
|
||||
.shape = {config.vit_config.model_dim, config.vit_config.patch_width,
|
||||
config.vit_config.patch_width, 3},
|
||||
.min_size = Type::kBF16,
|
||||
.cols_take_extra_dims = true,
|
||||
});
|
||||
Add(no_suffix,
|
||||
{
|
||||
.base_name = "img_head_bias",
|
||||
.source_names = {"img/head/bias", "embedder/mm_input_projection/b"},
|
||||
.axes = {0},
|
||||
.shape = {config.model_dim},
|
||||
.min_size = Type::kF32,
|
||||
});
|
||||
Add(no_suffix,
|
||||
{
|
||||
.base_name = "img_head_kernel",
|
||||
.source_names = {"img/head/kernel", "embedder/mm_input_projection/w"},
|
||||
.axes = {1, 0},
|
||||
.shape = {config.model_dim, config.vit_config.model_dim},
|
||||
.min_size = Type::kBF16,
|
||||
});
|
||||
Add(no_suffix, {
|
||||
.base_name = "img_pos_emb",
|
||||
.source_names = {"img/pos_embedding"},
|
||||
.axes = {0, 1},
|
||||
.shape = {/*1,*/ config.vit_config.seq_len,
|
||||
config.vit_config.model_dim},
|
||||
.min_size = Type::kF32,
|
||||
});
|
||||
// RMS norm applied to soft tokens prior to pos embedding.
|
||||
Add(no_suffix, {
|
||||
.base_name = "mm_embed_norm",
|
||||
.source_names = {"embedder/mm_soft_embedding_norm/scale"},
|
||||
.axes = {0},
|
||||
.shape = {config.vit_config.model_dim},
|
||||
.min_size = Type::kBF16,
|
||||
});
|
||||
}
|
||||
|
||||
// Returns the tensors for the given image layer config.
|
||||
void TensorInfoRegistry::AddImageLayerTensors(const ModelConfig& config,
|
||||
const LayerConfig& layer_config,
|
||||
const size_t img_layer_idx) {
|
||||
const std::string suffix = LayerSuffix(img_layer_idx);
|
||||
|
||||
// Vit layers.
|
||||
Add(suffix, {
|
||||
.base_name = "attn_out_w",
|
||||
.source_names = {"MultiHeadDotProductAttention_0/out/kernel"},
|
||||
.axes = {2, 0, 1},
|
||||
.shape = {config.vit_config.model_dim, layer_config.heads,
|
||||
layer_config.qkv_dim},
|
||||
.min_size = Type::kBF16,
|
||||
.cols_take_extra_dims = true,
|
||||
});
|
||||
Add(suffix, {
|
||||
.base_name = "attn_out_b",
|
||||
.source_names = {"MultiHeadDotProductAttention_0/out/bias"},
|
||||
.axes = {0},
|
||||
.shape = {config.vit_config.model_dim},
|
||||
.min_size = Type::kF32,
|
||||
});
|
||||
Add(suffix,
|
||||
{
|
||||
.base_name = "q_ein_w",
|
||||
.source_names = {"MultiHeadDotProductAttention_0/query/kernel"},
|
||||
.axes = {1, 2, 0},
|
||||
.shape = {layer_config.heads, layer_config.qkv_dim,
|
||||
config.vit_config.model_dim},
|
||||
.concat_names = {"qkv_ein_w", "k_ein_w", "v_ein_w"},
|
||||
.concat_axis = 1,
|
||||
.min_size = Type::kBF16,
|
||||
});
|
||||
Add(suffix, {
|
||||
.base_name = "k_ein_w",
|
||||
.source_names = {"MultiHeadDotProductAttention_0/key/kernel"},
|
||||
.axes = {1, 2, 0},
|
||||
.shape = {layer_config.heads, layer_config.qkv_dim,
|
||||
config.vit_config.model_dim},
|
||||
.concat_names = {""},
|
||||
.min_size = Type::kBF16,
|
||||
});
|
||||
Add(suffix,
|
||||
{
|
||||
.base_name = "v_ein_w",
|
||||
.source_names = {"MultiHeadDotProductAttention_0/value/kernel"},
|
||||
.axes = {1, 2, 0},
|
||||
.shape = {layer_config.heads, layer_config.qkv_dim,
|
||||
config.vit_config.model_dim},
|
||||
.concat_names = {""},
|
||||
.min_size = Type::kBF16,
|
||||
});
|
||||
Add(suffix, {
|
||||
.base_name = "qkv_ein_w",
|
||||
.source_names = {"MultiHeadDotProductAttention_0/qkv/kernel"},
|
||||
.axes = {1, 2, 0},
|
||||
.shape = {layer_config.heads, 3 * layer_config.qkv_dim,
|
||||
config.vit_config.model_dim},
|
||||
.min_size = Type::kBF16,
|
||||
});
|
||||
Add(suffix, {
|
||||
.base_name = "q_ein_b",
|
||||
.source_names = {"MultiHeadDotProductAttention_0/query/bias"},
|
||||
.axes = {0, 1},
|
||||
.shape = {layer_config.heads, layer_config.qkv_dim},
|
||||
.concat_names = {"qkv_ein_b", "k_ein_b", "v_ein_b"},
|
||||
.concat_axis = 1,
|
||||
.min_size = Type::kF32,
|
||||
});
|
||||
Add(suffix, {
|
||||
.base_name = "k_ein_b",
|
||||
.source_names = {"MultiHeadDotProductAttention_0/key/bias"},
|
||||
.axes = {0, 1},
|
||||
.shape = {layer_config.kv_heads, layer_config.qkv_dim},
|
||||
.concat_names = {""},
|
||||
.min_size = Type::kF32,
|
||||
});
|
||||
Add(suffix, {
|
||||
.base_name = "v_ein_b",
|
||||
.source_names = {"MultiHeadDotProductAttention_0/value/bias"},
|
||||
.axes = {0, 1},
|
||||
.shape = {layer_config.kv_heads, layer_config.qkv_dim},
|
||||
.concat_names = {""},
|
||||
.min_size = Type::kF32,
|
||||
});
|
||||
Add(suffix, {
|
||||
.base_name = "qkv_ein_b",
|
||||
.source_names = {"MultiHeadDotProductAttention_0/qkv/bias"},
|
||||
.axes = {0, 1},
|
||||
.shape = {layer_config.heads + layer_config.kv_heads * 2,
|
||||
layer_config.qkv_dim},
|
||||
.min_size = Type::kF32,
|
||||
});
|
||||
Add(suffix,
|
||||
{
|
||||
.base_name = "linear_0_w",
|
||||
.source_names = {"MlpBlock_0/Dense_0/kernel"},
|
||||
.axes = {1, 0},
|
||||
.shape = {layer_config.ff_hidden_dim, config.vit_config.model_dim},
|
||||
.min_size = Type::kBF16,
|
||||
});
|
||||
Add(suffix, {
|
||||
.base_name = "linear_0_b",
|
||||
.source_names = {"MlpBlock_0/Dense_0/bias"},
|
||||
.axes = {0},
|
||||
.shape = {layer_config.ff_hidden_dim},
|
||||
.min_size = Type::kF32,
|
||||
});
|
||||
Add(suffix,
|
||||
{
|
||||
.base_name = "linear_1_w",
|
||||
.source_names = {"MlpBlock_0/Dense_1/kernel"},
|
||||
.axes = {1, 0},
|
||||
.shape = {config.vit_config.model_dim, layer_config.ff_hidden_dim},
|
||||
.min_size = Type::kBF16,
|
||||
});
|
||||
Add(suffix, {
|
||||
.base_name = "linear_1_b",
|
||||
.source_names = {"MlpBlock_0/Dense_1/bias"},
|
||||
.axes = {0},
|
||||
.shape = {config.vit_config.model_dim},
|
||||
.min_size = Type::kF32,
|
||||
});
|
||||
Add(suffix,
|
||||
{
|
||||
.base_name = "ln_0_bias",
|
||||
.source_names = {"img/Transformer/encoderblock/LayerNorm_0/bias",
|
||||
"img/Transformer/encoderblock_" +
|
||||
std::to_string(img_layer_idx) +
|
||||
"/LayerNorm_0/bias"},
|
||||
.axes = {0},
|
||||
.shape = {config.vit_config.model_dim},
|
||||
.min_size = Type::kBF16,
|
||||
});
|
||||
Add(suffix,
|
||||
{
|
||||
.base_name = "ln_0_scale",
|
||||
.source_names = {"img/Transformer/encoderblock/LayerNorm_0/scale",
|
||||
"img/Transformer/encoderblock_" +
|
||||
std::to_string(img_layer_idx) +
|
||||
"/LayerNorm_0/scale"},
|
||||
.axes = {0},
|
||||
.shape = {config.vit_config.model_dim},
|
||||
.min_size = Type::kBF16,
|
||||
});
|
||||
Add(suffix,
|
||||
{
|
||||
.base_name = "ln_1_bias",
|
||||
.source_names = {"img/Transformer/encoderblock/LayerNorm_1/bias",
|
||||
"img/Transformer/encoderblock_" +
|
||||
std::to_string(img_layer_idx) +
|
||||
"/LayerNorm_1/bias"},
|
||||
.axes = {0},
|
||||
.shape = {config.vit_config.model_dim},
|
||||
.min_size = Type::kBF16,
|
||||
});
|
||||
Add(suffix,
|
||||
{
|
||||
.base_name = "ln_1_scale",
|
||||
.source_names = {"img/Transformer/encoderblock/LayerNorm_1/scale",
|
||||
"img/Transformer/encoderblock_" +
|
||||
std::to_string(img_layer_idx) +
|
||||
"/LayerNorm_1/scale"},
|
||||
.axes = {0},
|
||||
.shape = {config.vit_config.model_dim},
|
||||
.min_size = Type::kBF16,
|
||||
});
|
||||
}
|
||||
|
||||
void TensorInfoRegistry::AddGriffinLayerTensors(const LayerConfig& layer_config,
|
||||
const size_t layer_idx) {
|
||||
const std::string suffix = LayerSuffix(layer_idx);
|
||||
Add(suffix, {
|
||||
.base_name = "gr_lin_x_w",
|
||||
.source_names = {"recurrent_block/linear_x/kernel"},
|
||||
.axes = {1, 0},
|
||||
.shape = {layer_config.griffin_dim, layer_config.griffin_dim},
|
||||
});
|
||||
Add(suffix, {
|
||||
.base_name = "gr_lin_x_b",
|
||||
.source_names = {"recurrent_block/linear_x/bias"},
|
||||
.axes = {0},
|
||||
.shape = {layer_config.griffin_dim},
|
||||
.min_size = Type::kF32,
|
||||
});
|
||||
Add(suffix, {
|
||||
.base_name = "gr_lin_y_w",
|
||||
.source_names = {"recurrent_block/linear_y/kernel"},
|
||||
.axes = {1, 0},
|
||||
.shape = {layer_config.griffin_dim, layer_config.griffin_dim},
|
||||
});
|
||||
Add(suffix, {
|
||||
.base_name = "gr_lin_y_b",
|
||||
.source_names = {"recurrent_block/linear_y/bias"},
|
||||
.axes = {0},
|
||||
.shape = {layer_config.griffin_dim},
|
||||
.min_size = Type::kF32,
|
||||
});
|
||||
Add(suffix, {
|
||||
.base_name = "gr_lin_out_w",
|
||||
.source_names = {"recurrent_block/linear_out/kernel"},
|
||||
.axes = {1, 0},
|
||||
.shape = {layer_config.griffin_dim, layer_config.griffin_dim},
|
||||
});
|
||||
Add(suffix, {
|
||||
.base_name = "gr_lin_out_b",
|
||||
.source_names = {"recurrent_block/linear_out/bias"},
|
||||
.axes = {0},
|
||||
.shape = {layer_config.griffin_dim},
|
||||
.min_size = Type::kF32,
|
||||
});
|
||||
Add(suffix,
|
||||
{
|
||||
.base_name = "gr_conv_w",
|
||||
.source_names = {"recurrent_block/conv_1d/w"},
|
||||
.axes = {0, 1},
|
||||
.shape = {layer_config.conv1d_width, layer_config.griffin_dim},
|
||||
.min_size = Type::kF32,
|
||||
});
|
||||
Add(suffix, {
|
||||
.base_name = "gr_conv_b",
|
||||
.source_names = {"recurrent_block/conv_1d/b"},
|
||||
.axes = {0},
|
||||
.shape = {layer_config.griffin_dim},
|
||||
.min_size = Type::kF32,
|
||||
});
|
||||
Add(suffix, {
|
||||
.base_name = "gr1_gate_w",
|
||||
.source_names = {"recurrent_block/rg_lru/input_gate/w"},
|
||||
.axes = {0, 2, 1},
|
||||
.shape = {layer_config.heads,
|
||||
layer_config.griffin_dim / layer_config.heads,
|
||||
layer_config.griffin_dim / layer_config.heads},
|
||||
.concat_names = {"gr_gate_w", "gr2_gate_w"},
|
||||
});
|
||||
Add(suffix, {
|
||||
.base_name = "gr2_gate_w",
|
||||
.source_names = {"recurrent_block/rg_lru/a_gate/w"},
|
||||
.axes = {0, 2, 1},
|
||||
.shape = {layer_config.heads,
|
||||
layer_config.griffin_dim / layer_config.heads,
|
||||
layer_config.griffin_dim / layer_config.heads},
|
||||
.concat_names = {""},
|
||||
});
|
||||
Add(suffix, {
|
||||
.base_name = "gr_gate_w",
|
||||
.source_names = {"recurrent_block/rg_lru/gate/w"},
|
||||
.axes = {0, 2, 1},
|
||||
.shape = {2 * layer_config.heads,
|
||||
layer_config.griffin_dim / layer_config.heads,
|
||||
layer_config.griffin_dim / layer_config.heads},
|
||||
});
|
||||
Add(suffix, {
|
||||
.base_name = "gr1_gate_b",
|
||||
.source_names = {"recurrent_block/rg_lru/input_gate/b"},
|
||||
.axes = {0},
|
||||
.shape = {layer_config.griffin_dim},
|
||||
.concat_names = {"gr_gate_b", "gr2_gate_b"},
|
||||
.min_size = Type::kF32,
|
||||
});
|
||||
Add(suffix, {
|
||||
.base_name = "gr2_gate_b",
|
||||
.source_names = {"recurrent_block/rg_lru/a_gate/b"},
|
||||
.axes = {0},
|
||||
.shape = {layer_config.griffin_dim},
|
||||
.concat_names = {""},
|
||||
.min_size = Type::kF32,
|
||||
});
|
||||
Add(suffix, {
|
||||
.base_name = "gr_gate_b",
|
||||
.source_names = {"recurrent_block/rg_lru/input_gate/b"},
|
||||
.axes = {0, 1},
|
||||
.shape = {2 * layer_config.griffin_dim},
|
||||
.min_size = Type::kF32,
|
||||
});
|
||||
Add(suffix, {
|
||||
.base_name = "gr_a",
|
||||
.source_names = {"recurrent_block/rg_lru/a_param"},
|
||||
.axes = {0},
|
||||
.shape = {layer_config.griffin_dim},
|
||||
.min_size = Type::kF32,
|
||||
.scaled_softplus = true,
|
||||
});
|
||||
}
|
||||
|
||||
void TensorInfoRegistry::AddLayerTensors(const ModelConfig& config,
|
||||
const LayerConfig& layer_config,
|
||||
const size_t layer_idx) {
|
||||
const std::string suffix = LayerSuffix(layer_idx);
|
||||
Add(suffix, {
|
||||
.base_name = "key_norm",
|
||||
.source_names = {"attn/_key_norm/scale"},
|
||||
.axes = {0},
|
||||
.shape = {layer_config.qkv_dim},
|
||||
.min_size = Type::kBF16,
|
||||
});
|
||||
Add(suffix, {
|
||||
.base_name = "query_norm",
|
||||
.source_names = {"attn/_query_norm/scale"},
|
||||
.axes = {0},
|
||||
.shape = {layer_config.qkv_dim},
|
||||
.min_size = Type::kBF16,
|
||||
});
|
||||
Add(suffix, {
|
||||
.base_name = "qkv1_w",
|
||||
.source_names = {"attn/q_einsum/w"},
|
||||
.axes = {0, 2, 1},
|
||||
.shape = {layer_config.heads * layer_config.qkv_dim,
|
||||
config.model_dim},
|
||||
.concat_names = {"qkv_ein", "qkv2_w"},
|
||||
});
|
||||
Add(suffix, {
|
||||
.base_name = "qkv2_w",
|
||||
.source_names = {"attn/kv_einsum/w"},
|
||||
.axes = {1, 0, 3, 2},
|
||||
.shape = {2 * layer_config.kv_heads * layer_config.qkv_dim,
|
||||
config.model_dim},
|
||||
.concat_names = {""},
|
||||
});
|
||||
Add(suffix, {
|
||||
.base_name = "q_ein",
|
||||
.source_names = {"attention_block/proj_q/kernel"},
|
||||
.axes = {1, 0},
|
||||
.shape = {layer_config.model_dim, layer_config.model_dim},
|
||||
.concat_names = {"qkv_ein", "k_ein", "v_ein"},
|
||||
});
|
||||
Add(suffix, {
|
||||
.base_name = "k_ein",
|
||||
.source_names = {"attention_block/proj_k/kernel"},
|
||||
.axes = {1, 0},
|
||||
.shape = {layer_config.qkv_dim, layer_config.model_dim},
|
||||
.concat_names = {""},
|
||||
});
|
||||
Add(suffix, {
|
||||
.base_name = "v_ein",
|
||||
.source_names = {"attention_block/proj_v/kernel"},
|
||||
.axes = {1, 0},
|
||||
.shape = {layer_config.qkv_dim, layer_config.model_dim},
|
||||
.concat_names = {""},
|
||||
});
|
||||
Add(suffix, {
|
||||
.base_name = "qkv_ein",
|
||||
.source_names = {"attn/qkv_einsum/w"},
|
||||
.axes = {1, 0, 3, 2},
|
||||
.shape = {(layer_config.heads + 2 * layer_config.kv_heads) *
|
||||
layer_config.qkv_dim,
|
||||
config.model_dim},
|
||||
});
|
||||
Add(suffix, {
|
||||
.base_name = "attn_ob",
|
||||
.source_names = {"attention_block/proj_final/bias"},
|
||||
.axes = {0},
|
||||
.shape = {config.model_dim},
|
||||
.min_size = Type::kF32,
|
||||
});
|
||||
|
||||
Add(suffix, {
|
||||
.base_name = "gating_ein",
|
||||
.source_names = {"mlp/gating_einsum/w", "mlp/gating_einsum",
|
||||
"mlp_block/ffw_up/w"},
|
||||
.axes = {0, layer_config.optimized_gating ? 1u : 2u,
|
||||
layer_config.optimized_gating ? 2u : 1u},
|
||||
.shape = {2, layer_config.ff_hidden_dim, config.model_dim},
|
||||
});
|
||||
Add(suffix, {
|
||||
.base_name = "gating1_w",
|
||||
.source_names = {"none"},
|
||||
.axes = {0, layer_config.optimized_gating ? 1u : 2u,
|
||||
layer_config.optimized_gating ? 2u : 1u},
|
||||
.shape = {layer_config.ff_hidden_dim, config.model_dim},
|
||||
});
|
||||
Add(suffix, {
|
||||
.base_name = "gating2_w",
|
||||
.source_names = {"none"},
|
||||
.axes = {0, layer_config.optimized_gating ? 1u : 2u,
|
||||
layer_config.optimized_gating ? 2u : 1u},
|
||||
.shape = {layer_config.ff_hidden_dim, config.model_dim},
|
||||
});
|
||||
Add(suffix, {
|
||||
.base_name = "linear_w",
|
||||
.source_names = {"mlp/linear/w", "mlp/linear",
|
||||
"mlp_block/ffw_down/kernel"},
|
||||
.axes = {1, 0},
|
||||
.shape = {config.model_dim, layer_config.ff_hidden_dim},
|
||||
});
|
||||
Add(suffix, {
|
||||
.base_name = "pre_att_ns",
|
||||
.source_names = {"pre_attention_norm/scale",
|
||||
"temporal_pre_norm/scale"},
|
||||
.axes = {0},
|
||||
.shape = {config.model_dim},
|
||||
.min_size = Type::kBF16,
|
||||
});
|
||||
Add(suffix,
|
||||
{
|
||||
.base_name = "pre_ff_ns",
|
||||
.source_names = {"pre_ffw_norm/scale", "channel_pre_norm/scale"},
|
||||
.axes = {0},
|
||||
.shape = {config.model_dim},
|
||||
.min_size = Type::kBF16,
|
||||
});
|
||||
Add(suffix, {
|
||||
.base_name = "post_att_ns",
|
||||
.source_names = {"post_attention_norm/scale"},
|
||||
.axes = {0},
|
||||
.shape = {config.model_dim},
|
||||
.min_size = Type::kBF16,
|
||||
});
|
||||
Add(suffix, {
|
||||
.base_name = "post_ff_ns",
|
||||
.source_names = {"post_ffw_norm/scale"},
|
||||
.axes = {0},
|
||||
.shape = {config.model_dim},
|
||||
.min_size = Type::kBF16,
|
||||
});
|
||||
Add(suffix, {
|
||||
.base_name = "ffw_gat_b",
|
||||
.source_names = {"mlp_block/ffw_up/b"},
|
||||
.axes = {0},
|
||||
.shape = {2 * layer_config.ff_hidden_dim},
|
||||
.min_size = Type::kF32,
|
||||
});
|
||||
Add(suffix, {
|
||||
.base_name = "ffw_out_b",
|
||||
.source_names = {"mlp_block/ffw_down/bias"},
|
||||
.axes = {0},
|
||||
.shape = {config.model_dim},
|
||||
.min_size = Type::kF32,
|
||||
});
|
||||
Add(suffix,
|
||||
{
|
||||
.base_name = "att_ein",
|
||||
.source_names = {"attn/attn_vec_einsum/w",
|
||||
"attention_block/proj_final/kernel"},
|
||||
.preshape = {layer_config.heads, layer_config.qkv_dim,
|
||||
config.model_dim},
|
||||
.axes = {0, 2, 1},
|
||||
.shape = {layer_config.heads, config.model_dim, layer_config.qkv_dim},
|
||||
});
|
||||
Add(suffix,
|
||||
{
|
||||
.base_name = "att_w",
|
||||
.shape = {config.model_dim, layer_config.heads, layer_config.qkv_dim},
|
||||
.cols_take_extra_dims = true,
|
||||
});
|
||||
|
||||
if (config.model == Model::GRIFFIN_2B) {
|
||||
AddGriffinLayerTensors(layer_config, layer_idx);
|
||||
}
|
||||
}
|
||||
|
||||
TensorInfoRegistry::TensorInfoRegistry(const ModelConfig& config) {
|
||||
// Upper bound on the number of `Add()` calls in `Add*Tensors()`. Loose bound
|
||||
// in case those are changed without updating this. Better to allocate a bit
|
||||
// more than to 1.5-2x the size if too little.
|
||||
tensors_.reserve(10 + 32 * config.layer_configs.size() +
|
||||
24 * config.vit_config.layer_configs.size());
|
||||
AddModelTensors(config);
|
||||
for (size_t i = 0; i < config.layer_configs.size(); ++i) {
|
||||
AddLayerTensors(config, config.layer_configs[i], i);
|
||||
}
|
||||
for (size_t i = 0; i < config.vit_config.layer_configs.size(); ++i) {
|
||||
AddImageLayerTensors(config, config.vit_config.layer_configs[i], i);
|
||||
}
|
||||
}
|
||||
|
||||
TensorInfo TensorInfoRegistry::TensorInfoFromSourcePath(const std::string& path,
|
||||
int layer_idx) const {
|
||||
for (const TensorInfo& tensor : tensors_) {
|
||||
for (const std::string& source_name : tensor.source_names) {
|
||||
// path ends with source_name?
|
||||
const size_t pos = path.rfind(source_name);
|
||||
if (pos != std::string::npos && path.size() == pos + source_name.size()) {
|
||||
std::string name = tensor.base_name;
|
||||
if (layer_idx >= 0) name += LayerSuffix(static_cast<size_t>(layer_idx));
|
||||
return TensorInfoFromName(name);
|
||||
}
|
||||
}
|
||||
}
|
||||
return TensorInfo();
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
@ -0,0 +1,141 @@
|
|||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_TENSOR_INFO_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_TENSOR_INFO_H_
|
||||
|
||||
#include <stddef.h>
|
||||
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "compression/types.h" // Type
|
||||
#include "gemma/configs.h"
|
||||
#include "util/basics.h" // Extents2D
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
// Tensor metadata. This is far more than required to construct the `MatPtr` in
|
||||
// `LayerWeightsPtrs/WeightsPtrs`; they only use `.shape` via `ExtentsFromInfo`.
|
||||
// This is also bound to Python and filled by the exporter.
|
||||
struct TensorInfo {
|
||||
// The base name of the tensor without a layer suffix.
|
||||
std::string base_name;
|
||||
// Strings to match to the end of the name of the tensor in the python model.
|
||||
std::vector<std::string> source_names;
|
||||
// Initial reshape shape. Use only as a last resort when input may have
|
||||
// dimensions combined that need to be split before the transpose, as it
|
||||
// defeats the post-transpose shape checking. Normally empty.
|
||||
std::vector<size_t> preshape;
|
||||
// Transpose axes arg. If the input tensor has more dimensions than axes,
|
||||
// then leading dimensions are collapsed until the number of axes matches.
|
||||
std::vector<size_t> axes;
|
||||
// Expected final shape of the tensor after reshape/transpose.
|
||||
// Note that this is the shape of the tensor during export,
|
||||
// not the shape of the tensor in the sbs file, as the sbs file
|
||||
// is restricted to 2D tensors. With few exceptions, the sbs file
|
||||
// tensor rows gather all the excess dimensions. See cols_take_extra_dims.
|
||||
std::vector<size_t> shape;
|
||||
// List of names to concatenate with, used only if multiple tensors are
|
||||
// concatenated into one. The first tensor in the concatenation should have
|
||||
// concat names thus: The first name is the name of the result, and the
|
||||
// tensors with the remaining names are concatenated after this.
|
||||
// The remaining tensors to be concatenated should have just a single
|
||||
// empty string in concat_names to indicate that they have been consumed.
|
||||
std::vector<std::string> concat_names;
|
||||
// Axis at which to concatenate.
|
||||
size_t concat_axis = 0;
|
||||
// The highest permissible compression for this tensor. The default is
|
||||
// kNUQ, which provides maximum compression. Other values such as kBF16
|
||||
// or kF32 can be used to limit the compression to a specific type.
|
||||
Type min_size = Type::kNUQ;
|
||||
// Whether to apply scaled softplus to the data.
|
||||
bool scaled_softplus = false;
|
||||
// Whether the columns or the rows take any extra dimensions.
|
||||
// If false, then [10, 20, 30] -> [10*20, 30] and [30] -> [1, 30].
|
||||
// If true, then [10, 20, 30] -> [10, 20*30] and [30] -> [1, 30].
|
||||
bool cols_take_extra_dims = false;
|
||||
};
|
||||
|
||||
// Collapses/expands the tensor dims into 2D extents, which may be 0, 0 for
|
||||
// not-present tensors such as ViT in a text-only model. Safely handles nullptr
|
||||
// returned from `TensorInfoRegistry::Find`, hence not a member function.
|
||||
static inline Extents2D ExtentsFromInfo(const TensorInfo* tensor) {
|
||||
if (tensor == nullptr) return Extents2D(0, 0);
|
||||
|
||||
size_t cols = tensor->shape.back();
|
||||
size_t rows = 1;
|
||||
if (tensor->cols_take_extra_dims) {
|
||||
rows = tensor->shape[0];
|
||||
for (size_t i = 1; i < tensor->shape.size() - 1; ++i) {
|
||||
cols *= tensor->shape[i];
|
||||
}
|
||||
} else { // rows take extra dims
|
||||
for (size_t i = 0; i < tensor->shape.size() - 1; ++i) {
|
||||
rows *= tensor->shape[i];
|
||||
}
|
||||
}
|
||||
// Sometimes only one of rows or cols is zero; set both for consistency.
|
||||
if (rows == 0 || cols == 0) rows = cols = 0;
|
||||
return Extents2D(rows, cols);
|
||||
}
|
||||
|
||||
static inline std::string LayerSuffix(size_t layer_idx) {
|
||||
return std::string("_") + std::to_string(layer_idx);
|
||||
}
|
||||
|
||||
// Returns tensor base name without any layer suffix.
|
||||
static inline std::string StripLayerSuffix(const std::string& name) {
|
||||
return name.substr(0, name.rfind('_'));
|
||||
}
|
||||
|
||||
// Holds all `TensorInfo` for a model and retrieves them by (unique) name.
|
||||
class TensorInfoRegistry {
|
||||
public:
|
||||
explicit TensorInfoRegistry(const ModelConfig& config);
|
||||
~TensorInfoRegistry() = default;
|
||||
|
||||
// Returns nullptr if not found, otherwise the `TensorInfo` for the given
|
||||
// `name`, which either lacks a suffix, or is per-layer and ends with
|
||||
// `LayerSuffix(layer_idx)`. Used in `WeightsPtrs/LayerWeightsPtrs`.
|
||||
const TensorInfo* Find(const std::string& name) const {
|
||||
auto it = idx_from_name_.find(name);
|
||||
if (it == idx_from_name_.end()) return nullptr;
|
||||
return &tensors_[it->second];
|
||||
}
|
||||
|
||||
// Returns a copy of the `TensorInfo` whose name matches the given name, or a
|
||||
// default-constructed `TensorInfo` if not found. Destroying
|
||||
// `TensorInfoRegistry` afterward will not invalidate the returned value.
|
||||
TensorInfo TensorInfoFromName(const std::string& name) const {
|
||||
const TensorInfo* info = Find(name);
|
||||
if (info == nullptr) return TensorInfo();
|
||||
return *info;
|
||||
}
|
||||
|
||||
// Returns a copy of the `TensorInfo` whose source_name matches the end of the
|
||||
// given path, and whose name ends with the given layer_idx, otherwise a
|
||||
// default-constructed `TensorInfo`. Destroying `TensorInfoRegistry`
|
||||
// afterward will not invalidate the returned value.
|
||||
TensorInfo TensorInfoFromSourcePath(const std::string& path,
|
||||
int layer_idx) const;
|
||||
|
||||
private:
|
||||
// `suffix` is empty (only) for per-model tensors, otherwise `LayerSuffix`.
|
||||
void Add(const std::string& suffix, const TensorInfo& info);
|
||||
void AddModelTensors(const ModelConfig& config);
|
||||
void AddLayerTensors(const ModelConfig& config,
|
||||
const LayerConfig& layer_config, size_t layer_idx);
|
||||
void AddGriffinLayerTensors(const LayerConfig& layer_config,
|
||||
size_t layer_idx);
|
||||
|
||||
void AddImageLayerTensors(const ModelConfig& config,
|
||||
const LayerConfig& layer_config,
|
||||
size_t img_layer_idx);
|
||||
|
||||
std::vector<TensorInfo> tensors_;
|
||||
// Includes entries for base name *and* the suffixed name for each layer.
|
||||
std::unordered_map<std::string, size_t> idx_from_name_;
|
||||
};
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_TENSOR_INFO_H_
|
||||
|
|
@ -0,0 +1,40 @@
|
|||
#include "gemma/tensor_info.h"
|
||||
|
||||
#include <stdio.h>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "compression/types.h" // SfpStream
|
||||
#include "gemma/configs.h"
|
||||
#include "gemma/weights.h"
|
||||
#include "util/mat.h"
|
||||
#include "hwy/base.h" // HWY_ASSERT_M
|
||||
|
||||
namespace gcpp {
|
||||
namespace {
|
||||
|
||||
// Tests for all models that each tensor in the model can be found and that the
|
||||
// TensorInfoRegistry returns the correct shape and name for the tensor.
|
||||
TEST(TensorInfoRegistryTest, Find) {
|
||||
ForEachModel([&](Model model) {
|
||||
const ModelConfig config(model, Type::kSFP, ChooseWrapping(model));
|
||||
fprintf(stderr, "Testing %s (%s)\n", config.display_name.c_str(),
|
||||
config.Specifier().c_str());
|
||||
const TensorInfoRegistry tensors(config);
|
||||
// Each tensor in the model should be known/found.
|
||||
WeightsPtrs weights(config);
|
||||
weights.ForEachTensor(nullptr, nullptr, [&tensors](const TensorArgs& t) {
|
||||
const TensorInfo* info = tensors.Find(t.mat.Name());
|
||||
HWY_ASSERT_M(info, t.mat.Name());
|
||||
// Test that the `MatPtr` can be constructed from the TensorInfo,
|
||||
// and that the dimensions match.
|
||||
const MatPtr mat_ptr(t.mat.Name(), Type::kUnknown,
|
||||
ExtentsFromInfo(tensors.Find(t.mat.Name())));
|
||||
EXPECT_STREQ(t.mat.Name(), mat_ptr.Name()) << t.mat.Name();
|
||||
EXPECT_EQ(t.mat.Rows(), mat_ptr.Rows()) << t.mat.Name();
|
||||
EXPECT_EQ(t.mat.Cols(), mat_ptr.Cols()) << t.mat.Name();
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace gcpp
|
||||
|
|
@ -21,9 +21,7 @@
|
|||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "compression/io.h" // Path
|
||||
#include "compression/shared.h" // PromptWrapping
|
||||
#include "gemma/common.h" // Wrap
|
||||
#include "gemma/configs.h" // PromptWrapping
|
||||
#include "hwy/base.h" // HWY_ASSERT
|
||||
#include "hwy/profiler.h"
|
||||
// copybara:import_next_line:sentencepiece
|
||||
|
|
@ -37,24 +35,20 @@ constexpr bool kShowTokenization = false;
|
|||
class GemmaTokenizer::Impl {
|
||||
public:
|
||||
Impl() = default;
|
||||
explicit Impl(const Path& tokenizer_path) {
|
||||
PROFILER_ZONE("Startup.tokenizer");
|
||||
spp_ = std::make_unique<sentencepiece::SentencePieceProcessor>();
|
||||
if (!spp_->Load(tokenizer_path.path).ok()) {
|
||||
HWY_ABORT("Failed to load the tokenizer file.");
|
||||
}
|
||||
}
|
||||
// Loads the tokenizer from a serialized proto.
|
||||
explicit Impl(const std::string& tokenizer_proto) {
|
||||
if (tokenizer_proto == kMockTokenizer) return;
|
||||
PROFILER_ZONE("Startup.tokenizer");
|
||||
spp_ = std::make_unique<sentencepiece::SentencePieceProcessor>();
|
||||
if (!spp_->LoadFromSerializedProto(tokenizer_proto).ok()) {
|
||||
fprintf(stderr, "serialized proto size=%zu.\n", tokenizer_proto.size());
|
||||
HWY_ABORT("Failed to load the tokenizer from serialized proto.");
|
||||
HWY_ABORT("Failed to load tokenizer from %zu byte serialized proto.",
|
||||
tokenizer_proto.size());
|
||||
}
|
||||
}
|
||||
|
||||
std::string Serialize() const { return spp_->serialized_model_proto(); }
|
||||
std::string Serialize() const {
|
||||
return spp_ ? spp_->serialized_model_proto() : kMockTokenizer;
|
||||
}
|
||||
|
||||
bool Encode(const std::string& input,
|
||||
std::vector<std::string>* pieces) const {
|
||||
|
|
@ -82,22 +76,18 @@ class GemmaTokenizer::Impl {
|
|||
std::unique_ptr<sentencepiece::SentencePieceProcessor> spp_;
|
||||
};
|
||||
|
||||
GemmaTokenizer::GemmaTokenizer(const Path& tokenizer_path) {
|
||||
impl_ = std::make_unique<Impl>(tokenizer_path);
|
||||
GemmaTokenizer::GemmaTokenizer(const std::string& tokenizer_proto)
|
||||
: impl_(std::make_unique<Impl>(tokenizer_proto)) {
|
||||
HWY_ASSERT(impl_);
|
||||
}
|
||||
|
||||
// Default suffices, but they must be defined after GemmaTokenizer::Impl.
|
||||
GemmaTokenizer::GemmaTokenizer() = default;
|
||||
GemmaTokenizer::~GemmaTokenizer() = default;
|
||||
GemmaTokenizer::GemmaTokenizer(GemmaTokenizer&& other) = default;
|
||||
GemmaTokenizer& GemmaTokenizer::operator=(GemmaTokenizer&& other) = default;
|
||||
|
||||
std::string GemmaTokenizer::Serialize() const { return impl_->Serialize(); }
|
||||
|
||||
void GemmaTokenizer::Deserialize(const std::string& tokenizer_proto) {
|
||||
impl_ = std::make_unique<Impl>(tokenizer_proto);
|
||||
}
|
||||
|
||||
bool GemmaTokenizer::Encode(const std::string& input,
|
||||
std::vector<std::string>* pieces) const {
|
||||
return impl_->Encode(input, pieces);
|
||||
|
|
@ -114,57 +104,109 @@ bool GemmaTokenizer::Decode(const std::vector<int>& ids,
|
|||
return impl_->Decode(ids, detokenized);
|
||||
}
|
||||
|
||||
std::vector<int> WrapAndTokenize(const GemmaTokenizer& tokenizer,
|
||||
const ModelInfo& info, size_t pos,
|
||||
std::string& prompt) {
|
||||
Wrap(info, pos, prompt);
|
||||
GemmaChatTemplate::GemmaChatTemplate(const GemmaTokenizer& tokenizer,
|
||||
Model model) {
|
||||
sot_user_.reserve(3);
|
||||
if (!tokenizer.Encode("<start_of_turn>user\n", &sot_user_)) return;
|
||||
sot_model_.reserve(3);
|
||||
HWY_ASSERT(tokenizer.Encode("<start_of_turn>model\n", &sot_model_));
|
||||
eot_.reserve(2);
|
||||
HWY_ASSERT(tokenizer.Encode("<end_of_turn>\n", &eot_));
|
||||
|
||||
std::vector<int> tokens;
|
||||
HWY_ASSERT(tokenizer.Encode(prompt, &tokens));
|
||||
// Both pre-trained and instruction-tuned require BOS as first token.
|
||||
if (pos == 0) {
|
||||
tokens.insert(tokens.begin(), BOS_ID);
|
||||
}
|
||||
|
||||
// PaliGemma separator. The SEP token "\n" is always tokenized separately.
|
||||
if (info.wrapping == PromptWrapping::PALIGEMMA
|
||||
// || info.wrapping == PromptWrapping::GEMMA_VLM
|
||||
) {
|
||||
std::vector<int> sep_tokens;
|
||||
HWY_ASSERT(tokenizer.Encode("\n", &sep_tokens));
|
||||
tokens.insert(tokens.end(), sep_tokens.begin(), sep_tokens.end());
|
||||
}
|
||||
|
||||
return tokens;
|
||||
HWY_ASSERT(tokenizer.Encode("\n", &pali_sep_));
|
||||
vlm_soi_.reserve(2);
|
||||
HWY_ASSERT(tokenizer.Encode("\n\n<start_of_image>", &vlm_soi_));
|
||||
vlm_eoi_.reserve(2);
|
||||
HWY_ASSERT(tokenizer.Encode("<end_of_image>\n\n", &vlm_eoi_));
|
||||
}
|
||||
|
||||
std::vector<int> WrapVLM(const GemmaTokenizer& tokenizer, const ModelInfo& info,
|
||||
size_t pos, std::vector<int>& tokens,
|
||||
size_t image_batch_size, size_t max_image_batch_size) {
|
||||
HWY_ASSERT(info.wrapping == PromptWrapping::GEMMA_VLM);
|
||||
size_t num_images = hwy::DivCeil(image_batch_size, max_image_batch_size);
|
||||
std::vector<int> GemmaChatTemplate::Apply(size_t pos,
|
||||
const std::vector<int>& ids) const {
|
||||
HWY_ASSERT_M(!sot_user_.empty() && !sot_model_.empty() && !eot_.empty(),
|
||||
"GemmaChatTemplate has not been initialized.");
|
||||
std::vector<int> out;
|
||||
out.reserve(eot_.size() + sot_user_.size() + ids.size() + eot_.size() +
|
||||
sot_model_.size());
|
||||
|
||||
std::vector<int> sep_tokens;
|
||||
HWY_ASSERT(tokenizer.Encode("\n", &sep_tokens));
|
||||
|
||||
std::string begin_image_prompt = "\n\n<start_of_image>";
|
||||
std::vector<int> begin_image_tokens =
|
||||
WrapAndTokenize(tokenizer, info, pos, begin_image_prompt);
|
||||
|
||||
std::string end_image_prompt = "<end_of_image>\n\n";
|
||||
std::vector<int> end_image_tokens =
|
||||
WrapAndTokenize(tokenizer, info, pos, end_image_prompt);
|
||||
|
||||
for (size_t i = 0; i < num_images; ++i) {
|
||||
tokens.insert(tokens.begin(), begin_image_tokens.begin(),
|
||||
begin_image_tokens.end());
|
||||
tokens.insert(tokens.begin() + begin_image_tokens.size(), image_batch_size,
|
||||
-2);
|
||||
tokens.insert(tokens.begin() + begin_image_tokens.size() + image_batch_size,
|
||||
end_image_tokens.begin(), end_image_tokens.end());
|
||||
// Start with BOS, or prepend end_of_turn if this is a continuation.
|
||||
if (pos == 0) {
|
||||
out.push_back(BOS_ID);
|
||||
} else {
|
||||
out.insert(out.cend(), eot_.cbegin(), eot_.cend());
|
||||
}
|
||||
// Start of user turn, user prompt, end of turn; then start of model turn.
|
||||
out.insert(out.cend(), sot_user_.cbegin(), sot_user_.cend());
|
||||
out.insert(out.cend(), ids.cbegin(), ids.cend());
|
||||
out.insert(out.cend(), eot_.cbegin(), eot_.cend());
|
||||
out.insert(out.cend(), sot_model_.cbegin(), sot_model_.cend());
|
||||
return out;
|
||||
}
|
||||
|
||||
return tokens;
|
||||
std::vector<int> GemmaChatTemplate::WrapPali(const std::vector<int>& text_part,
|
||||
size_t image_batch_size) const {
|
||||
HWY_ASSERT_M(!pali_sep_.empty(),
|
||||
"GemmaChatTemplate has not been initialized.");
|
||||
std::vector<int> out;
|
||||
out.reserve(image_batch_size + 1 + text_part.size() + pali_sep_.size());
|
||||
out.resize(image_batch_size, 0);
|
||||
out.push_back(BOS_ID);
|
||||
out.insert(out.cend(), text_part.cbegin(), text_part.cend());
|
||||
out.insert(out.cend(), pali_sep_.cbegin(), pali_sep_.cend());
|
||||
return out;
|
||||
}
|
||||
|
||||
std::vector<int> GemmaChatTemplate::WrapVLM(const std::vector<int>& text_part,
|
||||
size_t image_batch_size) const {
|
||||
HWY_ASSERT_M(!vlm_soi_.empty() && !vlm_eoi_.empty(),
|
||||
"GemmaChatTemplate has not been initialized.");
|
||||
std::vector<int> out;
|
||||
out.reserve(text_part.size() + vlm_soi_.size() + image_batch_size +
|
||||
vlm_eoi_.size());
|
||||
out.insert(out.cend(), text_part.cbegin(), text_part.cend());
|
||||
out.insert(out.cend(), vlm_soi_.cbegin(), vlm_soi_.cend());
|
||||
out.insert(out.cend(), image_batch_size, -2);
|
||||
out.insert(out.cend(), vlm_eoi_.cbegin(), vlm_eoi_.cend());
|
||||
return out;
|
||||
}
|
||||
|
||||
// Text
|
||||
std::vector<int> WrapAndTokenize(const GemmaTokenizer& tokenizer,
|
||||
const GemmaChatTemplate& chat_template,
|
||||
const PromptWrapping wrapping, size_t pos,
|
||||
const std::string& prompt) {
|
||||
std::vector<int> tokens;
|
||||
HWY_ASSERT(tokenizer.Encode(prompt, &tokens));
|
||||
|
||||
switch (wrapping) {
|
||||
case PromptWrapping::GEMMA_IT:
|
||||
case PromptWrapping::GEMMA_VLM:
|
||||
return chat_template.Apply(pos, tokens);
|
||||
default:
|
||||
if (pos == 0) {
|
||||
tokens.insert(tokens.cbegin(), BOS_ID);
|
||||
}
|
||||
return tokens;
|
||||
}
|
||||
}
|
||||
|
||||
// Vision
|
||||
std::vector<int> WrapAndTokenize(const GemmaTokenizer& tokenizer,
|
||||
const GemmaChatTemplate& chat_template,
|
||||
const PromptWrapping wrapping, size_t pos,
|
||||
const std::string& prompt,
|
||||
size_t image_batch_size) {
|
||||
std::vector<int> text_part;
|
||||
HWY_ASSERT(tokenizer.Encode(prompt, &text_part));
|
||||
switch (wrapping) {
|
||||
case PromptWrapping::PALIGEMMA:
|
||||
HWY_ASSERT(pos == 0);
|
||||
return chat_template.WrapPali(text_part, image_batch_size);
|
||||
case PromptWrapping::GEMMA_VLM:
|
||||
return chat_template.Apply(
|
||||
pos, chat_template.WrapVLM(text_part, image_batch_size));
|
||||
default:
|
||||
HWY_ASSERT_M(false, "Current variant does not support vision prompt.");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue