diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000..ae66414 --- /dev/null +++ b/.gitattributes @@ -0,0 +1,37 @@ +*.7z filter=lfs diff=lfs merge=lfs -text +*.arrow filter=lfs diff=lfs merge=lfs -text +*.bin filter=lfs diff=lfs merge=lfs -text +*.bz2 filter=lfs diff=lfs merge=lfs -text +*.ckpt filter=lfs diff=lfs merge=lfs -text +*.ftz filter=lfs diff=lfs merge=lfs -text +*.gz filter=lfs diff=lfs merge=lfs -text +*.h5 filter=lfs diff=lfs merge=lfs -text +*.joblib filter=lfs diff=lfs merge=lfs -text +*.lfs.* filter=lfs diff=lfs merge=lfs -text +*.mlmodel filter=lfs diff=lfs merge=lfs -text +*.model filter=lfs diff=lfs merge=lfs -text +*.msgpack filter=lfs diff=lfs merge=lfs -text +*.npy filter=lfs diff=lfs merge=lfs -text +*.npz filter=lfs diff=lfs merge=lfs -text +*.onnx filter=lfs diff=lfs merge=lfs -text +*.ot filter=lfs diff=lfs merge=lfs -text +*.parquet filter=lfs diff=lfs merge=lfs -text +*.pb filter=lfs diff=lfs merge=lfs -text +*.pickle filter=lfs diff=lfs merge=lfs -text +*.pkl filter=lfs diff=lfs merge=lfs -text +*.pt filter=lfs diff=lfs merge=lfs -text +*.pth filter=lfs diff=lfs merge=lfs -text +*.rar filter=lfs diff=lfs merge=lfs -text +*.safetensors filter=lfs diff=lfs merge=lfs -text +saved_model/**/* filter=lfs diff=lfs merge=lfs -text +*.tar.* filter=lfs diff=lfs merge=lfs -text +*.tar filter=lfs diff=lfs merge=lfs -text +*.tflite filter=lfs diff=lfs merge=lfs -text +*.tgz filter=lfs diff=lfs merge=lfs -text +*.wasm filter=lfs diff=lfs merge=lfs -text +*.xz filter=lfs diff=lfs merge=lfs -text +*.zip filter=lfs diff=lfs merge=lfs -text +*.zst filter=lfs diff=lfs merge=lfs -text +*tfevents* filter=lfs diff=lfs merge=lfs -text +2b-pt-sfp.sbs filter=lfs diff=lfs merge=lfs -text +tokenizer.spm filter=lfs diff=lfs merge=lfs -text diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index b796814..2052a82 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -17,12 +17,10 @@ jobs: fail-fast: false matrix: # When adding another, also add to copybara's github_check_runs. - os: ['ubuntu-latest', 'macos-latest', 'windows-latest', 'ubuntu-20.04'] + os: ['ubuntu-latest', 'macos-latest', 'windows-latest'] build_type: ['Release'] preset: ['make', 'windows'] exclude: - - os: ubuntu-20.04 - preset: windows - os: ubuntu-latest preset: windows - os: macos-latest @@ -62,44 +60,6 @@ jobs: ${{ github.workspace }}/build/gemma ${{ github.workspace }}/build/libgemma.a - - if: matrix.os == 'ubuntu-20.04' - name: Upload build artifacts to Kaggle - uses: pculliton/push-kaggle-dataset@v1.0.0 - env: - KAGGLE_USERNAME: ${{ secrets.KAGGLE_USERNAME }} - KAGGLE_KEY: ${{ secrets.KAGGLE_KEY }} - with: - id: "phillipculliton/gemma-build-artifacts" - files: | - build/gemma - build/_deps/sentencepiece-build/src/libsentencepiece.so.0 - - - if: matrix.os == 'ubuntu-20.04' - name: Create code for new test notebook version - run: | - cat > runner.py << EOF - import subprocess - subprocess.run(["cp", "/kaggle/input/gemma-build-artifacts/gemma", "/kaggle/working"]) - subprocess.run(["chmod", "700", "/kaggle/working/gemma"]) - subprocess.run(["cp", "/kaggle/input/gemma-build-artifacts/_deps/sentencepiece-build/src/libsentencepiece.so.0", "/kaggle/working"]) - output = subprocess.run(["/kaggle/working/gemma", "--tokenizer", "/kaggle/input/gemma/gemmacpp/2b-it-sfp/4/tokenizer.spm", "--compressed_weights", "/kaggle/input/gemma/gemmacpp/2b-it-sfp/4/2b-it-sfp.sbs", "--model", "2b-it", "--verbosity", "0", "--max_generated_tokens", "128"], stdout=subprocess.PIPE, input='Write an email to the moon.', encoding='ascii').stdout - assert("write an email to the moon." not in output.lower()); - assert("moon" in output.lower()); - EOF - - - if: matrix.os == 'ubuntu-20.04' - name: Run kaggle test notebook - uses: pculliton/kaggle-action@v1.0.28 - with: - username: ${{ secrets.KAGGLE_USERNAME }} - key: ${{ secrets.KAGGLE_KEY }} - title: GemmaCPP-CI-2 - code_file: runner.py - dataset_sources: "phillipculliton/gemma-build-artifacts" - model_sources: "google/gemma/gemmaCpp/2b-it-sfp/4" - enable_gpu: False - kernel_type: script - bazel: runs-on: ubuntu-latest steps: @@ -116,4 +76,4 @@ jobs: with: path: ~/.cache/bazel key: bazel-${{ runner.os }} - - run: bazel build --cxxopt=-std=c++20 //:all + - run: bazel build --cxxopt=-std=c++20 //:gemma --jobs=10 --show_progress_rate_limit=1 diff --git a/.gitignore b/.gitignore index d4264cb..1c13032 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,25 @@ +# Build directories .cache/ bazel-*/ build-*/ +build/ + +# Python cache python/*/__pycache__ + +# Model files +*.sbs +*.spm +*.data +*.bin +*.weights + +# IDE and editor files +.vscode/ +.idea/ +*.swp +*~ + +# Local development +.env +.env.local \ No newline at end of file diff --git a/.vscode/c_cpp_properties.json b/.vscode/c_cpp_properties.json new file mode 100644 index 0000000..64d3f90 --- /dev/null +++ b/.vscode/c_cpp_properties.json @@ -0,0 +1,15 @@ +{ + "configurations": [ + { + "name": "Linux", + "includePath": [ + "${workspaceFolder}/**" + ], + "defines": [], + "cStandard": "c17", + "cppStandard": "c++17", + "intelliSenseMode": "linux-clang-x64" + } + ], + "version": 4 +} \ No newline at end of file diff --git a/BUILD.bazel b/BUILD.bazel index 8c32631..2628bc3 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -19,7 +19,10 @@ license( # Dual-licensed Apache 2 and 3-clause BSD. licenses(["notice"]) -exports_files(["LICENSE"]) +exports_files([ + "LICENSE", + ".github/workflows/build.yml", +]) cc_library( name = "basics", @@ -29,6 +32,16 @@ cc_library( ], ) +cc_library( + name = "args", + hdrs = ["util/args.h"], + deps = [ + ":basics", + "//io", # Path + "@highway//:hwy", + ], +) + # Split from :threading to break a circular dependency with :allocator. cc_library( name = "topology", @@ -59,6 +72,7 @@ cc_library( hdrs = ["util/threading.h"], deps = [ ":allocator", + ":args", ":basics", ":topology", # Placeholder for container detection, do not remove @@ -68,14 +82,28 @@ cc_library( ], ) +cc_library( + name = "threading_context", + srcs = ["util/threading_context.cc"], + hdrs = ["util/threading_context.h"], + deps = [ + ":allocator", + ":args", + ":basics", + ":threading", + ":topology", + "@highway//:hwy", + "@highway//:profiler", + ], +) + cc_test( name = "threading_test", srcs = ["util/threading_test.cc"], deps = [ - ":allocator", ":basics", - ":threading", - "@googletest//:gtest_main", + ":threading_context", + "@googletest//:gtest_main", # buildcleaner: keep "@highway//:auto_tune", "@highway//:hwy", "@highway//:hwy_test_util", @@ -97,6 +125,124 @@ cc_library( ], ) +cc_library( + name = "configs", + srcs = ["gemma/configs.cc"], + hdrs = ["gemma/configs.h"], + deps = [ + ":basics", + "//compression:types", + "//io", + "//io:fields", + "@highway//:hwy", # base.h + ], +) + +cc_test( + name = "configs_test", + srcs = ["gemma/configs_test.cc"], + deps = [ + ":configs", + "@googletest//:gtest_main", # buildcleaner: keep + "//compression:types", + "//io:fields", + ], +) + +cc_library( + name = "tensor_info", + srcs = ["gemma/tensor_info.cc"], + hdrs = ["gemma/tensor_info.h"], + deps = [ + ":basics", + ":configs", + "//compression:types", + ], +) + +cc_library( + name = "mat", + srcs = ["util/mat.cc"], + hdrs = ["util/mat.h"], + deps = [ + ":allocator", + ":basics", + ":tensor_info", + ":threading_context", + "//compression:types", + "//io:fields", + "@highway//:hwy", + "@highway//:profiler", + ], +) + +cc_library( + name = "tokenizer", + srcs = ["gemma/tokenizer.cc"], + hdrs = ["gemma/tokenizer.h"], + deps = [ + ":configs", + "@highway//:hwy", + "@highway//:profiler", + "@com_google_sentencepiece//:sentencepiece_processor", + ], +) + +cc_library( + name = "model_store", + srcs = ["gemma/model_store.cc"], + hdrs = ["gemma/model_store.h"], + deps = [ + ":allocator", + ":basics", + ":configs", + ":mat", + ":tensor_info", + ":threading_context", + ":tokenizer", + "//compression:types", + "//io", + "//io:blob_store", + "//io:fields", + "@highway//:hwy", + "@highway//:thread_pool", + ], +) + +cc_library( + name = "weights", + srcs = ["gemma/weights.cc"], + hdrs = ["gemma/weights.h"], + deps = [ + ":configs", + ":gemma_args", + ":mat", + ":matmul", + ":model_store", + ":tensor_info", + ":threading_context", + "//compression:compress", + "//io:blob_store", + "@highway//:hwy", + "@highway//:profiler", + "@highway//:thread_pool", + ], +) + +cc_test( + name = "tensor_info_test", + srcs = ["gemma/tensor_info_test.cc"], + deps = [ + ":configs", + ":mat", + ":tensor_info", + ":weights", + "@googletest//:gtest_main", # buildcleaner: keep + "//compression:compress", + "@highway//:hwy", # aligned_allocator.h + ], +) + # For building all tests in one command, so we can test several. test_suite( name = "ops_tests", @@ -104,34 +250,75 @@ test_suite( ) cc_library( - name = "ops", + name = "matmul", + srcs = ["ops/matmul.cc"], + hdrs = ["ops/matmul.h"], + textual_hdrs = ["ops/matmul-inl.h"], + deps = [ + ":allocator", + ":basics", + ":mat", + ":threading_context", + "//compression:compress", + "@highway//:bit_set", + "@highway//:hwy", + "@highway//:nanobenchmark", + "@highway//:profiler", + ], +) + +cc_library( + name = "matmul_static", srcs = [ - "ops/matmul.cc", + # single-file build time is ~30sec for msan, hence shard. + "ops/matmul_static_bf16.cc", + "ops/matmul_static_f32.cc", + "ops/matmul_static_nuq.cc", + "ops/matmul_static_sfp.cc", ], hdrs = [ - "ops/matmul.h", - "ops/ops.h", + "ops/matmul_static.h", ], + textual_hdrs = [ + "ops/matmul_static-inl.h", + "ops/matmul-inl.h", + ], + deps = [ + ":allocator", + ":basics", + ":mat", + ":matmul", + ":threading_context", + "//compression:compress", + "//compression:types", + "@highway//:hwy", + "@highway//:profiler", + "@highway//:timer", + ], +) + +cc_library( + name = "ops", + hdrs = ["ops/ops.h"], textual_hdrs = [ "ops/dot-inl.h", "ops/sum-inl.h", "ops/fp_arith-inl.h", - "ops/matmul-inl.h", "ops/matvec-inl.h", "ops/ops-inl.h", ], deps = [ ":allocator", ":basics", - ":threading", - ":topology", + ":mat", + ":matmul", + ":matmul_static", + ":threading_context", "//compression:compress", "@highway//:algo", - "@highway//:bit_set", "@highway//:hwy", "@highway//:math", "@highway//:matvec", - "@highway//:nanobenchmark", "@highway//:profiler", "@highway//:thread_pool", "@highway//hwy/contrib/sort:vqsort", @@ -143,15 +330,15 @@ cc_test( size = "small", timeout = "long", srcs = ["ops/dot_test.cc"], + linkstatic = True, local_defines = ["HWY_IS_TEST"], # for test_suite. tags = ["ops_tests"], deps = [ ":allocator", - ":app", ":ops", ":test_util", - ":threading", + ":threading_context", "@googletest//:gtest_main", # buildcleaner: keep "//compression:compress", "//compression:test_util", @@ -167,20 +354,23 @@ cc_test( cc_test( name = "ops_test", size = "small", - timeout = "eternal", + timeout = "long", srcs = ["ops/ops_test.cc"], + linkstatic = True, local_defines = ["HWY_IS_TEST"], # for test_suite. tags = ["ops_tests"], deps = [ ":allocator", - ":app", - ":common", + ":basics", + ":configs", + ":gemma_lib", + ":mat", ":ops", ":test_util", - ":threading", + ":threading_context", "@googletest//:gtest_main", # buildcleaner: keep - "//compression:compress", + "//compression:types", "@highway//:hwy", "@highway//:hwy_test_util", "@highway//:nanobenchmark", #buildcleaner: keep @@ -192,11 +382,14 @@ cc_test( size = "small", timeout = "long", srcs = ["ops/gemma_matvec_test.cc"], + linkstatic = True, local_defines = ["HWY_IS_TEST"], # for test_suite. tags = ["ops_tests"], deps = [ + ":mat", ":ops", + ":threading_context", "@googletest//:gtest_main", # buildcleaner: keep "//compression:compress", "@highway//:hwy", @@ -210,18 +403,23 @@ cc_test( size = "small", timeout = "long", srcs = ["ops/matmul_test.cc"], + linkstatic = True, local_defines = ["HWY_IS_TEST"], # for test_suite. tags = ["ops_tests"], deps = [ - ":allocator", ":basics", + ":mat", + ":matmul", + ":matmul_static", ":ops", - ":threading", + ":threading_context", "@googletest//:gtest_main", # buildcleaner: keep "//compression:compress", + "//compression:test_util", "@highway//:hwy", "@highway//:hwy_test_util", + "@highway//:nanobenchmark", "@highway//:thread_pool", ], ) @@ -231,6 +429,7 @@ cc_test( size = "small", timeout = "long", srcs = ["ops/bench_matmul.cc"], + linkstatic = True, local_defines = ["HWY_IS_TEST"], tags = [ "manual", @@ -238,12 +437,12 @@ cc_test( "ops_tests", # for test_suite. ], deps = [ - ":allocator", ":basics", - ":ops", - ":threading", + ":matmul", + ":threading_context", "@googletest//:gtest_main", # buildcleaner: keep "//compression:compress", + "//compression:test_util", "@highway//:hwy", "@highway//:hwy_test_util", "@highway//:nanobenchmark", @@ -252,101 +451,47 @@ cc_test( ], ) -cc_library( - name = "common", - srcs = [ - "gemma/common.cc", - "gemma/configs.cc", - "gemma/tensor_index.cc", - ], - hdrs = [ - "gemma/common.h", - "gemma/configs.h", - "gemma/tensor_index.h", - ], - deps = [ - "//compression:fields", - "//compression:sfp", - "@highway//:hwy", # base.h - "@highway//:thread_pool", - ], -) - -cc_test( - name = "configs_test", - srcs = ["gemma/configs_test.cc"], - deps = [ - ":common", - "@googletest//:gtest_main", - "@highway//:hwy", - ], -) - -cc_test( - name = "tensor_index_test", - srcs = ["gemma/tensor_index_test.cc"], - deps = [ - ":basics", - ":common", - ":weights", - "@googletest//:gtest_main", - "//compression:compress", - "@highway//:hwy", - ], -) - -cc_library( - name = "weights", - srcs = ["gemma/weights.cc"], - hdrs = ["gemma/weights.h"], - deps = [ - ":common", - "//compression:blob_store", - "//compression:compress", - "//compression:io", - "@highway//:hwy", - "@highway//:profiler", - "@highway//:stats", - "@highway//:thread_pool", - ], -) - -cc_library( - name = "tokenizer", - srcs = ["gemma/tokenizer.cc"], - hdrs = ["gemma/tokenizer.h"], - deps = [ - ":common", - "//compression:io", - "//compression:sfp", - "@highway//:hwy", - "@highway//:profiler", - "@com_google_sentencepiece//:sentencepiece_processor", - ], -) - cc_library( name = "kv_cache", srcs = ["gemma/kv_cache.cc"], hdrs = ["gemma/kv_cache.h"], deps = [ - ":common", + ":basics", + ":configs", + ":gemma_args", + ":mat", "@highway//:hwy", ], ) +cc_library( + name = "gemma_args", + hdrs = ["gemma/gemma_args.h"], + deps = [ + ":args", + ":basics", + ":mat", + ":matmul", + "//io", + "@highway//:hwy", + "@highway//:profiler", + ], +) + cc_library( name = "gemma_lib", srcs = [ + "gemma/attention.cc", "gemma/gemma.cc", - "gemma/instantiations/bf16.cc", - "gemma/instantiations/f32.cc", - "gemma/instantiations/nuq.cc", - "gemma/instantiations/sfp.cc", + "gemma/griffin.cc", + "gemma/vit.cc", ], hdrs = [ "gemma/activations.h", + "gemma/attention.h", "gemma/gemma.h", + "gemma/griffin.h", + "gemma/vit.h", ], exec_properties = { # Avoid linker OOMs when building with sanitizer instrumentation. @@ -354,23 +499,26 @@ cc_library( }, textual_hdrs = [ "gemma/gemma-inl.h", - # Placeholder for internal file2, do not remove, ], deps = [ ":allocator", ":basics", - ":common", - ":ops", - ":tokenizer", + ":configs", + ":gemma_args", ":kv_cache", - ":weights", + ":mat", + ":matmul", + ":model_store", + ":ops", ":threading", + ":threading_context", + ":weights", "//compression:compress", - "//compression:io", - "//compression:sfp", + "//compression:types", + "//io:blob_store", + "//io", "//paligemma:image", "@highway//:hwy", - "@highway//:bit_set", "@highway//:nanobenchmark", # timer "@highway//:profiler", "@highway//:thread_pool", @@ -382,35 +530,9 @@ cc_library( srcs = ["evals/cross_entropy.cc"], hdrs = ["evals/cross_entropy.h"], deps = [ - ":common", ":gemma_lib", ":ops", - "@highway//:hwy", - ], -) - -cc_library( - name = "args", - hdrs = ["util/args.h"], - deps = [ - ":basics", - "//compression:io", - "@highway//:hwy", - ], -) - -cc_library( - name = "app", - hdrs = ["util/app.h"], - deps = [ - ":args", - ":basics", - ":common", - ":gemma_lib", - ":ops", - ":threading", - "//compression:io", - "//compression:sfp", + "//compression:types", "@highway//:hwy", ], ) @@ -420,20 +542,49 @@ cc_library( srcs = ["evals/benchmark_helper.cc"], hdrs = ["evals/benchmark_helper.h"], deps = [ - ":app", - ":args", - ":common", + ":configs", ":cross_entropy", + ":gemma_args", ":gemma_lib", - ":kv_cache", + ":matmul", ":ops", - ":threading", - # Placeholder for internal dep, do not remove., + ":threading_context", + ":tokenizer", "@google_benchmark//:benchmark", "//compression:compress", "@highway//:hwy", "@highway//:nanobenchmark", - "@highway//:topology", + "@highway//:profiler", + ], +) + +cc_library( + name = "gemma_shared_lib", + srcs = [ + "gemma/bindings/c_api.cc", + "gemma/bindings/context.cc", + ], + hdrs = [ + "gemma/bindings/c_api.h", + "gemma/bindings/context.h", + ], + exec_properties = { + # Avoid linker OOMs when building with sanitizer instrumentation. + "mem": "28g", + }, + deps = [ + ":benchmark_helper", + ":gemma_args", + ":gemma_lib", + ":kv_cache", + ":matmul", + ":threading", + ":threading_context", + ":tokenizer", + "//paligemma:image", + "@highway//:hwy", + "@highway//:profiler", + "@highway//:timer", ], ) @@ -449,9 +600,10 @@ cc_test( ], deps = [ ":benchmark_helper", - ":common", + ":configs", ":gemma_lib", - "@googletest//:gtest_main", + "@googletest//:gtest_main", # buildcleaner: keep + "//io", "@highway//:hwy", "@highway//:hwy_test_util", ], @@ -460,6 +612,7 @@ cc_test( cc_test( name = "gemma_batch_bench", srcs = ["evals/gemma_batch_bench.cc"], + linkstatic = True, # Requires model files tags = [ "local", @@ -468,12 +621,12 @@ cc_test( ], deps = [ ":benchmark_helper", - ":common", ":gemma_lib", - ":tokenizer", - "@googletest//:gtest_main", + "@googletest//:gtest_main", # buildcleaner: keep "@highway//:hwy", "@highway//:hwy_test_util", + "@highway//:nanobenchmark", + "@highway//:profiler", ], ) @@ -481,15 +634,13 @@ cc_binary( name = "gemma", srcs = ["gemma/run.cc"], deps = [ - ":app", ":args", ":benchmark_helper", - ":common", + ":gemma_args", ":gemma_lib", - ":ops", - ":threading", - # Placeholder for internal dep, do not remove., - "//compression:sfp", + ":matmul", + ":tokenizer", + "//compression:types", "//paligemma:image", "@highway//:hwy", "@highway//:profiler", @@ -502,22 +653,15 @@ cc_binary( deps = [ ":args", ":benchmark_helper", - ":common", ":cross_entropy", ":gemma_lib", - "//compression:io", + "//io", "@highway//:hwy", "@highway//:nanobenchmark", "@nlohmann_json//:json", ], ) -cc_library( - name = "benchmark_prompts", - hdrs = ["evals/prompts.h"], - deps = ["@highway//:hwy"], -) - cc_binary( name = "benchmarks", srcs = [ @@ -526,7 +670,6 @@ cc_binary( ], deps = [ ":benchmark_helper", - ":benchmark_prompts", "@google_benchmark//:benchmark", "@highway//:hwy", # base.h ], @@ -534,14 +677,12 @@ cc_binary( cc_binary( name = "debug_prompt", - srcs = [ - "evals/debug_prompt.cc", - ], + srcs = ["evals/debug_prompt.cc"], deps = [ ":args", ":benchmark_helper", ":gemma_lib", - "//compression:io", + "//io", "@highway//:hwy", "@nlohmann_json//:json", ], @@ -554,158 +695,9 @@ cc_binary( ":args", ":benchmark_helper", ":gemma_lib", - "//compression:io", + "//io", "@highway//:hwy", "@highway//:profiler", - "@highway//:thread_pool", "@nlohmann_json//:json", ], ) - -cc_library( - name = "prompt", - hdrs = ["backprop/prompt.h"], - deps = [], -) - -cc_library( - name = "sampler", - hdrs = ["backprop/sampler.h"], - deps = [ - ":prompt", - ], -) - -cc_library( - name = "backprop", - srcs = [ - "backprop/backward.cc", - "backprop/forward.cc", - ], - hdrs = [ - "backprop/activations.h", - "backprop/backward.h", - "backprop/forward.h", - ], - textual_hdrs = [ - "backprop/backward-inl.h", - "backprop/forward-inl.h", - ], - deps = [ - ":allocator", - ":common", - ":ops", - ":prompt", - ":weights", - "//compression:compress", - "@highway//:dot", - "@highway//:hwy", # base.h - "@highway//:thread_pool", - ], -) - -cc_library( - name = "backprop_scalar", - hdrs = [ - "backprop/activations.h", - "backprop/backward_scalar.h", - "backprop/common_scalar.h", - "backprop/forward_scalar.h", - ], - deps = [ - ":common", - ":prompt", - ":weights", - "//compression:compress", - "@highway//:hwy", - ], -) - -cc_test( - name = "backward_scalar_test", - size = "large", - srcs = [ - "backprop/backward_scalar_test.cc", - "backprop/test_util.h", - ], - deps = [ - ":backprop_scalar", - ":common", - ":prompt", - ":sampler", - ":weights", - "@googletest//:gtest_main", - "//compression:compress", - "@highway//:thread_pool", - ], -) - -cc_test( - name = "backward_test", - size = "large", - srcs = [ - "backprop/backward_test.cc", - "backprop/test_util.h", - ], - exec_properties = { - # Avoid linker OOMs when building with sanitizer instrumentation. - "mem": "28g", - }, - deps = [ - ":allocator", - ":backprop", - ":backprop_scalar", - ":common", - ":ops", - ":prompt", - ":sampler", - ":threading", - ":weights", - "@googletest//:gtest_main", - "//compression:compress", - "@highway//:hwy", - "@highway//:hwy_test_util", - "@highway//:thread_pool", - ], -) - -cc_library( - name = "optimizer", - srcs = ["backprop/optimizer.cc"], - hdrs = ["backprop/optimizer.h"], - deps = [ - ":allocator", - ":common", - ":weights", - "//compression:compress", - "@highway//:hwy", - "@highway//:thread_pool", - ], -) - -cc_test( - name = "optimize_test", - srcs = [ - "backprop/optimize_test.cc", - ], - exec_properties = { - # Avoid linker OOMs when building with sanitizer instrumentation. - "mem": "28g", - }, - deps = [ - ":allocator", - ":backprop", - ":basics", - ":common", - ":gemma_lib", - ":ops", - ":optimizer", - ":prompt", - ":sampler", - ":threading", - ":weights", - "@googletest//:gtest_main", - "//compression:sfp", - "@highway//:thread_pool", - ], -) diff --git a/CMakeLists.txt b/CMakeLists.txt index 1737c2d..ef2f2c8 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -22,7 +22,7 @@ set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_EXPORT_COMPILE_COMMANDS ON) -FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG c5bebf84ad01edec97e336f5c97ca4e0df6b4d06 EXCLUDE_FROM_ALL) +FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 9414b48aeec251b69e6cadbfa42bebb5ddae1c34 EXCLUDE_FROM_ALL) FetchContent_MakeAvailable(highway) ## Note: absl needs to be installed by sentencepiece. This will only happen if @@ -39,58 +39,54 @@ set(BENCHMARK_ENABLE_GTEST_TESTS OFF) FetchContent_Declare(benchmark GIT_REPOSITORY https://github.com/google/benchmark.git GIT_TAG v1.8.2 EXCLUDE_FROM_ALL) FetchContent_MakeAvailable(benchmark) +# Base source files set(SOURCES - compression/blob_store.cc - compression/blob_store.h + compression/compress-inl.h compression/compress.cc compression/compress.h - compression/compress-inl.h - compression/fields.cc - compression/fields.h - compression/io_win.cc - compression/io.cc - compression/io.h compression/nuq-inl.h compression/sfp-inl.h - compression/shared.h + compression/types.h compression/test_util-inl.h - backprop/activations.h - backprop/backward.cc - backprop/backward.h - backprop/backward-inl.h - backprop/backward_scalar.h - backprop/common_scalar.h - backprop/forward.cc - backprop/forward.h - backprop/forward-inl.h - backprop/forward_scalar.h - backprop/optimizer.cc - backprop/optimizer.h evals/benchmark_helper.cc evals/benchmark_helper.h evals/cross_entropy.cc evals/cross_entropy.h gemma/activations.h - gemma/common.cc - gemma/common.h + gemma/attention.cc + gemma/attention.h gemma/configs.cc gemma/configs.h + gemma/gemma_args.h gemma/gemma-inl.h gemma/gemma.cc gemma/gemma.h - gemma/instantiations/bf16.cc - gemma/instantiations/f32.cc - gemma/instantiations/nuq.cc - gemma/instantiations/sfp.cc + gemma/griffin.cc + gemma/griffin.h gemma/kv_cache.cc gemma/kv_cache.h - gemma/tensor_index.cc - gemma/tensor_index.h + gemma/model_store.cc + gemma/model_store.h + gemma/tensor_info.cc + gemma/tensor_info.h gemma/tokenizer.cc gemma/tokenizer.h + gemma/vit.cc + gemma/vit.h gemma/weights.cc gemma/weights.h + io/blob_store.cc + io/blob_store.h + io/fields.cc + io/fields.h + io/io_win.cc + io/io.cc + io/io.h ops/dot-inl.h + ops/matmul_static_bf16.cc + ops/matmul_static_f32.cc + ops/matmul_static_nuq.cc + ops/matmul_static_sfp.cc ops/matmul-inl.h ops/matmul.cc ops/matmul.h @@ -102,15 +98,28 @@ set(SOURCES paligemma/image.h util/allocator.cc util/allocator.h - util/app.h - util/args.h util/basics.h + util/mat.cc + util/mat.h util/test_util.h + util/threading_context.cc + util/threading_context.h util/threading.cc util/threading.h util/topology.cc util/topology.h +) + +# Add C API sources only when building DLL +if(BUILD_GEMMA_DLL) + list(APPEND SOURCES + gemma/bindings/context.h + gemma/bindings/context.cc + gemma/bindings/c_api.h + gemma/bindings/c_api.cc ) + message(STATUS "Including C API files for DLL build") +endif() if(NOT CMAKE_BUILD_TYPE) set(CMAKE_BUILD_TYPE "Release") @@ -131,6 +140,33 @@ target_compile_definitions(libgemma PRIVATE $<$:_CRT_SECURE target_compile_options(libgemma PRIVATE $<$:-Wno-deprecated-declarations>) install(TARGETS libgemma DESTINATION lib) +# Shared library target for C# interop +if(BUILD_GEMMA_DLL) + add_library(gemma_shared SHARED ${SOURCES}) +set_property(TARGET gemma_shared PROPERTY CXX_STANDARD 17) +set_target_properties(gemma_shared PROPERTIES + PREFIX "" + OUTPUT_NAME "gemma" +) +set_property(TARGET gemma_shared PROPERTY POSITION_INDEPENDENT_CODE ON) +target_include_directories(gemma_shared PUBLIC ./) +target_link_libraries(gemma_shared PRIVATE + $ + $ + $ +) +target_include_directories(gemma_shared PUBLIC ${sentencepiece_SOURCE_DIR}) +target_compile_definitions(gemma_shared + PRIVATE + GEMMA_EXPORTS + $<$:_CRT_SECURE_NO_WARNINGS NOMINMAX> +) +target_compile_options(gemma_shared PRIVATE $<$:-Wno-deprecated-declarations>) +install(TARGETS gemma_shared DESTINATION lib) +install(FILES gemma/c_api.h DESTINATION include/gemma) +install(FILES gemma/GemmaInterop.cs DESTINATION include/gemma) +endif() + # Executable Target add_executable(gemma gemma/run.cc) @@ -154,17 +190,14 @@ enable_testing() include(GoogleTest) set(GEMMA_TEST_FILES - backprop/backward_scalar_test.cc - backprop/backward_test.cc - backprop/optimize_test.cc - compression/blob_store_test.cc compression/compress_test.cc compression/distortion_test.cc - compression/fields_test.cc compression/nuq_test.cc compression/sfp_test.cc evals/gemma_test.cc - gemma/tensor_index_test.cc + gemma/tensor_info_test.cc + io/blob_store_test.cc + io/fields_test.cc ops/bench_matmul.cc ops/dot_test.cc ops/gemma_matvec_test.cc @@ -197,8 +230,5 @@ endif() # GEMMA_ENABLE_TESTS ## Tools -add_executable(compress_weights compression/compress_weights.cc) -target_link_libraries(compress_weights libgemma hwy hwy_contrib) - -add_executable(migrate_weights compression/migrate_weights.cc) +add_executable(migrate_weights io/migrate_weights.cc) target_link_libraries(migrate_weights libgemma hwy hwy_contrib) diff --git a/CMakePresets.json b/CMakePresets.json index 5fe13c8..a34b5bf 100644 --- a/CMakePresets.json +++ b/CMakePresets.json @@ -31,6 +31,24 @@ "lhs": "${hostSystemName}", "rhs": "Windows" } + }, + { + "name": "windows-dll", + "inherits": "__defaults__", + "displayName": "Windows DLL", + "description": "Visual Studio 2022 with Clang/LLVM frontend (DLL build)", + "generator": "Visual Studio 17 2022", + "toolset": "ClangCL", + "condition": { + "type": "equals", + "lhs": "${hostSystemName}", + "rhs": "Windows" + }, + "cacheVariables": { + "BUILD_SHARED_LIBS": "OFF", + "CMAKE_WINDOWS_EXPORT_ALL_SYMBOLS": "ON", + "BUILD_GEMMA_DLL": "ON" + } } ], "buildPresets": [ @@ -54,6 +72,15 @@ "displayName": "Windows", "configuration": "Release", "configurePreset": "windows" + }, + { + "name": "windows-dll", + "displayName": "Windows DLL", + "configuration": "Release", + "configurePreset": "windows-dll", + "targets": [ + "gemma_shared" + ] } ] } diff --git a/DEVELOPERS.md b/DEVELOPERS.md index fdebad4..5d70fdb 100644 --- a/DEVELOPERS.md +++ b/DEVELOPERS.md @@ -96,21 +96,10 @@ https://github.com/keras-team/keras-nlp/blob/master/tools/gemma/export_gemma_to_ From Pytorch, use the following script to generate uncompressed weights: https://github.com/google/gemma.cpp/blob/dev/compression/convert_weights.py -Then run `compression/compress_weights.cc` (Bazel target -`compression:compress_weights`), specifying the resulting file as `--weights` -and the desired .sbs name as the `--compressed_weights`. +For PaliGemma, use `python/convert_from_safetensors` to create an SBS file +directly. -## Compile-Time Flags (Advanced) - -There are several compile-time flags to be aware of (note these may or may not -be exposed to the build system): - -- `GEMMA_MAX_SEQ_LEN` : Sets maximum sequence length to preallocate for the KV - Cache. The default is 4096 tokens but can be overridden. This is not exposed - through `CMakeLists.txt` yet. - -In the medium term this will likely be deprecated in favor of handling options -at runtime - dynamically resizing the KV cache as needed. +For other models, `gemma_export_main.py` is not yet open sourced. ## Using gemma.cpp as a Library (Advanced) @@ -164,7 +153,7 @@ constrained decoding type of use cases where you want to force the generation to fit a grammar. If you're not doing this, you can send an empty lambda or `std::function` as a no-op which is what `run.cc` does. -### `Transformer()` implements the inference (i.e. `forward()` method in PyTorch or Jax) computation of the neural network +### `Transformer()` implements inference (i.e. `forward()` in PyTorch or Jax) For high-level applications, you might only call `model.Generate()` and never interact directly with the neural network, but if you're doing something a bit @@ -172,9 +161,6 @@ more custom you can call transformer which performs a single inference operation on a single token and mutates the Activations and the KVCache through the neural network computation. -Note that an experimental backward pass is available in backprop/, which may be -useful for fine tuning. - ### For low level operations, defining new architectures, call `ops.h` functions directly You use `ops.h` if you're writing other NN architectures or modifying the diff --git a/MODULE.bazel b/MODULE.bazel index 77690fa..95fb5cc 100644 --- a/MODULE.bazel +++ b/MODULE.bazel @@ -18,7 +18,7 @@ bazel_dep(name = "google_benchmark", version = "1.8.5") # Require a more recent version. git_override( module_name = "highway", - commit = "c5bebf84ad01edec97e336f5c97ca4e0df6b4d06", + commit = "9414b48aeec251b69e6cadbfa42bebb5ddae1c34", remote = "https://github.com/google/highway", ) @@ -71,6 +71,7 @@ pip.parse( requirements_lock = "//compression/python:requirements.txt", ) use_repo(pip, "compression_deps") + pip.parse( hub_name = "python_deps", python_version = "3.11", diff --git a/README.md b/README.md index e9a6745..b389934 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,7 @@ foundation models from Google. For additional information about Gemma, see [ai.google.dev/gemma](https://ai.google.dev/gemma). Model weights, including gemma.cpp specific artifacts, are -[available on kaggle](https://www.kaggle.com/models/google/gemma). +[available on kaggle](https://www.kaggle.com/models/google/gemma-2). ## Who is this project for? @@ -18,8 +18,8 @@ deployment-oriented C++ inference runtimes, which are not designed for experimentation, and Python-centric ML research frameworks, which abstract away low-level computation through compilation. -gemma.cpp provides a minimalist implementation of Gemma-1, Gemma-2, Gemma-3, and -PaliGemma models, focusing on simplicity and directness rather than full +gemma.cpp provides a minimalist implementation of Gemma-2, Gemma-3, and +PaliGemma-2 models, focusing on simplicity and directness rather than full generality. This is inspired by vertically-integrated model implementations such as [ggml](https://github.com/ggerganov/ggml), [llama.c](https://github.com/karpathy/llama2.c), and @@ -45,9 +45,41 @@ this invite link](https://discord.gg/H5jCBAWxAe). This project follows [Google's Open Source Community Guidelines](https://opensource.google.com/conduct/). -*Active development is currently done on the `dev` branch. Please open pull -requests targeting `dev` branch instead of `main`, which is intended to be more -stable.* +> [!NOTE] Active development is currently done on the `dev` branch. Please open +> pull requests targeting `dev` branch instead of `main`, which is intended to +> be more stable. + +## What's inside? + +- LLM + + - CPU-only inference for: Gemma 2-3, Griffin(SSM), PaliGemma 2. + - Sampling with TopK and temperature. + - Backward pass (VJP) and Adam optimizer for Gemma research. + +- Optimizations + + - Mixed-precision (fp8, bf16, fp32, fp64 bit) GEMM: + - Designed for BF16 instructions, can efficiently emulate them. + - Automatic runtime autotuning 7 parameters per matrix shape. + - Weight compression integrated directly into GEMM: + - Custom fp8 format with 2..3 mantissa bits; tensor scaling. + - Also bf16, f32 and non-uniform 4-bit (NUQ); easy to add new formats. + +- Infrastructure + + - SIMD: single implementation via Highway. Chooses ISA at runtime. + - Tensor parallelism: CCX-aware, multi-socket thread pool. + - Disk I/O: memory map or parallel read (heuristic with user override). + - Custom format with forward/backward-compatible metadata serialization. + - Model conversion from Safetensors, not yet open sourced. + - Portability: Linux, Windows/OS X supported. CMake/Bazel. 'Any' CPU. + +- Frontends + + - C++ APIs with streaming for single query and batched inference. + - Basic interactive command-line app. + - Basic Python bindings (pybind11). ## Quick Start @@ -74,57 +106,20 @@ winget install --id Microsoft.VisualStudio.2022.BuildTools --force --override "- Visit the [Kaggle page for Gemma-2](https://www.kaggle.com/models/google/gemma-2/gemmaCpp) -[or Gemma-1](https://www.kaggle.com/models/google/gemma/frameworks/gemmaCpp), and select `Model Variations |> Gemma C++`. On this tab, the `Variation` dropdown includes the options below. Note bfloat16 weights are higher fidelity, while 8-bit switched floating point weights enable faster inference. In general, we recommend starting with the `-sfp` checkpoints. -If you are unsure which model to start with, we recommend starting with the -smallest Gemma-2 model, i.e. `2.0-2b-it-sfp`. - -Alternatively, visit the -[gemma.cpp](https://huggingface.co/models?other=gemma.cpp) models on the Hugging -Face Hub. First go the model repository of the model of interest (see -recommendations below). Then, click the `Files and versions` tab and download -the model and tokenizer files. For programmatic downloading, if you have -`huggingface_hub` installed, you can also download by running: - -``` -huggingface-cli login # Just the first time -huggingface-cli download google/gemma-2b-sfp-cpp --local-dir build/ -``` - -Gemma-1 2B instruction-tuned (`it`) and pre-trained (`pt`) models: - -| Model name | Description | -| ----------- | ----------- | -| `2b-it` | 2 billion parameter instruction-tuned model, bfloat16 | -| `2b-it-sfp` | 2 billion parameter instruction-tuned model, 8-bit switched floating point | -| `2b-pt` | 2 billion parameter pre-trained model, bfloat16 | -| `2b-pt-sfp` | 2 billion parameter pre-trained model, 8-bit switched floating point | - -Gemma-1 7B instruction-tuned (`it`) and pre-trained (`pt`) models: - -| Model name | Description | -| ----------- | ----------- | -| `7b-it` | 7 billion parameter instruction-tuned model, bfloat16 | -| `7b-it-sfp` | 7 billion parameter instruction-tuned model, 8-bit switched floating point | -| `7b-pt` | 7 billion parameter pre-trained model, bfloat16 | -| `7b-pt-sfp` | 7 billion parameter pre-trained model, 8-bit switched floating point | - -> [!NOTE] -> **Important**: We strongly recommend starting off with the `2b-it-sfp` model to -> get up and running. +> [!NOTE] **Important**: We strongly recommend starting off with the +> `gemma2-2b-it-sfp` model to get up and running. Gemma 2 models are named `gemma2-2b-it` for 2B and `9b-it` or `27b-it`. See the -`kModelFlags` definition in `common.cc`. +`ModelPrefix` function in `configs.cc`. ### Step 2: Extract Files -If you downloaded the models from Hugging Face, skip to step 3. - After filling out the consent form, the download should proceed to retrieve a tar archive file `archive.tar.gz`. Extract files from `archive.tar.gz` (this can take a few minutes): @@ -162,10 +157,9 @@ cmake --build --preset make -j [number of parallel threads to use] ``` Replace `[number of parallel threads to use]` with a number - the number of -cores available on your system is a reasonable heuristic. For example, -`make -j4 gemma` will build using 4 threads. If the `nproc` command is -available, you can use `make -j$(nproc) gemma` as a reasonable default -for the number of threads. +cores available on your system is a reasonable heuristic. For example, `make -j4 +gemma` will build using 4 threads. If the `nproc` command is available, you can +use `make -j$(nproc) gemma` as a reasonable default for the number of threads. If you aren't sure of the right value for the `-j` flag, you can simply run `make gemma` instead and it should still build the `./gemma` executable. @@ -174,7 +168,8 @@ If you aren't sure of the right value for the `-j` flag, you can simply run > On Windows Subsystem for Linux (WSL) users should set the number of > parallel threads to 1. Using a larger number may result in errors. -If the build is successful, you should now have a `gemma` executable in the `build/` directory. +If the build is successful, you should now have a `gemma` executable in the +`build/` directory. #### Windows @@ -186,7 +181,8 @@ cmake --preset windows cmake --build --preset windows -j [number of parallel threads to use] ``` -If the build is successful, you should now have a `gemma.exe` executable in the `build/` directory. +If the build is successful, you should now have a `gemma.exe` executable in the +`build/` directory. #### Bazel @@ -194,7 +190,8 @@ If the build is successful, you should now have a `gemma.exe` executable in the bazel build -c opt --cxxopt=-std=c++20 :gemma ``` -If the build is successful, you should now have a `gemma` executable in the `bazel-bin/` directory. +If the build is successful, you should now have a `gemma` executable in the +`bazel-bin/` directory. #### Make @@ -208,33 +205,21 @@ You can now run `gemma` from inside the `build/` directory. `gemma` has the following required arguments: -Argument | Description | Example value ---------------- | ---------------------------- | ----------------------- -`--model` | The model type. | `2b-it` ... (see below) -`--weights` | The compressed weights file. | `2b-it-sfp.sbs` -`--weight_type` | The compressed weight type. | `sfp` -`--tokenizer` | The tokenizer file. | `tokenizer.spm` - -`gemma` is invoked as: - -```sh -./gemma \ ---tokenizer [tokenizer file] \ ---weights [compressed weights file] \ ---weight_type [f32 or bf16 or sfp (default:sfp)] \ ---model [2b-it or 2b-pt or 7b-it or 7b-pt or ...] -``` +Argument | Description | Example value +------------- | ---------------------------- | --------------- +`--weights` | The compressed weights file. | `2b-it-sfp.sbs` +`--tokenizer` | The tokenizer file. | `tokenizer.spm` Example invocation for the following configuration: -- Compressed weights file `2b-it-sfp.sbs` (2B instruction-tuned model, 8-bit - switched floating point). -- Tokenizer file `tokenizer.spm`. +- weights file `gemma2-2b-it-sfp.sbs` (Gemma2 2B instruction-tuned model, + 8-bit switched floating point). +- Tokenizer file `tokenizer.spm` (can omit for single-format weights files + created after 2025-05-06, or output by migrate_weights.cc). ```sh ./gemma \ ---tokenizer tokenizer.spm \ ---weights 2b-it-sfp.sbs --model 2b-it +--tokenizer tokenizer.spm --weights gemma2-2b-it-sfp.sbs ``` ### RecurrentGemma @@ -256,23 +241,20 @@ Step 1, and run the binary as follows: ### PaliGemma Vision-Language Model -This repository includes a version of the PaliGemma VLM -([paper](https://arxiv.org/abs/2407.07726), -[code](https://github.com/google-research/big_vision/tree/main/big_vision/configs/proj/paligemma)) -and its successor PaliGemma 2 ([paper](https://arxiv.org/abs/2412.03555)). We -provide a C++ implementation of the PaliGemma model family here. +This repository includes a version of the PaliGemma 2 VLM +([paper](https://arxiv.org/abs/2412.03555)). We provide a C++ implementation of +the PaliGemma 2 model here. To use the version of PaliGemma included in this repository, build the gemma binary as noted above in Step 3. Download the compressed weights and tokenizer from -[Kaggle](https://www.kaggle.com/models/google/paligemma/gemmaCpp/paligemma-3b-mix-224) +[Kaggle](https://www.kaggle.com/models/google/paligemma-2/gemmaCpp/paligemma2-3b-mix-224) and run the binary as follows: ```sh ./gemma \ --tokenizer paligemma_tokenizer.model \ ---model paligemma-224 \ ---weights paligemma-3b-mix-224-sfp.sbs \ +--weights paligemma2-3b-mix-224-sfp.sbs \ --image_file paligemma/testdata/image.ppm ``` @@ -312,12 +294,12 @@ allows to contain the tokenizer (and the model type) directly. A tool to migrate from the multi-file format to the single-file format is available. ```sh -compression/migrate_weights \ +io/migrate_weights \ --tokenizer .../tokenizer.spm --weights .../gemma2-2b-it-sfp.sbs \ - --model gemma2-2b-it --output_weights .../gemma2-2b-it-sfp-single.sbs + --output_weights .../gemma2-2b-it-sfp-single.sbs ``` -After migration, you can use the new weights file with gemma.cpp like this: +After migration, you can omit the tokenizer argument like this: ```sh ./gemma --weights .../gemma2-2b-it-sfp-single.sbs @@ -325,15 +307,6 @@ After migration, you can use the new weights file with gemma.cpp like this: ### Troubleshooting and FAQs -**Running `./gemma` fails with "Failed to read cache gating_ein_0 (error 294) ..."** - -The most common problem is that the `--weight_type` argument does not match that -of the model file. Revisit step #3 and check which weights you downloaded. - -Note that we have already moved weight type from a compile-time decision to a -runtime argument. In a subsequent step, we plan to bake this information into -the weights. - **Problems building in Windows / Visual Studio** Currently if you're using Windows, we recommend building in WSL (Windows @@ -344,22 +317,22 @@ configurations, see issues for active discussion. A common issue is that you are using a pre-trained model, which is not instruction-tuned and thus does not respond to instructions. Make sure you are -using an instruction-tuned model (`2b-it-sfp`, `2b-it`, `7b-it-sfp`, `7b-it`) -and not a pre-trained model (any model with a `-pt` suffix). +using an instruction-tuned model (`gemma2-2b-it-sfp`) and not a pre-trained +model (any model with a `-pt` suffix). **What sequence lengths are supported?** -See `seq_len` in `configs.cc`. For the Gemma 3 models larger than 1B, this is -typically 32K but 128K would also work given enough RAM. Note that long -sequences will be slow due to the quadratic cost of attention. +See `max_seq_len` in `configs.cc` and `InferenceArgs.seq_len`. For the Gemma 3 +models larger than 1B, this is typically 32K but 128K would also work given +enough RAM. Note that long sequences will be slow due to the quadratic cost of +attention. **How do I convert my fine-tune to a `.sbs` compressed model file?** -For PaliGemma (1 and 2) checkpoints, you can use -python/convert_from_safetensors.py to convert from safetensors format (tested -with building via bazel). For an adapter model, you will likely need to call -merge_and_unload() to convert the adapter model to a single-file format before -converting it. +For PaliGemma 2 checkpoints, you can use python/convert_from_safetensors.py to +convert from safetensors format (tested with building via bazel). For an adapter +model, you will likely need to call merge_and_unload() to convert the adapter +model to a single-file format before converting it. Here is how to use it using a bazel build of the compression library assuming locally installed (venv) torch, numpy, safetensors, absl-py, etc.: @@ -373,22 +346,18 @@ ln -s $BAZEL_OUTPUT_DIR [...]/site-packages/compression python3 python/convert_from_safetensors.py --load_path [...].safetensors.index.json ``` -See also compression/convert_weights.py for a slightly older option to convert a -pytorch checkpoint. (The code may need updates to work with Gemma-2 models.) - **What are some easy ways to make the model run faster?** 1. Make sure you are using the 8-bit switched floating point `-sfp` models. These are half the size of bf16 and thus use less memory bandwidth and cache space. -2. If you're on a laptop, make sure power mode is set to maximize performance +2. Due to auto-tuning, the second and especially third query will be faster. +3. If you're on a laptop, make sure power mode is set to maximize performance and saving mode is **off**. For most laptops, the power saving modes get activated automatically if the computer is not plugged in. -3. Close other unused cpu-intensive applications. -4. On macs, anecdotally we observe a "warm-up" ramp-up in speed as performance +4. Close other unused cpu-intensive applications. +5. On macs, anecdotally we observe a "warm-up" ramp-up in speed as performance cores get engaged. -5. Experiment with the `--num_threads` argument value. Depending on the device, - larger numbers don't always mean better performance. We're also working on algorithmic and optimization approaches for faster inference, stay tuned. @@ -411,7 +380,7 @@ newline input. By default, verbosity is set to 1, bringing up a terminal-based interactive interface when `gemma` is invoked: -```console +```sh $ ./gemma [...] __ _ ___ _ __ ___ _ __ ___ __ _ ___ _ __ _ __ / _` |/ _ \ '_ ` _ \| '_ ` _ \ / _` | / __| '_ \| '_ \ @@ -420,11 +389,7 @@ $ ./gemma [...] __/ | | | | | |___/ |_| |_| -tokenizer : tokenizer.spm -compressed_weights : 2b-it-sfp.sbs -model : 2b-it -weights : [no path specified] -max_generated_tokens : 2048 +... *Usage* Enter an instruction and press enter (%C reset conversation, %Q quits). @@ -462,7 +427,7 @@ For using the `gemma` executable as a command line tool, it may be useful to create an alias for gemma.cpp with arguments fully specified: ```sh -alias gemma2b="~/gemma.cpp/build/gemma -- --tokenizer ~/gemma.cpp/build/tokenizer.spm --weights ~/gemma.cpp/build/gemma2-2b-it-sfp.sbs --model gemma2-2b-it --verbosity 0" +alias gemma2b="~/gemma.cpp/build/gemma -- --tokenizer ~/gemma.cpp/build/tokenizer.spm --weights ~/gemma.cpp/build/gemma2-2b-it-sfp.sbs --verbosity 0" ``` Replace the above paths with your own paths to the model and tokenizer paths @@ -481,7 +446,7 @@ cat configs.h | tail -n 35 | tr '\n' ' ' | xargs -0 echo "What does this C++ cod The output of the above command should look like: -```console +```sh [ Reading prompt ] [...] This C++ code snippet defines a set of **constants** used in a large language model (LLM) implementation, likely related to the **attention mechanism**. @@ -492,8 +457,8 @@ Let's break down the code: ### Incorporating gemma.cpp as a Library in your Project The easiest way to incorporate gemma.cpp in your own project is to pull in -gemma.cpp and dependencies using `FetchContent`. You can add the following to your -CMakeLists.txt: +gemma.cpp and dependencies using `FetchContent`. You can add the following to +your CMakeLists.txt: ``` include(FetchContent) @@ -562,9 +527,10 @@ submit a PR with a `README.md` edit. ## Acknowledgements and Contacts -gemma.cpp was started in fall 2023 by [Austin Huang](mailto:austinvhuang@google.com) -and [Jan Wassenberg](mailto:janwas@google.com), and subsequently released February 2024 -thanks to contributions from Phil Culliton, Paul Chang, and Dan Zheng. +gemma.cpp was started in fall 2023 by +[Austin Huang](mailto:austinvhuang@google.com) and +[Jan Wassenberg](mailto:janwas@google.com), and subsequently released February +2024 thanks to contributions from Phil Culliton, Paul Chang, and Dan Zheng. Griffin support was implemented in April 2024 thanks to contributions by Andrey Mikhaylov, Eugene Kliuchnikov, Jan Wassenberg, Jyrki Alakuijala, Lode diff --git a/backprop/activations.h b/backprop/activations.h deleted file mode 100644 index c616759..0000000 --- a/backprop/activations.h +++ /dev/null @@ -1,81 +0,0 @@ -// Copyright 2024 Google LLC -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef THIRD_PARTY_GEMMA_CPP_BACKPROP_ACTIVATIONS_H_ -#define THIRD_PARTY_GEMMA_CPP_BACKPROP_ACTIVATIONS_H_ - -#include - -#include - -#include "compression/compress.h" // MatStorageT -#include "gemma/configs.h" // ModelConfig - -namespace gcpp { - -template -struct ForwardLayer { - ForwardLayer(const LayerConfig& config, size_t seq_len) - : input("input", seq_len, config.model_dim), - pre_att_rms_out("pre_att_rms_out", seq_len, config.model_dim), - qkv("qkv", seq_len * (config.heads + 2), config.qkv_dim), - att("att", seq_len * config.heads, seq_len), - att_out("att_out", seq_len * config.heads, config.qkv_dim), - att_post1("att_post1", seq_len, config.model_dim), - attention_out("attention_out", seq_len, config.model_dim), - bf_pre_ffw_rms_out("bf_pre_ffw_rms_out", seq_len, config.model_dim), - ffw_hidden("ffw_hidden", seq_len, config.ff_hidden_dim * 2), - ffw_hidden_gated("ffw_hidden_gated", seq_len, config.ff_hidden_dim), - layer_config(config) {} - - MatStorageT input; - MatStorageT pre_att_rms_out; - MatStorageT qkv; - MatStorageT att; - MatStorageT att_out; - MatStorageT att_post1; - MatStorageT attention_out; - MatStorageT bf_pre_ffw_rms_out; - MatStorageT ffw_hidden; - MatStorageT ffw_hidden_gated; - const LayerConfig& layer_config; -}; - -template -struct ForwardPass { - ForwardPass(const ModelConfig& config) - : final_layer_output("final_layer_output", config.seq_len, - config.model_dim), - final_norm_output("final_norm_output", config.seq_len, - config.model_dim), - logits("logits", config.seq_len, config.vocab_size), - probs("probs", config.seq_len, config.vocab_size), - weights_config(config) { - for (const auto& layer_config : config.layer_configs) { - layers.emplace_back(layer_config, config.seq_len); - } - } - - std::vector> layers; - MatStorageT final_layer_output; - MatStorageT final_norm_output; - MatStorageT logits; - MatStorageT probs; - const ModelConfig& weights_config; -}; - -} // namespace gcpp - -#endif // THIRD_PARTY_GEMMA_CPP_BACKPROP_ACTIVATIONS_H_ diff --git a/backprop/backward-inl.h b/backprop/backward-inl.h deleted file mode 100644 index 2a0f330..0000000 --- a/backprop/backward-inl.h +++ /dev/null @@ -1,404 +0,0 @@ -// Copyright 2024 Google LLC -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Implementation of the Vector-Jacobian Products (VJP) of the individual -// operations of the forward pass. - -// Include guard for non-SIMD code. -#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_INL_H_ -#define THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_INL_H_ - -#include - -#include -#include - -#include "backprop/activations.h" -#include "backprop/prompt.h" -#include "gemma/common.h" -#include "gemma/weights.h" -#include "util/allocator.h" -#include "hwy/base.h" -#include "hwy/contrib/thread_pool/thread_pool.h" - -#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_INL_H_ - -// Include guard for (potentially) SIMD code. -#if defined(THIRD_PARTY_GEMMA_CPP_BACKWARD_TOGGLE) == defined(HWY_TARGET_TOGGLE) -#ifdef THIRD_PARTY_GEMMA_CPP_BACKWARD_TOGGLE -#undef THIRD_PARTY_GEMMA_CPP_BACKWARD_TOGGLE -#else -#define THIRD_PARTY_GEMMA_CPP_BACKWARD_TOGGLE -#endif - -#include "hwy/highway.h" -// After highway.h -#include "ops/matmul-inl.h" -#include "ops/ops-inl.h" -#include "hwy/contrib/dot/dot-inl.h" - -HWY_BEFORE_NAMESPACE(); -namespace gcpp { -namespace HWY_NAMESPACE { -namespace hn = hwy::HWY_NAMESPACE; - -HWY_INLINE void MatMulVJP(const float* HWY_RESTRICT weights, // kRows * kCols, - const float* HWY_RESTRICT x, // num_tokens * kCols - const float* HWY_RESTRICT v, // num_tokens * kRows - size_t cols, size_t rows, size_t num_tokens, - float* HWY_RESTRICT grad_w, // kRows * kCols, - float* HWY_RESTRICT grad_x, // num_tokens * kCols - hwy::ThreadPool& pool) { - hwy::ZeroBytes(grad_x, num_tokens * cols * sizeof(grad_x[0])); - for (size_t pos = 0; pos < num_tokens; ++pos) { - const size_t voffs = pos * rows; - const size_t xoffs = pos * cols; - for (size_t j = 0; j < rows; ++j) { - MulByConstAndAdd(v[voffs + j], &x[xoffs], &grad_w[j * cols], cols); - MulByConstAndAdd(v[voffs + j], &weights[j * cols], &grad_x[xoffs], cols); - } - } -} - -HWY_INLINE void MultiHeadMatMulVJP( - const float* HWY_RESTRICT weights, // heads * kRows * kCols - const float* HWY_RESTRICT x, // num_tokens * heads * kCols - const float* HWY_RESTRICT v, // num_tokens * kRows - size_t heads, size_t cols, size_t rows, size_t num_tokens, - float* HWY_RESTRICT grad_w, // heads * kRows * kCols - float* HWY_RESTRICT grad_x, // num_tokens * heads * kCols - hwy::ThreadPool& pool) { - hwy::ZeroBytes(grad_x, num_tokens * heads * cols * sizeof(grad_x[0])); - for (size_t pos = 0; pos < num_tokens; ++pos) { - for (size_t j = 0; j < rows; ++j) { - for (size_t h = 0; h < heads; ++h) { - MulByConstAndAdd(v[pos * rows + j], &x[pos * heads * cols + h * cols], - &grad_w[h * rows * cols + j * cols], cols); - MulByConstAndAdd(v[pos * rows + j], - &weights[h * rows * cols + j * cols], - &grad_x[pos * heads * cols + h * cols], cols); - } - } - } -} - -template -static HWY_INLINE hn::Vec DGelu(D d, hn::Vec v) { - const hn::Vec kMul = hn::Set(d, 0.044715f); - const hn::Vec kSqrt2OverPi = hn::Set(d, 0.797884560804236f); - const hn::Vec kHalf = hn::Set(d, 0.5f); - const hn::Vec kOne = hn::Set(d, 1.0f); - // kSqrtOverPi*3*kMul - const hn::Vec kMulv2 = hn::Set(d, 0.1070322244f); - - const hn::Vec v2 = hn::Mul(v, v); - const hn::Vec v3 = hn::Mul(v2, v); - const hn::Vec arg = hn::Mul(kSqrt2OverPi, hn::MulAdd(kMul, v3, v)); - const hn::Vec tanh = hn::Tanh(d, arg); - const hn::Vec cdf = hn::MulAdd(kHalf, tanh, kHalf); - const hn::Vec dtanh = hn::Sub(kOne, hn::Mul(tanh, tanh)); - const hn::Vec darg = hn::MulAdd(kMulv2, v2, kSqrt2OverPi); - return hn::MulAdd(kHalf, hn::Mul(v, hn::Mul(dtanh, darg)), cdf); -} - -static HWY_NOINLINE void SoftmaxVJP(const float* HWY_RESTRICT forward, - float* HWY_RESTRICT backward, - const size_t size) { - namespace hn = hwy::HWY_NAMESPACE; - using D = hn::ScalableTag; - const D d; - - const auto offset = - hn::Set(d, hn::Dot::Compute<0>(d, forward, backward, size)); - hn::Transform1( - d, backward, size, forward, - [&offset](const auto d, const auto v, const auto y) - HWY_ATTR { return hn::Mul(y, hn::Sub(v, offset)); }); -} - -static HWY_NOINLINE void RMSNormVJP( - const float* HWY_RESTRICT weights, const float* HWY_RESTRICT x, - const float* HWY_RESTRICT v, size_t model_dim, size_t num_tokens, - float* HWY_RESTRICT grad_w, float* HWY_RESTRICT grad_x, - hwy::ThreadPool& pool) { - for (size_t pos = 0; pos < num_tokens; ++pos) { - const size_t offset = pos * model_dim; - const float ss = detail::RMSNormMul(x + offset, model_dim); - - for (size_t i = 0; i < model_dim; ++i) { - grad_w[i] += v[offset + i] * x[offset + i] * ss; - } - const float ss3 = ss * ss * ss / StaticCast(model_dim); - float tmp = 0.0f; - for (size_t i = 0; i < model_dim; ++i) { - tmp += (1.0f + weights[i]) * v[offset + i] * x[offset + i]; - } - tmp *= ss3; - for (size_t i = 0; i < model_dim; ++i) { - grad_x[offset + i] = ss * (1.0f + weights[i]) * v[offset + i] - - tmp * x[offset + i]; - } - } -} - -static HWY_NOINLINE void InputEmbeddingVJP( - const float* weights, const std::vector& prompt, - const float scaling, const float* HWY_RESTRICT v, - float* HWY_RESTRICT grad, size_t model_dim) { - HWY_ASSERT(!prompt.empty()); - for (size_t pos = 0; pos < prompt.size() - 1; ++pos) { - int token = prompt[pos]; - MulByConstAndAdd(scaling, v + pos * model_dim, - grad + token * model_dim, model_dim); - } -} - -template -void LayerVJP(const LayerWeightsPtrs& weights, - const ForwardLayer& forward, - const float* HWY_RESTRICT next_layer_grad, size_t num_tokens, - LayerWeightsPtrs& grad, ForwardLayer& backward, - const RowVectorBatch& inv_timescale, - hwy::ThreadPool& pool) { - const LayerConfig& config = weights.layer_config; - const size_t model_dim = config.model_dim; - const size_t qkv_dim = config.qkv_dim; - const size_t heads = config.heads; - const size_t seq_len = forward.input.Rows(); - const size_t ff_hidden_dim = config.ff_hidden_dim; - const float query_scale = - static_cast(1.0 / sqrt(static_cast(qkv_dim))); - HWY_ASSERT(num_tokens <= seq_len); - - MatMulVJP(weights.linear_w.data(), forward.ffw_hidden_gated.data(), - next_layer_grad, ff_hidden_dim, model_dim, num_tokens, - grad.linear_w.data(), backward.ffw_hidden_gated.data(), pool); - - for (size_t pos = 0; pos < num_tokens; ++pos) { - const size_t hidden_offset = pos * ff_hidden_dim * 2; - const float* HWY_RESTRICT f_out = forward.ffw_hidden.data() + hidden_offset; - const float* HWY_RESTRICT f_out_mul = f_out + ff_hidden_dim; - const float* HWY_RESTRICT b_out_gated = - backward.ffw_hidden_gated.data() + pos * ff_hidden_dim; - float* HWY_RESTRICT b_out = backward.ffw_hidden.data() + hidden_offset; - float* HWY_RESTRICT b_out_mul = b_out + ff_hidden_dim; - namespace hn = hwy::HWY_NAMESPACE; - using DF = hn::ScalableTag; - DF df; - for (size_t i = 0; i < ff_hidden_dim; i += Lanes(df)) { - const auto y = Load(df, f_out + i); - const auto x = Load(df, f_out_mul + i); - const auto v = Load(df, b_out_gated + i); - hn::Store(hn::Mul(v, Gelu(df, y)), df, b_out_mul + i); - hn::Store(hn::Mul(v, hn::Mul(x, DGelu(df, y))), df, b_out + i); - } - } - - MatMulVJP(weights.gating_einsum_w.data(), forward.bf_pre_ffw_rms_out.data(), - backward.ffw_hidden.data(), model_dim, ff_hidden_dim * 2, - num_tokens, grad.gating_einsum_w.data(), - backward.bf_pre_ffw_rms_out.data(), pool); - RMSNormVJP(weights.pre_ffw_norm_scale.data(), forward.attention_out.data(), - backward.bf_pre_ffw_rms_out.data(), model_dim, num_tokens, - grad.pre_ffw_norm_scale.data(), backward.attention_out.data(), - pool); - - for (size_t pos = 0; pos < num_tokens; ++pos) { - AddFrom(next_layer_grad + pos * model_dim, - backward.attention_out.data() + pos * model_dim, model_dim); - } - - backward.qkv.ZeroInit(); - - MultiHeadMatMulVJP(weights.attn_vec_einsum_w.data(), forward.att_out.data(), - backward.attention_out.data(), heads, qkv_dim, model_dim, - num_tokens, grad.attn_vec_einsum_w.data(), - backward.att_out.data(), pool); - - for (size_t head = 0; head < heads; ++head) { - for (size_t pos = 0; pos < num_tokens; ++pos) { - const size_t aoffset = head * seq_len + pos * heads * seq_len; - const float* HWY_RESTRICT f_head_att = forward.att.data() + aoffset; - const float* HWY_RESTRICT b_att_out = - backward.att_out.data() + (pos * heads + head) * qkv_dim; - float* HWY_RESTRICT b_head_att = backward.att.data() + aoffset; - for (size_t pos2 = 0; pos2 <= pos; ++pos2) { - const size_t v2offs = (pos2 * (heads + 2) + heads + 1) * qkv_dim; - const float* HWY_RESTRICT f_v2 = forward.qkv.data() + v2offs; - float* HWY_RESTRICT b_v2 = backward.qkv.data() + v2offs; - b_head_att[pos2] = Dot(b_att_out, f_v2, qkv_dim); - MulByConstAndAdd(f_head_att[pos2], b_att_out, b_v2, qkv_dim); - } - } - } - - for (size_t head = 0; head < heads; ++head) { - for (size_t pos = 0; pos < num_tokens; ++pos) { - const size_t aoffset = head * seq_len + pos * heads * seq_len; - const float* HWY_RESTRICT f_head_att = forward.att.data() + aoffset; - float* HWY_RESTRICT b_head_att = backward.att.data() + aoffset; - SoftmaxVJP(f_head_att, b_head_att, pos + 1); - } - } - - for (size_t head = 0; head < heads; ++head) { - for (size_t pos = 0; pos < num_tokens; ++pos) { - const size_t qoffs = (pos * (heads + 2) + head) * qkv_dim; - const size_t aoffs = head * seq_len + pos * heads * seq_len; - const float* HWY_RESTRICT f_q = forward.qkv.data() + qoffs; - const float* HWY_RESTRICT b_head_att = backward.att.data() + aoffs; - float* HWY_RESTRICT b_q = backward.qkv.data() + qoffs; - for (size_t pos2 = 0; pos2 <= pos; ++pos2) { - const size_t k2offs = (pos2 * (heads + 2) + heads) * qkv_dim; - const float* HWY_RESTRICT f_k2 = forward.qkv.data() + k2offs; - float* HWY_RESTRICT b_k2 = backward.qkv.data() + k2offs; - MulByConstAndAdd(b_head_att[pos2], f_k2, b_q, qkv_dim); - MulByConstAndAdd(b_head_att[pos2], f_q, b_k2, qkv_dim); - } - } - } - - for (int pos = 0; pos < static_cast(num_tokens); ++pos) { - float* HWY_RESTRICT b_kv = - backward.qkv.data() + (pos * (heads + 2) + heads) * qkv_dim; - Rope(b_kv, qkv_dim, inv_timescale.Const(), -pos); - } - - for (size_t head = 0; head < heads; ++head) { - for (size_t pos = 0; pos < num_tokens; ++pos) { - float* HWY_RESTRICT b_q = - backward.qkv.data() + (pos * (heads + 2) + head) * qkv_dim; - MulByConst(query_scale, b_q, qkv_dim); - Rope(b_q, qkv_dim, inv_timescale.Const(), -pos); - } - } - - MatMulVJP(weights.qkv_einsum_w.data(), forward.pre_att_rms_out.data(), - backward.qkv.data(), model_dim, (heads + 2) * qkv_dim, num_tokens, - grad.qkv_einsum_w.data(), backward.pre_att_rms_out.data(), pool); - RMSNormVJP(weights.pre_attention_norm_scale.data(), forward.input.data(), - backward.pre_att_rms_out.data(), model_dim, num_tokens, - grad.pre_attention_norm_scale.data(), backward.input.data(), pool); - for (size_t pos = 0; pos < num_tokens; ++pos) { - AddFrom(backward.attention_out.data() + pos * model_dim, - backward.input.data() + pos * model_dim, model_dim); - } -} - -static HWY_NOINLINE void SoftcapVJP(const float cap, - const float* HWY_RESTRICT forward, - float* HWY_RESTRICT backward, - const size_t size) { - namespace hn = hwy::HWY_NAMESPACE; - using D = hn::ScalableTag; - const D d; - - const auto one = hn::Set(d, 1.0f); - const auto vcap = hn::Set(d, cap); - const auto vinv_cap = hn::Div(hn::Set(d, 1.0f), vcap); - hn::Transform1(d, backward, size, forward, - [&](const auto d, const auto v, const auto y) HWY_ATTR { - const auto scaled = hn::Mul(vinv_cap, y); // = tanh - return hn::Mul(v, hn::Sub(one, hn::Mul(scaled, scaled))); - }); -} - -static HWY_NOINLINE void CrossEntropyLossGrad( - const float* HWY_RESTRICT x, float* HWY_RESTRICT grad, - const Prompt& prompt, size_t vocab_size) { - HWY_ASSERT(!prompt.tokens.empty()); - const float scaling = -1.0 / std::log(2.0); - size_t num_tokens = prompt.tokens.size() - 1; - hwy::ZeroBytes(grad, num_tokens * vocab_size * sizeof(grad[0])); - for (size_t pos = 0; pos < num_tokens; ++pos) { - if (pos + 1 < prompt.context_size) { - continue; - } - const int next_token = prompt.tokens[pos + 1]; - grad[pos * vocab_size + next_token] = - scaling / x[pos * vocab_size + next_token]; - } -} - -template -void CrossEntropyLossBackwardPassInl(const Prompt& prompt, - const ModelWeightsPtrs& weights, - const ForwardPass& forward, - ModelWeightsPtrs& grad, - ForwardPass& backward, - RowVectorBatch& inv_timescale, - hwy::ThreadPool& pool) { - const ModelConfig& config = weights.weights_config; - const size_t kVocabSize = config.vocab_size; - const size_t model_dim = config.model_dim; - const size_t kLayers = config.layer_configs.size(); - const float kEmbScaling = EmbeddingScaling(model_dim); - HWY_ASSERT(!config.absolute_pe); - HWY_ASSERT(config.layer_configs[0].post_norm == PostNormType::None); - HWY_ASSERT(config.layer_configs[0].kv_heads == 1); - - HWY_DASSERT(prompt.context_size > 0); - HWY_DASSERT(prompt.context_size < prompt.tokens.size()); - const size_t num_tokens = prompt.tokens.size() - 1; - - CrossEntropyLossGrad(forward.probs.data(), backward.logits.data(), prompt, - kVocabSize); - - for (size_t pos = 0; pos < num_tokens; ++pos) { - SoftmaxVJP(forward.probs.data() + pos * kVocabSize, - backward.logits.data() + pos * kVocabSize, - kVocabSize); - } - - if (config.final_cap > 0.0f) { - for (size_t pos = 0; pos < num_tokens; ++pos) { - SoftcapVJP(config.final_cap, forward.logits.data() + pos * kVocabSize, - backward.logits.data() + pos * kVocabSize, kVocabSize); - } - } - - MatMulVJP(weights.embedder_input_embedding.data(), - forward.final_norm_output.data(), backward.logits.data(), model_dim, - kVocabSize, num_tokens, grad.embedder_input_embedding.data(), - backward.final_norm_output.data(), pool); - - RMSNormVJP(weights.final_norm_scale.data(), forward.final_layer_output.data(), - backward.final_norm_output.data(), model_dim, num_tokens, - grad.final_norm_scale.data(), backward.final_layer_output.data(), - pool); - - for (int layer = static_cast(kLayers) - 1; layer >= 0; --layer) { - auto layer_config = config.layer_configs[layer]; - // TODO(szabadka) Implement Griffin layer vjp. - HWY_ASSERT(layer_config.type == LayerAttentionType::kGemma); - float* next_layer_grad = layer + 1 < kLayers - ? backward.layers[layer + 1].input.data() - : backward.final_layer_output.data(); - LayerVJP(*weights.GetLayer(layer), forward.layers[layer], next_layer_grad, - num_tokens, *grad.GetLayer(layer), backward.layers[layer], - inv_timescale, pool); - } - - InputEmbeddingVJP(weights.embedder_input_embedding.data(), prompt.tokens, - kEmbScaling, backward.layers[0].input.data(), - grad.embedder_input_embedding.data(), model_dim); -} - -// NOLINTNEXTLINE(google-readability-namespace-comments) -} // namespace HWY_NAMESPACE -} // namespace gcpp -HWY_AFTER_NAMESPACE(); - -#endif // NOLINT diff --git a/backprop/backward.cc b/backprop/backward.cc deleted file mode 100644 index 868b391..0000000 --- a/backprop/backward.cc +++ /dev/null @@ -1,73 +0,0 @@ -// Copyright 2024 Google LLC -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "backprop/backward.h" - -#include "backprop/activations.h" -#include "backprop/prompt.h" -#include "gemma/common.h" -#include "gemma/weights.h" -#include "util/allocator.h" -#include "hwy/contrib/thread_pool/thread_pool.h" - -// Compiles this file for multiple architectures via "foreach_target.h", to -// which we pass the filename via macro 'argument'. -// clang-format off -#undef HWY_TARGET_INCLUDE -#define HWY_TARGET_INCLUDE "backprop/backward.cc" // NOLINT -// clang-format on -#include "hwy/foreach_target.h" // IWYU pragma: keep - -#include "hwy/highway.h" -// After highway.h -#include "backprop/backward-inl.h" - -HWY_BEFORE_NAMESPACE(); -namespace gcpp { -namespace HWY_NAMESPACE { - -void CrossEntropyLossBackwardPassT(const Prompt& prompt, - const ModelWeightsPtrs& weights, - const ForwardPass& forward, - ModelWeightsPtrs& grad, - ForwardPass& backward, - RowVectorBatch& inv_timescale, - hwy::ThreadPool& pool) { - CrossEntropyLossBackwardPassInl(prompt, weights, forward, grad, backward, - inv_timescale, pool); -} - -} // namespace HWY_NAMESPACE -} // namespace gcpp -HWY_AFTER_NAMESPACE(); - -#if HWY_ONCE -namespace gcpp { - -HWY_EXPORT(CrossEntropyLossBackwardPassT); - -void CrossEntropyLossBackwardPass(const Prompt& prompt, - const ModelWeightsPtrs& weights, - const ForwardPass& forward, - ModelWeightsPtrs& grad, - ForwardPass& backward, - RowVectorBatch& inv_timescale, - hwy::ThreadPool& pool) { - return HWY_DYNAMIC_DISPATCH(CrossEntropyLossBackwardPassT)( - prompt, weights, forward, grad, backward, inv_timescale, pool); -} - -} // namespace gcpp -#endif // HWY_ONCE diff --git a/backprop/backward.h b/backprop/backward.h deleted file mode 100644 index d8e50c7..0000000 --- a/backprop/backward.h +++ /dev/null @@ -1,37 +0,0 @@ -// Copyright 2024 Google LLC -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_H_ -#define THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_H_ - -#include "backprop/activations.h" -#include "backprop/prompt.h" -#include "gemma/weights.h" -#include "util/allocator.h" -#include "hwy/contrib/thread_pool/thread_pool.h" - -namespace gcpp { - -void CrossEntropyLossBackwardPass(const Prompt& prompt, - const ModelWeightsPtrs& weights, - const ForwardPass& forward, - ModelWeightsPtrs& grad, - ForwardPass& backward, - RowVectorBatch& inv_timescale, - hwy::ThreadPool& pool); - -} // namespace gcpp - -#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_H_ diff --git a/backprop/backward_scalar.h b/backprop/backward_scalar.h deleted file mode 100644 index b0a37b3..0000000 --- a/backprop/backward_scalar.h +++ /dev/null @@ -1,349 +0,0 @@ -// Copyright 2024 Google LLC -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_SCALAR_H_ -#define THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_SCALAR_H_ - -#include -#include - -#include -#include - -#include "backprop/activations.h" -#include "backprop/common_scalar.h" -#include "backprop/prompt.h" -#include "gemma/common.h" // EmbeddingScaling -#include "gemma/weights.h" - -namespace gcpp { -template -void MatMulVJPT(const T* w, const T* x, const T* dy, T* dw, T* dx, - size_t N, size_t M, size_t K) { - memset(dx, 0, M * K * sizeof(dx[0])); - for (size_t i = 0; i < K; ++i) { - for (size_t j = 0; j < N; ++j) { - MulByConstAndAddT(dy[i * N + j], &x[i * M], &dw[j * M], M); - MulByConstAndAddT(dy[i * N + j], &w[j * M], &dx[i * M], M); - } - } -} -template -void MultiHeadMatMulVJPT(const T* w, const T* x, const T* dy, T* dw, T* dx, - size_t H, size_t N, size_t M, size_t K) { - memset(dx, 0, H * M * K * sizeof(dx[0])); - for (size_t i = 0; i < K; ++i) { - for (size_t j = 0; j < N; ++j) { - for (size_t h = 0; h < H; ++h) { - MulByConstAndAddT(dy[i * N + j], &x[i * H * M + h * M], - &dw[h * N * M + j * M], M); - MulByConstAndAddT(dy[i * N + j], &w[h * N * M + j * M], - &dx[i * H * M + h * M], M); - } - } - } -} - -template -void RMSNormVJPT(const T* w, const T* x, const T* dy, T* dw, T* dx, - size_t N, size_t K) { - for (size_t i = 0; i < K; ++i) { - constexpr T eps(1e-6); - T ss = SquaredL2(x + i * N, N); - ss = T(1.0) / std::sqrt(ss / T(N) + eps); - for (size_t j = 0; j < N; ++j) { - dw[j] += dy[i * N + j] * x[i * N + j] * ss; - } - const T ss3 = ss * ss * ss / T(N); - T tmp = 0.0; - for (size_t j = 0; j < N; ++j) { - tmp += (T(1.0) + w[j]) * dy[i* N + j] * x[i * N + j]; - } - tmp *= ss3; - for (size_t j = 0; j < N; ++j) { - dx[i * N + j] = ss * (T(1.0) + w[j]) * dy[i* N + j] - tmp * x[i * N + j]; - } - } -} -template -void SoftmaxVJPT(const T* y, T* dy, size_t N) { - T sum = {}; - for (size_t i = 0; i < N; ++i) { - sum += y[i] * dy[i]; - } - for (size_t i = 0; i < N; ++i) { - dy[i] = y[i] * (dy[i] - sum); - } -} -template -void SoftmaxVJPT(const T* y, T* dy, size_t N, size_t K) { - for (size_t i = 0; i < K; ++i) { - SoftmaxVJPT(y + i * N, dy + i * N, N); - } -} - -template -T GeluDerivative(T x) { - static const T kMul = 0.044715; - static const T kSqrt2OverPi = 0.797884560804236; - static const T kMul2 = kSqrt2OverPi * T(3.0) * kMul; - - const T x2 = x * x; - const T x3 = x2 * x; - const T arg = kSqrt2OverPi * (kMul * x3 + x); - const T tanh = std::tanh(arg); - const T cdf = T(0.5) * (T(1.0) + tanh); - const T dtanh = T(1.0) - tanh * tanh; - const T darg = kMul2 * x2 + kSqrt2OverPi; - return T(0.5) * x * dtanh * darg + cdf; -} - -template -void GatedGeluVJP(const T* in, const T* d_out, T* d_in, size_t N, size_t K) { - for (size_t i = 0; i < K; ++i) { - const T* x1 = in + i * 2 * N; - const T* x2 = x1 + N; - const T* v = d_out + i * N; - T* dx1 = d_in + i * 2 * N; - T* dx2 = dx1 + N; - for (size_t j = 0; j < N; ++j) { - dx1[j] = v[j] * x2[j] * GeluDerivative(x1[j]); - dx2[j] = v[j] * Gelu(x1[j]); - } - } -} - -template -void MaskedAttentionVJP(const T* qkv, const T* doutput, T* dqkv, - size_t num_tokens, size_t kHeads, size_t qkv_dim, - size_t seq_len) { - for (size_t pos = 0; pos < num_tokens; ++pos) { - const size_t offset = pos * (kHeads + 2) * qkv_dim; - memset(dqkv + offset, 0, (kHeads + 1) * qkv_dim * sizeof(qkv[0])); - } - for (size_t head = 0; head < kHeads; ++head) { - for (size_t pos = 0; pos < num_tokens; ++pos) { - const size_t qoffs = (pos * (kHeads + 2) + head) * qkv_dim; - const size_t aoffs = head * seq_len + pos * kHeads * seq_len; - const T* q = qkv + qoffs; - const T* dout = doutput + aoffs; - T* dq = dqkv + qoffs; - for (size_t pos2 = 0; pos2 <= pos; ++pos2) { - const size_t koffs = (pos2 * (kHeads + 2) + kHeads) * qkv_dim; - const T* k = qkv + koffs; - T* dk = dqkv + koffs; - MulByConstAndAddT(dout[pos2], k, dq, qkv_dim); - MulByConstAndAddT(dout[pos2], q, dk, qkv_dim); - } - } - } -} - -template -void MaskedSoftmaxVJPT(const T* y, T* dy, size_t num_tokens, size_t kHeads, - size_t seq_len) { - for (size_t head = 0; head < kHeads; ++head) { - for (size_t pos = 0; pos < num_tokens; ++pos) { - size_t offset = pos * kHeads * seq_len + head * seq_len; - SoftmaxVJPT(y + offset, dy + offset, pos + 1); - memset(dy + offset + pos + 1, 0, (seq_len - pos - 1) * sizeof(T)); - } - } -} - -template -void MixByAttentionVJP(const T* qkv, const T* attention, const T* doutput, - T* dqkv, T* dattention, size_t num_tokens, size_t kHeads, - size_t qkv_dim, size_t seq_len) { - auto v_offset = [&](size_t pos) { - return (pos * (kHeads + 2) + kHeads + 1) * qkv_dim; - }; - for (size_t pos = 0; pos < num_tokens; ++pos) { - memset(&dqkv[v_offset(pos)], 0, qkv_dim * sizeof(qkv[0])); - } - for (size_t head = 0; head < kHeads; ++head) { - for (size_t pos = 0; pos < num_tokens; ++pos) { - const size_t offset = head * qkv_dim + pos * kHeads * qkv_dim; - const size_t aoffset = head * seq_len + pos * kHeads * seq_len; - const T* att = &attention[aoffset]; - const T* dout = &doutput[offset]; - T* datt = &dattention[aoffset]; - for (size_t pos2 = 0; pos2 <= pos; ++pos2) { - datt[pos2] = DotT(dout, &qkv[v_offset(pos2)], qkv_dim); - MulByConstAndAddT(att[pos2], dout, &dqkv[v_offset(pos2)], qkv_dim); - } - } - } -} - -template -void InputEmbeddingVJPT(const T* w, const std::vector& tokens, T scaling, - const T* dy, T* dw, size_t N) { - const size_t num_tokens = tokens.empty() ? 0 : tokens.size() - 1; - for (size_t i = 0; i < num_tokens; ++i) { - int token = tokens[i]; - MulByConstAndAddT(scaling, dy + i * N, dw + token * N, N); - } -} - -template -void LayerVJP(const LayerWeightsPtrs& weights, - const ForwardLayer& forward, const T* dy, - LayerWeightsPtrs& grad, ForwardLayer& backward, - size_t num_tokens) { - const LayerConfig& layer_config = weights.layer_config; - const size_t model_dim = layer_config.model_dim; - const size_t seq_len = forward.input.Rows(); - const size_t qkv_dim = layer_config.qkv_dim; - const size_t kHeads = layer_config.heads; - const size_t kFFHiddenDim = layer_config.ff_hidden_dim; - const T kQueryScale = 1.0 / std::sqrt(T(qkv_dim)); - - MatMulVJPT(weights.linear_w.data(), forward.ffw_hidden_gated.data(), dy, - grad.linear_w.data(), backward.ffw_hidden_gated.data(), model_dim, - kFFHiddenDim, num_tokens); - - GatedGeluVJP(forward.ffw_hidden.data(), backward.ffw_hidden_gated.data(), - backward.ffw_hidden.data(), kFFHiddenDim, num_tokens); - - MatMulVJPT(weights.gating_einsum_w.data(), forward.bf_pre_ffw_rms_out.data(), - backward.ffw_hidden.data(), grad.gating_einsum_w.data(), - backward.bf_pre_ffw_rms_out.data(), kFFHiddenDim * 2, model_dim, - num_tokens); - - RMSNormVJPT(weights.pre_ffw_norm_scale.data(), forward.attention_out.data(), - backward.bf_pre_ffw_rms_out.data(), - grad.pre_ffw_norm_scale.data(), backward.attention_out.data(), - model_dim, num_tokens); - - AddFromT(dy, backward.attention_out.data(), num_tokens * model_dim); - - MultiHeadMatMulVJPT(weights.attn_vec_einsum_w.data(), forward.att_out.data(), - backward.attention_out.data(), - grad.attn_vec_einsum_w.data(), backward.att_out.data(), - kHeads, model_dim, qkv_dim, num_tokens); - - MixByAttentionVJP(forward.qkv.data(), forward.att.data(), - backward.att_out.data(), backward.qkv.data(), - backward.att.data(), num_tokens, kHeads, qkv_dim, seq_len); - - MaskedSoftmaxVJPT(forward.att.data(), backward.att.data(), num_tokens, kHeads, - seq_len); - - MaskedAttentionVJP(forward.qkv.data(), backward.att.data(), - backward.qkv.data(), num_tokens, kHeads, qkv_dim, seq_len); - - for (size_t pos = 0; pos < num_tokens; ++pos) { - T* qkv = backward.qkv.data() + pos * (kHeads + 2) * qkv_dim; - MulByConstT(kQueryScale, qkv, kHeads * qkv_dim); - } - - for (int pos = 0; pos < num_tokens; ++pos) { - T* qkv = backward.qkv.data() + pos * (kHeads + 2) * qkv_dim; - for (size_t h = 0; h <= kHeads; ++h) { - Rope(qkv + h * qkv_dim, qkv_dim, -pos); - } - } - - MatMulVJPT(weights.qkv_einsum_w.data(), forward.pre_att_rms_out.data(), - backward.qkv.data(), grad.qkv_einsum_w.data(), - backward.pre_att_rms_out.data(), (kHeads + 2) * qkv_dim, model_dim, - num_tokens); - RMSNormVJPT(weights.pre_attention_norm_scale.data(), forward.input.data(), - backward.pre_att_rms_out.data(), - grad.pre_attention_norm_scale.data(), backward.input.data(), - model_dim, num_tokens); - - AddFromT(backward.attention_out.data(), backward.input.data(), - num_tokens * model_dim); -} - -template -void SoftcapVJPT(float cap, const T* y, T* dy, size_t N) { - const T inv_cap = T{1.0} / static_cast(cap); - for (size_t i = 0; i < N; ++i) { - T scaled = y[i] * inv_cap; // tanh - dy[i] *= (T{1.0} - scaled * scaled); - } -} - -template -void CrossEntropyLossGrad(const T* x, T* dx, const Prompt& prompt, size_t V) { - T scaling = -1.0 / std::log(2.0); - const std::vector tokens = prompt.tokens; - const size_t num_tokens = tokens.empty() ? 0 : tokens.size() - 1; - memset(dx, 0, V * num_tokens * sizeof(x[0])); - for (size_t i = 0; i < num_tokens; ++i) { - if (i + 1 < prompt.context_size) { - continue; - } - const int next_token = tokens[i + 1]; - dx[i * V + next_token] = scaling / x[i * V + next_token]; - } -} - -template -void CrossEntropyLossBackwardPass(const Prompt& prompt, - const ModelWeightsPtrs& weights, - const ForwardPass& forward, - ModelWeightsPtrs& grad, - ForwardPass& backward) { - const ModelConfig& config = weights.weights_config; - const size_t model_dim = config.model_dim; - const size_t vocab_size = config.vocab_size; - const size_t layers = config.layer_configs.size(); - const std::vector tokens = prompt.tokens; - const size_t num_tokens = tokens.empty() ? 0 : tokens.size() - 1; - - CrossEntropyLossGrad(forward.probs.data(), backward.logits.data(), prompt, - vocab_size); - - SoftmaxVJPT(forward.probs.data(), backward.logits.data(), vocab_size, - num_tokens); - - if (config.final_cap > 0.0f) { - for (size_t i = 0; i < num_tokens; ++i) { - SoftcapVJPT(config.final_cap, forward.logits.data() + i * vocab_size, - backward.logits.data() + i * vocab_size, vocab_size); - } - } - - MatMulVJPT( - weights.embedder_input_embedding.data(), forward.final_norm_output.data(), - backward.logits.data(), grad.embedder_input_embedding.data(), - backward.final_norm_output.data(), vocab_size, model_dim, num_tokens); - - RMSNormVJPT(weights.final_norm_scale.data(), - forward.final_layer_output.data(), - backward.final_norm_output.data(), grad.final_norm_scale.data(), - backward.final_layer_output.data(), model_dim, num_tokens); - - for (int layer = static_cast(layers) - 1; layer >= 0; --layer) { - T* next_layer_grad = layer + 1 < layers - ? backward.layers[layer + 1].input.data() - : backward.final_layer_output.data(); - LayerVJP(*weights.GetLayer(layer), forward.layers[layer], next_layer_grad, - *grad.GetLayer(layer), backward.layers[layer], num_tokens); - } - - const T kEmbScaling = EmbeddingScaling(model_dim); - InputEmbeddingVJPT(weights.embedder_input_embedding.data(), tokens, - kEmbScaling, backward.layers[0].input.data(), - grad.embedder_input_embedding.data(), model_dim); -} - -} // namespace gcpp - -#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_SCALAR_H_ diff --git a/backprop/backward_scalar_test.cc b/backprop/backward_scalar_test.cc deleted file mode 100644 index e40f3ed..0000000 --- a/backprop/backward_scalar_test.cc +++ /dev/null @@ -1,635 +0,0 @@ -// Copyright 2024 Google LLC -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "backprop/backward_scalar.h" - -#include -#include -#include // memcpy - -#include -#include -#include -#include - -#include "gtest/gtest.h" -#include "backprop/activations.h" -#include "backprop/common_scalar.h" -#include "backprop/forward_scalar.h" -#include "backprop/prompt.h" -#include "backprop/sampler.h" -#include "backprop/test_util.h" -#include "compression/compress.h" -#include "gemma/configs.h" -#include "gemma/weights.h" - -namespace gcpp { - -TEST(BackPropTest, MatMulVJP) { - static const size_t kRows = 8; - static const size_t kCols = 64; - static const size_t kTokens = 5; - std::mt19937 gen(42); - using T = double; - using TC = std::complex; - MatStorageT weights("weights", kRows, kCols); - MatStorageT x("x", kTokens, kCols); - MatStorageT grad("grad", kRows, kCols); - MatStorageT dx("dx", kTokens, kCols); - MatStorageT c_weights("c_weights", kRows, kCols); - MatStorageT c_x("c_x", kTokens, kCols); - MatStorageT c_y("c_y", kTokens, kRows); - MatStorageT dy("dy", kTokens, kRows); - - for (int iter = 0; iter < 10; ++iter) { - RandInit(weights, 1.0 * (1 << iter), gen); - RandInit(x, 1.0 * (1 << iter), gen); - RandInit(dy, 1.0, gen); - Complexify(weights, c_weights); - Complexify(x, c_x); - auto func = [&]() { - MatMulT(c_weights.data(), c_x.data(), c_y.data(), kRows, kCols, kTokens); - return DotT(dy.data(), c_y.data(), kTokens * kRows); - }; - grad.ZeroInit(); - MatMulVJPT(weights.data(), x.data(), dy.data(), grad.data(), dx.data(), - kRows, kCols, kTokens); - TestGradient(dx, c_x, func, 1e-11, 1e-12, __LINE__); - TestGradient(grad, c_weights, func, 1e-14, 1e-12, __LINE__); - } -} - -TEST(BackPropTest, MultiHeadMatMulVJP) { - static const size_t kRows = 2; - static const size_t kCols = 16; - static const size_t kHeads = 4; - static const size_t kTokens = 3; - std::mt19937 gen(42); - using T = double; - using TC = std::complex; - MatStorageT weights("weights", kRows, kCols * kHeads); - MatStorageT x("x", kTokens, kCols * kHeads); - MatStorageT grad("grad", kRows, kCols * kHeads); - MatStorageT dx("dx", kTokens, kCols * kHeads); - MatStorageT c_weights("c_weights", kRows, kCols * kHeads); - MatStorageT c_x("c_x", kTokens, kCols * kHeads); - MatStorageT c_y("c_y", kTokens, kRows); - MatStorageT dy("dy", kTokens, kRows); - - for (int iter = 0; iter < 10; ++iter) { - RandInit(weights, 1.0 * (1 << iter), gen); - RandInit(x, 1.0 * (1 << iter), gen); - RandInit(dy, 1.0, gen); - Complexify(weights, c_weights); - Complexify(x, c_x); - auto func = [&]() { - MultiHeadMatMul(c_weights.data(), c_x.data(), c_y.data(), kHeads, kRows, - kCols, kTokens); - return DotT(dy.data(), c_y.data(), kTokens * kRows); - }; - grad.ZeroInit(); - MultiHeadMatMulVJPT(weights.data(), x.data(), dy.data(), grad.data(), - dx.data(), kHeads, kRows, kCols, kTokens); - TestGradient(dx, c_x, func, 1e-15, 1e-13, __LINE__); - TestGradient(grad, c_weights, func, 1e-15, 1e-13, __LINE__); - } -} - -TEST(BackPropTest, RMSNormVJP) { - static const size_t K = 2; - static const size_t N = 64; - std::mt19937 gen(42); - using T = double; - using TC = std::complex; - MatStorageT weights("weights", N, 1); - MatStorageT grad("grad", N, 1); - MatStorageT x("x", K, N); - MatStorageT dx("dx", K, N); - MatStorageT dy("dy", K, N); - MatStorageT c_weights("c_weights", N, 1); - MatStorageT c_x("c_x", K, N); - MatStorageT c_y("c_y", K, N); - - for (int iter = 0; iter < 10; ++iter) { - RandInit(weights, 1.0 * (1 << iter), gen); - RandInit(x, 1.0 * (1 << iter), gen); - Complexify(weights, c_weights); - Complexify(x, c_x); - RandInit(dy, 1.0, gen); - auto func = [&]() { - RMSNormT(c_weights.data(), c_x.data(), c_y.data(), N, K); - return DotT(dy.data(), c_y.data(), K * N); - }; - grad.ZeroInit(); - RMSNormVJPT(weights.data(), x.data(), dy.data(), grad.data(), dx.data(), - N, K); - TestGradient(dx, c_x, func, 1e-15, 1e-14, __LINE__); - TestGradient(grad, c_weights, func, 1e-15, 1e-14, __LINE__); - } -} - -TEST(BackPropTest, SoftmaxVJP) { - static const size_t N = 64; - std::mt19937 gen(42); - using T = double; - using TC = std::complex; - MatStorageT x("x", N, 1); - MatStorageT dx("dx", N, 1); - MatStorageT dy("dy", N, 1); - MatStorageT c_x("c_x", N, 1); - MatStorageT c_y("c_y", N, 1); - - for (int iter = 0; iter < 10; ++iter) { - RandInit(x, 1.0 * (1 << iter), gen); - Complexify(x, c_x); - RandInit(dy, 1.0, gen); - auto func = [&]() { - memcpy(c_y.data(), c_x.data(), c_x.SizeBytes()); - Softmax(c_y.data(), N); - return DotT(dy.data(), c_y.data(), N); - }; - Softmax(x.data(), N); - memcpy(dx.data(), dy.data(), dx.SizeBytes()); - SoftmaxVJPT(x.data(), dx.data(), N); - TestGradient(dx, c_x, func, 1e-15, 1e-15, __LINE__); - } -} - -TEST(BackPropTest, MaskedSoftmaxVJP) { - static const size_t kSeqLen = 16; - static const size_t kHeads = 2; - static const size_t kTokens = 14; - static const size_t N = kTokens * kHeads * kSeqLen; - std::mt19937 gen(42); - using T = double; - using TC = std::complex; - MatStorageT x("x", N, 1); - MatStorageT dy("dy", N, 1); - MatStorageT dx("dx", N, 1); - MatStorageT c_x("c_x", N, 1); - MatStorageT c_y("c_y", N, 1); - dx.ZeroInit(); - - for (int iter = 0; iter < 10; ++iter) { - RandInit(x, 1.0 * (1 << iter), gen); - Complexify(x, c_x); - RandInit(dy, 1.0, gen); - auto func = [&]() { - memcpy(c_y.data(), c_x.data(), - kTokens * kHeads * kSeqLen * sizeof(c_x.At(0))); - MaskedSoftmax(c_y.data(), kTokens, kHeads, kSeqLen); - return DotT(dy.data(), c_y.data(), N); - }; - MaskedSoftmax(x.data(), kTokens, kHeads, kSeqLen); - memcpy(dx.data(), dy.data(), kTokens * kHeads * kSeqLen * sizeof(dx.At(0))); - MaskedSoftmaxVJPT(x.data(), dx.data(), kTokens, kHeads, kSeqLen); - TestGradient(dx, c_x, func, 1e-14, 1e-15, __LINE__); - } -} - -TEST(BackPropTest, SoftcapVJP) { - static const size_t N = 64; - std::mt19937 gen(42); - using T = double; - using TC = std::complex; - MatStorageT x("x", N, 1); - MatStorageT dx("dx", N, 1); - MatStorageT dy("dy", N, 1); - MatStorageT c_x("c_x", N, 1); - MatStorageT c_y("c_y", N, 1); - - constexpr float kCap = 30.0f; - for (int iter = 0; iter < 10; ++iter) { - RandInit(x, 1.0 * (1 << iter), gen); - Complexify(x, c_x); - RandInit(dy, 1.0, gen); - auto func = [&]() { - memcpy(c_y.data(), c_x.data(), N * sizeof(c_x.At(0))); - Softcap(kCap, c_y.data(), N); - return DotT(dy.data(), c_y.data(), N); - }; - Softcap(kCap, x.data(), N); - memcpy(dx.data(), dy.data(), dx.SizeBytes()); - SoftcapVJPT(kCap, x.data(), dx.data(), N); - TestGradient(dx, c_x, func, 1e-15, 1e-14, __LINE__); - } -} - -TEST(BackPropTest, CrossEntropyLossGrad) { - static const size_t K = 8; - static const size_t V = 64; - std::mt19937 gen(42); - using T = double; - using TC = std::complex; - MatStorageT x("x", K, V); - MatStorageT dx("dx", K, V); - MatStorageT c_x("c_x", K, V); - Prompt prompt; - prompt.tokens = { 0, 1, 2, 3, 0, 3, 2, 1, 0 }; - - const float kCap = 30.0f; - for (int iter = 0; iter < 10; ++iter) { - prompt.context_size = 1 + (iter % 6); - RandInit(x, 1.0 * (1 << iter), gen); - Softcap(kCap, x.data(), V * K); - Softmax(x.data(), V, K); - CrossEntropyLossGrad(x.data(), dx.data(), prompt, V); - Complexify(x, c_x); - auto func = [&]() { - return CrossEntropyLoss(c_x.data(), prompt, V); - }; - TestGradient(dx, c_x, func, 1e-100, 1e-15, __LINE__); - } -} - -TEST(BackPropTest, GatedGeluVJP) { - static const size_t K = 2; - static const size_t N = 64; - std::mt19937 gen(42); - using T = double; - using TC = std::complex; - MatStorageT x("x", K, 2 * N); - MatStorageT dx("dx", K, 2 * N); - MatStorageT dy("dy", K, N); - MatStorageT c_x("c_x", K, 2 * N); - MatStorageT c_y("c_y", K, N); - - for (int iter = 0; iter < 10; ++iter) { - RandInit(x, 1.0, gen); - Complexify(x, c_x); - RandInit(dy, 1.0, gen); - auto func = [&]() { - GatedGelu(c_x.data(), c_y.data(), N, K); - return DotT(dy.data(), c_y.data(), N * K); - }; - GatedGeluVJP(x.data(), dy.data(), dx.data(), N, K); - TestGradient(dx, c_x, func, 1e-15, 1e-15, __LINE__); - } -} - -TEST(BackPropTest, MaskedAttentionVJP) { - static const size_t kSeqLen = 16; - static const size_t kHeads = 2; - static const size_t kQKVDim = 8; - static const size_t kTokens = 14; - static const size_t kQKVSize = kSeqLen * (kHeads + 2) * kQKVDim; - static const size_t kOutSize = kTokens * kHeads * kSeqLen; - std::mt19937 gen(42); - using T = double; - using TC = std::complex; - MatStorageT x("x", kQKVSize, 1); - MatStorageT dx("dx", kQKVSize, 1); - MatStorageT dy("dy", kOutSize, 1); - MatStorageT c_x("c_x", kQKVSize, 1); - MatStorageT c_y("c_y", kOutSize, 1); - dx.ZeroInit(); - c_y.ZeroInit(); - - for (int iter = 0; iter < 10; ++iter) { - RandInit(x, 1.0, gen); - Complexify(x, c_x); - RandInit(dy, 1.0, gen); - auto func = [&]() { - MaskedAttention(c_x.data(), c_y.data(), kTokens, kHeads, kQKVDim, - kSeqLen); - return DotT(dy.data(), c_y.data(), kOutSize); - }; - MaskedAttentionVJP(x.data(), dy.data(), dx.data(), - kTokens, kHeads, kQKVDim, kSeqLen); - TestGradient(dx, c_x, func, 1e-14, 1e-15, __LINE__); - } -} - -TEST(BackPropTest, MixByAttentionVJP) { - static const size_t kSeqLen = 16; - static const size_t kHeads = 2; - static const size_t kQKVDim = 8; - static const size_t kTokens = 14; - static const size_t kQKVSize = kSeqLen * (kHeads + 2) * kQKVDim; - static const size_t kAttnSize = kSeqLen * kHeads * kSeqLen; - static const size_t kOutSize = kSeqLen * kHeads * kQKVDim; - std::mt19937 gen(42); - using T = double; - using TC = std::complex; - MatStorageT qkv("qkv", kQKVSize, 1); - MatStorageT dqkv("dqkv", kQKVSize, 1); - MatStorageT attn("attn", kAttnSize, 1); - MatStorageT dattn("dattn", kAttnSize, 1); - MatStorageT dy("dy", kOutSize, 1); - MatStorageT c_qkv("c_qkv", kQKVSize, 1); - MatStorageT c_attn("c_attn", kAttnSize, 1); - MatStorageT c_y("c_y", kOutSize, 1); - dqkv.ZeroInit(); - dattn.ZeroInit(); - c_y.ZeroInit(); - - for (int iter = 0; iter < 10; ++iter) { - RandInit(qkv, 1.0, gen); - RandInit(attn, 1.0, gen); - Complexify(qkv, c_qkv); - Complexify(attn, c_attn); - RandInit(dy, 1.0, gen); - auto func = [&]() { - MixByAttention(c_qkv.data(), c_attn.data(), c_y.data(), - kTokens, kHeads, kQKVDim, kSeqLen); - return DotT(dy.data(), c_y.data(), kOutSize); - }; - MixByAttentionVJP(qkv.data(), attn.data(), dy.data(), dqkv.data(), - dattn.data(), kTokens, kHeads, kQKVDim, kSeqLen); - TestGradient(dqkv, c_qkv, func, 1e-14, 1e-15, __LINE__); - TestGradient(dattn, c_attn, func, 1e-14, 1e-15, __LINE__); - } -} - -TEST(BackPropTest, InputEmbeddingVJP) { - static const size_t kSeqLen = 8; - static const size_t kVocabSize = 4; - static const size_t kModelDim = 16; - std::mt19937 gen(42); - using T = double; - using TC = std::complex; - MatStorageT weights("weights", kVocabSize, kModelDim); - MatStorageT grad("grad", kVocabSize, kModelDim); - MatStorageT dy("dy", kSeqLen, kModelDim); - MatStorageT c_weights("c_weights", kVocabSize, kModelDim); - MatStorageT c_y("c_y", kSeqLen, kModelDim); - std::vector tokens = { 0, 1, 2, 3, 0, 1, 2 }; - size_t num_tokens = tokens.size() - 1; - - for (size_t iter = 0; iter < 10; ++iter) { - RandInit(weights, 1.0, gen); - RandInit(dy, 1.0, gen); - Complexify(weights, c_weights); - auto func = [&]() { - InputEmbedding(c_weights.data(), tokens, TC(3.0), c_y.data(), kModelDim); - return DotT(dy.data(), c_y.data(), num_tokens * kModelDim); - }; - grad.ZeroInit(); - InputEmbeddingVJPT(weights.data(), tokens, 3.0, dy.data(), grad.data(), - kModelDim); - TestGradient(grad, c_weights, func, 1e-16, 1e-14, __LINE__); - } -} - -static ModelConfig TestConfig() { - ModelConfig config; - config.scale_names = {"att_ein", "qkv_ein", "gr_lin_x_w", "gr_lin_y_w", - "gr_lin_out_w", "gr_gate_w", "gating_ein", "linear_w"}; - config.model_dim = 32; - config.vocab_size = 12; - config.seq_len = 18; - LayerConfig layer_config; - layer_config.model_dim = config.model_dim; - layer_config.ff_hidden_dim = 48; - layer_config.heads = 3; - layer_config.kv_heads = 1; - layer_config.qkv_dim = 12; - config.layer_configs = {2, layer_config}; - config.num_tensor_scales = 4 * config.layer_configs.size(); - config.query_scale = QueryScaleType::SqrtKeySize; - config.attention_window_sizes = FixedAttentionWindowSizes<2>(32); - // This is required for optimize_test to pass. - config.final_cap = 30.0f; - return config; -} - -TEST(BackPropTest, LayerVJP) { - std::mt19937 gen(42); - using T = double; - using TC = std::complex; - ModelConfig config = TestConfig(); - TensorIndex tensor_index(config, /*llm_layer_idx=*/0, /*img_layer_idx=*/-1, - /*reshape_att=*/false); - const size_t kOutputSize = config.seq_len * config.model_dim; - LayerWeightsPtrs weights(config.layer_configs[0], tensor_index); - LayerWeightsPtrs grad(config.layer_configs[0], tensor_index); - ForwardLayer forward(config.layer_configs[0], config.seq_len); - ForwardLayer backward(config.layer_configs[0], config.seq_len); - LayerWeightsPtrs c_weights(config.layer_configs[0], tensor_index); - ForwardLayer c_forward(config.layer_configs[0], config.seq_len); - MatStorageT y("y", kOutputSize, 1); - MatStorageT dy("dy", kOutputSize, 1); - MatStorageT c_y("c_y", kOutputSize, 1); - const size_t num_tokens = 3; - std::vector layer_storage; - weights.Allocate(layer_storage); - grad.Allocate(layer_storage); - c_weights.Allocate(layer_storage); - backward.input.ZeroInit(); - - for (size_t iter = 0; iter < 10; ++iter) { - RandInit(weights, 1.0, gen); - RandInit(forward.input, 1.0, gen); - RandInit(dy, 1.0, gen); - Complexify(weights, c_weights); - Complexify(forward.input, c_forward.input); - auto func = [&]() { - ApplyLayer(c_weights, c_forward, num_tokens, c_y.data()); - return DotT(dy.data(), c_y.data(), num_tokens * config.model_dim); - }; - grad.ZeroInit(/*layer_idx=*/0); - ApplyLayer(weights, forward, num_tokens, y.data()); - LayerVJP(weights, forward, dy.data(), grad, backward, num_tokens); - TestGradient(backward.input, c_forward.input, func, 1e-11, 5e-11, - __LINE__); - TestGradient(grad, c_weights, func, 1e-11); - } -} - -TEST(BackPropTest, EndToEnd) { - std::mt19937 gen(42); - using T = double; - using TC = std::complex; - ModelConfig config = TestConfig(); - WeightsWrapper weights(config); - WeightsWrapper grad(config); - ForwardPass forward(config); - ForwardPass backward(config); - WeightsWrapper c_weights(config); - ForwardPass c_forward(config); - - ReverseSequenceSampler training_task({0, 0, 1, 1}); - std::vector batch = training_task.SampleBatch(3, gen); - - for (const Prompt& prompt : batch) { - ReverseSequenceSampler::LogPrompt(prompt); - RandInit(weights.get(), 1.0, gen); - CrossEntropyLossForwardPass(prompt, weights.get(), forward); - grad.ZeroInit(); - CrossEntropyLossBackwardPass( - prompt, weights.get(), forward, grad.get(), backward); - - Complexify(weights.get(), c_weights.get()); - auto func = [&]() { - return CrossEntropyLossForwardPass(prompt, c_weights.get(), c_forward); - }; - - TestGradient(grad.get(), c_weights.get(), func, 1e-11); - } -} - -template -void MulByConstAndAddT(T c, const LayerWeightsPtrs& x, - LayerWeightsPtrs& out) { - MulByConstAndAddT(c, x.pre_attention_norm_scale, - out.pre_attention_norm_scale); - MulByConstAndAddT(c, x.attn_vec_einsum_w, out.attn_vec_einsum_w); - MulByConstAndAddT(c, x.qkv_einsum_w, out.qkv_einsum_w); - MulByConstAndAddT(c, x.pre_ffw_norm_scale, out.pre_ffw_norm_scale); - MulByConstAndAddT(c, x.gating_einsum_w, out.gating_einsum_w); - MulByConstAndAddT(c, x.linear_w, out.linear_w); -} - -template -void MulByConstAndAddT(T c, const ModelWeightsPtrs& x, - ModelWeightsPtrs& out) { - const size_t layers = x.c_layers.size(); - MulByConstAndAddT(c, x.embedder_input_embedding, - out.embedder_input_embedding); - MulByConstAndAddT(c, x.final_norm_scale, out.final_norm_scale); - for (size_t i = 0; i < layers; ++i) { - MulByConstAndAddT(c, *x.GetLayer(i), *out.GetLayer(i)); - } -} - -// Evaluates forward pass on a batch. -template -T CrossEntropyLossForwardPass(const std::vector& batch, - const WeightsWrapper& weights, - ForwardPass& forward) { - T loss = 0.0; - for (const Prompt& prompt : batch) { - loss += CrossEntropyLossForwardPass(prompt, weights.get(), forward); - } - T scale = 1.0 / batch.size(); - return loss * scale; -} - -// Evaluates forward pass on a batch by applying gradient with the given -// learning rate. Does not update weights, but uses the given tmp weights -// instead. -template -T CrossEntropyLossForwardPass(T learning_rate, const std::vector& batch, - const WeightsWrapper& weights, - const WeightsWrapper& grad, - WeightsWrapper& tmp, ForwardPass& forward) { - tmp.CopyFrom(weights); - const T scale = -learning_rate / batch.size(); - MulByConstAndAddT(scale, grad.get(), tmp.get()); - return CrossEntropyLossForwardPass(batch, tmp, forward); -} - -// Uses line search in the negative gradient direction to update weights. We do -// this so that we can test that each step during the gradient descent can -// decrease the objective function value. -template -T FindOptimalUpdate(const WeightsWrapper& grad, WeightsWrapper& weights, - WeightsWrapper& tmp, ForwardPass& forward, - const std::vector& batch, T loss, - T initial_learning_rate) { - T lr0 = initial_learning_rate; - T loss0 = CrossEntropyLossForwardPass( - lr0, batch, weights, grad, tmp, forward); - for (size_t iter = 0; iter < 30; ++iter) { - T lr1 = lr0 * 0.5; - T loss1 = CrossEntropyLossForwardPass( - lr1, batch, weights, grad, tmp, forward); - if (loss0 < loss && loss1 >= loss0) { - break; - } - loss0 = loss1; - lr0 = lr1; - } - for (size_t iter = 0; iter < 30; ++iter) { - T lr1 = lr0 * 2.0; - T loss1 = CrossEntropyLossForwardPass( - lr1, batch, weights, grad, tmp, forward); - if (loss1 >= loss0) { - break; - } - loss0 = loss1; - lr0 = lr1; - } - const T scale = -lr0 / batch.size(); - MulByConstAndAddT(scale, grad.get(), weights.get()); - return lr0; -} - -TEST(BackProptest, Convergence) { - std::mt19937 gen(42); - using T = float; - using TC = std::complex; - ModelConfig config = TestConfig(); - WeightsWrapper weights(config); - WeightsWrapper grad(config); - WeightsWrapper tmp(config); - ForwardPass forward(config); - ForwardPass backward(config); - WeightsWrapper c_weights(config); - ForwardPass c_forward(config); - constexpr size_t kBatchSize = 5; - ReverseSequenceSampler training_task({0, 0, 0, 1, 1}); - T learning_rate = 0.01; - - RandInit(weights.get(), T(1.0), gen); - - printf("Sample batch:\n"); - for (size_t i = 0; i < 10; ++i) { - ReverseSequenceSampler::LogPrompt(training_task.Sample(gen)); - } - - T prev_loss = std::numeric_limits::max(); - bool stop = false; - size_t step = 0; - while (!stop) { - T loss = 0.0; - grad.ZeroInit(); - std::mt19937 sgen(42); - std::vector batch = training_task.SampleBatch(kBatchSize, sgen); - for (const Prompt& prompt : batch) { - loss += CrossEntropyLossForwardPass(prompt, weights.get(), forward); - CrossEntropyLossBackwardPass( - prompt, weights.get(), forward, grad.get(), backward); - } - - if (step % 250 == 0) { - printf("Checking gradient...\n"); - Complexify(weights.get(), c_weights.get()); - auto func = [&]() { - TC scale = batch.size(); - return CrossEntropyLossForwardPass(batch, c_weights, c_forward) * scale; - }; - - TestGradient(grad.get(), c_weights.get(), func, 5e-3f); - } - - loss /= batch.size(); - EXPECT_LT(loss, prev_loss); - stop = step >= 10000 || loss < 1e-2; - if (step % 10 == 0 || stop) { - printf("step: %5zu loss: %.15f learning_rate: %.15f\n", - step, loss, learning_rate); - } - if (!stop) { - learning_rate = FindOptimalUpdate( - grad, weights, tmp, forward, batch, loss, learning_rate); - ++step; - } - prev_loss = loss; - } - EXPECT_LT(step, 1000); -} - -} // namespace gcpp diff --git a/backprop/backward_test.cc b/backprop/backward_test.cc deleted file mode 100644 index f1c97b2..0000000 --- a/backprop/backward_test.cc +++ /dev/null @@ -1,279 +0,0 @@ -// Copyright 2023 Google LLC -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef HWY_DISABLED_TARGETS -#define HWY_DISABLED_TARGETS HWY_SCALAR -#endif - -#include - -#include -#include // std::abs -#include -#include - -#include "backprop/activations.h" -#include "backprop/backward_scalar.h" -#include "backprop/common_scalar.h" -#include "backprop/forward_scalar.h" -#include "backprop/prompt.h" -#include "backprop/sampler.h" -#include "backprop/test_util.h" -#include "gemma/configs.h" -#include "ops/ops.h" -#include "util/threading.h" -#include "hwy/base.h" - -// clang-format off -#undef HWY_TARGET_INCLUDE -#define HWY_TARGET_INCLUDE "backprop/backward_test.cc" //NOLINT -// clang-format on -#include "hwy/foreach_target.h" // IWYU pragma: keep -#include "hwy/highway.h" -#include "hwy/tests/test_util-inl.h" -// After highway.h -#include "backprop/backward-inl.h" -#include "backprop/forward-inl.h" -#include "compression/compress.h" -#include "ops/ops-inl.h" -#include "util/allocator.h" - -HWY_BEFORE_NAMESPACE(); -namespace gcpp { -namespace HWY_NAMESPACE { - -void TestMatMulVJP() { - static const size_t kRows = 8; - static const size_t kCols = 64; - static const size_t kTokens = 5; - const BoundedTopology topology(BoundedSlice(0, 1), BoundedSlice(0, 8)); - Allocator::Init(topology); - gcpp::NestedPools pools(topology, 1, /*pin=*/Tristate::kFalse); - std::mt19937 gen(42); - MatStorageT weights("weights", kRows, kCols); - MatStorageT x("x", kTokens, kCols); - MatStorageT dy("dy", kTokens, kRows); - MatStorageT grad("grad", kRows, kCols); - MatStorageT dx("dx", kTokens, kCols); - MatStorageT grad_scalar("grad_scalar", kRows, kCols); - MatStorageT dx_scalar("dx_scalar", kTokens, kCols); - using TC = std::complex; - MatStorageT c_weights("c_weights", kRows, kCols); - MatStorageT c_x("c_x", kTokens, kCols); - MatStorageT c_y("c_y", kTokens, kRows); - - for (int iter = 0; iter < 10; ++iter) { - RandInit(weights, 1.0f * (1 << iter), gen); - RandInit(x, 1.0f * (1 << iter), gen); - RandInit(dy, 1.0f, gen); - Complexify(weights, c_weights); - Complexify(x, c_x); - auto func = [&]() { - MatMulT(c_weights.data(), c_x.data(), c_y.data(), kRows, kCols, kTokens); - return DotT(dy.data(), c_y.data(), kTokens * kRows); - }; - - grad.ZeroInit(); - MatMulVJP(weights.data(), x.data(), dy.data(), kCols, kRows, kTokens, - grad.data(), dx.data(), pools.Pool()); - TestGradient(dx, c_x, func, 5e-5f, 5e-5f, __LINE__); - TestGradient(grad, c_weights, func, 5e-5f, 5e-5f, __LINE__); - - grad_scalar.ZeroInit(); - MatMulVJPT(weights.data(), x.data(), dy.data(), grad_scalar.data(), - dx_scalar.data(), kRows, kCols, kTokens); - TestNear(dx, dx_scalar, 5e-5, 1e-4, __LINE__); - TestNear(grad, grad_scalar, 5e-5, 5e-5, __LINE__); - } -} - -void TestMultiHeadMatMulVJP() { - static const size_t kRows = 2; - static const size_t kCols = 16; - static const size_t kHeads = 4; - static const size_t kTokens = 3; - const BoundedTopology topology(BoundedSlice(0, 1), BoundedSlice(0, 8)); - Allocator::Init(topology); - gcpp::NestedPools pools(topology, 1, /*pin=*/Tristate::kFalse); - std::mt19937 gen(42); - MatStorageT weights("weights", kRows, kCols * kHeads); - MatStorageT x("x", kTokens, kCols * kHeads); - MatStorageT grad("grad", kRows, kCols * kHeads); - MatStorageT dx("dx", kTokens, kCols * kHeads); - MatStorageT dy("dy", kTokens, kRows); - MatStorageT grad_scalar("grad_scalar", kRows, kCols * kHeads); - MatStorageT dx_scalar("dx_scalar", kTokens, kCols * kHeads); - using TC = std::complex; - MatStorageT c_weights("c_weights", kRows, kCols * kHeads); - MatStorageT c_x("c_x", kTokens, kCols * kHeads); - MatStorageT c_y("c_y", kTokens, kRows); - - for (int iter = 0; iter < 10; ++iter) { - RandInit(weights, 1.0f * (1 << iter), gen); - RandInit(x, 1.0f * (1 << iter), gen); - RandInit(dy, 1.0f, gen); - Complexify(weights, c_weights); - Complexify(x, c_x); - auto func = [&]() { - MultiHeadMatMul(c_weights.data(), c_x.data(), c_y.data(), kHeads, kRows, - kCols, kTokens); - return DotT(dy.data(), c_y.data(), kTokens * kRows); - }; - - grad.ZeroInit(); - MultiHeadMatMulVJP(weights.data(), x.data(), dy.data(), kHeads, kCols, - kRows, kTokens, grad.data(), dx.data(), pools.Pool()); - TestGradient(dx, c_x, func, 5e-5f, 5e-5f, __LINE__); - TestGradient(grad, c_weights, func, 5e-5f, 5e-5f, __LINE__); - - grad_scalar.ZeroInit(); - MultiHeadMatMulVJPT(weights.data(), x.data(), dy.data(), grad_scalar.data(), - dx_scalar.data(), kHeads, kRows, kCols, kTokens); - TestNear(dx, dx_scalar, 5e-5, 5e-5, __LINE__); - TestNear(grad, grad_scalar, 5e-5, 5e-5, __LINE__); - } -} - -void TestRMSNormVJP() { - static const size_t K = 2; - static const size_t N = 64; - const BoundedTopology topology(BoundedSlice(0, 1), BoundedSlice(0, 8)); - Allocator::Init(topology); - gcpp::NestedPools pools(topology, 1, /*pin=*/Tristate::kFalse); - std::mt19937 gen(42); - MatStorageT weights("weights", N, 1); - MatStorageT x("x", K, N); - MatStorageT grad("grad", N, 1); - MatStorageT dx("dx", K, N); - MatStorageT dy("dy", K, N); - MatStorageT grad_scalar("grad_scalar", N, 1); - MatStorageT dx_scalar("dx_scalar", K, N); - using TC = std::complex; - MatStorageT c_weights("c_weights", N, 1); - MatStorageT c_x("c_x", K, N); - MatStorageT c_y("c_y", K, N); - - for (int iter = 0; iter < 10; ++iter) { - RandInit(weights, 1.0f * (1 << iter), gen); - RandInit(x, 1.0f * (1 << iter), gen); - RandInit(dy, 1.0f, gen); - Complexify(weights, c_weights); - Complexify(x, c_x); - auto func = [&]() { - RMSNormT(c_weights.data(), c_x.data(), c_y.data(), N, K); - return DotT(dy.data(), c_y.data(), K * N); - }; - - grad.ZeroInit(); - RMSNormVJP(weights.data(), x.data(), dy.data(), N, K, grad.data(), - dx.data(), pools.Pool()); - TestGradient(dx, c_x, func, 5e-5f, 5e-5f, __LINE__); - TestGradient(grad, c_weights, func, 5e-5f, 5e-5f, __LINE__); - - grad_scalar.ZeroInit(); - RMSNormVJPT(weights.data(), x.data(), dy.data(), grad_scalar.data(), - dx_scalar.data(), N, K); - TestNear(dx, dx_scalar, 0, 2e-5, __LINE__); - TestNear(grad, grad_scalar, 0, 2e-5, __LINE__); - } -} - -static ModelConfig TestConfig() { - ModelConfig config; - config.scale_names = {"att_ein", "qkv_ein", "gr_lin_x_w", "gr_lin_y_w", - "gr_lin_out_w", "gr_gate_w", "gating_ein", "linear_w"}; - config.model_dim = 32; - config.vocab_size = 16; - config.seq_len = 24; - LayerConfig layer_config; - layer_config.model_dim = config.model_dim; - layer_config.ff_hidden_dim = 64; - layer_config.heads = 3; - layer_config.kv_heads = 1; - layer_config.qkv_dim = 16; - config.layer_configs = {2, layer_config}; - config.num_tensor_scales = 4 * config.layer_configs.size(); - config.query_scale = QueryScaleType::SqrtKeySize; - config.attention_window_sizes = FixedAttentionWindowSizes<2>(32); - // This is required for optimize_test to pass. - config.att_cap = 50.0f; - config.final_cap = 30.0f; - return config; -} - -void TestEndToEnd() { - std::mt19937 gen(42); - const BoundedTopology topology(BoundedSlice(0, 1), BoundedSlice(0, 1)); - Allocator::Init(topology); - gcpp::NestedPools pools(topology, 1, /*pin=*/Tristate::kFalse); - ModelConfig config = TestConfig(); - WeightsWrapper weights(config); - WeightsWrapper grad(config); - ForwardPass forward0(config); - ForwardPass forward1(config); - ForwardPass backward(config); - using TC = std::complex; - WeightsWrapper c_weights(config); - ForwardPass c_forward(config); - - ReverseSequenceSampler training_task({0, 0, 1, 1}); - std::vector batch = training_task.SampleBatch(3, gen); - - RowVectorBatch inv_timescale = CreateInvTimescale( - config.layer_configs[0].qkv_dim, - config.layer_configs[0].post_qk == PostQKType::HalfRope); - for (const Prompt& prompt : batch) { - ReverseSequenceSampler::LogPrompt(prompt); - RandInit(weights.get(), 1.0f, gen); - - float loss0 = CrossEntropyLossForwardPass(prompt, weights.get(), forward0); - - float loss1 = CrossEntropyLossForwardPass( - prompt.tokens, prompt.context_size, weights.get(), forward1, - inv_timescale, pools.Pool()); - - EXPECT_NEAR(loss1, loss0, std::abs(loss0) * 2e-5); - - grad.ZeroInit(); - CrossEntropyLossBackwardPassInl(prompt, weights.get(), forward1, grad.get(), - backward, inv_timescale, pools.Pool()); - - Complexify(weights.get(), c_weights.get()); - auto func = [&]() { - return CrossEntropyLossForwardPass(prompt, c_weights.get(), c_forward); - }; - - TestGradient(grad.get(), c_weights.get(), func, 2e-3f); - } -} - -// NOLINTNEXTLINE(google-readability-namespace-comments) -} // namespace HWY_NAMESPACE -} // namespace gcpp -HWY_AFTER_NAMESPACE(); - -#if HWY_ONCE - -namespace gcpp { -HWY_BEFORE_TEST(BackwardTest); -HWY_EXPORT_AND_TEST_P(BackwardTest, TestMatMulVJP); -HWY_EXPORT_AND_TEST_P(BackwardTest, TestMultiHeadMatMulVJP); -HWY_EXPORT_AND_TEST_P(BackwardTest, TestRMSNormVJP); -HWY_EXPORT_AND_TEST_P(BackwardTest, TestEndToEnd); -HWY_AFTER_TEST(); - -} // namespace gcpp - -#endif diff --git a/backprop/common_scalar.h b/backprop/common_scalar.h deleted file mode 100644 index c61086d..0000000 --- a/backprop/common_scalar.h +++ /dev/null @@ -1,121 +0,0 @@ -// Copyright 2024 Google LLC -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_SCALAR_H_ -#define THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_SCALAR_H_ - -#include - -#include - -#include "compression/compress.h" // MatStorageT - -namespace gcpp { - -template -U DotT(const T* a, const U* b, size_t N) { - U sum = {}; - for (size_t i = 0; i < N; ++i) { - sum += a[i] * b[i]; - } - return sum; -} - -template<> -inline std::complex DotT(const float* a, const std::complex* b, - size_t N) { - std::complex sum = {}; - for (size_t i = 0; i < N; ++i) { - sum += static_cast(a[i]) * b[i]; - } - return sum; -} - -template -void MulByConstT(T c, T* x, size_t N) { - for (size_t i = 0; i < N; ++i) { - x[i] *= c; - } -} - -// out += c * x -template -void MulByConstAndAddT(T c, const T* x, T* out, size_t N) { - for (size_t i = 0; i < N; ++i) { - out[i] += c * x[i]; - } -} - -template -void MulByConstAndAddT(T c, const MatPtrT& x, MatPtrT& out) { - MulByConstAndAddT(c, x.data(), out.data(), x.NumElements()); -} - -template -void AddFromT(const T* a, T* out, size_t N) { - for (size_t i = 0; i < N; ++i) { - out[i] += a[i]; - } -} - -template -T SquaredL2(const T* x, size_t N) { - T sum = {}; - for (size_t i = 0; i < N; ++i) { - sum += x[i] * x[i]; - } - return sum; -} - -template -T Gelu(T x) { - static const T kMul = 0.044715; - static const T kSqrt2OverPi = 0.797884560804236; - - const T x3 = x * x * x; - const T arg = kSqrt2OverPi * (kMul * x3 + x); - const T cdf = T(0.5) * (T(1.0) + std::tanh(arg)); - return x * cdf; -} - -template -void Rope(T* x, U base, size_t N, int i) { - const size_t N2 = N / 2; - for (size_t dim = 0; dim < N2; ++dim) { - const T freq_exponents = T(2 * dim) / T(N); - const T timescale = std::pow(base, freq_exponents); - const T theta = T(i) / timescale; - const T cos_val = std::cos(theta); - const T sin_val = std::sin(theta); - const T x0 = x[dim]; - const T x1 = x[dim + N2]; - x[dim] = x0 * cos_val - x1 * sin_val; - x[dim + N2] = x0 * sin_val + x1 * cos_val; - } -} - -template -void Rope(T* x, size_t N, int i) { - Rope(x, T(10000.0), N, i); -} - -template -void Rope(std::complex* x, size_t N, int i) { - Rope(x, T(10000.0), N, i); -} - -} // namespace gcpp - -#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_SCALAR_H_ diff --git a/backprop/forward-inl.h b/backprop/forward-inl.h deleted file mode 100644 index ca969c4..0000000 --- a/backprop/forward-inl.h +++ /dev/null @@ -1,296 +0,0 @@ -// Copyright 2024 Google LLC -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Include guard for non-SIMD code. -#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_INL_H_ -#define THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_INL_H_ - -#include -#include - -#include -#include - -#include "backprop/activations.h" -#include "gemma/common.h" -#include "gemma/configs.h" -#include "gemma/weights.h" -#include "util/allocator.h" -#include "hwy/base.h" -#include "hwy/contrib/thread_pool/thread_pool.h" - -#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_INL_H_ - -// Include guard for (potentially) SIMD code. -#if defined(THIRD_PARTY_GEMMA_CPP_FORWARD_TOGGLE) == defined(HWY_TARGET_TOGGLE) -#ifdef THIRD_PARTY_GEMMA_CPP_FORWARD_TOGGLE -#undef THIRD_PARTY_GEMMA_CPP_FORWARD_TOGGLE -#else -#define THIRD_PARTY_GEMMA_CPP_FORWARD_TOGGLE -#endif - -#include "hwy/highway.h" -// After highway.h -#include "ops/matvec-inl.h" -#include "ops/ops-inl.h" - -HWY_BEFORE_NAMESPACE(); -namespace gcpp { -namespace HWY_NAMESPACE { - -template -void InputEmbedding(const ArrayT& weights, const std::vector& prompt, - const float scaling, float* HWY_RESTRICT output, - size_t model_dim, size_t vocab_size) { - const hn::ScalableTag df; - HWY_ASSERT(!prompt.empty()); - for (size_t pos = 0; pos < prompt.size() - 1; ++pos) { - int token = prompt[pos]; - DecompressAndZeroPad(df, MakeSpan(weights.data(), model_dim * vocab_size), - token * model_dim, output + pos * model_dim, - model_dim); - MulByConst(scaling, output + pos * model_dim, model_dim); - } -} - -template -void ApplyRMSNorm(const WT* HWY_RESTRICT weights, const XT* HWY_RESTRICT x, - size_t model_dim, size_t num_tokens, - OutT* HWY_RESTRICT output, - hwy::ThreadPool& pool) { - for (size_t pos = 0; pos < num_tokens; ++pos) { - const size_t offset = pos * model_dim; - RMSNorm(x + offset, weights, output + offset, model_dim); - } -} - -static HWY_NOINLINE float CrossEntropyLoss(const float* HWY_RESTRICT probs, - const std::vector& prompt, - size_t context_size, - size_t vocab_size, - hwy::ThreadPool& pool) { - HWY_ASSERT(!prompt.empty()); - float loss = 0.0f; - for (size_t pos = 0; pos < prompt.size() - 1; ++pos) { - if (pos + 1 < context_size) { - continue; // next token is part of context, don't try to predict it - } - const int next_token = prompt[pos + 1]; - loss += std::log(probs[pos * vocab_size + next_token]); - } - float scaling = -1.0 / std::log(2.0); - return loss * scaling; -} - -template -void ApplyForwardLayer(const LayerWeightsPtrs& weights, - ForwardLayer& activations, size_t num_tokens, - float* HWY_RESTRICT output, - const RowVectorBatch& inv_timescale, - hwy::ThreadPool& pool) { - const LayerConfig& config = weights.layer_config; - const size_t model_dim = config.model_dim; - const size_t kSeqLen = activations.input.Rows(); - const size_t kQKVDim = config.qkv_dim; - const size_t kHeads = config.heads; - static const float query_scale = - static_cast(1.0 / sqrt(static_cast(kQKVDim))); - HWY_ASSERT(num_tokens <= kSeqLen); - - ApplyRMSNorm(weights.pre_attention_norm_scale.data(), - activations.input.data(), model_dim, num_tokens, - activations.pre_att_rms_out.data(), pool); - - for (size_t pos = 0; pos < num_tokens; ++pos) { - MatVec(weights.qkv_einsum_w, 0, (kHeads + 2) * kQKVDim, model_dim, - activations.pre_att_rms_out.data() + pos * model_dim, - activations.qkv.data() + pos * (kHeads + 2) * kQKVDim, pool); - } - const size_t num_tasks = kHeads * num_tokens; - - for (size_t pos = 0; pos < num_tokens; ++pos) { - float* HWY_RESTRICT k = - activations.qkv.data() + (pos * (kHeads + 2) + kHeads) * kQKVDim; - Rope(k, kQKVDim, inv_timescale.Const(), pos); - } - pool.Run(0, num_tasks, [&](const uint64_t task, size_t thread) HWY_ATTR { - const size_t head = task % kHeads; - const size_t pos = task / kHeads; - float* HWY_RESTRICT q = - activations.qkv.data() + (pos * (kHeads + 2) + head) * kQKVDim; - Rope(q, kQKVDim, inv_timescale.Const(), pos); - MulByConst(query_scale, q, kQKVDim); - }); - - pool.Run(0, num_tasks, [&](const uint64_t task, size_t thread) HWY_ATTR { - const size_t head = task % kHeads; - const size_t pos = task / kHeads; - const float* HWY_RESTRICT q = - activations.qkv.data() + (pos * (kHeads + 2) + head) * kQKVDim; - float* HWY_RESTRICT head_att = - activations.att.data() + (pos * kHeads + head) * kSeqLen; - for (size_t pos2 = 0; pos2 <= pos; ++pos2) { - const float* HWY_RESTRICT k2 = - activations.qkv.data() + (pos2 * (kHeads + 2) + kHeads) * kQKVDim; - const float score = Dot(q, k2, kQKVDim); - head_att[pos2] = score; - } - }); - - pool.Run(0, num_tasks, [&](const uint64_t task, size_t thread) HWY_ATTR { - const size_t head = task % kHeads; - const size_t pos = task / kHeads; - float* HWY_RESTRICT head_att = - activations.att.data() + (pos * kHeads + head) * kSeqLen; - Softmax(head_att, pos + 1); - }); - - pool.Run(0, num_tasks, [&](const uint64_t task, size_t thread) HWY_ATTR { - const size_t head = task % kHeads; - const size_t pos = task / kHeads; - const float* HWY_RESTRICT head_att = - activations.att.data() + (pos * kHeads + head) * kSeqLen; - float* HWY_RESTRICT att_out = - activations.att_out.data() + (pos * kHeads + head) * kQKVDim; - hwy::ZeroBytes(att_out, kQKVDim * sizeof(*att_out)); - for (size_t pos2 = 0; pos2 <= pos; ++pos2) { - float* HWY_RESTRICT v2 = - activations.qkv.data() + (pos2 * (kHeads + 2) + kHeads + 1) * kQKVDim; - MulByConstAndAdd(head_att[pos2], v2, att_out, kQKVDim); - } - }); - - activations.attention_out.ZeroInit(); - for (size_t pos = 0; pos < num_tokens; ++pos) { - for (size_t head = 0; head < kHeads; ++head) { - MatVec( - weights.attn_vec_einsum_w, head * model_dim * kQKVDim, model_dim, - kQKVDim, - activations.att_out.data() + pos * kHeads * kQKVDim + head * kQKVDim, - activations.att_post1.data() + pos * model_dim, pool); - AddFrom(activations.att_post1.data() + pos * model_dim, - activations.attention_out.data() + pos * model_dim, model_dim); - } - } - - for (size_t pos = 0; pos < num_tokens; ++pos) { - AddFrom(activations.input.data() + pos * model_dim, - activations.attention_out.data() + pos * model_dim, model_dim); - } - - ApplyRMSNorm(weights.pre_ffw_norm_scale.data(), - activations.attention_out.data(), model_dim, num_tokens, - activations.bf_pre_ffw_rms_out.data(), pool); - const size_t kFFHiddenDim = config.ff_hidden_dim; - for (size_t pos = 0; pos < num_tokens; ++pos) { - MatVec(weights.gating_einsum_w, 0, kFFHiddenDim * 2, model_dim, - activations.bf_pre_ffw_rms_out.data() + pos * model_dim, - activations.ffw_hidden.data() + pos * kFFHiddenDim * 2, pool); - } - for (size_t pos = 0; pos < num_tokens; ++pos) { - const size_t hidden_offset = pos * kFFHiddenDim * 2; - const float* HWY_RESTRICT out = - activations.ffw_hidden.data() + hidden_offset; - const float* HWY_RESTRICT out_mul = out + kFFHiddenDim; - float* HWY_RESTRICT out_gated = - activations.ffw_hidden_gated.data() + pos * kFFHiddenDim; - namespace hn = hwy::HWY_NAMESPACE; - using DF = hn::ScalableTag; - DF df; - for (size_t i = 0; i < kFFHiddenDim; i += Lanes(df)) { - const auto y = hn::Load(df, out + i); - const auto x = hn::Load(df, out_mul + i); - hn::Store(hn::Mul(x, Gelu(df, y)), df, out_gated + i); - } - } - for (size_t pos = 0; pos < num_tokens; ++pos) { - MatVec(weights.linear_w, 0, model_dim, kFFHiddenDim, - activations.ffw_hidden_gated.data() + pos * kFFHiddenDim, - output + pos * model_dim, pool); - } - for (size_t pos = 0; pos < num_tokens; ++pos) { - AddFrom(activations.attention_out.data() + pos * model_dim, - output + pos * model_dim, model_dim); - } -} - -template -float CrossEntropyLossForwardPass(const std::vector& prompt, - size_t context_size, - const ModelWeightsPtrs& weights, - ForwardPass& forward, - const RowVectorBatch& inv_timescale, - hwy::ThreadPool& pool) { - const ModelConfig& config = weights.weights_config; - const size_t vocab_size = config.vocab_size; - const size_t model_dim = config.model_dim; - const size_t layers = config.layer_configs.size(); - const float emb_scaling = EmbeddingScaling(model_dim); - HWY_ASSERT(!config.absolute_pe); - HWY_ASSERT(config.layer_configs[0].post_norm == PostNormType::None); - HWY_ASSERT(config.layer_configs[0].kv_heads == 1); - - HWY_DASSERT(context_size > 0); - HWY_DASSERT(context_size < prompt.size()); - const size_t num_tokens = prompt.size() - 1; - - InputEmbedding(weights.embedder_input_embedding, prompt, emb_scaling, - forward.layers[0].input.data(), model_dim, vocab_size); - - for (size_t layer = 0; layer < config.layer_configs.size(); ++layer) { - auto type = config.layer_configs[layer].type; - // TODO(szabadka) Implement Griffin layer. - HWY_ASSERT(type == LayerAttentionType::kGemma); - float* HWY_RESTRICT output = layer + 1 < layers - ? forward.layers[layer + 1].input.data() - : forward.final_layer_output.data(); - ApplyForwardLayer(*weights.GetLayer(layer), forward.layers[layer], - num_tokens, output, inv_timescale, pool); - } - - ApplyRMSNorm(weights.final_norm_scale.data(), - forward.final_layer_output.data(), model_dim, num_tokens, - forward.final_norm_output.data(), pool); - - for (size_t pos = 0; pos < num_tokens; ++pos) { - MatVec(weights.embedder_input_embedding, 0, vocab_size, model_dim, - forward.final_norm_output.data() + pos * model_dim, - forward.logits.data() + pos * vocab_size, pool); - } - - if (config.final_cap > 0.0f) { - for (size_t pos = 0; pos < num_tokens; ++pos) { - LogitsSoftCap(config.final_cap, forward.logits.data() + pos * vocab_size, - vocab_size); - } - } - - hwy::CopyBytes(forward.logits.data(), forward.probs.data(), - num_tokens * vocab_size * sizeof(forward.logits.At(0))); - - for (size_t pos = 0; pos < num_tokens; ++pos) { - Softmax(forward.probs.data() + pos * vocab_size, vocab_size); - } - - return CrossEntropyLoss(forward.probs.data(), prompt, context_size, - vocab_size, pool); -} - -// NOLINTNEXTLINE(google-readability-namespace-comments) -} // namespace HWY_NAMESPACE -} // namespace gcpp -HWY_AFTER_NAMESPACE(); - -#endif // NOLINT diff --git a/backprop/forward.cc b/backprop/forward.cc deleted file mode 100644 index 0c6cc5c..0000000 --- a/backprop/forward.cc +++ /dev/null @@ -1,68 +0,0 @@ -// Copyright 2024 Google LLC -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "backprop/forward.h" - -#include "backprop/activations.h" -#include "backprop/prompt.h" -#include "gemma/common.h" -#include "gemma/configs.h" -#include "util/allocator.h" -#include "hwy/contrib/thread_pool/thread_pool.h" - -// Compiles this file for multiple architectures via "foreach_target.h", to -// which we pass the filename via macro 'argument'. -#undef HWY_TARGET_INCLUDE -#define HWY_TARGET_INCLUDE "backprop/forward.cc" // NOLINT -#include "hwy/foreach_target.h" // IWYU pragma: keep - -#include "hwy/highway.h" -// After highway.h -#include "backprop/forward-inl.h" -#include "gemma/weights.h" - -HWY_BEFORE_NAMESPACE(); -namespace gcpp { -namespace HWY_NAMESPACE { - -float CrossEntropyLossForwardPassT(const Prompt& prompt, - const ModelWeightsPtrs& weights, - ForwardPass& forward, - RowVectorBatch& inv_timescale, - hwy::ThreadPool& pool) { - return CrossEntropyLossForwardPass(prompt.tokens, prompt.context_size, - weights, forward, inv_timescale, pool); -} - -} // namespace HWY_NAMESPACE -} // namespace gcpp -HWY_AFTER_NAMESPACE(); - -#if HWY_ONCE -namespace gcpp { - -HWY_EXPORT(CrossEntropyLossForwardPassT); - -float CrossEntropyLossForwardPass(const Prompt& prompt, - const ModelWeightsPtrs& weights, - ForwardPass& forward, - RowVectorBatch& inv_timescale, - hwy::ThreadPool& pool) { - return HWY_DYNAMIC_DISPATCH(CrossEntropyLossForwardPassT)( - prompt, weights, forward, inv_timescale, pool); -} - -} // namespace gcpp -#endif // HWY_ONCE diff --git a/backprop/forward.h b/backprop/forward.h deleted file mode 100644 index 3b42298..0000000 --- a/backprop/forward.h +++ /dev/null @@ -1,35 +0,0 @@ -// Copyright 2024 Google LLC -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_H_ -#define THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_H_ - -#include "backprop/activations.h" -#include "backprop/prompt.h" -#include "gemma/weights.h" -#include "util/allocator.h" -#include "hwy/contrib/thread_pool/thread_pool.h" - -namespace gcpp { - -float CrossEntropyLossForwardPass(const Prompt& prompt, - const ModelWeightsPtrs& weights, - ForwardPass& forward, - RowVectorBatch& inv_timescale, - hwy::ThreadPool& pool); - -} // namespace gcpp - -#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_H_ diff --git a/backprop/forward_scalar.h b/backprop/forward_scalar.h deleted file mode 100644 index 617d0c3..0000000 --- a/backprop/forward_scalar.h +++ /dev/null @@ -1,294 +0,0 @@ -// Copyright 2024 Google LLC -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_SCALAR_H_ -#define THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_SCALAR_H_ - -#include -#include - -#include -#include -#include - -#include "backprop/activations.h" -#include "backprop/common_scalar.h" -#include "backprop/prompt.h" -#include "gemma/common.h" // EmbeddingScaling -#include "gemma/weights.h" -#include "hwy/base.h" - -namespace gcpp { - -// w is N x M matrix in row-major order, x is M x K matrix in column-major order -// y = w * x is N x K matrix in column-major order. -template -void MatMulT(const T* w, const T* x, T* y, size_t N, size_t M, size_t K) { - for (size_t i = 0; i < K; ++i) { - for (size_t j = 0; j < N; ++j) { - y[i * N + j] = DotT(&w[j * M], &x[i * M], M); - } - } -} - -// w is H concatenated N x M matrix in row-major order, x is HM x K matrix in -// column-major order and y = w' * x is N x K matrix in column-major order, -// where w' is the rearrangement of w into an N x HM matrix. -template -void MultiHeadMatMul(const T* w, const T* x, T* y, size_t H, size_t N, - size_t M, size_t K) { - memset(y, 0, N * K * sizeof(y[0])); - for (size_t i = 0; i < K; ++i) { - for (size_t h = 0; h < H; ++h) { - for (size_t j = 0; j < N; ++j) { - y[i * N + j] += DotT(&w[h * N * M + j * M], &x[i * H * M + h * M], M); - } - } - } -} - -template -void RMSNormT(const T* w, const T* x, T* out, size_t N, size_t K) { - constexpr T eps(1e-6); - for (size_t i = 0; i < K; ++i) { - T ss = SquaredL2(x + i * N, N); - ss = T(1.0) / std::sqrt(ss / T(N) + eps); - for (size_t j = 0; j < N; j++) { - out[i * N + j] = (T(1.0) + w[j]) * (ss * x[i * N + j]); - } - } -} -template -void Softmax(T* x, size_t N) { - T sum = {}; - auto maxreal = std::real(x[0]); - for (size_t i = 1; i < N; ++i) { - if (std::real(x[i]) > maxreal) { - maxreal = std::real(x[i]); - } - } - for (size_t i = 0; i < N; ++i) { - x[i] = std::exp(x[i] - maxreal); - sum += x[i]; - } - T scale = T(1.0) / sum; - for (size_t i = 0; i < N; ++i) { - x[i] *= scale; - } -} -template -void Softmax(T* x, size_t N, size_t K) { - for (size_t i = 0; i < K; ++i) { - Softmax(x + i * N, N); - } -} -template -void Softcap(float cap, T* x, size_t N) { - const T inv_cap = T{1.0} / static_cast(cap); - for (size_t i = 0; i < N; ++i) { - x[i] = static_cast(cap) * std::tanh(x[i] * inv_cap); - } -} - -template -void GatedGelu(const T* in, T* out, size_t N, size_t K) { - for (size_t i = 0; i < K; ++i) { - const T* x1 = in + i * 2 * N; - const T* x2 = x1 + N; - T* y = out + i * N; - for (size_t j = 0; j < N; ++j) { - y[j] = x2[j] * Gelu(x1[j]); - } - } -} - -template -void InputEmbedding(const T* w, const std::vector& tokens, T scaling, - T* y, size_t N) { - HWY_ASSERT(w != nullptr); - HWY_ASSERT(y != nullptr); - const size_t num_tokens = tokens.empty() ? 0 : tokens.size() - 1; - for (size_t i = 0; i < num_tokens; ++i) { - int token = tokens[i]; - memcpy(y + i * N, w + token * N, N * sizeof(y[0])); - MulByConstT(scaling, y + i * N, N); - } -} - -template -void MaskedAttention(const T* qkv, T* output, size_t num_tokens, size_t heads, - size_t qkv_dim, size_t seq_len) { - for (size_t pos = 0; pos < num_tokens; ++pos) { - for (size_t head = 0; head < heads; ++head) { - const size_t qoffset = pos * (heads + 2) * qkv_dim; - const size_t aoffset = pos * heads * seq_len + head * seq_len; - const T* q = qkv + qoffset + head * qkv_dim; - for (size_t pos2 = 0; pos2 <= pos; ++pos2) { - const T* k = qkv + (pos2 * (heads + 2) + heads) * qkv_dim; - output[aoffset + pos2] = DotT(q, k, qkv_dim); - } - } - } -} -template -void MaskedSoftmax(T* x, size_t num_tokens, size_t heads, size_t seq_len) { - for (size_t pos = 0; pos < num_tokens; ++pos) { - for (size_t head = 0; head < heads; ++head) { - size_t offset = pos * heads * seq_len + head * seq_len; - Softmax(x + offset, pos + 1); - memset(x + offset + pos + 1, 0, (seq_len - pos - 1) * sizeof(T)); - } - } -} -template -void MixByAttention(const T* qkv, const T* attention, T* output, - size_t num_tokens, size_t heads, size_t qkv_dim, - size_t seq_len) { - for (size_t pos = 0; pos < num_tokens; ++pos) { - for (size_t head = 0; head < heads; ++head) { - const T* att = &attention[pos * heads * seq_len + head * seq_len]; - T* out = &output[head * qkv_dim + pos * heads * qkv_dim]; - memset(out, 0, qkv_dim * sizeof(out[0])); - for (size_t pos2 = 0; pos2 <= pos; ++pos2) { - size_t v_offset = (pos2 * (heads + 2) + heads + 1) * qkv_dim; - const T* v = &qkv[v_offset]; - MulByConstAndAddT(att[pos2], v, out, qkv_dim); - } - } - } -} -template -void ApplyLayer(const LayerWeightsPtrs& weights, - ForwardLayer& activations, size_t num_tokens, T* output) { - const LayerConfig& layer_config = weights.layer_config; - const size_t model_dim = layer_config.model_dim; - const size_t seq_len = activations.input.Rows(); - const size_t qkv_dim = layer_config.qkv_dim; - const size_t heads = layer_config.heads; - const size_t ff_hidden_dim = layer_config.ff_hidden_dim; - static const T query_scale = T(1.0) / std::sqrt(T(qkv_dim)); - - RMSNormT(weights.pre_attention_norm_scale.data(), activations.input.data(), - activations.pre_att_rms_out.data(), model_dim, num_tokens); - - MatMulT(weights.qkv_einsum_w.data(), activations.pre_att_rms_out.data(), - activations.qkv.data(), (heads + 2) * qkv_dim, model_dim, num_tokens); - - for (size_t pos = 0; pos < num_tokens; ++pos) { - T* qkv = activations.qkv.data() + pos * (heads + 2) * qkv_dim; - for (size_t h = 0; h <= heads; ++h) { - Rope(qkv + h * qkv_dim, qkv_dim, pos); - } - } - - for (size_t pos = 0; pos < num_tokens; ++pos) { - T* qkv = activations.qkv.data() + pos * (heads + 2) * qkv_dim; - MulByConstT(query_scale, qkv, heads * qkv_dim); - } - - MaskedAttention(activations.qkv.data(), activations.att.data(), num_tokens, - heads, qkv_dim, seq_len); - - MaskedSoftmax(activations.att.data(), num_tokens, heads, seq_len); - - MixByAttention(activations.qkv.data(), activations.att.data(), - activations.att_out.data(), num_tokens, heads, qkv_dim, - seq_len); - - MultiHeadMatMul(weights.attn_vec_einsum_w.data(), activations.att_out.data(), - activations.attention_out.data(), heads, model_dim, qkv_dim, - num_tokens); - - AddFromT(activations.input.data(), activations.attention_out.data(), - num_tokens * model_dim); - - RMSNormT(weights.pre_ffw_norm_scale.data(), activations.attention_out.data(), - activations.bf_pre_ffw_rms_out.data(), model_dim, num_tokens); - - MatMulT(weights.gating_einsum_w.data(), activations.bf_pre_ffw_rms_out.data(), - activations.ffw_hidden.data(), ff_hidden_dim * 2, model_dim, - num_tokens); - - GatedGelu(activations.ffw_hidden.data(), activations.ffw_hidden_gated.data(), - ff_hidden_dim, num_tokens); - - MatMulT(weights.linear_w.data(), activations.ffw_hidden_gated.data(), output, - model_dim, ff_hidden_dim, num_tokens); - - AddFromT(activations.attention_out.data(), output, num_tokens * model_dim); -} - -template -T CrossEntropyLoss(const T* x, const Prompt& prompt, size_t V) { - T loss = {}; - const std::vector tokens = prompt.tokens; - const size_t num_tokens = tokens.empty() ? 0 : tokens.size() - 1; - for (size_t i = 0; i < num_tokens; ++i) { - if (i + 1 < prompt.context_size) { - continue; // next token is part of context, don't try to predict it - } - const int next_token = tokens[i + 1]; - loss += std::log(x[i * V + next_token]); - } - T scaling = -1.0 / std::log(2.0); - return loss * scaling; -} - -template -T CrossEntropyLossForwardPass(const Prompt& prompt, - const ModelWeightsPtrs& weights, - ForwardPass& forward) { - const ModelConfig& config = weights.weights_config; - const size_t model_dim = config.model_dim; - const size_t vocab_size = config.vocab_size; - const size_t layers = config.layer_configs.size(); - const std::vector tokens = prompt.tokens; - const size_t num_tokens = tokens.empty() ? 0 : tokens.size() - 1; - - const T kEmbScaling = EmbeddingScaling(model_dim); - InputEmbedding(weights.embedder_input_embedding.data(), tokens, kEmbScaling, - forward.layers[0].input.data(), model_dim); - - for (size_t layer = 0; layer < layers; ++layer) { - T* output = layer + 1 < layers ? forward.layers[layer + 1].input.data() - : forward.final_layer_output.data(); - ApplyLayer(*weights.GetLayer(layer), forward.layers[layer], num_tokens, - output); - } - - RMSNormT(weights.final_norm_scale.data(), forward.final_layer_output.data(), - forward.final_norm_output.data(), model_dim, num_tokens); - - MatMulT(weights.embedder_input_embedding.data(), - forward.final_norm_output.data(), forward.logits.data(), vocab_size, - model_dim, num_tokens); - - for (size_t pos = 0; pos < num_tokens; ++pos) { - if (config.final_cap > 0.0f) { - Softcap(config.final_cap, forward.logits.data() + pos * vocab_size, - vocab_size); - } - } - - memcpy(forward.probs.data(), forward.logits.data(), - num_tokens * vocab_size * sizeof(forward.logits.At(0))); - Softmax(forward.probs.data(), vocab_size, num_tokens); - - return CrossEntropyLoss(forward.probs.data(), prompt, vocab_size); -} - -} // namespace gcpp - -#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_SCALAR_H_ diff --git a/backprop/optimize_test.cc b/backprop/optimize_test.cc deleted file mode 100644 index 6f08bf0..0000000 --- a/backprop/optimize_test.cc +++ /dev/null @@ -1,155 +0,0 @@ -// Copyright 2024 Google LLC -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include - -#include -#include -#include -#include - -#include "gtest/gtest.h" -#include "backprop/activations.h" -#include "backprop/backward.h" -#include "backprop/forward.h" -#include "backprop/optimizer.h" -#include "backprop/prompt.h" -#include "backprop/sampler.h" -#include "compression/shared.h" -#include "gemma/common.h" -#include "gemma/configs.h" -#include "gemma/gemma.h" -#include "gemma/weights.h" -#include "ops/ops.h" -#include "util/allocator.h" -#include "util/basics.h" -#include "util/threading.h" -#include "hwy/contrib/thread_pool/thread_pool.h" - -namespace gcpp { - -TEST(OptimizeTest, GradientDescent) { - const BoundedTopology topology(BoundedSlice(0, 1), BoundedSlice(0, 1)); - Allocator::Init(topology); - NestedPools pools(topology, 1, /*pin=*/Tristate::kFalse); - MatMulEnv env(topology, pools); - hwy::ThreadPool& pool = pools.Pool(); - std::mt19937 gen(42); - - const ModelInfo info = { - .model = Model::GEMMA_TINY, - .wrapping = PromptWrapping::GEMMA_IT, - .weight = Type::kF32, - }; - ModelConfig config = ConfigFromModel(info.model); - ModelWeightsStorage grad, grad_m, grad_v; - grad.Allocate(info.model, info.weight, pool); - grad_m.Allocate(info.model, info.weight, pool); - grad_v.Allocate(info.model, info.weight, pool); - grad_m.ZeroInit(); - grad_v.ZeroInit(); - ForwardPass forward(config), backward(config); - KVCache kv_cache = KVCache::Create(config, /*prefill_tbatch_size=*/16); - - RowVectorBatch inv_timescale = CreateInvTimescale( - config.layer_configs[0].qkv_dim, - config.layer_configs[0].post_qk == PostQKType::HalfRope); - - Gemma gemma(GemmaTokenizer(), info, env); - - const auto generate = [&](const std::vector& prompt) { - std::vector reply; - auto stream_token = [&reply](int token, float) { - reply.push_back(token); - return token != ReverseSequenceSampler::kEndToken; - }; - RuntimeConfig runtime = { - .max_generated_tokens = 16, - .temperature = 1.0f, - .gen = &gen, - .verbosity = 0, - .stream_token = stream_token, - .eos_id = ReverseSequenceSampler::kEndToken, - }; - TimingInfo timing_info; - gemma.Generate(runtime, prompt, 0, kv_cache, timing_info); - return reply; - }; - - // Sanity check of reply tokens. - // 1) Its length should be greater than the prompt. - // 2) The prompt should be a prefix of the reply. - auto verify = [&](const Prompt& prompt) { - const std::vector& context = prompt.context(); - std::vector reply = generate(context); - if (reply.size() <= context.size()) return false; - return std::equal(context.begin(), context.end(), reply.begin(), - reply.begin() + context.size()); - }; - - gemma.MutableWeights().RandInit(gen); - gemma.MutableWeights().AllocAndCopyWithTranspose(pool); - - printf("Initial weights:\n"); - gemma.MutableWeights().LogWeightStats(); - - constexpr size_t kBatchSize = 8; - const float alpha = 0.001f; - const float beta1 = 0.9f; - const float beta2 = 0.999f; - const float epsilon = 1e-8f; - - ReverseSequenceSampler training_task({ - 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1}); - size_t steps = 0; - size_t num_ok; - for (; steps < 1000000; ++steps) { - std::mt19937 sgen(42); - grad.ZeroInit(); - float total_loss = 0.0f; - num_ok = 0; - for (size_t i = 0; i < kBatchSize; ++i) { - Prompt prompt = training_task.Sample(sgen); - total_loss += CrossEntropyLossForwardPass( - prompt, *gemma.Weights().GetWeightsOfType(), forward, - inv_timescale, pool); - CrossEntropyLossBackwardPass( - prompt, *gemma.Weights().GetWeightsOfType(), forward, - *grad.GetWeightsOfType(), backward, inv_timescale, pool); - gemma.MutableWeights().CopyWithTranspose(pool); - num_ok += verify(prompt) ? 1 : 0; - } - total_loss /= kBatchSize; - - AdamUpdate(info.weight, grad, alpha, beta1, beta2, epsilon, steps + 1, - gemma.Weights(), grad_m, grad_v, pool); - printf("step: %zu total_loss: %.15f num_ok: %zu/%zu\n", - steps, total_loss, num_ok, kBatchSize); - if (steps % 100 == 0) { - printf("Batch gradient:\n"); - grad.LogWeightStats(); - } - if (total_loss < 0.5f) { - break; - } - } - printf("Num steps: %zu\n", steps); - printf("Final weights:\n"); - gemma.MutableWeights().LogWeightStats(); - EXPECT_LT(steps, 300); - EXPECT_EQ(num_ok, kBatchSize); -} - -} // namespace gcpp diff --git a/backprop/optimizer.cc b/backprop/optimizer.cc deleted file mode 100644 index 9187bf7..0000000 --- a/backprop/optimizer.cc +++ /dev/null @@ -1,95 +0,0 @@ -// Copyright 2024 Google LLC -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "backprop/optimizer.h" - -#include - -#include "compression/compress.h" -#include "gemma/common.h" -#include "gemma/weights.h" -#include "util/allocator.h" -#include "hwy/aligned_allocator.h" -#include "hwy/base.h" -#include "hwy/contrib/thread_pool/thread_pool.h" - -namespace gcpp { - -namespace { - -class AdamUpdater { - public: - explicit AdamUpdater(float alpha, float beta1, float beta2, float epsilon, - size_t t) - : alpha_(alpha), beta1_(beta1), beta2_(beta2), cbeta1_(1.0f - beta1), - cbeta2_(1.0f - beta2), norm1_(1.0 / (1.0 - std::pow(beta1, t))), - norm2_(1.0 / (1.0 - std::pow(beta2, t))), epsilon_(epsilon) {} - - void operator()(const char* name, const MatPtr& grad, MatPtr& weights, - MatPtr& grad_m, MatPtr& grad_v) { - const float* HWY_RESTRICT g = grad.data(); - float* HWY_RESTRICT w = weights.data(); - float* HWY_RESTRICT m = grad_m.data(); - float* HWY_RESTRICT v = grad_v.data(); - for (size_t i = 0; i < grad.NumElements(); ++i) { - m[i] *= beta1_; - m[i] += cbeta1_ * g[i]; - v[i] *= beta2_; - v[i] += cbeta2_ * g[i] * g[i]; - const float mhat = m[i] * norm1_; - const float vhat = v[i] * norm2_; - w[i] -= alpha_ * mhat / (std::sqrt(vhat) + epsilon_); - } - } - - private: - float alpha_; - float beta1_; - float beta2_; - float cbeta1_; - float cbeta2_; - float norm1_; - float norm2_; - float epsilon_; -}; - -void AdamUpdate(ModelWeightsPtrs* grad, float alpha, float beta1, - float beta2, float epsilon, size_t t, - ModelWeightsPtrs* weights, - ModelWeightsPtrs* grad_m, - ModelWeightsPtrs* grad_v, hwy::ThreadPool& pool) { - AdamUpdater updater(alpha, beta1, beta2, epsilon, t); - ModelWeightsPtrs::ForEachTensor( - {grad, weights, grad_m, grad_v}, ForEachType::kLoadNoToc, - [&updater](const char* name, hwy::Span tensors) { - updater(name, *tensors[0], *tensors[1], *tensors[2], *tensors[3]); - }); -} - -} // namespace - -void AdamUpdate(Type weight_type, const ModelWeightsStorage& grad, float alpha, - float beta1, float beta2, float epsilon, size_t t, - const ModelWeightsStorage& weights, - const ModelWeightsStorage& grad_m, - const ModelWeightsStorage& grad_v, hwy::ThreadPool& pool) { - HWY_ASSERT(weight_type == Type::kF32); - AdamUpdate(grad.GetWeightsOfType(), alpha, beta1, beta2, epsilon, t, - weights.GetWeightsOfType(), - grad_m.GetWeightsOfType(), grad_v.GetWeightsOfType(), - pool); -} - -} // namespace gcpp diff --git a/backprop/optimizer.h b/backprop/optimizer.h deleted file mode 100644 index 8b25c52..0000000 --- a/backprop/optimizer.h +++ /dev/null @@ -1,33 +0,0 @@ -// Copyright 2024 Google LLC -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_OPTIMIZER_H_ -#define THIRD_PARTY_GEMMA_CPP_GEMMA_OPTIMIZER_H_ - -#include "gemma/common.h" -#include "gemma/weights.h" -#include "hwy/contrib/thread_pool/thread_pool.h" - -namespace gcpp { - -void AdamUpdate(Type weight_type, const ModelWeightsStorage& grad, float alpha, - float beta1, float beta2, float epsilon, size_t t, - const ModelWeightsStorage& weights, - const ModelWeightsStorage& grad_m, - const ModelWeightsStorage& grad_v, hwy::ThreadPool& pool); - -} // namespace gcpp - -#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_OPTIMIZER_H_ diff --git a/backprop/sampler.h b/backprop/sampler.h deleted file mode 100644 index 17f5762..0000000 --- a/backprop/sampler.h +++ /dev/null @@ -1,90 +0,0 @@ -// Copyright 2024 Google LLC -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_SAMPLER_H_ -#define THIRD_PARTY_GEMMA_CPP_GEMMA_SAMPLER_H_ - -#include -#include - -#include -#include - -#include "backprop/prompt.h" - -namespace gcpp { - -class PromptSampler { - public: - virtual Prompt Sample(std::mt19937& gen) = 0; - virtual ~PromptSampler() = default; - - std::vector SampleBatch(size_t batch_size, std::mt19937& gen) { - std::vector batch; - batch.reserve(batch_size); - for (size_t i = 0; i < batch_size; ++i) { - batch.emplace_back(Sample(gen)); - } - return batch; - } -}; - -class ReverseSequenceSampler : public PromptSampler { - public: - explicit ReverseSequenceSampler(const std::vector& length_histo) - : token_dist_(0, 9) { - for (int i = 0; i < length_histo.size(); ++i) { - const int count = length_histo[i]; - for (int j = 0; j < count; ++j) { - length_lut_.push_back(i + 1); - } - } - length_dist_ = std::uniform_int_distribution<>(0, length_lut_.size() - 1); - } - virtual ~ReverseSequenceSampler() = default; - - static constexpr int kReverseToken = 10; - static constexpr int kEndToken = 11; - - Prompt Sample(std::mt19937& gen) override { - Prompt prompt; - int len = length_lut_[length_dist_(gen)]; - prompt.tokens.resize(2 * len + 2); - prompt.tokens[len] = kReverseToken; - prompt.tokens[2 * len + 1] = kEndToken; - for (size_t i = 0; i < len; ++i) { - prompt.tokens[i] = prompt.tokens[2 * len - i] = token_dist_(gen); - } - prompt.context_size = len + 1; - return prompt; - } - - static void LogPrompt(const Prompt& prompt) { - static const char* kVocab[] = { - "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "-->", "|", - }; - for (int token : prompt.tokens) printf("%s", kVocab[token]); - printf(" [context_size: %zu]\n", prompt.context_size); - } - - private: - std::uniform_int_distribution<> token_dist_; - std::uniform_int_distribution<> length_dist_; - std::vector length_lut_; -}; - -} // namespace gcpp - -#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_SAMPLER_H_ diff --git a/backprop/test_util.h b/backprop/test_util.h deleted file mode 100644 index a83e3d5..0000000 --- a/backprop/test_util.h +++ /dev/null @@ -1,209 +0,0 @@ -// Copyright 2024 Google LLC -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_TEST_UTIL_H_ -#define THIRD_PARTY_GEMMA_CPP_GEMMA_TEST_UTIL_H_ - -#include - -#include -#include -#include -#include - -#include "gtest/gtest.h" -#include "compression/compress.h" -#include "gemma/configs.h" -#include "gemma/weights.h" -#include "hwy/contrib/thread_pool/thread_pool.h" - -namespace gcpp { - -template -void RandInit(MatPtrT& x, T stddev, std::mt19937& gen) { - std::normal_distribution dist(0.0, stddev); - for (size_t i = 0; i < x.NumElements(); ++i) { - x.At(i) = dist(gen); - } -} - -// TODO: make a member of Layer. -template -void RandInit(LayerWeightsPtrs& w, T stddev, std::mt19937& gen) { - RandInit(w.pre_attention_norm_scale, stddev, gen); - RandInit(w.attn_vec_einsum_w, stddev, gen); - RandInit(w.qkv_einsum_w, stddev, gen); - RandInit(w.pre_ffw_norm_scale, stddev, gen); - RandInit(w.gating_einsum_w, stddev, gen); - RandInit(w.linear_w, stddev, gen); -} - -template -void RandInit(ModelWeightsPtrs& w, T stddev, std::mt19937& gen) { - const size_t kLayers = w.c_layers.size(); - RandInit(w.embedder_input_embedding, stddev, gen); - RandInit(w.final_norm_scale, stddev, gen); - for (size_t i = 0; i < kLayers; ++i) { - RandInit(*w.GetLayer(i), stddev, gen); - } -} - -template -void Complexify(const MatPtrT& x, MatPtrT>& c_x) { - for (size_t i = 0; i < x.NumElements(); ++i) { - c_x.At(i) = std::complex(x.At(i), 0.0); - } -} - -template -void Complexify(const LayerWeightsPtrs& w, LayerWeightsPtrs& c_w) { - Complexify(w.pre_attention_norm_scale, c_w.pre_attention_norm_scale); - Complexify(w.attn_vec_einsum_w, c_w.attn_vec_einsum_w); - Complexify(w.qkv_einsum_w, c_w.qkv_einsum_w); - Complexify(w.pre_ffw_norm_scale, c_w.pre_ffw_norm_scale); - Complexify(w.gating_einsum_w, c_w.gating_einsum_w); - Complexify(w.linear_w, c_w.linear_w); -} - -template -void Complexify(const ModelWeightsPtrs& w, ModelWeightsPtrs& c_w) { - const size_t kLayers = w.c_layers.size(); - Complexify(w.embedder_input_embedding, c_w.embedder_input_embedding); - Complexify(w.final_norm_scale, c_w.final_norm_scale); - for (size_t i = 0; i < kLayers; ++i) { - Complexify(*w.GetLayer(i), *c_w.GetLayer(i)); - } -} - -// Somewhat duplicates ModelWeightsStorage, but that has neither double nor -// complex types allowed and it would cause code bloat to add them there. -template -class WeightsWrapper { - public: - explicit WeightsWrapper(const ModelConfig& config) - : pool_(0), weights_(config) { - weights_.Allocate(data_, pool_); - } - - const ModelWeightsPtrs& get() const { return weights_; } - ModelWeightsPtrs& get() { return weights_; } - void ZeroInit() { weights_.ZeroInit(); } - void CopyFrom(const WeightsWrapper& other) { - weights_.CopyFrom(other.weights_); - } - - private: - hwy::ThreadPool pool_; - std::vector data_; - ModelWeightsPtrs weights_; -}; - -template -void TestNear(const MatPtrT& actual, const MatPtrT& expected, - double max_abs_err, double max_rel_err, int line) { - double sum0 = 0; - double sum1 = 0; - double sum01 = 0; - for (size_t i = 0; i < actual.NumElements(); ++i) { - sum0 += actual.At(i) * actual.At(i); - sum1 += expected.At(i) * expected.At(i); - sum01 += actual.At(i) * expected.At(i); - ASSERT_NEAR(actual.At(i), expected.At(i), - std::max(max_abs_err, std::abs(expected.At(i)) * max_rel_err)) - << "line: " << line << " dim=" << expected.NumElements() << " i=" << i; - } - if (sum0 > 1e-40) { - double norm_dot = sum01 / std::sqrt(sum0) / std::sqrt(sum1); - ASSERT_NEAR(norm_dot, 1.0, 1e-7) - << "line: " << line << " sum0: " << sum0 << " sum1: " << sum1 - << " sum01: " << sum01; - } -} - -// Compute gradient with the finite difference method in the complex plane. -// If f : R->R is the tested function and F : C->C is its extension on the -// complex plane so that F is complex differentiable in x, then -// -// F(x + ih) = F(x) + ih F'(x) + O(h^2) F''(x) -// -// which means that -// -// F'(x) ~= Imag(F(x + ih)) / h -// -// This method is more numerically stable than the real-valued finite difference -// method since we don't need to subtract floating point numbers that are near -// to each other. -template -void TestGradient(const MatPtrT& grad, MatPtrT>& x, - FUNC func, U step, T max_abs_err, T max_rel_err, int line) { - MatStorageT exp_grad("exp_grad", x.Rows(), x.Cols()); - const U inv_step = 1.0 / step; - for (size_t i = 0; i < x.NumElements(); ++i) { - const U x0 = std::real(x.At(i)); - const std::complex x1 = std::complex(x0, step); - x.At(i) = x1; - const std::complex f1 = func(); - exp_grad.At(i) = std::imag(f1) * inv_step; - x.At(i) = x0; - } - TestNear(grad, exp_grad, max_abs_err, max_rel_err, line); -} - -template -void TestGradient(const MatPtrT& grad, MatPtrT>& x, - FUNC func, float max_abs_err, float max_rel_error, int line) { - TestGradient(grad, x, func, 1e-30f, max_abs_err, max_rel_error, line); -} - -template -void TestGradient(const MatPtrT& grad, MatPtrT>& x, - FUNC func, T max_abs_err, T max_rel_error, int line) { - TestGradient(grad, x, func, 1e-50, max_abs_err, max_rel_error, line); -} - -template -void TestGradient(const LayerWeightsPtrs& grad, - LayerWeightsPtrs& c_weights, FUNC func, T max_err) { - TestGradient(grad.pre_attention_norm_scale, - c_weights.pre_attention_norm_scale, - func, max_err, max_err, __LINE__); - TestGradient(grad.attn_vec_einsum_w, c_weights.attn_vec_einsum_w, - func, max_err, max_err, __LINE__); - TestGradient(grad.qkv_einsum_w, c_weights.qkv_einsum_w, - func, max_err, max_err, __LINE__); - TestGradient(grad.pre_ffw_norm_scale, c_weights.pre_ffw_norm_scale, - func, max_err, max_err, __LINE__); - TestGradient(grad.gating_einsum_w, c_weights.gating_einsum_w, - func, max_err, max_err, __LINE__); - TestGradient(grad.linear_w, c_weights.linear_w, - func, max_err, max_err, __LINE__); -} - -template -void TestGradient(const ModelWeightsPtrs& grad, - ModelWeightsPtrs& c_weights, FUNC func, T max_err) { - TestGradient(grad.embedder_input_embedding, - c_weights.embedder_input_embedding, - func, 2 * max_err, max_err, __LINE__); - TestGradient(grad.final_norm_scale, c_weights.final_norm_scale, - func, max_err, max_err, __LINE__); - for (size_t i = 0; i < grad.c_layers.size(); ++i) { - TestGradient(*grad.GetLayer(i), *c_weights.GetLayer(i), func, max_err); - } -} - -} // namespace gcpp - -#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_TEST_UTIL_H_ diff --git a/build/.gitignore b/build/.gitignore deleted file mode 100644 index 3822a0b..0000000 --- a/build/.gitignore +++ /dev/null @@ -1,3 +0,0 @@ -* -!.gitignore -!.hgignore \ No newline at end of file diff --git a/compression/BUILD.bazel b/compression/BUILD.bazel index f12ca59..0104ef5 100644 --- a/compression/BUILD.bazel +++ b/compression/BUILD.bazel @@ -1,4 +1,4 @@ -# Weight compression, I/O and analysis +# Weight compression and analysis. package( default_applicable_licenses = [ @@ -20,78 +20,11 @@ config_setting( visibility = ["//visibility:private"], ) -FILE_DEPS = select({ - "//conditions:default": [ - # Placeholder for io deps, do not remove - ], - ":android": [], - # Placeholder for internal build rules, do not remove -}) - -cc_library( - name = "io", - srcs = [ - "io.cc", - # Placeholder for io backend, do not remove - ], - hdrs = ["io.h"], - local_defines = select({ - # Placeholder for internal build rules, do not remove - "//conditions:default": [], - }), - deps = [ - "@highway//:hwy", - ] + FILE_DEPS, -) - -cc_library( - name = "fields", - srcs = ["fields.cc"], - hdrs = ["fields.h"], - deps = [ - "@highway//:hwy", - ], -) - -cc_test( - name = "fields_test", - srcs = ["fields_test.cc"], - deps = [ - ":fields", - "@googletest//:gtest_main", # buildcleaner: keep - "@highway//:hwy_test_util", - ], -) - -cc_library( - name = "blob_store", - srcs = ["blob_store.cc"], - hdrs = ["blob_store.h"], - deps = [ - ":io", - "@highway//:hwy", - "@highway//:thread_pool", - ], -) - -cc_test( - name = "blob_store_test", - srcs = ["blob_store_test.cc"], - deps = [ - ":blob_store", - ":io", - "@googletest//:gtest_main", # buildcleaner: keep - "@highway//:hwy", - "@highway//:hwy_test_util", - "@highway//:thread_pool", - ], -) - cc_library( name = "distortion", hdrs = [ "distortion.h", - "shared.h", + "types.h", ], deps = [ "//:basics", @@ -115,21 +48,29 @@ cc_test( ) cc_library( - name = "sfp", - hdrs = ["shared.h"], - textual_hdrs = ["sfp-inl.h"], + name = "types", + hdrs = ["types.h"], deps = [ "//:basics", "@highway//:hwy", ], ) +cc_library( + name = "sfp", + textual_hdrs = ["sfp-inl.h"], + deps = [ + ":types", + "@highway//:hwy", + ], +) + cc_library( name = "nuq", - hdrs = ["shared.h"], textual_hdrs = ["nuq-inl.h"], deps = [ ":sfp", + ":types", "//:basics", "@highway//:hwy", "@highway//hwy/contrib/sort:vqsort", @@ -144,8 +85,10 @@ cc_library( deps = [ ":compress", ":distortion", + "//:mat", "@highway//:hwy", "@highway//:hwy_test_util", + "@highway//:thread_pool", ], ) @@ -153,7 +96,6 @@ cc_test( name = "sfp_test", size = "small", srcs = ["sfp_test.cc"], - features = ["fully_static_link"], linkstatic = True, local_defines = ["HWY_IS_TEST"], # for test_suite. @@ -174,7 +116,6 @@ cc_test( size = "small", timeout = "long", srcs = ["nuq_test.cc"], - features = ["fully_static_link"], linkstatic = True, local_defines = ["HWY_IS_TEST"], # for test_suite. @@ -182,7 +123,6 @@ cc_test( deps = [ ":distortion", ":nuq", - ":sfp", "@googletest//:gtest_main", # buildcleaner: keep "//:test_util", "@highway//:hwy", @@ -196,21 +136,18 @@ cc_library( srcs = ["compress.cc"], hdrs = [ "compress.h", - "shared.h", + "types.h", ], textual_hdrs = ["compress-inl.h"], deps = [ - ":blob_store", ":distortion", - ":fields", - ":io", ":nuq", ":sfp", - "//:allocator", "//:basics", - "//:common", + "//:mat", "@highway//:hwy", "@highway//:nanobenchmark", + "@highway//:profiler", "@highway//:stats", "@highway//:thread_pool", ], @@ -221,7 +158,6 @@ cc_test( size = "small", timeout = "long", srcs = ["compress_test.cc"], - features = ["fully_static_link"], linkstatic = True, local_defines = ["HWY_IS_TEST"], # for test_suite. @@ -245,51 +181,10 @@ cc_library( deps = [ ":nuq", ":sfp", + ":types", "@highway//:hwy", "@highway//:stats", "@highway//:thread_pool", "@highway//hwy/contrib/sort:vqsort", ], ) - -cc_binary( - name = "compress_weights", - srcs = ["compress_weights.cc"], - deps = [ - ":compress", - ":io", - "//:allocator", - "//:args", - "//:common", - "//:tokenizer", - "//:weights", - "@highway//:hwy", - "@highway//:thread_pool", - ], -) - -cc_binary( - name = "blob_compare", - srcs = ["blob_compare.cc"], - deps = [ - ":blob_store", - ":io", - "//:allocator", - "//:basics", - "//:threading", - "@highway//:hwy", - "@highway//:hwy_test_util", - "@highway//:nanobenchmark", - ], -) - -cc_binary( - name = "migrate_weights", - srcs = ["migrate_weights.cc"], - deps = [ - "//:app", - "//:args", - "//:benchmark_helper", - "//:gemma_lib", - ], -) diff --git a/compression/analyze.h b/compression/analyze.h index 38537db..7d41633 100644 --- a/compression/analyze.h +++ b/compression/analyze.h @@ -26,7 +26,7 @@ #include // std::abs #include -#include "compression/shared.h" +#include "compression/types.h" #include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/stats.h" diff --git a/compression/blob_compare.cc b/compression/blob_compare.cc deleted file mode 100644 index c0fe63c..0000000 --- a/compression/blob_compare.cc +++ /dev/null @@ -1,230 +0,0 @@ -// Copyright 2024 Google LLC -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include -#include - -#include -#include - -#include "compression/blob_store.h" -#include "compression/io.h" // Path -#include "util/allocator.h" -#include "util/basics.h" // IndexRange -#include "util/threading.h" -#include "hwy/aligned_allocator.h" // Span -#include "hwy/base.h" -#include "hwy/timer.h" - -namespace gcpp { - -using KeySpan = hwy::Span; - -// Returns false if any keys differ, because then blobs are not comparable. -bool CompareKeys(const BlobReader& reader1, const BlobReader& reader2) { - KeySpan keys1 = reader1.Keys(); - KeySpan keys2 = reader2.Keys(); - if (keys1.size() != keys2.size()) { - fprintf(stderr, "#keys mismatch: %zu vs %zu\n", keys1.size(), keys2.size()); - return false; - } - for (size_t i = 0; i < keys1.size(); ++i) { - if (keys1[i] != keys2[i]) { - fprintf(stderr, "key %zu mismatch: %s vs %s\n", i, - StringFromKey(keys1[i]).c_str(), StringFromKey(keys2[i]).c_str()); - return false; - } - } - - return true; -} - -// Total amount to allocate for all blobs. -size_t TotalBytes(BlobReader& reader) { - size_t total_bytes = 0; - for (const hwy::uint128_t key : reader.Keys()) { - total_bytes += reader.BlobSize(key); - } - return total_bytes; -} - -using BytePtr = hwy::AlignedFreeUniquePtr; -using ByteSpan = hwy::Span; // Sections within BytePtr -using BlobVec = std::vector; // in order of keys - -// Allocates memory within the single allocation and updates `pos`. -BlobVec ReserveMemory(BlobReader& reader, BytePtr& all_blobs, size_t& pos) { - BlobVec blobs; - for (const hwy::uint128_t key : reader.Keys()) { - const size_t bytes = reader.BlobSize(key); - blobs.push_back(ByteSpan(all_blobs.get() + pos, bytes)); - pos += bytes; - } - return blobs; -} - -// Reads one set of blobs in parallel (helpful if in disk cache). -void ReadBlobs(BlobReader& reader, BlobVec& blobs, hwy::ThreadPool& pool) { - HWY_ASSERT(reader.Keys().size() == blobs.size()); - for (size_t i = 0; i < blobs.size(); ++i) { - reader.Enqueue(reader.Keys()[i], blobs[i].data(), blobs[i].size()); - } - const BlobError err = reader.ReadAll(pool); - if (err != 0) { - HWY_ABORT("Parallel read failed: %d\n", err); - } -} - -// Parallelizes ReadBlobs across (two) packages, if available. -void ReadBothBlobs(BlobReader& reader1, BlobReader& reader2, size_t total_bytes, - BlobVec& blobs1, BlobVec& blobs2, NestedPools& pools) { - const double t0 = hwy::platform::Now(); - fprintf(stderr, "Reading %zu GiB, %zux%zu cores: ", total_bytes >> 30, - pools.AllPackages().NumWorkers(), pools.Pool().NumWorkers()); - pools.AllPackages().Run(0, 2, [&](size_t task, size_t pkg_idx) { - ReadBlobs(task ? reader2 : reader1, task ? blobs2 : blobs1, - pools.Pool(pkg_idx)); - }); - const double t1 = hwy::platform::Now(); - fprintf(stderr, "%.1f GB/s\n", total_bytes / (t1 - t0) * 1E-9); -} - -// Returns number of elements with a mismatch. For float and bf16 blobs, uses -// L1 and relative error, otherwise byte-wise comparison. -size_t BlobDifferences(const ByteSpan& data1, const ByteSpan& data2, - const hwy::uint128_t key) { - if (data1.size() != data2.size() || data1.size() == 0) { - HWY_ABORT("key %s size mismatch: %zu vs %zu\n", StringFromKey(key).c_str(), - data1.size(), data2.size()); - } - - size_t mismatches = 0; - char type; - hwy::CopyBytes(&key, &type, 1); - if (type == 'F') { - HWY_ASSERT(data1.size() % sizeof(float) == 0); - for (size_t j = 0; j < data1.size(); j += sizeof(float)) { - float f1, f2; - hwy::CopyBytes(&data1[j], &f1, sizeof(f1)); - hwy::CopyBytes(&data2[j], &f2, sizeof(f2)); - const float l1 = hwy::ScalarAbs(f1 - f2); - const float rel = hwy::ScalarAbs(f1) == 0.0f ? 0.0f : l1 / f1; - if (l1 > 1E-3f || rel > 1E-2f) { - fprintf(stderr, "key %s %5zu: L1 %.5f rel %.4f\n", - StringFromKey(key).c_str(), j, l1, rel); - ++mismatches; - } - } - } else if (type == 'B') { - for (size_t j = 0; j < data1.size(); j += sizeof(hwy::bfloat16_t)) { - hwy::bfloat16_t b1, b2; - hwy::CopyBytes(&data1[j], &b1, sizeof(b1)); - hwy::CopyBytes(&data2[j], &b2, sizeof(b2)); - const float f1 = hwy::ConvertScalarTo(b1); - const float f2 = hwy::ConvertScalarTo(b2); - const float l1 = hwy::ScalarAbs(f1 - f2); - const float rel = hwy::ScalarAbs(f1) == 0.0f ? 0.0f : l1 / f1; - if (l1 > 1E-2f || rel > 1E-1f) { - fprintf(stderr, "key %s %5zu: L1 %.5f rel %.4f\n", - StringFromKey(key).c_str(), j, l1, rel); - ++mismatches; - } - } - } else { - for (size_t j = 0; j < data1.size(); ++j) { - if (data1[j] != data2[j]) { - if (mismatches == 0) { - fprintf(stderr, "key %s mismatch at byte %5zu\n", - StringFromKey(key).c_str(), j); - } - ++mismatches; - } - } - } - return mismatches; -} - -void CompareBlobs(const KeySpan& keys, BlobVec& blobs1, BlobVec& blobs2, - size_t total_bytes, NestedPools& pools) { - fprintf(stderr, "Comparing %zu blobs in parallel: ", keys.size()); - const double t0 = hwy::platform::Now(); - std::atomic blobs_equal{}; - std::atomic blobs_diff{}; - const IndexRangePartition ranges = StaticPartition( - IndexRange(0, keys.size()), pools.AllPackages().NumWorkers(), 1); - ParallelizeOneRange( - ranges, pools.AllPackages(), - [&](const IndexRange& range, size_t pkg_idx) { - pools.Pool(pkg_idx).Run( - range.begin(), range.end(), [&](size_t i, size_t /*thread*/) { - const size_t mismatches = - BlobDifferences(blobs1[i], blobs2[i], keys[i]); - if (mismatches != 0) { - fprintf(stderr, "key %s has %zu mismatches in %zu bytes!\n", - StringFromKey(keys[i]).c_str(), mismatches, - blobs1[i].size()); - blobs_diff.fetch_add(1); - } else { - blobs_equal.fetch_add(1); - } - }); - }); - const double t1 = hwy::platform::Now(); - fprintf(stderr, "%.1f GB/s; total blob matches=%zu, mismatches=%zu\n", - total_bytes / (t1 - t0) * 1E-9, blobs_equal.load(), - blobs_diff.load()); -} - -// Compares two sbs files, including blob order. -void ReadAndCompareBlobs(const char* path1, const char* path2) { - // Open files. - BlobReader reader1; - BlobReader reader2; - const BlobError err1 = reader1.Open(Path(path1)); - const BlobError err2 = reader2.Open(Path(path2)); - if (err1 != 0 || err2 != 0) { - HWY_ABORT("Failed to open files: %s %s: %d %d\n", path1, path2, err1, err2); - } - - if (!CompareKeys(reader1, reader2)) return; - - // Single allocation, avoid initializing the memory. - BoundedTopology topology; - Allocator::Init(topology); - NestedPools pools(topology); - const size_t total_bytes = TotalBytes(reader1) + TotalBytes(reader2); - BytePtr all_blobs = hwy::AllocateAligned(total_bytes); - size_t pos = 0; - BlobVec blobs1 = ReserveMemory(reader1, all_blobs, pos); - BlobVec blobs2 = ReserveMemory(reader2, all_blobs, pos); - - ReadBothBlobs(reader1, reader2, total_bytes, blobs1, blobs2, pools); - - CompareBlobs(reader1.Keys(), blobs1, blobs2, total_bytes, pools); -} - -} // namespace gcpp - -int main(int argc, char** argv) { - if (argc != 3) { - HWY_ABORT("Usage: %s \n", argv[0]); - } - if (strcmp(argv[1], argv[2]) == 0) { - HWY_ABORT("Filenames are the same, skipping comparison: %s\n", argv[1]); - } - gcpp::ReadAndCompareBlobs(argv[1], argv[2]); - return 0; -} diff --git a/compression/blob_store.cc b/compression/blob_store.cc deleted file mode 100644 index 06bcb56..0000000 --- a/compression/blob_store.cc +++ /dev/null @@ -1,341 +0,0 @@ -// Copyright 2024 Google LLC -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "compression/blob_store.h" - -#include -#include - -#include -#include -#include -#include -#include - -#include "compression/io.h" -#include "hwy/aligned_allocator.h" -#include "hwy/base.h" -#include "hwy/contrib/thread_pool/thread_pool.h" -#include "hwy/detect_compiler_arch.h" - -namespace gcpp { - -hwy::uint128_t MakeKey(const char* string) { - size_t length = 0; - for (size_t i = 0; string[i] != '\0'; ++i) { - ++length; - } - if (length > 16) { - HWY_ABORT("Key %s is too long, please truncate to 16 chars.", string); - } - - hwy::uint128_t ret; - hwy::ZeroBytes(&ret); - hwy::CopyBytes(string, &ret, length); - return ret; -} - -std::string StringFromKey(hwy::uint128_t key) { - std::string name(sizeof(key) + 1, '\0'); - hwy::CopyBytes(&key, name.data(), sizeof(key)); - name.resize(name.find('\0')); - return name; -} - -namespace { -void EnqueueChunkRequests(uint64_t offset, uint64_t size, uint8_t* data, - std::vector& requests) { - // Split into chunks for load-balancing even if blob sizes vary. - constexpr size_t kChunkSize = 4 * 1024 * 1024; // bytes - - // Split into whole chunks and possibly one remainder. - uint64_t pos = 0; - if (size >= kChunkSize) { - for (; pos <= size - kChunkSize; pos += kChunkSize) { - requests.emplace_back(offset + pos, kChunkSize, data + pos, 0); - } - } - if (pos != size) { - requests.emplace_back(offset + pos, size - pos, data + pos, 0); - } -} -} // namespace - -static_assert(HWY_IS_LITTLE_ENDIAN, "Assumes little endian"); - -// On-disk representation (little-endian). -// -// Deliberately omits a version number because this file format is unchanging. -// Additional data may be added only inside new blobs. Changes to the blob -// contents or type should be handled by renaming keys. -#pragma pack(push, 1) -class BlobStore { - static constexpr uint32_t kMagic = 0x0A534253; // SBS\n - - public: - // NOT including padding, so that we can also use ZeroFillPadding after - // copying the header. - static constexpr size_t HeaderSize(size_t num_blobs) { - // 16-byte fixed fields plus per-blob: 16-byte key, 16-byte offset/size. - return 16 + 32 * num_blobs; - } - - // Returns how many bytes to allocate for the header without the subsequent - // blobs. Requires num_blobs_ to already be set, typically by reading - // sizeof(BlobStore) bytes from disk. - size_t PaddedHeaderSize() const { - return hwy::RoundUpTo(HeaderSize(num_blobs_), kBlobAlign); - } - - // Returns aligned offset and zero-fills between that and `offset`. - uint64_t ZeroFillPadding(uint64_t offset) { - uint8_t* const bytes = reinterpret_cast(this); - const uint64_t padded = hwy::RoundUpTo(offset, kBlobAlign); - hwy::ZeroBytes(bytes + offset, padded - offset); - return padded; - } - - BlobError CheckValidity(const uint64_t file_size) { - if (magic_ != kMagic) return __LINE__; - if (num_blobs_ == 0) return __LINE__; - if (file_size_ != file_size) return __LINE__; - - // Ensure blobs are back to back, and zero-pad. - uint64_t offset = ZeroFillPadding(HeaderSize(num_blobs_)); - for (size_t i = 0; i < num_blobs_; ++i) { - const hwy::uint128_t val = keys_[num_blobs_ + i]; - if (val.lo != offset) return __LINE__; - offset = hwy::RoundUpTo(offset + val.hi, kBlobAlign); - } - - if (offset != file_size_) return __LINE__; - - return 0; // all OK - } - - static BlobStorePtr Allocate(uint64_t total_size) { - uint8_t* bytes = - static_cast(hwy::AllocateAlignedBytes(total_size)); - if (!bytes) return BlobStorePtr(); - return BlobStorePtr(new (bytes) BlobStore(), hwy::AlignedFreer()); - } - - static std::vector PrepareWriteRequests( - const hwy::uint128_t keys[], const hwy::Span blobs[], - size_t num_blobs, BlobStore* bs) { - // Sanity check and ensure the cast below is safe. - HWY_ASSERT(num_blobs < (1ULL << 20)); - - // Allocate var-length header. - const size_t header_size = HeaderSize(num_blobs); - const size_t padded_header_size = hwy::RoundUpTo(header_size, kBlobAlign); - const uint64_t padded_header_end = bs->ZeroFillPadding(header_size); - HWY_ASSERT(padded_header_end == padded_header_size); - - // All-zero buffer used to write padding to the file without copying the - // input blobs. - static uint8_t zeros[kBlobAlign] = {0}; - - // Total file size will be the header plus all padded blobs. - uint64_t payload = 0; - for (size_t i = 0; i < num_blobs; ++i) { - payload += hwy::RoundUpTo(blobs[i].size(), kBlobAlign); - } - const size_t total_size = padded_header_size + payload; - - // Fill header. - bs->magic_ = kMagic; - bs->num_blobs_ = static_cast(num_blobs); - bs->file_size_ = total_size; - hwy::CopyBytes(keys, bs->keys_, num_blobs * sizeof(keys[0])); - - // First IO request is for the header (not yet filled!). - std::vector requests; - requests.reserve(1 + 2 * num_blobs); - requests.emplace_back(/*offset=*/0, padded_header_size, - reinterpret_cast(bs), 0); - - // Fill second half of keys_ with offset/size and prepare IO requests. - uint64_t offset = padded_header_end; - for (size_t i = 0; i < num_blobs; ++i) { - bs->keys_[num_blobs + i].lo = offset; - bs->keys_[num_blobs + i].hi = blobs[i].size(); - - EnqueueChunkRequests(offset, blobs[i].size(), - const_cast(blobs[i].data()), requests); - offset += blobs[i].size(); - const size_t padded_size = hwy::RoundUpTo(blobs[i].size(), kBlobAlign); - if (padded_size != blobs[i].size()) { - const size_t padding = padded_size - blobs[i].size(); - HWY_ASSERT(padding <= kBlobAlign); - requests.emplace_back(offset, padding, zeros, 0); - offset += padding; - } - } - - HWY_ASSERT(offset == total_size); - return requests; - } - - bool FindKey(const hwy::uint128_t key, uint64_t& offset, size_t& size) const { - for (size_t i = 0; i < num_blobs_; ++i) { - if (keys_[i] == key) { - const hwy::uint128_t val = keys_[num_blobs_ + i]; - offset = val.lo; - size = val.hi; - return true; - } - } - return false; - } - - hwy::Span Keys() const { - return hwy::Span(keys_, num_blobs_); - } - - private: - uint32_t magic_; - uint32_t num_blobs_; // never 0 - uint64_t file_size_; // must match actual size of file - hwy::uint128_t keys_[1]; // length: 2 * num_blobs - // Padding, then the blob identified by keys[0], then padding etc. -}; -#pragma pack(pop) - -BlobError BlobReader::Open(const Path& filename) { - file_ = OpenFileOrNull(filename, "r"); - if (!file_) return __LINE__; - - // Read first part of header to get actual size. - BlobStore bs; - if (!file_->Read(0, sizeof(bs), &bs)) return __LINE__; - const size_t padded_size = bs.PaddedHeaderSize(); - HWY_ASSERT(padded_size >= sizeof(bs)); - - // Allocate full header. - blob_store_ = BlobStore::Allocate(padded_size); - if (!blob_store_) return __LINE__; - - // Copy what we already read (more efficient than seek + re-read). - hwy::CopySameSize(&bs, blob_store_.get()); - // Read the rest of the header, but not the full file. - uint8_t* bytes = reinterpret_cast(blob_store_.get()); - if (!file_->Read(sizeof(bs), padded_size - sizeof(bs), bytes + sizeof(bs))) { - return __LINE__; - } - - return blob_store_->CheckValidity(file_->FileSize()); -} - -size_t BlobReader::BlobSize(hwy::uint128_t key) const { - uint64_t offset; - size_t size; - if (!blob_store_->FindKey(key, offset, size)) return 0; - return size; -} - -BlobError BlobReader::Enqueue(hwy::uint128_t key, void* data, size_t size) { - uint64_t offset; - size_t actual_size; - if (!blob_store_->FindKey(key, offset, actual_size)) return __LINE__; - if (actual_size != size) { - fprintf(stderr, - "Mismatch between expected %d and actual %d KiB size of blob %s. " - "Please see README.md on how to update the weights.\n", - static_cast(size >> 10), static_cast(actual_size >> 10), - StringFromKey(key).c_str()); - return __LINE__; - } - - EnqueueChunkRequests(offset, actual_size, reinterpret_cast(data), - requests_); - return 0; -} - -// Parallel synchronous I/O. Alternatives considered: -// - readv is limited to 0x7FFFF000 bytes on Linux (even 64-bit). Note that -// pread calls preadv with a single iovec. -// - O_DIRECT seems undesirable because we do want to use the OS cache -// between consecutive runs. -// - memory-mapped I/O is less predictable and adds noise to measurements. -BlobError BlobReader::ReadAll(hwy::ThreadPool& pool) { - File* pfile = file_.get(); // not owned - const auto& requests = requests_; - std::atomic_flag err = ATOMIC_FLAG_INIT; - // >5x speedup from parallel reads when cached. - pool.Run(0, requests.size(), - [pfile, &requests, &err](uint64_t i, size_t /*thread*/) { - if (!pfile->Read(requests[i].offset, requests[i].size, - requests[i].data)) { - fprintf(stderr, "Failed to read blob %zu\n", - static_cast(i)); - err.test_and_set(); - } - }); - if (err.test_and_set()) return __LINE__; - return 0; -} - -BlobError BlobReader::ReadOne(hwy::uint128_t key, void* data, - size_t size) const { - uint64_t offset; - size_t actual_size; - if (!blob_store_->FindKey(key, offset, actual_size)) return __LINE__; - if (actual_size != size) { - fprintf(stderr, - "Mismatch between expected %d and actual %d KiB size of blob %s. " - "Please see README.md on how to update the weights.\n", - static_cast(size >> 10), static_cast(actual_size >> 10), - StringFromKey(key).c_str()); - return __LINE__; - } - if (!file_->Read(offset, actual_size, data)) { - return __LINE__; - } - return 0; -} - -hwy::Span BlobReader::Keys() const { - return blob_store_->Keys(); -} - -BlobError BlobWriter::WriteAll(hwy::ThreadPool& pool, const Path& filename) { - HWY_ASSERT(keys_.size() == blobs_.size()); - - // Concatenate blobs in memory. - const size_t header_size = BlobStore::HeaderSize(keys_.size()); - const size_t padded_header_size = hwy::RoundUpTo(header_size, kBlobAlign); - const BlobStorePtr bs = BlobStore::Allocate(padded_header_size); - const std::vector requests = BlobStore::PrepareWriteRequests( - keys_.data(), blobs_.data(), keys_.size(), bs.get()); - - // Create/replace existing file. - std::unique_ptr file = OpenFileOrNull(filename, "w+"); - if (!file) return __LINE__; - File* pfile = file.get(); // not owned - - std::atomic_flag err = ATOMIC_FLAG_INIT; - pool.Run(0, requests.size(), - [pfile, &requests, &err](uint64_t i, size_t /*thread*/) { - if (!pfile->Write(requests[i].data, requests[i].size, - requests[i].offset)) { - err.test_and_set(); - } - }); - if (err.test_and_set()) return __LINE__; - return 0; -} - -} // namespace gcpp diff --git a/compression/blob_store.h b/compression/blob_store.h deleted file mode 100644 index d98235c..0000000 --- a/compression/blob_store.h +++ /dev/null @@ -1,117 +0,0 @@ -// Copyright 2024 Google LLC -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_BLOB_STORE_H_ -#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_BLOB_STORE_H_ - -#include -#include - -#include -#include -#include - -#include "compression/io.h" -#include "hwy/aligned_allocator.h" -#include "hwy/base.h" // hwy::uint128_t -#include "hwy/contrib/thread_pool/thread_pool.h" - -namespace gcpp { - -// Convenient way to construct a key from a string (<= 16 chars). -hwy::uint128_t MakeKey(const char* string); - -// Returns a string from a key. -std::string StringFromKey(hwy::uint128_t key); - -// Ordered list of opaque blobs (~hundreds), identified by unique opaque -// 128-bit keys. -class BlobStore; - -// Incomplete type, so dtor will not be called. -using BlobStorePtr = hwy::AlignedFreeUniquePtr; - -// 0 if successful, otherwise the line number of the failing check. -using BlobError = int; - -// Blob offsets on disk and memory addresses are a multiple of this, because -// we pad the header and each blob's size. This matches CUDA alignment and the -// maximum SVE vector size, and exceeds typical x86 cache line sizes (64 or -// 128), which can help performance. -static constexpr size_t kBlobAlign = 256; - -// One I/O request, serviced by threads in a pool. -struct BlobIO { - BlobIO(uint64_t offset, size_t size, void* data, uint64_t padding) - : offset(offset), size(size), data(data), padding(padding) {} - - uint64_t offset; - size_t size; // bytes - void* data; - uint64_t padding; -}; - -class BlobReader { - public: - BlobReader() { requests_.reserve(500); } - ~BlobReader() = default; - - // Opens `filename` and reads its header. - BlobError Open(const Path& filename); - - // Returns the size of the blob identified by `key`, or 0 if not found. - size_t BlobSize(hwy::uint128_t key) const; - - // Enqueues read requests if `key` is found and its size matches `size`, which - // is in units of bytes. - BlobError Enqueue(hwy::uint128_t key, void* data, size_t size); - - // Reads all enqueued requests. - BlobError ReadAll(hwy::ThreadPool& pool); - - // Reads one blob directly. - BlobError ReadOne(hwy::uint128_t key, void* data, size_t size) const; - - // Returns all available blob keys. - hwy::Span Keys() const; - - private: - BlobStorePtr blob_store_; // holds header, not the entire file - std::vector requests_; - std::unique_ptr file_; -}; - -class BlobWriter { - public: - // `size` is in bytes. - void Add(hwy::uint128_t key, const void* data, size_t size) { - keys_.push_back(key); - blobs_.emplace_back(static_cast(data), size); - } - - // Stores all blobs to disk in the given order with padding for alignment. - BlobError WriteAll(hwy::ThreadPool& pool, const Path& filename); - - // Returns the number of blobs added. - size_t DebugNumBlobsAdded() const { return keys_.size(); } - - private: - std::vector keys_; - std::vector> blobs_; -}; - -} // namespace gcpp - -#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_BLOB_STORE_H_ diff --git a/compression/blob_store_test.cc b/compression/blob_store_test.cc deleted file mode 100644 index dbba55f..0000000 --- a/compression/blob_store_test.cc +++ /dev/null @@ -1,86 +0,0 @@ -// Copyright 2024 Google LLC -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "compression/blob_store.h" - -#include - -#include -#include - -#include "compression/io.h" -#include "hwy/contrib/thread_pool/thread_pool.h" -#include "hwy/tests/hwy_gtest.h" -#include "hwy/tests/test_util-inl.h" // HWY_ASSERT_EQ - -namespace gcpp { -namespace { - -#if !HWY_TEST_STANDALONE -class BlobStoreTest : public testing::Test {}; -#endif - -#if !HWY_OS_WIN -TEST(BlobStoreTest, TestReadWrite) { - static const std::array kOriginalData = {-1, 0, 3.14159, 2.71828}; - - // mkstemp will modify path_str so it holds a newly-created temporary file. - char path_str[] = "/tmp/blob_store_test.sbs-XXXXXX"; - const int fd = mkstemp(path_str); - HWY_ASSERT(fd > 0); - - hwy::ThreadPool pool(4); - const Path path(path_str); - std::array buffer = kOriginalData; - - const hwy::uint128_t keyA = MakeKey("0123456789abcdef"); - const hwy::uint128_t keyB = MakeKey("q"); - BlobWriter writer; - writer.Add(keyA, "DATA", 5); - writer.Add(keyB, buffer.data(), sizeof(buffer)); - HWY_ASSERT_EQ(writer.WriteAll(pool, path), 0); - HWY_ASSERT_ARRAY_EQ(kOriginalData.data(), buffer.data(), buffer.size()); - - std::fill(buffer.begin(), buffer.end(), 0); - BlobReader reader; - HWY_ASSERT_EQ(reader.Open(path), 0); - HWY_ASSERT_EQ(reader.BlobSize(keyA), 5); - HWY_ASSERT_EQ(reader.BlobSize(keyB), sizeof(buffer)); - - HWY_ASSERT_EQ(reader.Enqueue(keyB, buffer.data(), sizeof(buffer)), 0); - HWY_ASSERT_EQ(reader.ReadAll(pool), 0); - HWY_ASSERT_ARRAY_EQ(kOriginalData.data(), buffer.data(), buffer.size()); - - { - std::array buffer; - HWY_ASSERT(reader.ReadOne(keyA, buffer.data(), 1) != 0); - HWY_ASSERT_EQ(reader.ReadOne(keyA, buffer.data(), 5), 0); - HWY_ASSERT_STRING_EQ("DATA", buffer.data()); - } - - const hwy::Span keys = reader.Keys(); - HWY_ASSERT_EQ(keys.size(), 2); - HWY_ASSERT_EQ(keys[0], keyA); - HWY_ASSERT_EQ(keys[1], keyB); - - close(fd); - unlink(path_str); -} -#endif - -} // namespace -} // namespace gcpp - -HWY_TEST_MAIN(); diff --git a/compression/compress-inl.h b/compression/compress-inl.h index 8638b5f..512f8fa 100644 --- a/compression/compress-inl.h +++ b/compression/compress-inl.h @@ -21,19 +21,20 @@ #include #include -#include // lroundf, only if COMPRESS_STATS -#include +#include #include -#include "compression/blob_store.h" #include "compression/compress.h" // IWYU pragma: export #include "compression/distortion.h" -#include "gemma/configs.h" #include "hwy/aligned_allocator.h" #include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/timer.h" +#if COMPRESS_STATS +#include // lroundf +#endif + #endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_COMPRESS_INL_H_ // Include guard for (potentially) SIMD code. @@ -64,8 +65,8 @@ static constexpr bool kIsTest = false; template // primary, must specialize struct CompressTraits {}; -// Used by backprop/, where weights are currently f32; also MatMul for f32 -// weights or activations, if native `ReorderWidenMulAccumulate` is available. +// Used by MatMul for f32 weights or activations, if native +// `ReorderWidenMulAccumulate` is available. template <> struct CompressTraits { using Packed = float; @@ -379,7 +380,7 @@ struct CompressTraits { using Packed = SfpStream; // Callers are responsible for scaling `raw` such that its magnitudes do not - // exceed `SfpStream::kMax`. See CompressedArray::scale(). + // exceed `SfpStream::kMax`. See CompressedArray::Scale(). template static HWY_INLINE void Compress(DF df, const float* HWY_RESTRICT raw, size_t num, CompressPerThread& tls, @@ -387,7 +388,7 @@ struct CompressTraits { const size_t packed_ofs) { SfpCodec::Enc(df, raw, num, packed.ptr + packed_ofs); - if (COMPRESS_STATS) { + if constexpr (COMPRESS_STATS) { const hn::Repartition dbf; auto distorted = hwy::AllocateAligned(hwy::RoundUpTo(num, hn::Lanes(dbf))); @@ -431,9 +432,10 @@ struct CompressTraits { size_t num, CompressPerThread& tls, const PackedSpan& packed, const size_t packed_ofs) { - NuqCodec::Enc(df, raw, num, tls.buf, packed, packed_ofs); + if (!tls.buf) tls.buf = std::make_unique(); + NuqCodec::Enc(df, raw, num, *tls.buf, packed, packed_ofs); - if (COMPRESS_STATS) { + if constexpr (COMPRESS_STATS) { for (size_t i = 0; i < num; ++i) { tls.stats.NotifyIn(static_cast(lroundf(raw[i] * 100.0f + 500.0f))); } @@ -477,7 +479,7 @@ HWY_NOINLINE void Compress(const float* HWY_RESTRICT raw, size_t num, const size_t packed_ofs, hwy::ThreadPool& pool) { packed.BoundsCheck(packed_ofs, num); work.tls.resize(pool.NumWorkers()); - if (COMPRESS_STATS) { + if constexpr (COMPRESS_STATS) { for (auto& tls : work.tls) { tls.stats.Reset(); } @@ -486,7 +488,7 @@ HWY_NOINLINE void Compress(const float* HWY_RESTRICT raw, size_t num, const bool want_bench = COMPRESS_STATS || !kIsTest; const double t0 = want_bench ? hwy::platform::Now() : 0.0; - using Traits = CompressTraits; + using Traits = CompressTraits>; constexpr size_t kBatch = 8192; const size_t num_batches = hwy::DivCeil(num, kBatch); pool.Run(0, num_batches, @@ -507,7 +509,7 @@ HWY_NOINLINE void Compress(const float* HWY_RESTRICT raw, size_t num, fprintf(stderr, "Compress %.1f MB/s\n", mbps); } - if (COMPRESS_STATS) { + if constexpr (COMPRESS_STATS) { for (size_t i = 1; i < work.tls.size(); ++i) { work.tls[0].stats.Assimilate(work.tls[i].stats); } @@ -515,26 +517,25 @@ HWY_NOINLINE void Compress(const float* HWY_RESTRICT raw, size_t num, } } -// Adapter that compresses into `MatStorageT`. `raw` must already be scaled -// to fit the value range, if `Packed` is `SfpStream`. +// Same as above, but without parallelization nor benchmarking. template -HWY_INLINE void CompressScaled(const float* HWY_RESTRICT raw, size_t num, - CompressWorkingSet& work, - MatStorageT& compressed, - hwy::ThreadPool& pool) { - Compress(raw, num, work, - MakeSpan(compressed.data(), compressed.NumElements()), - /*packed_ofs=*/0, pool); +HWY_NOINLINE void Compress(const float* HWY_RESTRICT raw, size_t num, + CompressPerThread& tls, + const PackedSpan& packed, + const size_t packed_ofs) { + packed.BoundsCheck(packed_ofs, num); + using Traits = CompressTraits>; + const hn::ScalableTag df; + Traits::Compress(df, raw, num, tls, packed, packed_ofs); } -// Stores two f32 vectors to f32 or bf16; avoids duplicating RMSNorm and -// RMSNormInplace for the two output types. +// Stores two f32 vectors to f32 or bf16. template > void Compress2(DF df, VF raw0, VF raw1, const PackedSpan& packed, const size_t packed_ofs) { static_assert(hwy::IsSameEither()); packed.BoundsCheck(packed_ofs, 2 * hn::Lanes(df)); - using Traits = CompressTraits; + using Traits = CompressTraits>; Traits::Store2(df, raw0, raw1, packed, packed_ofs); } @@ -566,7 +567,7 @@ HWY_INLINE void Decompress2(DRaw d, const PackedSpan& packed, // Decompresses from any type of `packed`, starting at (any) `packed_ofs`, to // (any) `num` elements in `raw`, then appends `[0, hn::Lanes(d))` zeroes as // required to round `num` up to one vector, if it is not already. The caller is -// responsible for scaling `raw` to the original range because `EmbedToken` +// responsible for scaling `raw` to the original range because `EmbedMMToken` // also wants to scale the decompressed elements. // `TRaw` can be `float/BF16`, or `double` if `Packed` is `float`. template > @@ -708,51 +709,6 @@ HWY_INLINE float DecompressAndCall(D, const PackedSpan v, comp3); } -// Functor called for each tensor, which compresses and stores them along with -// their scaling factors to BlobStore. -class Compressor { - public: - explicit Compressor(hwy::ThreadPool& pool) : writer_(pool) {} - - template - void operator()(MatPtrT* compressed, const char* decorated_name, - const float* HWY_RESTRICT weights) { - size_t num_weights = compressed->NumElements(); - if (num_weights == 0 || weights == nullptr || compressed->Ptr() == nullptr) - return; - size_t num_compressed = compressed->NumElements(); - PackedSpan packed = MakeSpan(compressed->data(), num_compressed); - fprintf(stderr, "Compressing %s (%zuM), please wait\n", decorated_name, - num_weights / (1000 * 1000)); - Compress(weights, num_weights, work_, packed, /*packed_ofs=*/0, - writer_.pool()); - writer_(compressed, decorated_name); - } - - void AddTokenizer(const std::string& tokenizer) { - writer_.AddTokenizer(tokenizer); - } - - void AddScales(const float* scales, size_t len) { - writer_.AddScales(scales, len); - } - - // Writes all blobs to disk in the given order. The config is optional and - // if given, it is written to the file, along with the TOC, making it - // single-file format. Otherwise, the file is written in the multi-file format - // without a TOC. - BlobError WriteAll(const Path& blob_filename, const ModelConfig* config) { - return writer_.WriteAll(blob_filename, config); - } - - // Returns the number of blobs added. - size_t DebugNumBlobsAdded() const { return writer_.DebugNumBlobsAdded(); } - - private: - CompressWorkingSet work_; - WriteToBlobStore writer_; -}; - // NOLINTNEXTLINE(google-readability-namespace-comments) } // namespace HWY_NAMESPACE } // namespace gcpp diff --git a/compression/compress.cc b/compression/compress.cc index e858e15..6ef8990 100644 --- a/compression/compress.cc +++ b/compression/compress.cc @@ -15,8 +15,34 @@ #include "compression/compress.h" +#include +#include + +#include "util/mat.h" +#include "hwy/base.h" +#include "hwy/profiler.h" + namespace gcpp { -MatPtr::~MatPtr() {} +float ScaleWeights(float* HWY_RESTRICT raw, size_t num) { + PROFILER_FUNC; + + float maxabs = 0.0; + for (size_t i = 0; i < num; ++i) { + maxabs = HWY_MAX(maxabs, hwy::ScalarAbs(raw[i])); + } + if (maxabs <= SfpStream::kMax) { + return 1.0f; + } + const float scale = maxabs / SfpStream::kMax; + const float inv_scale = static_cast(1.0 / static_cast(scale)); + for (size_t i = 0; i < num; ++i) { + // Clamp because kMax may still be exceeded. + const float magn = + HWY_MIN(SfpStream::kMax, hwy::ScalarAbs(raw[i] * inv_scale)); + raw[i] = hwy::ScalarCopySign(magn, raw[i]); + } + return scale; +} } // namespace gcpp diff --git a/compression/compress.h b/compression/compress.h index d875c4b..811f483 100644 --- a/compression/compress.h +++ b/compression/compress.h @@ -17,31 +17,19 @@ #ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_COMPRESS_H_ #define THIRD_PARTY_GEMMA_CPP_COMPRESSION_COMPRESS_H_ -#include "hwy/base.h" #define COMPRESS_STATS 0 #include #include -#include -#include -#include -#include -#include -#include +#if COMPRESS_STATS +#include +#endif + +#include #include -// IWYU pragma: begin_exports -#include "compression/blob_store.h" -#include "compression/fields.h" -#include "compression/io.h" -#include "compression/shared.h" -#include "gemma/tensor_index.h" -#include "util/basics.h" -// IWYU pragma: end_exports -#include "gemma/configs.h" -#include "util/allocator.h" -#include "hwy/per_target.h" +#include "compression/types.h" // IWYU pragma: export #if COMPRESS_STATS #include "compression/distortion.h" #include "hwy/stats.h" @@ -49,388 +37,6 @@ namespace gcpp { -// Base class for rank-1 or 2 tensors (vector or matrix). -// Supports both dynamic and compile-time sizing. -// Holds metadata and a non-owning pointer to the data, owned by the derived -// MatStorageT class. -// This class also provides easy conversion from/to a table of contents for a -// BlobStore file, and a templated (compile-time) accessor for a 2-d array of -// fixed inner dimension and type. -// It is designed to be put in a vector, and has default copy and operator=, so -// it is easy to read/write a blob_store file. -class MatPtr : public IFields { - public: - // Full constructor for dynamic sizing. - MatPtr(const std::string& name, Type type, size_t element_size, size_t rows, - size_t cols) - : name_(name), - type_(type), - element_size_(element_size), - num_elements_(rows * cols), - rows_(rows), - cols_(cols), - ptr_(nullptr) { - stride_ = cols; - } - // Default is to leave all fields default-initialized. - MatPtr() = default; - virtual ~MatPtr(); - - // Compatibility interface for CompressedArray. - // TODO: remove. - template - T* data() { - return HWY_RCAST_ALIGNED(T*, ptr_); - } - template - const T* data() const { - return HWY_RCAST_ALIGNED(const T*, ptr_); - } - - const void* Ptr() const { return ptr_; } - void* Ptr() { return ptr_; } - // Sets the pointer from another MatPtr. - void SetPtr(const MatPtr& other) { ptr_ = other.ptr_; } - - // Copying allowed as the metadata is small. - MatPtr(const MatPtr& other) = default; - MatPtr& operator=(const MatPtr& other) = default; - - // Returns the name of the blob. - const char* Name() const override { return name_.c_str(); } - void SetName(const std::string& name) { name_ = name; } - - // Returns the type of the blob. - Type GetType() const { return type_; } - - // Returns the size of each element in bytes. - size_t ElementSize() const { return element_size_; } - - // Returns the number of elements in the array. - size_t NumElements() const { return num_elements_; } - - // Returns the number of bytes in the array. - size_t SizeBytes() const { - if (this->GetType() == TypeEnum()) { - return NuqStream::PackedEnd(num_elements_); - } - return num_elements_ * element_size_; - } - - // Returns the number of rows in the 2-d array (outer dimension). - size_t Rows() const { return rows_; } - - // Returns the number of columns in the 2-d array (inner dimension). - size_t Cols() const { return cols_; } - - Extents2D Extents() const { return Extents2D(rows_, cols_); } - - // Currently same as cols, but may differ in the future. This is the offset by - // which to advance pointers to the next row. - size_t Stride() const { return stride_; } - - // Decoded elements should be multiplied by this to restore their original - // range. This is required because SfpStream can only encode a limited range - // of magnitudes. - float scale() const { return scale_; } - void set_scale(float scale) { scale_ = scale; } - - std::string LayerName(int layer) const { - std::string name = name_ + std::to_string(layer); - HWY_ASSERT(name.size() <= sizeof(hwy::uint128_t)); - return name; - } - - // Sets all data to zero. - void ZeroInit() { - if (ptr_ == nullptr) - HWY_ABORT("ptr_ is null on tensor %s\n", name_.c_str()); - hwy::ZeroBytes(ptr_, SizeBytes()); - } - - void VisitFields(IFieldsVisitor& visitor) override { - visitor(name_); - visitor(type_); - visitor(element_size_); - visitor(num_elements_); - visitor(rows_); - visitor(cols_); - visitor(scale_); - visitor(stride_); - } - - // Calls func on the upcasted type. Since MatPtr by design is not templated, - // here we provide a way to get to the derived type, provided that `Type()` - // is one of the strings returned by `TypeName()`. - template - decltype(auto) CallUpcasted(FuncT& func, TArgs&&... args); - - protected: - // Arbitrary name for the array of preferably <= 16 characters. - std::string name_; - // Should be the result of TypeEnum for CallUpcasted() to work. - Type type_; - // sizeof(T) - uint32_t element_size_ = 0; - // Number of elements in the array. - uint32_t num_elements_ = 0; // In element_size units. - // Number of rows in the 2-d array (outer dimension). - uint32_t rows_ = 0; - // Number of columns in the 2-d array (inner dimension). - uint32_t cols_ = 0; - // Scaling to apply to each element. - float scale_ = 1.0f; - // Aligned data array. This is always a borrowed pointer. It should never be - // freed. The underlying memory is owned by a subclass or some external class - // and must outlive this object. - void* ptr_ = nullptr; - - uint32_t stride_; -}; - -// MatPtrT adds a single template argument to MatPtr for an explicit type. -// Use this class as a function argument where the type needs to be known. -// Use MatPtr where the type does not need to be known. -template -class MatPtrT : public MatPtr { - public: - // Full constructor for dynamic sizing. - MatPtrT(const std::string& name, size_t rows, size_t cols) - : MatPtr(name, TypeEnum(), sizeof(MatT), rows, cols) {} - // Construction from TensorIndex entry to remove duplication of sizes. - MatPtrT(const std::string& name, const TensorIndex& tensor_index) - : MatPtrT(name, tensor_index.FindName(name)) {} - MatPtrT(const std::string& name, const TensorInfo* tensor) - : MatPtr(name, TypeEnum(), sizeof(MatT), 0, 0) { - if (tensor == nullptr) { - cols_ = 0; - rows_ = 0; - } else { - cols_ = tensor->shape.back(); - rows_ = 1; - if (tensor->cols_take_extra_dims) { - // The columns eat the extra dimensions. - rows_ = tensor->shape[0]; - for (size_t i = 1; i < tensor->shape.size() - 1; ++i) { - cols_ *= tensor->shape[i]; - } - } else { - // The rows eat the extra dimensions. - for (size_t i = 0; i < tensor->shape.size() - 1; ++i) { - rows_ *= tensor->shape[i]; - } - } - } - stride_ = cols_; - num_elements_ = rows_ * cols_; - } - - // Copying allowed as the metadata is small. - MatPtrT(const MatPtr& other) : MatPtr(other) {} - MatPtrT& operator=(const MatPtr& other) { - MatPtr::operator=(other); - return *this; - } - MatPtrT(const MatPtrT& other) = default; - MatPtrT& operator=(const MatPtrT& other) = default; - - std::string CacheName(int layer = -1, char separator = ' ', - int index = -1) const { - // Already used/retired: s, S, n, 1 - const char prefix = hwy::IsSame() ? 'F' - : hwy::IsSame() ? 'B' - : hwy::IsSame() ? '$' - : hwy::IsSame() ? '2' - : '?'; - std::string name = std::string(1, prefix) + name_; - if (layer >= 0 || index >= 0) { - name += '_'; - if (layer >= 0) name += std::to_string(layer); - if (index >= 0) { - name += separator + std::to_string(index); - } - } - return name; - } - - // Sets the number of elements in the array. For use when the number of - // elements is != rows * cols ONLY. - void SetNumElements(size_t num_elements) { - num_elements_ = CompressedArrayElements(num_elements); - } - - // 2-d Accessor for a specific type but with a dynamic inner dimension. - template - const T& At(size_t row, size_t col) const { - size_t index = row * cols_ + col; - HWY_DASSERT(index < num_elements_); - return HWY_RCAST_ALIGNED(const T*, ptr_)[index]; - } - - // 1-d Accessor for a specific type. - // TODO: replace this with a Foreach(), or at least a ForEachRow(). - const MatT& At(size_t index) const { - HWY_DASSERT(index < num_elements_); - return HWY_RCAST_ALIGNED(const MatT*, ptr_)[index]; - } - MatT& At(size_t index) { return HWY_RCAST_ALIGNED(MatT*, ptr_)[index]; } - - // Compatibility interface for CompressedArray. - // TODO: remove - template - T* data() { - return HWY_RCAST_ALIGNED(T*, ptr_); - } - template - const T* data() const { - return HWY_RCAST_ALIGNED(const T*, ptr_); - } - // The const accessor data_scale1() asserts (!) that the scale is 1.0f, so - // calling it means "I am sure the scale is 1 and therefore ignore the scale". - // A scale of 0 indicates that the scale has likely never been set, so is - // "implicitly 1". - const MatT* data_scale1() const { - HWY_ASSERT(scale() == 1.f); - return HWY_RCAST_ALIGNED(const MatT*, ptr_); - } -}; - -template -decltype(auto) MatPtr::CallUpcasted(FuncT& func, TArgs&&... args) { - if (type_ == TypeEnum()) { - return func(dynamic_cast*>(this), - std::forward(args)...); - } else if (type_ == TypeEnum()) { - return func(dynamic_cast*>(this), - std::forward(args)...); - } else if (type_ == TypeEnum()) { - return func(dynamic_cast*>(this), - std::forward(args)...); - } else if (type_ == TypeEnum()) { - return func(dynamic_cast*>(this), - std::forward(args)...); - } else { - HWY_ABORT("Type %d unknown.", type_); - } -} - -// MatStorageT adds the actual data storage to MatPtrT. -// TODO: use Extents2D instead of rows and cols. -template -class MatStorageT : public MatPtrT { - public: - // Full constructor for dynamic sizing. - MatStorageT(const std::string& name, size_t rows, size_t cols) - : MatPtrT(name, rows, cols) { - Allocate(); - } - // Can copy the metadata, from a MatPtr, and allocate later. - MatStorageT(const MatPtr& other) : MatPtrT(other) {} - ~MatStorageT() = default; - - // Move-only because this contains a unique_ptr. - MatStorageT(const MatStorageT& other) = delete; - MatStorageT& operator=(const MatStorageT& other) = delete; - MatStorageT(MatStorageT&& other) = default; - MatStorageT& operator=(MatStorageT&& other) = default; - - // Allocate the memory and copy the pointer to the MatPtr. - // num_elements is in elements. In the default (zero) case, it is computed - // from the current num_elements_ which was set by the constructor from the - // rows and cols. - void Allocate(size_t num_elements = 0) { - if (num_elements == 0) { - num_elements = hwy::DivCeil(this->SizeBytes(), sizeof(MatT)); - } else { - this->num_elements_ = num_elements; - } - // Pad to allow overrunning the last row by 2 BF16 vectors, hence at most - // `2 * VectorBytes / sizeof(BF16)` elements of MatT. - const size_t padding = hwy::VectorBytes(); - data_ = Allocator::Alloc(num_elements + padding); - hwy::ZeroBytes(&data_[num_elements], padding * sizeof(MatT)); - this->ptr_ = data_.get(); - } - - // Zeros the content. - void ZeroInit() { - HWY_ASSERT(data_ != nullptr); - hwy::ZeroBytes(data_.get(), this->SizeBytes()); - } - - private: - AlignedPtr data_; -}; - -// MatStorage allows heterogeneous tensors to be stored in a single vector. -using MatStorage = MatStorageT; - -// Table of contents for a blob store file. Full metadata, but not actual data. -class BlobToc { - public: - BlobToc() = default; - - // Loads the table of contents from the given reader. - BlobError LoadToc(BlobReader& reader) { - hwy::uint128_t toc_key = MakeKey(kTocName); - size_t toc_size = reader.BlobSize(toc_key); - if (toc_size != 0) { - std::vector toc(toc_size / sizeof(uint32_t)); - BlobError err = reader.ReadOne(toc_key, toc.data(), toc_size); - if (err != 0) { - fprintf(stderr, "Failed to read toc (error %d)\n", err); - return err; - } - size_t consumed = 0; - size_t prev_consumed = static_cast(-1); - while (consumed < toc.size() && prev_consumed != consumed) { - MatPtr blob; - const IFields::ReadResult result = - blob.Read(hwy::Span(toc), consumed); - prev_consumed = consumed; - consumed = result.pos; - if (blob.NumElements() > 0) { - AddToToc(blob); - } - } - } - return 0; - } - - bool Empty() const { return toc_map_.empty(); } - - // Returns true if the table of contents contains the given name. - bool Contains(const std::string& name) const { - return toc_map_.find(name) != toc_map_.end(); - } - - // Returns the blob with the given name, or nullptr if not found. - const MatPtr* Get(const std::string& name) const { - auto it = toc_map_.find(name); - if (it == toc_map_.end()) return nullptr; - return &toc_[it->second]; - } - // The name of the toc in the blob store file. - static constexpr char kTocName[] = "toc"; - - // The name of the config in the blob store file. - static constexpr char kConfigName[] = "config"; - - // The name of the tokenizer in the blob store file. - static constexpr char kTokenizerName[] = "tokenizer"; - - private: - // Adds the blob to the table of contents. - void AddToToc(const MatPtr& blob) { - HWY_ASSERT(!Contains(blob.Name())); - toc_map_[blob.Name()] = toc_.size(); - toc_.push_back(blob); - } - - std::unordered_map toc_map_; - std::vector toc_; -}; - #if COMPRESS_STATS class CompressStats { public: @@ -489,7 +95,8 @@ struct CompressStats { #endif // COMPRESS_STATS struct CompressPerThread { - NuqStream::ClusterBuf buf; + // Allocated the first time NUQ is used. + std::unique_ptr buf; CompressStats stats; }; @@ -497,196 +104,11 @@ struct CompressWorkingSet { std::vector tls; }; -// Class to collect and write a set of tensors to a blob store file. -class WriteToBlobStore { - public: - explicit WriteToBlobStore(hwy::ThreadPool& pool) : pool_(pool) {} - - template - void operator()(MatPtrT* compressed, const char* decorated_name) { - if (compressed->Ptr() == nullptr) return; - writer_.Add(MakeKey(decorated_name), compressed->Ptr(), - compressed->SizeBytes()); - MatPtr renamed_tensor(*compressed); - renamed_tensor.SetName(decorated_name); - renamed_tensor.AppendTo(toc_); - } - - void AddTokenizer(const std::string& tokenizer) { - writer_.Add(MakeKey(BlobToc::kTokenizerName), tokenizer.data(), - tokenizer.size() * sizeof(tokenizer[0])); - } - - void AddScales(const float* scales, size_t len) { - if (len) { - MatPtrT scales_ptr("scales", 0, 1); - writer_.Add(MakeKey(scales_ptr.CacheName().c_str()), scales, - len * sizeof(scales[0])); - } - } - - // Writes all blobs to disk in the given order. The config is optional and - // if given, it is written to the file, along with the TOC, making it - // single-file format. Otherwise, the file is written in the multi-file format - // without a TOC. - BlobError WriteAll(const Path& blob_filename, const ModelConfig* config) { - if (config) { - writer_.Add(MakeKey(BlobToc::kTocName), toc_.data(), - toc_.size() * sizeof(toc_[0])); - config_buffer_ = config->Write(); - writer_.Add(MakeKey(BlobToc::kConfigName), config_buffer_.data(), - config_buffer_.size() * sizeof(config_buffer_[0])); - } - const BlobError err = writer_.WriteAll(pool_, blob_filename); - if (err != 0) { - fprintf(stderr, "Failed to write blobs to %s (error %d)\n", - blob_filename.path.c_str(), err); - } - return err; - } - - // Returns the number of blobs added. - size_t DebugNumBlobsAdded() const { return writer_.DebugNumBlobsAdded(); } - - hwy::ThreadPool& pool() { return pool_; } - - protected: - hwy::ThreadPool& pool_; - - private: - std::vector toc_; - BlobWriter writer_; - std::vector config_buffer_; -}; - -// Functor called for each tensor, which loads them and their scaling factors -// from BlobStore. -class ReadFromBlobStore { - public: - explicit ReadFromBlobStore(const Path& blob_filename) { - err_ = reader_.Open(blob_filename); - if (HWY_UNLIKELY(err_ != 0)) { - fprintf(stderr, "Error %d opening BlobStore %s.\n", err_, - blob_filename.path.c_str()); - return; // avoid overwriting err_ to ensure ReadAll will fail. - } - err_ = file_toc_.LoadToc(reader_); - if (HWY_UNLIKELY(err_ != 0)) { - fprintf(stderr, "Found a TOC, but failed to load it (code %d)\n", err_); - } - } - - // Returns true if there is a TOC. - bool HaveToc() const { return !file_toc_.Empty(); } - - // Reads the config from the blob store file. - BlobError LoadConfig(ModelConfig& config) { - hwy::uint128_t config_key = MakeKey(BlobToc::kConfigName); - size_t config_size = reader_.BlobSize(config_key); - if (config_size == 0) return __LINE__; - std::vector config_buffer(config_size / sizeof(uint32_t)); - BlobError err = - reader_.ReadOne(config_key, config_buffer.data(), config_size); - if (err != 0) { - fprintf(stderr, "Failed to read config (error %d)\n", err); - return err; - } - config.Read(hwy::Span(config_buffer), 0); - return 0; - } - - // Reads the tokenizer from the blob store file. - BlobError LoadTokenizer(std::string& tokenizer) { - hwy::uint128_t key = MakeKey(BlobToc::kTokenizerName); - size_t tokenizer_size = reader_.BlobSize(key); - if (tokenizer_size == 0) return __LINE__; - tokenizer.resize(tokenizer_size); - ; - BlobError err = reader_.ReadOne(key, tokenizer.data(), tokenizer_size); - if (err != 0) { - fprintf(stderr, "Failed to read tokenizer (error %d)\n", err); - return err; - } - return 0; - } - - // Called for each tensor, enqueues read requests. - void operator()(const char* name, hwy::Span tensors) { - if (file_toc_.Empty() || file_toc_.Contains(name)) { - model_toc_.push_back(tensors[0]); - file_keys_.push_back(name); - } - } - - BlobError LoadScales(float* scales, size_t len) { - for (size_t i = 0; i < len; ++i) { - scales[i] = 1.0f; - } - MatPtrT scales_ptr("scales", 0, 1); - auto key = MakeKey(scales_ptr.CacheName().c_str()); - if (reader_.BlobSize(key) == 0) return 0; - return reader_.Enqueue(key, scales, len * sizeof(scales[0])); - } - - // Returns whether all tensors are successfully loaded from cache. - BlobError ReadAll(hwy::ThreadPool& pool, - std::vector& model_memory) { - // reader_ invalid or any Enqueue failed - if (err_ != 0) return err_; - // Setup the model_memory. - for (size_t b = 0; b < model_toc_.size(); ++b) { - const std::string& file_key = file_keys_[b]; - MatPtr* blob = model_toc_[b]; - if (!file_toc_.Empty()) { - const MatPtr* toc_blob = file_toc_.Get(file_key); - if (toc_blob == nullptr) { - fprintf(stderr, "Blob %s not found in TOC\n", file_key.c_str()); - return __LINE__; - } - if (toc_blob->Rows() != blob->Rows() || - toc_blob->Cols() != blob->Cols()) { - fprintf(stderr, "Blob %s has size mismatch TOC\n", file_key.c_str()); - return __LINE__; - } - std::string name = blob->Name(); - *blob = *toc_blob; - blob->SetName(name); - } - model_memory.emplace_back(*blob); - model_memory.back().SetName(file_key); - } - // Allocate in parallel using the pool. - pool.Run(0, model_memory.size(), - [this, &model_memory](uint64_t task, size_t /*thread*/) { - model_memory[task].Allocate(); - model_toc_[task]->SetPtr(model_memory[task]); - }); - // Enqueue the read requests. - for (auto& blob : model_memory) { - err_ = - reader_.Enqueue(MakeKey(blob.Name()), blob.data(), blob.SizeBytes()); - if (err_ != 0) { - fprintf(stderr, - "Failed to read blob %s (error %d) of size %zu x %zu x %zu\n", - blob.Name(), err_, blob.Rows(), blob.Cols(), - blob.ElementSize()); - return err_; - } - } - return reader_.ReadAll(pool); - } - - private: - BlobReader reader_; - BlobError err_ = 0; - // Table of contents from the file, if present. - BlobToc file_toc_; - // Table of contents from the model. Pointers to original MatPtrT so the - // data pointers can be updated. - std::vector model_toc_; - // Mangled names of the tensors in model_toc_ for reading from the file. - std::vector file_keys_; -}; +// Returns 1.0f if all magnitudes are <= `SfpStream::kMax`, otherwise scales +// them such that the largest magnitude is `SfpStream::kMax`, and returns the +// multiplier with which to restore the original values. This is only necessary +// before compressing to `SfpStream` and `NuqStream`. +float ScaleWeights(float* HWY_RESTRICT raw, size_t num); } // namespace gcpp #endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_COMPRESS_H_ diff --git a/compression/compress_test.cc b/compression/compress_test.cc index 13b1982..2270689 100644 --- a/compression/compress_test.cc +++ b/compression/compress_test.cc @@ -13,10 +13,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -// SFP uses ConcatEven/Odd which are not supported; skip SVE for faster tests. +#include "compression/types.h" #ifndef HWY_DISABLED_TARGETS -#define HWY_DISABLED_TARGETS (HWY_SCALAR | HWY_SVE) -#endif +#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS +#endif // HWY_DISABLED_TARGETS #include "compression/compress.h" @@ -80,7 +80,7 @@ struct TestDecompress2T { stats.Notify(raw[i], hwy::ConvertScalarTo(dec[i])); } - if constexpr (false) { + if constexpr (true) { // leave enabled due to sporadic failures fprintf(stderr, "TypeName() %s TypeName() %s: num %zu: stats.SumL1() " "%f stats.GeomeanValueDivL1() %f stats.WeightedAverageL1() %f " diff --git a/compression/compress_weights.cc b/compression/compress_weights.cc deleted file mode 100644 index cbf7e35..0000000 --- a/compression/compress_weights.cc +++ /dev/null @@ -1,286 +0,0 @@ -// Copyright 2024 Google LLC -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Command line tool to create compressed weights. - -// Compiles this file for multiple architectures via "foreach_target.h", to -// which we pass the filename via macro 'argument'. -#undef HWY_TARGET_INCLUDE -#define HWY_TARGET_INCLUDE \ - "compression/compress_weights.cc" // NOLINT -#include "hwy/foreach_target.h" // IWYU pragma: keep -#include "hwy/highway.h" -// After highway.h -#include "compression/compress-inl.h" -#include "gemma/configs.h" -#include "gemma/tokenizer.h" - -#ifndef GEMMA_COMPRESS_WEIGHTS_ONCE -#define GEMMA_COMPRESS_WEIGHTS_ONCE - -#include -#include - -#include // std::clamp -#include -#include -#include -#include // NOLINT -#include - -#include "compression/compress.h" -#include "compression/io.h" // Path -#include "compression/shared.h" // PromptWrapping -#include "gemma/common.h" // Model -#include "gemma/weights.h" -#include "util/allocator.h" -#include "util/args.h" -#include "hwy/base.h" -#include "hwy/contrib/thread_pool/thread_pool.h" - -namespace gcpp { - -namespace { - -} // namespace - -struct Args : public ArgsBase { - static constexpr size_t kDefaultNumThreads = ~size_t{0}; - - void ChooseNumThreads() { - if (num_threads == kDefaultNumThreads) { - // This is a rough heuristic, replace with something better in the future. - num_threads = static_cast(std::clamp( - static_cast(std::thread::hardware_concurrency()) - 2, 1, 18)); - } - } - - public: - Args(int argc, char* argv[]) { - InitAndParse(argc, argv); - ChooseNumThreads(); - } - - // Returns error string or nullptr if OK. - const char* Validate() { - if (const char* err = ParseModelTypeAndWrapping(model_type_str, model_type_, - prompt_wrapping_)) { - return err; - } - if (const char* err = ParseType(weight_type_str, weight_type_)) { - return err; - } - if (weights.path.empty()) { - return "Missing --weights flag, a file for the uncompressed model."; - } - if (compressed_weights.path.empty()) { - return "Missing --compressed_weights flag, a file for the compressed " - "model."; - } - if (!weights.Exists()) { - return "Can't open file specified with --weights flag."; - } - return nullptr; - } - - Path weights; // uncompressed weights file location - Path compressed_weights; // compressed weights file location - std::string model_type_str; - std::string weight_type_str; - size_t num_threads; - // If non-empty, whether to include the config and TOC in the output file, as - // well as the tokenizer. - Path tokenizer; - - template - void ForEach(const Visitor& visitor) { - visitor(weights, "weights", Path(), - "Path to model weights (.bin) file.\n" - " Required argument."); - visitor(model_type_str, "model", std::string(), - "Model type\n 2b-it = 2B parameters, instruction-tuned\n " - "2b-pt = 2B parameters, pretrained\n 7b-it = 7B parameters " - "instruction-tuned\n 7b-pt = 7B parameters, pretrained\n " - "gr2b-it = griffin 2B parameters, instruction-tuned\n " - "gr2b-pt = griffin 2B parameters, pretrained\n " - " Required argument."); - visitor(weight_type_str, "weight_type", std::string("sfp"), - "Weight type\n f32 = float, bf16 = bfloat16, SFP = 8-bit FP\n" - " Required argument."); - visitor(compressed_weights, "compressed_weights", Path(), - "Path name where compressed weights (.sbs) file will be written.\n" - " Required argument."); - visitor(num_threads, "num_threads", - kDefaultNumThreads, // see ChooseNumThreads - "Number of threads to use.\n Default = Estimate of the " - "number of supported concurrent threads.", - 2); - visitor(tokenizer, "tokenizer", Path(), - "Path to tokenizer file. If given, the config and TOC are also " - "added to the output file."); - } - - // Uninitialized before Validate, must call after that. - gcpp::Model ModelType() const { return model_type_; } - gcpp::PromptWrapping PromptWrappingType() const { return prompt_wrapping_; } - gcpp::Type WeightType() const { return weight_type_; } - - private: - Model model_type_; - PromptWrapping prompt_wrapping_; - Type weight_type_; -}; - -void ShowHelp(gcpp::Args& args) { - std::cerr - << "Usage:\n./compress_weights --weights " - " --model --compressed_weights \n"; - std::cerr << "\n*Arguments*\n\n"; - args.Help(); - std::cerr << "\n"; -} - -} // namespace gcpp -#endif // GEMMA_COMPRESS_WEIGHTS_ONCE - -// SIMD code, compiled once per target. -HWY_BEFORE_NAMESPACE(); -namespace gcpp { -namespace HWY_NAMESPACE { - -template -void CompressWeights(const Path& weights_path, - const Path& compressed_weights_path, Model model_type, - Type weight_type, PromptWrapping wrapping, - const Path& tokenizer_path, hwy::ThreadPool& pool) { - if (!weights_path.Exists()) { - HWY_ABORT("The model weights file '%s' does not exist.", - weights_path.path.c_str()); - } - printf("Compressing weights from %s to %s\n", weights_path.path.c_str(), - compressed_weights_path.path.c_str()); - ModelConfig config = ConfigFromModel(model_type); - config.weight = weight_type; - config.wrapping = wrapping; - std::vector model_storage; - ModelWeightsPtrs c_weights(config); - c_weights.Allocate(model_storage, pool); - ModelWeightsPtrs uc_weights(config); - uc_weights.Allocate(model_storage, pool); - // Get uncompressed weights, compress, and store. - FILE* fptr = fopen(weights_path.path.c_str(), "rb"); - if (fptr == nullptr) { - HWY_ABORT("Failed to open model file %s - does it exist?", - weights_path.path.c_str()); - } - bool ok = true; - uint64_t total_size = 0; - ModelWeightsPtrs::ForEachTensor( - {&uc_weights}, ForEachType::kLoadNoToc, - [&](const char* name, hwy::Span tensors) { - fprintf(stderr, "Loading Parameters (size %zu): %s\n", - tensors[0]->SizeBytes(), name); - ok &= 1 == fread(tensors[0]->Ptr(), tensors[0]->SizeBytes(), 1, fptr); - total_size += tensors[0]->SizeBytes(); - }); - if (!tokenizer_path.path.empty()) { - uc_weights.AllocAndCopyWithTranspose(pool, model_storage); - } - const bool scale_for_compression = config.num_tensor_scales > 0; - std::vector scales; - if (scale_for_compression) { - uc_weights.GetOrApplyScales(scales); - } - Compressor compressor(pool); - ModelWeightsPtrs::ForEachTensor( - {reinterpret_cast*>(&uc_weights), &c_weights}, - tokenizer_path.path.empty() ? ForEachType::kLoadNoToc - : ForEachType::kLoadWithToc, - [&compressor](const char* name, hwy::Span tensors) { - tensors[1]->CallUpcasted( - compressor, name, - reinterpret_cast(tensors[0]->Ptr())); - }); - if (!tokenizer_path.path.empty()) { - std::string tokenizer_proto = ReadFileToString(tokenizer_path); - compressor.AddTokenizer(tokenizer_proto); - } else { - compressor.AddScales(scales.data(), scales.size() * sizeof(scales[0])); - } - compressor.WriteAll(compressed_weights_path, - tokenizer_path.path.empty() ? nullptr : &config); -} - -} // namespace HWY_NAMESPACE -} // namespace gcpp -HWY_AFTER_NAMESPACE(); - -#if HWY_ONCE -namespace gcpp { - -void Run(Args& args) { - hwy::ThreadPool pool(args.num_threads); - if (args.PromptWrappingType() == PromptWrapping::PALIGEMMA) { - HWY_ABORT("PaliGemma is not supported in compress_weights."); - } - const Model model_type = args.ModelType(); - const Type weight_type = args.WeightType(); - switch (weight_type) { - case Type::kF32: - HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(CompressWeights) - (args.weights, args.compressed_weights, model_type, weight_type, - args.PromptWrappingType(), args.tokenizer, pool); - break; - case Type::kBF16: - HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(CompressWeights) - (args.weights, args.compressed_weights, model_type, weight_type, - args.PromptWrappingType(), args.tokenizer, pool); - break; - case Type::kSFP: - HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(CompressWeights) - (args.weights, args.compressed_weights, model_type, weight_type, - args.PromptWrappingType(), args.tokenizer, pool); - break; - case Type::kNUQ: - HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(CompressWeights) - (args.weights, args.compressed_weights, model_type, weight_type, - args.PromptWrappingType(), args.tokenizer, pool); - break; - default: - HWY_ABORT("Weight type %d unsupported.", static_cast(weight_type)); - } -} - -} // namespace gcpp - -int main(int argc, char** argv) { - gcpp::Args args(argc, argv); - - if (gcpp::HasHelp(argc, argv)) { - gcpp::ShowHelp(args); - return 0; - } - - if (const char* error = args.Validate()) { - gcpp::ShowHelp(args); - HWY_ABORT("\nInvalid args: %s", error); - } - - gcpp::Run(args); - - return 0; -} - -#endif // HWY_ONCE diff --git a/compression/convert_weights.py b/compression/convert_weights.py deleted file mode 100644 index 3ba1642..0000000 --- a/compression/convert_weights.py +++ /dev/null @@ -1,209 +0,0 @@ -# Copyright 2024 Google LLC -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Converts pytorch to f32 for use by compress_weights.cc.""" - -import argparse -import collections -import os -from gemma import config -from gemma import model as gemma_model -import numpy as np -import torch - -# Requires torch 2.2 and gemma package from -# https://github.com/google/gemma_pytorch - - -def check_file_exists(value): - if not os.path.exists(str(value)): - raise argparse.ArgumentTypeError( - "The file %s does not appear to exist." % value - ) - return value - - -def check_model_types(value): - if str(value).lower() not in ["2b", "7b"]: - raise argparse.ArgumentTypeError( - "Model type value %s is not in [2b, 7b]." % value - ) - return value - - -parser = argparse.ArgumentParser() -parser.add_argument( - "--tokenizer", - dest="tokenizer", - default="models/tokenizer.spm", - help="Location of tokenizer file (.model or .spm)", - type=check_file_exists, -) - -parser.add_argument( - "--weights", - dest="weights", - default="models/gemma-2b-it.ckpt", - help="Location of input checkpoint file (.ckpt)", - type=check_file_exists, -) - -parser.add_argument( - "--output_file", - dest="output_file", - default="2bit-f32.sbs", - help="Location to write converted weights", - type=str, -) - -parser.add_argument( - "--model_type", - dest="model_type", - default="2b", - help="Model size / type (2b, 7b)", - type=check_model_types, -) - -args = parser.parse_args() - - -TRANSFORMATIONS = { - "2b": collections.defaultdict( - lambda: lambda x: x, - { - "embedder.weight": lambda x: x, - "self_attn.qkv_proj.weight": lambda x: x.reshape((10, 256, 2048)), - "self_attn.o_proj.weight": lambda x: x.reshape( - (2048, 8, 256) - ).transpose([1, 0, 2]), - "mlp.gate_proj.weight": lambda x: x[np.newaxis, :, :], - "mlp.up_proj.weight": lambda x: x[np.newaxis, :, :], - "mlp.down_proj.weight": lambda x: x, - }, - ), - "7b": collections.defaultdict( - lambda: lambda x: x, - { - "embedder.weight": lambda x: x, - "self_attn.qkv_proj.weight": lambda x: x.reshape( - (3, 16, 256, 3072) - ).transpose([1, 0, 2, 3]), - "self_attn.o_proj.weight": lambda x: x.reshape( - (3072, 16, 256) - ).transpose([1, 0, 2]), - "mlp.gate_proj.weight": lambda x: x[np.newaxis, :, :], - "mlp.up_proj.weight": lambda x: x[np.newaxis, :, :], - "mlp.down_proj.weight": lambda x: x, - }, - ), -} - -VALIDATIONS = { - "2b": { - "embedder.weight": lambda x: x.shape == (256000, 2048), - "model.norm.weight": lambda x: x.shape == (2048,), - "self_attn.qkv_proj.weight": lambda x: x.shape == (10, 256, 2048), - "self_attn.o_proj.weight": lambda x: x.shape == (8, 2048, 256), - "mlp.gate_proj.weight": lambda x: x.shape == (1, 16384, 2048), - "mlp.up_proj.weight": lambda x: x.shape == (1, 16384, 2048), - "mlp.down_proj.weight": lambda x: x.shape == (2048, 16384), - "input_layernorm.weight": lambda x: x.shape == (2048,), - "post_attention_layernorm.weight": lambda x: x.shape == (2048,), - }, - "7b": { - "embedder.weight": lambda x: x.shape == (256000, 3072), - "model.norm.weight": lambda x: x.shape == (3072,), - "self_attn.qkv_proj.weight": lambda x: x.shape == (16, 3, 256, 3072), - "self_attn.o_proj.weight": lambda x: x.shape == (16, 3072, 256), - "mlp.gate_proj.weight": lambda x: x.shape == (1, 24576, 3072), - "mlp.up_proj.weight": lambda x: x.shape == (1, 24576, 3072), - "mlp.down_proj.weight": lambda x: x.shape == (3072, 24576), - "input_layernorm.weight": lambda x: x.shape == (3072,), - "post_attention_layernorm.weight": lambda x: x.shape == (3072,), - }, -} - - -def param_names(num_hidden_layers: int): - """Return parameter names in the order they are expected for deserialization.""" - - # note *weight_scaler params are ignored in the forward computation unless - # quantization is being used. - # - # since we are working with the full precision weights as input, don't - # include these in the parameters being iterated over. - - names = [ - ("embedder.weight",) * 2, # embedder_input_embedding - ("model.norm.weight",) * 2, # final_norm_scale - ] - layer_params = [ - "self_attn.o_proj.weight", # attn_vec_einsum_w - "self_attn.qkv_proj.weight", # qkv_einsum_w - "mlp.gate_proj.weight", # gating_einsum_w - "mlp.up_proj.weight", - "mlp.down_proj.weight", # linear_w - "input_layernorm.weight", # pre_attention_norm_scale - "post_attention_layernorm.weight", # pre_ffw_norm_scale - ] - for layer in range(num_hidden_layers): - for layer_param in layer_params: - names = names + [(f"model.layers.{layer}.{layer_param}", layer_param)] - return names - - -def convert_weights(): - """Main function; loads weights, runs transformations, writes f32.""" - model_type = args.model_type - output_file = args.output_file - - model_config = config.get_model_config(model_type) - model_config.dtype = "float32" - model_config.tokenizer = args.tokenizer - device = torch.device("cpu") - torch.set_default_dtype(torch.float) - model = gemma_model.GemmaForCausalLM(model_config) - - model.load_weights(args.weights) - model.to(device).eval() - - model_dict = dict(model.named_parameters()) - param_order = param_names(model_config.num_hidden_layers) - - all_ok = True - print("Checking transformations ...") - for name, layer_name in param_order: - arr = model_dict[name].detach().numpy() - arr = TRANSFORMATIONS[model_type][layer_name](arr) - check = "OK" if VALIDATIONS[model_type][layer_name](arr) else "FAILED" - - if check == "FAILED": - all_ok = False - print(f" {name : <60}{str(arr.shape) : <20}{check}") - - if all_ok: - print("Writing parameters ...") - with open(output_file, "wb") as bin_handle: - for name, layer_name in param_order: - arr = model_dict[name].detach().numpy() - arr = TRANSFORMATIONS[model_type][layer_name](arr) - check = "OK" if VALIDATIONS[model_type][layer_name](arr) else "FAILED" - print(f" {name : <60}{str(arr.shape) : <20}{check}") - arr.flatten().astype(np.float32).tofile(bin_handle) - - -if __name__ == "__main__": - convert_weights() - print("Done") diff --git a/compression/distortion_test.cc b/compression/distortion_test.cc index 9350b5b..c52ecca 100644 --- a/compression/distortion_test.cc +++ b/compression/distortion_test.cc @@ -17,7 +17,7 @@ #include -#include "compression/shared.h" // SfpStream::kMax +#include "compression/types.h" // SfpStream::kMax #include "util/test_util.h" #include "hwy/nanobenchmark.h" #include "hwy/tests/hwy_gtest.h" diff --git a/compression/io.cc b/compression/io.cc deleted file mode 100644 index 84e3603..0000000 --- a/compression/io.cc +++ /dev/null @@ -1,121 +0,0 @@ -// Copyright 2024 Google LLC -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Safe to be first, does not include POSIX headers. -#include "hwy/detect_compiler_arch.h" -// Only compile this file on non-Windows; it replaces io_win.cc. It is easier to -// check this in source code because we support multiple build systems. -#if !HWY_OS_WIN - -// Request POSIX 2008, including `pread()` and `posix_fadvise()`. -#if !defined(_XOPEN_SOURCE) || _XOPEN_SOURCE < 700 -#undef _XOPEN_SOURCE -#define _XOPEN_SOURCE 700 -#endif -#if !defined(_POSIX_C_SOURCE) || _POSIX_C_SOURCE < 200809 -#define _POSIX_C_SOURCE 200809 -#endif - -// Make `off_t` 64-bit even on 32-bit systems. Works for Android >= r15c. -#undef _FILE_OFFSET_BITS -#define _FILE_OFFSET_BITS 64 - -#include // open -#include -#include -#include // SEEK_END - unistd isn't enough for IDE. -#include // O_RDONLY -#include // read, write, close - -#include - -#include "compression/io.h" -#include "hwy/base.h" // HWY_ASSERT - -namespace gcpp { - -class FilePosix : public File { - int fd_ = 0; - - public: - explicit FilePosix(int fd) : fd_(fd) { HWY_ASSERT(fd > 0); } - ~FilePosix() override { - if (fd_ != 0) { - HWY_ASSERT(close(fd_) != -1); - } - } - - uint64_t FileSize() const override { - static_assert(sizeof(off_t) == 8, "64-bit off_t required"); - const off_t size = lseek(fd_, 0, SEEK_END); - if (size < 0) { - return 0; - } - return static_cast(size); - } - - bool Read(uint64_t offset, uint64_t size, void* to) const override { - uint8_t* bytes = reinterpret_cast(to); - uint64_t pos = 0; - for (;;) { - // pread seems to be faster than lseek + read when parallelized. - const auto bytes_read = pread(fd_, bytes + pos, size - pos, offset + pos); - if (bytes_read <= 0) break; - pos += bytes_read; - HWY_ASSERT(pos <= size); - if (pos == size) break; - } - return pos == size; // success if managed to read desired size - } - - bool Write(const void* from, uint64_t size, uint64_t offset) override { - const uint8_t* bytes = reinterpret_cast(from); - uint64_t pos = 0; - for (;;) { - const auto bytes_written = - pwrite(fd_, bytes + pos, size - pos, offset + pos); - if (bytes_written <= 0) break; - pos += bytes_written; - HWY_ASSERT(pos <= size); - if (pos == size) break; - } - return pos == size; // success if managed to write desired size - } -}; // FilePosix - -HWY_MAYBE_UNUSED extern std::unique_ptr OpenFileGoogle( - const Path& filename, const char* mode); - -std::unique_ptr OpenFileOrNull(const Path& filename, const char* mode) { - std::unique_ptr file; // OpenFileGoogle omitted - if (file) return file; - - const bool is_read = mode[0] != 'w'; - const int flags = is_read ? O_RDONLY : O_CREAT | O_RDWR | O_TRUNC; - const int fd = open(filename.path.c_str(), flags, 0644); - if (fd < 0) return file; - -#if HWY_OS_LINUX && (!defined(__ANDROID_API__) || __ANDROID_API__ >= 21) - if (is_read) { - // Doubles the readahead window, which seems slightly faster when cached. - (void)posix_fadvise(fd, 0, 0, POSIX_FADV_SEQUENTIAL); - } -#endif - - return std::make_unique(fd); -} - -} // namespace gcpp -#endif // !HWY_OS_WIN diff --git a/compression/nuq-inl.h b/compression/nuq-inl.h index 150ad79..997bb5b 100644 --- a/compression/nuq-inl.h +++ b/compression/nuq-inl.h @@ -23,7 +23,7 @@ #include -#include "compression/shared.h" +#include "compression/types.h" #include "util/basics.h" #include "hwy/base.h" diff --git a/compression/nuq_test.cc b/compression/nuq_test.cc index 6dd5982..df300f4 100644 --- a/compression/nuq_test.cc +++ b/compression/nuq_test.cc @@ -13,20 +13,20 @@ // See the License for the specific language governing permissions and // limitations under the License. -// SFP uses ConcatEven/Odd which are not supported; skip SVE for faster tests. +#include "compression/types.h" #ifndef HWY_DISABLED_TARGETS -#define HWY_DISABLED_TARGETS (HWY_SCALAR | HWY_SVE) -#endif +#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS +#endif // HWY_DISABLED_TARGETS #include #include #include #include // std::shuffle +#include #include #include "compression/distortion.h" -#include "compression/shared.h" #include "util/test_util.h" #include "hwy/aligned_allocator.h" #include "hwy/base.h" @@ -104,7 +104,7 @@ struct TestPlateaus { HWY_ASSERT(-0.5f <= in[i] && in[i] < 0.5f); } - std::random_device rd; + std::random_device rd; // NOLINT std::mt19937 rng(rd()); std::shuffle(in.get(), in.get() + kGroupSize, rng); @@ -151,7 +151,7 @@ struct TestRamp { HWY_ASSERT(-0.45f <= in[i] && in[i] < 0.55f); } - std::random_device rd; + std::random_device rd; // NOLINT std::mt19937 rng(rd()); std::shuffle(in.get(), in.get() + kGroupSize, rng); @@ -246,7 +246,8 @@ struct TestOffset { auto in = hwy::AllocateAligned(total); // Enc() requires f32 auto dec1 = hwy::AllocateAligned(total); auto dec2 = hwy::AllocateAligned(kMidLen); - auto nuq = hwy::AllocateAligned(NuqStream::PackedEnd(total)); + auto nuq = hwy::AllocateAligned( + hwy::RoundUpTo(NuqStream::PackedEnd(total), hwy::VectorBytes())); HWY_ASSERT(in && dec1 && dec2 && nuq); const auto nuq_span = MakeSpan(nuq.get(), total); @@ -296,7 +297,8 @@ struct TestUnalignedOffset { auto in = hwy::AllocateAligned(total); // Enc() requires f32 auto dec1 = hwy::AllocateAligned(total); - auto nuq = hwy::AllocateAligned(NuqStream::PackedEnd(total)); + auto nuq = hwy::AllocateAligned( + hwy::RoundUpTo(NuqStream::PackedEnd(total), hwy::VectorBytes())); auto dec2 = hwy::AllocateAligned(num_decompressed); HWY_ASSERT(in && dec1 && dec2 && nuq); const auto nuq_span = MakeSpan(nuq.get(), total); @@ -347,7 +349,8 @@ struct TestDec2 { auto dec0 = hwy::AllocateAligned(total); auto dec1 = hwy::AllocateAligned(total); auto dec2 = hwy::AllocateAligned(kMidLen); - auto nuq = hwy::AllocateAligned(NuqStream::PackedEnd(total)); + auto nuq = hwy::AllocateAligned( + hwy::RoundUpTo(NuqStream::PackedEnd(total), hwy::VectorBytes())); HWY_ASSERT(in && dec0 && dec1 && dec2 && nuq); const auto nuq_span = MakeSpan(nuq.get(), total); @@ -449,7 +452,8 @@ struct TestEncDec { const size_t num = 4 * kGroupSize; auto in = hwy::AllocateAligned(num); // Enc() requires f32 auto out = hwy::AllocateAligned(num); // already padded - auto nuq = hwy::AllocateAligned(NuqStream::PackedEnd(num)); + auto nuq = hwy::AllocateAligned( + hwy::RoundUpTo(NuqStream::PackedEnd(num), hwy::VectorBytes())); HWY_ASSERT(in && out && nuq); const auto nuq_span = MakeSpan(nuq.get(), num); @@ -512,6 +516,7 @@ HWY_AFTER_NAMESPACE(); #if HWY_ONCE namespace gcpp { HWY_BEFORE_TEST(NuqTest); +#if GEMMA_ENABLE_NUQ HWY_EXPORT_AND_TEST_P(NuqTest, TestAllFlat); HWY_EXPORT_AND_TEST_P(NuqTest, TestAllPlateaus); HWY_EXPORT_AND_TEST_P(NuqTest, TestAllRamp); @@ -525,6 +530,9 @@ HWY_EXPORT_AND_TEST_P(NuqTest, TestUnalignedOffsetF32); HWY_EXPORT_AND_TEST_P(NuqTest, TestAllNibble); HWY_EXPORT_AND_TEST_P(NuqTest, TestEncDecBF16); HWY_EXPORT_AND_TEST_P(NuqTest, TestEncDecF32); +#else +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(NuqTest); +#endif // GEMMA_ENABLE_NUQ HWY_AFTER_TEST(); } // namespace gcpp #endif // HWY_ONCE diff --git a/compression/python/BUILD.bazel b/compression/python/BUILD.bazel index 8bfb391..474511c 100644 --- a/compression/python/BUILD.bazel +++ b/compression/python/BUILD.bazel @@ -14,11 +14,16 @@ cc_library( hdrs = ["compression_clif_aux.h"], visibility = ["//visibility:private"], deps = [ - "@abseil-cpp//absl/types:span", - "//:common", + "//:basics", + "//:configs", + "//:mat", + "//:model_store", + "//:tensor_info", + "//:threading_context", "//:tokenizer", "//compression:compress", - "//compression:io", + "//io", + "//io:blob_store", "@highway//:hwy", "@highway//:thread_pool", ], @@ -29,9 +34,9 @@ pybind_extension( srcs = ["compression_extension.cc"], deps = [ ":compression_clif_aux", - "@abseil-cpp//absl/types:span", - "//:common", - "//compression:sfp", + "//:mat", + "//:tensor_info", + "//compression:types", ], ) diff --git a/compression/python/compression_clif_aux.cc b/compression/python/compression_clif_aux.cc index 2705756..2de1b67 100644 --- a/compression/python/compression_clif_aux.cc +++ b/compression/python/compression_clif_aux.cc @@ -15,14 +15,29 @@ #include "compression/python/compression_clif_aux.h" -#include -#include +#include "compression/types.h" // GEMMA_DISABLED_TARGETS +#ifndef HWY_DISABLED_TARGETS +#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS +#endif // HWY_DISABLED_TARGETS + +#include +#include +#include + #include #include -#include "compression/compress.h" -#include "compression/shared.h" -#include "hwy/aligned_allocator.h" +#include "compression/compress.h" // ScaleWeights +#include "gemma/configs.h" // ModelConfig +#include "gemma/model_store.h" // ModelStore +#include "gemma/tensor_info.h" // TensorInfo +#include "gemma/tokenizer.h" +#include "io/blob_store.h" // BlobWriter +#include "io/io.h" // Path +#include "util/basics.h" +#include "util/mat.h" +#include "util/threading_context.h" +#include "hwy/contrib/thread_pool/thread_pool.h" #undef HWY_TARGET_INCLUDE #define HWY_TARGET_INCLUDE \ @@ -32,157 +47,97 @@ // After highway.h #include "compression/compress-inl.h" -// Non-SIMD includes and types. Note that HWY_ONCE is only true on the last -// compile pass, whereas we want this defined in the first. -#ifndef GEMMA_ONCE -#define GEMMA_ONCE - -#include "absl/types/span.h" -#include "compression/io.h" -#include "gemma/configs.h" -#include "gemma/tensor_index.h" -#include "gemma/tokenizer.h" -#include "hwy/base.h" -#include "hwy/contrib/thread_pool/thread_pool.h" - -namespace gcpp { - -class WriterInterface { - public: - virtual ~WriterInterface() = default; - - virtual void Insert(std::string name, absl::Span weights, - Type type, const TensorInfo& tensor_info, - float scale) = 0; - virtual void InsertSfp(std::string name, absl::Span weights) = 0; - virtual void InsertNUQ(std::string name, absl::Span weights) = 0; - virtual void InsertBfloat16(std::string name, - absl::Span weights) = 0; - virtual void InsertFloat(std::string name, - absl::Span weights) = 0; - virtual void AddScales(const std::vector& scales) = 0; - virtual void AddTokenizer(const std::string& tokenizer_path) = 0; - - virtual size_t DebugNumBlobsAdded() const = 0; - - virtual int WriteWithConfig(std::string path, const ModelConfig* config) = 0; -}; - -} // namespace gcpp - -#endif // GEMMA_ONCE - // SIMD code, compiled once per target. HWY_BEFORE_NAMESPACE(); namespace gcpp { namespace HWY_NAMESPACE { -class SbsWriterImpl : public WriterInterface { +// Implementation for the currently compiled SIMD target. +class SbsWriterImpl : public ISbsWriter { template - void AllocateAndCompress(const std::string& name, - absl::Span weights) { - MatPtrT storage(name, 1, weights.size()); - model_memory_.push_back(storage); - model_memory_.back().Allocate(); - storage.SetPtr(model_memory_.back()); - std::string decorated_name = storage.CacheName(); - compressor_(&storage, decorated_name.c_str(), weights.data()); - } - template - void AllocateWithShape(const std::string& name, - absl::Span weights, - const TensorInfo& tensor_info, float scale) { - MatPtrT storage(name, &tensor_info); - storage.set_scale(scale); + void InsertT(const char* name, F32Span weights, + const TensorInfo& tensor_info) { + // TODO(janwas): 1D parallel-for. + hwy::ThreadPool& pool = ctx_.pools.Pool(); - // Don't reset num_elements for NUQ. - if (!hwy::IsSame, NuqStream>()) { - storage.SetNumElements(CompressedArrayElements(weights.size())); + MatPtrT mat(name, ExtentsFromInfo(&tensor_info)); + // SFP and NUQ (which uses SFP for cluster centers) have a limited range + // and depending on the input values may require rescaling. Scaling is + // cheap for matmul and probably not an issue for other ops, but it might be + // beneficial for precision to keep the original data range for other types. + if (mat.GetType() == Type::kSFP || mat.GetType() == Type::kNUQ) { + mat.SetScale(ScaleWeights(weights.data(), weights.size())); } - model_memory_.push_back(storage); - if (mode_ == CompressorMode::kTEST_ONLY) return; - model_memory_.back().Allocate(); - storage.SetPtr(model_memory_.back()); - std::string decorated_name = storage.CacheName(); - compressor_(&storage, decorated_name.c_str(), weights.data()); + if (weights.size() == 0) { + HWY_WARN("Ignoring zero-sized tensor %s.", name); + return; + } + + mat.AppendTo(serialized_mat_ptrs_); + MatOwner mat_owner; + mat_owner.AllocateFor(mat, ctx_.allocator, MatPadding::kPacked); + + // Handle gemma_export_test's MockArray. Write blobs so that the test + // succeeds, but we only have 10 floats, not the full tensor. + if (weights.size() == 10 && mat.Extents().Area() != 10) { + Compress(weights.data(), weights.size(), working_set_, mat.Span(), + /*packed_ofs=*/0, pool); + writer_.Add(name, mat.Packed(), mat.ElementBytes() * 10); + return; + } + + fprintf(stderr, "Compressing %s (%zu x %zu = %zuM) to %s, please wait\n", + name, mat.Rows(), mat.Cols(), weights.size() / (1000 * 1000), + TypeName(TypeEnum())); + HWY_ASSERT(weights.size() == mat.Extents().Area()); + Compress(weights.data(), weights.size(), working_set_, mat.Span(), + /*packed_ofs=*/0, pool); + writer_.Add(name, mat.Packed(), mat.PackedBytes()); } public: - explicit SbsWriterImpl(CompressorMode mode) - : pool_(0), compressor_(pool_), mode_(mode) {} + SbsWriterImpl(const std::string& sbs_path) + : ctx_(ThreadingArgs()), + writer_(gcpp::Path(sbs_path), ctx_.pools.Pool()) {} - void Insert(std::string name, absl::Span weights, Type type, - const TensorInfo& tensor_info, float scale) override { + void Insert(const char* name, F32Span weights, Type type, + const TensorInfo& tensor_info) override { switch (type) { case Type::kSFP: - AllocateWithShape(name, weights, tensor_info, scale); + InsertT(name, weights, tensor_info); break; case Type::kNUQ: - AllocateWithShape(name, weights, tensor_info, scale); + InsertT(name, weights, tensor_info); break; case Type::kBF16: - AllocateWithShape(name, weights, tensor_info, scale); + InsertT(name, weights, tensor_info); break; case Type::kF32: - AllocateWithShape(name, weights, tensor_info, scale); + InsertT(name, weights, tensor_info); break; default: - HWY_ABORT("Unsupported type"); + HWY_ABORT("Unsupported destination (compressed) type %s", + TypeName(type)); } } - void InsertSfp(std::string name, absl::Span weights) override { - AllocateAndCompress(name, weights); + void Write(const ModelConfig& config, + const std::string& tokenizer_path) override { + const GemmaTokenizer tokenizer( + tokenizer_path.empty() ? kMockTokenizer + : ReadFileToString(Path(tokenizer_path))); + WriteSingleFile(config, tokenizer, serialized_mat_ptrs_, writer_); } - void InsertNUQ(std::string name, absl::Span weights) override { - AllocateAndCompress(name, weights); - } - - void InsertBfloat16(std::string name, - absl::Span weights) override { - AllocateAndCompress(name, weights); - } - - void InsertFloat(std::string name, absl::Span weights) override { - AllocateAndCompress(name, weights); - } - - void AddScales(const std::vector& scales) override { - HWY_ASSERT(scales_.empty()); - scales_ = scales; - compressor_.AddScales(scales_.data(), scales_.size()); - } - - void AddTokenizer(const std::string& tokenizer_path) override { - Path path(tokenizer_path); - GemmaTokenizer tokenizer(path); - std::string tokenizer_proto = tokenizer.Serialize(); - HWY_ASSERT(!tokenizer_proto.empty()); - compressor_.AddTokenizer(tokenizer_proto); - } - - // Returns the number of blobs added. - size_t DebugNumBlobsAdded() const { - if (mode_ == CompressorMode::kTEST_ONLY) return model_memory_.size(); - return compressor_.DebugNumBlobsAdded(); - } - - int WriteWithConfig(std::string path, const ModelConfig* config) override { - return compressor_.WriteAll(gcpp::Path(path), config); - } - - hwy::ThreadPool pool_; - Compressor compressor_; + ThreadingContext ctx_; CompressWorkingSet working_set_; - std::vector model_memory_; - std::vector scales_; - CompressorMode mode_; + BlobWriter writer_; + std::vector serialized_mat_ptrs_; }; -WriterInterface* NewSbsWriter(CompressorMode mode) { - return new SbsWriterImpl(mode); +ISbsWriter* NewSbsWriter(const std::string& sbs_path) { + return new SbsWriterImpl(sbs_path); } } // namespace HWY_NAMESPACE @@ -194,43 +149,11 @@ namespace gcpp { HWY_EXPORT(NewSbsWriter); -SbsWriter::SbsWriter(CompressorMode mode) - : impl_(HWY_DYNAMIC_DISPATCH(NewSbsWriter)(mode)) {} -SbsWriter::~SbsWriter() = default; +SbsWriter::SbsWriter(const std::string& path) + : impl_(HWY_DYNAMIC_DISPATCH(NewSbsWriter)(path)) {} -void SbsWriter::Insert(std::string name, absl::Span weights, - Type type, const TensorInfo& tensor_info, float scale) { - impl_->Insert(name, weights, type, tensor_info, scale); -} -void SbsWriter::InsertSfp(std::string name, absl::Span weights) { - impl_->InsertSfp(name, weights); -} -void SbsWriter::InsertNUQ(std::string name, absl::Span weights) { - impl_->InsertNUQ(name, weights); -} -void SbsWriter::InsertBfloat16(std::string name, - absl::Span weights) { - impl_->InsertBfloat16(name, weights); -} -void SbsWriter::InsertFloat(std::string name, absl::Span weights) { - impl_->InsertFloat(name, weights); -} - -void SbsWriter::AddScales(const std::vector& scales) { - impl_->AddScales(scales); -} - -void SbsWriter::AddTokenizer(const std::string& tokenizer_path) { - impl_->AddTokenizer(tokenizer_path); -} - -size_t SbsWriter::DebugNumBlobsAdded() const { - return impl_->DebugNumBlobsAdded(); -} - -int SbsWriter::WriteWithConfig(std::string path, const ModelConfig* config) { - return impl_->WriteWithConfig(path, config); -} +SbsReader::SbsReader(const std::string& path) + : reader_(Path(path)), model_(reader_) {} } // namespace gcpp #endif // HWY_ONCE diff --git a/compression/python/compression_clif_aux.h b/compression/python/compression_clif_aux.h index 4ea5b16..6979865 100644 --- a/compression/python/compression_clif_aux.h +++ b/compression/python/compression_clif_aux.h @@ -16,52 +16,67 @@ #ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_PYTHON_COMPRESSION_CLIF_AUX_H_ #define THIRD_PARTY_GEMMA_CPP_COMPRESSION_PYTHON_COMPRESSION_CLIF_AUX_H_ -#include +#include + #include #include -#include -#include "absl/types/span.h" -#include "compression/shared.h" +#include "compression/types.h" // Type #include "gemma/configs.h" -#include "gemma/tensor_index.h" +#include "gemma/model_store.h" +#include "gemma/tensor_info.h" +#include "io/blob_store.h" +#include "util/mat.h" +#include "hwy/aligned_allocator.h" // Span namespace gcpp { -// How to process the data. -enum class CompressorMode { - // No compression, no write to file, just for testing. - kTEST_ONLY, - // Old-style compression, no table of contents. - kNO_TOC, - // New-style compression, with table of contents. - kWITH_TOC, +// Can be modified in place by ScaleWeights. +using F32Span = hwy::Span; + +// Interface because we compile one derived implementation per SIMD target, +// because Compress() uses SIMD. +class ISbsWriter { + public: + virtual ~ISbsWriter() = default; + + virtual void Insert(const char* name, F32Span weights, Type type, + const TensorInfo& tensor_info) = 0; + + virtual void Write(const ModelConfig& config, + const std::string& tokenizer_path) = 0; }; -class WriterInterface; - +// Non-virtual class used by pybind that calls the interface's virtual methods. +// This avoids having to register the derived types with pybind. class SbsWriter { public: - explicit SbsWriter(CompressorMode mode); - ~SbsWriter(); + explicit SbsWriter(const std::string& sbs_path); - void Insert(std::string name, absl::Span weights, Type type, - const TensorInfo& tensor_info, float scale); - void InsertSfp(std::string name, absl::Span weights); - void InsertNUQ(std::string name, absl::Span weights); - void InsertBfloat16(std::string name, absl::Span weights); - void InsertFloat(std::string name, absl::Span weights); - void AddScales(const std::vector& scales); - void AddTokenizer(const std::string& tokenizer_path); + void Insert(const char* name, F32Span weights, Type type, + const TensorInfo& tensor_info) { + impl_->Insert(name, weights, type, tensor_info); + } - size_t DebugNumBlobsAdded() const; - - int Write(std::string path) { return WriteWithConfig(path, nullptr); } - int WriteWithConfig(std::string path, const ModelConfig* config); + void Write(const ModelConfig& config, const std::string& tokenizer_path) { + impl_->Write(config, tokenizer_path); + } private: - // Isolates Highway-dispatched types and other internals from CLIF. - std::unique_ptr impl_; + std::unique_ptr impl_; +}; + +// Limited metadata-only reader for tests. +class SbsReader { + public: + SbsReader(const std::string& path); + + const ModelConfig& Config() const { return model_.Config(); } + const MatPtr* FindMat(const char* name) const { return model_.FindMat(name); } + + private: + gcpp::BlobReader reader_; + gcpp::ModelStore model_; }; } // namespace gcpp diff --git a/compression/python/compression_extension.cc b/compression/python/compression_extension.cc index c873a23..e3d1556 100644 --- a/compression/python/compression_extension.cc +++ b/compression/python/compression_extension.cc @@ -15,58 +15,54 @@ #include #include -#include -#include #include -#include "absl/types/span.h" #include "compression/python/compression_clif_aux.h" -#include "compression/shared.h" +#include "compression/types.h" // Type +#include "gemma/tensor_info.h" +#include "util/mat.h" -using gcpp::CompressorMode; +using gcpp::MatPtr; +using gcpp::SbsReader; using gcpp::SbsWriter; -namespace py = pybind11; +namespace pybind11 { -namespace { template -void wrap_span(SbsWriter& writer, std::string name, py::array_t data) { +static void CallWithF32Span(SbsWriter& writer, const char* name, + array_t data, gcpp::Type type, + const gcpp::TensorInfo& tensor_info) { if (data.ndim() != 1 || data.strides(0) != sizeof(float)) { - throw std::domain_error("Input array must be 1D and densely packed."); + HWY_ABORT("Input array must be 1D (not %d) and contiguous floats.", + static_cast(data.ndim())); } - std::invoke(Func, writer, name, absl::MakeSpan(data.data(0), data.size())); + std::invoke(Func, writer, name, + gcpp::F32Span(data.mutable_data(0), data.size()), type, + tensor_info); } -template -void wrap_span_typed(SbsWriter& writer, std::string name, - py::array_t data, gcpp::Type type, - gcpp::TensorInfo tensor_info, float scale) { - if (data.ndim() != 1 || data.strides(0) != sizeof(float)) { - throw std::domain_error("Input array must be 1D and densely packed."); - } - std::invoke(Func, writer, name, absl::MakeSpan(data.data(0), data.size()), - type, tensor_info, scale); -} -} // namespace PYBIND11_MODULE(compression, m) { - py::enum_(m, "CompressorMode") - .value("TEST_ONLY", CompressorMode::kTEST_ONLY) - .value("NO_TOC", CompressorMode::kNO_TOC) - .value("WITH_TOC", CompressorMode::kWITH_TOC); + class_(m, "SbsWriter") + .def(init()) + .def("insert", CallWithF32Span<&SbsWriter::Insert>) + .def("write", &SbsWriter::Write, arg("config"), arg("tokenizer_path")); - py::class_(m, "SbsWriter") - .def(py::init()) - // NOTE: Individual compression backends may impose constraints on the - // array length, such as a minimum of (say) 32 elements. - .def("insert", wrap_span_typed<&SbsWriter::Insert>) - .def("insert_sfp", wrap_span<&SbsWriter::InsertSfp>) - .def("insert_nuq", wrap_span<&SbsWriter::InsertNUQ>) - .def("insert_bf16", wrap_span<&SbsWriter::InsertBfloat16>) - .def("insert_float", wrap_span<&SbsWriter::InsertFloat>) - .def("add_scales", &SbsWriter::AddScales) - .def("add_tokenizer", &SbsWriter::AddTokenizer) - .def("debug_num_blobs_added", &SbsWriter::DebugNumBlobsAdded) - .def("write", &SbsWriter::Write) - .def("write_with_config", &SbsWriter::WriteWithConfig); + class_(m, "MatPtr") + // No init, only created within C++. + .def_property_readonly("rows", &MatPtr::Rows, "Number of rows") + .def_property_readonly("cols", &MatPtr::Cols, "Number of cols") + .def_property_readonly("type", &MatPtr::GetType, "Element type") + .def_property_readonly("scale", &MatPtr::Scale, "Scaling factor"); + + class_(m, "SbsReader") + .def(init()) + .def_property_readonly("config", &SbsReader::Config, + return_value_policy::reference_internal, + "ModelConfig") + .def("find_mat", &SbsReader::FindMat, + return_value_policy::reference_internal, + "Returns MatPtr for given name."); } + +} // namespace pybind11 diff --git a/compression/python/compression_test.py b/compression/python/compression_test.py index fdf00e3..e8244ed 100644 --- a/compression/python/compression_test.py +++ b/compression/python/compression_test.py @@ -25,46 +25,120 @@ from python import configs class CompressionTest(absltest.TestCase): def test_sbs_writer(self): + info_192 = configs.TensorInfo() + info_192.name = "ignored_192" + info_192.axes = [0] + info_192.shape = [192] + temp_file = self.create_tempfile("test.sbs") - tensor_info = configs.TensorInfo() - tensor_info.name = "foo" - tensor_info.axes = [0] - tensor_info.shape = [192] - - writer = compression.SbsWriter(compression.CompressorMode.NO_TOC) + writer = compression.SbsWriter(temp_file.full_path) writer.insert( - "foo", - np.array([0.0012] * 128 + [0.001] * 64, dtype=np.float32), + "tensor0", + # Large enough to require scaling. + np.array([3.0012] * 128 + [4.001] * 64, dtype=np.float32), configs.Type.kSFP, - tensor_info, - 1.0, + info_192, ) - tensor_info_nuq = configs.TensorInfo() - tensor_info_nuq.name = "fooNUQ" - tensor_info_nuq.axes = [0] - tensor_info_nuq.shape = [256] + # 2D tensor. + info_2d = configs.TensorInfo() + info_2d.name = "ignored_2d" + info_2d.axes = [0, 1] + info_2d.shape = [96, 192] writer.insert( - "fooNUQ", + "tensor_2d", + np.array([i / 1e3 for i in range(96 * 192)], dtype=np.float32), + configs.Type.kBF16, + info_2d, + ) + + # 3D collapsed into rows. + info_3d = configs.TensorInfo() + info_3d.name = "ignored_3d" + info_3d.axes = [0, 1, 2] + info_3d.shape = [10, 12, 192] + info_3d.cols_take_extra_dims = False + writer.insert( + "tensor_3d", + # Verification of scale below depends on the shape and multiplier here. + np.array([i / 1e3 for i in range(10 * 12 * 192)], dtype=np.float32), + configs.Type.kSFP, + info_3d, + ) + + # Exercise all types supported by Compress. + info_256 = configs.TensorInfo() + info_256.name = "ignored_256" + info_256.axes = [0] + info_256.shape = [256] + writer.insert( + "tensor_sfp", np.array([0.000375] * 128 + [0.00009] * 128, dtype=np.float32), - configs.Type.kNUQ, - tensor_info_nuq, - 1.0, + configs.Type.kSFP, + info_256, ) - writer.insert_sfp( - "bar", np.array([0.000375] * 128 + [0.00009] * 128, dtype=np.float32) + writer.insert( + "tensor_bf", + np.array([0.000375] * 128 + [0.00007] * 128, dtype=np.float32), + configs.Type.kBF16, + info_256, ) - writer.insert_nuq( - "baz", np.array([0.000125] * 128 + [0.00008] * 128, dtype=np.float32) + writer.insert( + "tensor_f32", + np.array([0.000375] * 128 + [0.00006] * 128, dtype=np.float32), + configs.Type.kF32, + info_256, ) - writer.insert_bf16( - "qux", np.array([0.000375] * 128 + [0.00007] * 128, dtype=np.float32) + + config = configs.ModelConfig( + configs.Model.GEMMA_TINY, + configs.Type.kSFP, + configs.PromptWrapping.GEMMA_IT, ) - writer.insert_float( - "quux", np.array([0.000375] * 128 + [0.00006] * 128, dtype=np.float32) - ) - self.assertEqual(writer.debug_num_blobs_added(), 6) - self.assertEqual(writer.write(temp_file.full_path), 0) + tokenizer_path = "" # no tokenizer required for testing + writer.write(config, tokenizer_path) + + print("Ignore next two warnings; test does not enable model deduction.") + reader = compression.SbsReader(temp_file.full_path) + + self.assertEqual(reader.config.model, configs.Model.GEMMA_TINY) + self.assertEqual(reader.config.weight, configs.Type.kSFP) + + mat = reader.find_mat("tensor0") + self.assertEqual(mat.cols, 192) + self.assertEqual(mat.rows, 1) + self.assertEqual(mat.type, configs.Type.kSFP) + self.assertAlmostEqual(mat.scale, 4.001 / 1.875, places=5) + + mat = reader.find_mat("tensor_2d") + self.assertEqual(mat.cols, 192) + self.assertEqual(mat.rows, 96) + self.assertEqual(mat.type, configs.Type.kBF16) + self.assertAlmostEqual(mat.scale, 1.0) + + mat = reader.find_mat("tensor_3d") + self.assertEqual(mat.cols, 192) + self.assertEqual(mat.rows, 10 * 12) + self.assertEqual(mat.type, configs.Type.kSFP) + self.assertAlmostEqual(mat.scale, 192 * 120 / 1e3 / 1.875, places=2) + + mat = reader.find_mat("tensor_sfp") + self.assertEqual(mat.cols, 256) + self.assertEqual(mat.rows, 1) + self.assertEqual(mat.type, configs.Type.kSFP) + self.assertAlmostEqual(mat.scale, 1.0) + + mat = reader.find_mat("tensor_bf") + self.assertEqual(mat.cols, 256) + self.assertEqual(mat.rows, 1) + self.assertEqual(mat.type, configs.Type.kBF16) + self.assertAlmostEqual(mat.scale, 1.0) + + mat = reader.find_mat("tensor_f32") + self.assertEqual(mat.cols, 256) + self.assertEqual(mat.rows, 1) + self.assertEqual(mat.type, configs.Type.kF32) + self.assertAlmostEqual(mat.scale, 1.0) if __name__ == "__main__": diff --git a/compression/sfp-inl.h b/compression/sfp-inl.h index 1be84e9..dad6536 100644 --- a/compression/sfp-inl.h +++ b/compression/sfp-inl.h @@ -20,7 +20,7 @@ #include #include -#include "compression/shared.h" +#include "compression/types.h" #include "hwy/base.h" #endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_SFP_INL_H_ diff --git a/compression/sfp_test.cc b/compression/sfp_test.cc index f79e600..8e49ceb 100644 --- a/compression/sfp_test.cc +++ b/compression/sfp_test.cc @@ -13,10 +13,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -// We use ConcatEven/Odd which are not supported. Use HWY_EMU128 instead. +#include "compression/types.h" #ifndef HWY_DISABLED_TARGETS -#define HWY_DISABLED_TARGETS HWY_SCALAR -#endif +#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS +#endif // HWY_DISABLED_TARGETS #include #include @@ -25,7 +25,6 @@ #include #include "compression/distortion.h" -#include "compression/shared.h" #include "util/test_util.h" #include "hwy/aligned_allocator.h" #include "hwy/base.h" diff --git a/compression/test_util-inl.h b/compression/test_util-inl.h index 860644a..7c4f854 100644 --- a/compression/test_util-inl.h +++ b/compression/test_util-inl.h @@ -18,10 +18,13 @@ #define THIRD_PARTY_GEMMA_CPP_COMPRESSION_TEST_UTIL_INL_H_ // IWYU pragma: begin_exports -#include "compression/compress.h" #include "compression/distortion.h" +#include "util/mat.h" // IWYU pragma: end_exports +#include "compression/compress.h" +#include "hwy/contrib/thread_pool/thread_pool.h" + #endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_TEST_UTIL_INL_H_ // Include guard for (potentially) SIMD code. @@ -59,7 +62,63 @@ void ForeachPackedAndRawType() { ForeachRawType(); ForeachRawType(); ForeachRawType(); - ForeachRawType(); + if constexpr (GEMMA_ENABLE_NUQ) { + ForeachRawType(); + } +} + +// Generates inputs: deterministic, within max SfpStream range. +template +MatStorageT GenerateMat(const Extents2D& extents, + const Allocator& allocator, MatPadding padding, + hwy::ThreadPool& pool) { + gcpp::CompressWorkingSet ws; + ws.tls.resize(pool.NumWorkers()); + MatStorageT raw("raw", extents, allocator, MatPadding::kPacked); + MatStorageT compressed("mat", extents, allocator, padding); + const float scale = SfpStream::kMax / extents.Area(); + pool.Run(0, extents.rows, [&](const size_t r, size_t thread) { + float* HWY_RESTRICT row = raw.Row(r); + for (size_t c = 0; c < extents.cols; c++) { + float f = static_cast(r * extents.cols + c) * scale; + if ((r + c) & 1) f = -f; // Also generate some negative values. + row[c] = f; + } + Compress(raw.Row(r), raw.Cols(), ws.tls[thread], + MakeSpan(compressed.Row(r), compressed.Cols()), + /*packed_ofs=*/0); + }); + + compressed.SetScale(0.6f); // Arbitrary value, different from 1. + return compressed; +} + +// Same, but `extents` describes the transposed matrix. +template +MatStorageT GenerateTransposedMat(const Extents2D extents, + const Allocator& allocator, + MatPadding padding, + hwy::ThreadPool& pool) { + gcpp::CompressWorkingSet ws; + ws.tls.resize(pool.NumWorkers()); + MatStorageT raw("raw", extents, allocator, MatPadding::kPacked); + MatStorageT compressed("trans", extents, allocator, padding); + const float scale = SfpStream::kMax / extents.Area(); + pool.Run(0, extents.rows, [&](const size_t r, size_t thread) { + float* HWY_RESTRICT row = raw.Row(r); + for (size_t c = 0; c < extents.cols; c++) { + float f = static_cast(c * extents.rows + r) * scale; + if ((r + c) & 1) f = -f; // Also generate some negative values. + row[c] = f; + } + Compress(raw.Row(r), raw.Cols(), ws.tls[thread], + MakeSpan(compressed.Row(r), compressed.Cols()), + /*packed_ofs=*/0); + }); + + // Arbitrary value, different from 1, must match `GenerateMat`. + compressed.SetScale(0.6f); + return compressed; } // NOLINTNEXTLINE(google-readability-namespace-comments) diff --git a/compression/shared.h b/compression/types.h similarity index 76% rename from compression/shared.h rename to compression/types.h index a5c87ae..dc10676 100644 --- a/compression/shared.h +++ b/compression/types.h @@ -13,18 +13,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Definitions shared between the public compress-inl.h interface and the -// sfp-inl.h and nuq-inl.h implementation details. +// Types shared between tensor definitions and `compress-inl.h`. -#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_SHARED_H_ -#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_SHARED_H_ +#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_TYPES_H_ +#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_TYPES_H_ #include #include -#include -#include - // IWYU pragma: begin_exports #include "util/basics.h" // BF16 #include "hwy/aligned_allocator.h" @@ -33,6 +29,35 @@ namespace gcpp { +// EMU128 must not be disabled because we disable SCALAR. +#define HWY_BROKEN_EMU128 0 + +// Allow user override of disabled targets. +#ifndef GEMMA_DISABLED_TARGETS + +// All platforms: exclude SCALAR because we use ReorderWidenMulAccumulate. + +#if HWY_ARCH_ARM_V7 +// No NEON because we require double-precision support. +#define HWY_DISABLED_TARGETS (HWY_SCALAR | HWY_ALL_NEON) +#elif HWY_ARCH_ARM_A64 +// We do not yet use AES (e.g. for random generation), hence NEON is the same +// as NEON_WITHOUT_AES. Also skip SVE because SVE2_128 and SVE_256 cover most. +#define GEMMA_DISABLED_TARGETS (HWY_SCALAR | HWY_NEON | HWY_SVE) +#elif HWY_ARCH_X86 +// Skip anything older than Haswell (2013); also use Zen4 for recent CPUs, +// because we do not use anything added by SPR (e.g. FP16) nor AVX 10.2. +#define GEMMA_DISABLED_TARGETS \ + (HWY_SCALAR | HWY_SSE2 | HWY_SSSE3 | HWY_SSE4 | HWY_AVX3_SPR | HWY_AVX10_2) +#endif // HWY_ARCH_* + +#endif // GEMMA_DISABLED_TARGETS + +// Only used in experiments, hence disable in default builds. +#ifndef GEMMA_ENABLE_NUQ +#define GEMMA_ENABLE_NUQ 0 +#endif + // Switching Floating Point: a hybrid 8-bit float representation of bf16/f32 // inputs that combines the advantages of e4m3 and e5m2 into a single format. // It supports seeking at a granularity of 1 and decoding to bf16/f32. @@ -63,30 +88,6 @@ struct SfpStream { }; #pragma pack(pop) -// Returns 1.0f if all magnitudes are <= SfpStream::kMax, otherwise scales them -// such that the largest magnitude is SfpStream::kMax, and returns the -// multiplier with which to restore the original values. This is only necessary -// before compressing to SfpStream. -// TODO: vectorize -static inline float ScaleWeights(float* HWY_RESTRICT raw, size_t num) { - float maxabs = 0.0; - for (size_t i = 0; i < num; ++i) { - maxabs = HWY_MAX(maxabs, hwy::ScalarAbs(raw[i])); - } - if (maxabs <= SfpStream::kMax) { - return 1.0f; - } - const float scale = maxabs / SfpStream::kMax; - const float inv_scale = static_cast(1.0 / static_cast(scale)); - for (size_t i = 0; i < num; ++i) { - // Clamp because kMax may still be exceeded. - const float magn = - HWY_MIN(SfpStream::kMax, hwy::ScalarAbs(raw[i] * inv_scale)); - raw[i] = hwy::ScalarCopySign(magn, raw[i]); - } - return scale; -} - // Non-uniform quantization: a compressed representation of f32 inputs that // supports seeking at a granularity of 1 (for `DecompressAndZeroPad`) or // two vectors (for `Decompress2`), and decoding to bf16/f32. @@ -185,31 +186,25 @@ constexpr bool IsNuqStream() { return hwy::IsSame, NuqStream>(); } -// Instruction-tuned models require extra 'turn structure' tokens in prompts. -enum class PromptWrapping { - GEMMA_IT, - GEMMA_PT, - GEMMA_VLM, - PALIGEMMA, - kSentinel // must be last +// Tensor types for loading weights. +enum class Type { kUnknown, kF32, kBF16, kSFP, kNUQ, kF64 }; +// These are used in `ModelConfig.Specifier`, hence the strings will not +// change, though new ones may be added. +static constexpr const char* kTypeStrings[] = {"unknown", "f32", "bf16", + "sfp", "nuq", "f64"}; +static constexpr size_t kNumTypes = + sizeof(kTypeStrings) / sizeof(kTypeStrings[0]); +static constexpr size_t kTypeBits[] = { + 0, + 8 * sizeof(float), + 8 * sizeof(BF16), + 8 * sizeof(SfpStream), + 4 /* NuqStream, actually 4.5 */, + 8 * sizeof(double), }; -inline bool EnumValid(PromptWrapping type) { - return static_cast(type) >= 0 && - static_cast(type) < static_cast(PromptWrapping::kSentinel); -} - -// Tensor types for loading weights. Note that not all types are supported as -// weights for a model, but can be used for other purposes, such as types for -// ModelWeightsPtrs. When adding a new type that is supported, also -// update gemma.cc, weights.*, and add instantiations/new_one.cc. -enum class Type { kUnknown, kF32, kBF16, kSFP, kNUQ, kF64, kC64, kU128 }; -constexpr const char* kTypeStrings[] = {"unknown", "f32", "bf16", "sfp", - "nuq", "f64", "c64", "u128"}; - -inline bool EnumValid(Type type) { - return static_cast(type) >= 0 && - static_cast(type) <= static_cast(Type::kU128); +static inline bool EnumValid(Type type) { + return static_cast(type) < kNumTypes; } // Returns a Type enum for the type of the template parameter. @@ -226,20 +221,22 @@ Type TypeEnum() { return Type::kNUQ; } else if constexpr (hwy::IsSame()) { return Type::kF64; - } else if constexpr (hwy::IsSame>()) { - return Type::kC64; - } else if constexpr (hwy::IsSame()) { - return Type::kU128; } else { HWY_DASSERT(false); return Type::kUnknown; } } -// Returns a string name for the type of the template parameter. +static inline size_t TypeBits(Type type) { + return kTypeBits[static_cast(type)]; +} + +static inline const char* TypeName(Type type) { + return kTypeStrings[static_cast(type)]; +} template const char* TypeName() { - return kTypeStrings[static_cast(TypeEnum())]; + return TypeName(TypeEnum()); } template @@ -248,7 +245,9 @@ constexpr bool IsCompressed() { } // Returns the number of `MatT` elements required to store `capacity` values, -// which must not be zero. +// which must not be zero. This is only intended to support the extra tables +// required for NUQ. `capacity` includes any padding and is `rows * stride`. +// Deprecated, replaced by fixup within `MatPtr`. Only used by tests. template constexpr size_t CompressedArrayElements(size_t capacity) { if constexpr (hwy::IsSame, NuqStream>()) { @@ -304,4 +303,4 @@ HWY_INLINE PackedSpan MakeConst(PackedSpan packed) { } } // namespace gcpp -#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_SHARED_H_ +#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_TYPES_H_ diff --git a/evals/benchmark.cc b/evals/benchmark.cc index 8682189..4dec9ee 100644 --- a/evals/benchmark.cc +++ b/evals/benchmark.cc @@ -6,14 +6,12 @@ #include #include #include -#include // std::pair #include -#include "compression/io.h" // Path #include "evals/benchmark_helper.h" #include "evals/cross_entropy.h" -#include "gemma/common.h" #include "gemma/gemma.h" +#include "io/io.h" // Path #include "util/args.h" #include "hwy/base.h" #include "hwy/timer.h" @@ -27,7 +25,6 @@ class BenchmarkArgs : public ArgsBase { public: BenchmarkArgs(int argc, char* argv[]) { InitAndParse(argc, argv); } - Path goldens; Path summarize_text; Path cross_entropy; Path trivia_qa; @@ -36,8 +33,6 @@ class BenchmarkArgs : public ArgsBase { template void ForEach(const Visitor& visitor) { - visitor(goldens.path, "goldens_dir", std::string(""), - "Directory containing golden files", 2); visitor(summarize_text.path, "summarize_text", std::string(""), "Path to text file to summarize", 2); visitor(cross_entropy.path, "cross_entropy", std::string(""), @@ -53,56 +48,6 @@ class BenchmarkArgs : public ArgsBase { } }; -std::vector> load_goldens( - const std::string& path) { - std::ifstream goldens_file(path); - if (!goldens_file) { - std::cout << "Could not load goldens file: " << path << "\n" << std::flush; - return {}; - } - std::vector> res; - std::string query_separator; - std::string query; - std::string answer_separator; - std::string answer; - while (std::getline(goldens_file, query_separator) && - std::getline(goldens_file, query) && - std::getline(goldens_file, answer_separator) && - std::getline(goldens_file, answer)) { - res.push_back({query, answer}); - } - return res; -} - -int BenchmarkGoldens(GemmaEnv& env, const std::string& golden_path) { - std::vector> queries_answers = - load_goldens(golden_path); - size_t correct_answers = 0; - size_t total_tokens = 0; - const double time_start = hwy::platform::Now(); - for (auto& [question, expected_answer] : queries_answers) { - QueryResult result = env.QueryModel(question); - total_tokens += result.tokens_generated; - if (result.response.find(expected_answer) != std::string::npos) { - correct_answers++; - } else { - std::cout << "Wrong!\n"; - std::cout << "Input: " << question << "\n"; - std::cout << "Expected: " << expected_answer << "\n"; - std::cout << "Output: " << result.response << "\n\n" << std::flush; - } - } - LogSpeedStats(time_start, total_tokens); - - std::cout << "Correct: " << correct_answers << " out of " - << queries_answers.size() << "\n" - << std::flush; - if (correct_answers != queries_answers.size()) { - return EXIT_FAILURE; - } - return EXIT_SUCCESS; -} - int BenchmarkSummary(GemmaEnv& env, const Path& text) { std::string prompt("Here is some text to summarize:\n"); prompt.append(ReadFileToString(text)); @@ -117,6 +62,7 @@ int BenchmarkSummary(GemmaEnv& env, const Path& text) { int BenchmarkCrossEntropy(GemmaEnv& env, const Path& text, size_t batch_tokens) { + const Gemma& gemma = *env.GetGemma(); std::string input = ReadFileToString(text); std::vector prompt = env.Tokenize(input); std::cout << "Number of input tokens: " << prompt.size() << "\n"; @@ -128,10 +74,11 @@ int BenchmarkCrossEntropy(GemmaEnv& env, const Path& text, size_t num_tokens = std::min(prompt.size() - pos, batch_tokens); std::vector prompt_slice(prompt.begin() + pos, prompt.begin() + pos + num_tokens); - KVCache kv_cache = KVCache::Create(env.GetModel()->GetModelConfig(), - env.MutableConfig().prefill_tbatch_size); - float entropy = ComputeCrossEntropy( - *env.GetModel(), num_tokens, prompt_slice, kv_cache, env.Verbosity()); + KVCache kv_cache(gemma.Config(), gemma.Inference(), + env.MutableEnv().ctx.allocator); + float entropy = + ComputeCrossEntropy(*env.GetGemma(), num_tokens, prompt_slice, kv_cache, + env.MutableEnv(), env.Verbosity()); total_entropy += entropy; LogSpeedStats(time_start, pos + num_tokens); std::string text_slice = env.StringFromTokens(prompt_slice); @@ -183,14 +130,7 @@ int main(int argc, char** argv) { gcpp::GemmaEnv env(argc, argv); gcpp::BenchmarkArgs benchmark_args(argc, argv); - if (!benchmark_args.goldens.Empty()) { - const std::string golden_path = - benchmark_args.goldens.path + "/" + - gcpp::ModelString(env.GetModel()->Info().model, - env.GetModel()->Info().wrapping) + - ".txt"; - return BenchmarkGoldens(env, golden_path); - } else if (!benchmark_args.summarize_text.Empty()) { + if (!benchmark_args.summarize_text.Empty()) { return BenchmarkSummary(env, benchmark_args.summarize_text); } else if (!benchmark_args.cross_entropy.Empty()) { return BenchmarkCrossEntropy(env, benchmark_args.cross_entropy, diff --git a/evals/benchmark_helper.cc b/evals/benchmark_helper.cc index 2daebdf..3b999b4 100644 --- a/evals/benchmark_helper.cc +++ b/evals/benchmark_helper.cc @@ -18,27 +18,21 @@ #include #include -#include #include -#include #include #include #include #include -// Placeholder for internal header, do not modify. -#include "compression/compress.h" // TypeName +#include "compression/types.h" // TypeName #include "evals/cross_entropy.h" -#include "gemma/common.h" // StringFromType #include "gemma/gemma.h" -#include "gemma/kv_cache.h" -#include "util/app.h" -#include "util/args.h" -#include "util/threading.h" -#include "hwy/base.h" -#include "hwy/contrib/thread_pool/topology.h" +#include "gemma/gemma_args.h" +#include "ops/matmul.h" // MatMulEnv +#include "util/threading_context.h" #include "hwy/highway.h" -#include "hwy/per_target.h" // VectorBytes +#include "hwy/per_target.h" // DispatchedTarget +#include "hwy/profiler.h" // PROFILER_ENABLED #include "hwy/timer.h" namespace gcpp { @@ -49,53 +43,37 @@ void InitGenerator(const InferenceArgs& inference, std::mt19937& gen) { gen.seed(0x12345678); } else { // Depending on the library implementation, this may still be deterministic. - std::random_device rd; + std::random_device rd; // NOLINT gen.seed(rd()); } } -GemmaEnv::GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference, - const AppArgs& app) - : topology_(CreateTopology(app)), - pools_(CreatePools(topology_, app)), - env_(topology_, pools_) { - InferenceArgs mutable_inference = inference; - AbortIfInvalidArgs(mutable_inference); - LoaderArgs mutable_loader = loader; - if (const char* err = mutable_loader.Validate()) { - mutable_loader.Help(); - fprintf(stderr, "Skipping model load because: %s\n", err); - } else { - fprintf(stderr, "Loading model...\n"); - model_ = AllocateGemma(mutable_loader, env_); - // Only allocate one for starters because GenerateBatch might not be called. - kv_caches_.resize(1); - kv_caches_[0] = KVCache::Create(model_->GetModelConfig(), - inference.prefill_tbatch_size); +GemmaEnv::GemmaEnv(const LoaderArgs& loader, const ThreadingArgs& threading, + const InferenceArgs& inference) + : ctx_(threading), env_(ctx_), gemma_(loader, inference, ctx_) { + const ModelConfig& config = gemma_.Config(); + // Only allocate one for starters because GenerateBatch might not be called. + kv_caches_.push_back(KVCache(config, inference, ctx_.allocator)); + + if (inference.verbosity >= 2) { + ShowConfig(loader, threading, inference, config, gemma_.WeightReadMode(), + ctx_); } + InitGenerator(inference, gen_); + runtime_config_ = { .max_generated_tokens = inference.max_generated_tokens, .temperature = inference.temperature, .gen = &gen_, - .verbosity = app.verbosity, + .verbosity = inference.verbosity, }; -} - -// Internal init must run before the GemmaEnv ctor above, hence it cannot occur -// in the argv ctor below because its body runs *after* the delegating ctor. -// This helper function takes care of the init, and could be applied to any of -// the *Args classes, it does not matter which. -static AppArgs MakeAppArgs(int argc, char** argv) { - { // So that indentation matches expectations. - // Placeholder for internal init, do not modify. - } - return AppArgs(argc, argv); + inference.CopyTo(runtime_config_); } GemmaEnv::GemmaEnv(int argc, char** argv) - : GemmaEnv(LoaderArgs(argc, argv), InferenceArgs(argc, argv), - MakeAppArgs(argc, argv)) {} + : GemmaEnv(LoaderArgs(argc, argv), ThreadingArgs(argc, argv), + InferenceArgs(argc, argv)) {} QueryResult GemmaEnv::QueryModel(const std::vector& tokens) { QueryResult result; @@ -117,8 +95,8 @@ QueryResult GemmaEnv::QueryModel(const std::vector& tokens) { } gcpp::TimingInfo timing_info { .verbosity = runtime_config_.verbosity }; runtime_config_.batch_stream_token = batch_stream_token; - model_->Generate(runtime_config_, tokens, /*start_pos=*/0, kv_caches_[0], - timing_info); + gemma_.Generate(runtime_config_, tokens, /*start_pos=*/0, kv_caches_[0], env_, + timing_info); return result; } @@ -127,23 +105,25 @@ void GemmaEnv::QueryModel( gcpp::TimingInfo timing_info { .verbosity = runtime_config_.verbosity }; const StreamFunc previous_stream_token = runtime_config_.stream_token; runtime_config_.stream_token = stream_token; - model_->Generate(runtime_config_, tokens, /*start_pos=*/0, kv_caches_[0], - timing_info); + gemma_.Generate(runtime_config_, tokens, /*start_pos=*/0, kv_caches_[0], env_, + timing_info); runtime_config_.stream_token = previous_stream_token; } std::vector GemmaEnv::BatchQueryModel( - const QueriesPromptTokens& queries_prompt) { + const QueriesPromptTokens& queries_prompt, + const hwy::Span& prefix_end) { const size_t num_queries = queries_prompt.size(); HWY_ASSERT(num_queries != 0); std::vector res(num_queries); - const BatchStreamFunc batch_stream_token = [&res, &queries_prompt, this]( - size_t query_index, size_t pos, - int token, float) { + const BatchStreamFunc batch_stream_token = [&, this](const size_t query_index, + const size_t pos, + const int token, float) { + HWY_ASSERT(query_index < num_queries); std::string token_text; - HWY_ASSERT( - model_->Tokenizer().Decode(std::vector{token}, &token_text)); + HWY_ASSERT(gemma_.Tokenizer().Decode(std::vector{token}, &token_text)); res[query_index].response.append(token_text); + HWY_ASSERT(pos == res[query_index].tokens_generated); res[query_index].tokens_generated += 1; if (res[query_index].tokens_generated == queries_prompt[query_index].size()) { @@ -151,6 +131,7 @@ std::vector GemmaEnv::BatchQueryModel( } return true; }; + runtime_config_.batch_stream_token = batch_stream_token; if (runtime_config_.verbosity >= 2) { fprintf(stderr, "Max gen: %zu temp: %f tbatch: %zu qbatch: %zu\n", runtime_config_.max_generated_tokens, runtime_config_.temperature, @@ -158,23 +139,16 @@ std::vector GemmaEnv::BatchQueryModel( runtime_config_.decode_qbatch_size); } - // Ensure we have one KVCache per query. - if (kv_caches_.size() < num_queries) { - kv_caches_.resize(num_queries); - } - for (size_t i = 1; i < num_queries; ++i) { - if (kv_caches_[i].seq_len == 0) { - kv_caches_[i] = KVCache::Create(model_->GetModelConfig(), - runtime_config_.prefill_tbatch_size); - } + // Ensure we have at least one KVCache per query. + while (kv_caches_.size() < num_queries) { + kv_caches_.push_back( + KVCache(gemma_.Config(), gemma_.Inference(), ctx_.allocator)); } + const hwy::Span kv_caches(&kv_caches_[0], num_queries); + gcpp::AllQueries all_queries(queries_prompt, kv_caches, prefix_end); gcpp::TimingInfo timing_info = {.verbosity = runtime_config_.verbosity}; - runtime_config_.batch_stream_token = batch_stream_token; - std::vector queries_pos(num_queries, 0); - model_->GenerateBatch(runtime_config_, queries_prompt, - QueriesPos(queries_pos.data(), num_queries), - KVCaches(&kv_caches_[0], num_queries), timing_info); + gemma_.GenerateBatch(runtime_config_, all_queries, env_, timing_info); return res; } @@ -203,8 +177,8 @@ std::vector GemmaEnv::BatchQueryModel( float GemmaEnv::CrossEntropy(const std::string& input) { std::vector prompt = Tokenize(input); prompt.insert(prompt.begin(), BOS_ID); - return ComputeCrossEntropy(*GetModel(), /*max_generated_tokens=*/3072, prompt, - MutableKVCache(), + return ComputeCrossEntropy(*GetGemma(), /*max_generated_tokens=*/3072, prompt, + MutableKVCache(), env_, /*verbosity=*/0) / static_cast(input.size()); } @@ -236,13 +210,37 @@ std::string CacheString() { return buf; } -void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app, - const BoundedTopology& topology, NestedPools& pools) { - loader.Print(app.verbosity); - inference.Print(app.verbosity); - app.Print(app.verbosity); +static constexpr const char* CompiledConfig() { + if constexpr (HWY_IS_ASAN) { + return "asan"; + } else if constexpr (HWY_IS_MSAN) { + return "msan"; + } else if constexpr (HWY_IS_TSAN) { + return "tsan"; + } else if constexpr (HWY_IS_HWASAN) { + return "hwasan"; + } else if constexpr (HWY_IS_UBSAN) { + return "ubsan"; + } else if constexpr (HWY_IS_DEBUG_BUILD) { + return "dbg"; + } else { + return "opt"; + } +} - if (app.verbosity >= 2) { +void ShowConfig(const LoaderArgs& loader, const ThreadingArgs& threading, + const InferenceArgs& inference, const ModelConfig& config, + const WeightsPtrs::Mode weight_read_mode, + const ThreadingContext& ctx) { + threading.Print(inference.verbosity); + loader.Print(inference.verbosity); + inference.Print(inference.verbosity); + fprintf( + stderr, "Model : %s, to_bf16 %d, mmap %d => %s\n", + config.Specifier().c_str(), static_cast(loader.to_bf16), + static_cast(loader.map), WeightsPtrs::ToString(weight_read_mode)); + + if (inference.verbosity >= 2) { time_t now = time(nullptr); char* dt = ctime(&now); // NOLINT char cpu100[100] = "unknown"; @@ -250,38 +248,34 @@ void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app, fprintf(stderr, "Date & Time : %s" // dt includes \n - "CPU : %s\n" + "CPU : %s, bind %d\n" "CPU topology : %s, %s, %s\n" "Instruction set : %s (%zu bits)\n" - "Compiled config : %s\n" - "Weight Type : %s\n" - "EmbedderInput Type : %s\n", - dt, cpu100, topology.TopologyString(), pools.PinString(), + "Compiled config : %s, profiler %d\n" + "Memory MiB : %4zu\n", + dt, cpu100, static_cast(threading.bind), + ctx.topology.TopologyString(), ctx.pools.PinString(), CacheString().c_str(), hwy::TargetName(hwy::DispatchedTarget()), - hwy::VectorBytes() * 8, CompiledConfig(), - StringFromType(loader.Info().weight), TypeName()); + ctx.allocator.VectorBytes() * 8, CompiledConfig(), PROFILER_ENABLED, + ctx.allocator.TotalMiB()); } } -void ShowHelp(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) { +void ShowHelp(const LoaderArgs& loader, const ThreadingArgs& threading, + const InferenceArgs& inference) { std::cerr << "\n\ngemma.cpp : a lightweight, standalone C++ inference engine\n" "==========================================================\n\n" - "To run gemma.cpp, you need to " - "specify 3 required model loading arguments:\n" - " --tokenizer\n" - " --weights\n" - " --model,\n" - " or with the newer weights format, specify just:\n" - " --weights\n"; + "To run with pre-2025 weights, specify --tokenizer and --weights.\n" + "With the single-file weights format, specify just --weights.\n"; std::cerr << "\n*Example Usage*\n\n./gemma --tokenizer tokenizer.spm " - "--weights 2b-it-sfp.sbs --model 2b-it\n"; + "--weights gemma2-2b-it-sfp.sbs\n"; std::cerr << "\n*Model Loading Arguments*\n\n"; loader.Help(); + std::cerr << "\n*Threading Arguments*\n\n"; + threading.Help(); std::cerr << "\n*Inference Arguments*\n\n"; inference.Help(); - std::cerr << "\n*Application Arguments*\n\n"; - app.Help(); std::cerr << "\n"; } diff --git a/evals/benchmark_helper.h b/evals/benchmark_helper.h index f6e32c0..8f4d96f 100644 --- a/evals/benchmark_helper.h +++ b/evals/benchmark_helper.h @@ -18,15 +18,16 @@ #include -#include #include #include #include +#include "gemma/configs.h" #include "gemma/gemma.h" +#include "gemma/gemma_args.h" +#include "gemma/tokenizer.h" // WrapAndTokenize #include "ops/matmul.h" -#include "util/app.h" -#include "util/threading.h" +#include "util/threading_context.h" #include "hwy/base.h" namespace gcpp { @@ -46,19 +47,21 @@ class GemmaEnv { public: // Calls the other constructor with *Args arguments initialized from argv. GemmaEnv(int argc, char** argv); - GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference, - const AppArgs& app); + GemmaEnv(const LoaderArgs& loader, const ThreadingArgs& threading, + const InferenceArgs& inference); + MatMulEnv& Env() { return env_; } size_t MaxGeneratedTokens() const { return runtime_config_.max_generated_tokens; } - void SetMaxGeneratedTokens(size_t max_generated_tokens) { - runtime_config_.max_generated_tokens = max_generated_tokens; + void SetMaxGeneratedTokens(int max_generated_tokens) { + runtime_config_.max_generated_tokens = + static_cast(max_generated_tokens); } std::vector Tokenize(const std::string& input) const { std::vector tokens; - HWY_ASSERT(model_->Tokenizer().Encode(input, &tokens)); + HWY_ASSERT(gemma_.Tokenizer().Encode(input, &tokens)); return tokens; } @@ -69,20 +72,23 @@ class GemmaEnv { } std::vector WrapAndTokenize(std::string& input) const { - return gcpp::WrapAndTokenize(model_->Tokenizer(), model_->Info(), 0, input); + return gcpp::WrapAndTokenize(gemma_.Tokenizer(), gemma_.ChatTemplate(), + gemma_.Config().wrapping, 0, input); } std::string StringFromTokens(const std::vector& tokens) const { std::string string; - HWY_ASSERT(model_->Tokenizer().Decode(tokens, &string)); + HWY_ASSERT(gemma_.Tokenizer().Decode(tokens, &string)); return string; } // Runs inference on the given input and returns the top-1 result string and // the number of tokens that were generated. QueryResult QueryModel(const std::vector& tokens); + // The default prefix_end means "causal attention". std::vector BatchQueryModel( - const QueriesPromptTokens& queries_prompt); + const QueriesPromptTokens& queries_prompt, + const hwy::Span& prefix_end = hwy::Span()); // Adds turn structure to input, tokenizes and calls the above overload. QueryResult QueryModel(std::string& input); std::vector BatchQueryModel( @@ -97,20 +103,19 @@ class GemmaEnv { // number of bits per token. float CrossEntropy(const std::string& input); - // Returns nullptr if the model failed to load. - Gemma* GetModel() const { return model_.get(); } + const Gemma* GetGemma() const { return &gemma_; } int Verbosity() const { return runtime_config_.verbosity; } RuntimeConfig& MutableConfig() { return runtime_config_; } std::mt19937& MutableGen() { return gen_; } KVCache& MutableKVCache() { return kv_caches_[0]; } + MatMulEnv& MutableEnv() { return env_; } private: - BoundedTopology topology_; - NestedPools pools_; // Thread pool. + ThreadingContext ctx_; MatMulEnv env_; - std::mt19937 gen_; // Random number generator. - std::unique_ptr model_; + Gemma gemma_; + std::mt19937 gen_; // Random number generator. std::vector kv_caches_; // Same number as query batch. RuntimeConfig runtime_config_; }; @@ -118,9 +123,12 @@ class GemmaEnv { // Logs the inference speed in tokens/sec. void LogSpeedStats(double time_start, size_t total_tokens); -void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app, - const BoundedTopology& topology, NestedPools& pools); -void ShowHelp(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app); +void ShowConfig(const LoaderArgs& loader, const ThreadingArgs& threading, + const InferenceArgs& inference, const ModelConfig& config, + WeightsPtrs::Mode weight_read_mode, + const ThreadingContext& ctx); +void ShowHelp(const LoaderArgs& loader, const ThreadingArgs& threading, + const InferenceArgs& inference); } // namespace gcpp diff --git a/evals/cross_entropy.cc b/evals/cross_entropy.cc index e4bf1b1..09c3a42 100644 --- a/evals/cross_entropy.cc +++ b/evals/cross_entropy.cc @@ -13,6 +13,11 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "compression/types.h" // GEMMA_DISABLED_TARGETS +#ifndef HWY_DISABLED_TARGETS +#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS +#endif // HWY_DISABLED_TARGETS + // Compiles this file for multiple architectures via "foreach_target.h", to // which we pass the filename via macro 'argument'. // clang-format off @@ -38,17 +43,12 @@ #include #include "evals/cross_entropy.h" -#include "gemma/common.h" #include "gemma/gemma.h" #include "hwy/base.h" namespace gcpp { namespace { -template -struct GetVocabSize { - int operator()() const { return TConfig::kVocabSize; } -}; static std::string TokenString(const GemmaTokenizer& tokenizer, int token) { std::string token_str; @@ -85,7 +85,7 @@ namespace gcpp { namespace HWY_NAMESPACE { void CallSoftmax(float* HWY_RESTRICT logits, size_t vocab_size) { - Softmax(logits, vocab_size); + Softmax(logits, vocab_size, /*worker=*/0); } } // namespace HWY_NAMESPACE @@ -97,12 +97,12 @@ namespace gcpp { HWY_EXPORT(CallSoftmax); -float ComputeCrossEntropy(Gemma& gemma, size_t max_generated_tokens, +float ComputeCrossEntropy(const Gemma& gemma, size_t max_generated_tokens, const std::vector& prompt, KVCache& kv_cache, - int verbosity) { + MatMulEnv& env, int verbosity) { const StreamFunc stream_token = [](int, float) { return true; }; - const int vocab_size = gemma.GetModelConfig().vocab_size; + const int vocab_size = gemma.Config().vocab_size; float cross_entropy = std::log(vocab_size); // first token; == -log(1/v_s) size_t pos = 1; @@ -145,7 +145,7 @@ float ComputeCrossEntropy(Gemma& gemma, size_t max_generated_tokens, }; TimingInfo timing_info; - gemma.Generate(runtime, prompt0, 0, kv_cache, timing_info); + gemma.Generate(runtime, prompt0, 0, kv_cache, env, timing_info); const float scale = 1.0f / std::log(2.0f); return cross_entropy * scale; diff --git a/evals/cross_entropy.h b/evals/cross_entropy.h index fed224c..0a143cc 100644 --- a/evals/cross_entropy.h +++ b/evals/cross_entropy.h @@ -24,9 +24,9 @@ namespace gcpp { -float ComputeCrossEntropy(Gemma& gemma, size_t max_generated_tokens, +float ComputeCrossEntropy(const Gemma& gemma, size_t max_generated_tokens, const std::vector& prompt, KVCache& kv_cache, - int verbosity); + MatMulEnv& env, int verbosity); } // namespace gcpp diff --git a/evals/debug_prompt.cc b/evals/debug_prompt.cc index 2d02b3a..66fa466 100644 --- a/evals/debug_prompt.cc +++ b/evals/debug_prompt.cc @@ -18,9 +18,9 @@ #include #include -#include "compression/io.h" #include "evals/benchmark_helper.h" #include "gemma/gemma.h" // LayersOutputFunc +#include "io/io.h" #include "util/args.h" #include "hwy/base.h" #include "nlohmann/json.hpp" diff --git a/evals/gemma_batch_bench.cc b/evals/gemma_batch_bench.cc index 44b803f..6d97c61 100644 --- a/evals/gemma_batch_bench.cc +++ b/evals/gemma_batch_bench.cc @@ -13,25 +13,18 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "gemma/gemma.h" - #include #include #include #include "evals/benchmark_helper.h" -#include "gemma/common.h" +#include "gemma/gemma.h" #include "hwy/base.h" +#include "hwy/nanobenchmark.h" +#include "hwy/profiler.h" #include "hwy/tests/hwy_gtest.h" -// This test can be run manually with the downloaded gemma weights. -// To run the test, pass the following flags: -// --model --tokenizer --weights -// It should pass for the following models: -// Gemma1: 2b-it (v1 and v1.1), 7b-it (v1 and v1.1), gr2b-it, -// Gemma2: gemma2-2b-it, 9b-it, 27b-it, - namespace gcpp { namespace { @@ -40,61 +33,23 @@ namespace { // non-local static variables with dtors. GemmaEnv* s_env = nullptr; -class GemmaTest : public ::testing::Test { +class GemmaBatchBench : public ::testing::Test { protected: std::vector BatchGemmaReply( const std::vector& inputs) { - s_env->SetMaxGeneratedTokens(64); + s_env->SetMaxGeneratedTokens(24); s_env->MutableConfig().temperature = 0.0f; // deterministic - s_env->MutableConfig().verbosity = 5; + s_env->MutableConfig().verbosity = 2; std::vector replies; - // Using the turn structure worsens results sometimes. - // However, some models need the turn structure to work. - // It would be good to make these tests more consistent. - if (s_env->GetModel()->Info().model == Model::GEMMA2_27B || - s_env->GetModel()->Info().model == Model::GRIFFIN_2B) { - for (QueryResult result : s_env->BatchQueryModel(inputs)) { - replies.push_back(result.response); - } - return replies; - } - // Otherwise, do not use turn structure. - std::vector> prompts_vector; - prompts_vector.reserve(inputs.size()); - for (const auto& input_string : inputs) { - prompts_vector.push_back(s_env->TokenizeAndPrependBOS(input_string)); - } - std::vector prompt_spans; - for (const auto& prompt : prompts_vector) { - prompt_spans.push_back(PromptTokens(prompt.data(), prompt.size())); - } - QueriesPromptTokens prompts(prompt_spans.data(), prompt_spans.size()); - for (const QueryResult& result : s_env->BatchQueryModel(prompts)) { + for (const QueryResult& result : s_env->BatchQueryModel(inputs)) { replies.push_back(result.response); } return replies; } - - void GenerateTokens(std::vector &kQA, size_t num_questions) { - ASSERT_NE(s_env->GetModel(), nullptr); - - std::vector inputs; - for (size_t i = 0; i < num_questions; ++i) { - inputs.push_back(kQA[i]); - } - std::vector responses = BatchGemmaReply(inputs); - for (size_t i = 0; i < num_questions; ++i) { - std::string response = responses.at(i); - fprintf(stderr, "Batch answer %zu '%s'\n\n", i + 1, response.c_str()); - } - } }; -TEST_F(GemmaTest, RandomQuestionsBatched) { - s_env->MutableConfig().decode_qbatch_size = 3; - s_env->MutableConfig().verbosity = 5; - - static std::vector kQA = { +TEST_F(GemmaBatchBench, RandomQuestionsBatched) { + const std::vector questions = { {"Write me a poem about Australia?"}, {"What's the history of Denmark?"}, {"Write me a comedy story about the USA."}, @@ -128,13 +83,27 @@ TEST_F(GemmaTest, RandomQuestionsBatched) { {"Tell me about space travel."}, {"Explain to me how electric cars work."}, }; - static const size_t kNum = kQA.size(); - GenerateTokens(kQA, kNum); + + // Fills prompts round robin from `questions` until the desired batch size. + std::vector inputs; + inputs.reserve(s_env->MutableConfig().decode_qbatch_size); + size_t qpos = 0; + for (size_t i = 0; i < inputs.capacity(); ++i) { + inputs.push_back(questions[qpos++]); + if (qpos == questions.size()) qpos = 0; + } + std::vector responses = BatchGemmaReply(inputs); + for (size_t i = 0; i < hwy::Unpredictable1() * 3; ++i) { + fprintf(stderr, "Batch answer %zu '%s'\n\n", i, responses[i].c_str()); + } + + PROFILER_PRINT_RESULTS(); } } // namespace } // namespace gcpp int main(int argc, char** argv) { + fprintf(stderr, "GemmaEnv setup..\n"); gcpp::GemmaEnv env(argc, argv); gcpp::s_env = &env; diff --git a/evals/gemma_test.cc b/evals/gemma_test.cc index 7674c5e..12080f9 100644 --- a/evals/gemma_test.cc +++ b/evals/gemma_test.cc @@ -21,7 +21,8 @@ #include #include "evals/benchmark_helper.h" -#include "gemma/common.h" +#include "gemma/configs.h" +#include "io/io.h" #include "hwy/base.h" #include "hwy/tests/hwy_gtest.h" @@ -36,132 +37,75 @@ namespace gcpp { namespace { -// Shared state. Requires argc/argv, so construct in main and use the same raw -// pointer approach as in benchmarks.cc. Note that the style guide forbids -// non-local static variables with dtors. -GemmaEnv* s_env = nullptr; - class GemmaTest : public ::testing::Test { - protected: - std::string GemmaReply(const std::string& prompt) { - s_env->SetMaxGeneratedTokens(2048); - s_env->MutableConfig().temperature = 0.0f; // deterministic - s_env->MutableConfig().verbosity = 0; - // Using the turn structure worsens results sometimes. - // However, some models need the turn structure to work. - // It would be good to make these tests more consistent. - if (s_env->GetModel()->Info().model == Model::GEMMA2_27B || - s_env->GetModel()->Info().model == Model::GRIFFIN_2B) { - std::string mutable_prompt = prompt; - QueryResult result = s_env->QueryModel(mutable_prompt); // Uses turns. - return result.response; - } - // Otherwise, do not use turn structure. - const std::vector tokens = s_env->TokenizeAndPrependBOS(prompt); - QueryResult result = s_env->QueryModel(tokens); - return result.response; + public: + // Requires argc/argv, hence do not use `SetUpTestSuite`. + static void InitEnv(int argc, char** argv) { + HWY_ASSERT(s_env == nullptr); // Should only be called once. + s_env = new GemmaEnv(argc, argv); + const gcpp::ModelConfig& config = s_env->GetGemma()->Config(); + fprintf(stderr, "Using %s\n", config.Specifier().c_str()); } + static void DeleteEnv() { delete s_env; } + + protected: std::vector BatchGemmaReply( const std::vector& inputs) { + HWY_ASSERT(s_env); // must have called InitEnv() s_env->SetMaxGeneratedTokens(64); s_env->MutableConfig().temperature = 0.0f; // deterministic s_env->MutableConfig().verbosity = 0; + // Always use turn structure (WrapAndTokenize). std::vector replies; - // Using the turn structure worsens results sometimes. - // However, some models need the turn structure to work. - // It would be good to make these tests more consistent. - if (s_env->GetModel()->Info().model == Model::GEMMA2_27B || - s_env->GetModel()->Info().model == Model::GRIFFIN_2B) { - for (QueryResult result : s_env->BatchQueryModel(inputs)) { - replies.push_back(result.response); - } - return replies; - } - // Otherwise, do not use turn structure. - std::vector> prompts_vector; - prompts_vector.reserve(inputs.size()); - for (const auto& input_string : inputs) { - prompts_vector.push_back(s_env->TokenizeAndPrependBOS(input_string)); - } - std::vector prompt_spans; - for (const auto& prompt : prompts_vector) { - prompt_spans.push_back(PromptTokens(prompt.data(), prompt.size())); - } - QueriesPromptTokens prompts(prompt_spans.data(), prompt_spans.size()); - for (const QueryResult& result : s_env->BatchQueryModel(prompts)) { + for (QueryResult result : s_env->BatchQueryModel(inputs)) { replies.push_back(result.response); } return replies; } - void TestQuestions(const char* kQA[][2], size_t num_questions, bool batch) { - ASSERT_NE(s_env->GetModel(), nullptr); - if (batch) { - std::vector inputs; - for (size_t i = 0; i < num_questions; ++i) { - fprintf(stderr, "Batch Question %zu\n\n", i + 1); - inputs.push_back(kQA[i][0]); - } - std::vector responses = BatchGemmaReply(inputs); - for (size_t i = 0; i < num_questions; ++i) { - std::string response = responses.at(i); - fprintf(stderr, "Batch answer %zu '%s'\n\n", i + 1, response.c_str()); - EXPECT_TRUE(response.find(kQA[i][1]) != std::string::npos); // NOLINT - } - } else { - for (size_t i = 0; i < num_questions; ++i) { - fprintf(stderr, "Question %zu\n\n", i + 1); - std::string response = GemmaReply(kQA[i][0]); - fprintf(stderr, "'%s'\n\n", response.c_str()); - EXPECT_TRUE(response.find(kQA[i][1]) != std::string::npos); // NOLINT - } - } - } + // Shared state. Requires argc/argv, so construct in main via InitEnv. + // Note that the style guide forbids non-local static variables with dtors. + static GemmaEnv* s_env; }; -TEST_F(GemmaTest, GeographyBatched) { - s_env->MutableConfig().decode_qbatch_size = 3; - // 6 are enough to test batching and the loop. +GemmaEnv* GemmaTest::s_env = nullptr; + +TEST_F(GemmaTest, Batched) { + // Test remainder handling in MatMul (four rows per tile), but avoid a + // second batch in debug builds to speed up the test. + s_env->MutableConfig().decode_qbatch_size = HWY_IS_DEBUG_BUILD ? 6 : 3; static const char* kQA[][2] = { {"What is the capital of Australia?", "Canberra"}, - {"What is the capital of Denmark?", "Copenhagen"}, - {"Ljubljana is the capital of which country?", "Slovenia"}, - {"Is Chicago a country?", "city"}, {"How many states does the US have?", "50"}, {"What is the Pacific?", "ocean"}, - }; - static const size_t kNum = sizeof(kQA) / sizeof(kQA[0]); - TestQuestions(kQA, HWY_MIN(kNum, 3), /*batch=*/false); - TestQuestions(kQA, 1, /*batch=*/true); - TestQuestions(kQA, kNum, /*batch=*/true); -} - -TEST_F(GemmaTest, History) { - static const char* kQA[][2] = { {"When was the battle of Hastings?", "1066"}, - }; - static const size_t kNum = sizeof(kQA) / sizeof(kQA[0]); - TestQuestions(kQA, kNum, /*batch=*/false); -} - -TEST_F(GemmaTest, Arithmetic) { - static const char* kQA[][2] = { {"what is 13 + 14?", "27"}, {"what is 7 * 8?", "56"}, }; - static const size_t kNum = sizeof(kQA) / sizeof(kQA[0]); - TestQuestions(kQA, kNum, /*batch=*/false); + const size_t kNum = sizeof(kQA) / sizeof(kQA[0]); + std::vector inputs; + for (size_t i = 0; i < kNum; ++i) { + inputs.push_back(kQA[i][0]); + } + std::vector responses = BatchGemmaReply(inputs); + HWY_ASSERT(responses.size() == kNum); + for (size_t i = 0; i < kNum; ++i) { + fprintf(stderr, "#%zu: '%s'\n\n", i, responses[i].c_str()); + EXPECT_TRUE(responses[i].find(kQA[i][1]) != std::string::npos); // NOLINT + } } TEST_F(GemmaTest, Multiturn) { - Gemma* model = s_env->GetModel(); - ASSERT_NE(model, nullptr); + const Gemma* model = s_env->GetGemma(); + const ModelConfig& config = model->Config(); size_t abs_pos = 0; std::string response; - auto stream_token = [&](int token, float) { - if (token == EOS_ID) return true; + auto stream_token = [&](size_t query_idx, size_t pos, int token, float) { + HWY_ASSERT(query_idx == 0); + HWY_ASSERT(pos == abs_pos); ++abs_pos; + if (config.IsEOS(token)) return true; std::string token_text; EXPECT_TRUE( model->Tokenizer().Decode(std::vector{token}, &token_text)); @@ -173,83 +117,44 @@ TEST_F(GemmaTest, Multiturn) { .temperature = 0.0f, .gen = &s_env->MutableGen(), .verbosity = 2, - .stream_token = stream_token, + .batch_stream_token = stream_token, }; TimingInfo timing_info{.verbosity = 0}; // First "say" something slightly unusual. std::string mutable_prompt = "I have a car and its color is turquoise."; - std::vector tokens = WrapAndTokenize(model->Tokenizer(), model->Info(), - abs_pos, mutable_prompt); + std::vector tokens = + WrapAndTokenize(model->Tokenizer(), model->ChatTemplate(), + config.wrapping, abs_pos, mutable_prompt); + model->Generate(runtime_config, tokens, abs_pos, s_env->MutableKVCache(), - timing_info); + s_env->MutableEnv(), timing_info); // Note: we do not rewind any tokens here. If the model // produced one and WrapAndTokenize() inserts another one, it will just be // duplicated. mutable_prompt = "Please repeat all prior statements."; - tokens = WrapAndTokenize(model->Tokenizer(), model->Info(), abs_pos, - mutable_prompt); + tokens = WrapAndTokenize(model->Tokenizer(), model->ChatTemplate(), + config.wrapping, abs_pos, mutable_prompt); + // Reset the `response` string here, then check that the model actually has // access to the previous turn by asking to reproduce. response.clear(); model->Generate(runtime_config, tokens, abs_pos, s_env->MutableKVCache(), - timing_info); - fprintf(stderr, "decoded: %s\n", response.c_str()); + s_env->MutableEnv(), timing_info); + fprintf(stderr, "decoded: '%s'\n", response.c_str()); bool remembered_turquoise = response.find("turquoise") != std::string::npos; // NOLINT bool remembered_car = response.find("car") != std::string::npos; // NOLINT EXPECT_TRUE(remembered_turquoise || remembered_car); } -static const char kJingleBells[] = R"( -Dashing through the snow -In a one-horse open sleigh -O'er the fields we go -Laughing all the way -Bells on bobtails ring -Making spirits bright -What fun it is to ride and sing -A sleighing song tonight -)"; - -// The "Hay Draft" of the Gettysburg Address. -static const char kGettysburg[] = { - "Four score and seven years ago our fathers brought forth, upon this " - "continent, a new nation, conceived in Liberty, and dedicated to the " - "proposition that all men are created equal.\n\nNow we are engaged in a " - "great civil war, testing whether that nation, or any nation, so " - "conceived, and so dedicated, can long endure. We are met here on a great " - "battlefield of that war. We have come to dedicate a portion of it as a " - "final resting place for those who here gave their lives that that nation " - "might live. It is altogether fitting and proper that we should do " - "this.\n\nBut in a larger sense we can not dedicate -- we can not " - "consecrate -- we can not hallow this ground. The brave men, living and " - "dead, who struggled, here, have consecrated it far above our poor power " - "to add or detract. The world will little note, nor long remember, what we " - "say here, but can never forget what they did here. It is for us, the " - "living, rather to be dedicated here to the unfinished work which they " - "have, thus far, so nobly carried on. It is rather for us to be here " - "dedicated to the great task remaining before us -- that from these " - "honored dead we take increased devotion to that cause for which they here " - "gave the last full measure of devotion -- that we here highly resolve " - "that these dead shall not have died in vain; that this nation shall have " - "a new birth of freedom; and that this government of the people, by the " - "people, for the people, shall not perish from the earth.\n"}; - TEST_F(GemmaTest, CrossEntropySmall) { - ASSERT_NE(s_env->GetModel(), nullptr); + HWY_ASSERT(s_env->GetGemma() != nullptr); + const ModelConfig& config = s_env->GetGemma()->Config(); static const char kSmall[] = "The capital of Hungary is Budapest which is located in Europe."; float entropy = s_env->CrossEntropy(kSmall); fprintf(stderr, "per-token entropy: %f\n", entropy); - switch (s_env->GetModel()->Info().model) { - case gcpp::Model::GEMMA_2B: - // 2B v.1 and v.1.1 produce slightly different results. - EXPECT_NEAR(entropy, 2.6f, 0.2f); - break; - case gcpp::Model::GEMMA_7B: - // 7B v.1 and v.1.1 produce slightly different results. - EXPECT_NEAR(entropy, 2.8f, 0.2f); - break; + switch (config.model) { case gcpp::Model::GRIFFIN_2B: EXPECT_NEAR(entropy, 2.61f, 0.02f); break; @@ -268,76 +173,14 @@ TEST_F(GemmaTest, CrossEntropySmall) { } } -TEST_F(GemmaTest, CrossEntropyJingleBells) { - ASSERT_NE(s_env->GetModel(), nullptr); - float entropy = s_env->CrossEntropy(kJingleBells); - fprintf(stderr, "per-token entropy: %f\n", entropy); - switch (s_env->GetModel()->Info().model) { - case gcpp::Model::GEMMA_2B: - // 2B v.1 and v.1.1 produce slightly different results. - EXPECT_NEAR(entropy, 1.9f, 0.2f); - break; - case gcpp::Model::GEMMA_7B: - // 7B v.1 and v.1.1 produce slightly different results. - EXPECT_NEAR(entropy, 1.07f, 0.05f); - break; - case gcpp::Model::GRIFFIN_2B: - EXPECT_NEAR(entropy, 1.62f, 0.02f); - break; - case gcpp::Model::GEMMA2_2B: - EXPECT_NEAR(entropy, 0.49f, 0.02f); - break; - case gcpp::Model::GEMMA2_9B: - EXPECT_NEAR(entropy, 0.37f, 0.02f); - break; - case gcpp::Model::GEMMA2_27B: - EXPECT_NEAR(entropy, 0.33f, 0.02f); - break; - default: - FAIL() << "no entropy expectation for this model"; - break; - } -} - -TEST_F(GemmaTest, CrossEntropyGettysburg) { - ASSERT_NE(s_env->GetModel(), nullptr); - float entropy = s_env->CrossEntropy(kGettysburg); - fprintf(stderr, "per-token entropy: %f\n", entropy); - switch (s_env->GetModel()->Info().model) { - case gcpp::Model::GEMMA_2B: - // 2B v.1 and v.1.1 produce slightly different results. - EXPECT_NEAR(entropy, 1.1f, 0.1f); - break; - case gcpp::Model::GEMMA_7B: - // 7B v.1 and v.1.1 produce slightly different results. - EXPECT_NEAR(entropy, 0.75f, 0.1f); - break; - case gcpp::Model::GRIFFIN_2B: - EXPECT_NEAR(entropy, 0.71f, 0.02f); - break; - case gcpp::Model::GEMMA2_2B: - EXPECT_NEAR(entropy, 0.20f, 0.02f); - break; - case gcpp::Model::GEMMA2_9B: - EXPECT_NEAR(entropy, 0.15f, 0.02f); - break; - case gcpp::Model::GEMMA2_27B: - EXPECT_NEAR(entropy, 0.14f, 0.02f); - break; - default: - FAIL() << "no entropy expectation for this model"; - break; - } -} - } // namespace } // namespace gcpp int main(int argc, char** argv) { - gcpp::GemmaEnv env(argc, argv); - gcpp::s_env = &env; - testing::InitGoogleTest(&argc, argv); - - return RUN_ALL_TESTS(); + gcpp::InternalInit(); + gcpp::GemmaTest::InitEnv(argc, argv); + int ret = RUN_ALL_TESTS(); + gcpp::GemmaTest::DeleteEnv(); + return ret; } diff --git a/evals/run_mmlu.cc b/evals/run_mmlu.cc index 77c9dcd..b6537fe 100644 --- a/evals/run_mmlu.cc +++ b/evals/run_mmlu.cc @@ -19,12 +19,11 @@ #include #include -#include "compression/io.h" // Path #include "evals/benchmark_helper.h" #include "gemma/gemma.h" // Gemma +#include "io/io.h" // Path #include "util/args.h" #include "hwy/base.h" -#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/highway.h" #include "hwy/profiler.h" #include "nlohmann/json.hpp" @@ -89,7 +88,7 @@ void Run(GemmaEnv& env, JsonArgs& json) { "A", "B", "C", "D", // " A", " B", " C", " D", // "**", "**:", ":**", "The", "Answer", "is", ":", "."}; - const TokenSet accept_set(env.GetModel()->Tokenizer(), accept_strings); + const TokenSet accept_set(env.GetGemma()->Tokenizer(), accept_strings); for (auto sample : json_data["samples"]) { const int id = sample["i"]; @@ -131,8 +130,9 @@ void Run(GemmaEnv& env, JsonArgs& json) { .verbosity = env.Verbosity(), .stream_token = stream_token, }; - env.GetModel()->Generate(runtime_config, prompt, /*pos=*/0, - env.MutableKVCache(), timing_info); + env.GetGemma()->Generate(runtime_config, prompt, /*pos=*/0, + env.MutableKVCache(), env.MutableEnv(), + timing_info); std::string output_string = env.StringFromTokens(predicted_token_ids); fprintf(stderr, "Correct %s, model '%s'\n", correct_answer.c_str(), diff --git a/examples/hello_world/BUILD.bazel b/examples/hello_world/BUILD.bazel index 3160103..440e824 100644 --- a/examples/hello_world/BUILD.bazel +++ b/examples/hello_world/BUILD.bazel @@ -10,13 +10,11 @@ cc_binary( name = "hello_world", srcs = ["run.cc"], deps = [ - # Placeholder for internal dep, do not remove., - "//:app", "//:args", + "//:gemma_args", "//:gemma_lib", - "//:threading", + "//:threading_context", "//:tokenizer", "@highway//:hwy", - "@highway//:thread_pool", ], ) diff --git a/examples/hello_world/run.cc b/examples/hello_world/run.cc index 4eb8647..96c56bf 100644 --- a/examples/hello_world/run.cc +++ b/examples/hello_world/run.cc @@ -23,19 +23,15 @@ #include #include -// Placeholder for internal header, do not modify. #include "gemma/gemma.h" +#include "gemma/gemma_args.h" // LoaderArgs #include "gemma/tokenizer.h" -#include "util/app.h" // LoaderArgs #include "util/args.h" -#include "util/threading.h" +#include "util/threading_context.h" #include "hwy/base.h" -#include "hwy/contrib/thread_pool/thread_pool.h" -int main(int argc, char **argv) { { - // Placeholder for internal init, do not modify. - } +int main(int argc, char **argv) { gcpp::LoaderArgs loader(argc, argv); gcpp::InferenceArgs inference(argc, argv); gcpp::AppArgs app(argc, argv); diff --git a/examples/simplified_gemma/BUILD.bazel b/examples/simplified_gemma/BUILD.bazel index bedb322..98c0f5e 100644 --- a/examples/simplified_gemma/BUILD.bazel +++ b/examples/simplified_gemma/BUILD.bazel @@ -10,10 +10,10 @@ cc_library( name = "gemma", hdrs = ["gemma.hpp"], deps = [ - "//:app", + "//:gemma_args", "//:gemma_lib", - "//:ops", - "//:threading", + "//:matmul", + "//:threading_context", "//:tokenizer", "@highway//:hwy", ], @@ -24,15 +24,6 @@ cc_binary( srcs = ["run.cc"], deps = [ ":gemma", - # Placeholder for internal dep, do not remove., - "//:app", - "//:args", - "//:common", - "//:gemma_lib", - "//:ops", - "//:threading", - "//:tokenizer", - "@highway//:hwy", - "@highway//:thread_pool", + "//:gemma_args", ], ) diff --git a/examples/simplified_gemma/CMakeLists.txt b/examples/simplified_gemma/CMakeLists.txt index e7e6653..5595164 100644 --- a/examples/simplified_gemma/CMakeLists.txt +++ b/examples/simplified_gemma/CMakeLists.txt @@ -14,10 +14,11 @@ cmake_minimum_required(VERSION 3.11) project(simplified_gemma) +set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) include(FetchContent) -FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG c5bebf84ad01edec97e336f5c97ca4e0df6b4d06) +FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 9414b48aeec251b69e6cadbfa42bebb5ddae1c34) FetchContent_MakeAvailable(highway) FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 53de76561cfc149d3c01037f0595669ad32a5e7c) FetchContent_MakeAvailable(sentencepiece) @@ -31,7 +32,7 @@ if (NOT BUILD_MODE) endif() if (BUILD_MODE STREQUAL "local") # Relative path to gemma.cpp from examples/simplified_gemma/build/ - FetchContent_Declare(gemma SOURCE_DIR ../../..) + FetchContent_Declare(gemma SOURCE_DIR ../../..) else() FetchContent_Declare(gemma GIT_REPOSITORY https://github.com/google/gemma.cpp.git GIT_TAG a9aa63fd2ea6b786ed0706d619588bfe2d43370e) endif() diff --git a/examples/simplified_gemma/README.md b/examples/simplified_gemma/README.md index d8f9394..37b4f71 100644 --- a/examples/simplified_gemma/README.md +++ b/examples/simplified_gemma/README.md @@ -41,7 +41,7 @@ gemma.cpp specifying the tokenizer, compressed weights file, and model type, for example: ```sh -./simplified_gemma --tokenizer tokenizer.spm --compressed_weights 2b-it-sfp.sbs --model 2b-it +./simplified_gemma --tokenizer tokenizer.spm --weights 2b-it-sfp.sbs --model 2b-it ``` Should print a greeting to the terminal: diff --git a/examples/simplified_gemma/gemma.hpp b/examples/simplified_gemma/gemma.hpp index a2a7760..7f6e4c2 100644 --- a/examples/simplified_gemma/gemma.hpp +++ b/examples/simplified_gemma/gemma.hpp @@ -24,55 +24,39 @@ #include #include "third_party/gemma_cpp/gemma/gemma.h" +#include "third_party/gemma_cpp/gemma/gemma_args.h" // LoaderArgs #include "third_party/gemma_cpp/gemma/tokenizer.h" #include "third_party/gemma_cpp/ops/matmul.h" -#include "third_party/gemma_cpp/util/app.h" // LoaderArgs -#include "third_party/gemma_cpp/util/threading.h" +#include "third_party/gemma_cpp/util/threading_context.h" #include "third_party/highway/hwy/base.h" class SimplifiedGemma { public: SimplifiedGemma(const gcpp::LoaderArgs& loader, - const gcpp::InferenceArgs& inference = gcpp::InferenceArgs(), - const gcpp::AppArgs& app = gcpp::AppArgs()) - : loader_(loader), - inference_(inference), - app_(app), - topology_(gcpp::CreateTopology(app_)), - pools_(gcpp::CreatePools(topology_, app_)), - env_(topology_, pools_), - model_(gcpp::CreateGemma(loader_, env_)) { - Init(); - } - - SimplifiedGemma(int argc, char** argv) - : loader_(argc, argv, /*validate=*/true), - inference_(argc, argv), - app_(argc, argv), - topology_(gcpp::CreateTopology(app_)), - pools_(gcpp::CreatePools(topology_, app_)), - env_(topology_, pools_), - model_(gcpp::CreateGemma(loader_, env_)) { - Init(); - } - - void Init() { - // Instantiate model and KV Cache - kv_cache_ = gcpp::KVCache::Create(model_.GetModelConfig(), - inference_.prefill_tbatch_size); - + const gcpp::ThreadingArgs& threading = gcpp::ThreadingArgs(), + const gcpp::InferenceArgs& inference = gcpp::InferenceArgs()) + : ctx_(UpdateArgs(threading, inference)), + env_(ctx_), + gemma_(loader, inference, ctx_), + kv_cache_(gemma_.Config(), inference, ctx_.allocator) { // Initialize random number generator std::random_device rd; gen_.seed(rd()); } + SimplifiedGemma(int argc, char** argv) + : SimplifiedGemma(gcpp::LoaderArgs(argc, argv), + gcpp::ThreadingArgs(argc, argv), + gcpp::InferenceArgs(argc, argv)) {} + void Generate(std::string& prompt, size_t max_generated_tokens = 1024, float temperature = 0.7, const std::set& reject_tokens = {}) { size_t generated = 0; const std::vector tokens = gcpp::WrapAndTokenize( - model_.Tokenizer(), loader_.Info(), generated, prompt); + gemma_.Tokenizer(), gemma_.ChatTemplate(), + gemma_.Config().wrapping, generated, prompt); const size_t prompt_size = tokens.size(); // This callback function gets invoked every time a token is generated @@ -80,9 +64,9 @@ class SimplifiedGemma { ++generated; if (generated < prompt_size) { // print feedback - } else if (!this->model_.GetModelConfig().IsEOS(token)) { + } else if (!gemma_.Config().IsEOS(token)) { std::string token_text; - HWY_ASSERT(this->model_.Tokenizer().Decode({token}, &token_text)); + HWY_ASSERT(gemma_.Tokenizer().Decode({token}, &token_text)); std::cout << token_text << std::flush; } return true; @@ -100,19 +84,15 @@ class SimplifiedGemma { return !reject_tokens.contains(token); }, }; - model_.Generate(runtime_config, tokens, 0, kv_cache_, timing_info); + gemma_.Generate(runtime_config, tokens, 0, kv_cache_, env_, timing_info); } ~SimplifiedGemma() = default; private: - gcpp::LoaderArgs loader_; - gcpp::InferenceArgs inference_; - gcpp::AppArgs app_; - gcpp::BoundedTopology topology_; - gcpp::NestedPools pools_; + gcpp::ThreadingContext ctx_; gcpp::MatMulEnv env_; - gcpp::Gemma model_; + gcpp::Gemma gemma_; gcpp::KVCache kv_cache_; std::mt19937 gen_; std::string validation_error_; -}; \ No newline at end of file +}; diff --git a/examples/simplified_gemma/run.cc b/examples/simplified_gemma/run.cc index f73ddb5..b7af134 100644 --- a/examples/simplified_gemma/run.cc +++ b/examples/simplified_gemma/run.cc @@ -17,30 +17,25 @@ #include -// Placeholder for internal header, do not modify. #include "third_party/gemma_cpp/examples/simplified_gemma/gemma.hpp" -#include "util/app.h" // LoaderArgs +#include "gemma/gemma_args.h" // LoaderArgs int main(int argc, char** argv) { - { - // Placeholder for internal init, do not modify. - } - // Standard usage: LoaderArgs takes argc and argv as input, then parses // necessary flags. - gcpp::LoaderArgs loader(argc, argv, /*validate=*/true); + gcpp::LoaderArgs loader(argc, argv); // Optional: LoaderArgs can also take tokenizer and weights paths directly. // // gcpp::LoaderArgs loader("/path/to/tokenizer", "/path/to/weights", // "model_identifier"); - // Optional: InferenceArgs and AppArgs can be passed in as well. If not + // Optional: ThreadingArgs and InferenceArgs can be passed in as well. If not // specified, default values will be used. // // gcpp::InferenceArgs inference(argc, argv); - // gcpp::AppArgs app(argc, argv); - // SimplifiedGemma gemma(loader, inference, app); + // gcpp::ThreadingArgs threading(argc, argv); + // SimplifiedGemma gemma(loader, threading, inference); SimplifiedGemma gemma(loader); std::string prompt = "Write a greeting to the world."; diff --git a/gemma/activations.h b/gemma/activations.h index 86345e2..b222bd9 100644 --- a/gemma/activations.h +++ b/gemma/activations.h @@ -16,104 +16,199 @@ #ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_ACTIVATIONS_H_ #define THIRD_PARTY_GEMMA_CPP_GEMMA_ACTIVATIONS_H_ +#include // sqrtf #include +#include -#include "compression/shared.h" // BF16 -#include "gemma/configs.h" -#include "ops/matmul.h" // MatMulEnv -#include "ops/ops.h" // CreateInvTimescale -#include "util/allocator.h" // RowVectorBatch -#include "util/threading.h" -#include "hwy/base.h" // HWY_DASSERT -#include "hwy/contrib/thread_pool/thread_pool.h" +#include +#include + +#include "gemma/configs.h" // ModelConfig +#include "ops/matmul.h" // MatMulEnv +#include "ops/ops.h" // CreateInvTimescale +#include "util/allocator.h" // Allocator +#include "util/basics.h" // BF16 +#include "util/mat.h" // MatStorageT namespace gcpp { -struct Activations { - explicit Activations(const ModelConfig& config) - : weights_config(config), - layer_config(config.layer_configs[0]), - seq_len(config.seq_len), - cache_pos_size(config.CachePosSize()) {} +struct GriffinActivations { + GriffinActivations(const ModelConfig& config, size_t batch_size, + const Allocator& allocator) + : griffin_x( + MatFactory("griffin_x", batch_size, config.model_dim, allocator)), + griffin_y( + MatFactory("griffin_y", batch_size, config.model_dim, allocator)), + griffin_gate_x(MatFactory("griffin_gate_x", batch_size, + config.model_dim, allocator)), + griffin_multiplier(MatFactory("griffin_mul", batch_size, + config.model_dim, allocator)) {} - RowVectorBatch x; // input - RowVectorBatch q; // query, also KV if MHA. - RowVectorBatch logits; + void SetBatchSize(size_t batch_size) { + if (griffin_x.Rows() == 0) return; + griffin_x.OverrideRows(batch_size); + griffin_y.OverrideRows(batch_size); + griffin_gate_x.OverrideRows(batch_size); + griffin_multiplier.OverrideRows(batch_size); + } - // Attention - RowVectorBatch pre_att_rms_out; - RowVectorBatch att; // attention vector - RowVectorBatch att_out; // attention output + MatStorageT griffin_x; + MatStorageT griffin_y; + MatStorageT griffin_gate_x; + MatStorageT griffin_multiplier; +}; + +struct AttentionActivations { + // Returns the scale value to use for the query in the attention computation. + // Also called by ops_test. + static inline float ChooseQueryScale(const ModelConfig& config) { + if (config.query_scale == QueryScaleType::SqrtModelDimDivNumHeads) + return 1.0f / sqrtf(static_cast(config.model_dim / + config.layer_configs[0].heads)); + // QueryScaleType::SqrtKeySize + return 1.0f / sqrtf(static_cast(config.layer_configs[0].qkv_dim)); + } + + AttentionActivations( + const ModelConfig& config, const LayerConfig& layer_config, + size_t batch_size, size_t seq_len, const Allocator& allocator, + std::vector>& row_ptrs) + : config(config), + + // `vocab_size == 0` means it is for Vit part, VitAttention is still MHA + // and does not use an external KV cache. + q(MatFactory("q", batch_size, + config.vocab_size == 0 + ? layer_config.heads * 3 * layer_config.qkv_dim + : layer_config.heads * layer_config.qkv_dim, + allocator)), + + pre_att_rms_out(MatFactory("pre_att_rms_out", batch_size, + config.model_dim, allocator)), + att(MatFactory("att", batch_size, layer_config.heads * seq_len, + allocator)), + att_out(MatFactory("att_out", batch_size, + layer_config.heads * layer_config.qkv_dim, + allocator)), + att_sums( + MatFactory("att_sums", batch_size, config.model_dim, allocator)), + + inv_timescale( + CreateInvTimescale(allocator, layer_config.qkv_dim, + layer_config.post_qk == PostQKType::HalfRope)), + inv_timescale_global(CreateInvTimescale( + allocator, layer_config.qkv_dim, + layer_config.post_qk == PostQKType::HalfRope, 1000000.0)), + + div_seq_len(static_cast(seq_len)), + div_heads(static_cast(layer_config.heads)), + query_scale(ChooseQueryScale(config)) { + // Batch size can be 0 in experimental code so do not assert. + if (batch_size == 0) { + static std::atomic_flag warned = ATOMIC_FLAG_INIT; + if (!warned.test_and_set()) { + HWY_WARN("Creating mostly empty activations with a batch_size of 0."); + } + return; + } + + // For MatMul outputs, precompute their row pointers. + // If we forget any MatMul outputs here, debug builds print a warning but + // fill them in each MatMul call. + q.AllocateAndAttachRowPtrs(row_ptrs); + att_sums.AllocateAndAttachRowPtrs(row_ptrs); + } + + void SetBatchSize(size_t batch_size) { + q.OverrideRows(batch_size); + + pre_att_rms_out.OverrideRows(batch_size); + att.OverrideRows(batch_size); + att_out.OverrideRows(batch_size); + att_sums.OverrideRows(batch_size); + } + + const ModelConfig& config; + + MatStorageT q; // query + + MatStorageT pre_att_rms_out; + MatStorageT att; // attention vector + MatStorageT att_out; // attention output // Accumulation of attention outputs over heads - RowVectorBatch att_sums; - - // Gated FFW - RowVectorBatch bf_pre_ffw_rms_out; - RowVectorBatch C1; - RowVectorBatch C2; - RowVectorBatch ffw_out; - - // Griffin - RowVectorBatch griffin_x; - RowVectorBatch griffin_y; - RowVectorBatch griffin_gate_x; - RowVectorBatch griffin_multiplier; + MatStorageT att_sums; // Rope - RowVectorBatch inv_timescale; - RowVectorBatch inv_timescale_global; + MatStorageT inv_timescale; + MatStorageT inv_timescale_global; - // Dynamic because no default ctor and only initialized in `Allocate`. - MatMulEnv* env; + hwy::Divisor div_seq_len; + // Unfortunately, some models (Griffin) have non-power-of-two heads. + hwy::Divisor div_heads; + float query_scale; +}; - PostQKType post_qk = PostQKType::Rope; - // And the config. - const ModelConfig& weights_config; - const LayerConfig& layer_config; - size_t seq_len; - size_t cache_pos_size = 0; +struct Activations { + Activations(const ModelConfig& config, size_t batch_size, size_t seq_len, + const Allocator& allocator, + std::vector>& row_ptrs) + : layer_config(config.layer_configs[0]), - void Allocate(size_t batch_size, MatMulEnv* env) { - post_qk = layer_config.post_qk; - const size_t model_dim = weights_config.model_dim; - const size_t ff_hidden_dim = layer_config.ff_hidden_dim; - const size_t vocab_size = weights_config.vocab_size; - const size_t qkv_dim = layer_config.qkv_dim; - const size_t heads = layer_config.heads; + x(MatFactory("x", batch_size, config.model_dim, allocator)), + logits(MatFactory("logits", batch_size, config.vocab_size, allocator)), - x = RowVectorBatch(Extents2D(batch_size, model_dim)); - q = RowVectorBatch( - Extents2D(batch_size, heads * layer_config.QStride())); - if (vocab_size > 0) { - logits = RowVectorBatch(Extents2D(batch_size, vocab_size)); - } + pre_ffw_rms_out(MatFactory("pre_ffw_rms_out", batch_size, + config.model_dim, allocator)), + C1(MatFactory("C1", batch_size, layer_config.ff_hidden_dim, allocator)), + C2(MatFactory("C2", batch_size, layer_config.ff_hidden_dim, allocator)), + ffw_out(MatFactory("ffw_out", batch_size, config.model_dim, allocator)), - pre_att_rms_out = RowVectorBatch(Extents2D(batch_size, model_dim)); - att = RowVectorBatch( - Extents2D(batch_size, heads * weights_config.seq_len)); - att_out = RowVectorBatch(Extents2D(batch_size, heads * qkv_dim)); - att_sums = RowVectorBatch(Extents2D(batch_size, model_dim)); + attention(config, layer_config, batch_size, seq_len, allocator, + row_ptrs), + griffin(config, config.model == Model::GRIFFIN_2B ? batch_size : 0, + allocator) { + HWY_ASSERT(batch_size != 0); - bf_pre_ffw_rms_out = RowVectorBatch(Extents2D(batch_size, model_dim)); - C1 = RowVectorBatch(Extents2D(batch_size, ff_hidden_dim)); - C2 = RowVectorBatch(Extents2D(batch_size, ff_hidden_dim)); - ffw_out = RowVectorBatch(Extents2D(batch_size, model_dim)); + // For MatMul outputs, precompute their row pointers. + // If we forget any MatMul outputs here, debug builds print a warning but + // fill them in each MatMul call. + x.AllocateAndAttachRowPtrs(row_ptrs); + logits.AllocateAndAttachRowPtrs(row_ptrs); + C1.AllocateAndAttachRowPtrs(row_ptrs); + C2.AllocateAndAttachRowPtrs(row_ptrs); + ffw_out.AllocateAndAttachRowPtrs(row_ptrs); - if (layer_config.type == LayerAttentionType::kGriffinRecurrentBlock) { - griffin_x = RowVectorBatch(Extents2D(batch_size, model_dim)); - griffin_y = RowVectorBatch(Extents2D(batch_size, model_dim)); - griffin_gate_x = RowVectorBatch(Extents2D(batch_size, model_dim)); - griffin_multiplier = - RowVectorBatch(Extents2D(batch_size, model_dim)); - } - - inv_timescale = CreateInvTimescale(layer_config.qkv_dim, - post_qk == PostQKType::HalfRope); - inv_timescale_global = - CreateInvTimescale(qkv_dim, post_qk == PostQKType::HalfRope, 1000000.0); - - this->env = env; + // Note that BindC on any MatMul output considerably slows down Prefill. } + + // Negligible CPU time. + void SetBatchSize(size_t batch_size) { + x.OverrideRows(batch_size); + logits.OverrideRows(batch_size); + + pre_ffw_rms_out.OverrideRows(batch_size); + C1.OverrideRows(batch_size); + C2.OverrideRows(batch_size); + ffw_out.OverrideRows(batch_size); + + attention.SetBatchSize(batch_size); + griffin.SetBatchSize(batch_size); + } + + const LayerConfig& layer_config; + + MatStorageT x; // input + MatStorageT logits; + + // Gated FFW + MatStorageT pre_ffw_rms_out; + // Norm may be large, so prefer to keep as f32. + MatStorageT C1; + MatStorageT C2; + MatStorageT ffw_out; + + AttentionActivations attention; + GriffinActivations griffin; }; } // namespace gcpp diff --git a/gemma/attention.cc b/gemma/attention.cc new file mode 100644 index 0000000..74ea77a --- /dev/null +++ b/gemma/attention.cc @@ -0,0 +1,358 @@ +// Copyright 2025 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include + +#include "compression/types.h" // GEMMA_DISABLED_TARGETS +#ifndef HWY_DISABLED_TARGETS +#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS +#endif // HWY_DISABLED_TARGETS + +#include "gemma/activations.h" +#include "gemma/configs.h" // kMaxQKVDim +#include "gemma/gemma.h" +#include "gemma/weights.h" +#include "util/threading.h" +#include "hwy/contrib/thread_pool/thread_pool.h" +#include "hwy/profiler.h" + +// Compiles this file for multiple architectures via "foreach_target.h", to +// which we pass the filename via macro 'argument'. +// clang-format off +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "gemma/attention.cc" // NOLINT +// clang-format on +#include "hwy/foreach_target.h" // IWYU pragma: keep +#include "hwy/highway.h" +// After highway.h +#include "compression/compress-inl.h" +#include "ops/ops-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace gcpp { +namespace HWY_NAMESPACE { + +// Computes Q.K scores, which are "logits" (or scores) stored to att. +// `k` is a strided view of the kv cache with dimensions [seq_len, qkv_dim]. +static HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos, + const hwy::Divisor& div_seq_len, + const float* HWY_RESTRICT q, + const MatPtrT& k, float* HWY_RESTRICT att, + const size_t worker) { + PROFILER_ZONE2(worker, "Gen.Attention.QDotK"); + if (HWY_LIKELY(last_pos < static_cast(div_seq_len.GetDivisor()))) { + // Slightly faster: no wraparound. + for (size_t pos = start_pos; pos <= last_pos; ++pos) { + const float score = Dot(q, k.Row(pos), k.Cols()); + att[pos] = score; + } + } else { + for (size_t pos = start_pos; pos <= last_pos; ++pos) { + const size_t pos_modulo = div_seq_len.Remainder(pos); + const float score = Dot(q, k.Row(pos_modulo), k.Cols()); + att[pos_modulo] = score; + } + } +} + +static void PositionalEncodingQK(float* qk, const size_t layer_idx, + const LayerWeightsPtrs& layer, + const AttentionActivations& activations, + const size_t worker, const size_t pos, + const float mul = 1.0f) { + const size_t qkv_dim = layer.layer_config.qkv_dim; + const PostQKType& post_qk = layer.layer_config.post_qk; + // qk is either q or k, so qkv_dim is the length we operate on. + const float* inv_timescale = activations.inv_timescale.PackedScale1(); + const bool is_global_layer = activations.config.IsGlobalLayer(layer_idx); + // TODO: add a config flag instead of hardcoding the model. + if (is_global_layer && IsVLM(activations.config.model)) { + inv_timescale = activations.inv_timescale_global.PackedScale1(); + } + // PostQKType::Rope + if (post_qk == PostQKType::HalfRope) { + Rope(qk, qkv_dim / 2, inv_timescale, pos, worker); + if (mul != 1.0f) MulByConst(mul, qk, qkv_dim, worker); + } else { + RopeAndMulBy(mul, qk, qkv_dim, inv_timescale, pos, worker); + } +} + +// Accumulates the sum of v (from `kv_cache`) * probability (`att`) into +// `att_out`. Equivalent in gemma/modules.py: +// encoded = jnp.einsum('BTNS,BSNH->BTNH', probs, value_proj) +// `v` is a strided view of the kv cache with dimensions [seq_len, qkv_dim]. +static HWY_INLINE void WeightedSumV( + const size_t start_pos, const size_t last_pos, + const hwy::Divisor& div_seq_len, const float* HWY_RESTRICT att, + const MatPtrT& v, float* HWY_RESTRICT att_out, const size_t worker) { + if (HWY_LIKELY(last_pos < static_cast(div_seq_len.GetDivisor()))) { + // Slightly faster: no wraparound. Could be replaced with MatMul(att, v) if + // we supported non-transposed B. + // TODO: 2..4x unroll + MulByConstTo(att[start_pos], v.Row(start_pos), att_out, v.Cols(), worker); + for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) { + MulByConstAndAdd(att[pos], v.Row(pos), att_out, v.Cols(), worker); + } + } else { + { + const size_t pos_mod = div_seq_len.Remainder(start_pos); + MulByConstTo(att[pos_mod], v.Row(pos_mod), att_out, v.Cols(), worker); + } + for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) { + const size_t pos_mod = div_seq_len.Remainder(pos); + MulByConstAndAdd(att[pos_mod], v.Row(pos_mod), att_out, v.Cols(), worker); + } + } +} + +// Calculates the attention outputs for a single q, which may be updated +// in place for RMSNorm. +void SingleDotSoftmaxWeightedSum( + const size_t pos, const size_t start_pos, const size_t last_pos, + float* HWY_RESTRICT q, const MatPtrT& k, const MatPtrT& v, + const size_t layer_idx, const LayerWeightsPtrs& layer, + const AttentionActivations& activations, float* HWY_RESTRICT att, + float* HWY_RESTRICT att_out, const size_t worker) { + const float att_cap = activations.config.att_cap; + const float query_scale = activations.query_scale; + const size_t seq_len = + static_cast(activations.div_seq_len.GetDivisor()); + + // Apply rope and scaling to Q. + if (layer.query_norm_scale.HasPtr()) { + CallUpcasted(&layer.query_norm_scale, [&](const auto* weights_t) { + RMSNormInplace(weights_t->PackedScale1(), 0, q, + layer.layer_config.qkv_dim, worker); + }); + } + + PositionalEncodingQK(q, layer_idx, layer, activations, worker, pos, + query_scale); + + QDotK(start_pos, last_pos, activations.div_seq_len, q, k, att, worker); + + // SoftMax with optional SoftCap yields "probabilities" in att. + const size_t att_len = HWY_MIN(last_pos + 1, seq_len); + MaybeLogitsSoftCap(att_cap, att, att_len, worker); + Softmax(att, att_len, worker, /*temperature=*/1.0f); + + WeightedSumV(start_pos, last_pos, activations.div_seq_len, att, v, att_out, + worker); +} + +// The attention window usually starts at 0 unless `pos` is larger than +// the attention window size, then it is `pos` - window_size + 1. +static HWY_INLINE size_t StartPos(size_t pos, const ModelConfig& config, + size_t layer_idx) { + const size_t att_window_size = config.attention_window_sizes[layer_idx]; + return pos - HWY_MIN(att_window_size - 1, pos); +} + +void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx, + const LayerWeightsPtrs& layer, + AttentionActivations& activations, QBatch& qbatch, + NestedPools& pools) { + static const uint32_t HWY_MAYBE_UNUSED zone_id_par = + PROFILER_ADD_ZONE("Gen.Attention.DotSoftmax.par"); + + const hwy::Divisor div_qbatch(qbatch.Size()); + const LayerConfig& layer_config = layer.layer_config; + const size_t qkv_dim = layer_config.qkv_dim; + + // A "head group" in the context of GQA refers to a collection of query + // heads that share the same key and value heads. + const size_t kHeadGroups = layer_config.heads / layer_config.kv_heads; + + const size_t cache_layer_size = layer_config.CacheLayerSize(); + const size_t seq_len = + static_cast(activations.div_seq_len.GetDivisor()); + // All layers should have the same number of heads. + HWY_DASSERT(activations.div_heads.GetDivisor() == layer_config.heads); + + // For each head/token/query, compute Q.K, softmax, and weighted V. + const auto func = [&](const size_t task, size_t worker) HWY_ATTR { + const size_t tq_idx = activations.div_heads.Divide(task); + const size_t head = activations.div_heads.Remainder(task); +#if PROFILER_ENABLED + const hwy::Zone zone(worker, zone_id_par); +#endif + + const size_t qi = div_qbatch.Remainder(tq_idx); + const size_t batch_idx = div_qbatch.Divide(tq_idx); + auto& kv_cache = qbatch.KV(qi).kv_cache; + + // Find the token position in the query and calculate + // the range of cache positions to attend to. + const size_t pos = qbatch.Pos(qi) + batch_idx; + const size_t start_pos = StartPos(pos, activations.config, layer_idx); + size_t last_pos = pos; + const size_t prefix_end = qbatch.PrefixEnd(qi); + if (prefix_end > 0 && prefix_end - 1 > last_pos) { + // last_pos in QDotK and WeightedSumV is inclusive. + last_pos = prefix_end - 1; + } + + float* HWY_RESTRICT q = activations.q.Row(tq_idx) + head * qkv_dim; + float* HWY_RESTRICT att = activations.att.Row(tq_idx) + head * seq_len; + float* HWY_RESTRICT att_out = + activations.att_out.Row(tq_idx) + head * qkv_dim; + + // Make strided read-only views into the kv cache for + // this query and head. + const size_t head_offset = (head / kHeadGroups) * qkv_dim * 2; + const size_t kv_head_offset = layer_idx * cache_layer_size + head_offset; + MatPtrT k("k_view", Extents2D(seq_len, qkv_dim)); + k.SetPtr(kv_cache.Row(0) + kv_head_offset, kv_cache.Stride()); + MatPtrT v("v_view", Extents2D(seq_len, qkv_dim)); + v.SetPtr(kv_cache.Row(0) + kv_head_offset + qkv_dim, kv_cache.Stride()); + + SingleDotSoftmaxWeightedSum(pos, start_pos, last_pos, q, k, v, layer_idx, + layer, activations, att, att_out, worker); + }; + + { + PROFILER_ZONE("Gen.Attention.DotSoftmax.ForkJoin"); + const size_t pkg_idx = 0; + // Full parallelism is helpful, SmallParallelFor is insufficient. + ParallelFor(num_tokens * div_qbatch.GetDivisor() * layer_config.heads, + pools, pkg_idx, func); + } +} + +// Different functions use different naming conventions for the number of +// tokens. Functions that are query-independent, such as RMSNorm*, call the +// count `num_interleaved`. Functions that are query-dependent, such as +// `Attention`, use separate `num_tokens` and `num_queries`. `num_tokens` is the +// number of tokens from one query: 1 for decode, otherwise prefill_tbatch_size. + +// Fills activations.q and writes to KV cache. +static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx, + const LayerWeightsPtrs& layer, + AttentionActivations& activations, + const QBatch& qbatch, const int flags, + MatMulEnv& env) { + PROFILER_ZONE("Gen.Attention.QKV"); + const hwy::Divisor div_qbatch(qbatch.Size()); + const size_t num_interleaved = num_tokens * div_qbatch.GetDivisor(); + const LayerConfig& layer_config = layer.layer_config; + const size_t qkv_dim = layer_config.qkv_dim; + const size_t kv_heads = layer_config.kv_heads; + const size_t cache_layer_size = layer_config.CacheLayerSize(); + + // The original qkv_einsum_w has shape [(heads + kv_heads * 2), qkv_dim, + // model_dim], which we reshaped to (heads + kv_heads * 2) * qkv_dim rows. + CallMatMul(activations.pre_att_rms_out, layer.qkv_einsum_w1, + /*add=*/nullptr, env, activations.q); + + // Set up MatMul row pointers for writing to KV, which consists of + // `kv_heads` pairs of (k, v) vectors. This safely handles wraparound + // because rows are computed modulo seq_len. + MatPtrT kv_rows("kv", Extents2D(activations.pre_att_rms_out.Rows(), + layer.qkv_einsum_w2.Rows())); + for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved; + ++interleaved_idx) { + const size_t qi = div_qbatch.Remainder(interleaved_idx); + const size_t batch_idx = div_qbatch.Divide(interleaved_idx); + const size_t cache_pos = + activations.div_seq_len.Remainder(qbatch.Pos(qi) + batch_idx); + env.row_ptrs[2][interleaved_idx] = reinterpret_cast( + qbatch.KV(qi).kv_cache.Row(cache_pos) + layer_idx * cache_layer_size); + } + kv_rows.AttachRowPtrs(env.row_ptrs[2].get()); + CallMatMul(activations.pre_att_rms_out, layer.qkv_einsum_w2, + /*add=*/nullptr, env, kv_rows); + + // Apply positional encodings for K. + // Note that 2D parallelism is not worth the fork/join overhead because the + // tasks are very lightweight. + env.ctx.pools.Pool(0).Run( + 0, kv_heads * num_interleaved, + [&](uint64_t task, size_t thread) HWY_ATTR { + const size_t head = task % kv_heads; + const size_t interleaved_idx = task / kv_heads; + const size_t qi = div_qbatch.Remainder(interleaved_idx); + const size_t batch_idx = div_qbatch.Divide(interleaved_idx); + const size_t pos = qbatch.Pos(qi) + batch_idx; + const size_t cache_pos = activations.div_seq_len.Remainder(pos); + auto& kv_cache = qbatch.KV(qi).kv_cache; + KV_t* HWY_RESTRICT kv = kv_cache.Row(cache_pos) + + layer_idx * cache_layer_size + + head * qkv_dim * 2; + + HWY_ALIGN float kv_f32[2 * kMaxQKVDim]; + const hn::ScalableTag df; + DecompressAndZeroPad(df, MakeSpan(kv, 2 * qkv_dim), 0, kv_f32, + 2 * qkv_dim); + + // Apply further processing to K. + if (layer.key_norm_scale.HasPtr()) { + CallUpcasted(&layer.key_norm_scale, [&](const auto* weights_t) { + RMSNormInplace(weights_t->PackedScale1(), 0, kv_f32, qkv_dim, + thread); + }); + } + + PositionalEncodingQK(kv_f32, layer_idx, layer, activations, thread, + pos); + CompressPerThread tls; + Compress(kv_f32, 2 * qkv_dim, tls, MakeSpan(kv, 2 * qkv_dim), 0); + }); +} + +// Sums encoded (`att_out`) over num_heads (`layer_config.heads`) and +// head_dim (`qkv_dim`) into output (`layer_out`). +static HWY_INLINE void SumHeads(const LayerWeightsPtrs& layer, + AttentionActivations& activations, + MatMulEnv& env) { + PROFILER_ZONE("Gen.Attention.SumHeads"); + const LayerConfig& layer_config = layer.layer_config; + // att_weights and att_out are concatenated heads, each of length + // layer_config.qkv_dim. Thus the [num_interleaved, + // layer_config.model_dim] matmul output is the sum over heads. Compare + // gemma/modules.py: attn_output = self.attn_vec_einsum('BTNH,NHD->BTD', + // encoded) + HWY_DASSERT(layer_config.model_dim != 0 && layer_config.heads != 0 && + layer_config.qkv_dim != 0); + const float* add = layer_config.softmax_attn_output_biases + ? layer.attention_output_biases.PackedScale1() + : nullptr; + CallMatMul(activations.att_out, layer.att_weights, add, env, + activations.att_sums); +} + +void GemmaAttention(size_t num_tokens, const size_t layer_idx, + const LayerWeightsPtrs& layer, + AttentionActivations& activations, QBatch& qbatch, + MatMulEnv& env, int flags) { + const LayerConfig& layer_config = layer.layer_config; + HWY_DASSERT(!layer_config.IsMHA()); // No longer supported. + HWY_DASSERT_M((layer_config.heads % layer_config.kv_heads) == 0, + "query heads must be a multiple of key-value heads"); + (void)layer_config; // only used in HWY_DASSERT + + ComputeQKV(num_tokens, layer_idx, layer, activations, qbatch, flags, env); + DotSoftmaxWeightedSum(num_tokens, layer_idx, layer, activations, qbatch, + env.ctx.pools); + SumHeads(layer, activations, env); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace gcpp +HWY_AFTER_NAMESPACE(); diff --git a/gemma/attention.h b/gemma/attention.h new file mode 100644 index 0000000..42b2be1 --- /dev/null +++ b/gemma/attention.h @@ -0,0 +1,59 @@ +// Copyright 2025 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_ATTENTION_H_ +#define THIRD_PARTY_GEMMA_CPP_GEMMA_ATTENTION_H_ + +// Declares GemmaAttention for all SIMD targets. + +#include + +#include "gemma/gemma.h" +#include "hwy/highway.h" + +namespace gcpp { + +// Passed to HWY_VISIT_TARGETS; declares for one target. +#define GEMMA_DECL_ATTENTION(TARGET, NAMESPACE) \ + namespace NAMESPACE { \ + void SingleDotSoftmaxWeightedSum( \ + const size_t pos, const size_t start_pos, const size_t last_pos, \ + float* HWY_RESTRICT q, const MatPtrT& k, const MatPtrT& v, \ + size_t layer_idx, const LayerWeightsPtrs& layer, \ + const AttentionActivations& activations, float* HWY_RESTRICT att, \ + float* HWY_RESTRICT att_out, size_t worker); \ + \ + void DotSoftmaxWeightedSum(const size_t num_tokens, size_t layer_idx, \ + const LayerWeightsPtrs& layer, \ + AttentionActivations& activations, \ + QBatch& qbatch, NestedPools& pools); \ + \ + void GemmaAttention(size_t num_tokens, const size_t layer_idx, \ + const LayerWeightsPtrs& layer, \ + AttentionActivations& activations, QBatch& qbatch, \ + MatMulEnv& env, int flags); \ + /* NOLINTNEXTLINE(google-readability-namespace-comments) */ \ + } // namespace NAMESPACE + +// Function declarations for each SIMD target. Allows direct call from the +// per-target namespace. We may later replace this with dynamic dispatch if +// the overhead is acceptable. +HWY_VISIT_TARGETS(GEMMA_DECL_ATTENTION) + +#undef GEMMA_DECL_ATTENTION + +} // namespace gcpp + +#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_ATTENTION_H_ diff --git a/gemma/bindings/GemmaInterop.cs b/gemma/bindings/GemmaInterop.cs new file mode 100644 index 0000000..0fb3ee8 --- /dev/null +++ b/gemma/bindings/GemmaInterop.cs @@ -0,0 +1,473 @@ +// Copyright 2025 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System; +using System.Diagnostics; +using System.Runtime.InteropServices; +using System.Text; +namespace GemmaCpp +{ + public class GemmaException : Exception + { + public GemmaException(string message) : base(message) { } + } + + public class Gemma : IDisposable + { + private IntPtr _context; + private bool _disposed; + + // Optional: Allow setting DLL path + public static string DllPath { get; set; } = "gemma.dll"; + + [DllImport("kernel32.dll", CharSet = CharSet.Unicode, SetLastError = true)] + private static extern IntPtr LoadLibrary(string lpFileName); + + static Gemma() + { + // Load DLL from specified path + if (LoadLibrary(DllPath) == IntPtr.Zero) + { + throw new DllNotFoundException($"Failed to load {DllPath}. Error: {Marshal.GetLastWin32Error()}"); + } + } + + [DllImport("gemma", CallingConvention = CallingConvention.Cdecl)] + private static extern IntPtr GemmaCreate( + [MarshalAs(UnmanagedType.LPUTF8Str)] string tokenizerPath, + [MarshalAs(UnmanagedType.LPUTF8Str)] string modelType, + [MarshalAs(UnmanagedType.LPUTF8Str)] string weightsPath, + [MarshalAs(UnmanagedType.LPUTF8Str)] string weightType, + int maxGeneratedTokens); + + [DllImport("gemma", CallingConvention = CallingConvention.Cdecl)] + private static extern void GemmaDestroy(IntPtr context); + + // Delegate type for token callbacks + public delegate bool TokenCallback(string token); + + // Keep delegate alive for duration of calls + private GCHandle _callbackHandle; + + [UnmanagedFunctionPointer(CallingConvention.Cdecl)] + private delegate bool GemmaTokenCallback( + [MarshalAs(UnmanagedType.LPUTF8Str)] string text, + IntPtr userData); + + [DllImport("gemma", CallingConvention = CallingConvention.Cdecl)] + private static extern int GemmaGenerate( + IntPtr context, + [MarshalAs(UnmanagedType.LPUTF8Str)] string prompt, + [Out] byte[] output, + int maxOutputChars, + GemmaTokenCallback callback, + IntPtr userData); + + [DllImport("gemma", CallingConvention = CallingConvention.Cdecl)] + private static extern int GemmaGenerateMultimodal( + IntPtr context, + [MarshalAs(UnmanagedType.LPUTF8Str)] string prompt, + IntPtr image_data, // Renamed param to match C API + int image_width, // Added dimension + int image_height, // Added dimension + [MarshalAs(UnmanagedType.LPUTF8Str)] StringBuilder output, // Output should be StringBuilder for multimodal + int maxOutputChars, + GemmaTokenCallback callback, + IntPtr userData); + + [DllImport("gemma", CallingConvention = CallingConvention.Cdecl)] + private static extern int GemmaCountTokens( + IntPtr context, + [MarshalAs(UnmanagedType.LPUTF8Str)] string text); + + // Configuration function imports + [DllImport("gemma", CallingConvention = CallingConvention.Cdecl)] + private static extern void GemmaSetMaxGeneratedTokens(IntPtr context, int value); + + [DllImport("gemma", CallingConvention = CallingConvention.Cdecl)] + private static extern void GemmaSetMultiturn(IntPtr context, int value); + + [DllImport("gemma", CallingConvention = CallingConvention.Cdecl)] + private static extern void GemmaSetTemperature(IntPtr context, float value); + + [DllImport("gemma", CallingConvention = CallingConvention.Cdecl)] + private static extern void GemmaSetTopK(IntPtr context, int value); + + [DllImport("gemma", CallingConvention = CallingConvention.Cdecl)] + private static extern void GemmaSetDeterministic(IntPtr context, int value); + + [DllImport("gemma", CallingConvention = CallingConvention.Cdecl)] + private static extern void GemmaSetPrefillTbatchSize(IntPtr context, int value); + + [DllImport("gemma", CallingConvention = CallingConvention.Cdecl, EntryPoint = "GemmaResetConversation")] + private static extern void GemmaResetConversation(IntPtr context); + + // Conversation management function imports + [DllImport("gemma", CallingConvention = CallingConvention.Cdecl, EntryPoint = "GemmaCreateConversation")] + private static extern int GemmaCreateConversation( + IntPtr context, + [MarshalAs(UnmanagedType.LPUTF8Str)] string conversationName); + + [DllImport("gemma", CallingConvention = CallingConvention.Cdecl, EntryPoint = "GemmaSwitchConversation")] + private static extern int GemmaSwitchConversation( + IntPtr context, + [MarshalAs(UnmanagedType.LPUTF8Str)] string conversationName); + + [DllImport("gemma", CallingConvention = CallingConvention.Cdecl, EntryPoint = "GemmaDeleteConversation")] + private static extern int GemmaDeleteConversation( + IntPtr context, + [MarshalAs(UnmanagedType.LPUTF8Str)] string conversationName); + + [DllImport("gemma", CallingConvention = CallingConvention.Cdecl, EntryPoint = "GemmaHasConversation")] + private static extern int GemmaHasConversation( + IntPtr context, + [MarshalAs(UnmanagedType.LPUTF8Str)] string conversationName); + + [DllImport("gemma", CallingConvention = CallingConvention.Cdecl, EntryPoint = "GemmaGetCurrentConversation")] + [return: MarshalAs(UnmanagedType.LPUTF8Str)] // Marshal the const char* return value as a string + private static extern string GemmaGetCurrentConversation(IntPtr context); + + [DllImport("gemma", CallingConvention = CallingConvention.Cdecl, EntryPoint = "GemmaSaveConversation")] + private static extern void GemmaSaveConversation(IntPtr context); + + // Native callback delegate type + [UnmanagedFunctionPointer(CallingConvention.Cdecl)] + private delegate void GemmaLogCallback( + [MarshalAs(UnmanagedType.LPUTF8Str)] string message, + IntPtr userData); + + [DllImport("gemma", CallingConvention = CallingConvention.Cdecl)] + private static extern void GemmaSetLogCallback( + IntPtr context, + GemmaLogCallback callback, + IntPtr userData); + + private GCHandle _logCallbackHandle; + private bool _loggingEnabled = false; + + public Gemma(string tokenizerPath, string weightsPath, int maxGeneratedTokens = 8192) + { + _context = GemmaCreate(tokenizerPath, weightsPath, maxGeneratedTokens); + if (_context == IntPtr.Zero) + { + throw new GemmaException("Failed to create Gemma context"); + } + } + + // Enable debug logging + public void EnableLogging(bool enable = true) + { + if (enable && !_loggingEnabled) + { + GemmaLogCallback logCallback = (message, _) => + { + Debug.WriteLine($"Gemma: {message}"); + }; + _logCallbackHandle = GCHandle.Alloc(logCallback); + GemmaSetLogCallback(_context, logCallback, IntPtr.Zero); + _loggingEnabled = true; + } + else if (!enable && _loggingEnabled) + { + if (_logCallbackHandle.IsAllocated) + _logCallbackHandle.Free(); + GemmaSetLogCallback(_context, null, IntPtr.Zero); + _loggingEnabled = false; + } + } + + // Configuration methods + public void SetMultiturn(bool enable) + { + if (_disposed) + throw new ObjectDisposedException(nameof(Gemma)); + + if (_context == IntPtr.Zero) + throw new GemmaException("Gemma context is invalid"); + + GemmaSetMultiturn(_context, enable ? 1 : 0); + Debug.WriteLine($"Gemma: Set multiturn to {(enable ? "enabled" : "disabled")}"); + } + + public void SetTemperature(float temperature) + { + if (_disposed) + throw new ObjectDisposedException(nameof(Gemma)); + + if (_context == IntPtr.Zero) + throw new GemmaException("Gemma context is invalid"); + + GemmaSetTemperature(_context, temperature); + Debug.WriteLine($"Gemma: Set temperature to {temperature}"); + } + + public void SetTopK(int topK) + { + if (_disposed) + throw new ObjectDisposedException(nameof(Gemma)); + + if (_context == IntPtr.Zero) + throw new GemmaException("Gemma context is invalid"); + + GemmaSetTopK(_context, topK); + Debug.WriteLine($"Gemma: Set topK to {topK}"); + } + + public void SetDeterministic(bool deterministic) + { + if (_disposed) + throw new ObjectDisposedException(nameof(Gemma)); + + if (_context == IntPtr.Zero) + throw new GemmaException("Gemma context is invalid"); + + GemmaSetDeterministic(_context, deterministic ? 1 : 0); + Debug.WriteLine($"Gemma: Set deterministic to {(deterministic ? "true" : "false")}"); + } + + // Renamed public method + public void ResetConversation() + { + if (_disposed) + throw new ObjectDisposedException(nameof(Gemma)); + + if (_context == IntPtr.Zero) + throw new GemmaException("Gemma context is invalid"); + + GemmaResetConversation(_context); // Call P/Invoke method + Debug.WriteLine("Gemma: Reset active conversation"); + } + + // Conversation management methods + public bool CreateConversation(string conversationName) + { + if (_disposed) + throw new ObjectDisposedException(nameof(Gemma)); + + if (_context == IntPtr.Zero) + throw new GemmaException("Gemma context is invalid"); + + bool result = GemmaCreateConversation(_context, conversationName) != 0; // Call P/Invoke method + Debug.WriteLine($"Gemma: Create conversation '{conversationName}' - {(result ? "succeeded" : "failed")}"); + return result; + } + + public bool SwitchConversation(string conversationName) + { + if (_disposed) + throw new ObjectDisposedException(nameof(Gemma)); + + if (_context == IntPtr.Zero) + throw new GemmaException("Gemma context is invalid"); + + bool result = GemmaSwitchConversation(_context, conversationName) != 0; // Call P/Invoke method + Debug.WriteLine($"Gemma: Switch to conversation '{conversationName}' - {(result ? "succeeded" : "failed")}"); + return result; + } + + public bool DeleteConversation(string conversationName) + { + if (_disposed) + throw new ObjectDisposedException(nameof(Gemma)); + + if (_context == IntPtr.Zero) + throw new GemmaException("Gemma context is invalid"); + + bool result = GemmaDeleteConversation(_context, conversationName) != 0; // Call P/Invoke method + Debug.WriteLine($"Gemma: Delete conversation '{conversationName}' - {(result ? "succeeded" : "failed")}"); + return result; + } + + public bool HasConversation(string conversationName) + { + if (_disposed) + throw new ObjectDisposedException(nameof(Gemma)); + + if (_context == IntPtr.Zero) + throw new GemmaException("Gemma context is invalid"); + + bool result = GemmaHasConversation(_context, conversationName) != 0; // Call P/Invoke method + Debug.WriteLine($"Gemma: Has conversation '{conversationName}' - {result}"); + return result; + } + + public string GetCurrentConversation() + { + if (_disposed) + throw new ObjectDisposedException(nameof(Gemma)); + + if (_context == IntPtr.Zero) + throw new GemmaException("Gemma context is invalid"); + + string currentConversation = GemmaGetCurrentConversation(_context); // Call P/Invoke method + Debug.WriteLine($"Gemma: Current conversation is '{currentConversation}'"); + return currentConversation; + } + + public void SaveConversation() + { + if (_disposed) + throw new ObjectDisposedException(nameof(Gemma)); + + if (_context == IntPtr.Zero) + throw new GemmaException("Gemma context is invalid"); + + GemmaSaveConversation(_context); + Debug.WriteLine($"Gemma: Saved current conversation ('{GetCurrentConversation()}') to prewarmed cache."); + } + + public int CountTokens(string prompt) + { + if (_disposed) + throw new ObjectDisposedException(nameof(Gemma)); + + if (_context == IntPtr.Zero) + throw new GemmaException("Gemma context is invalid"); + int count = GemmaCountTokens(_context, prompt); + return count; + } + + public string Generate(string prompt, int maxOutputChars = 4096) + { + return Generate(prompt, null, maxOutputChars); + } + + public string Generate(string prompt, TokenCallback callback, int maxOutputChars = 4096) + { + if (_disposed) + throw new ObjectDisposedException(nameof(Gemma)); + + if (_context == IntPtr.Zero) + throw new GemmaException("Gemma context is invalid"); + + var outputBuffer = new byte[maxOutputChars * 4]; // Allow for worst case UTF-8 size + GemmaTokenCallback nativeCallback = null; + + // Track token count for debugging + int tokenCount = 0; + + if (callback != null) + { + nativeCallback = (text, _) => + { + tokenCount++; + // Log token for debugging + Debug.WriteLine($"Token {tokenCount}: '{text}'"); + + // Pass token to user callback + return callback(text); + }; + _callbackHandle = GCHandle.Alloc(nativeCallback); + } + + try + { + int length = GemmaGenerate(_context, prompt, outputBuffer, maxOutputChars, + nativeCallback, IntPtr.Zero); + + if (length < 0) + throw new GemmaException("Generation failed"); + + Debug.WriteLine($"Generation complete: {tokenCount} tokens processed, result length: {length}"); + + // Convert the byte buffer to a string using UTF-8 encoding + string result = Encoding.UTF8.GetString(outputBuffer, 0, length); + return result; + } + finally + { + if (_callbackHandle.IsAllocated) + _callbackHandle.Free(); + } + } + + public string GenerateMultimodal(string prompt, float[] imageData, int imageWidth, int imageHeight, int maxOutputChars = 4096) + { + // Pass width and height to the overloaded method + return GenerateMultimodal(prompt, imageData, imageWidth, imageHeight, null, maxOutputChars); + } + + public string GenerateMultimodal(string prompt, float[] imageData, int imageWidth, int imageHeight, TokenCallback callback, int maxOutputChars = 4096) + { + if (_disposed) + throw new ObjectDisposedException(nameof(Gemma)); + + if (_context == IntPtr.Zero) + throw new GemmaException("Gemma context is invalid"); + + if (imageData == null || imageData.Length == 0) + throw new ArgumentException("Image data cannot be null or empty", nameof(imageData)); + + if (imageWidth <= 0 || imageHeight <= 0) + throw new ArgumentException("Image dimensions must be positive"); + + if (imageData.Length < imageWidth * imageHeight * 3) + throw new ArgumentException("Image data array is too small for the specified dimensions"); + + var output = new StringBuilder(maxOutputChars); + GemmaTokenCallback nativeCallback = null; + + if (callback != null) + { + nativeCallback = (text, _) => callback(text); + _callbackHandle = GCHandle.Alloc(nativeCallback); + } + + // Pin the image data so it doesn't move during the native call + GCHandle imageHandle = GCHandle.Alloc(imageData, GCHandleType.Pinned); + + try + { + IntPtr imagePtr = imageHandle.AddrOfPinnedObject(); + + // Pass image dimensions to the native call + int length = GemmaGenerateMultimodal(_context, prompt, imagePtr, imageWidth, imageHeight, output, maxOutputChars, + nativeCallback, IntPtr.Zero); + + if (length < 0) + throw new GemmaException("Multimodal generation failed"); + + return output.ToString(); + } + finally + { + imageHandle.Free(); + + if (_callbackHandle.IsAllocated) + _callbackHandle.Free(); + } + } + + public void Dispose() + { + if (!_disposed) + { + if (_context != IntPtr.Zero) + { + GemmaDestroy(_context); + _context = IntPtr.Zero; + } + if (_logCallbackHandle.IsAllocated) + _logCallbackHandle.Free(); + _disposed = true; + } + } + + ~Gemma() + { + Dispose(); + } + } +} diff --git a/gemma/bindings/c_api.cc b/gemma/bindings/c_api.cc new file mode 100644 index 0000000..cba2ffb --- /dev/null +++ b/gemma/bindings/c_api.cc @@ -0,0 +1,139 @@ +// Copyright 2025 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GEMMA_EXPORTS +#define GEMMA_EXPORTS +#endif + +#include "gemma/bindings/c_api.h" + +extern "C" { + +GEMMA_API GemmaContext* GemmaCreate(const char* tokenizer_path, + const char* weights_path, + int max_generated_tokens) { + try { + GemmaContext* ctx = GemmaContext::Create(tokenizer_path, weights_path, + max_generated_tokens); + return ctx; + } catch (...) { + return nullptr; + } +} + +GEMMA_API void GemmaDestroy(GemmaContext* ctx) { + delete ctx; +} + +GEMMA_API int GemmaGenerate(GemmaContext* ctx, const char* prompt, char* output, + int max_output_chars, GemmaTokenCallback callback, + void* user_data) { + if (!ctx) return -1; + return ctx->Generate(prompt, output, max_output_chars, callback, user_data); +} + +GEMMA_API int GemmaGenerateMultimodal(GemmaContext* ctx, const char* prompt, + const void* image_data, int image_width, + int image_height, char* output, + int max_output_chars, + GemmaTokenCallback callback, + void* user_data) { + if (!ctx) return -1; + + return ctx->GenerateMultimodal(prompt, image_data, image_width, image_height, + output, max_output_chars, callback, user_data); +} + +GEMMA_API int GemmaCountTokens(GemmaContext* ctx, const char* text) { + if (!ctx || !text) return -1; + return ctx->CountTokens(text); +} + +GEMMA_API void GemmaSetLogCallback(GemmaContext* ctx, GemmaLogCallback callback, + void* user_data) { + if (!ctx) return; + ctx->SetLogCallback(callback, user_data); +} + +// Configuration functions implementation +GEMMA_API void GemmaSetMaxGeneratedTokens(GemmaContext* ctx, int value) { + if (!ctx) return; + ctx->SetMaxGeneratedTokens(value); +} + +GEMMA_API void GemmaSetMultiturn(GemmaContext* ctx, int value) { + if (!ctx) return; + ctx->SetMultiturn(value); +} + +GEMMA_API void GemmaSetTemperature(GemmaContext* ctx, float value) { + if (!ctx) return; + ctx->SetTemperature(value); +} + +GEMMA_API void GemmaSetTopK(GemmaContext* ctx, int value) { + if (!ctx) return; + ctx->SetTopK(value); +} + +GEMMA_API void GemmaSetDeterministic(GemmaContext* ctx, int value) { + if (!ctx) return; + ctx->SetDeterministic(value != 0); +} + +GEMMA_API void GemmaSetPrefillTbatchSize(GemmaContext* ctx, int value) { + if (!ctx) return; + ctx->SetPrefillTbatchSize(value); +} + +GEMMA_API void GemmaResetConversation(GemmaContext* ctx) { // Renamed function + if (!ctx) return; + ctx->ResetConversation(); +} + +GEMMA_API int GemmaCreateConversation(GemmaContext* ctx, + const char* conversation_name) { + if (!ctx || !conversation_name) return 0; + return ctx->CreateConversation(conversation_name) ? 1 : 0; +} + +GEMMA_API int GemmaSwitchConversation(GemmaContext* ctx, + const char* conversation_name) { + if (!ctx || !conversation_name) return 0; + return ctx->SwitchConversation(conversation_name) ? 1 : 0; +} + +GEMMA_API int GemmaDeleteConversation(GemmaContext* ctx, + const char* conversation_name) { + if (!ctx || !conversation_name) return 0; + return ctx->DeleteConversation(conversation_name) ? 1 : 0; +} + +GEMMA_API int GemmaHasConversation(GemmaContext* ctx, + const char* conversation_name) { + if (!ctx || !conversation_name) return 0; + return ctx->HasConversation(conversation_name) ? 1 : 0; +} + +GEMMA_API const char* GemmaGetCurrentConversation(GemmaContext* ctx) { + if (!ctx) return nullptr; + return ctx->GetCurrentConversation(); +} + +GEMMA_API void GemmaSaveConversation(GemmaContext* ctx) { + if (!ctx) return; + ctx->SaveConversation(); +} +} diff --git a/gemma/bindings/c_api.h b/gemma/bindings/c_api.h new file mode 100644 index 0000000..6d369b8 --- /dev/null +++ b/gemma/bindings/c_api.h @@ -0,0 +1,86 @@ +// Copyright 2025 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_GEMMA_C_API_H_ +#define THIRD_PARTY_GEMMA_C_API_H_ + +#include "gemma/bindings/context.h" + +#ifdef _WIN32 +#ifdef GEMMA_EXPORTS +#define GEMMA_API __declspec(dllexport) +#else +#define GEMMA_API __declspec(dllimport) +#endif +#else +#define GEMMA_API __attribute__((visibility("default"))) +#endif + +#ifdef __cplusplus +extern "C" { +#endif + +#ifdef __cplusplus +typedef gcpp::GemmaContext GemmaContext; +#else +typedef struct GemmaContext GemmaContext; +#endif + +typedef bool (*GemmaTokenCallback)(const char* text, void* user_data); +typedef void (*GemmaLogCallback)(const char* message, void* user_data); + +GEMMA_API GemmaContext* GemmaCreate(const char* tokenizer_path, + const char* weights_path, + int max_generated_tokens); +GEMMA_API void GemmaDestroy(GemmaContext* ctx); +GEMMA_API int GemmaGenerate(GemmaContext* ctx, const char* prompt, char* output, + int max_output_chars, GemmaTokenCallback callback, + void* user_data); +GEMMA_API int GemmaGenerateMultimodal(GemmaContext* ctx, const char* prompt, + const void* image_data, int image_width, + int image_height, char* output, + int max_output_chars, + GemmaTokenCallback callback, + void* user_data); + +GEMMA_API int GemmaCountTokens(GemmaContext* ctx, const char* text); + +GEMMA_API void GemmaSetLogCallback(GemmaContext* ctx, GemmaLogCallback callback, + void* user_data); + +// Configuration functions +GEMMA_API void GemmaSetMultiturn(GemmaContext* ctx, int value); +GEMMA_API void GemmaSetTemperature(GemmaContext* ctx, float value); +GEMMA_API void GemmaSetTopK(GemmaContext* ctx, int value); +GEMMA_API void GemmaSetDeterministic(GemmaContext* ctx, int value); +GEMMA_API void GemmaResetConversation(GemmaContext* ctx); + +// Conversation management functions (renamed) +GEMMA_API int GemmaCreateConversation(GemmaContext* ctx, + const char* conversation_name); +GEMMA_API int GemmaSwitchConversation(GemmaContext* ctx, + const char* conversation_name); +GEMMA_API int GemmaDeleteConversation(GemmaContext* ctx, + const char* conversation_name); +GEMMA_API int GemmaHasConversation(GemmaContext* ctx, + const char* conversation_name); +GEMMA_API const char* GemmaGetCurrentConversation(GemmaContext* ctx); +GEMMA_API void GemmaSaveConversation(GemmaContext* ctx); + +#ifdef __cplusplus +} +#endif + +#endif // THIRD_PARTY_GEMMA_C_API_H_ diff --git a/gemma/bindings/context.cc b/gemma/bindings/context.cc new file mode 100644 index 0000000..76ebe1e --- /dev/null +++ b/gemma/bindings/context.cc @@ -0,0 +1,350 @@ +// Copyright 2025 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gemma/bindings/context.h" + +#include +#include // strncpy + +#include +#include +#include +#include + +#include "evals/benchmark_helper.h" // InitGenerator +#include "gemma/gemma.h" +#include "gemma/gemma_args.h" +#include "gemma/tokenizer.h" // WrapAndTokenize +#include "util/threading.h" +#include "util/threading_context.h" +#include "hwy/profiler.h" +#include "hwy/timer.h" + +#ifdef _WIN32 +#include +#endif + +#include "gemma/kv_cache.h" +#include "paligemma/image.h" + +namespace gcpp { + +// ConversationData constructor implementation +ConversationData::ConversationData(const ModelConfig& model_config, + const InferenceArgs& inference_args, + const Allocator& allocator) + : kv_cache( + std::make_unique(model_config, inference_args, allocator)), + abs_pos(0) {} + +// ConversationData copy constructor implementation +ConversationData::ConversationData(const ConversationData& other) + : kv_cache(nullptr), abs_pos(other.abs_pos) { + if (other.kv_cache) { + kv_cache = std::make_unique(other.kv_cache->Copy()); + } +} + +// Initialize static members +GemmaLogCallback GemmaContext::s_log_callback = nullptr; +void* GemmaContext::s_log_user_data = nullptr; + +GemmaContext* GemmaContext::Create(const char* tokenizer_path, + const char* weights_path, + int max_generated_tokens) { + std::stringstream ss; + ss << "Creating GemmaContext with tokenizer_path: " + << (tokenizer_path ? tokenizer_path : "null") + << ", weights_path: " << (weights_path ? weights_path : "null") + << ", max_generated_tokens: " << max_generated_tokens; + LogDebug(ss.str().c_str()); + + ThreadingArgs threading_args; + threading_args.spin = gcpp::Tristate::kFalse; + + LoaderArgs loader(tokenizer_path, weights_path); + LogDebug("LoaderArgs created"); + + // Initialize cached args + LogDebug("Initializing inference args"); + InferenceArgs inference_args; + inference_args.Init(); + inference_args.max_generated_tokens = max_generated_tokens; + inference_args.temperature = 0.7f; + inference_args.top_k = 1; + inference_args.deterministic = false; + + ss.str(""); + ss << "Inference args initialized with max_tokens: " << max_generated_tokens + << ", temperature: " << inference_args.temperature + << ", top_k: " << inference_args.top_k << ", deterministic: " + << (inference_args.deterministic ? "true" : "false"); + LogDebug(ss.str().c_str()); + + return new GemmaContext(loader, inference_args, threading_args, + max_generated_tokens); +} + +GemmaContext::GemmaContext(const LoaderArgs& loader, + const InferenceArgs& inference_args, + const ThreadingArgs& threading_args, + int max_generated_tokens) + : inference_args(inference_args), + threading_args(threading_args), + ctx(UpdateArgs(threading_args, inference_args)), + matmul_env(ctx), + active_conversation_name("default"), + model(loader, inference_args, matmul_env.ctx) { + std::stringstream ss; + + LogDebug("Creating initial ConversationData"); + // Create the initial ConversationData object using make_shared + active_conversation = std::make_shared( + model.Config(), inference_args, ctx.allocator); + + LogDebug( + "Storing initial ConversationData in conversation_cache[\"default\"]"); + // Store the shared_ptr in the map under the "default" key + conversation_cache["default"] = active_conversation; + + LogDebug("GemmaContext constructor completed"); +} + +// Internal implementation shared by Generate and GenerateMultimodal +int GemmaContext::GenerateInternal(const char* prompt_string, + const void* image_data, int image_width, + int image_height, char* output, + int max_output_chars, + GemmaTokenCallback callback, + void* user_data) { + PROFILER_ZONE("Gen.Internal"); + size_t tokens_generated_this_turn = 0; // differentiates prefill from reply + size_t prompt_size = 0; + std::stringstream ss; + result_buffer.clear(); + + InitGenerator(inference_args, gen); + + // Ensure we have an active conversation + if (!active_conversation || !active_conversation->kv_cache) { + LogDebug("Generate called with null active_conversation or kv_cache"); + return -1; + } + + // callback function invoked for each generated token. + auto stream_token = [&, callback, user_data](int token, float) { + // Use abs_pos from the active conversation + ++(active_conversation->abs_pos); + const bool in_prompt = tokens_generated_this_turn < prompt_size; + const bool first_response_token = tokens_generated_this_turn == prompt_size; + ++tokens_generated_this_turn; + if (in_prompt || model.Config().IsEOS(token)) { + return true; + } + + std::string token_text; + HWY_ASSERT(model.Tokenizer().Decode(std::vector{token}, &token_text)); + if (first_response_token) { + token_text.erase(0, token_text.find_first_not_of(" \t\n")); + } + + // if we have a managed callback, pass it the token text + if (callback) { + if (!callback(token_text.c_str(), user_data)) { + LogDebug("Callback returned false, stopping generation"); + return false; + } + } + + result_buffer.append(token_text); + return true; + }; + + // set up runtime config + TimingInfo timing_info = {}; + RuntimeConfig runtime_config = {.gen = &gen, + .stream_token = stream_token, + .use_spinning = threading_args.spin}; + inference_args.CopyTo(runtime_config); + size_t prefix_end = 0; + + const ModelConfig& model_config = model.Config(); + + // generate + std::vector prompt; + const size_t pool_dim = model_config.vit_config.pool_dim; + ImageTokens image_tokens( + "image_tokens", + image_data + ? Extents2D(model_config.vit_config.seq_len / (pool_dim * pool_dim), + model_config.model_dim) + : Extents2D(0, 0), + ctx.allocator, MatPadding::kOdd); + if (image_data != nullptr) { + HWY_ASSERT(model_config.wrapping == PromptWrapping::PALIGEMMA || + model_config.wrapping == PromptWrapping::GEMMA_VLM); + + Image image; + image.Set(image_width, image_height, static_cast(image_data)); + + // We may need to resize the supplied image depending on whether we're using + // PaliGemma or Gemma 3. + const size_t image_size = model_config.vit_config.image_size; + image.Resize(image_size, image_size); + + // Use the existing runtime_config defined earlier in the function. + // RuntimeConfig runtime_config = { ... }; // This was already defined + double image_tokens_start = hwy::platform::Now(); + // Pass the populated image object to GenerateImageTokens + model.GenerateImageTokens(runtime_config, + active_conversation->kv_cache->SeqLen(), image, + image_tokens, matmul_env); + double image_tokens_duration = hwy::platform::Now() - image_tokens_start; + + ss.str(""); + ss << "\n\n[ Timing info ] Image token generation took: "; + ss << static_cast(image_tokens_duration * 1000) << " ms\n", + LogDebug(ss.str().c_str()); + + prompt = WrapAndTokenize( + model.Tokenizer(), model.ChatTemplate(), model_config.wrapping, + active_conversation->abs_pos, prompt_string, image_tokens.Rows()); + runtime_config.image_tokens = &image_tokens; + prompt_size = prompt.size(); + // The end of the prefix for prefix-LM style attention in Paligemma. + // See Figure 2 of https://arxiv.org/abs/2407.07726. + prefix_end = prompt_size; + } else { + // Text-only case (original logic) + // Use abs_pos from the active conversation + prompt = WrapAndTokenize(model.Tokenizer(), model.ChatTemplate(), + model_config.wrapping, + active_conversation->abs_pos, prompt_string); + prompt_size = prompt.size(); + } + + // Check if prompt generation failed (e.g., multimodal not implemented yet) + if (prompt.empty() && image_data != nullptr) { + // Already logged the error, just ensure we don't proceed. + return -1; + } + + // Create a span from the prompt vector - Generate() expects a hwy::Span, + // which has a different memory footprint to that of a std::vector. + hwy::Span prompt_span(prompt.data(), prompt.size()); + + // Pass the KVCache object by reference from the active conversation + model.Generate(runtime_config, prompt_span, active_conversation->abs_pos, + prefix_end, *active_conversation->kv_cache, matmul_env, + timing_info); + + // prepare for next turn + if (!inference_args.multiturn || + model_config.wrapping == PromptWrapping::PALIGEMMA) { + // If not multiturn, or Paligemma (which handles turns differently), + // reset the *active* conversation's position. + active_conversation->abs_pos = 0; + InitGenerator(inference_args, gen); + } else { + // Multi-turn Gemma: Rewind position in the active conversation + // The last token was either EOS, then it should be ignored because it is + // never part of the dialog, see Table 5 in the Gemma-2 paper: + // https://arxiv.org/pdf/2408.00118 + // Or we have hit max_generated_tokens, then the last token will be lost. + // (We could store it in stream_token, and then prepend to the next turn, + // but it's not worth the complexity, as multi-turn with max_generated is + // not a common use case.) + // In either case, we need to rewind the active conversation's abs_pos by + // one. + HWY_ASSERT(active_conversation->abs_pos > 0); + active_conversation->abs_pos--; + } + + // Copy result buffer to output C-string (ensure null termination) + strncpy(output, result_buffer.c_str(), max_output_chars - 1); + output[max_output_chars - 1] = '\0'; + + return static_cast(strlen(output)); +} + +// Public Generate method (wrapper for text-only) +int GemmaContext::Generate(const char* prompt_string, char* output, + int max_output_chars, GemmaTokenCallback callback, + void* user_data) { + // Call the internal implementation with null image_data and 0 dimensions + return GenerateInternal(prompt_string, nullptr, 0, 0, output, + max_output_chars, callback, user_data); +} + +// Public GenerateMultimodal method (wrapper) +int GemmaContext::GenerateMultimodal(const char* prompt_string, + const void* image_data, int image_width, + int image_height, char* output, + int max_output_chars, + GemmaTokenCallback callback, + void* user_data) { + if (image_data == nullptr) { + LogDebug( + "GenerateMultimodal called with null image_data. Use Generate for " + "text-only."); + // Or potentially call GenerateInternal with null image_data anyway? + // Returning error seems safer. + return -1; + } + + return GenerateInternal(prompt_string, image_data, image_width, image_height, + output, max_output_chars, callback, user_data); +} + +int GemmaContext::CountTokens(const char* text) { + LogDebug("CountTokens method started"); + std::stringstream ss; + ss << "CountTokens called with text: '" << (text ? text : "null") << "'"; + LogDebug(ss.str().c_str()); + + if (!text) { + LogDebug("CountTokens failed: Invalid parameters"); + if (!text) LogDebug(" text is null"); + return -1; + } + + try { + LogDebug("Creating text string"); + std::string text_str(text); + + LogDebug("Creating tokens vector"); + std::vector tokens; + + LogDebug("Encoding text to tokens"); + HWY_ASSERT(model.Tokenizer().Encode(text_str, &tokens)); + + ss.str(""); + ss << "Text tokenized into " << tokens.size() << " tokens"; + LogDebug(ss.str().c_str()); + + LogDebug("CountTokens completed successfully"); + return static_cast(tokens.size()); + } catch (...) { + LogDebug("Unknown exception in CountTokens"); + return -1; + } +} + +// Get the name of the currently active conversation +const char* GemmaContext::GetCurrentConversation() { + return active_conversation_name.c_str(); +} + +} // namespace gcpp diff --git a/gemma/bindings/context.h b/gemma/bindings/context.h new file mode 100644 index 0000000..859a644 --- /dev/null +++ b/gemma/bindings/context.h @@ -0,0 +1,316 @@ +// Copyright 2025 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_CONTEXT_H_ +#define THIRD_PARTY_GEMMA_CPP_GEMMA_CONTEXT_H_ + +#include // For std::shared_ptr, std::make_shared +#include +#include +#include +#include + +// Logging +#ifdef _WIN32 +#include +#else +#include +#endif + +#include "gemma/gemma.h" +#include "gemma/gemma_args.h" +#include "gemma/kv_cache.h" +#include "ops/matmul.h" // MatMulEnv +#include "hwy/base.h" +#include "hwy/highway.h" + +namespace gcpp { + +// Struct to hold data for a single conversation thread +struct ConversationData { + ConversationData(const ModelConfig& model_config, + const InferenceArgs& inference_args, + const Allocator& allocator); + ConversationData(const ConversationData& other); + + std::unique_ptr kv_cache; + size_t abs_pos = 0; +}; + +typedef bool (*GemmaTokenCallback)(const char* text, void* user_data); +typedef void (*GemmaLogCallback)(const char* message, void* user_data); + +class GemmaContext { + private: + GemmaContext(const LoaderArgs& loader, const InferenceArgs& inference_args, + const ThreadingArgs& threading_args, int max_generated_tokens); + + public: + static GemmaContext* Create(const char* tokenizer_path, + const char* weights_path, + int max_generated_tokens); + + // Returns length of generated text, or -1 on error + int Generate(const char* prompt_string, char* output, int max_output_chars, + GemmaTokenCallback callback, void* user_data); + // Returns length of generated text, or -1 on error + int GenerateMultimodal(const char* prompt_string, const void* image_data, + int image_width, int image_height, char* output, + int max_output_chars, GemmaTokenCallback callback, + void* user_data); + + // Returns number of tokens in text, or -1 on error + int CountTokens(const char* text); + + // Add new method to set logger + static void SetLogCallback(GemmaLogCallback callback, void* user_data) { + s_log_callback = callback; + s_log_user_data = user_data; + } + + // Set max generated tokens + void SetMaxGeneratedTokens(size_t value) { + inference_args.max_generated_tokens = value; + LogDebug("Setting max_generated_tokens to configured value"); + } + + // Set multiturn flag (0 = disabled, 1 = enabled) + void SetMultiturn(int value) { + inference_args.multiturn = value; + LogDebug("Setting multiturn to configured value"); + } + + // Set temperature for token generation + void SetTemperature(float value) { + inference_args.temperature = value; + LogDebug("Setting temperature to configured value"); + } + + // Set top_k parameter for sampling + void SetTopK(int value) { + inference_args.top_k = value; + LogDebug("Setting top_k to configured value"); + } + + // Set deterministic flag + void SetDeterministic(bool value) { + inference_args.deterministic = value; + // Reset the random number generator for deterministic generation + if (value) { + gen.seed(0x87654321); + } + LogDebug("Setting deterministic flag to configured value"); + } + + // Set prefill_tbatch_size + void SetPrefillTbatchSize(size_t value) { + inference_args.prefill_tbatch_size = value; + LogDebug("Setting prefill_tbatch_size to configured value"); + } + + void SaveConversation() { + if (!active_conversation || active_conversation_name.empty()) { + if (!active_conversation) { + LogDebug("SaveConversation: No active conversation to save."); + } else { // active_conversation_name must be empty + LogDebug( + "SaveConversation: Active conversation name is empty. Cannot " + "save."); + } + return; + } + std::string log_msg = "SaveConversation: Attempting to save '"; + log_msg += active_conversation_name; + log_msg += "' to prewarmed_cache."; + LogDebug(log_msg.c_str()); + + // Create a deep copy of the active_conversation via copy ctor. + auto conversation_copy = + std::make_shared(*active_conversation); + + // Store the deep copy in prewarmed_cache. + // If a conversation with the same name already exists, it will be + // overwritten. std::shared_ptr will handle the destruction of the old + // object if it's being replaced. + prewarmed_cache[active_conversation_name] = conversation_copy; + + log_msg = "SaveConversation: Successfully saved '"; + log_msg += active_conversation_name; + log_msg += "' to prewarmed_cache."; + LogDebug(log_msg.c_str()); + } + + // Reset the currently active conversation + void ResetConversation() { + if (active_conversation) { + std::string log_prefix = "ResetConversation ('"; + log_prefix += active_conversation_name.empty() ? "[unnamed]" + : active_conversation_name; + log_prefix += "'): "; + LogDebug((log_prefix + "Attempting to reset.").c_str()); + // Attempt to restore from prewarmed_cache first, regardless of name. + auto it = prewarmed_cache.find(active_conversation_name); + if (it != prewarmed_cache.end() && it->second && it->second->kv_cache) { + // Found in prewarmed_cache and the cached entry is valid. + LogDebug((log_prefix + "Found in prewarmed_cache. Restoring state.") + .c_str()); + active_conversation->abs_pos = it->second->abs_pos; + // Perform a deep copy of the KVCache from the prewarmed version. + active_conversation->kv_cache = + std::make_unique(it->second->kv_cache->Copy()); + LogDebug((log_prefix + "Successfully restored from prewarmed_cache.") + .c_str()); + return; + } + + // If not found in prewarmed_cache or prewarmed_cache entry is invalid, + // rewind to initial state. + active_conversation->abs_pos = 0; + // Replace the cache within the current ConversationData object + active_conversation->kv_cache = std::make_unique( + model.Config(), inference_args, ctx.allocator); + + LogDebug((log_prefix + "Successfully rewound to initial state.").c_str()); + } else { + LogDebug("Cannot reset conversation: active_conversation is null"); + } + } + + // Create a new named conversation + bool CreateConversation(const char* conversation_name) { + std::string name(conversation_name); + if (conversation_cache.count(name)) { + LogDebug("Conversation already exists"); + return false; + } + LogDebug("Creating new conversation"); + // Create a new ConversationData object using make_shared + conversation_cache[name] = std::make_shared( + model.Config(), inference_args, ctx.allocator); + return true; + } + + // Switch to a named conversation + bool SwitchConversation(const char* conversation_name) { + std::string name(conversation_name); + auto it = conversation_cache.find(name); + if (it == conversation_cache.end()) { + LogDebug("Conversation not found"); + return false; + } + LogDebug("Switching active conversation"); + active_conversation = it->second; + active_conversation_name = conversation_name; + return true; + } + + // Delete a named conversation + bool DeleteConversation(const char* conversation_name) { + std::string name(conversation_name); + auto it = conversation_cache.find(name); + + if (it == conversation_cache.end()) { + LogDebug("Conversation not found for deletion"); + return false; + } + if (name == "default") { + LogDebug("Cannot delete the default conversation"); + return false; + } + if (it->second == active_conversation) { + LogDebug("Cannot delete the currently active conversation"); + return false; + } + + LogDebug("Deleting conversation"); + conversation_cache.erase(it); + + auto it2 = prewarmed_cache.find(name); + if (it2 != prewarmed_cache.end()) { + prewarmed_cache.erase(it2); + } + + return true; + } + + // Check if a named conversation exists + bool HasConversation(const char* conversation_name) { + std::string name(conversation_name); + return conversation_cache.count(name); + } + + // Get the name of the currently active conversation + const char* GetCurrentConversation(); + + private: + // Internal implementation shared by Generate and GenerateMultimodal + int GenerateInternal(const char* prompt_string, + const void* image_data, // Null for text-only generation + int image_width, + int image_height, + char* output, int max_output_chars, + GemmaTokenCallback callback, void* user_data); + + // Pointer to the currently active conversation's data + std::shared_ptr active_conversation; + + // Cache of all named conversations + std::unordered_map> + conversation_cache; + std::unordered_map> + prewarmed_cache; + + // Buffers (potentially could be moved into ConversationData if needed + // per-conversation) + std::string prompt_buffer; + std::string result_buffer; + std::vector token_buffer; + + // Cached args (remain global for the context) + InferenceArgs inference_args; + ThreadingArgs threading_args; + ThreadingContext ctx; + MatMulEnv matmul_env; + + std::string active_conversation_name; + + // Model itself (don't move this, needs to be below the args above) + Gemma model; + + // Random generator (remains global for the context) + std::mt19937 gen; + + // Static members for logging + static GemmaLogCallback s_log_callback; + static void* s_log_user_data; + + // Use logging helper method to print messages into a managed callback if + // necessary + static void LogDebug(const char* message) { + if (s_log_callback != nullptr) { + s_log_callback(message, s_log_user_data); + } else { +#ifdef _WIN32 + OutputDebugStringA(message); +#else + printf("%s", message); +#endif + } + } +}; + +} // namespace gcpp + +#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_CONTEXT_H_ diff --git a/gemma/common.cc b/gemma/common.cc deleted file mode 100644 index 0d8977b..0000000 --- a/gemma/common.cc +++ /dev/null @@ -1,177 +0,0 @@ -// Copyright 2024 Google LLC -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "gemma/common.h" - -#include // sqrtf -#include -#include - -#include // std::transform -#include -#include -#include - -#include "compression/shared.h" -#include "hwy/base.h" -#include "hwy/contrib/thread_pool/thread_pool.h" - -namespace gcpp { - -constexpr const char* kModelFlags[] = { - "2b-pt", "2b-it", // Gemma 2B - "7b-pt", "7b-it", // Gemma 7B - "gr2b-pt", "gr2b-it", // RecurrentGemma - "tiny", // Gemma Tiny (mostly for debugging) - "gemma2-2b-pt", "gemma2-2b-it", // Gemma2 2B - "9b-pt", "9b-it", // Gemma2 9B - "27b-pt", "27b-it", // Gemma2 27B - "paligemma-224", // PaliGemma 224 - "paligemma-448", // PaliGemma 448 - "paligemma2-3b-224", // PaliGemma2 3B 224 - "paligemma2-3b-448", // PaliGemma2 3B 448 - "paligemma2-10b-224", // PaliGemma2 10B 224 - "paligemma2-10b-448", // PaliGemma2 10B 448 - "gemma3-4b", // Gemma3 4B - "gemma3-1b", // Gemma3 1B - "gemma3-12b", // Gemma3 12B - "gemma3-27b", // Gemma3 27B -}; -constexpr Model kModelTypes[] = { - Model::GEMMA_2B, Model::GEMMA_2B, // Gemma 2B - Model::GEMMA_7B, Model::GEMMA_7B, // Gemma 7B - Model::GRIFFIN_2B, Model::GRIFFIN_2B, // RecurrentGemma - Model::GEMMA_TINY, // Gemma Tiny - Model::GEMMA2_2B, Model::GEMMA2_2B, // Gemma2 2B - Model::GEMMA2_9B, Model::GEMMA2_9B, // Gemma2 9B - Model::GEMMA2_27B, Model::GEMMA2_27B, // Gemma2 27B - Model::PALIGEMMA_224, // PaliGemma 224 - Model::PALIGEMMA_448, // PaliGemma 448 - Model::PALIGEMMA2_3B_224, // PaliGemma2 3B 224 - Model::PALIGEMMA2_3B_448, // PaliGemma2 3B 448 - Model::PALIGEMMA2_10B_224, // PaliGemma2 10B 224 - Model::PALIGEMMA2_10B_448, // PaliGemma2 10B 448 - Model::GEMMA3_4B, // Gemma3 4B - Model::GEMMA3_1B, // Gemma3 1B - Model::GEMMA3_12B, // Gemma3 12B - Model::GEMMA3_27B, // Gemma3 27B -}; -constexpr PromptWrapping kPromptWrapping[] = { - PromptWrapping::GEMMA_PT, PromptWrapping::GEMMA_IT, // Gemma 2B - PromptWrapping::GEMMA_PT, PromptWrapping::GEMMA_IT, // Gemma 7B - PromptWrapping::GEMMA_PT, PromptWrapping::GEMMA_IT, // RecurrentGemma - PromptWrapping::GEMMA_IT, // Gemma Tiny - PromptWrapping::GEMMA_PT, PromptWrapping::GEMMA_IT, // Gemma2 2B - PromptWrapping::GEMMA_PT, PromptWrapping::GEMMA_IT, // Gemma2 9B - PromptWrapping::GEMMA_PT, PromptWrapping::GEMMA_IT, // Gemma2 27B - PromptWrapping::PALIGEMMA, PromptWrapping::PALIGEMMA, // PaliGemma 224/448 - PromptWrapping::PALIGEMMA, PromptWrapping::PALIGEMMA, // PG2 3B 224/448 - PromptWrapping::PALIGEMMA, PromptWrapping::PALIGEMMA, // PG2 10B 224/448 - PromptWrapping::GEMMA_VLM, // Gemma3 4B - PromptWrapping::GEMMA_PT, // Gemma3 1B - PromptWrapping::GEMMA_VLM, // Gemma3 12B - PromptWrapping::GEMMA_VLM, // Gemma3 27B -}; - -constexpr size_t kNumModelFlags = std::size(kModelFlags); -static_assert(kNumModelFlags == std::size(kModelTypes)); -static_assert(kNumModelFlags == std::size(kPromptWrapping)); - -const char* ParseModelTypeAndWrapping(const std::string& model_flag, - Model& model, PromptWrapping& wrapping) { - static std::string kErrorMessageBuffer = - "Invalid or missing model flag, need to specify one of "; - for (size_t i = 0; i + 1 < kNumModelFlags; ++i) { - kErrorMessageBuffer.append(kModelFlags[i]); - kErrorMessageBuffer.append(", "); - } - kErrorMessageBuffer.append(kModelFlags[kNumModelFlags - 1]); - kErrorMessageBuffer.append("."); - std::string model_type_lc = model_flag; - std::transform(model_type_lc.begin(), model_type_lc.end(), - model_type_lc.begin(), ::tolower); - for (size_t i = 0; i < kNumModelFlags; ++i) { - if (kModelFlags[i] == model_type_lc) { - model = kModelTypes[i]; - wrapping = kPromptWrapping[i]; - HWY_ASSERT(std::string(ModelString(model, wrapping)) == model_type_lc); - return nullptr; - } - } - return kErrorMessageBuffer.c_str(); -} - -const char* ModelString(Model model, PromptWrapping wrapping) { - for (size_t i = 0; i < kNumModelFlags; i++) { - if (kModelTypes[i] == model && kPromptWrapping[i] == wrapping) - return kModelFlags[i]; - } - HWY_ABORT("Unknown model %d wrapping %d\n", static_cast(model), - static_cast(wrapping)); -} - -const char* StringFromType(Type type) { - return kTypeStrings[static_cast(type)]; -} - -const char* ParseType(const std::string& type_string, Type& type) { - constexpr size_t kNum = std::size(kTypeStrings); - static std::string kErrorMessageBuffer = - "Invalid or missing type, need to specify one of "; - for (size_t i = 0; i + 1 < kNum; ++i) { - kErrorMessageBuffer.append(kTypeStrings[i]); - kErrorMessageBuffer.append(", "); - } - kErrorMessageBuffer.append(kTypeStrings[kNum - 1]); - kErrorMessageBuffer.append("."); - std::string type_lc = type_string; - std::transform(type_lc.begin(), type_lc.end(), type_lc.begin(), ::tolower); - for (size_t i = 0; i < kNum; ++i) { - if (kTypeStrings[i] == type_lc) { - type = static_cast(i); - HWY_ASSERT(std::string(StringFromType(type)) == type_lc); - return nullptr; - } - } - return kErrorMessageBuffer.c_str(); -} - -void Wrap(const ModelInfo& info, size_t pos, std::string& prompt) { - - // Instruction-tuned models are trained to expect control tokens. - if (info.wrapping == PromptWrapping::GEMMA_IT) { - // Prepend "" if this is a multi-turn dialogue continuation. - const std::string start = (pos == 0) - ? "user\n" - : "\nuser\n"; - prompt = start + prompt + "\nmodel\n"; - } -} - -float EmbeddingScaling(size_t model_dim) { - // Round to bf16 to match Gemma's Embedder, which casts before mul. - return hwy::ConvertScalarTo(hwy::ConvertScalarTo( - sqrtf(static_cast(model_dim)))); -} - -float ChooseQueryScale(const ModelConfig& config) { - if (config.query_scale == QueryScaleType::SqrtModelDimDivNumHeads) - return 1.0f / sqrtf(static_cast(config.model_dim / - config.layer_configs[0].heads)); - // QueryScaleType::SqrtKeySize - return 1.0f / sqrtf(static_cast(config.layer_configs[0].qkv_dim)); -} - -} // namespace gcpp diff --git a/gemma/common.h b/gemma/common.h deleted file mode 100644 index 984b0ba..0000000 --- a/gemma/common.h +++ /dev/null @@ -1,57 +0,0 @@ -// Copyright 2024 Google LLC -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_H_ -#define THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_H_ - -#include - -#include - -#include "compression/shared.h" // PromptWrapping -#include "gemma/configs.h" // IWYU pragma: export -#include "hwy/base.h" // ConvertScalarTo - -namespace gcpp { - -// Struct to bundle model information. -struct ModelInfo { - Model model; - PromptWrapping wrapping; - Type weight; -}; - -// Returns error string or nullptr if OK. -// Thread-hostile. -const char* ParseModelTypeAndWrapping(const std::string& model_flag, - Model& model, PromptWrapping& wrapping); -const char* ParseType(const std::string& type_string, Type& type); - -// Inverse of ParseModelTypeAndWrapping. -const char* ModelString(Model model, PromptWrapping wrapping); -const char* StringFromType(Type type); - -// Wraps the given prompt using the expected control tokens for IT models. -void Wrap(const ModelInfo& info, size_t pos, std::string& prompt); - -// Returns the scale value to use for the embedding (basically sqrt model_dim). -float EmbeddingScaling(size_t model_dim); - -// Returns the scale value to use for the query in the attention computation. -float ChooseQueryScale(const ModelConfig& config); - -} // namespace gcpp - -#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_H_ diff --git a/gemma/configs.cc b/gemma/configs.cc index 276c8f9..562500d 100644 --- a/gemma/configs.cc +++ b/gemma/configs.cc @@ -15,22 +15,31 @@ #include "gemma/configs.h" -#include -#include +#include +#include +#include +#include + +#include "compression/types.h" // Type +#include "io/fields.h" // IFields +#include "io/io.h" // Path #include "hwy/base.h" namespace gcpp { +static constexpr size_t kVocabSize = 256000; + +static constexpr size_t kGemmaV3VocabSize = 262144; + static ModelConfig ConfigNoSSM() { ModelConfig config; - config.scale_names = {"att_ein", "qkv_ein", "gr_lin_x_w", "gr_lin_y_w", - "gr_lin_out_w", "gr_gate_w", "gating_ein", "linear_w"}; + config.scale_base_names = {"att_ein", "qkv_ein", "gr_lin_x_w", + "gr_lin_y_w", "gr_lin_out_w", "gr_gate_w", + "gating_ein", "linear_w"}; return config; } -static ModelConfig ConfigBaseGemmaV1() { return ConfigNoSSM(); } - static ModelConfig ConfigBaseGemmaV2() { ModelConfig config = ConfigNoSSM(); config.att_cap = 50.0f; @@ -54,17 +63,17 @@ static LayerConfig LayerConfigGemma2_27B(size_t model_dim) { static ModelConfig ConfigGemma2_27B() { ModelConfig config = ConfigBaseGemmaV2(); - config.model_name = "Gemma2_27B"; + config.display_name = "Gemma2_27B"; config.model = Model::GEMMA2_27B; config.model_dim = 4608; config.vocab_size = kVocabSize; - config.seq_len = 8192; + config.max_seq_len = 8192; LayerConfig layer_config = LayerConfigGemma2_27B(config.model_dim); - config.layer_configs = {46, layer_config}; - config.num_tensor_scales = 4 * config.layer_configs.size(); + config.num_layers = 46; + config.layer_configs = {config.num_layers, layer_config}; config.query_scale = QueryScaleType::SqrtModelDimDivNumHeads; config.attention_window_sizes = - RepeatedAttentionWindowSizes<46, 2>({4096, 8192}); + RepeatedAttentionWindowSizes<46, 2>({4096, config.max_seq_len}); return config; } @@ -82,17 +91,17 @@ static LayerConfig LayerConfigGemma2_9B(size_t model_dim) { static ModelConfig ConfigGemma2_9B() { ModelConfig config = ConfigBaseGemmaV2(); - config.model_name = "Gemma2_9B"; + config.display_name = "Gemma2_9B"; config.model = Model::GEMMA2_9B; config.model_dim = 3584; config.vocab_size = kVocabSize; - config.seq_len = 8192; + config.max_seq_len = 8192; LayerConfig layer_config = LayerConfigGemma2_9B(config.model_dim); - config.layer_configs = {42, layer_config}; - config.num_tensor_scales = 4 * config.layer_configs.size(); + config.num_layers = 42; + config.layer_configs = {config.num_layers, layer_config}; config.query_scale = QueryScaleType::SqrtKeySize; config.attention_window_sizes = - RepeatedAttentionWindowSizes<42, 2>({4096, 8192}); + RepeatedAttentionWindowSizes<42, 2>({4096, config.max_seq_len}); return config; } @@ -110,66 +119,17 @@ static LayerConfig LayerConfigGemma2_2B(size_t model_dim) { static ModelConfig ConfigGemma2_2B() { ModelConfig config = ConfigBaseGemmaV2(); - config.model_name = "Gemma2_2B"; + config.display_name = "Gemma2_2B"; config.model = Model::GEMMA2_2B; config.model_dim = 2304; config.vocab_size = kVocabSize; - config.seq_len = 8192; + config.max_seq_len = 8192; LayerConfig layer_config = LayerConfigGemma2_2B(config.model_dim); - config.layer_configs = {26, layer_config}; - config.num_tensor_scales = 4 * config.layer_configs.size(); + config.num_layers = 26; + config.layer_configs = {config.num_layers, layer_config}; config.query_scale = QueryScaleType::SqrtKeySize; config.attention_window_sizes = - RepeatedAttentionWindowSizes<26, 2>({4096, 8192}); - return config; -} - -static LayerConfig LayerConfigGemma7B(size_t model_dim) { - LayerConfig config; - config.model_dim = model_dim; - config.ff_hidden_dim = 16 * 3072 / 2; // = 24576 - config.heads = 16; - config.kv_heads = 16; - config.qkv_dim = 256; - return config; -} - -static ModelConfig ConfigGemma7B() { - ModelConfig config = ConfigBaseGemmaV1(); - config.model_name = "Gemma7B"; - config.model = Model::GEMMA_7B; - config.model_dim = 3072; - config.vocab_size = kVocabSize; - config.seq_len = kSeqLen; - LayerConfig layer_config = LayerConfigGemma7B(config.model_dim); - config.layer_configs = {28, layer_config}; - config.num_tensor_scales = 4 * config.layer_configs.size(); - config.query_scale = QueryScaleType::SqrtKeySize; - config.attention_window_sizes = FixedAttentionWindowSizes<28>(kSeqLen); - return config; -} - -static LayerConfig LayerConfigGemma2B(size_t model_dim) { - LayerConfig config; - config.model_dim = model_dim; - config.ff_hidden_dim = 16 * 2048 / 2; // = 16384 - config.heads = 8; - config.kv_heads = 1; - config.qkv_dim = 256; - return config; -} - -static ModelConfig ConfigGemma2B() { - ModelConfig config = ConfigBaseGemmaV1(); - config.model_name = "Gemma2B"; - config.model = Model::GEMMA_2B; - config.model_dim = 2048; - config.vocab_size = kVocabSize; - config.seq_len = kSeqLen; - LayerConfig layer_config = LayerConfigGemma2B(config.model_dim); - config.layer_configs = {18, layer_config}; - config.num_tensor_scales = 4 * config.layer_configs.size(); - config.attention_window_sizes = FixedAttentionWindowSizes<18>(kSeqLen); + RepeatedAttentionWindowSizes<26, 2>({4096, config.max_seq_len}); return config; } @@ -185,17 +145,18 @@ static LayerConfig LayerConfigGemmaTiny(size_t model_dim) { static ModelConfig ConfigGemmaTiny() { ModelConfig config = ConfigNoSSM(); - config.model_name = "GemmaTiny"; + config.display_name = "GemmaTiny"; config.model = Model::GEMMA_TINY; - config.model_dim = 128; - config.vocab_size = 64; - config.seq_len = 32; + config.wrapping = PromptWrapping::GEMMA_IT; + config.model_dim = 32; + config.vocab_size = 32; // at least two f32 vectors + config.max_seq_len = 32; LayerConfig layer_config = LayerConfigGemmaTiny(config.model_dim); - config.layer_configs = {3, layer_config}; - config.num_tensor_scales = 4 * config.layer_configs.size(); + config.num_layers = 2; + config.layer_configs = {config.num_layers, layer_config}; config.query_scale = QueryScaleType::SqrtKeySize; - config.attention_window_sizes = FixedAttentionWindowSizes<3>(32); - // This is required for optimize_test to pass. + config.attention_window_sizes = FixedAttentionWindowSizes<2>(32); + config.att_cap = 50.0f; config.final_cap = 30.0f; config.eos_id = 11; config.secondary_eos_id = 11; @@ -223,23 +184,23 @@ static LayerConfig LayerConfigGriffin2B(size_t model_dim) { static ModelConfig ConfigGriffin2B() { ModelConfig config = ConfigNoSSM(); - config.model_name = "Griffin2B"; + config.display_name = "Griffin2B"; config.model = Model::GRIFFIN_2B; - // Griffin uses local attention, so kSeqLen is actually the local attention - // window. + // Griffin uses local attention, so max_seq_len is actually the local + // attention window. config.model_dim = 2560; config.vocab_size = kVocabSize; - config.seq_len = 2048; + config.max_seq_len = 2048; LayerConfig layer_config = LayerConfigGriffin2B(config.model_dim); - config.layer_configs = {26, layer_config}; - for (size_t i = 2; i < config.layer_configs.size(); i += 3) { + config.num_layers = 26; + config.layer_configs = {config.num_layers, layer_config}; + for (size_t i = 2; i < config.num_layers; i += 3) { config.layer_configs[i].type = LayerAttentionType::kGemma; config.layer_configs[i].griffin_dim = 0; } - config.num_tensor_scales = 140; - config.attention_window_sizes = FixedAttentionWindowSizes<26>(config.seq_len); + config.attention_window_sizes = + FixedAttentionWindowSizes<26>(config.max_seq_len); config.use_local_attention = true; - // This is required for optimize_test to pass. config.final_cap = 0.0f; return config; } @@ -273,26 +234,10 @@ static void AddVitConfig(ModelConfig& config, size_t image_size = 224) { config.vit_config.num_scales = 4 * config.vit_config.layer_configs.size(); } -static ModelConfig ConfigPaliGemma_224() { - ModelConfig config = ConfigGemma2B(); - config.model_name = "PaliGemma_224"; - config.model = Model::PALIGEMMA_224; - AddVitConfig(config); - return config; -} - -static ModelConfig ConfigPaliGemma_448() { - ModelConfig config = ConfigGemma2B(); - config.model_name = "PaliGemma_448"; - config.model = Model::PALIGEMMA_448; - AddVitConfig(config, /*image_size=*/448); - return config; -} - ModelConfig GetVitConfig(const ModelConfig& config) { ModelConfig vit_config = ConfigNoSSM(); vit_config.model_dim = config.vit_config.model_dim; - vit_config.seq_len = config.vit_config.seq_len; + vit_config.max_seq_len = config.vit_config.seq_len; vit_config.layer_configs = config.vit_config.layer_configs; vit_config.pool_dim = config.vit_config.pool_dim; vit_config.wrapping = config.wrapping; @@ -303,32 +248,36 @@ ModelConfig GetVitConfig(const ModelConfig& config) { static ModelConfig ConfigPaliGemma2_3B_224() { ModelConfig config = ConfigGemma2_2B(); - config.model_name = "PaliGemma2_3B_224"; + config.display_name = "PaliGemma2_3B_224"; config.model = Model::PALIGEMMA2_3B_224; + config.wrapping = PromptWrapping::PALIGEMMA; AddVitConfig(config); return config; } static ModelConfig ConfigPaliGemma2_3B_448() { ModelConfig config = ConfigGemma2_2B(); - config.model_name = "PaliGemma2_3B_448"; + config.display_name = "PaliGemma2_3B_448"; config.model = Model::PALIGEMMA2_3B_448; + config.wrapping = PromptWrapping::PALIGEMMA; AddVitConfig(config, /*image_size=*/448); return config; } static ModelConfig ConfigPaliGemma2_10B_224() { ModelConfig config = ConfigGemma2_9B(); - config.model_name = "PaliGemma2_10B_224"; + config.display_name = "PaliGemma2_10B_224"; config.model = Model::PALIGEMMA2_10B_224; + config.wrapping = PromptWrapping::PALIGEMMA; AddVitConfig(config); return config; } static ModelConfig ConfigPaliGemma2_10B_448() { ModelConfig config = ConfigGemma2_9B(); - config.model_name = "PaliGemma2_10B_448"; + config.display_name = "PaliGemma2_10B_448"; config.model = Model::PALIGEMMA2_10B_448; + config.wrapping = PromptWrapping::PALIGEMMA; AddVitConfig(config, /*image_size=*/448); return config; } @@ -358,18 +307,19 @@ static LayerConfig LayerConfigGemma3_1B_LM(size_t model_dim) { static ModelConfig ConfigGemma3_1B() { ModelConfig config = ConfigBaseGemmaV3(); - config.model_name = "Gemma3_1B"; + config.display_name = "Gemma3_1B"; config.model = Model::GEMMA3_1B; + config.wrapping = PromptWrapping::GEMMA_VLM; config.model_dim = 1152; - config.vocab_size = 262144; // new vocab size / tokenizer - config.seq_len = 32 * 1024; + config.vocab_size = kGemmaV3VocabSize; // new vocab size / tokenizer + config.max_seq_len = 32 * 1024; LayerConfig layer_config = LayerConfigGemma3_1B_LM(config.model_dim); - config.layer_configs = {26, layer_config}; - config.num_tensor_scales = 4 * config.layer_configs.size(); + config.num_layers = 26; + config.layer_configs = {config.num_layers, layer_config}; config.query_scale = QueryScaleType::SqrtKeySize; // interleaved local / global attention config.attention_window_sizes = RepeatedAttentionWindowSizes<26, 6>( - {512, 512, 512, 512, 512, config.seq_len}); + {512, 512, 512, 512, 512, config.max_seq_len}); return config; } @@ -389,27 +339,29 @@ static LayerConfig LayerConfigGemma3_4B_LM(size_t model_dim) { // Until we have the SigLIP checkpoints included, we use the LM config directly. static ModelConfig ConfigGemma3_4B_LM() { ModelConfig config = ConfigBaseGemmaV3(); - config.model_name = "Gemma3_4B"; + config.display_name = "Gemma3_4B"; config.model = Model::GEMMA3_4B; + config.wrapping = PromptWrapping::GEMMA_VLM; config.model_dim = 2560; - config.vocab_size = 262144; // new vocab size / tokenizer - config.seq_len = 32 * 1024; + config.vocab_size = kGemmaV3VocabSize; // new vocab size / tokenizer + config.max_seq_len = 32 * 1024; LayerConfig layer_config = LayerConfigGemma3_4B_LM(config.model_dim); - config.layer_configs = {34, layer_config}; - config.num_tensor_scales = 4 * config.layer_configs.size(); + config.num_layers = 34; + config.layer_configs = {config.num_layers, layer_config}; config.query_scale = QueryScaleType::SqrtKeySize; // interleaved local / global attention config.attention_window_sizes = RepeatedAttentionWindowSizes<34, 6>( - {1024, 1024, 1024, 1024, 1024, config.seq_len}); + {1024, 1024, 1024, 1024, 1024, config.max_seq_len}); return config; } static ModelConfig ConfigGemma3_4B() { ModelConfig config = ConfigGemma3_4B_LM(); - config.model_name = "Gemma3_4B"; + config.display_name = "Gemma3_4B"; config.model = Model::GEMMA3_4B; + config.wrapping = PromptWrapping::GEMMA_VLM; AddVitConfig(config, /*image_size=*/896); - config.vocab_size = 262144; + config.vocab_size = kGemmaV3VocabSize; config.vit_config.pool_dim = 4; const size_t num_patches = config.vit_config.image_size / config.vit_config.patch_width; @@ -436,27 +388,29 @@ static LayerConfig LayerConfigGemma3_12B_LM(size_t model_dim) { static ModelConfig ConfigGemma3_12B_LM() { ModelConfig config = ConfigBaseGemmaV3(); - config.model_name = "Gemma3_12B"; + config.display_name = "Gemma3_12B"; config.model = Model::GEMMA3_12B; + config.wrapping = PromptWrapping::GEMMA_VLM; config.model_dim = 3840; - config.vocab_size = 262144; // new vocab size / tokenizer - config.seq_len = 32 * 1024; + config.vocab_size = kGemmaV3VocabSize; // new vocab size / tokenizer + config.max_seq_len = 32 * 1024; LayerConfig layer_config = LayerConfigGemma3_12B_LM(config.model_dim); - config.layer_configs = {48, layer_config}; - config.num_tensor_scales = 4 * config.layer_configs.size(); + config.num_layers = 48; + config.layer_configs = {config.num_layers, layer_config}; config.query_scale = QueryScaleType::SqrtKeySize; // interleaved local / global attention config.attention_window_sizes = RepeatedAttentionWindowSizes<48, 6>( - {1024, 1024, 1024, 1024, 1024, config.seq_len}); + {1024, 1024, 1024, 1024, 1024, config.max_seq_len}); return config; } static ModelConfig ConfigGemma3_12B() { ModelConfig config = ConfigGemma3_12B_LM(); - config.model_name = "Gemma3_12B"; + config.display_name = "Gemma3_12B"; config.model = Model::GEMMA3_12B; + config.wrapping = PromptWrapping::GEMMA_VLM; AddVitConfig(config, /*image_size=*/896); - config.vocab_size = 262144; + config.vocab_size = kGemmaV3VocabSize; config.vit_config.pool_dim = 4; const size_t num_patches = config.vit_config.image_size / config.vit_config.patch_width; @@ -483,27 +437,29 @@ static LayerConfig LayerConfigGemma3_27B_LM(size_t model_dim) { static ModelConfig ConfigGemma3_27B_LM() { ModelConfig config = ConfigBaseGemmaV3(); - config.model_name = "Gemma3_27B"; + config.display_name = "Gemma3_27B"; config.model = Model::GEMMA3_27B; + config.wrapping = PromptWrapping::GEMMA_VLM; config.model_dim = 5376; - config.vocab_size = 262144; // new vocab size / tokenizer - config.seq_len = 32 * 1024; + config.vocab_size = kGemmaV3VocabSize; // new vocab size / tokenizer + config.max_seq_len = 32 * 1024; LayerConfig layer_config = LayerConfigGemma3_27B_LM(config.model_dim); - config.layer_configs = {62, layer_config}; - config.num_tensor_scales = 4 * config.layer_configs.size(); + config.num_layers = 62; + config.layer_configs = {config.num_layers, layer_config}; config.query_scale = QueryScaleType::SqrtKeySize; // interleaved local / global attention config.attention_window_sizes = RepeatedAttentionWindowSizes<62, 6>( - {1024, 1024, 1024, 1024, 1024, config.seq_len}); + {1024, 1024, 1024, 1024, 1024, config.max_seq_len}); return config; } static ModelConfig ConfigGemma3_27B() { ModelConfig config = ConfigGemma3_27B_LM(); - config.model_name = "Gemma3_27B"; + config.display_name = "Gemma3_27B"; config.model = Model::GEMMA3_27B; + config.wrapping = PromptWrapping::GEMMA_VLM; AddVitConfig(config, /*image_size=*/896); - config.vocab_size = 262144; + config.vocab_size = kGemmaV3VocabSize; config.vit_config.pool_dim = 4; const size_t num_patches = config.vit_config.image_size / config.vit_config.patch_width; @@ -515,12 +471,8 @@ static ModelConfig ConfigGemma3_27B() { return config; } -ModelConfig ConfigFromModel(Model model) { +static ModelConfig ConfigFromModel(Model model) { switch (model) { - case Model::GEMMA_2B: - return ConfigGemma2B(); - case Model::GEMMA_7B: - return ConfigGemma7B(); case Model::GEMMA2_2B: return ConfigGemma2_2B(); case Model::GEMMA2_9B: @@ -531,10 +483,6 @@ ModelConfig ConfigFromModel(Model model) { return ConfigGriffin2B(); case Model::GEMMA_TINY: return ConfigGemmaTiny(); - case Model::PALIGEMMA_224: - return ConfigPaliGemma_224(); - case Model::PALIGEMMA_448: - return ConfigPaliGemma_448(); case Model::PALIGEMMA2_3B_224: return ConfigPaliGemma2_3B_224(); case Model::PALIGEMMA2_3B_448: @@ -556,124 +504,249 @@ ModelConfig ConfigFromModel(Model model) { } } -#define TEST_EQUAL(a, b) \ - if (a != b) { \ - if (debug) \ - std::cerr << #a << "=" << a << " != " << #b << "=" << b << "\n"; \ - result = false; \ +const char* ModelPrefix(Model model) { + switch (model) { + case Model::UNKNOWN: + return "unknown"; + case Model::GEMMA2_2B: + return "gemma2-2b"; + case Model::GEMMA2_9B: + return "9b"; + case Model::GEMMA2_27B: + return "27b"; + case Model::GRIFFIN_2B: + return "gr2b"; + case Model::GEMMA_TINY: + return "tiny"; + case Model::PALIGEMMA2_3B_224: + return "paligemma2-3b-224"; + case Model::PALIGEMMA2_3B_448: + return "paligemma2-3b-448"; + case Model::PALIGEMMA2_10B_224: + return "paligemma2-10b-224"; + case Model::PALIGEMMA2_10B_448: + return "paligemma2-10b-448"; + case Model::GEMMA3_4B: + return "gemma3-4b"; + case Model::GEMMA3_1B: + return "gemma3-1b"; + case Model::GEMMA3_12B: + return "gemma3-12b"; + case Model::GEMMA3_27B: + return "gemma3-27b"; + default: + HWY_ABORT("Model type %d unknown.", static_cast(model)); } - -#define RETURN_IF_NOT_EQUAL(a, b) \ - if (a != b) { \ - if (debug) \ - std::cerr << #a << "=" << a << " != " << #b << "=" << b << "\n"; \ - return false; \ - } - -#define WARN_IF_NOT_EQUAL(a, b) \ - if (a != b) { \ - std::cerr << #a << "=" << a << " != " << #b << "=" << b << "\n"; \ - } - -bool LayerConfig::TestEqual(const LayerConfig& other, bool partial, - bool debug) const { - bool result = true; - // Optimized gating may not be set correctly in the c++ configs. - if (debug) { - WARN_IF_NOT_EQUAL(optimized_gating, other.optimized_gating) - } - TEST_EQUAL(model_dim, other.model_dim); - TEST_EQUAL(griffin_dim, other.griffin_dim); - TEST_EQUAL(ff_hidden_dim, other.ff_hidden_dim); - TEST_EQUAL(heads, other.heads); - TEST_EQUAL(kv_heads, other.kv_heads); - TEST_EQUAL(qkv_dim, other.qkv_dim); - TEST_EQUAL(conv1d_width, other.conv1d_width); - if (!partial) { - TEST_EQUAL(ff_biases, other.ff_biases); - TEST_EQUAL(softmax_attn_output_biases, other.softmax_attn_output_biases); - } - TEST_EQUAL(static_cast(post_norm), static_cast(other.post_norm)); - TEST_EQUAL(static_cast(type), static_cast(other.type)); - TEST_EQUAL(static_cast(activation), static_cast(other.activation)); - TEST_EQUAL(static_cast(post_qk), static_cast(other.post_qk)); - return result; } -bool VitConfig::TestEqual(const VitConfig& other, bool partial, - bool debug) const { - bool result = true; - TEST_EQUAL(model_dim, other.model_dim); - TEST_EQUAL(seq_len, other.seq_len); - if (!partial) { - TEST_EQUAL(num_scales, other.num_scales); +PromptWrapping ChooseWrapping(const Model model, Tristate wrapping) { + if (IsPaliGemma(model)) { + if (wrapping != Tristate::kDefault) { + HWY_WARN("Ignoring unnecessary --wrapping for PaliGemma models."); + } + return PromptWrapping::PALIGEMMA; } - TEST_EQUAL(patch_width, other.patch_width); - TEST_EQUAL(image_size, other.image_size); - RETURN_IF_NOT_EQUAL(layer_configs.size(), other.layer_configs.size()); - for (size_t i = 0; i < layer_configs.size(); ++i) { - result &= - layer_configs[i].TestEqual(other.layer_configs[i], partial, debug); + if (IsVLM(model)) { + if (wrapping != Tristate::kDefault) { + HWY_WARN("Ignoring unnecessary --wrapping for VLM models."); + } + return PromptWrapping::GEMMA_VLM; } - return result; + // Default to IT unless --wrapping=0. + return wrapping == Tristate::kFalse ? PromptWrapping::GEMMA_PT + : PromptWrapping::GEMMA_IT; } -bool ModelConfig::TestEqual(const ModelConfig& other, bool partial, - bool debug) const { - bool result = true; - TEST_EQUAL(model_family_version, other.model_family_version); - // We don't care about model_name, model, wrapping, or weight being different, - // but will output in debug mode if they are. - if (debug) { - WARN_IF_NOT_EQUAL(model_name, other.model_name); - WARN_IF_NOT_EQUAL(static_cast(model), static_cast(other.model)); - WARN_IF_NOT_EQUAL(static_cast(wrapping), - static_cast(other.wrapping)); - WARN_IF_NOT_EQUAL(static_cast(weight), static_cast(other.weight)); +ModelConfig::ModelConfig(const Model model, Type weight, + PromptWrapping wrapping) { + HWY_ASSERT(weight != Type::kUnknown); + HWY_ASSERT(wrapping != PromptWrapping::kSentinel); + this->model = model; + if (model != Model::UNKNOWN) *this = ConfigFromModel(model); + HWY_ASSERT(this->model == model); + this->weight = weight; + this->wrapping = wrapping; +} + +static Model FindModel(const std::string& specifier) { + Model found_model = Model::UNKNOWN; + ForEachModel([&](Model model) { + // Some model names are prefixes of other model names + const std::string prefix = std::string(ModelPrefix(model)) + "-"; + if (specifier.rfind(prefix, 0) == 0) { // Starts with prefix. + // We only expect one match. + HWY_ASSERT_M(found_model == Model::UNKNOWN, specifier.c_str()); + found_model = model; + } + }); + HWY_ASSERT_M(found_model != Model::UNKNOWN, specifier.c_str()); + return found_model; +} + +static Type FindType(const std::string& specifier) { + Type found_type = Type::kUnknown; + for (size_t i = 1; i < kNumTypes; ++i) { + const Type type = static_cast(i); + if (specifier.find(TypeName(type)) != std::string::npos) { // NOLINT + // We only expect one match. + HWY_ASSERT_M(found_type == Type::kUnknown, specifier.c_str()); + found_type = type; + } } - TEST_EQUAL(model_dim, other.model_dim); - TEST_EQUAL(vocab_size, other.vocab_size); - TEST_EQUAL(seq_len, other.seq_len); - if (!partial) { - TEST_EQUAL(num_tensor_scales, other.num_tensor_scales); + HWY_ASSERT_M(found_type != Type::kUnknown, specifier.c_str()); + return found_type; +} + +static PromptWrapping FindWrapping(const std::string& specifier) { + PromptWrapping found_wrapping = PromptWrapping::kSentinel; + for (size_t i = 0; i < static_cast(PromptWrapping::kSentinel); ++i) { + const PromptWrapping w = static_cast(i); + if (specifier.find(WrappingSuffix(w)) != std::string::npos) { // NOLINT + // We expect zero or one match. + HWY_ASSERT_M(found_wrapping == PromptWrapping::kSentinel, + specifier.c_str()); + found_wrapping = w; + } } - TEST_EQUAL(att_cap, other.att_cap); - TEST_EQUAL(final_cap, other.final_cap); - TEST_EQUAL(absolute_pe, other.absolute_pe); - TEST_EQUAL(use_local_attention, other.use_local_attention); - TEST_EQUAL(static_cast(query_scale), - static_cast(other.query_scale)); - RETURN_IF_NOT_EQUAL(layer_configs.size(), other.layer_configs.size()); - for (size_t i = 0; i < layer_configs.size(); ++i) { - result &= - layer_configs[i].TestEqual(other.layer_configs[i], partial, debug); + if (found_wrapping == PromptWrapping::kSentinel) { + return ChooseWrapping(FindModel(specifier)); } - RETURN_IF_NOT_EQUAL(attention_window_sizes.size(), - other.attention_window_sizes.size()); - for (size_t i = 0; i < attention_window_sizes.size(); ++i) { - TEST_EQUAL(attention_window_sizes[i], other.attention_window_sizes[i]); + return found_wrapping; +} + +// Obtains model/weight/wrapping by finding prefix and suffix strings. +ModelConfig::ModelConfig(const std::string& specifier) + : ModelConfig(FindModel(specifier), FindType(specifier), + FindWrapping(specifier)) {} + +std::string ModelConfig::Specifier() const { + HWY_ASSERT(model != Model::UNKNOWN); + HWY_ASSERT(weight != Type::kUnknown); + HWY_ASSERT(wrapping != PromptWrapping::kSentinel); + + std::string base_name = ModelPrefix(model); + + base_name += '-'; + base_name += TypeName(weight); + + if (wrapping != PromptWrapping::GEMMA_VLM && + wrapping != PromptWrapping::PALIGEMMA) { + base_name += WrappingSuffix(wrapping); } - if (!partial) { - if (scale_names != other.scale_names) { - result = false; - if (debug) { - std::cerr << "scale_names mismatch\n"; + + return base_name; +} + +// Returns whether all fields match. +static bool AllEqual(const IFields& a, const IFields& b, bool print) { + const std::vector serialized_a = a.Write(); + const std::vector serialized_b = b.Write(); + if (serialized_a != serialized_b) { + if (print) { + fprintf(stderr, "%s differs. Recommend generating a diff:\n", a.Name()); + a.Print(); + b.Print(); + } + return false; + } + return true; +} + +bool LayerConfig::TestEqual(const LayerConfig& other, bool print) const { + return AllEqual(*this, other, print); +} + +bool VitConfig::TestEqual(const VitConfig& other, bool print) const { + return AllEqual(*this, other, print); +} + +bool ModelConfig::TestEqual(const ModelConfig& other, bool print) const { + // Early out to guard the loop below; a differing number of layers will anyway + // cause a mismatch. + if (layer_configs.size() != other.layer_configs.size()) { + if (print) { + HWY_WARN("Layer configs size mismatch %zu vs %zu", layer_configs.size(), + other.layer_configs.size()); + } + return false; + } + + // Copy so we can 'ignore' fields by setting them to the same value. + ModelConfig a = *this; + ModelConfig b = other; + // Called by `OverwriteWithCanonical`, so ignore the fields it will set. + a.display_name = b.display_name; + a.model = b.model; + + // The following are not yet set by config_converter.py, so we here ignore + // them for purposes of comparison, and there overwrite the converter's config + // with the canonical ModelConfig constructed via (deduced) enum, so that + // these fields will be set. + // `vit_config` is also not yet set, but we must not ignore it because + // otherwise PaliGemma models will be indistinguishable for `configs_test`. + a.pool_dim = b.pool_dim; // ViT + a.eos_id = b.eos_id; + a.secondary_eos_id = b.secondary_eos_id; + a.scale_base_names = b.scale_base_names; + for (size_t i = 0; i < a.layer_configs.size(); ++i) { + a.layer_configs[i].optimized_gating = b.layer_configs[i].optimized_gating; + } + + return AllEqual(a, b, print); +} + +// Constructs the canonical ModelConfig for each model. If there is one for +// which TestEqual returns true, overwrites `*this` with that and returns true. +bool ModelConfig::OverwriteWithCanonical() { + bool found = false; + const bool print = false; + ForEachModel([&](Model model) { + const ModelConfig config(model, weight, wrapping); + if (config.TestEqual(*this, print)) { + HWY_ASSERT(!found); // Should only find one. + found = true; + *this = config; + } + }); + return found; +} + +Model DeduceModel(const Path& blob_path, size_t layers, int layer_types) { + switch (layers) { + case 2: + return Model::GEMMA_TINY; + case 26: + if (layer_types & kDeducedGriffin) return Model::GRIFFIN_2B; + if (layer_types & kDeducedViT) return Model::GEMMA3_1B; + return Model::GEMMA2_2B; + case 27: + return (layer_types & kDeduced448) ? Model::PALIGEMMA2_3B_448 + : Model::PALIGEMMA2_3B_224; + case 34: + return Model::GEMMA3_4B; + case 42: + if (layer_types & kDeducedViT) { + return (layer_types & kDeduced448) ? Model::PALIGEMMA2_10B_448 + : Model::PALIGEMMA2_10B_224; } - } - } - TEST_EQUAL(norm_num_groups, other.norm_num_groups); - result &= vit_config.TestEqual(other.vit_config, partial, debug); - return result; -} + return Model::GEMMA2_9B; + case 46: + return Model::GEMMA2_27B; + case 48: + return Model::GEMMA3_12B; + case 62: + return Model::GEMMA3_27B; -Model ModelFromConfig(const ModelConfig& config) { - for (Model model : kAllModels) { - ModelConfig model_config = ConfigFromModel(model); - if (config.TestEqual(model_config, /*partial=*/true, /*debug=*/false)) { - return model; - } + // TODO: detect these. + /* + return Model::GEMMA2_772M; + return Model::PALIGEMMA2_772M_224; + */ + default: + HWY_WARN("Failed to deduce model type from %s, layer count %zu types %x.", + blob_path.path.c_str(), layers, layer_types); + return Model::UNKNOWN; } - return Model::UNKNOWN; } } // namespace gcpp diff --git a/gemma/configs.h b/gemma/configs.h index 837e067..19e6278 100644 --- a/gemma/configs.h +++ b/gemma/configs.h @@ -19,35 +19,52 @@ // Model configurations #include +#include -#include #include -#include #include -#include #include -#include "compression/fields.h" // IFieldsVisitor -#include "compression/shared.h" // BF16 +#include "compression/types.h" // Type +#include "io/fields.h" // IFieldsVisitor +#include "io/io.h" // Path +#include "util/basics.h" namespace gcpp { -// Allow changing pre-allocated kv cache size as a compiler flag -#ifndef GEMMA_MAX_SEQLEN -#define GEMMA_MAX_SEQLEN 4096 -#endif // !GEMMA_MAX_SEQLEN - -// Allow changing k parameter of `SampleTopK` as a compiler flag -#ifndef GEMMA_TOPK -#define GEMMA_TOPK 1 -#endif // !GEMMA_TOPK - -static constexpr size_t kSeqLen = GEMMA_MAX_SEQLEN; -static constexpr size_t kTopK = GEMMA_TOPK; -static constexpr size_t kVocabSize = 256000; static constexpr size_t kMaxConv1DWidth = 4; +static constexpr size_t kMaxQKVDim = 1024; -using EmbedderInputT = BF16; +// Instruction-tuned models require extra 'turn structure' tokens in prompts. +enum class PromptWrapping { + GEMMA_IT, + GEMMA_PT, + GEMMA_VLM, // for >1B Gemma3 + PALIGEMMA, + kSentinel // must be last +}; + +// This is used in `ModelConfig.Specifier`, so the strings will not change, +// though new ones may be added. +static inline const char* WrappingSuffix(PromptWrapping wrapping) { + switch (wrapping) { + case PromptWrapping::GEMMA_IT: + return "-it"; + case PromptWrapping::GEMMA_PT: + return "-pt"; + case PromptWrapping::GEMMA_VLM: + return "-vlm"; + case PromptWrapping::PALIGEMMA: + return "-pg"; + default: + return "-?"; + } +} + +static inline bool EnumValid(PromptWrapping wrapping) { + return static_cast(wrapping) < + static_cast(PromptWrapping::kSentinel); +} enum class LayerAttentionType { kGemma, @@ -55,63 +72,68 @@ enum class LayerAttentionType { kVit, }; -inline bool EnumValid(LayerAttentionType type) { - return static_cast(type) >= 0 && - static_cast(type) <= static_cast(LayerAttentionType::kVit); +static inline bool EnumValid(LayerAttentionType type) { + return type == LayerAttentionType::kGemma || + type == LayerAttentionType::kGriffinRecurrentBlock || + type == LayerAttentionType::kVit; } // Post attention and ffw normalization type. enum class PostNormType { None, Scale, + kSentinel // must be last }; -inline bool EnumValid(PostNormType type) { - return static_cast(type) >= 0 && - static_cast(type) <= static_cast(PostNormType::Scale); +static inline bool EnumValid(PostNormType type) { + return static_cast(type) < + static_cast(PostNormType::kSentinel); } // Post qk projection operation type. enum class PostQKType { Rope, HalfRope, + kSentinel // must be last }; -inline bool EnumValid(PostQKType type) { - return static_cast(type) >= 0 && - static_cast(type) <= static_cast(PostQKType::HalfRope); +static inline bool EnumValid(PostQKType type) { + return static_cast(type) < + static_cast(PostNormType::kSentinel); } // FFW activation function. enum class ActivationType { Gelu, + kSentinel // must be last }; -inline bool EnumValid(ActivationType type) { - return static_cast(type) >= 0 && - static_cast(type) <= static_cast(ActivationType::Gelu); +static inline bool EnumValid(ActivationType type) { + return static_cast(type) < + static_cast(ActivationType::kSentinel); } // Attention query scale. enum class QueryScaleType { SqrtKeySize, SqrtModelDimDivNumHeads, + kSentinel // must be last }; -inline bool EnumValid(QueryScaleType type) { - return static_cast(type) >= 0 && - static_cast(type) <= - static_cast(QueryScaleType::SqrtModelDimDivNumHeads); +static inline bool EnumValid(QueryScaleType type) { + return static_cast(type) < + static_cast(QueryScaleType::kSentinel); } // Residual connection type. enum class ResidualType { Add, + kSentinel // must be last }; -inline bool EnumValid(ResidualType type) { - return static_cast(type) >= 0 && - static_cast(type) <= static_cast(ResidualType::Add); +static inline bool EnumValid(ResidualType type) { + return static_cast(type) < + static_cast(ResidualType::kSentinel); } template @@ -137,17 +159,15 @@ std::vector RepeatedAttentionWindowSizes( // Model variants: see configs.cc for details. enum class Model { - UNKNOWN, - GEMMA_2B, - GEMMA_7B, - GEMMA2_9B, + UNKNOWN = 0, + // 1 and 2 are obsolete. + GEMMA2_9B = 3, GEMMA2_27B, GRIFFIN_2B, - GEMMA_TINY, + GEMMA_TINY, // for testing only GEMMA2_2B, - PALIGEMMA_224, - PALIGEMMA_448, - PALIGEMMA2_3B_224, + // 8 and 9 are obsolete. + PALIGEMMA2_3B_224 = 10, PALIGEMMA2_3B_448, PALIGEMMA2_10B_224, PALIGEMMA2_10B_448, @@ -155,43 +175,64 @@ enum class Model { GEMMA3_1B, GEMMA3_12B, GEMMA3_27B, + kSentinel, }; -// Allows the Model enum to be iterated over. -static constexpr Model kAllModels[] = { - Model::GEMMA_2B, Model::GEMMA_7B, Model::GEMMA2_9B, Model::GEMMA2_27B, - Model::GRIFFIN_2B, Model::GEMMA_TINY, Model::GEMMA2_2B, - Model::PALIGEMMA_224, Model::PALIGEMMA_448, Model::PALIGEMMA2_3B_224, - Model::PALIGEMMA2_3B_448, Model::PALIGEMMA2_10B_224, - Model::PALIGEMMA2_10B_448, Model::GEMMA3_4B, Model::GEMMA3_1B, - Model::GEMMA3_12B, Model::GEMMA3_27B, -}; +// Returns canonical model name without the PromptWrapping suffix. This is used +// in Specifier and thus does not change. +const char* ModelPrefix(Model model); -inline bool EnumValid(Model model) { - for (Model m : kAllModels) { - if (m == model) return true; +// Gemma3 is multimodal and has a different prompt wrapping than PaliGemma. +// This is used for deducing the PromptWrapping for pre-2025 BlobStore. +static inline bool IsVLM(Model model) { + return model == Model::GEMMA3_4B || model == Model::GEMMA3_1B || + model == Model::GEMMA3_12B || model == Model::GEMMA3_27B; +} + +static inline bool IsPaliGemma(Model model) { + if (model == Model::PALIGEMMA2_3B_224 || model == Model::PALIGEMMA2_3B_448 || + model == Model::PALIGEMMA2_10B_224 || + model == Model::PALIGEMMA2_10B_448) { + return true; } return false; } +// Visits every valid model enum, skipping `UNKNOWN` and `kSentinel`. +template +void ForEachModel(const Func& func) { + for (size_t i = static_cast(Model::GEMMA2_9B); + i < static_cast(Model::kSentinel); ++i) { + if (i == 8 || i == 9) continue; + func(static_cast(i)); + } +} + +static inline bool EnumValid(Model model) { + // Valid for purposes of serialization, even if unknown. + if (model == Model::UNKNOWN) return true; + const size_t i = static_cast(model); + if (i >= static_cast(Model::GEMMA2_9B) && + i < static_cast(Model::kSentinel) && i != 8 && i != 9) { + return true; + } + return false; +} + +struct InternalLayerConfig : public IFields { + const char* Name() const override { return "InternalLayerConfig"; } + + // Source of truth for field ordering. + void VisitFields(IFieldsVisitor& visitor) override { + // Append new fields here, then update `python/configs.cc`. + } +}; + +// Per-layer configuration. struct LayerConfig : public IFields { - // Returns true if *this and other are equal. - // If partial is true, then we don't check for items that are only set after - // the tensors are loaded from the checkpoint. - // If debug is true, then we output the mismatched fields to stderr. - bool TestEqual(const LayerConfig& other, bool partial, bool debug) const; - - size_t CacheLayerSize() const { return kv_heads * qkv_dim * 2; } - - // Multi-Head Attention? - bool IsMHA() const { return heads == kv_heads; } - - // Stride between subsequent queries. Each of Q, K, V are of length kQKVDim, - // but for MHA we store them as Q,K,V, Q,K,V, .. instead of Q..Q, K..K, V..V. - size_t QStride() const { return qkv_dim * (IsMHA() ? 3 : 1); } - const char* Name() const override { return "LayerConfig"; } + // Source of truth for field ordering. void VisitFields(IFieldsVisitor& visitor) override { visitor(model_dim); visitor(griffin_dim); @@ -208,35 +249,41 @@ struct LayerConfig : public IFields { visitor(activation); visitor(post_qk); visitor(use_qk_norm); + internal.VisitFields(visitor); + // Append new fields here, then update `python/configs.cc`. } + // Returns whether all fields match. + bool TestEqual(const LayerConfig& other, bool print) const; + + size_t CacheLayerSize() const { return kv_heads * qkv_dim * 2; } + + // Multi-Head Attention? + bool IsMHA() const { return heads == kv_heads; } + uint32_t model_dim = 0; uint32_t griffin_dim = 0; uint32_t ff_hidden_dim = 0; uint32_t heads = 0; uint32_t kv_heads = 0; - uint32_t qkv_dim = 0; - uint32_t conv1d_width = 0; // griffin only + uint32_t qkv_dim = 0; // length of Q, K, V vectors (contiguous). + uint32_t conv1d_width = 0; // Griffin only bool ff_biases = false; - bool softmax_attn_output_biases = false; - bool optimized_gating = true; + bool softmax_attn_output_biases = false; // for Griffin + bool optimized_gating = true; // for Gemma3 PostNormType post_norm = PostNormType::None; LayerAttentionType type = LayerAttentionType::kGemma; ActivationType activation = ActivationType::Gelu; PostQKType post_qk = PostQKType::Rope; bool use_qk_norm = false; + InternalLayerConfig internal; }; // Dimensions related to image processing. struct VitConfig : public IFields { - // Returns true if *this and other are equal. - // If partial is true, then we don't check for items that are only set after - // the tensors are loaded from the checkpoint. - // If debug is true, then we output the mismatched fields to stderr. - bool TestEqual(const VitConfig& other, bool partial, bool debug) const; - const char* Name() const override { return "VitConfig"; } + // Source of truth for field ordering. void VisitFields(IFieldsVisitor& visitor) override { visitor(model_dim); visitor(seq_len); @@ -245,8 +292,12 @@ struct VitConfig : public IFields { visitor(image_size); visitor(layer_configs); visitor(pool_dim); + // Append new fields here, then update `python/configs.cc`. } + // Returns whether all fields match. + bool TestEqual(const VitConfig& other, bool print) const; + uint32_t model_dim = 0; uint32_t seq_len = 0; uint32_t num_scales = 0; @@ -256,20 +307,96 @@ struct VitConfig : public IFields { std::vector layer_configs; }; +// Returns a valid `PromptWrapping` for the given `model`, for passing to the +// `ModelConfig` ctor when the caller does not care about the wrapping. The +// wrapping mode is either determined by the model (for PaliGemma and Gemma3), +// or defaults to IT, subject to user override for PT. +PromptWrapping ChooseWrapping(Model model, + Tristate wrapping = Tristate::kDefault); + +struct InternalModelConfig : public IFields { + const char* Name() const override { return "InternalModelConfig"; } + + // Source of truth for field ordering. + void VisitFields(IFieldsVisitor& visitor) override { + // Append new fields here, then update `python/configs.cc`. + } +}; + struct ModelConfig : public IFields { - // Returns true if *this and other are equal. - // If partial is true, then we don't check for items that are only set after - // the tensors are loaded from the checkpoint. - // If debug is true, then we output the mismatched fields to stderr. - bool TestEqual(const ModelConfig& other, bool partial, bool debug) const; + // Preferred usage (single-file format): default-construct, then deserialize + // from a blob. Also used by `config_converter.py`, which sets sufficient + // fields for `TestEqual` and then calls `OverwriteWithCanonical()`. + ModelConfig() = default; + // For use by `model_store.cc` for pre-2025 format after deducing the model + // from tensors plus a user-specified `wrapping` override (`ChooseWrapping`). + ModelConfig(Model model, Type weight, PromptWrapping wrapping); + // Parses a string returned by `Specifier()`. Used by the exporter to select + // the model from command line arguments. Do not use this elsewhere - the + // second ctor is preferred because it is type-checked. + ModelConfig(const std::string& specifier); + + const char* Name() const override { return "ModelConfig"; } + + // Source of truth for field ordering. + void VisitFields(IFieldsVisitor& visitor) override { + visitor(model_family_version); + visitor(display_name); + visitor(model); + visitor(wrapping); + visitor(weight); + + visitor(num_layers); + visitor(model_dim); + visitor(vocab_size); + visitor(max_seq_len); + + visitor(unused_num_tensor_scales); + + visitor(att_cap); + visitor(final_cap); + + visitor(absolute_pe); + visitor(use_local_attention); + visitor(query_scale); + visitor(layer_configs); + visitor(attention_window_sizes); + visitor(norm_num_groups); + visitor(vit_config); + visitor(pool_dim); + + visitor(eos_id); + visitor(secondary_eos_id); + + visitor(scale_base_names); + + internal.VisitFields(visitor); + + // Append new fields here, then update `python/configs.cc`. + } + + // Returns whether all fields match except `model` and `display_name`, and + // some others that are not yet set by config_converter.py. This is for + // internal use by `OverwriteWithCanonical`, but potentially useful elsewhere. + bool TestEqual(const ModelConfig& other, bool print) const; + + // For each model, constructs its canonical `ModelConfig` and if `TestEqual` + // returns true, overwrites `*this` with that. Otherwise, returns false to + // indicate this is not a known model. Called by `config_converter.py`. + bool OverwriteWithCanonical(); + + // Returns a string encoding of the model family, size, weight, and + // `PromptWrapping`. Stable/unchanging; can be used as the model file name. + // The third ctor also expects a string returned by this. + std::string Specifier() const; void AddLayerConfig(const LayerConfig& layer_config) { layer_configs.push_back(layer_config); + HWY_ASSERT(layer_configs.size() <= num_layers); } - size_t CachePosSize() const { - size_t num_layers = layer_configs.size(); - return num_layers * layer_configs[0].CacheLayerSize(); + bool IsGlobalLayer(size_t layer_idx) const { + return attention_window_sizes[layer_idx] == max_seq_len; } size_t NumLayersOfTypeBefore(LayerAttentionType type, size_t num) const { @@ -287,77 +414,77 @@ struct ModelConfig : public IFields { size_t NumHeads() const { uint32_t num_heads = 0; for (const auto& layer_config : layer_configs) { - num_heads = std::max(num_heads, layer_config.heads); + num_heads = HWY_MAX(num_heads, layer_config.heads); } return num_heads; } - const char* Name() const override { return "ModelConfig"; } + size_t KVCacheCols() const { + size_t num_layers = layer_configs.size(); + return num_layers * layer_configs[0].CacheLayerSize(); + } bool IsEOS(int id) const { return (id == eos_id || id == secondary_eos_id); } - void VisitFields(IFieldsVisitor& visitor) override { - visitor(model_family_version); - visitor(model_name); - visitor(model); - visitor(wrapping); - visitor(weight); - visitor(num_layers); - visitor(model_dim); - visitor(vocab_size); - visitor(seq_len); - visitor(num_tensor_scales); - visitor(att_cap); - visitor(final_cap); - visitor(absolute_pe); - visitor(use_local_attention); - visitor(query_scale); - visitor(layer_configs); - visitor(attention_window_sizes); - visitor(norm_num_groups); - visitor(vit_config); - visitor(pool_dim); - visitor(eos_id); - visitor(secondary_eos_id); - } - - // Major version of the model family. It is used as a fallback to distinguish - // between model types when there is no explicit information in the config. + // Major version of the model family, reflecting architecture changes. This is + // more convenient to compare than `Model` because that also includes the + // model size. uint32_t model_family_version = 1; - std::string model_name; - Model model = Model::UNKNOWN; + // For display only, may change. Use `Specifier()` for setting the + // file name. Not checked by `TestEqual` because `config_converter.py` does + // not set this. + std::string display_name; + Model model = Model::UNKNOWN; // Not checked by `TestEqual`, see above. PromptWrapping wrapping = PromptWrapping::GEMMA_PT; Type weight = Type::kUnknown; + uint32_t num_layers = 0; uint32_t model_dim = 0; uint32_t vocab_size = 0; - uint32_t seq_len = 0; - uint32_t num_tensor_scales = 0; + uint32_t max_seq_len = 0; + + // We no longer set nor use this: config_converter is not able to set this, + // and only pre-2025 format stores scales, and we do not require advance + // knowledge of how many there will be. Any scales present will just be + // assigned in order to the tensors matching `scale_base_names`. + uint32_t unused_num_tensor_scales = 0; + float att_cap = 0.0f; float final_cap = 0.0f; + bool absolute_pe = false; - bool use_local_attention = false; // griffin only + bool use_local_attention = false; // Griffin only QueryScaleType query_scale = QueryScaleType::SqrtKeySize; std::vector layer_configs; std::vector attention_window_sizes; - std::unordered_set scale_names; uint32_t norm_num_groups = 1; + // Dimensions related to image processing. VitConfig vit_config; uint32_t pool_dim = 1; // used only for VitConfig copy + int eos_id = 1; int secondary_eos_id = 1; + + // Tensor base names without a layer suffix, used by `ModelStore` only for + // pre-2025 format. + std::vector scale_base_names; + + InternalModelConfig internal; }; -// Returns the config for the given model. -ModelConfig ConfigFromModel(Model model); - -// Returns the model for the given config, if it matches any standard model. -Model ModelFromConfig(const ModelConfig& config); - // Returns the sub-config for the ViT model of the PaliGemma model. ModelConfig GetVitConfig(const ModelConfig& config); +enum DeducedLayerTypes { + kDeducedGriffin = 1, + kDeducedViT = 2, + kDeduced448 = 4, // For ViT, 448x448 resolution instead of 224x224. +}; + +// layer_types is one or more of `DeducedLayerTypes`. +Model DeduceModel(const Path& blob_path, size_t layers, int layer_types); + } // namespace gcpp #endif // THIRD_PARTY_GEMMA_CPP_GEMMA_CONFIGS_H_ diff --git a/gemma/configs_test.cc b/gemma/configs_test.cc index 3efd2cb..0ca4a84 100644 --- a/gemma/configs_test.cc +++ b/gemma/configs_test.cc @@ -1,461 +1,44 @@ #include "gemma/configs.h" -#include -#include -#include -#include +#include + +#include #include #include "gtest/gtest.h" -#include "hwy/aligned_allocator.h" +#include "compression/types.h" // Type +#include "io/fields.h" // Type namespace gcpp { -template -constexpr std::array OldFixedLayerConfig( - LayerAttentionType type) { - std::array config = {}; - for (LayerAttentionType& l : config) { - l = type; - } - return config; -} +TEST(ConfigsTest, TestAll) { + ForEachModel([&](Model model) { + ModelConfig config(model, Type::kSFP, ChooseWrapping(model)); + fprintf(stderr, "Testing %s (%s)\n", config.display_name.c_str(), + config.Specifier().c_str()); + HWY_ASSERT(config.model == model); -template -constexpr std::array OldFixedAttentionWindowSizes( - size_t window_size) { - std::array window_size_configs = {}; - for (size_t& l : window_size_configs) { - l = window_size; - } - return window_size_configs; -} + // We can deduce the model/display_name from all other fields. + config.model = Model::UNKNOWN; + const std::string saved_display_name = config.display_name; + config.display_name.clear(); + HWY_ASSERT(config.OverwriteWithCanonical()); + HWY_ASSERT(config.model == model); + HWY_ASSERT(config.display_name == saved_display_name); -// Repeat window_size_pattern for kNum / kPatternSize times. -template -constexpr std::array OldRepeatedAttentionWindowSizes( - const std::array& window_size_pattern) { - static_assert(kNum % kPatternSize == 0, - "kNum must be a multiple of kPatternSize"); - std::array window_size_configs = {}; - for (size_t i = 0; i < kNum; ++i) { - window_size_configs[i] = window_size_pattern[i % kPatternSize]; - } - return window_size_configs; -} - -template -constexpr size_t OldNumLayersOfTypeBefore( - const std::array& layers, - LayerAttentionType type, size_t num) { - size_t count = 0; - for (size_t i = 0; i < num; i++) { - if (layers[i] == type) count++; - } - return count; -} - -template -struct CacheLayerSize { - constexpr size_t operator()() const { - return TConfig::kKVHeads * TConfig::kQKVDim * 2; - } -}; - -template -struct CachePosSize { - constexpr size_t operator()() const { - return TConfig::kGemmaLayers * CacheLayerSize()(); - } -}; - -struct OldConfigNoVit { - struct VitConfig { - // Some of these are needed to make the compiler happy when trying to - // generate code that will actually never be used. - using Weight = float; - static constexpr int kLayers = 0; - static constexpr std::array kLayerConfig = - OldFixedLayerConfig<0>(LayerAttentionType::kVit); - static constexpr int kModelDim = 0; - static constexpr int kFFHiddenDim = 0; - static constexpr int kHeads = 1; // Avoid division by 0 in griffin gate_w. - static constexpr int kKVHeads = 0; - static constexpr int kQKVDim = 0; - static constexpr int kSeqLen = 0; - static constexpr ResidualType kResidual = ResidualType::Add; - static constexpr int kGriffinLayers = 0; - static constexpr int kConv1dWidth = 0; - static constexpr bool kFFBiases = false; - static constexpr bool kSoftmaxAttnOutputBiases = false; - static constexpr PostNormType kPostNorm = PostNormType::None; - }; -}; - -struct OldConfigNoSSM : OldConfigNoVit { - static constexpr int kGriffinLayers = 0; - - static constexpr int kConv1dWidth = 0; - static constexpr bool kFFBiases = false; - static constexpr bool kSoftmaxAttnOutputBiases = false; - static constexpr bool kUseHalfRope = false; - static constexpr bool kUseLocalAttention = false; - static constexpr bool kInterleaveQKV = true; - static constexpr PostQKType kPostQK = PostQKType::Rope; - static constexpr ActivationType kActivation = ActivationType::Gelu; - static constexpr ResidualType kResidual = ResidualType::Add; -}; - -struct OldConfigBaseGemmaV1 : OldConfigNoSSM { - static constexpr float kAttCap = 0.0f; - static constexpr float kFinalCap = 0.0f; - static constexpr PostNormType kPostNorm = PostNormType::None; - static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize; -}; - -struct OldConfigBaseGemmaV2 : OldConfigNoSSM { - static constexpr float kAttCap = 50.0f; - static constexpr float kFinalCap = 30.0f; - static constexpr PostNormType kPostNorm = PostNormType::Scale; -}; - -template -struct OldConfigGemma2_27B : public OldConfigBaseGemmaV2 { - using Weight = TWeight; // make accessible where we only have a TConfig - - static constexpr int kSeqLen = 8192; - static constexpr int kVocabSize = gcpp::kVocabSize; - static constexpr std::array kLayerConfig = - OldFixedLayerConfig<46>(LayerAttentionType::kGemma); - static constexpr std::array kAttentionWindowSizes = - OldRepeatedAttentionWindowSizes<46, 2>({4096, kSeqLen}); - static constexpr int kLayers = kLayerConfig.size(); - static constexpr int kNumTensorScales = 4 * kLayers; - static constexpr int kGemmaLayers = kLayers; - static constexpr int kModelDim = 4608; - static constexpr int kFFHiddenDim = 16 * 4608 / 2; // = 36864 - static constexpr int kHeads = 32; - static constexpr int kKVHeads = 16; - static constexpr int kQKVDim = 128; // query size == key size == value size - static constexpr int kTopK = gcpp::kTopK; - static constexpr bool kAbsolutePE = false; - static constexpr QueryScaleType kQueryScale = - QueryScaleType::SqrtModelDimDivNumHeads; -}; - -template -struct OldConfigGemma2_9B : public OldConfigBaseGemmaV2 { - using Weight = TWeight; // make accessible where we only have a TConfig - - static constexpr int kSeqLen = 8192; - static constexpr int kVocabSize = gcpp::kVocabSize; - static constexpr std::array kLayerConfig = - OldFixedLayerConfig<42>(LayerAttentionType::kGemma); - static constexpr std::array kAttentionWindowSizes = - OldRepeatedAttentionWindowSizes<42, 2>({4096, kSeqLen}); - static constexpr int kLayers = kLayerConfig.size(); - static constexpr int kNumTensorScales = 4 * kLayers; - static constexpr int kGemmaLayers = kLayers; - static constexpr int kModelDim = 3584; - static constexpr int kFFHiddenDim = 8 * 3584 / 2; // = 14336 - static constexpr int kHeads = 16; - static constexpr int kKVHeads = 8; - static constexpr int kQKVDim = 256; // query size == key size == value size - static constexpr int kTopK = gcpp::kTopK; - static constexpr bool kAbsolutePE = false; - static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize; -}; - -template -struct OldConfigGemma7B : public OldConfigBaseGemmaV1 { - using Weight = TWeight; // make accessible where we only have a TConfig - - static constexpr int kSeqLen = gcpp::kSeqLen; - static constexpr int kVocabSize = gcpp::kVocabSize; - static constexpr std::array kLayerConfig = - OldFixedLayerConfig<28>(LayerAttentionType::kGemma); - static constexpr std::array kAttentionWindowSizes = - OldFixedAttentionWindowSizes<28>(kSeqLen); - static constexpr int kLayers = kLayerConfig.size(); - static constexpr int kNumTensorScales = 4 * kLayers; - static constexpr int kGemmaLayers = kLayers; - static constexpr int kModelDim = 3072; - static constexpr int kFFHiddenDim = 16 * 3072 / 2; // = 24576 - static constexpr int kHeads = 16; - static constexpr int kKVHeads = 16; // standard MHA - static constexpr int kQKVDim = 256; // query size == key size == value size - static constexpr int kTopK = gcpp::kTopK; - static constexpr bool kAbsolutePE = false; -}; - -template -struct OldConfigGemma2B : public OldConfigBaseGemmaV1 { - using Weight = TWeight; // make accessible where we only have a TConfig - - static constexpr int kSeqLen = gcpp::kSeqLen; - static constexpr int kVocabSize = gcpp::kVocabSize; - static constexpr std::array kLayerConfig = - OldFixedLayerConfig<18>(LayerAttentionType::kGemma); - static constexpr std::array kAttentionWindowSizes = - OldFixedAttentionWindowSizes<18>(kSeqLen); - static constexpr int kLayers = kLayerConfig.size(); - static constexpr int kNumTensorScales = 4 * kLayers; - static constexpr int kGemmaLayers = kLayers; - static constexpr int kModelDim = 2048; - static constexpr int kFFHiddenDim = 16 * 2048 / 2; // = 16384 - static constexpr int kHeads = 8; - static constexpr int kKVHeads = 1; - static constexpr int kQKVDim = 256; // query size == key size == value size - static constexpr int kTopK = gcpp::kTopK; - static constexpr bool kAbsolutePE = false; -}; - -template -struct OldConfigPaliGemma_224 : public OldConfigGemma2B { - // On the LM side, the vocab size is one difference to Gemma1-2B in the - // architecture. PaliGemma adds 1024 and 128 tokens. - static constexpr int kVocabSize = 256000 + 1024 + 128; // = 257152 - - // Sub-config for the Vision-Transformer part. - struct VitConfig : public OldConfigNoSSM { - using Weight = TWeight; - // The ViT parts. https://arxiv.org/abs/2305.13035 - // "SoViT-400m/14 [...] has a width of 1152, depth 27, and MLP dim 4304." - static constexpr std::array kLayerConfig = - OldFixedLayerConfig<27>(LayerAttentionType::kVit); - static constexpr int kLayers = kLayerConfig.size(); - static constexpr int kNumTensorScales = 4 * kLayers; - static constexpr int kModelDim = 1152; - static constexpr int kFFHiddenDim = 4304; - static constexpr int kHeads = 16; - static constexpr int kKVHeads = 16; // standard MHA - static constexpr int kQKVDim = 72; - static constexpr int kSeqLen = 16 * 16; // 256 - static constexpr bool kFFBiases = true; - // The Vit part does not have a vocabulary, the image patches are embedded. - static constexpr int kVocabSize = 0; - // Dimensions related to image processing. - static constexpr int kPatchWidth = 14; - static constexpr int kImageSize = 224; - // Necessary constant for the layer configuration. - static constexpr PostNormType kPostNorm = PostNormType::None; - }; -}; - -template -struct OldConfigGemma2_2B : public OldConfigBaseGemmaV2 { - using Weight = TWeight; // make accessible where we only have a TConfig - - static constexpr int kSeqLen = 8192; - static constexpr int kVocabSize = gcpp::kVocabSize; - static constexpr std::array kLayerConfig = - OldFixedLayerConfig<26>(LayerAttentionType::kGemma); - static constexpr std::array kAttentionWindowSizes = - OldRepeatedAttentionWindowSizes<26, 2>({4096, kSeqLen}); - static constexpr int kLayers = kLayerConfig.size(); - static constexpr int kNumTensorScales = 4 * kLayers; - static constexpr int kGemmaLayers = kLayers; - static constexpr int kModelDim = 2304; - static constexpr int kFFHiddenDim = 8 * 2304 / 2; // = 9216 - static constexpr int kHeads = 8; - static constexpr int kKVHeads = 4; - static constexpr int kQKVDim = 256; // query size == key size == value size - static constexpr int kTopK = gcpp::kTopK; - static constexpr bool kAbsolutePE = false; - static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize; -}; - -template -struct OldConfigGemmaTiny : public OldConfigNoSSM { - using Weight = TWeight; // make accessible where we only have a TConfig - - static constexpr int kSeqLen = 32; - static constexpr int kVocabSize = 64; - static constexpr std::array kLayerConfig = - OldFixedLayerConfig<3>(LayerAttentionType::kGemma); - static constexpr std::array kAttentionWindowSizes = - OldFixedAttentionWindowSizes<3>(kSeqLen); - static constexpr int kLayers = kLayerConfig.size(); - static constexpr int kNumTensorScales = 4 * kLayers; - static constexpr int kGemmaLayers = kLayers; - static constexpr int kModelDim = 128; - static constexpr int kFFHiddenDim = 256; - static constexpr int kHeads = 4; - static constexpr int kKVHeads = 1; - static constexpr int kQKVDim = 16; // query size == key size == value size - static constexpr int kTopK = gcpp::kTopK; - static constexpr bool kAbsolutePE = false; - static constexpr PostNormType kPostNorm = PostNormType::None; - static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize; - - static constexpr float kAttCap = 0.0f; - // This is required for optimize_test to pass. - static constexpr float kFinalCap = 30.0f; -}; - -template -struct OldConfigGriffin2B : OldConfigNoVit { - using Weight = TWeight; // make accessible where we only have a TConfig - - // Griffin uses local attention, so kSeqLen is actually the local attention - // window. - static constexpr int kSeqLen = 2048; - static constexpr int kVocabSize = gcpp::kVocabSize; - static constexpr std::array kLayerConfig = { - LayerAttentionType::kGriffinRecurrentBlock, - LayerAttentionType::kGriffinRecurrentBlock, - LayerAttentionType::kGemma, - LayerAttentionType::kGriffinRecurrentBlock, - LayerAttentionType::kGriffinRecurrentBlock, - LayerAttentionType::kGemma, - LayerAttentionType::kGriffinRecurrentBlock, - LayerAttentionType::kGriffinRecurrentBlock, - LayerAttentionType::kGemma, - LayerAttentionType::kGriffinRecurrentBlock, - LayerAttentionType::kGriffinRecurrentBlock, - LayerAttentionType::kGemma, - LayerAttentionType::kGriffinRecurrentBlock, - LayerAttentionType::kGriffinRecurrentBlock, - LayerAttentionType::kGemma, - LayerAttentionType::kGriffinRecurrentBlock, - LayerAttentionType::kGriffinRecurrentBlock, - LayerAttentionType::kGemma, - LayerAttentionType::kGriffinRecurrentBlock, - LayerAttentionType::kGriffinRecurrentBlock, - LayerAttentionType::kGemma, - LayerAttentionType::kGriffinRecurrentBlock, - LayerAttentionType::kGriffinRecurrentBlock, - LayerAttentionType::kGemma, - LayerAttentionType::kGriffinRecurrentBlock, - LayerAttentionType::kGriffinRecurrentBlock, - }; - static constexpr std::array kAttentionWindowSizes = - OldFixedAttentionWindowSizes<26>(kSeqLen); - static constexpr int kLayers = kLayerConfig.size(); - static constexpr int kGemmaLayers = OldNumLayersOfTypeBefore( - kLayerConfig, LayerAttentionType::kGemma, kLayers); - static constexpr int kGriffinLayers = OldNumLayersOfTypeBefore( - kLayerConfig, LayerAttentionType::kGriffinRecurrentBlock, kLayers); - static constexpr int kModelDim = 2560; - static constexpr int kFFHiddenDim = 7680; - static constexpr int kHeads = 10; - static constexpr int kKVHeads = 1; - static constexpr int kQKVDim = 256; // query size == key size == value size - static constexpr int kTopK = gcpp::kTopK; - static constexpr bool kAbsolutePE = false; - static constexpr PostNormType kPostNorm = PostNormType::None; - - // No SoftCap. - static constexpr float kAttCap = 0.0f; - static constexpr float kFinalCap = 0.0f; - - // SSM config. - static constexpr int kConv1dWidth = 4; - static constexpr bool kFFBiases = true; - static constexpr bool kSoftmaxAttnOutputBiases = true; - static constexpr bool kUseHalfRope = true; - static constexpr bool kUseLocalAttention = true; - static constexpr bool kInterleaveQKV = false; - static constexpr int kNumTensorScales = 140; - static constexpr PostQKType kPostQK = PostQKType::Rope; - static constexpr ActivationType kActivation = ActivationType::Gelu; - static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize; - static constexpr ResidualType kResidual = ResidualType::Add; -}; - -template -void AssertMatch(const ModelConfig& config) { - ASSERT_EQ(TConfig::kModelDim, config.model_dim); - if constexpr (TConfig::VitConfig::kModelDim != 0) { - ASSERT_EQ(TConfig::VitConfig::kModelDim, config.vit_config.model_dim); - ASSERT_EQ(TConfig::VitConfig::kSeqLen, config.vit_config.seq_len); - ASSERT_EQ(TConfig::VitConfig::kNumTensorScales, - config.vit_config.num_scales); - for (size_t i = 0; i < config.vit_config.layer_configs.size(); ++i) { - ASSERT_EQ(TConfig::VitConfig::kLayerConfig[i], - config.vit_config.layer_configs[i].type); - } - } - ASSERT_EQ(TConfig::kVocabSize, config.vocab_size); - ASSERT_EQ(TConfig::kSeqLen, config.seq_len); - ASSERT_EQ(TConfig::kAttCap, config.att_cap); - ASSERT_EQ(TConfig::kFinalCap, config.final_cap); - ASSERT_EQ(TConfig::kAbsolutePE, config.absolute_pe); - ASSERT_EQ(TConfig::kUseLocalAttention, config.use_local_attention); - ASSERT_EQ(TConfig::kQueryScale, config.query_scale); - ASSERT_EQ(TConfig::kGemmaLayers, - config.NumLayersOfType(LayerAttentionType::kGemma)); - ASSERT_EQ(TConfig::kGriffinLayers, - config.NumLayersOfType(LayerAttentionType::kGriffinRecurrentBlock)); - for (size_t i = 0; i < config.layer_configs.size(); ++i) { - ASSERT_EQ(TConfig::kModelDim, config.layer_configs[i].model_dim); - ASSERT_EQ(TConfig::kFFHiddenDim, config.layer_configs[i].ff_hidden_dim); - ASSERT_EQ(TConfig::kHeads, config.layer_configs[i].heads); - ASSERT_EQ(TConfig::kKVHeads, config.layer_configs[i].kv_heads); - ASSERT_EQ(TConfig::kQKVDim, config.layer_configs[i].qkv_dim); - ASSERT_EQ(TConfig::kConv1dWidth, config.layer_configs[i].conv1d_width); - ASSERT_EQ(TConfig::kFFBiases, config.layer_configs[i].ff_biases); - ASSERT_EQ(TConfig::kSoftmaxAttnOutputBiases, - config.layer_configs[i].softmax_attn_output_biases); - ASSERT_EQ(TConfig::kPostNorm, config.layer_configs[i].post_norm); - ASSERT_EQ(TConfig::kLayerConfig[i], config.layer_configs[i].type); - ASSERT_EQ(TConfig::kActivation, config.layer_configs[i].activation); - PostQKType post_qk = TConfig::kPostQK; - if (TConfig::kUseHalfRope) { - post_qk = PostQKType::HalfRope; - } - ASSERT_EQ(post_qk, config.layer_configs[i].post_qk); - } - - ASSERT_EQ(TConfig::kAttentionWindowSizes.size(), - config.attention_window_sizes.size()); - for (size_t i = 0; i < config.attention_window_sizes.size(); ++i) { - ASSERT_EQ(TConfig::kAttentionWindowSizes[i], - config.attention_window_sizes[i]); - } - ASSERT_EQ(TConfig::kNumTensorScales, config.num_tensor_scales); -} - -ModelConfig RoundTripSerialize(const ModelConfig& config) { - std::vector config_buffer = config.Write(); - ModelConfig deserialized; - deserialized.Read(hwy::Span(config_buffer), 0); - return deserialized; -} - -TEST(ConfigsTest, OldConfigGemma2B) { - AssertMatch>(ConfigFromModel(Model::GEMMA_2B)); - ModelConfig config = RoundTripSerialize(ConfigFromModel(Model::GEMMA_2B)); - AssertMatch>(config); -} - -TEST(ConfigsTest, OldConfigGemma7B) { - AssertMatch>(ConfigFromModel(Model::GEMMA_7B)); -} - -TEST(ConfigsTest, OldConfigGemma2_2B) { - AssertMatch>(ConfigFromModel(Model::GEMMA2_2B)); -} - -TEST(ConfigsTest, OldConfigGemma2_9B) { - AssertMatch>(ConfigFromModel(Model::GEMMA2_9B)); -} - -TEST(ConfigsTest, OldConfigGemma2_27B) { - AssertMatch>(ConfigFromModel(Model::GEMMA2_27B)); -} - -TEST(ConfigsTest, OldConfigGriffin2B) { - AssertMatch>(ConfigFromModel(Model::GRIFFIN_2B)); -} - -TEST(ConfigsTest, OldConfigGemmaTiny) { - AssertMatch>(ConfigFromModel(Model::GEMMA_TINY)); -} - -TEST(ConfigsTest, OldConfigPaliGemma_224) { - AssertMatch>( - ConfigFromModel(Model::PALIGEMMA_224)); + const std::vector serialized = config.Write(); + ModelConfig deserialized; + const IFields::ReadResult result = + deserialized.Read(hwy::Span(serialized), /*pos=*/0); + HWY_ASSERT(result.pos == serialized.size()); + // We wrote it, so all fields should be known, and no extra. + HWY_ASSERT(result.extra_u32 == 0); + HWY_ASSERT(result.missing_fields == 0); + // All fields should match. + HWY_ASSERT(deserialized.TestEqual(config, /*print=*/true)); + HWY_ASSERT(deserialized.model == model); + HWY_ASSERT(deserialized.display_name == saved_display_name); + }); } } // namespace gcpp diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index ccb34f0..e3f9a19 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -13,33 +13,18 @@ // See the License for the specific language governing permissions and // limitations under the License. -// SIMD functions for Gemma/Griffin transformers. +// Transformer components shared between vit.cc and attention.cc. -#include // sqrtf #include -#include +#include -#include // std::min -#include -#include - -#include "compression/compress.h" #include "gemma/activations.h" -#include "gemma/common.h" #include "gemma/configs.h" -#include "gemma/gemma.h" -#include "gemma/kv_cache.h" #include "gemma/weights.h" -#include "paligemma/image.h" -#include "util/allocator.h" -#include "util/basics.h" +#include "ops/matmul.h" +#include "util/mat.h" #include "util/threading.h" -#include "hwy/aligned_allocator.h" -#include "hwy/base.h" -#include "hwy/bit_set.h" -#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/profiler.h" -#include "hwy/timer.h" // Include guard (still compiled once per target) #if defined(THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_INL_H_) == \ @@ -52,709 +37,17 @@ #include "hwy/highway.h" // After highway.h -#include "ops/matmul-inl.h" -#include "ops/matvec-inl.h" #include "ops/ops-inl.h" -#ifndef GEMMA_TYPE -#if HWY_IDE -// Provide a definition so the IDE does not complain. -#define GEMMA_TYPE float -#else -#error "Only include from instantiations/*.cc, which must define GEMMA_TYPE" -#endif // HWY_IDE -#endif // GEMMA_TYPE - HWY_BEFORE_NAMESPACE(); namespace gcpp { namespace HWY_NAMESPACE { -// Different functions use different naming conventions for the number of -// tokens. Functions that are query-independent, such as RMSNorm*, call the -// count `num_interleaved`. Functions that are query-dependent, such as -// `Attention`, use separate `num_tokens` and `num_queries`. - -// TODO: add batch query support for Griffin (QueriesPos). template -HWY_NOINLINE void GriffinRecurrent(size_t batch_start, size_t num_tokens, - size_t layer, Activations& activations, - const LayerWeightsPtrs* layer_weights, - const KVCaches& kv_caches) { - PROFILER_ZONE("Gen.Griffin"); - KVCache& kv_cache = kv_caches[0]; - hwy::ThreadPool& pool = activations.env->parallel.Pools().Pool(0); - namespace hn = hwy::HWY_NAMESPACE; - using D = hn::ScalableTag; - const size_t model_dim = layer_weights->layer_config.model_dim; - const size_t conv_1d_width = layer_weights->layer_config.conv1d_width; - const size_t heads = layer_weights->layer_config.heads; - - // X / Y linear layers. - for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) { - float* HWY_RESTRICT y = activations.griffin_y.Batch(batch_idx); - float* HWY_RESTRICT x = activations.griffin_x.Batch(batch_idx); - TwoMatVecAdd(layer_weights->griffin.linear_x_w, - layer_weights->griffin.linear_y_w, 0, model_dim, model_dim, - activations.pre_att_rms_out.Batch(batch_idx), - /*add0=*/layer_weights->griffin.linear_x_biases.data_scale1(), - /*add1=*/layer_weights->griffin.linear_y_biases.data_scale1(), - /*out0=*/x, /*out1=*/y, pool); - Gelu(y, model_dim); - } - - // Conv1D. - for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) { - const size_t pos = batch_start + batch_idx; - float* HWY_RESTRICT x = activations.griffin_x.Batch(batch_idx); - HWY_FULL(float) df; - HWY_DASSERT(model_dim % hn::Lanes(df) == 0); - const size_t layer_offset = layer * model_dim * (conv_1d_width - 1); - - // cache[i] = input at time t-i. - float* HWY_RESTRICT cache[kMaxConv1DWidth]; - cache[0] = x; - for (size_t i = 1; i < conv_1d_width; i++) { - cache[i] = - kv_cache.conv1d_cache.get() + layer_offset + - ((pos + conv_1d_width - 1 - i) % (conv_1d_width - 1)) * model_dim; - } - for (size_t i = 0; i < model_dim; i += hn::Lanes(df)) { - auto xv = hn::Load(df, x + i); - auto accum0 = - hn::Load(df, layer_weights->griffin.conv_biases.data_scale1() + i); - auto accum1 = hn::Zero(df); - HWY_ASSERT_M(conv_1d_width % 2 == 0, "Conv width must be even"); - for (size_t l = 0; 2 * l < conv_1d_width; l++) { - auto wv0 = - hn::Load(df, layer_weights->griffin.conv_w.data_scale1() + - (conv_1d_width - 1 - 2 * l) * model_dim + i); - auto wv1 = - hn::Load(df, layer_weights->griffin.conv_w.data_scale1() + - (conv_1d_width - 2 - 2 * l) * model_dim + i); - accum0 = hn::MulAdd(wv0, hn::Load(df, cache[l * 2] + i), accum0); - accum1 = hn::MulAdd(wv1, hn::Load(df, cache[l * 2 + 1] + i), accum1); - } - hn::Store(hn::Add(accum0, accum1), df, x + i); - hn::Store(xv, df, cache[HWY_MAX(conv_1d_width, 1) - 1] + i); - } - } - - // RGLRU - for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) { - const size_t pos = batch_start + batch_idx; - float* HWY_RESTRICT y = activations.griffin_y.Batch(batch_idx); - float* HWY_RESTRICT x = activations.griffin_x.Batch(batch_idx); - float* HWY_RESTRICT gate_x = activations.griffin_gate_x.Batch(batch_idx); - float* HWY_RESTRICT a = activations.griffin_multiplier.Batch(batch_idx); - float* HWY_RESTRICT rnn_state = - kv_cache.rglru_cache.get() + layer * model_dim; - - pool.Run(0, heads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR { - const size_t kHeadDim = model_dim / heads; - const size_t kMatrixSize = kHeadDim * kHeadDim; - size_t head_offset = head * kHeadDim; - TwoOfsMatVecAddLoop( - layer_weights->griffin.gate_w, kMatrixSize * head, - kMatrixSize * (heads + head), kHeadDim, kHeadDim, x + head_offset, - /*add0=*/layer_weights->griffin.gate_biases.data_scale1() + - head_offset, - /*add1=*/layer_weights->griffin.gate_biases.data_scale1() + - model_dim + head_offset, - /*out0=*/gate_x + head_offset, /*out1=*/a + head_offset); - Sigmoid(gate_x + head_offset, kHeadDim); - Sigmoid(a + head_offset, kHeadDim); - const auto fn_mul = [](D d, hn::Vec x, hn::Vec gate_x) - HWY_ATTR { return hn::Mul(x, gate_x); }; - hn::Transform1(D(), a + head_offset, kHeadDim, - layer_weights->griffin.a.data_scale1() + head_offset, - fn_mul); - hn::Transform1(D(), x + head_offset, kHeadDim, gate_x + head_offset, - fn_mul); - // RNN scan - HWY_FULL(float) df; - HWY_DASSERT(kHeadDim % hn::Lanes(df) == 0); - for (size_t i = 0; i < kHeadDim; i += hn::Lanes(df)) { - auto log_a = hn::Load(df, a + head_offset + i); - auto gated_x = hn::Load(df, x + head_offset + i); - auto rnn = hn::Load(df, rnn_state + head_offset + i); - auto a = hn::Exp(df, log_a); - auto x_multiplier = hn::Sqrt(hn::NegMulAdd(a, a, hn::Set(df, 1.0f))); - if (pos == 0) { - x_multiplier = hn::Set(df, 1.0f); - } - auto new_x = hn::MulAdd(x_multiplier, gated_x, hn::Mul(a, rnn)); - hn::Store(new_x, df, rnn_state + head_offset + i); - - // Join branches. - auto yv = hn::Load(df, y + head_offset + i); - auto pre_out = hn::Mul(yv, new_x); - hn::Store(pre_out, df, x + head_offset + i); - } - }); - } - - // Final linear layer. - for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) { - float* HWY_RESTRICT x = activations.griffin_x.Batch(batch_idx); - float* out_ptr = activations.att_sums.Batch(batch_idx); - MatVecAdd(layer_weights->griffin.linear_out_w, 0, model_dim, model_dim, x, - layer_weights->griffin.linear_out_biases.data_scale1(), out_ptr, - pool); - } -} - -// Wrapper class; holds arguments in member variables to shorten call sites. -template -class GemmaAttention { - // The attention window usually starts at 0 unless `pos` is larger than - // the attention window size, then it is `pos` - window_size + 1. - HWY_INLINE size_t StartPos(size_t pos, size_t layer) { - const size_t att_window_size = - activations_.weights_config.attention_window_sizes[layer]; - return pos - std::min(att_window_size - 1, pos); - } - - template - HWY_INLINE void PositionalEncodingQK(U* qk, size_t pos, size_t layer, - const float mul) { - // qk is either q or k, so qkv_dim is the length we operate on. - const size_t qkv_dim = layer_config_.qkv_dim; - const float* inv_timescale = activations_.inv_timescale.Const(); - bool is_global_layer = - activations_.weights_config.attention_window_sizes[layer] == - activations_.seq_len; - // TODO: add a config flag instead of hardcoding the model. - if (is_global_layer && - (activations_.weights_config.model == Model::GEMMA3_4B || - activations_.weights_config.model == Model::GEMMA3_12B || - activations_.weights_config.model == Model::GEMMA3_27B || - activations_.weights_config.model == Model::GEMMA3_1B)) { - inv_timescale = activations_.inv_timescale_global.Const(); - } - // PostQKType::Rope - (void)layer; - if (layer_weights_.layer_config.post_qk == PostQKType::HalfRope) { - Rope(qk, qkv_dim / 2, inv_timescale, pos); - if (mul != 1.0f) MulByConst(mul, qk, qkv_dim); - } else { - RopeAndMulBy(mul, qk, qkv_dim, inv_timescale, pos); - } - } - - // Fills activations.q and computes KV. For is_mha_, a single MatMul suffices - // and we later copy KV from q to KVCache. Otherwise, a second MatMul writes - // KV directly to KVCache. - HWY_NOINLINE void ComputeQKV(const size_t num_interleaved) { - PROFILER_ZONE("Gen.Attention.QKV"); - const size_t model_dim = layer_config_.model_dim; - const size_t qkv_dim = layer_config_.qkv_dim; - const size_t heads = layer_config_.heads; - const size_t kv_heads = layer_config_.kv_heads; - - const auto pre_att_rms_out = - ConstMatFromBatch(num_interleaved, activations_.pre_att_rms_out); - auto w_q1 = layer_weights_.qkv_einsum_w.data() - ? ConstMatFromWeights(layer_weights_.qkv_einsum_w) - : ConstMatFromWeights(layer_weights_.qkv_einsum_w1); - // The original qkv_einsum_w has shape [(heads + kv_heads * 2), kKQVDim, - // model_dim], which we reshaped to (heads + kv_heads * 2) * kKQVDim rows. - // We must shrink to the actual size because MatMul verifies - // `B.extents.rows == C.Cols()`. If MHA, `QStride() == 3 * qkv_dim` and all - // rows are used. Otherwise, `QStride() == qkv_dim` and KV will be - // computed in the second MatMul. - const size_t w1_rows = heads * layer_config_.QStride(); - w_q1.ShrinkRows(w1_rows); - MatMul(pre_att_rms_out, w_q1, - /*add=*/nullptr, *activations_.env, RowPtrFromBatch(activations_.q)); - - if (is_mha_) { - // Multi-Head Attention a.k.a. "use_qkv_einsum" computed QKV already. - } else { - auto w_q2 = layer_weights_.qkv_einsum_w.data() - ? ConstMatFromWeights(layer_weights_.qkv_einsum_w, - w1_rows * model_dim) - : ConstMatFromWeights(layer_weights_.qkv_einsum_w2); - // KV structure is [k, v, k, v, ....] = kv_heads pairs of (k, v). - const size_t w_rows_kv_cols = kv_heads * 2 * qkv_dim; - w_q2.ShrinkRows(w_rows_kv_cols); - - // Single query and no wraparound means we can use a matmul and write - // directly into the KV cache with a stride of cache_pos_size_. - if (num_queries_ == 1 && - queries_pos_[0] + num_tokens_ <= div_seq_len_.GetDivisor()) { - const size_t kv_ofs = - queries_pos_[0] * cache_pos_size_ + layer_ * cache_layer_size_; - float* HWY_RESTRICT kv = kv_caches_[0].kv_cache.get() + kv_ofs; - RowPtrF kv_rows(kv, w_rows_kv_cols); - kv_rows.SetStride(cache_pos_size_); - MatMul(pre_att_rms_out, w_q2, - /*add=*/nullptr, *activations_.env, kv_rows); - } else { - // Proceed row by row because there will be wraparound. - for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved; - ++interleaved_idx) { - const float* x = activations_.pre_att_rms_out.Batch(interleaved_idx); - const size_t query_idx = interleaved_idx % num_queries_; - const size_t batch_idx = interleaved_idx / num_queries_; - KVCache& kv_cache = kv_caches_[query_idx]; - const size_t cache_pos = - div_seq_len_.Remainder(queries_pos_[query_idx] + batch_idx); - const size_t kv_offset = - cache_pos * cache_pos_size_ + layer_ * cache_layer_size_; - float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset; - if (layer_weights_.qkv_einsum_w.data()) { - MatVec(layer_weights_.qkv_einsum_w, heads * qkv_dim * model_dim, - w_rows_kv_cols, model_dim, x, kv, pool_); - } else { - MatVec(layer_weights_.qkv_einsum_w2, 0, // - w_rows_kv_cols, model_dim, x, kv, pool_); - } - } - } - } // !is_mha_ - - // Apply positional encodings for K (and copy KV to cache if MHA). - pool_.Run(0, kv_heads * num_interleaved, - [&](uint64_t task, size_t /*thread*/) HWY_ATTR { - const size_t head = task % kv_heads; - const size_t interleaved_idx = task / kv_heads; - const size_t query_idx = interleaved_idx % num_queries_; - const size_t batch_idx = interleaved_idx / num_queries_; - const size_t pos = queries_pos_[query_idx] + batch_idx; - const size_t cache_pos = div_seq_len_.Remainder(pos); - const size_t kv_offset = cache_pos * cache_pos_size_ + - layer_ * cache_layer_size_ + - head * qkv_dim * 2; - KVCache& kv_cache = kv_caches_[query_idx]; - float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset; - // If MHA, copy computed K and V into KVCache. - if (is_mha_) { - const float* HWY_RESTRICT mha_kv = - activations_.q.Batch(interleaved_idx) + head * q_stride_ + - qkv_dim; - hwy::CopyBytes(mha_kv, kv, 2 * qkv_dim * sizeof(*kv)); - } - - // Apply further processing to K. - if (layer_weights_.key_norm_scale.data()) { - RMSNormInplace(layer_weights_.key_norm_scale.data(), kv, - qkv_dim); - } - PositionalEncodingQK(kv, pos, layer_, /*mul=*/1.0f); - }); - } - - // Computes Q.K scores, which are "logits" (or scores) stored to head_att. - HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos, - const size_t head_offset, const float* HWY_RESTRICT q, - const KVCache& kv_cache, float* HWY_RESTRICT head_att) { - const size_t qkv_dim = layer_config_.qkv_dim; - if (HWY_LIKELY(last_pos < activations_.seq_len)) { - // Slightly faster: no wraparound. - for (size_t pos = start_pos; pos <= last_pos; ++pos) { - const size_t kv_offset = - pos * cache_pos_size_ + layer_ * cache_layer_size_ + head_offset; - const float* HWY_RESTRICT k = &kv_cache.kv_cache[kv_offset]; - const float score = Dot(q, k, qkv_dim); - head_att[pos] = score; - } - } else { - for (size_t pos = start_pos; pos <= last_pos; ++pos) { - const size_t cache_pos = div_seq_len_.Remainder(pos); - const size_t kv_offset = cache_pos * cache_pos_size_ + - layer_ * cache_layer_size_ + head_offset; - const float* HWY_RESTRICT k = &kv_cache.kv_cache[kv_offset]; - const float score = Dot(q, k, qkv_dim); - head_att[pos % activations_.seq_len] = score; - } - } - } - - // Accumulates the sum of v (from `kv_cache`) * probability (`head_att`) into - // `att_out`. Equivalent in gemma/modules.py: - // encoded = jnp.einsum('BTNS,BSNH->BTNH', probs, value_proj) - HWY_INLINE void WeightedSumV(const size_t start_pos, const size_t last_pos, - const float* HWY_RESTRICT head_att, - const size_t layer, const size_t head_offset, - const hwy::Divisor& div_seq_len, - const KVCache& kv_cache, - float* HWY_RESTRICT att_out) const { - const size_t qkv_dim = layer_config_.qkv_dim; - hwy::ZeroBytes(att_out, qkv_dim * sizeof(*att_out)); - - if (HWY_LIKELY(last_pos < activations_.seq_len)) { - // Slightly faster: no wraparound. - for (size_t pos = start_pos; pos <= last_pos; ++pos) { - const size_t kv_offset = - pos * cache_pos_size_ + layer * cache_layer_size_ + head_offset; - const float* HWY_RESTRICT v = - kv_cache.kv_cache.get() + kv_offset + qkv_dim; - MulByConstAndAdd(head_att[pos], v, att_out, qkv_dim); - } - } else { - for (size_t pos = start_pos; pos <= last_pos; ++pos) { - const size_t cache_pos = div_seq_len.Remainder(pos); - const size_t kv_offset = cache_pos * cache_pos_size_ + - layer * cache_layer_size_ + head_offset; - const float* HWY_RESTRICT v = - kv_cache.kv_cache.get() + kv_offset + qkv_dim; - MulByConstAndAdd(head_att[pos % activations_.seq_len], v, att_out, - qkv_dim); - } - } - } - - HWY_NOINLINE void DotSoftmaxWeightedSum(const size_t num_interleaved) { - PROFILER_ZONE("Gen.Attention.DotSoftmax"); - const float query_scale = ChooseQueryScale(activations_.weights_config); - - // A "head group" in the context of GQA refers to a collection of query - // heads that share the same key and value heads. - const size_t kHeadGroups = layer_config_.heads / layer_config_.kv_heads; - - // For each head (token, query), compute Q.K, softmax, and weighted V. - pool_.Run(0, layer_config_.heads * num_interleaved, - [&](uint64_t task, size_t /*thread*/) HWY_ATTR { - const size_t head = task % layer_config_.heads; - const size_t interleaved_idx = task / layer_config_.heads; - const size_t query_idx = interleaved_idx % num_queries_; - const size_t batch_idx = interleaved_idx / num_queries_; - const size_t qkv_dim = layer_config_.qkv_dim; - const size_t head_offset = (head / kHeadGroups) * qkv_dim * 2; - KVCache& kv_cache = kv_caches_[query_idx]; - float* HWY_RESTRICT q = - activations_.q.Batch(interleaved_idx) + head * q_stride_; - - // Apply rope and scaling to Q. - const size_t pos = queries_pos_[query_idx] + batch_idx; - if (layer_weights_.query_norm_scale.data()) { - RMSNormInplace(layer_weights_.query_norm_scale.data(), q, - qkv_dim); - } - PositionalEncodingQK(q, pos, layer_, query_scale); - - const size_t start_pos = StartPos(pos, layer_); - size_t last_pos = pos; - const size_t prefix_end = queries_prefix_end_[query_idx]; - if (prefix_end > 0 && prefix_end - 1 > last_pos) { - // last_pos in QDotK and WeightedSumV is inclusive. - last_pos = prefix_end - 1; - } - - float* HWY_RESTRICT head_att = - activations_.att.Batch(interleaved_idx) + - head * activations_.seq_len; - QDotK(start_pos, last_pos, head_offset, q, kv_cache, head_att); - // SoftMax with optional SoftCap yields "probabilities" in - // head_att. - const size_t head_att_len = - std::min(last_pos + 1, activations_.seq_len); - MaybeLogitsSoftCap(activations_.weights_config.att_cap, - head_att, head_att_len); - Softmax(head_att, head_att_len); - - float* HWY_RESTRICT att_out = - activations_.att_out.Batch(interleaved_idx) + - head * qkv_dim; - WeightedSumV(start_pos, last_pos, head_att, layer_, head_offset, - div_seq_len_, kv_cache, att_out); - }); - } - - // Sums encoded (`att_out`) over num_heads (`layer_config_.heads`) and - // head_dim (`qkv_dim`) into output (`layer_out`). - HWY_NOINLINE void SumHeads(const size_t num_interleaved) { - PROFILER_ZONE("Gen.Attention.SumHeads"); - // att_weights and att_out are concatenated heads, each of length - // layer_config_.qkv_dim. Thus the [num_interleaved, - // layer_config_.model_dim] matmul output is the sum over heads. Compare - // gemma/modules.py: attn_output = self.attn_vec_einsum('BTNH,NHD->BTD', - // encoded) - HWY_DASSERT(layer_config_.model_dim > 0); - HWY_DASSERT(layer_config_.heads > 0); - HWY_DASSERT(layer_config_.qkv_dim > 0); - HWY_DASSERT(layer_weights_.att_weights.data() != nullptr); - HWY_DASSERT(activations_.att_out.All() != nullptr); - HWY_DASSERT(activations_.att_sums.All() != nullptr); - - const float* add = - layer_weights_.layer_config.softmax_attn_output_biases - ? layer_weights_.attention_output_biases.data_scale1() - : nullptr; - MatMul(ConstMatFromBatch(num_interleaved, activations_.att_out), - ConstMatFromWeights(layer_weights_.att_weights), add, - *activations_.env, RowPtrFromBatch(activations_.att_sums)); - } - - public: - // Constructor with explicit initialization of queries_prefix_end. This is - // needed for the Prefix-LM style attention. For standard causal attention, - // the other constructor can be used. - GemmaAttention(const QueriesPos& queries_pos, - const QueriesPos& queries_prefix_end, size_t num_tokens, - size_t layer, Activations& activations, - const LayerWeightsPtrs* layer_weights, - const hwy::Divisor& div_seq_len, const KVCaches& kv_caches) - : GemmaAttention(queries_pos, &queries_prefix_end, num_tokens, layer, - activations, layer_weights, div_seq_len, kv_caches) {} - // Constructor with default initialization to 0 for queries_prefix_end. - GemmaAttention(const QueriesPos& queries_pos, size_t num_tokens, size_t layer, - Activations& activations, - const LayerWeightsPtrs* layer_weights, - const hwy::Divisor& div_seq_len, const KVCaches& kv_caches) - : GemmaAttention(queries_pos, nullptr, num_tokens, layer, activations, - layer_weights, div_seq_len, kv_caches) {} - - // Full attention computation in three steps. - HWY_INLINE void operator()() { - const size_t num_interleaved = num_tokens_ * num_queries_; - ComputeQKV(num_interleaved); - DotSoftmaxWeightedSum(num_interleaved); - SumHeads(num_interleaved); - } - - private: - // Delegated Constructor that does most of the common work. - GemmaAttention(const QueriesPos& queries_pos, - const QueriesPos* queries_prefix_end, size_t num_tokens, - size_t layer, Activations& activations, - const LayerWeightsPtrs* layer_weights, - const hwy::Divisor& div_seq_len, const KVCaches& kv_caches) - : queries_pos_(queries_pos), - num_queries_(queries_pos.size()), - num_tokens_(num_tokens), - layer_(layer), - layer_config_(layer_weights->layer_config), - q_stride_(layer_config_.QStride()), - cache_layer_size_(layer_weights->layer_config.CacheLayerSize()), - cache_pos_size_(activations.cache_pos_size), - is_mha_(layer_config_.IsMHA()), - activations_(activations), - layer_weights_(*layer_weights), - div_seq_len_(div_seq_len), - kv_caches_(kv_caches), - pool_(activations.env->parallel.Pools().Pool(0)) { - HWY_DASSERT(num_queries_ <= kv_caches_.size()); - HWY_DASSERT_M((layer_config_.heads % layer_config_.kv_heads) == 0, - "query heads must be a multiple of key-value heads"); - if (queries_prefix_end != nullptr) { - queries_prefix_end_ = *queries_prefix_end; - } else { - queries_prefix_end_vec_.assign(num_queries_, 0); - queries_prefix_end_ = QueriesPos(queries_prefix_end_vec_.data(), - queries_prefix_end_vec_.size()); - } - } - - const QueriesPos& queries_pos_; - std::vector queries_prefix_end_vec_; - QueriesPos queries_prefix_end_; - const size_t num_queries_; - const size_t num_tokens_; - const size_t layer_; - const LayerConfig& layer_config_; - const size_t q_stride_ = 0; - const size_t cache_layer_size_ = 0; - const size_t cache_pos_size_ = 0; - const bool is_mha_ = false; - - Activations& activations_; - const LayerWeightsPtrs& layer_weights_; - const hwy::Divisor& div_seq_len_; - const KVCaches& kv_caches_; - hwy::ThreadPool& pool_; -}; - -template -HWY_NOINLINE void Attention( - LayerAttentionType type, const QueriesPos& queries_pos, - const QueriesPos& queries_prefix_end, size_t num_tokens, size_t layer, - Activations& activations, const LayerWeightsPtrs* layer_weights, - const hwy::Divisor& div_seq_len, const KVCaches& kv_caches) { - if (type == LayerAttentionType::kGemma) { - GemmaAttention(queries_pos, queries_prefix_end, num_tokens, layer, - activations, layer_weights, div_seq_len, kv_caches)(); - } else { - // Only reached if the model is Griffin. - // The kv_caches are allocated only for the griffin layers, so we need to - // map the layer index to the griffin layer index. - auto type = layer_weights->layer_config.type; - size_t layer_of_type = - activations.weights_config.NumLayersOfTypeBefore(type, layer); - HWY_ASSERT(queries_pos.size() == 1); - GriffinRecurrent(queries_pos[0], num_tokens, layer_of_type, activations, - layer_weights, kv_caches); - } -} - -// Wrapper class; holds arguments in member variables to shorten call sites. -// The main differences to GemmaAttention are: -// - no KV Cache necessary, attention is always all-to-all and not causal. -// - no potential wrap-around, attention always goes from 0 to kSeqLen. -// - no need for batching, as we are always computing attention for kSeqLen -// tokens. -// This results in a much simpler implementation. However, to avoid duplicating -// code, we should still consider merging the two classes. -// TODO(keysers): Refactor to share code with GemmaAttention. -template -class VitAttention { - // Computes Q, K, V for all heads, stored in activations_.q. - HWY_NOINLINE void ComputeQKV() { - PROFILER_ZONE("Gen.VitAttention.QKV"); - auto& qkv = activations_.q; - HWY_ASSERT(qkv.BatchSize() == num_tokens_); - HWY_ASSERT(qkv.Cols() == layer_config_.heads * 3 * layer_config_.qkv_dim); - MatMul(ConstMatFromBatch(num_tokens_, activations_.pre_att_rms_out), - ConstMatFromWeights(layer_weights_.vit.qkv_einsum_w), - layer_weights_.vit.qkv_einsum_b.data_scale1(), *activations_.env, - RowPtrFromBatch(qkv)); - } - - // TODO(philculliton): transition fully to MatMul. - HWY_NOINLINE void DotSoftmaxWeightedSumMatrix() { - const size_t qkv_dim = layer_config_.qkv_dim; - const size_t heads = layer_config_.heads; - HWY_ASSERT_M(heads == layer_config_.kv_heads, "Vit expects MHA"); - const size_t seq_len = activations_.seq_len; - const float query_scale = 1.0f / sqrtf(static_cast(qkv_dim)); - PROFILER_ZONE("Gen.VitAttention.DotSoftmax"); - - // Shift Q, K, VT to RowVectorBatches with AllocateAlignedRows(extents) - RowVectorBatch Q = - AllocateAlignedRows(Extents2D(num_tokens_, qkv_dim)); - RowVectorBatch K = - AllocateAlignedRows(Extents2D(seq_len, qkv_dim)); - RowVectorBatch C(Extents2D(num_tokens_, seq_len)); - - // Initialize att_out to zero prior to head loop. - hwy::ZeroBytes(activations_.att_out.All(), - num_tokens_ * heads * qkv_dim * sizeof(float)); - - for (size_t head = 0; head < heads; ++head) { - pool_.Run(0, num_tokens_, [&](uint64_t task, size_t /*thread*/) HWY_ATTR { - const size_t token = task; - float* HWY_RESTRICT q = - activations_.q.Batch(token) + head * 3 * qkv_dim; - // TODO: shift to MatMul with A.scale once MatMul is confirmed working - MulByConst(query_scale, q, qkv_dim); - hwy::CopyBytes(q, Q.Batch(token), qkv_dim * sizeof(float)); - }); - - pool_.Run(0, seq_len, [&](uint64_t task, size_t /*thread*/) HWY_ATTR { - const size_t seq_idx = task; - float* HWY_RESTRICT k = - activations_.q.Batch(seq_idx) + head * 3 * qkv_dim + qkv_dim; - hwy::CopyBytes(k, K.Batch(seq_idx), qkv_dim * sizeof(float)); - }); - - // this produces C, a (num_tokens_, seq_len) matrix of dot products - MatMul(ConstMatFromBatch(Q.BatchSize(), Q), - ConstMatFromBatch(K.BatchSize(), K), nullptr, *activations_.env, - RowPtrFromBatch(C)); - - pool_.Run(0, num_tokens_, [&](uint64_t task, size_t /*thread*/) HWY_ATTR { - float* HWY_RESTRICT c = C.Batch(task); - Softmax(c, C.Cols()); - }); - - pool_.Run(0, num_tokens_, [&](uint64_t task, size_t /*thread*/) HWY_ATTR { - size_t token = task; - float* HWY_RESTRICT att_out = - activations_.att_out.Batch(token) + head * qkv_dim; - for (size_t i = 0; i < seq_len; ++i) { - float* HWY_RESTRICT v = - activations_.q.Batch(i) + head * 3 * qkv_dim + 2 * qkv_dim; - MulByConstAndAdd(C.Batch(token)[i], v, att_out, qkv_dim); - } - }); - } - } - - HWY_NOINLINE void DotSoftmaxWeightedSum() { - const size_t qkv_dim = layer_config_.qkv_dim; - const size_t heads = layer_config_.heads; - HWY_ASSERT_M(heads == layer_config_.kv_heads, "Vit expects MHA"); - const size_t seq_len = activations_.seq_len; - const float query_scale = 1.0f / sqrtf(static_cast(qkv_dim)); - PROFILER_ZONE("Gen.VitAttention.DotSoftmax"); - - // Compute Q.K, softmax, and weighted V. - pool_.Run(0, layer_config_.heads * num_tokens_, - [&](uint64_t task, size_t /*thread*/) HWY_ATTR { - const size_t head = task % layer_config_.heads; - const size_t token = task / layer_config_.heads; - // Compute Q.K scores, which are "logits" stored in head_att. - float* HWY_RESTRICT q = - activations_.q.Batch(token) + head * 3 * qkv_dim; - MulByConst(query_scale, q, qkv_dim); - float* HWY_RESTRICT head_att = - activations_.att.Batch(token) + head * activations_.seq_len; - for (size_t i = 0; i < seq_len; ++i) { - float* HWY_RESTRICT k = - activations_.q.Batch(i) + head * 3 * qkv_dim + qkv_dim; - head_att[i] = Dot(q, k, qkv_dim); // score = q.k - } - // SoftMax yields "probabilities" in head_att. - Softmax(head_att, seq_len); - // Compute weighted sum of v into att_out. - float* HWY_RESTRICT att_out = - activations_.att_out.Batch(token) + head * qkv_dim; - hwy::ZeroBytes(att_out, qkv_dim * sizeof(*att_out)); - for (size_t i = 0; i < seq_len; ++i) { - float* HWY_RESTRICT v = activations_.q.Batch(i) + - head * 3 * qkv_dim + 2 * qkv_dim; - MulByConstAndAdd(head_att[i], v, att_out, qkv_dim); - } - }); - } - - // Sums encoded (`att_out`) over num_heads (`layer_config_.heads`) and - // head_dim (`qkv_dim`) into output (`att_sums`). - HWY_NOINLINE void SumHeads() { - PROFILER_ZONE("Gen.VitAttention.SumHeads"); - auto* bias = layer_weights_.vit.attn_out_b.data_scale1(); - // att_weights and att_out are concatenated heads, each of length - // qkv_dim. Thus the [num_tokens_, layer_config_.model_dim] - // matmul output is the sum over heads. - auto att_out = ConstMatFromBatch(num_tokens_, activations_.att_out); - auto att_weights = ConstMatFromWeights(layer_weights_.vit.attn_out_w); - auto att_sums = RowPtrFromBatch(activations_.att_sums); - MatMul(att_out, att_weights, bias, *activations_.env, att_sums); - } - - public: - VitAttention(size_t num_tokens, size_t layer, Activations& activations, - const LayerWeightsPtrs* layer_weights) - : num_tokens_(num_tokens), - layer_(layer), - activations_(activations), - layer_weights_(*layer_weights), - layer_config_(layer_weights->layer_config), - pool_(activations.env->parallel.Pools().Pool(0)) {} - - HWY_INLINE void operator()() { - ComputeQKV(); - if (activations_.weights_config.wrapping == PromptWrapping::GEMMA_VLM) { - DotSoftmaxWeightedSumMatrix(); - } else { - DotSoftmaxWeightedSum(); - } - SumHeads(); - } - - private: - const size_t num_tokens_; - const size_t layer_; - Activations& activations_; - const LayerWeightsPtrs& layer_weights_; - const LayerConfig& layer_config_; - hwy::ThreadPool& pool_; -}; - -template -HWY_NOINLINE void Activation(ActivationType activation, T* HWY_RESTRICT c1, - T* HWY_RESTRICT c2, size_t count) { - PROFILER_ZONE("Gen.Activation"); +void Activation(ActivationType activation, T* HWY_RESTRICT c1, + const T* HWY_RESTRICT c2, const size_t count, + const size_t worker) { + PROFILER_ZONE2(worker, "Gen.Activation"); namespace hn = hwy::HWY_NAMESPACE; using DF = hn::ScalableTag; using VF = hn::Vec; @@ -769,849 +62,88 @@ HWY_NOINLINE void Activation(ActivationType activation, T* HWY_RESTRICT c1, }); } -template -HWY_NOINLINE void FFWNoVit(Activations& activations, size_t num_interleaved, - const LayerWeightsPtrs* layer_weights) { - PROFILER_ZONE("Gen.FFW"); - const size_t model_dim = layer_weights->layer_config.model_dim; - const size_t ffh_hidden_dim = layer_weights->layer_config.ff_hidden_dim; - HWY_DASSERT(num_interleaved <= activations.bf_pre_ffw_rms_out.BatchSize()); +// No C2 multiplier. +template +void ActivationBatched(ActivationType activation, Mat& c1, NestedPools& pools) { + using T = typename Mat::T; + const size_t pkg_idx = 0; + SmallParallelFor( + c1.Rows(), pools, pkg_idx, [&](uint64_t task, size_t worker) { + // Cast to correct type so type deduction works. + Activation(activation, c1.Row(task), static_cast(nullptr), + c1.Cols(), worker); + }); +} - const bool add_bias = layer_weights->layer_config.ff_biases; +template +HWY_NOINLINE void ActivationBatched(ActivationType activation, Mat& c1, + const Mat* c2, NestedPools& pools) { + using T = typename Mat::T; + HWY_DASSERT(c1.SameShape(*c2)); + const size_t pkg_idx = 0; + if (c2 && c2->HasPtr()) { + SmallParallelFor(c1.Rows(), pools, pkg_idx, + [&](uint64_t task, size_t worker) { + Activation(activation, c1.Row(task), c2->Row(task), + c1.Cols(), worker); + }); + } else { // No multiplier + SmallParallelFor( + c1.Rows(), pools, pkg_idx, [&](uint64_t task, size_t worker) { + Activation(activation, c1.Row(task), static_cast(nullptr), + c1.Cols(), worker); + }); + } +} + +template +HWY_NOINLINE void ResidualConnection(const MatPtrT& other, + MatPtrT& HWY_RESTRICT x, + const LayerWeights& layer, + bool is_attention, ThreadingContext& ctx) { + // ResidualType::Add + AddFromBatched(other, x, ctx); +} + +template +void PostNorm(PostNormType post_norm, const MatPtr& weights, + MatPtrT& inout, ThreadingContext& ctx) { + HWY_DASSERT(weights.Rows() == 1); + if (post_norm == PostNormType::Scale) { + RMSNormInplaceBatched(weights, inout, ctx); + } +} + +static inline void FFWNoVit(const LayerWeightsPtrs& layer, + Activations& activations, MatMulEnv& env) { + PROFILER_ZONE("Gen.FFW"); + const LayerConfig& layer_config = layer.layer_config; + const size_t ffh_hidden_dim = layer_config.ff_hidden_dim; + + const bool add_bias = layer_config.ff_biases; const float* bias1 = - add_bias ? layer_weights->ffw_gating_biases.data_scale1() : nullptr; + add_bias ? layer.ffw_gating_biases.PackedScale1() : nullptr; const float* bias2 = add_bias ? bias1 + ffh_hidden_dim : nullptr; const float* output_bias = - add_bias ? layer_weights->ffw_output_biases.data_scale1() : nullptr; - - // Define slightly more readable names for the weights and activations. - const auto x = - ConstMatFromBatch(num_interleaved, activations.bf_pre_ffw_rms_out); - - auto hidden_activations = RowPtrFromBatch(activations.C1); - auto multiplier = RowPtrFromBatch(activations.C2); - auto ffw_out = RowPtrFromBatch(activations.ffw_out); - - // gating_einsum_w holds two half-matrices. We plan to change the importer to - // avoid this confusion by splitting into gating_einsum_w1 and - // gating_einsum_w2. - const bool split = !!layer_weights->gating_einsum_w.data(); - auto w1 = split ? ConstMatFromWeights(layer_weights->gating_einsum_w) - : ConstMatFromWeights(layer_weights->gating_einsum_w1); - auto w2 = split ? ConstMatFromWeights(layer_weights->gating_einsum_w, - model_dim * ffh_hidden_dim) - : ConstMatFromWeights(layer_weights->gating_einsum_w2); - if (split) { - // Ensure that B.Extents().row matches C.Cols() because MatMul checks that. - w1.ShrinkRows(ffh_hidden_dim); - w2.ShrinkRows(ffh_hidden_dim); - } - auto w_output = ConstMatFromWeights(layer_weights->linear_w); + add_bias ? layer.ffw_output_biases.PackedScale1() : nullptr; // Compute the hidden layer activations. - MatMul(x, w1, bias1, *activations.env, hidden_activations); - MatMul(x, w2, bias2, *activations.env, multiplier); + CallMatMul(activations.pre_ffw_rms_out, layer.gating_einsum_w1, bias1, env, + activations.C1); + CallMatMul(activations.pre_ffw_rms_out, layer.gating_einsum_w2, bias2, env, + activations.C2); // Activation (Gelu) and maybe multiply by gate. Store activations in act. - Activation(layer_weights->layer_config.activation, hidden_activations.Row(0), - multiplier.Row(0), ffh_hidden_dim * num_interleaved); + ActivationBatched(layer_config.activation, activations.C1, &activations.C2, + env.ctx.pools); // Hidden layer -> output layer. - auto activations_mat = MakeConstMat( - hidden_activations.Row(0), Extents2D(num_interleaved, ffh_hidden_dim), - hidden_activations.Stride()); - - MatMul(activations_mat, w_output, output_bias, *activations.env, ffw_out); -} - -// Same as FFWNoVit, but with different layer_weights members and no second -// gating matrix. -template -HWY_NOINLINE void FFWVit(Activations& activations, size_t num_interleaved, - const LayerWeightsPtrs* layer_weights) { - PROFILER_ZONE("Gen.FFW"); - const size_t ff_hidden_dim = layer_weights->layer_config.ff_hidden_dim; - HWY_DASSERT(num_interleaved <= activations.bf_pre_ffw_rms_out.BatchSize()); - - const bool add_bias = layer_weights->layer_config.ff_biases; - const float* bias1 = - add_bias ? layer_weights->vit.linear_0_b.data_scale1() : nullptr; - const float* output_bias = - add_bias ? layer_weights->vit.linear_1_b.data_scale1() : nullptr; - - // Define slightly more readable names for the weights and activations. - const auto x = - ConstMatFromBatch(num_interleaved, activations.bf_pre_ffw_rms_out); - - auto hidden_activations = RowPtrFromBatch(activations.C1); - auto ffw_out = RowPtrFromBatch(activations.ffw_out); - - auto w1 = ConstMatFromWeights(layer_weights->vit.linear_0_w); - auto w_output = ConstMatFromWeights(layer_weights->vit.linear_1_w); - - // Compute the hidden layer activations. - MatMul(x, w1, bias1, *activations.env, hidden_activations); - - // Activation (Gelu), store in act. - RowPtrF multiplier = RowPtrF(nullptr, 0); - Activation(layer_weights->layer_config.activation, hidden_activations.Row(0), - multiplier.Row(0), ff_hidden_dim * num_interleaved); - - // Hidden layer -> output layer. - auto activations_mat = MakeConstMat(hidden_activations.Row(0), - Extents2D(num_interleaved, ff_hidden_dim), - hidden_activations.Stride()); - - MatMul(activations_mat, w_output, output_bias, *activations.env, ffw_out); -} - -// `batch_idx` indicates which row of `x` to write to. -// `pos` is the *token*'s position, not the start of the batch, because this is -// called for batches of tokens in prefill, but batches of queries in decode. -// -// For GEMMA_VLM, image tokens are copied into -2 locations (per the Gemma 3 -// spec) until we run out of image tokens. This allows for a multi-image prompt -// if -2 locations with appropriate begin/end image tokens are created by the -// calling application. -template -HWY_NOINLINE void EmbedMMToken(int token, size_t batch_idx, size_t pos, - size_t pos_in_prompt, - const ModelWeightsPtrs& weights, - RowVectorBatch& x, - const ImageTokens* image_tokens, - size_t& image_token_position) { - // Image tokens just need to be copied. - if (weights.weights_config.wrapping == PromptWrapping::GEMMA_VLM && - image_tokens != nullptr && token == -2 && - image_token_position < image_tokens->BatchSize()) { - hwy::CopyBytes(image_tokens->Batch(image_token_position), - x.Batch(batch_idx), x.Cols() * sizeof(x.Const()[0])); - image_token_position++; - return; - } - - if (weights.weights_config.wrapping == PromptWrapping::PALIGEMMA && - image_tokens != nullptr && pos_in_prompt < image_tokens->BatchSize()) { - hwy::CopyBytes(image_tokens->Batch(pos_in_prompt), x.Batch(batch_idx), - x.Cols() * sizeof(x.Const()[0])); - return; - } - - const size_t model_dim = weights.weights_config.model_dim; - const size_t vocab_size = weights.weights_config.vocab_size; - const float emb_scaling = EmbeddingScaling(model_dim); - - HWY_DASSERT(token >= 0); - HWY_DASSERT(token < static_cast(vocab_size)); - - const hn::ScalableTag df; - DecompressAndZeroPad( - df, - MakeSpan(weights.embedder_input_embedding.data(), vocab_size * model_dim), - token * model_dim, x.Batch(batch_idx), model_dim); - MulByConst(emb_scaling * weights.embedder_input_embedding.scale(), - x.Batch(batch_idx), model_dim); - if (weights.weights_config.absolute_pe) { - AddAbsolutePositionalEmbeddings(x.Batch(batch_idx), model_dim, pos); - } -} - -// `batch_idx` indicates which row of `x` to write to. -// `pos` is the *token*'s position, not the start of the batch, because this is -// called for batches of tokens in prefill, but batches of queries in decode. -// This version of the function doesn't track internal image token position. -template -HWY_NOINLINE void EmbedToken(int token, size_t batch_idx, size_t pos, - size_t pos_in_prompt, - const ModelWeightsPtrs& weights, - RowVectorBatch& x, - const ImageTokens* image_tokens) { - size_t image_token_position = 0; - EmbedMMToken(token, batch_idx, pos, pos_in_prompt, weights, x, - image_tokens, image_token_position); -} - -template -HWY_NOINLINE void ResidualConnection( - size_t num_interleaved, T* HWY_RESTRICT other, T* HWY_RESTRICT x, - const LayerWeightsPtrs* layer_weights, bool is_attention) { - // ResidualType::Add - AddFromBatched(num_interleaved, other, x, - layer_weights->layer_config.model_dim); -} - -template -void PostNorm(PostNormType post_norm, size_t num_interleaved, - const WeightT& weights, InOutT* inout) { - if (post_norm == PostNormType::Scale) { - RMSNormInplaceBatched(num_interleaved, weights.data_scale1(), inout, - weights.NumElements()); - } -} - -template -HWY_NOINLINE void TransformerLayer(const QueriesPos& queries_pos, - const QueriesPos& queries_prefix_end, - size_t num_tokens, size_t cache_layer_idx, - const LayerWeightsPtrs* layer_weights, - Activations& activations, - const hwy::Divisor& div_seq_len, - const KVCaches& kv_caches) { - const size_t model_dim = activations.weights_config.model_dim; - const size_t num_interleaved = num_tokens * queries_pos.size(); - auto type = layer_weights->layer_config.type; - - RMSNormBatched(num_interleaved, activations.x.All(), - layer_weights->pre_attention_norm_scale.data_scale1(), - activations.pre_att_rms_out.All(), model_dim); - - Attention(type, queries_pos, queries_prefix_end, num_tokens, cache_layer_idx, - activations, layer_weights, div_seq_len, kv_caches); - - PostNorm(layer_weights->layer_config.post_norm, num_interleaved, - layer_weights->post_attention_norm_scale, - activations.att_sums.All()); - - ResidualConnection(num_interleaved, activations.att_sums.All(), - activations.x.All(), layer_weights, /*is_attention=*/true); - - RMSNormBatched(num_interleaved, activations.x.All(), - layer_weights->pre_ffw_norm_scale.data_scale1(), - activations.bf_pre_ffw_rms_out.All(), model_dim); - - if (layer_weights->layer_config.type == LayerAttentionType::kVit) { - FFWVit(activations, num_interleaved, layer_weights); - } else { - FFWNoVit(activations, num_interleaved, layer_weights); - } - - PostNorm(layer_weights->layer_config.post_norm, num_interleaved, - layer_weights->post_ffw_norm_scale, activations.ffw_out.All()); - - ResidualConnection(num_interleaved, activations.ffw_out.All(), - activations.x.All(), layer_weights, - /*is_attention=*/false); -} - -// Vit transformer layer. Some comments below refer to the Vit implementation in -// the Big Vision codebase. See -// github.com/google-research/big_vision/blob/main/big_vision/models/vit.py -// TODO(keysers): consider adding a wrapper for both LayerNorm with RMSNorm and -// try merging this with TransformerLayer. -template -HWY_NOINLINE void VitTransformerLayer(size_t num_tokens, size_t layer, - const LayerWeightsPtrs* layer_weights, - Activations& activations) { - const size_t model_dim = activations.weights_config.model_dim; - auto type = layer_weights->layer_config.type; - HWY_DASSERT(type == LayerAttentionType::kVit); - (void)type; - - auto& x = activations.x; - HWY_DASSERT(x.BatchSize() == num_tokens); - HWY_DASSERT(x.Cols() == model_dim); - - // y = nn.LayerNorm()(x) - // y ~ pre_att_rms_out - LayerNormBatched(num_tokens, x.All(), - layer_weights->vit.layer_norm_0_scale.data_scale1(), - layer_weights->vit.layer_norm_0_bias.data_scale1(), - activations.pre_att_rms_out.All(), model_dim); - - // y = out["sa"] = nn.MultiHeadDotProductAttention(...)(y, y) - // y ~ att_sums - VitAttention(num_tokens, layer, activations, layer_weights)(); - - // x = out["+sa"] = x + y - AddFromBatched(num_tokens, activations.att_sums.All(), x.All(), model_dim); - - // y = nn.LayerNorm()(x) - // y ~ bf_pre_ffw_rms_out - LayerNormBatched(num_tokens, x.All(), - layer_weights->vit.layer_norm_1_scale.data_scale1(), - layer_weights->vit.layer_norm_1_bias.data_scale1(), - activations.bf_pre_ffw_rms_out.All(), model_dim); - - // y = out["mlp"] = MlpBlock(...)(y) - // y ~ ffw_out - FFWVit(activations, num_tokens, layer_weights); - - // x = out["+mlp"] = x + y - AddFromBatched(num_tokens, activations.ffw_out.All(), x.All(), model_dim); -} - -// Prefill() and Transformer() increment positions in-place. -using QueriesMutablePos = hwy::Span; - -// Populates KV cache for batches of tokens from one query at a time. -template -HWY_NOINLINE void Prefill( - const QueriesPromptTokens& queries_prompt, - const QueriesMutablePos& queries_pos, const QueriesPos& queries_prefix_end, - const size_t query_idx_start, const ModelWeightsPtrs& weights, - Activations& activations, const RuntimeConfig& runtime_config, - const hwy::Divisor& div_seq_len, const KVCaches& kv_caches) { - PROFILER_ZONE("Gen.Prefill"); - const size_t num_queries = queries_prompt.size(); - HWY_DASSERT(queries_pos.size() == num_queries); - HWY_DASSERT(queries_prefix_end.size() == num_queries); - HWY_DASSERT(kv_caches.size() == num_queries); - - // Batches are important for amortizing loading weights over multiple tokens. - // This is possible in prefill because we know all tokens beforehand, whereas - // decode depends on the previous output token. However, each prefill batch of - // a query requires that preceding batches already wrote to the KV cache, - // hence we sequentially loop over token batches. We can reduce the number of - // iterations by increasing the batch size, but this also increases arithmetic - // intensity, and so we are eventually compute-limited. We could devote some - // threads to parallelizing over queries, but for simplicity we assign them - // all to MatMul. - const size_t max_tbatch_size = activations.x.BatchSize(); - - // For each query. `qi` is within the batch, not the global query index. - for (size_t qi = 0; qi < num_queries; ++qi) { - // Single query at a time, so pass slices of the spans because - // GemmaAttention will only access the first KV cache and position. - QueriesPos single_query_pos(&queries_pos[qi], 1); - QueriesPos single_query_prefix_end(&queries_prefix_end[qi], 1); - KVCaches single_kv_cache(&kv_caches[qi], 1); - - const size_t prompt_size = queries_prompt[qi].size(); - // In autoregressive mode, we don't need to prefill the last token, so - 1. - size_t prefill_this_query = prompt_size - 1; - const size_t prefix_end_this_query = queries_prefix_end[qi]; - // We can't attend beyond the prompt_size. - HWY_ASSERT(prefix_end_this_query <= prompt_size); - // Special case: if the prefix includes the last token, we need to prefill - // the last token, too. However, we need to rewind this for the generation - // of the first token. So we need to keep track of this. - // TODO: consider implementing masking instead of this logic? - const bool attend_to_last_token = - (prefill_this_query < prefix_end_this_query); - if (attend_to_last_token) { - // The difference can be at most 1. - prefill_this_query += 1; - HWY_ASSERT(prefill_this_query == prefix_end_this_query); - } - // In prefix-LM mode, we need to look at all the tokens for the prefix in - // one iteration through the layers, so we need a large enough batch size. - HWY_ASSERT(prefix_end_this_query == 0 || - max_tbatch_size >= prefill_this_query); - - // For each batch of tokens in the query: - for (size_t tbatch_start = 0; tbatch_start < prefill_this_query; - tbatch_start += max_tbatch_size) { - const size_t tbatch_size = - HWY_MIN(max_tbatch_size, prefill_this_query - tbatch_start); - - // Fill activations.x (much faster than TransformerLayer). - size_t image_token_position = 0; - for (size_t ti = 0; ti < tbatch_size; ++ti) { - const size_t pos = queries_pos[qi] + ti; - const size_t pos_in_prompt = tbatch_start + ti; - const int token = queries_prompt[qi][pos_in_prompt]; - EmbedMMToken(token, ti, pos, pos_in_prompt, weights, activations.x, - runtime_config.image_tokens, image_token_position); - } - - // Transformer with one batch of tokens from a single query. - for (size_t layer = 0; - layer < weights.weights_config.layer_configs.size(); ++layer) { - const auto* layer_weights = weights.GetLayer(layer); - TransformerLayer(single_query_pos, single_query_prefix_end, tbatch_size, - layer, layer_weights, activations, div_seq_len, - single_kv_cache); - } - - // NOTE: we unconditionally call StreamToken, even if EOS. - for (size_t ti = 0; ti < tbatch_size; ++ti) { - const size_t pos = queries_pos[qi] + ti; - const size_t pos_in_prompt = tbatch_start + ti; - const int token = queries_prompt[qi][pos_in_prompt]; - if (pos_in_prompt < prompt_size - 1) { - runtime_config.StreamToken(query_idx_start + qi, pos, token, 0.0f); - } else { - // The last token will be streamed later and we should only get here - // if we need to attend to the last token because it is in the prefix. - HWY_ASSERT(attend_to_last_token); - } - } - - queries_pos[qi] += tbatch_size; - } // for tbatch_start - if (attend_to_last_token) { - // We need to rewind the position for the last token that we only - // attended to to make sure the prefix LM sees everything. - // This means we duplicate work on the last prompt token in autoregressive - // decoding. Alternatives: (1) real masking; (2) always prefill the last - // token and only generate the next one from the already prefilled - // activations. - queries_pos[qi] -= 1; - } - } -} - -// Gets the patches of the image and embeds them with the image embedding -// kernel. The result is stored in activations.x. -template -HWY_NOINLINE void EmbedImagePatches(const Image& image, - const ModelWeightsPtrs& weights, - Activations& activations) { - const size_t model_dim = weights.weights_config.vit_config.model_dim; - const size_t patch_width = weights.weights_config.vit_config.patch_width; - const size_t seq_len = weights.weights_config.vit_config.seq_len; - const size_t patch_size = patch_width * patch_width * 3; - HWY_DASSERT(weights.vit_img_embedding_kernel.NumElements() == - patch_size * model_dim); - HWY_DASSERT(activations.x.Cols() == model_dim); - std::vector> image_patches(seq_len); - for (size_t i = 0; i < seq_len; ++i) { - image_patches[i] = hwy::AllocateAligned(patch_size); - image.GetPatch(i, image_patches[i].get()); - } - // img/embedding/kernel has original shape (14, 14, 3, 1152) - // H x W x C x D transposed to D x (H x W x C) so here (1152, 14 * 14 * 3) - // image_patches is (256, 14 * 14 * 3) - // This could be done as one MatMul like: - // RowVectorBatch image_patches(kSeqLen, kPatchSize); - // [Get patches] - // MatMul( - // MatFromBatch(kVitSeqLen, image_patches), - // MatFromWeights(weights.vit_img_embedding_kernel), - // weights.vit_img_embedding_bias.data_scale1(), *activations.env, - // RowPtrF(activations.x.All(), kVitModelDim)); - // However, MatMul currently requires that - // A.cols % (2 * hn::Lanes(hn::ScalableTag())) == 0 - // which is not the case here. We should relax that requirement on MatMul and - // then use the above. For now, we rely on MatVecAdd instead. - for (size_t i = 0; i < seq_len; ++i) { - MatVecAdd( - weights.vit_img_embedding_kernel, 0, model_dim, patch_size, - image_patches[i].get(), weights.vit_img_embedding_bias.data_scale1(), - activations.x.Batch(i), activations.env->parallel.Pools().Pool(0)); - } - // Add position embeddings. - AddFrom(weights.vit_img_pos_embedding.data_scale1(), activations.x.All(), - seq_len * model_dim); -} - -// Prefills the image tokens with the ViT encoder. -template -HWY_NOINLINE void PrefillVit(const ModelWeightsPtrs& weights, - const RuntimeConfig& runtime_config, - const Image& image, ImageTokens& image_tokens, - Activations& activations) { - PROFILER_ZONE("Gen.PrefillVit"); - const size_t num_tokens = weights.weights_config.vit_config.seq_len; - const size_t vit_model_dim = weights.weights_config.vit_config.model_dim; - HWY_ASSERT(num_tokens == activations.x.BatchSize()); - // Embed the image patches. - EmbedImagePatches(image, weights, activations); - // Go through all layers. - for (size_t layer = 0; - layer < weights.weights_config.vit_config.layer_configs.size(); - ++layer) { - const auto* layer_weights = weights.GetVitLayer(layer); - VitTransformerLayer(num_tokens, layer, layer_weights, activations); - } - // Final Layernorm. - LayerNormBatched(num_tokens, activations.x.All(), - weights.vit_encoder_norm_scale.data_scale1(), - weights.vit_encoder_norm_bias.data_scale1(), - activations.x.All(), vit_model_dim); - - if (weights.weights_config.wrapping == PromptWrapping::GEMMA_VLM) { - activations.x = AvgPool4x4(activations.x); - - // Apply soft embedding norm before input projection. - RMSNormInplace(weights.mm_embed_norm.data_scale1(), activations.x.All(), - vit_model_dim); - } - - // Apply head embedding into image_tokens of size of the LLM kModelDim. - MatMul(ConstMatFromBatch(activations.x.BatchSize(), activations.x), - ConstMatFromWeights(weights.vit_img_head_kernel), - weights.vit_img_head_bias.data_scale1(), *activations.env, - RowPtrFromBatch(image_tokens)); -} - -// Generates one token for each query. `queries_token` is the previous token -// from each query, and `queries_pos` are their position in the sequence. -template -HWY_NOINLINE void Transformer( - const QueriesToken& queries_token, const QueriesMutablePos& queries_pos, - const QueriesPos& queries_prefix_end, const ModelWeightsPtrs& weights, - Activations& activations, const hwy::Divisor& div_seq_len, - const KVCaches& kv_caches, const LayersOutputFunc& layers_output, - const ActivationsObserverFunc& activations_observer) { - const size_t model_dim = weights.weights_config.model_dim; - const size_t num_queries = queries_token.size(); - HWY_DASSERT(queries_pos.size() == num_queries); - HWY_DASSERT(queries_prefix_end.size() == num_queries); - - if (layers_output) { - for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) { - const float token_f = queries_token[query_idx]; - layers_output(query_idx, queries_pos[query_idx], "tokens", -1, &token_f, - 1); - } - } - - size_t image_token_position = 0; - for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) { - EmbedMMToken(queries_token[query_idx], query_idx, queries_pos[query_idx], - /*pos_in_prompt=*/0, weights, activations.x, - /*image_tokens=*/nullptr, image_token_position); - } - - for (size_t layer = 0; layer < weights.c_layers.size(); ++layer) { - const LayerWeightsPtrs* layer_weights = weights.GetLayer(layer); - TransformerLayer(queries_pos, queries_prefix_end, /*num_tokens=*/1, layer, - layer_weights, activations, div_seq_len, kv_caches); - - if (activations_observer) { - activations_observer(queries_pos, layer, activations); - } - } - - RMSNormInplaceBatched(num_queries, weights.final_norm_scale.data_scale1(), - activations.x.All(), model_dim); - - if (activations_observer) { - activations_observer(queries_pos, -1, activations); - } - for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) { - queries_pos[query_idx] += 1; - } -} - -// Placeholder for internal test3, do not remove - -// Returns the min and max number of tokens for all queries. -static size_t MaxQueryLength(const QueriesPromptTokens& queries_prompt) { - size_t max_prompt_size = 0; - for (size_t i = 0; i < queries_prompt.size(); ++i) { - max_prompt_size = std::max(max_prompt_size, queries_prompt[i].size()); - } - return max_prompt_size; -} - -// Holds "is at end of stream" state for each query. -class TokenStreamer { - public: - explicit TokenStreamer(const RuntimeConfig& runtime_config, - const ModelConfig& model_config) - : runtime_config_(runtime_config), model_config_(model_config) {} - - // Returns whether the query was already at, or has just reached, the end of - // the stream: either via token == eos_id, or StreamToken returning false. - bool operator()(size_t query_idx, size_t pos, int token, float prob) { - if (HWY_UNLIKELY(is_eos_.Get(query_idx))) return true; - - if (!runtime_config_.StreamToken(query_idx, pos, token, prob) || - model_config_.IsEOS(token)) { - is_eos_.Set(query_idx); - return true; - } - - return false; - } - - private: - const RuntimeConfig& runtime_config_; - const ModelConfig& model_config_; - hwy::BitSet4096<> is_eos_; -}; - -HWY_INLINE SampleFunc ChooseSampleFunc(const RuntimeConfig& runtime_config) { - // If user provided a sample_func, use it. - if (runtime_config.sample_func) return runtime_config.sample_func; - - // Fast path for top-1 with no accept_token. - if (runtime_config.top_k == 1 && !runtime_config.accept_token) { - return [](float* logits, size_t vocab_size) HWY_ATTR -> TokenAndProb { - PROFILER_ZONE("Gen.Sample Top1"); - return Top1OfSoftmax(logits, vocab_size); - }; - } - - // General case: Softmax with top-k sampling. - return [&runtime_config](float* logits, - size_t vocab_size) HWY_ATTR -> TokenAndProb { - PROFILER_ZONE("Gen.Sample general"); - return FusedSoftmaxAndSampleTopK( - logits, runtime_config.top_k, vocab_size, *runtime_config.gen, - runtime_config.temperature, runtime_config.accept_token); - }; -} - -template -// Runs one decode step for all the queries in the batch. Returns true if all -// queries are at . -bool DecodeStepT(const ModelWeightsPtrs& weights, - const RuntimeConfig& runtime_config, - const QueriesPromptTokens& queries_prompt, - const size_t query_idx_start, const KVCaches& kv_caches, - const QueriesPos& queries_prefix_end, - const hwy::Divisor div_seq_len, const size_t vocab_size, - const SampleFunc& sample_token, Activations& activations, - TokenStreamer& token_streamer, std::vector& gen_tokens, - TimingInfo& timing_info, - const QueriesMutablePos& queries_mutable_pos) { - const size_t num_queries = queries_prompt.size(); - // Decode generates one token per query and increments - // queries_mutable_pos. - Transformer(QueriesToken(gen_tokens.data(), num_queries), queries_mutable_pos, - queries_prefix_end, weights, activations, div_seq_len, kv_caches, - runtime_config.layers_output, - runtime_config.activations_observer); - // queries_pos are incremented by Transformer. - - bool all_queries_eos = true; - { - PROFILER_ZONE("Gen.EmbeddingMatmul"); - // Compute logits from last layer activations. - MatMul(ConstMatFromBatch(num_queries, activations.x), - ConstMatFromWeights(weights.embedder_input_embedding), - /*add=*/nullptr, *activations.env, - RowPtrFromBatch(activations.logits)); - } - PROFILER_ZONE("Gen.Softcap+Sample+Stream"); - for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) { - float* HWY_RESTRICT logits = activations.logits.Batch(query_idx); - MaybeLogitsSoftCap(weights.weights_config.final_cap, logits, vocab_size); - const TokenAndProb tp = sample_token(logits, vocab_size); - timing_info.NotifyGenerated(); - - const bool is_eos = - token_streamer(query_idx_start + query_idx, - queries_mutable_pos[query_idx], tp.token, tp.prob); - all_queries_eos &= is_eos; - gen_tokens[query_idx] = is_eos ? runtime_config.eos_id : tp.token; - } - return all_queries_eos; -} - -// Generates one continuation for each query in `queries_prompt`, which is one -// qbatch whose size is at most the `batch_size` passed to -// `activations.Allocate`. -// -// `queries_pos` stores the KV cache position for each query. In the first turn -// of a chat, pos = 0; we increment each query's position after each token. -// -// `query_idx_start` is the query_idx of the first query in the batch, so that -// `StreamFunc` gets the global query index, not relative to the batch. -// -// `kv_caches` is for the batch, size must match `queries_prompt`. -template -void GenerateT(const ModelWeightsStorage& model, Activations& activations, - const RuntimeConfig& runtime_config, - const QueriesPromptTokens& queries_prompt, - const QueriesPos& queries_pos_in, - const QueriesPos& queries_prefix_end, - const size_t query_idx_start, const KVCaches& kv_caches, - TimingInfo& timing_info) { - // Griffin assumes that the recurrent block cache is zero-initialized. - for (size_t i = 0; i < kv_caches.size(); ++i) { - if (queries_pos_in[i] == 0) { - kv_caches[i].ZeroGriffinCache(); // No-op for non-Griffin models. - } - } - - // Copy so we can increment without requiring users to pass in a mutable span. - std::vector queries_pos_copy(queries_pos_in.cbegin(), - queries_pos_in.cend()); - const QueriesMutablePos queries_mutable_pos(queries_pos_copy.data(), - queries_pos_copy.size()); - - // Sanity check: prompts should not be empty, nor start with EOS. - for (size_t query_idx = 0; query_idx < queries_prompt.size(); ++query_idx) { - const PromptTokens& prompt = queries_prompt[query_idx]; - HWY_ASSERT(prompt.size() != 0 && prompt[0] != runtime_config.eos_id); - } - - const size_t num_queries = queries_prompt.size(); - HWY_ASSERT(num_queries <= 4096); // TokenStreamer uses BitSet4096. - HWY_ASSERT(num_queries <= activations.x.BatchSize()); - HWY_ASSERT(queries_pos_in.size() == num_queries); - HWY_ASSERT(kv_caches.size() == num_queries); - const hwy::Divisor div_seq_len(static_cast(kv_caches[0].seq_len)); - const ModelWeightsPtrs& weights = *model.GetWeightsOfType(); - size_t max_prompt_size = MaxQueryLength(queries_prompt); - size_t max_generated_tokens = runtime_config.max_generated_tokens; - RangeChecks(weights.weights_config, max_generated_tokens, max_prompt_size); - const SampleFunc sample_token = ChooseSampleFunc(runtime_config); - - // Prefill stops before min_prompt_size - 1 because the last prompt - // token is the first input token for generation. - timing_info.prefill_start = hwy::platform::Now(); - // If tbatch is larger than the qbatch we already have in `activations`, then - // allocate prefill_activations, otherwise reuse. - const bool use_prefill_activations = - runtime_config.prefill_tbatch_size > activations.x.BatchSize(); - Activations prefill_activations(weights.weights_config); - if (use_prefill_activations) { - prefill_activations.Allocate(runtime_config.prefill_tbatch_size, - activations.env); - } - Prefill(queries_prompt, queries_mutable_pos, queries_prefix_end, - query_idx_start, weights, - use_prefill_activations ? prefill_activations : activations, - runtime_config, div_seq_len, kv_caches); - // Compute the number of tokens that were prefilled and notify timing_info. - size_t prefilled_tokens = 0; - for (size_t qi = 0; qi < num_queries; ++qi) { - prefilled_tokens += queries_prompt[qi].size() - 1; - } - timing_info.NotifyPrefill(prefilled_tokens); - // queries_pos are incremented by Prefill. - - // Storage for the last generated token from each query, passed to the next - // Transformer() call. - std::vector gen_tokens(num_queries); - - // Stream the last prompt token from each query and fill gen_tokens. - TokenStreamer token_streamer(runtime_config, model.Config()); - for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) { - size_t last_token_pos_in_prompt = - queries_mutable_pos[query_idx] - queries_pos_in[query_idx]; - gen_tokens[query_idx] = queries_prompt[query_idx][last_token_pos_in_prompt]; - (void)token_streamer(query_idx_start + query_idx, - queries_mutable_pos[query_idx], gen_tokens[query_idx], - 0.0f); - } - - { - const size_t vocab_size = model.Config().vocab_size; - timing_info.generate_start = hwy::platform::Now(); - for (size_t gen = 0; gen < max_generated_tokens; ++gen) { - bool all_queries_eos = DecodeStepT( - weights, runtime_config, queries_prompt, query_idx_start, kv_caches, - queries_prefix_end, div_seq_len, vocab_size, sample_token, - activations, token_streamer, gen_tokens, - timing_info, queries_mutable_pos); - if (all_queries_eos) break; - } // foreach token to generate - timing_info.NotifyGenerateDone(); - } -} - -template -void GenerateSingleT(const ModelWeightsStorage& model, - const RuntimeConfig& runtime_config, - const PromptTokens& prompt, size_t pos, size_t prefix_end, - KVCache& kv_cache, MatMulEnv* env, - TimingInfo& timing_info) { - constexpr size_t kNumQueries = 1; - const size_t qbatch_start = 0; - - // TODO: move into Gemma? - Activations activations(model.Config()); - activations.Allocate(kNumQueries, env); - - const QueriesPromptTokens queries_prompt(&prompt, kNumQueries); - QueriesPos queries_pos(&pos, kNumQueries); - const QueriesPos queries_prefix_end(&prefix_end, kNumQueries); - const KVCaches kv_caches{&kv_cache, kNumQueries}; - - GenerateT(model, activations, runtime_config, queries_prompt, queries_pos, - queries_prefix_end, qbatch_start, kv_caches, timing_info); -} - -template -void GenerateBatchT(const ModelWeightsStorage& model, - const RuntimeConfig& runtime_config, - const QueriesPromptTokens& queries_prompt, - const QueriesPos& queries_pos, - const QueriesPos& queries_prefix_end, - const KVCaches& kv_caches, MatMulEnv* env, - TimingInfo& timing_info) { - const size_t num_queries = queries_prompt.size(); - HWY_ASSERT(queries_pos.size() == num_queries); - HWY_ASSERT(kv_caches.size() == num_queries); - // Griffin does not support query batching. - size_t max_qbatch_size = runtime_config.decode_qbatch_size; - for (const auto& layer_config : model.Config().layer_configs) { - if (layer_config.type == LayerAttentionType::kGriffinRecurrentBlock) { - max_qbatch_size = 1; - break; - } - } - - Activations activations(model.Config()); - activations.Allocate(max_qbatch_size, env); - - for (size_t qbatch_start = 0; qbatch_start < num_queries; - qbatch_start += max_qbatch_size) { - // Generate one batch of tokens from `qbatch_size` queries. - const size_t qbatch_size = - HWY_MIN(num_queries - qbatch_start, max_qbatch_size); - const QueriesPromptTokens qbatch_prompts(&queries_prompt[qbatch_start], - qbatch_size); - QueriesPos qbatch_pos(&queries_pos[qbatch_start], qbatch_size); - const QueriesPos qbatch_prefix_end(&queries_prefix_end[qbatch_start], - qbatch_size); - const KVCaches qbatch_kv(&kv_caches[qbatch_start], qbatch_size); - GenerateT(model, activations, runtime_config, qbatch_prompts, qbatch_pos, - qbatch_prefix_end, qbatch_start, qbatch_kv, timing_info); - } -} - -template -void GenerateImageTokensT(const ModelWeightsStorage& model, - const RuntimeConfig& runtime_config, - const Image& image, ImageTokens& image_tokens, - MatMulEnv* env) { - if (model.Config().vit_config.layer_configs.empty()) { - HWY_ABORT("Model does not support generating image tokens."); - } - RuntimeConfig prefill_runtime_config = runtime_config; - ModelConfig vit_config = GetVitConfig(model.Config()); - prefill_runtime_config.prefill_tbatch_size = - vit_config.seq_len / (vit_config.pool_dim * vit_config.pool_dim); - Activations prefill_activations(vit_config); - prefill_activations.Allocate(vit_config.seq_len, env); - // Weights are for the full PaliGemma model, not just the ViT part. - PrefillVit(*model.GetWeightsOfType(), prefill_runtime_config, image, - image_tokens, prefill_activations); + CallMatMul(activations.C1, layer.linear_w, output_bias, env, + activations.ffw_out); } +// NOLINTNEXTLINE(google-readability-namespace-comments) } // namespace HWY_NAMESPACE - -#if HWY_ONCE - -// These are extern functions defined by instantiations/*.cc, which include this -// 'header' after defining GEMMA_CONFIG, which is for function overloading. -void GenerateSingle( // NOLINT(misc-definitions-in-headers) - GEMMA_TYPE, const ModelWeightsStorage& model, - const RuntimeConfig& runtime_config, const PromptTokens& prompt, size_t pos, - size_t prefix_end, KVCache& kv_cache, MatMulEnv* env, - TimingInfo& timing_info) { - HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateSingleT) - (model, runtime_config, prompt, pos, prefix_end, kv_cache, env, timing_info); -} - -void GenerateBatch( // NOLINT(misc-definitions-in-headers) - GEMMA_TYPE, const ModelWeightsStorage& model, - const RuntimeConfig& runtime_config, - const QueriesPromptTokens& queries_prompt, const QueriesPos& queries_pos, - const QueriesPos& queries_prefix_end, const KVCaches& kv_caches, - MatMulEnv* env, TimingInfo& timing_info) { - HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateBatchT) - (model, runtime_config, queries_prompt, queries_pos, queries_prefix_end, - kv_caches, env, timing_info); -} - -void GenerateImageTokens( // NOLINT(misc-definitions-in-headers) - GEMMA_TYPE, const ModelWeightsStorage& model, - const RuntimeConfig& runtime_config, const Image& image, - ImageTokens& image_tokens, MatMulEnv* env) { - HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateImageTokensT) - (model, runtime_config, image, image_tokens, env); -} - -#endif // HWY_ONCE - } // namespace gcpp HWY_AFTER_NAMESPACE(); diff --git a/gemma/gemma.cc b/gemma/gemma.cc index bfc6534..496c21d 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -13,173 +13,654 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Defines Gemma member functions; the actual implementations are in -// gemma-inl.h, included from instantiations/*.cc. +// Defines Gemma member functions which dynamic-dispatch into the SIMD +// implementations in gemma-inl.h. #include "gemma/gemma.h" +#include "compression/types.h" // GEMMA_DISABLED_TARGETS +#ifndef HWY_DISABLED_TARGETS +#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS +#endif // HWY_DISABLED_TARGETS + +// Compiles this file for multiple architectures via "foreach_target.h", to +// which we pass the filename via macro 'argument'. +// clang-format off +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "gemma/gemma.cc" // NOLINT +// clang-format on +#include "hwy/foreach_target.h" // IWYU pragma: keep +#include "hwy/highway.h" +// After highway.h +#include "gemma/attention.h" // includes highway.h +#include "gemma/gemma-inl.h" +#include "gemma/griffin.h" // includes highway.h +#include "gemma/vit.h" // includes highway.h + +#ifndef GEMMA_CC_ONCE +#define GEMMA_CC_ONCE + +#include // sqrtf #include #include #include #include -#include -#include // std::move #include -#include "compression/io.h" // Path -#include "compression/shared.h" -#include "gemma/common.h" +#include "gemma/configs.h" +#include "gemma/model_store.h" #include "gemma/weights.h" -#include "ops/ops-inl.h" +#include "io/blob_store.h" +#include "io/io.h" // Path +#include "ops/matmul.h" #include "paligemma/image.h" -#include "util/threading.h" -#include "hwy/highway.h" +#include "util/threading_context.h" +#include "hwy/aligned_allocator.h" // Span +#include "hwy/base.h" +#include "hwy/timer.h" +#endif // GEMMA_CC_ONCE + +HWY_BEFORE_NAMESPACE(); namespace gcpp { +namespace HWY_NAMESPACE { -Gemma::Gemma(const Path& tokenizer_path, const Path& weights, - const ModelInfo& info, MatMulEnv& env) - : env_(env), tokenizer_(tokenizer_path) { - model_.Load(weights, info.model, info.weight, info.wrapping, - env_.parallel.Pools().Pool(0), - /*tokenizer_proto=*/nullptr); -} - -Gemma::Gemma(const Path& weights, MatMulEnv& env) : env_(env) { - std::string tokenizer_proto; - model_.Load(weights, Model::UNKNOWN, Type::kUnknown, PromptWrapping::GEMMA_IT, - env_.parallel.Pools().Pool(0), &tokenizer_proto); - tokenizer_.Deserialize(tokenizer_proto); -} - -Gemma::Gemma(GemmaTokenizer&& tokenizer, const ModelInfo& info, MatMulEnv& env) - : env_(env), tokenizer_(std::move(tokenizer)) { - HWY_ASSERT(info.weight == Type::kF32); - model_.Allocate(info.model, info.weight, env_.parallel.Pools().Pool(0)); -} - -Gemma::~Gemma() { -} - -// There are >=3 types of the inference code. To reduce compile time, -// we shard them across multiple translation units in instantiations/*.cc. -// This declares the functions defined there. We use overloading because -// explicit instantiations are still too slow to compile. -#define GEMMA_DECLARE(TWEIGHT) \ - extern void GenerateSingle(TWEIGHT, const ModelWeightsStorage& model, \ - const RuntimeConfig& runtime_config, \ - const PromptTokens& prompt, size_t pos, \ - size_t prefix_end, KVCache& kv_cache, \ - MatMulEnv* env, TimingInfo& timing_info); \ - extern void GenerateBatch( \ - TWEIGHT, const ModelWeightsStorage& model, \ - const RuntimeConfig& runtime_config, const QueriesPromptTokens& prompts, \ - const QueriesPos& queries_pos, const QueriesPos& queries_prefix_end, \ - const KVCaches& kv_caches, MatMulEnv* env, TimingInfo& timing_info); \ - extern void GenerateImageTokens(TWEIGHT, const ModelWeightsStorage& model, \ - const RuntimeConfig& runtime_config, \ - const Image& image, \ - ImageTokens& image_tokens, MatMulEnv* env); -GEMMA_DECLARE(float) -GEMMA_DECLARE(BF16) -GEMMA_DECLARE(NuqStream) -GEMMA_DECLARE(SfpStream) - -// Adapters to select from the above overloads via CallForModelWeight. -template -struct GenerateSingleT { - void operator()(const ModelWeightsStorage& model, - const RuntimeConfig& runtime_config, - const PromptTokens& prompt, size_t pos, size_t prefix_end, - KVCache& kv_cache, MatMulEnv* env, - TimingInfo& timing_info) const { - GenerateSingle(TConfig(), model, runtime_config, prompt, pos, prefix_end, - kv_cache, env, timing_info); +void Attention(LayerAttentionType type, const size_t num_tokens, + const size_t layer_idx, const LayerWeightsPtrs& layer, + Activations& activations, QBatch& qbatch, MatMulEnv& env) { + if (type == LayerAttentionType::kGemma) { + GemmaAttention(num_tokens, layer_idx, layer, activations.attention, qbatch, + env, + /*flags=*/0); + } else { + HWY_DASSERT(type == LayerAttentionType::kGriffinRecurrentBlock); + // KVCache conv1d_cache and rglru_cache have one row per *Griffin* layer, + // so map `layer` to the Griffin layer index. + const size_t griffin_layer = + activations.attention.config.NumLayersOfTypeBefore(type, layer_idx); + GriffinRecurrent(num_tokens, griffin_layer, &layer, activations, qbatch, + env); } -}; +} -template -struct GenerateBatchT { - void operator()(const ModelWeightsStorage& model, - const RuntimeConfig& runtime_config, - const QueriesPromptTokens& queries_prompt, - const QueriesPos& queries_pos, - const QueriesPos& queries_prefix_end, - const KVCaches& kv_caches, MatMulEnv* env, - TimingInfo& timing_info) const { - GenerateBatch(TConfig(), model, runtime_config, queries_prompt, queries_pos, - queries_prefix_end, kv_caches, env, timing_info); - } -}; +static HWY_NOINLINE void TransformerLayer(const size_t num_tokens, + const size_t layer_idx, + const LayerWeightsPtrs& layer, + Activations& activations, + QBatch& qbatch, MatMulEnv& env) { + const LayerConfig& layer_config = layer.layer_config; -template -struct GenerateImageTokensT { - void operator()(const ModelWeightsStorage& model, - const RuntimeConfig& runtime_config, const Image& image, - ImageTokens& image_tokens, MatMulEnv* env) const { - GenerateImageTokens(TConfig(), model, runtime_config, image, image_tokens, - env); + RMSNormBatched(activations.x, layer.pre_attention_norm_scale, + activations.attention.pre_att_rms_out, env.ctx); + + Attention(layer_config.type, num_tokens, layer_idx, layer, activations, + qbatch, env); + + PostNorm(layer_config.post_norm, layer.post_attention_norm_scale, + activations.attention.att_sums, env.ctx); + + ResidualConnection(activations.attention.att_sums, activations.x, layer, + /*is_attention=*/true, env.ctx); + + RMSNormBatched(activations.x, layer.pre_ffw_norm_scale, + activations.pre_ffw_rms_out, env.ctx); + + if (layer_config.type == LayerAttentionType::kVit) { + FFWVit(layer, activations, env); + } else { + FFWNoVit(layer, activations, env); } -}; + + PostNorm(layer_config.post_norm, layer.post_ffw_norm_scale, + activations.ffw_out, env.ctx); + + ResidualConnection(activations.ffw_out, activations.x, layer, + /*is_attention=*/false, env.ctx); +} + +// Returns the scale value to use for the embedding (basically sqrt model_dim). +static float EmbeddingScaling(size_t model_dim) { + // Round to bf16 to match Gemma's Embedder, which casts before mul. + return hwy::ConvertScalarTo( + hwy::ConvertScalarTo(sqrtf(static_cast(model_dim)))); +} + +// `batch_idx` indicates which row of `x` to write to. +// `pos` is the *token*'s position, not the start of the batch, because this is +// called for batches of tokens in prefill, but batches of queries in decode. +// +// For GEMMA_VLM, image tokens are copied into -2 locations (per the Gemma 3 +// spec) until we run out of image tokens. This allows for a multi-image prompt +// if -2 locations with appropriate begin/end image tokens are created by the +// calling application. +// Returns new image_token_position. +static HWY_NOINLINE size_t +EmbedMMToken(int token, size_t qi, size_t pos, size_t pos_in_prompt, + const ModelConfig& model_config, const WeightsPtrs& weights, + MatStorageT& x, const ImageTokens* image_tokens = nullptr, + size_t image_token_position = 0) { + // Image tokens just need to be copied. + if (model_config.wrapping == PromptWrapping::GEMMA_VLM && + image_tokens != nullptr && token == -2 && + image_token_position < image_tokens->Rows()) { + hwy::CopyBytes(image_tokens->Row(image_token_position), x.Row(qi), + x.Cols() * x.ElementBytes()); + return image_token_position + 1; + } + + if (model_config.wrapping == PromptWrapping::PALIGEMMA && + image_tokens != nullptr && pos_in_prompt < image_tokens->Rows()) { + hwy::CopyBytes(image_tokens->Row(pos_in_prompt), x.Row(qi), + x.Cols() * x.ElementBytes()); + return image_token_position; + } + + const size_t model_dim = model_config.model_dim; + const float emb_scaling = EmbeddingScaling(model_dim); + const size_t worker = 0; // Not yet parallelized. + + HWY_DASSERT(token >= 0); + HWY_DASSERT(token < static_cast(model_config.vocab_size)); + + CallUpcasted(&weights.embedder_input_embedding, [&](const auto* weights_t) { + // Using `Stride` to compute the offset works for both NUQ (because we use + // an offset and NUQ is never padded) and padded, because non-NUQ types are + // seekable, hence the offset can also skip any padding. + const size_t embedding_ofs = token * weights_t->Stride(); + HWY_ASSERT(weights_t->Cols() == model_dim); + const auto embedding_span = + MakeSpan(weights_t->Row(0), embedding_ofs + model_dim); + const hn::ScalableTag df; + DecompressAndZeroPad(df, embedding_span, embedding_ofs, x.Row(qi), + model_dim); + MulByConst(emb_scaling * weights_t->Scale(), x.Row(qi), model_dim, worker); + }); + + if (model_config.absolute_pe) { + AddAbsolutePositionalEmbeddings(x.Row(qi), model_dim, pos); + } + return image_token_position; +} + +// Populates KV cache for batches of tokens from one query at a time. This is +// called if prompts are longer than the query batch size, and also in +// prefix-LM mode (end > 0), which must see all tokens in one batch. +static HWY_NOINLINE void PrefillTBatch(const ModelConfig& config, + const RuntimeConfig& runtime_config, + const WeightsPtrs& weights, + Activations& activations, QBatch& qbatch, + MatMulEnv& env, + hwy::BitSet4096<>& non_eos) { + PROFILER_ZONE("Gen.PrefillT"); + + // Batches are important for amortizing loading weights over multiple tokens. + // This is possible in prefill because we know all tokens beforehand, whereas + // decode depends on the previous output token. However, each prefill batch of + // a query requires that preceding batches already wrote to the KV cache, + // hence we sequentially loop over token batches. We can reduce the number of + // iterations by increasing the batch size, but this also increases arithmetic + // intensity, and so we are eventually compute-limited. TransformerLayer uses + // all available threads, so we do not also parallelize over queries, but note + // that PrefillQBatch uses queries as the batch dimension. + const size_t max_tbatch_size = runtime_config.prefill_tbatch_size; + + // For each query. `qi` is within the batch, not the global query index. + for (size_t qi = 0; qi < qbatch.Size(); ++qi) { + non_eos.Set(qi); + + // One query at a time, batching will be the query's prompt tokens. + QBatch qbatch_1 = qbatch.Single(qi); + + const size_t prompt_size = qbatch_1.Prompt(0).size(); + // In autoregressive mode, we don't need to prefill the last token, so - 1. + size_t prefill_this_query = prompt_size - 1; + const size_t prefix_end_this_query = qbatch_1.PrefixEnd(0); + // We can't attend beyond the prompt_size. + HWY_ASSERT(prefix_end_this_query <= prompt_size); + // Special case: if the prefix includes the last token, we need to prefill + // the last token, too. However, we need to rewind this for the generation + // of the first token. So we need to keep track of this. + // TODO: consider implementing masking instead of this logic? + const bool attend_to_last_token = + (prefill_this_query < prefix_end_this_query); + if (attend_to_last_token) { + // The difference can be at most 1. + prefill_this_query += 1; + HWY_ASSERT(prefill_this_query == prefix_end_this_query); + } + // In prefix-LM mode, we need to look at all the tokens for the prefix in + // one iteration through the layers, so we need a large enough batch size. + HWY_ASSERT(prefix_end_this_query == 0 || + max_tbatch_size >= prefill_this_query); + + // For each batch of tokens in the query: + for (size_t tbatch_start = 0; tbatch_start < prefill_this_query; + tbatch_start += max_tbatch_size) { + const size_t tbatch_size = + HWY_MIN(max_tbatch_size, prefill_this_query - tbatch_start); + activations.SetBatchSize(tbatch_size); + + // Fill activations.x (much faster than TransformerLayer). + size_t image_token_position = 0; + for (size_t ti = 0; ti < tbatch_size; ++ti) { + const size_t pos = qbatch_1.Pos(0) + ti; + const size_t pos_in_prompt = tbatch_start + ti; + const int token = qbatch_1.Prompt(0)[pos_in_prompt]; + image_token_position = EmbedMMToken( + token, ti, pos, pos_in_prompt, config, weights, activations.x, + runtime_config.image_tokens, image_token_position); + } + + // Transformer with one batch of tokens from a single query. + for (size_t layer_idx = 0; layer_idx < config.layer_configs.size(); + ++layer_idx) { + TransformerLayer(tbatch_size, layer_idx, *weights.GetLayer(layer_idx), + activations, qbatch_1, env); + } + + // NOTE: we unconditionally call StreamToken, even if EOS. + for (size_t ti = 0; ti < tbatch_size; ++ti) { + const size_t pos = qbatch_1.Pos(0) + ti; + const size_t pos_in_prompt = tbatch_start + ti; + const int token = qbatch_1.Prompt(0)[pos_in_prompt]; + if (pos_in_prompt < prompt_size - 1) { + runtime_config.StreamToken(qbatch_1.QueryIdx(0), pos, token, 0.0f); + } else { + // The last token will be streamed later and we should only get here + // if we need to attend to the last token because it is in the prefix. + HWY_ASSERT(attend_to_last_token); + } + } + + qbatch_1.MutablePos(0) += tbatch_size; + } // for tbatch_start + if (attend_to_last_token) { + // We need to rewind the position for the last token that we only + // attended to to make sure the prefix LM sees everything. + // This means we duplicate work on the last prompt token in autoregressive + // decoding. Alternatives: (1) real masking; (2) always prefill the last + // token and only generate the next one from the already prefilled + // activations. + qbatch_1.MutablePos(0) -= 1; + } + } +} + +// Embeds PrevToken (one from each query) and calls each TransformerLayer. +// Called by query-batched `PrefillQBatch` and `DecodeStepT`, but not the +// token-batched `PrefillTBatch`. +static HWY_NOINLINE void Transformer(const ModelConfig& config, + const RuntimeConfig& runtime_config, + const WeightsPtrs& weights, + Activations& activations, QBatch& qbatch, + MatMulEnv& env) { + if (HWY_UNLIKELY(runtime_config.layers_output)) { + for (size_t qi = 0; qi < qbatch.Size(); ++qi) { + const float token_f = qbatch.PrevToken(qi); + runtime_config.layers_output(qbatch.QueryIdx(qi), qbatch.Pos(qi), + "tokens", -1, &token_f, 1); + } + } + + // TODO: parallelize? + for (size_t qi = 0; qi < qbatch.Size(); ++qi) { + EmbedMMToken(qbatch.PrevToken(qi), qi, qbatch.Pos(qi), + /*pos_in_prompt=*/0, config, weights, activations.x); + } + + for (size_t layer_idx = 0; layer_idx < weights.c_layers.size(); ++layer_idx) { + TransformerLayer(/*num_tokens=*/1, layer_idx, *weights.GetLayer(layer_idx), + activations, qbatch, env); + + if (HWY_UNLIKELY(runtime_config.activations_observer)) { + runtime_config.activations_observer( + QueriesPos(&qbatch.MutablePos(0), qbatch.Size()), layer_idx, + activations); + } + } +} + +// Populates KV cache for the batch queries, one token at a time. Only called +// for autoregressive (non-prefix-LM) prefill, so `queries_prefix_end` == 0. +static HWY_NOINLINE void PrefillQBatch(const size_t max_prompt_size, + const ModelConfig& config, + const RuntimeConfig& runtime_config, + const WeightsPtrs& weights, + Activations& activations, QBatch& qbatch, + MatMulEnv& env, + hwy::BitSet4096<>& non_eos) { + PROFILER_ZONE("Gen.PrefillQ"); + + for (size_t qi = 0; qi < qbatch.Size(); ++qi) { + non_eos.Set(qi); + HWY_DASSERT(qbatch.PrefixEnd(qi) == 0); + } + + // In autoregressive mode, we don't prefill the last token, hence - 1. + for (size_t pos_in_prompt = 0; pos_in_prompt < max_prompt_size - 1; + ++pos_in_prompt) { + for (size_t qi = 0; qi < qbatch.Size(); ++qi) { + int token = config.eos_id; + if (pos_in_prompt < qbatch.Prompt(qi).size() - 1) { + token = qbatch.Prompt(qi)[pos_in_prompt]; + // Ignore StreamToken return value because requesting to stop does not + // make sense during prefill. + (void)runtime_config.StreamToken(qbatch.QueryIdx(qi), pos_in_prompt, + token, 0.0f); + qbatch.MutablePos(qi) = pos_in_prompt; + } + + qbatch.PrevToken(qi) = token; + } + + // The input (PrevToken) is one token from each query in the batch. + // Do not call DecodeStepT because it computes logits for token + // probabilities, which are not required for the prompt tokens. + Transformer(config, runtime_config, weights, activations, qbatch, env); + } + + for (size_t qi = 0; qi < qbatch.Size(); ++qi) { + qbatch.MutablePos(qi) = qbatch.Prompt(qi).size() - 1; + } +} + +// Calls `StreamToken`, writes the token to `PrevToken` for use by subsequent +// `DecodeStepT`, and increments `MutablePos`. Also updates `non_eos` if the +// query is at the end of its sequence. +static void StreamAndUpdateEOS(const size_t qi, int token, const float prob, + const ModelConfig& config, + const RuntimeConfig& runtime_config, + QBatch& qbatch, hwy::BitSet4096<>& non_eos) { + HWY_DASSERT(non_eos.Get(qi)); // otherwise, should not be called. + + if (HWY_UNLIKELY(!runtime_config.StreamToken(qbatch.QueryIdx(qi), + qbatch.Pos(qi), token, prob))) { + // User decided to stop: set token to primary EOS to trigger IsEOS below. + token = config.eos_id; + HWY_DASSERT(config.IsEOS(token)); + } + + qbatch.PrevToken(qi) = token; + qbatch.MutablePos(qi) += 1; + + // Primary or secondary EOS: mark query as EOS, but still increment (for + // multi-turn, we should still keep the prior EOS). + if (HWY_UNLIKELY(config.IsEOS(token))) non_eos.Clear(qi); +} + +// For a batch of queries, runs Transformer, computes logits, samples and +// streams the token. +static void DecodeStepT(const ModelConfig& config, + const RuntimeConfig& runtime_config, + const WeightsPtrs& weights, + const SampleFunc& sample_token, + Activations& activations, QBatch& qbatch, + MatMulEnv& env, hwy::BitSet4096<>& non_eos, + TimingInfo& timing_info) { + HWY_DASSERT(qbatch.Size() == activations.x.Rows()); + + Transformer(config, runtime_config, weights, activations, qbatch, env); + + RMSNormInplaceBatched(weights.final_norm_scale, activations.x, env.ctx); + + if (HWY_UNLIKELY(runtime_config.activations_observer)) { + runtime_config.activations_observer( + QueriesPos(&qbatch.MutablePos(0), qbatch.Size()), -1, activations); + } + + { + PROFILER_ZONE("Gen.EmbeddingMatmul"); + // Compute logits from last layer activations. + CallMatMul(activations.x, weights.embedder_input_embedding, + /*add=*/nullptr, env, activations.logits); + } + PROFILER_ZONE("Gen.Softcap+Sample+Stream"); + const size_t worker = 0; // TODO: parallelize + non_eos.Foreach([&](size_t qi) { + float* HWY_RESTRICT logits = activations.logits.Row(qi); + MaybeLogitsSoftCap(config.final_cap, logits, config.vocab_size, worker); + const TokenAndProb tp = sample_token(logits, config.vocab_size); + timing_info.NotifyGenerated(); + + StreamAndUpdateEOS(qi, tp.token, tp.prob, config, runtime_config, qbatch, + non_eos); + }); +} + +static HWY_INLINE SampleFunc +ChooseSampleFunc(const RuntimeConfig& runtime_config) { + // If user provided a sample_func, use it. + if (runtime_config.sample_func) return runtime_config.sample_func; + + const size_t worker = 0; // TODO: parallelize + + // Fast path for top-1 with no accept_token. + if (runtime_config.top_k == 1 && !runtime_config.accept_token) { + return [](float* logits, size_t vocab_size) HWY_ATTR -> TokenAndProb { + PROFILER_ZONE2(worker, "Gen.Sample Top1"); + return Top1OfSoftmax(logits, vocab_size); + }; + } + + // General case: Softmax with top-k sampling. + return [&runtime_config](float* logits, + size_t vocab_size) HWY_ATTR -> TokenAndProb { + PROFILER_ZONE("Gen.Sample general"); + return FusedSoftmaxAndSampleTopK( + logits, runtime_config.top_k, vocab_size, *runtime_config.gen, + runtime_config.temperature, runtime_config.accept_token, worker); + }; +} + +// Decode: generates one continuation token for each query in `qbatch`. +static void GenerateT(const ModelConfig& config, + const RuntimeConfig& runtime_config, + const WeightsPtrs& weights, Activations& activations, + QBatch& qbatch, MatMulEnv& env, TimingInfo& timing_info) { + // Griffin assumes that the recurrent block cache is zero-initialized. + for (size_t qi = 0; qi < qbatch.Size(); ++qi) { + if (qbatch.MutablePos(qi) == 0) { + qbatch.KV(qi).ZeroGriffinCache(); // No-op for non-Griffin models. + } + } + + size_t max_prompt_size = 0; + bool all_prefix_end_are_zero = true; + size_t total_prefill_tokens = 0; // only for throughput stats. + const size_t seq_len = qbatch.KV(0).SeqLen(); + for (size_t qi = 0; qi < qbatch.Size(); ++qi) { + const PromptTokens& prompt = qbatch.Prompt(qi); + max_prompt_size = HWY_MAX(max_prompt_size, prompt.size()); + + // Prefill stops before size - 1 because the last prompt token is the + // first input token for generation. + total_prefill_tokens += prompt.size() - 1; + + // Sanity check: prompts should not be empty, nor start with EOS. + HWY_ASSERT(prompt.size() != 0 && prompt[0] != config.eos_id); + + all_prefix_end_are_zero &= qbatch.PrefixEnd(qi) == 0; + + // We use a single divisor, so all sequence lengths must be the same. + HWY_ASSERT(qbatch.KV(qi).SeqLen() == seq_len); + } + if (max_prompt_size >= seq_len) { + HWY_ABORT("max_prompt_size = %zu, increase --seq_len to at least that.", + max_prompt_size); + } + HWY_ASSERT(activations.attention.div_seq_len.GetDivisor() == seq_len); + + // Lacks a constructor to bulk-set, hence initialized by Prefill* which have + // qi loops anyway. + hwy::BitSet4096<> non_eos; // indexed by qi + + timing_info.prefill_start = hwy::platform::Now(); + // Batch over the larger of prompt length, or queries. + if ((qbatch.Size() > max_prompt_size) && all_prefix_end_are_zero) { + activations.SetBatchSize(qbatch.Size()); // required before PrefillQBatch + PrefillQBatch(max_prompt_size, config, runtime_config, weights, activations, + qbatch, env, non_eos); + } else { + PrefillTBatch(config, runtime_config, weights, activations, qbatch, env, + non_eos); + activations.SetBatchSize(qbatch.Size()); // Restore after PrefillTBatch. + } + HWY_DASSERT(non_eos.Count() == qbatch.Size()); + timing_info.NotifyPrefill(total_prefill_tokens); + // queries_pos have been incremented by Prefill. + + // Stream the last prompt token from each query, fill activations.gen_tokens. + for (size_t qi = 0; qi < qbatch.Size(); ++qi) { + const size_t last_pos_in_prompt = qbatch.Pos(qi) - qbatch.InitialPos(qi); + StreamAndUpdateEOS(qi, qbatch.Prompt(qi)[last_pos_in_prompt], 0.0f, config, + runtime_config, qbatch, non_eos); + } + + size_t max_gen_steps = runtime_config.max_generated_tokens; + if (max_prompt_size + max_gen_steps > seq_len) { + HWY_WARN("prefill %zu + max_gen_steps %zu > seq_len %zu, truncating.", + max_prompt_size, max_gen_steps, seq_len); + max_gen_steps = seq_len - max_prompt_size; + } + + const SampleFunc sample_token = ChooseSampleFunc(runtime_config); + + { + timing_info.generate_start = hwy::platform::Now(); + for (size_t gen = 0; gen < max_gen_steps && non_eos.Any(); ++gen) { + DecodeStepT(config, runtime_config, weights, sample_token, activations, + qbatch, env, non_eos, timing_info); + } + timing_info.NotifyGenerateDone(); + } +} + +void GenerateSingleT(const PromptTokens& prompt, size_t pos, size_t prefix_end, + const ModelConfig& config, + const RuntimeConfig& runtime_config, + const WeightsPtrs& weights, KVCache& kv_cache, + MatMulEnv& env, TimingInfo& timing_info) { + Activations activations(config, runtime_config.prefill_tbatch_size, + kv_cache.SeqLen(), env.ctx.allocator, env.row_ptrs); + + AllQueries all_queries(prompt, pos, prefix_end, + hwy::Span(&kv_cache, 1)); + QBatch qbatch(/*start=*/0, /*max_size=*/1, all_queries); + GenerateT(config, runtime_config, weights, activations, qbatch, env, + timing_info); +} + +// Splits the input into batches of at most `runtime_config.decode_qbatch_size` +// queries, and calls `GenerateT` on each batch. +void GenerateBatchT(const ModelConfig& config, + const RuntimeConfig& runtime_config, + const WeightsPtrs& weights, AllQueries& all_queries, + MatMulEnv& env, TimingInfo& timing_info) { + const size_t max_batch_size = HWY_MAX(runtime_config.decode_qbatch_size, + runtime_config.prefill_tbatch_size); + Activations activations(config, max_batch_size, + all_queries[0].kv_cache.SeqLen(), env.ctx.allocator, + env.row_ptrs); + + for (size_t start = 0; start < all_queries.NumQueries(); + start += runtime_config.decode_qbatch_size) { + QBatch qbatch(start, runtime_config.decode_qbatch_size, all_queries); + // Generate a batch of one token for each of `qbatch.Size()` queries. + GenerateT(config, runtime_config, weights, activations, qbatch, env, + timing_info); + } +} + +void GenerateImageTokensT(const ModelConfig& config, + const RuntimeConfig& runtime_config, size_t seq_len, + const WeightsPtrs& weights, const Image& image, + ImageTokens& image_tokens, MatMulEnv& env) { + if (config.vit_config.layer_configs.empty()) { + HWY_ABORT("Model does not support generating image tokens."); + } + RuntimeConfig prefill_runtime_config = runtime_config; + const ModelConfig vit_config = GetVitConfig(config); + const size_t num_tokens = vit_config.max_seq_len; + prefill_runtime_config.prefill_tbatch_size = + num_tokens / (vit_config.pool_dim * vit_config.pool_dim); + Activations prefill_activations(vit_config, num_tokens, num_tokens, + env.ctx.allocator, env.row_ptrs); + // Weights are for the full PaliGemma model, not just the ViT part. + PrefillVit(config, weights, prefill_runtime_config, image, image_tokens, + prefill_activations, env); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace gcpp +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace gcpp { +HWY_EXPORT(GenerateSingleT); +HWY_EXPORT(GenerateBatchT); +HWY_EXPORT(GenerateImageTokensT); + +Gemma::Gemma(const LoaderArgs& loader, const InferenceArgs& inference, + ThreadingContext& ctx) + : reader_(loader.weights), + model_(reader_, loader.tokenizer, loader.wrapping), + weights_(model_.Config()), + chat_template_(model_.Tokenizer(), model_.Config().model), + inference_(inference) { + weight_read_mode_ = weights_.ReadFromBlobs(model_, reader_, loader, inference, + mat_owners_, ctx); + // Read everything into memory, or `weights_.mapped_` keeps the mapping alive. + reader_.CloseFile(); +} + +Gemma::~Gemma() = default; + +void Gemma::Save(const Path& weights_path, NestedPools& pools) const { + BlobWriter writer(weights_path, pools.Pool()); + const std::vector serialized_mat_ptrs = + weights_.AddTensorDataToWriter(writer); + WriteSingleFile(model_.Config(), model_.Tokenizer(), serialized_mat_ptrs, + writer); +} void Gemma::Generate(const RuntimeConfig& runtime_config, const PromptTokens& prompt, size_t pos, size_t prefix_end, - KVCache& kv_cache, TimingInfo& timing_info) { - env_.parallel.Pools().MaybeStartSpinning(runtime_config.use_spinning); + KVCache& kv_cache, MatMulEnv& env, + TimingInfo& timing_info) const { + env.ctx.pools.MaybeStartSpinning(runtime_config.use_spinning); - model_.CallForModelWeight( - runtime_config, prompt, pos, prefix_end, kv_cache, &env_, timing_info); + HWY_DYNAMIC_DISPATCH(GenerateSingleT)(prompt, pos, prefix_end, + model_.Config(), runtime_config, + weights_, kv_cache, env, timing_info); - env_.parallel.Pools().MaybeStopSpinning(runtime_config.use_spinning); + env.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning); } void Gemma::GenerateBatch(const RuntimeConfig& runtime_config, - const QueriesPromptTokens& queries_prompt, - const QueriesPos& queries_pos, - const QueriesPos& queries_prefix_end, - const KVCaches& kv_caches, TimingInfo& timing_info) { - // If we did not get passed prefix ends (size 0), assume 0 and pass that on. - QueriesPos mutable_queries_prefix_end = queries_prefix_end; - std::vector prefix_end_vec; - if (queries_prefix_end.size() == 0) { - prefix_end_vec.resize(queries_prompt.size(), 0); - mutable_queries_prefix_end = - QueriesPos(prefix_end_vec.data(), prefix_end_vec.size()); - } + AllQueries& all_queries, MatMulEnv& env, + TimingInfo& timing_info) const { + env.ctx.pools.MaybeStartSpinning(runtime_config.use_spinning); - env_.parallel.Pools().MaybeStartSpinning(runtime_config.use_spinning); + HWY_DYNAMIC_DISPATCH(GenerateBatchT)(model_.Config(), runtime_config, + weights_, all_queries, env, timing_info); - model_.CallForModelWeight( - runtime_config, queries_prompt, queries_pos, mutable_queries_prefix_end, - kv_caches, &env_, timing_info); - - env_.parallel.Pools().MaybeStopSpinning(runtime_config.use_spinning); + env.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning); } void Gemma::GenerateImageTokens(const RuntimeConfig& runtime_config, - const Image& image, ImageTokens& image_tokens) { - env_.parallel.Pools().MaybeStartSpinning(runtime_config.use_spinning); + size_t seq_len, const Image& image, + ImageTokens& image_tokens, + MatMulEnv& env) const { + env.ctx.pools.MaybeStartSpinning(runtime_config.use_spinning); - model_.CallForModelWeight(runtime_config, image, - image_tokens, &env_); + HWY_DYNAMIC_DISPATCH(GenerateImageTokensT)(model_.Config(), runtime_config, + seq_len, weights_, image, + image_tokens, env); - env_.parallel.Pools().MaybeStopSpinning(runtime_config.use_spinning); -} - -// Non-template functions moved from gemma-inl.h to avoid ODR violations. - -void RangeChecks(const ModelConfig& weights_config, - size_t& max_generated_tokens, const size_t prompt_size) { - if (!weights_config.use_local_attention) { - if (max_generated_tokens > weights_config.seq_len) { - fprintf(stderr, - "WARNING: max_generated_tokens %zu > kSeqLen %u, truncating.\n", - max_generated_tokens, weights_config.seq_len); - max_generated_tokens = weights_config.seq_len; - } - } - HWY_ASSERT(prompt_size > 0); + env.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning); } } // namespace gcpp +#endif // HWY_ONCE diff --git a/gemma/gemma.h b/gemma/gemma.h index ccda69c..5ebd70d 100644 --- a/gemma/gemma.h +++ b/gemma/gemma.h @@ -16,122 +16,154 @@ #ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_H_ #define THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_H_ -#include -#include -#include +#include + #include // IWYU pragma: begin_exports -#include "compression/io.h" // Path #include "gemma/activations.h" -#include "gemma/common.h" #include "gemma/configs.h" +#include "gemma/gemma_args.h" #include "gemma/kv_cache.h" -#include "gemma/tokenizer.h" +#include "gemma/model_store.h" #include "gemma/weights.h" +#include "io/blob_store.h" +#include "io/io.h" // Path #include "ops/matmul.h" // MatMulEnv #include "paligemma/image.h" -#include "util/allocator.h" // RowVectorBatch -#include "util/basics.h" // TokenAndProb +#include "util/basics.h" // TokenAndProb +#include "util/threading_context.h" #include "hwy/timer.h" // IWYU pragma: end_exports -#include "hwy/aligned_allocator.h" // Span namespace gcpp { -using PromptTokens = hwy::Span; -// Batches of independent queries have their own prompt, previous token, -// position in the sequence, and KVCache. -using QueriesPromptTokens = hwy::Span; -using QueriesToken = hwy::Span; -using QueriesPos = hwy::Span; -using KVCaches = hwy::Span; +struct PerQuery { + PromptTokens prompt; -// StreamFunc is called with (token, probability). For prompt tokens, -// probability is 0.0f. StreamFunc should return false to stop generation and -// true to continue generation. -using StreamFunc = std::function; -// BatchStreamFunc is called with (query_idx, pos, token, probability). -// For prompt tokens, probability is 0.0f. -// StreamFunc should return false to stop generation and true to continue. -using BatchStreamFunc = std::function; -// If not empty, AcceptFunc is called with token. It should return false for -// tokens you don't want to generate and true for tokens you want to generate. -using AcceptFunc = std::function; -// If not empty, SampleFunc is called with the logits for the next token, which -// it may modify/overwrite, and its return value is the next generated token -// together with its probability. -using SampleFunc = std::function; -// If not empty, LayersOutputFunc is called for layer outputs, specified with: -// - index of query within containing batch (if any); zero otherwise. -// - position in the tokens sequence -// - name of the data, e.g. "tokens" for token IDs -// - layer index (or -1 for global outputs) -// - pointer to the data array -// - size of the data array -using LayersOutputFunc = std::function; -// If not empty, ActivationsObserverFunc is invoked after each layer with: -// - per-query position within the tokens sequence -// - layer index (or -1 for post-norm output) -// - activations -using ActivationsObserverFunc = - std::function; + // Position in the KV cache: initially zero for the first turn, or when + // multi-turn is NOT desired. Incremented by prefill and `StreamAndUpdateEOS`. + size_t mutable_pos; + // Allows computing the last prefill token as `mutable_pos - initial_pos`, + // which might differ from `prompt.size() - 1` for prefix-LM. + size_t initial_pos; + // Zero for causal attention, or the end of the prefix for prefix-LM style + // attention in Paligemma. + size_t prefix_end; -// ImageTokens are represented as a RowVectorBatch, where each "batch" index -// corresponds to a token for an image patch as computed by the image encoder. -using ImageTokens = RowVectorBatch; + KVCache& kv_cache; -// RuntimeConfig holds configuration for a single generation run. -struct RuntimeConfig { - // If not empty, batch_stream_token is called for each token in the batch, - // instead of stream_token. - bool StreamToken(size_t query_idx, size_t pos, int token, float prob) const { - if (batch_stream_token) { - return batch_stream_token(query_idx, pos, token, prob); + // Previous token generated for this query, or the last prompt token. Will be + // fed into the next Transformer() call. + int prev_token = 0; +}; + +// Array of `PerQuery`. Referenced by `QBatch` and passed to `GenerateBatch`. +struct AllQueries { + AllQueries() = default; + + // For `GenerateSingleT`: same prompt/pos, replicated for each KV cache. + AllQueries(const PromptTokens& prompt, size_t pos, size_t prefix_end, + const hwy::Span& kv_caches) { + per_query_.reserve(kv_caches.size()); + for (size_t i = 0; i < kv_caches.size(); ++i) { + HWY_ASSERT(kv_caches[i].SeqLen() == kv_caches[0].SeqLen()); + per_query_.push_back(PerQuery{ + .prompt = prompt, + .mutable_pos = pos, + .initial_pos = pos, + .prefix_end = prefix_end, + .kv_cache = kv_caches[i], + }); } - return stream_token(token, prob); } - // Limit on the number of tokens generated. - size_t max_generated_tokens; + // Batch of queries with initial position set to zero. Causal attention + // is requested via empty or all-zero `prefix_end`. + AllQueries( + const hwy::Span& prompts, + const hwy::Span& kv_caches, + const hwy::Span& prefix_end = hwy::Span()) { + HWY_ASSERT(prompts.size() == kv_caches.size()); + HWY_ASSERT(prompts.size() == prefix_end.size() || prefix_end.size() == 0); + per_query_.reserve(kv_caches.size()); + for (size_t i = 0; i < kv_caches.size(); ++i) { + HWY_ASSERT(kv_caches[i].SeqLen() == kv_caches[0].SeqLen()); + per_query_.push_back(PerQuery{ + .prompt = prompts[i], + .mutable_pos = 0, + .initial_pos = 0, + .prefix_end = prefix_end.size() == 0 ? 0 : prefix_end[i], + .kv_cache = kv_caches[i], + }); + } + } - // These defaults are overridden by InferenceArgs::CopyTo(*this): - // Max tokens per batch during prefill. - size_t prefill_tbatch_size = 256; - // Max queries per batch (one token from each) during decode. - size_t decode_qbatch_size = 16; + void Reserve(size_t size) { per_query_.reserve(size); } + void Append(const PerQuery& query) { per_query_.push_back(query); } - // Sampling-related parameters. - float temperature; // Temperature for sampling. - size_t top_k = kTopK; // Top-k for sampling. - std::mt19937* gen; // Random number generator used for sampling. + size_t NumQueries() const { return per_query_.size(); } - int verbosity; // Controls verbosity of printed messages. + PerQuery& operator[](size_t query_idx) { + HWY_DASSERT(query_idx < NumQueries()); + return per_query_[query_idx]; + } + const PerQuery& operator[](size_t query_idx) const { + HWY_DASSERT(query_idx < NumQueries()); + return per_query_[query_idx]; + } - // Functions operating on the generated tokens. - StreamFunc stream_token; - BatchStreamFunc batch_stream_token; - AcceptFunc accept_token; // if empty, accepts all tokens. - SampleFunc sample_func; // if empty, uses SampleTopK. + private: + std::vector per_query_; +}; - // Observer callbacks for intermediate data. - LayersOutputFunc layers_output; // if not empty, called after each layer. - ActivationsObserverFunc activations_observer; // if set, called per-layer. +// View into AllQueries: either a batch of queries, or a single query for use +// in PrefillTBatch or GenerateSingleT. Cheap to create because it holds a +// reference to AllQueries. +class QBatch { + public: + QBatch(size_t start, size_t max_size, AllQueries& queries) + : start_(start), + max_size_(max_size), + queries_(queries), + size_(HWY_MIN(max_size_, queries_.NumQueries() - start_)) { + HWY_ASSERT(max_size_ <= 4096); // non_eos uses `BitSet4096`. + HWY_DASSERT(size_ != 0); + HWY_DASSERT(start_ + size_ <= queries_.NumQueries()); + } - // If not empty, these point to the image tokens and are used in the - // PaliGemma prefix-LM style attention. - const ImageTokens *image_tokens = nullptr; + // Returns a single-query view starting at `qi` relative to this batch. + QBatch Single(size_t qi) const { return QBatch(start_ + qi, 1, queries_); } - // Whether to use thread spinning to reduce barrier synchronization latency. - // Mutable so we can change kDefault to kTrue/kFalse during Generate, because - // RuntimeConfig is const there and is not passed to the Gemma ctor. This - // default decision is likely sufficient because it is based on whether - // threads are successfully pinned. - mutable Tristate use_spinning = Tristate::kDefault; + // How many queries in this batch, <= `queries_.NumQueries()` and `max_size_`. + size_t Size() const { return size_; } - // End-of-sequence token. - int eos_id = EOS_ID; + // Returns index for use with `AllQueries` and `BatchStreamToken`. + size_t QueryIdx(size_t qi) const { + HWY_DASSERT(qi < size_); + return start_ + qi; + } + + // Accessor functions to bridge the previous SoA and current AoS layout. + const PromptTokens& Prompt(size_t qi) const { + return queries_[QueryIdx(qi)].prompt; + } + size_t Pos(size_t qi) const { return queries_[QueryIdx(qi)].mutable_pos; } + size_t& MutablePos(size_t qi) { return queries_[QueryIdx(qi)].mutable_pos; } + size_t InitialPos(size_t qi) const { + return queries_[QueryIdx(qi)].initial_pos; + } + size_t PrefixEnd(size_t qi) const { + return queries_[QueryIdx(qi)].prefix_end; + } + KVCache& KV(size_t qi) const { return queries_[QueryIdx(qi)].kv_cache; } + int& PrevToken(size_t qi) { return queries_[QueryIdx(qi)].prev_token; } + + private: + size_t start_; + size_t max_size_; + AllQueries& queries_; + size_t size_; }; struct TimingInfo { @@ -193,82 +225,58 @@ struct TimingInfo { size_t tokens_generated = 0; }; +// After construction, all methods are const and thread-compatible if using +// separate ThreadingContext for each thread. class Gemma { public: - // Reads old format weights file and tokenizer file. - // `env` must remain valid for the lifetime of this Gemma. - Gemma(const Path& tokenizer_path, const Path& weights, const ModelInfo& info, - MatMulEnv& env); - // Reads new format weights file that contains everything in a single file. - // `env` must remain valid for the lifetime of this Gemma. - Gemma(const Path& weights, MatMulEnv& env); - // Allocates weights, caller is responsible for filling them. - Gemma(GemmaTokenizer&& tokenizer, const ModelInfo& info, MatMulEnv& env); + // Reads weights/config/tokenizer from the `BlobStore` at `loader.weights`. + // `ctx` is only used to read tensors, but it is typically also referenced + // by the `MatMulEnv` passed to the Generate* methods. + Gemma(const LoaderArgs& loader, const InferenceArgs& inference, + ThreadingContext& ctx); ~Gemma(); - const ModelConfig& GetModelConfig() const { return model_.Config(); } - ModelInfo Info() const { - return ModelInfo({.model = model_.Config().model, - .wrapping = model_.Config().wrapping, - .weight = model_.Config().weight}); - } - const GemmaTokenizer& Tokenizer() const { return tokenizer_; } - const ModelWeightsStorage& Weights() const { return model_; } - ModelWeightsStorage& MutableWeights() { return model_; } - void Save(const Path& weights, hwy::ThreadPool& pool) { - std::string tokenizer_proto = tokenizer_.Serialize(); - model_.Save(tokenizer_proto, weights, pool); - } + const ModelConfig& Config() const { return model_.Config(); } + const GemmaTokenizer& Tokenizer() const { return model_.Tokenizer(); } + const WeightsPtrs& Weights() const { return weights_; } + WeightsPtrs::Mode WeightReadMode() const { return weight_read_mode_; } + const GemmaChatTemplate& ChatTemplate() const { return chat_template_; } + const InferenceArgs& Inference() const { return inference_; } + + void Save(const Path& weights_path, NestedPools& pools) const; // `pos` is the position in the KV cache. Users are responsible for // incrementing it in the `*StreamFunc`, or setting to zero for single-turn. void Generate(const RuntimeConfig& runtime_config, const PromptTokens& prompt, - size_t pos, KVCache& kv_cache, TimingInfo& timing_info) { - Generate(runtime_config, prompt, pos, /*prefix_end=*/0, kv_cache, + size_t pos, KVCache& kv_cache, MatMulEnv& env, + TimingInfo& timing_info) const { + Generate(runtime_config, prompt, pos, /*prefix_end=*/0, kv_cache, env, timing_info); } // For prefix-LM style attention, we can pass the end of the prefix. void Generate(const RuntimeConfig& runtime_config, const PromptTokens& prompt, size_t pos, size_t prefix_end, KVCache& kv_cache, - TimingInfo& timing_info); + MatMulEnv& env, TimingInfo& timing_info) const; - // `queries_pos` are the positions in the KV cache. Users are responsible for - // incrementing them in `BatchStreamFunc`, or setting to zero for single-turn. void GenerateBatch(const RuntimeConfig& runtime_config, - const QueriesPromptTokens& queries_prompt, - const QueriesPos& queries_pos, const KVCaches& kv_caches, - TimingInfo& timing_info) { - GenerateBatch(runtime_config, queries_prompt, queries_pos, - /*queries_prefix_end=*/{}, kv_caches, timing_info); - } - // For prefix-LM style attention, we can pass the ends of the prefixes. - void GenerateBatch(const RuntimeConfig& runtime_config, - const QueriesPromptTokens& queries_prompt, - const QueriesPos& queries_pos, - const QueriesPos& queries_prefix_end, - const KVCaches& kv_caches, TimingInfo& timing_info); + AllQueries& all_queries, MatMulEnv& env, + TimingInfo& timing_info) const; // Generates the image tokens by running the image encoder ViT. - void GenerateImageTokens(const RuntimeConfig& runtime_config, - const Image& image, ImageTokens& image_tokens); + void GenerateImageTokens(const RuntimeConfig& runtime_config, size_t seq_len, + const Image& image, ImageTokens& image_tokens, + MatMulEnv& env) const; private: - MatMulEnv& env_; - - GemmaTokenizer tokenizer_; - // Type-erased so that this can be defined in the header. - ModelWeightsStorage model_; + BlobReader reader_; + ModelStore model_; + std::vector mat_owners_; + WeightsPtrs weights_; + WeightsPtrs::Mode weight_read_mode_; + GemmaChatTemplate chat_template_; + InferenceArgs inference_; }; -// Adds BOS token and possibly 'turn' annotations, which depend on `info` -// and `pos`, the number of tokens decoded so far; returns the corresponding -// tokens. Asserts that tokenization is successful. -std::vector WrapAndTokenize(const GemmaTokenizer& tokenizer, - const ModelInfo& info, size_t pos, - std::string& prompt); -void RangeChecks(const ModelConfig& weights_config, - size_t& max_generated_tokens, size_t prompt_size); - } // namespace gcpp #endif // THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_H_ diff --git a/gemma/gemma_args.h b/gemma/gemma_args.h new file mode 100644 index 0000000..70268c7 --- /dev/null +++ b/gemma/gemma_args.h @@ -0,0 +1,273 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Shared between various frontends. + +#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_ARGS_H_ +#define THIRD_PARTY_GEMMA_CPP_GEMMA_ARGS_H_ + +#include +#include + +#include +#include +#include + +#include "io/io.h" // Path +#include "ops/matmul.h" // MMStorage::kMax* +#include "util/args.h" +#include "util/basics.h" // Tristate +#include "util/mat.h" +#include "hwy/aligned_allocator.h" // Span +#include "hwy/base.h" // HWY_ABORT +#include "hwy/profiler.h" + +namespace gcpp { + +struct LoaderArgs : public ArgsBase { + LoaderArgs(int argc, char* argv[]) { InitAndParse(argc, argv); } + LoaderArgs(const std::string& tokenizer_path, + const std::string& weights_path) { + Init(); // Init sets to defaults, so assignments must come after Init(). + tokenizer.path = tokenizer_path; + weights.path = weights_path; + }; + + Path tokenizer; + Path weights; // weights file location + Tristate map; + Tristate to_bf16; + Tristate wrapping; + + template + void ForEach(const Visitor& visitor) { + visitor(tokenizer, "tokenizer", Path(), + "Path name of tokenizer model; only required for pre-2025 format."); + visitor(weights, "weights", Path(), + "Path name of model weights (.sbs) file.\n Required argument.\n"); + visitor(map, "map", Tristate::kDefault, + "Enable memory-mapping? -1 = auto, 0 = no, 1 = yes."); + visitor(to_bf16, "to_bf16", Tristate::kDefault, + "Convert weights to bf16? -1 = auto, 0 = no, 1 = yes."); + visitor(wrapping, "wrapping", Tristate::kDefault, + "Enable prompt wrapping? Specify 0 for pre-2025 format PT models."); + } +}; + +using PromptTokens = hwy::Span; + +// Batches of independent queries have their own prompt, previous token, +// position in the sequence, and KVCache. +using QueriesPromptTokens = hwy::Span; +using QueriesToken = hwy::Span; +using QueriesPos = hwy::Span; + +// ImageTokens are represented as a matrix, where each row corresponds +// to a token for an image patch as computed by the image encoder. +using ImageTokens = MatStorageT; + +// StreamFunc is called with (token, probability). For prompt tokens, +// probability is 0.0f. StreamFunc should return false to stop generation and +// true to continue generation. +using StreamFunc = std::function; +// BatchStreamFunc is called with (query_idx, pos, token, probability). +// For prompt tokens, probability is 0.0f. Generation continues if this returns +// true and stops if it returns false. Note that query_idx is absolute, not +// relative to the batch. +using BatchStreamFunc = std::function; +// If not empty, AcceptFunc is called with token. It should return false for +// tokens you don't want to generate and true for tokens you want to generate. +using AcceptFunc = std::function; +// If not empty, SampleFunc is called with the logits for the next token, which +// it may modify/overwrite, and its return value is the next generated token +// together with its probability. +using SampleFunc = std::function; +// If not empty, LayersOutputFunc is called for layer outputs, specified with: +// - index of query within containing batch (if any); zero otherwise. +// - position in the tokens sequence +// - name of the data, e.g. "tokens" for token IDs +// - layer index (or -1 for global outputs) +// - pointer to the data array +// - size of the data array +using LayersOutputFunc = std::function; +// If not empty, ActivationsObserverFunc is invoked after each layer with: +// - per-query position within the tokens sequence +// - layer index (or -1 for post-norm output) +// - activations +struct Activations; +using ActivationsObserverFunc = + std::function; + +// RuntimeConfig holds configuration for a single generation run. +// TODO: move into InferenceArgs, use that directly. +struct RuntimeConfig { + // If non-null, `batch_stream_token` is called for each token in the batch, + // otherwise `stream_token`. `query_idx` is absolute, not batch-relative. + bool StreamToken(size_t query_idx, size_t pos, int token, float prob) const { + PROFILER_ZONE("Gen.StreamToken"); + if (batch_stream_token) { + return batch_stream_token(query_idx, pos, token, prob); + } + return stream_token(token, prob); + } + + // Limit on the number of tokens generated. + size_t max_generated_tokens; + + // These defaults are overridden by InferenceArgs::CopyTo(*this): + // Max tokens per batch during prefill. + size_t prefill_tbatch_size = 256; + // Max queries per batch (one token from each) during decode. + size_t decode_qbatch_size = 16; + + // Sampling-related parameters. + float temperature; // Temperature for sampling. + + size_t top_k = 1; // Top-k for sampling. + std::mt19937* gen; // Random number generator used for sampling. + + int verbosity; // Controls verbosity of printed messages. + + // Functions operating on the generated tokens. + StreamFunc stream_token; + BatchStreamFunc batch_stream_token; + AcceptFunc accept_token; // if empty, accepts all tokens. + SampleFunc sample_func; // if empty, uses SampleTopK. + + // Observer callbacks for intermediate data. + LayersOutputFunc layers_output; // if not empty, called after each layer. + ActivationsObserverFunc activations_observer; // if set, called per-layer. + + // If not empty, these point to the image tokens and are used in the + // PaliGemma prefix-LM style attention. + const ImageTokens* image_tokens = nullptr; + + // Whether to use thread spinning to reduce barrier synchronization latency. + // Mutable so we can change kDefault to kTrue/kFalse during Generate, because + // RuntimeConfig is const there and is not passed to the Gemma ctor. This + // default decision is likely sufficient because it is based on whether + // threads are successfully pinned. + mutable Tristate use_spinning = Tristate::kDefault; +}; + +struct InferenceArgs : public ArgsBase { + InferenceArgs(int argc, char* argv[]) { InitAndParse(argc, argv); } + InferenceArgs() { Init(); }; + + bool IsInteractive() const { return prompt.empty() && prompt_file.Empty(); } + + int verbosity; + + size_t seq_len; + size_t max_generated_tokens; + + size_t prefill_tbatch_size; + size_t decode_qbatch_size; + + float temperature; + size_t top_k; + bool deterministic; + bool multiturn; + Path image_file; + + std::string prompt; // Bypasses std::getline + // For prompts longer than the Linux terminal's 4K line edit buffer. + Path prompt_file; + std::string eot_line; + + template + void ForEach(const Visitor& visitor) { + visitor(verbosity, "verbosity", 1, + "Show verbose developer information\n 0 = only print generation " + "output\n 1 = standard user-facing terminal ui\n 2 = show " + "developer/debug info).\n Default = 1.", + 1); + + visitor(seq_len, "seq_len", size_t{8192}, + "Sequence length, capped by ModelConfig.max_seq_len."); + visitor(max_generated_tokens, "max_generated_tokens", size_t{4096}, + "Maximum number of tokens to generate."); + + visitor(prefill_tbatch_size, "prefill_tbatch", size_t{256}, + "Prefill: max tokens per batch."); + visitor(decode_qbatch_size, "decode_qbatch", size_t{16}, + "Decode: max queries per batch."); + + visitor(temperature, "temperature", 1.0f, "Temperature for top-K", 2); + visitor(top_k, "top_k", size_t{1}, "Number of top-K tokens to sample from", + 2); + visitor(deterministic, "deterministic", false, + "Make top-k sampling deterministic", 2); + visitor(multiturn, "multiturn", false, + "Multiturn mode\n 0 = clear KV cache after every " + "interaction\n 1 = continue KV cache after every interaction\n " + " Default : 0 (conversation " + "resets every turn)"); + visitor(image_file, "image_file", Path(), "Image file to load."); + + visitor(prompt, "prompt", std::string(""), + "Initial prompt for non-interactive mode. When specified, " + "generates a response and exits.", + 1); + visitor(prompt_file, "prompt_file", Path(), + "Path to file containing the prompt for non-interactive mode. When " + " specified, generates a response and exits.", + 1); + + visitor( + eot_line, "eot_line", std::string(""), + "End of turn line. " + "When you specify this, the prompt will be all lines " + "before the line where only the given string appears.\n Default = " + "When a newline is encountered, that signals the end of the turn.", + 2); + } + + void CopyTo(RuntimeConfig& runtime_config) const { + runtime_config.max_generated_tokens = max_generated_tokens; + runtime_config.prefill_tbatch_size = prefill_tbatch_size; + runtime_config.decode_qbatch_size = decode_qbatch_size; + if (prefill_tbatch_size > MMStorage::kMaxM) { + HWY_ABORT( + "prefill_tbatch_size %zu > kMaxM %zu: specify a smaller value, " + "or increase the constant in MMStorage.\n", + prefill_tbatch_size, MMStorage::kMaxM); + } + if (decode_qbatch_size > MMStorage::kMaxM) { + HWY_ABORT( + "decode_qbatch_size %zu > kMaxM %zu: specify a smaller value, " + "or increase the constant in MMStorage.\n", + decode_qbatch_size, MMStorage::kMaxM); + } + + runtime_config.temperature = temperature; + runtime_config.top_k = top_k; + } +}; + +static inline ThreadingArgs UpdateArgs(const ThreadingArgs& threading_args, + const InferenceArgs& inference_args) { + if (inference_args.decode_qbatch_size >= 256) { + ThreadingArgs copy = threading_args; + copy.max_packages = 1; + return copy; + } + return threading_args; +} + +} // namespace gcpp + +#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_ARGS_H_ diff --git a/gemma/griffin.cc b/gemma/griffin.cc new file mode 100644 index 0000000..35bf29a --- /dev/null +++ b/gemma/griffin.cc @@ -0,0 +1,192 @@ +// Copyright 2025 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "compression/types.h" // GEMMA_DISABLED_TARGETS +#ifndef HWY_DISABLED_TARGETS +#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS +#endif // HWY_DISABLED_TARGETS + +#include "gemma/activations.h" +#include "gemma/gemma.h" +#include "gemma/gemma_args.h" +#include "gemma/weights.h" +#include "hwy/contrib/thread_pool/thread_pool.h" +#include "hwy/profiler.h" + +// Compiles this file for multiple architectures via "foreach_target.h", to +// which we pass the filename via macro 'argument'. +// clang-format off +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "gemma/griffin.cc" // NOLINT +// clang-format on +#include "hwy/foreach_target.h" // IWYU pragma: keep +#include "hwy/highway.h" +// After highway.h +#include "ops/matvec-inl.h" +#include "ops/ops-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace gcpp { +namespace HWY_NAMESPACE { + +void GriffinRecurrent(size_t num_tokens, size_t griffin_layer, + const LayerWeightsPtrs* layer_weights, + Activations& activations, QBatch& qbatch, + MatMulEnv& env) { + PROFILER_ZONE("Gen.Griffin"); + hwy::ThreadPool& pool = env.ctx.pools.Pool(0); + namespace hn = hwy::HWY_NAMESPACE; + using D = hn::ScalableTag; + const D df; + + const size_t model_dim = layer_weights->layer_config.model_dim; + HWY_DASSERT(model_dim % hn::Lanes(df) == 0); + + const size_t heads = layer_weights->layer_config.heads; + const size_t conv_1d_width = layer_weights->layer_config.conv1d_width; + HWY_ASSERT_M(conv_1d_width % 2 == 0, "Conv width must be even"); + const size_t kHeadDim = model_dim / heads; + const size_t kMatrixSize = kHeadDim * kHeadDim; + + const size_t num_interleaved = num_tokens * qbatch.Size(); + const hwy::Divisor div_qbatch(static_cast(qbatch.Size())); + GriffinActivations& griffin = activations.griffin; + + // X / Y linear layers. + // TODO: MatMul + HWY_DASSERT(griffin.griffin_y.Rows() == griffin.griffin_x.Rows()); + HWY_DASSERT(num_interleaved == griffin.griffin_y.Rows()); + CallUpcastedSame( + &layer_weights->griffin.linear_x_w, &layer_weights->griffin.linear_y_w, + [&](const auto* wx, const auto* wy) { + for (size_t r = 0; r < num_interleaved; ++r) { + float* HWY_RESTRICT y = griffin.griffin_y.Row(r); + float* HWY_RESTRICT x = griffin.griffin_x.Row(r); + TwoMatVecAdd( + *wx, *wy, 0, model_dim, model_dim, + activations.attention.pre_att_rms_out.Row(r), + /*add0=*/layer_weights->griffin.linear_x_biases.PackedScale1(), + /*add1=*/layer_weights->griffin.linear_y_biases.PackedScale1(), + /*out0=*/x, /*out1=*/y, pool); + Gelu(y, model_dim); + } + }); + + // Conv1D. + for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved; + ++interleaved_idx) { + const size_t qi = div_qbatch.Remainder(interleaved_idx); + const size_t batch_idx = div_qbatch.Divide(interleaved_idx); + const size_t pos = qbatch.Pos(qi) + batch_idx; + float* HWY_RESTRICT x = griffin.griffin_x.Row(qi); + + // cache[i] = input at time t-i. + float* HWY_RESTRICT cache[kMaxConv1DWidth]; + cache[0] = x; + for (size_t i = 1; i < conv_1d_width; i++) { + cache[i] = + qbatch.KV(qi).conv1d_cache.Row(griffin_layer) + + ((pos + conv_1d_width - 1 - i) % (conv_1d_width - 1)) * model_dim; + } + for (size_t i = 0; i < model_dim; i += hn::Lanes(df)) { + auto xv = hn::Load(df, x + i); + auto accum0 = + hn::Load(df, layer_weights->griffin.conv_biases.PackedScale1() + i); + auto accum1 = hn::Zero(df); + for (size_t l = 0; 2 * l < conv_1d_width; l++) { + auto wv0 = + hn::Load(df, layer_weights->griffin.conv_w.PackedScale1() + + (conv_1d_width - 1 - 2 * l) * model_dim + i); + auto wv1 = + hn::Load(df, layer_weights->griffin.conv_w.PackedScale1() + + (conv_1d_width - 2 - 2 * l) * model_dim + i); + accum0 = hn::MulAdd(wv0, hn::Load(df, cache[l * 2] + i), accum0); + accum1 = hn::MulAdd(wv1, hn::Load(df, cache[l * 2 + 1] + i), accum1); + } + hn::Store(hn::Add(accum0, accum1), df, x + i); + hn::Store(xv, df, cache[HWY_MAX(conv_1d_width, 1) - 1] + i); + } + } + + // RGLRU + for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved; + ++interleaved_idx) { + const size_t qi = div_qbatch.Remainder(interleaved_idx); + const size_t batch_idx = div_qbatch.Divide(interleaved_idx); + const size_t pos = qbatch.Pos(qi) + batch_idx; + + float* HWY_RESTRICT x = griffin.griffin_x.Row(qi); + float* HWY_RESTRICT y = griffin.griffin_y.Row(qi); + float* HWY_RESTRICT gate_x = griffin.griffin_gate_x.Row(qi); + float* HWY_RESTRICT a = griffin.griffin_multiplier.Row(qi); + float* HWY_RESTRICT rnn_state = + qbatch.KV(qi).rglru_cache.Row(griffin_layer); + + pool.Run(0, heads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR { + size_t head_offset = head * kHeadDim; + CallUpcasted(&layer_weights->griffin.gate_w, [&](const auto* gate_w) { + TwoOfsMatVecAddLoop( + *gate_w, kMatrixSize * head, kMatrixSize * (heads + head), kHeadDim, + kHeadDim, x + head_offset, + /*add0=*/layer_weights->griffin.gate_biases.PackedScale1() + + head_offset, + /*add1=*/layer_weights->griffin.gate_biases.PackedScale1() + + model_dim + head_offset, + /*out0=*/gate_x + head_offset, /*out1=*/a + head_offset); + }); + Sigmoid(gate_x + head_offset, kHeadDim); + Sigmoid(a + head_offset, kHeadDim); + const auto fn_mul = [](D d, hn::Vec x, hn::Vec gate_x) + HWY_ATTR { return hn::Mul(x, gate_x); }; + hn::Transform1(D(), a + head_offset, kHeadDim, + layer_weights->griffin.a.PackedScale1() + head_offset, + fn_mul); + hn::Transform1(D(), x + head_offset, kHeadDim, gate_x + head_offset, + fn_mul); + // RNN scan + HWY_DASSERT(kHeadDim % hn::Lanes(df) == 0); + for (size_t i = 0; i < kHeadDim; i += hn::Lanes(df)) { + auto log_a = hn::Load(df, a + head_offset + i); + auto gated_x = hn::Load(df, x + head_offset + i); + auto rnn = hn::Load(df, rnn_state + head_offset + i); + auto a = hn::Exp(df, log_a); + auto x_multiplier = hn::Sqrt(hn::NegMulAdd(a, a, hn::Set(df, 1.0f))); + if (pos == 0) { + x_multiplier = hn::Set(df, 1.0f); + } + auto new_x = hn::MulAdd(x_multiplier, gated_x, hn::Mul(a, rnn)); + hn::Store(new_x, df, rnn_state + head_offset + i); + + // Join branches. + auto yv = hn::Load(df, y + head_offset + i); + auto pre_out = hn::Mul(yv, new_x); + hn::Store(pre_out, df, x + head_offset + i); + } + }); + } // interleaved_idx + + // Final linear layer. + CallMatMul(griffin.griffin_x, layer_weights->griffin.linear_out_w, + layer_weights->griffin.linear_out_biases.PackedScale1(), env, + activations.attention.att_sums); +} // GriffinRecurrent + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace gcpp +HWY_AFTER_NAMESPACE(); diff --git a/gemma/griffin.h b/gemma/griffin.h new file mode 100644 index 0000000..0ba6a23 --- /dev/null +++ b/gemma/griffin.h @@ -0,0 +1,47 @@ +// Copyright 2025 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_GRIFFIN_H_ +#define THIRD_PARTY_GEMMA_CPP_GEMMA_GRIFFIN_H_ + +// Declares GriffinRecurrent for all SIMD targets. + +#include + +#include "gemma/gemma.h" +#include "hwy/highway.h" + +namespace gcpp { + +// Passed to HWY_VISIT_TARGETS; declares for one target. +#define GEMMA_DECL_GRIFFIN(TARGET, NAMESPACE) \ + namespace NAMESPACE { \ + void GriffinRecurrent(size_t num_tokens, size_t griffin_layer, \ + const LayerWeightsPtrs* layer_weights, \ + Activations& activations, QBatch& qbatch, \ + MatMulEnv& env); \ + /* NOLINTNEXTLINE(google-readability-namespace-comments) */ \ + } // namespace NAMESPACE + +// Function declarations for each SIMD target. Allows direct call from the +// per-target namespace. We may later replace this with dynamic dispatch if +// the overhead is acceptable. +HWY_VISIT_TARGETS(GEMMA_DECL_GRIFFIN) + +#undef GEMMA_DECL_GRIFFIN + +} // namespace gcpp + +#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_GRIFFIN_H_ diff --git a/gemma/kv_cache.cc b/gemma/kv_cache.cc index 60ad5dd..9d107e8 100644 --- a/gemma/kv_cache.cc +++ b/gemma/kv_cache.cc @@ -15,91 +15,75 @@ #include "gemma/kv_cache.h" -#include +#include -#include "gemma/common.h" // CallForModel -#include "hwy/aligned_allocator.h" -#include "hwy/base.h" // ZeroBytes +#include "gemma/configs.h" +#include "gemma/gemma_args.h" +#include "util/mat.h" // ZeroInit +#include "hwy/base.h" // HWY_MAX namespace gcpp { void KVCache::ZeroGriffinCache() { - if (conv1d_cache_size != 0) { - hwy::ZeroBytes(conv1d_cache.get(), - conv1d_cache_size * sizeof(conv1d_cache[0])); - } - if (rglru_cache_size != 0) { - hwy::ZeroBytes(rglru_cache.get(), - rglru_cache_size * sizeof(rglru_cache[0])); - } + if (conv1d_cache.Rows() == 0) return; + ZeroInit(conv1d_cache); + ZeroInit(rglru_cache); } -// prefill_tbatch_size is the maximum number of tokens from one query to -// prefill at a time. -KVCache KVCache::Create(const ModelConfig& weights_config, - size_t prefill_tbatch_size) { - KVCache kv_cache = {}; - - const size_t size_cache_pos = weights_config.CachePosSize(); - if (size_cache_pos != 0) { - // Allocate more so that prefill can always access one batch, even if - // near the end of the sequence. - kv_cache.seq_len = weights_config.seq_len + prefill_tbatch_size; - kv_cache.kv_cache = - hwy::AllocateAligned(kv_cache.seq_len * size_cache_pos); - } - - const size_t num_griffin_layers = weights_config.NumLayersOfType( - LayerAttentionType::kGriffinRecurrentBlock); - // TODO(patrickms): Add query batching support for Griffin. - if (num_griffin_layers > 0) { - uint32_t conv1d_width = 0; - for (const auto& layer_config : weights_config.layer_configs) { - conv1d_width = std::max(conv1d_width, layer_config.conv1d_width); - } - const size_t conv1d_cache_size = - num_griffin_layers * (conv1d_width == 0 ? 0 : conv1d_width - 1) * - weights_config.model_dim; - kv_cache.conv1d_cache_size = conv1d_cache_size; - if (conv1d_cache_size != 0) { - kv_cache.conv1d_cache = hwy::AllocateAligned(conv1d_cache_size); - } - - const size_t rglru_cache_size = - num_griffin_layers * weights_config.model_dim; - kv_cache.rglru_cache_size = rglru_cache_size; - if (rglru_cache_size != 0) { - kv_cache.rglru_cache = hwy::AllocateAligned(rglru_cache_size); - } - } // num_griffin_layers - - return kv_cache; +static size_t GriffinLayers(const ModelConfig& config) { + return config.NumLayersOfType(LayerAttentionType::kGriffinRecurrentBlock); } -KVCache KVCache::Copy(const ModelConfig& weights_config, - size_t prefill_tbatch_size) { - KVCache kv_cache_copy = Create(weights_config, prefill_tbatch_size); +static size_t GriffinConv1dCols(const ModelConfig& config) { + size_t conv1d_width = 0; + for (const auto& layer_config : config.layer_configs) { + conv1d_width = HWY_MAX(conv1d_width, layer_config.conv1d_width); + } + // The row offset, in blocks of model_dim is computed mod (conv1d_width - 1), + // hence allocate conv1d_width * model_dim total columns. + return conv1d_width * config.model_dim; +} - const size_t size_cache_pos = weights_config.CachePosSize(); - if (size_cache_pos != 0) { - std::copy(kv_cache.get(), kv_cache.get() + size_cache_pos * seq_len, - kv_cache_copy.kv_cache.get()); +// Number of rows for KV cache. Note that both rows and cols are u32, and +// the total number of elements can exceed 2^32. +static size_t CappedSeqLen(const ModelConfig& config, + const InferenceArgs& inference_args) { + if (inference_args.seq_len > config.max_seq_len) { + HWY_WARN("Capping seq_len %zu to config.max_seq_len %u.", + inference_args.seq_len, config.max_seq_len); + return config.max_seq_len; + } + return inference_args.seq_len; +} + +KVCache::KVCache(const Extents2D& conv1d_extents, + const Extents2D& rglru_extents, const Extents2D& kv_extents, + const Allocator& allocator) + : conv1d_cache("conv1d_cache", conv1d_extents, allocator, MatPadding::kOdd), + rglru_cache("rglru_cache", rglru_extents, allocator, MatPadding::kOdd), + kv_cache("kv", kv_extents, allocator, MatPadding::kOdd), + allocator_(allocator) {} + +KVCache::KVCache(const ModelConfig& config, const InferenceArgs& inference_args, + const Allocator& allocator) + : KVCache( + Extents2D(GriffinLayers(config), GriffinConv1dCols(config)), + Extents2D(GriffinLayers(config), config.model_dim), + Extents2D(CappedSeqLen(config, inference_args), config.KVCacheCols()), + allocator) {} + +KVCache KVCache::Copy() { + KVCache copy(conv1d_cache.Extents(), rglru_cache.Extents(), + kv_cache.Extents(), allocator_); + + if (conv1d_cache.Rows() != 0) { + CopyMat(conv1d_cache, copy.conv1d_cache); + CopyMat(rglru_cache, copy.rglru_cache); } - const size_t num_griffin_layers = weights_config.NumLayersOfType( - LayerAttentionType::kGriffinRecurrentBlock); - if (num_griffin_layers > 0) { - if (conv1d_cache_size != 0) { - std::copy(conv1d_cache.get(), conv1d_cache.get() + conv1d_cache_size, - kv_cache_copy.conv1d_cache.get()); - } - if (rglru_cache_size != 0) { - std::copy(rglru_cache.get(), - rglru_cache.get() + rglru_cache_size * sizeof(rglru_cache[0]), - kv_cache_copy.rglru_cache.get()); - } - } - return kv_cache_copy; + CopyMat(kv_cache, copy.kv_cache); + + return copy; } } // namespace gcpp diff --git a/gemma/kv_cache.h b/gemma/kv_cache.h index 6052d0b..7b5b88d 100644 --- a/gemma/kv_cache.h +++ b/gemma/kv_cache.h @@ -18,34 +18,41 @@ #include -#include "gemma/common.h" // Model -#include "hwy/aligned_allocator.h" +#include "gemma/configs.h" // ModelConfig +#include "gemma/gemma_args.h" // InferenceArgs +#include "util/basics.h" // BF16 +#include "util/mat.h" namespace gcpp { +using KV_t = float; + struct KVCache { - size_t seq_len = 0; // = kSeqLen + prefill_tbatch_size + KVCache(const ModelConfig& config, const InferenceArgs& inference_args, + const Allocator& allocator); - // seq_len * kGemmaLayers * kKVHeads * kQKVDim * 2 - hwy::AlignedFreeUniquePtr kv_cache; - - // (kConv1dWidth - 1) * kModelDim * kGriffinLayers - hwy::AlignedFreeUniquePtr conv1d_cache; - size_t conv1d_cache_size = 0; - - // kModelDim * kGriffinLayers - hwy::AlignedFreeUniquePtr rglru_cache; - size_t rglru_cache_size = 0; + // Returns a deep copy of the KVCache. Use explicit function instead of + // copy ctor to make the cost explicit. + KVCache Copy(); // Zero-initialize the Griffin recurrent block cache, i.e. the conv1d_cache // and rglru_cache. void ZeroGriffinCache(); - static KVCache Create(const ModelConfig& weights_config, - size_t prefill_tbatch_size); + size_t SeqLen() const { return kv_cache.Rows(); } - // Returns a deep copy of the KVCache. - KVCache Copy(const ModelConfig& weights_config, size_t prefill_tbatch_size); + // [griffin_layers, griffin_conv1d_cols * model_dim] + MatStorageT conv1d_cache; + MatStorageT rglru_cache; // [griffin_layers, model_dim] + + MatStorageT kv_cache; // [seq_len, layers * kv_heads * qkv_dim * 2] + + private: + const Allocator& allocator_; + + // For use by other ctor and Copy() + KVCache(const Extents2D& conv1d_extents, const Extents2D& rglru_extents, + const Extents2D& kv_extents, const Allocator& allocator); }; } // namespace gcpp diff --git a/gemma/model_store.cc b/gemma/model_store.cc new file mode 100644 index 0000000..8f6c138 --- /dev/null +++ b/gemma/model_store.cc @@ -0,0 +1,464 @@ +// Copyright 2025 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gemma/model_store.h" + +#include +#include +#include + +#include +#include +#include +#include // strcmp +#include +#include // std::errc // NOLINT + +#include "compression/types.h" +#include "gemma/configs.h" // ModelConfig, kMaxQKVDim +#include "gemma/tensor_info.h" +#include "gemma/tokenizer.h" +#include "io/blob_store.h" +#include "io/fields.h" +#include "io/io.h" // Path +#include "util/basics.h" +#include "util/threading_context.h" +#include "hwy/base.h" + +namespace gcpp { + +// Single-file format contains blobs with these names: +static constexpr char kConfigName[] = "config"; +static constexpr char kTokenizerName[] = "tokenizer"; +static constexpr char kMatPtrsName[] = "toc"; +// Pre-2025 format has one metadata blob. 'F' denoted f32. +static constexpr char kDecoratedScalesName[] = "Fscales"; + +static void WarnIfExtra(const IFields::ReadResult& result, const char* name) { + // No warning if missing_fields > 0: those fields are default-initialized. + if (result.extra_u32) { + HWY_WARN( + "Serialized blob %s has %u extra fields the code is not aware of. " + "Consider updating to the latest code from GitHub.", + name, result.extra_u32); + } +} + +// Returns the serialized tokenizer (std::string is required for proto). +// Reads it from a blob or from a separate file if pre-2025. +static std::string ReadTokenizer(BlobReader& reader, + const Path& tokenizer_path) { + std::string tokenizer; + // Check prevents `CallWithSpan` from printing a warning. + if (reader.Find(kTokenizerName)) { + if (!reader.CallWithSpan( + kTokenizerName, [&tokenizer](const hwy::Span bytes) { + tokenizer.assign(bytes.data(), bytes.size()); + })) { + HWY_WARN( + "Reading tokenizer blob failed, please raise an issue. You can " + "instead specify a tokenizer file via --tokenizer."); + } + } + + // Read actual tokenizer from blob. + if (!tokenizer.empty() && tokenizer != kMockTokenizer) { + if (!tokenizer_path.Empty()) { + HWY_WARN("--weights has tokenizer but overriding with %s.", + tokenizer_path.path.c_str()); + return ReadFileToString(tokenizer_path); + } + + return tokenizer; + } + + // No blob but user specified path to file: read it or abort. + if (!tokenizer_path.Empty()) { + return ReadFileToString(tokenizer_path); + } + + HWY_WARN( + "BlobStore does not contain a tokenizer and no --tokenizer was " + "specified. Tests may continue but inference will fail."); + return kMockTokenizer; +} + +using KeyVec = std::vector; + +class TypePrefix { + public: + static Type TypeFromChar(char c) { + switch (c) { + case 'F': + return Type::kF32; + case 'B': + return Type::kBF16; + case '$': + return Type::kSFP; + case '2': + return Type::kNUQ; + default: + // The other types were not written to pre-2025 files, hence no need to + // encode and check for them here. + return Type::kUnknown; + } + } + + TypePrefix(const KeyVec& keys, const BlobReader& reader) { + for (size_t key_idx = 0; key_idx < keys.size(); ++key_idx) { + const std::string& key = keys[key_idx]; + const Type type = TypeFromChar(key[0]); + const uint64_t bytes = reader.Range(key_idx).bytes; + bytes_[static_cast(type)] += bytes; + blobs_[static_cast(type)]++; + total_bytes_ += bytes; + } + } + + // Returns true for pre-2025 format, which has type prefixes and thus the + // functions below may be used. + bool HasPrefixes() const { + return bytes_[static_cast(Type::kUnknown)] != total_bytes_; + } + + // Returns the weight type deduced from the histogram of blobs per type. + // Rationale: We expect a mix of types due to varying precision requirements + // for each tensor. The preferred weight type might not even be the most + // common, because we prioritize higher compression for the *large* tensors. + // Ignore types which only have a few blobs (might be metadata), and assume + // that there would be at least 4 of the large tensors (in particular, global + // attention layers). Hence return the smallest type with >= 4 blobs. + Type DeduceWeightType() const { + size_t min_bits = ~size_t{0}; + Type weight_type = Type::kUnknown; + for (size_t i = 0; i < kNumTypes; ++i) { + if (blobs_[i] < 4) continue; + const size_t bits = TypeBits(static_cast(i)); + if (bits < min_bits) { + min_bits = bits; + weight_type = static_cast(i); + } + } + return weight_type; + } + + // Prints statistics on the total size of tensors by type. + void PrintTypeBytes() const { + for (size_t type_idx = 0; type_idx < kNumTypes; ++type_idx) { + const Type type = static_cast(type_idx); + const uint64_t bytes = bytes_[type_idx]; + if (bytes == 0) continue; + const double percent = 100.0 * bytes / total_bytes_; + fprintf(stderr, "%12zu blob bytes (%5.2f%%) of %4s\n", + static_cast(bytes), percent, TypeName(type)); + } + } + + private: + uint64_t total_bytes_ = 0; + std::array bytes_{0}; + std::array blobs_{0}; +}; + +// Returns 0 if the blob does not seem to be a per-layer tensor, otherwise the +// layer index. +static size_t LayerIdxFromKey(const std::string& key) { + const auto parse_num = [&key](size_t begin, size_t end) -> int { + HWY_DASSERT(begin <= end); + HWY_DASSERT(end <= key.size()); + int val = 0; + auto [ptr, ec] = std::from_chars(key.data() + begin, key.data() + end, val); + return (ec == std::errc()) ? val : -1; + }; + + const size_t suffix_pos = key.rfind('_'); + // If there is no digit after the last underscore, it is not a layer name. + if (suffix_pos == std::string::npos) return 0; + if (suffix_pos == key.size() - 1) return 0; + + int layer_idx = parse_num(suffix_pos + 1, key.size()); + + HWY_ASSERT(layer_idx < 999); + return layer_idx == -1 ? 0 : static_cast(layer_idx); +} + +// Returns the number of layers based on the largest blob name suffix seen. +// This works with or without type prefixes because it searches for suffixes. +static size_t DeduceNumLayers(const KeyVec& keys) { + // Built-in self-test. + { + HWY_ASSERT(LayerIdxFromKey("gr_conv_w_2") == 2); // common case + HWY_ASSERT(LayerIdxFromKey("prefix_") == 0); // no number + HWY_ASSERT(LayerIdxFromKey("c_embedding") == 0); // per-model + HWY_ASSERT(LayerIdxFromKey("c_final_norm") == 0); // per-model, two _ + } + + size_t max_layer_idx = 0; + for (const std::string& key : keys) { + max_layer_idx = HWY_MAX(max_layer_idx, LayerIdxFromKey(key)); + } + return max_layer_idx + 1; +} + +// Looks for known tensor names associated with model families. +// This works with or without type prefixes because it searches for substrings. +static int DeduceLayerTypes(const BlobReader& reader) { + int layer_types = 0; + for (size_t key_idx = 0; key_idx < reader.Keys().size(); ++key_idx) { + const std::string& key = reader.Keys()[key_idx]; + if (key.find("gr_conv_w") != std::string::npos) { // NOLINT + return kDeducedGriffin; + } + if (key.find("qkv_ein_w") != std::string::npos) { // NOLINT + layer_types |= kDeducedViT; + } + if (key.find("img_pos_emb") != std::string::npos) { // NOLINT + // About 5.88 elements per pixel; assume at least bf16. + if (reader.Range(key_idx).bytes > 448 * 448 * 5 * sizeof(BF16)) { + layer_types |= kDeduced448; + } + } + } + return layer_types; +} + +// `wrapping_override` is forwarded from the command line. For pre-2025 files +// without `ModelConfig`, it is the only way to force PT. +static ModelConfig ReadOrDeduceConfig(BlobReader& reader, + Tristate wrapping_override) { + const TypePrefix type_prefix(reader.Keys(), reader); + Type deduced_weight = Type::kUnknown; + if (type_prefix.HasPrefixes()) { + deduced_weight = type_prefix.DeduceWeightType(); + type_prefix.PrintTypeBytes(); + } + + // Always deduce so we can verify it against the config we read. + const size_t layers = DeduceNumLayers(reader.Keys()); + const int layer_types = DeduceLayerTypes(reader); + const Model deduced_model = + DeduceModel(reader.blob_path(), layers, layer_types); + + ModelConfig config; + // Check first to prevent `CallWithSpan` from printing a warning. + if (reader.Find(kConfigName)) { + HWY_ASSERT(reader.CallWithSpan( + kConfigName, [&config](const SerializedSpan serialized) { + const IFields::ReadResult result = config.Read(serialized, 0); + WarnIfExtra(result, kConfigName); + HWY_ASSERT_M(result.pos != 0, "Error deserializing config"); + })); + + HWY_ASSERT(config.model != Model::UNKNOWN); + HWY_ASSERT(config.wrapping != PromptWrapping::kSentinel); + HWY_ASSERT(config.weight != Type::kUnknown); + for (const LayerConfig& layer_config : config.layer_configs) { + if (static_cast(layer_config.qkv_dim) > kMaxQKVDim) { + HWY_ABORT("Increase kMaxQKVDim to at least %u.", layer_config.qkv_dim); + } + } + + // We trust the deserialized config, but checking helps to validate the + // deduction, which we rely on below for pre-2025 files. + if (config.model != deduced_model) { + const std::string suffix = WrappingSuffix(config.wrapping); + HWY_WARN("Detected model %s does not match config %s.", + (std::string(ModelPrefix(deduced_model)) + suffix).c_str(), + (std::string(ModelPrefix(config.model)) + suffix).c_str()); + } + return config; + } + + // Pre-2025 format: no config, rely on deduction plus `wrapping_override`. + return ModelConfig(deduced_model, deduced_weight, + ChooseWrapping(deduced_model, wrapping_override)); +} + +static std::vector ReadScales(BlobReader& reader, + const ModelConfig& config) { + std::vector scales; + // Check first to prevent `CallWithSpan` from printing a warning. This blob is + // optional even in pre-2025 format; Griffin was the first to include it. + if (reader.Find(kDecoratedScalesName)) { + HWY_ASSERT(reader.CallWithSpan( + kDecoratedScalesName, + [&scales](const hwy::Span scales_blob) { + scales.assign(scales_blob.cbegin(), scales_blob.cend()); + })); + } + return scales; +} + +// Single-file format: reads `MatPtr` from the blob; returns false if not found. +bool ModelStore::ReadMatPtrs(BlobReader& reader) { + // Check first to prevent `CallWithSpan` from printing a warning. + if (!reader.Find(kMatPtrsName)) return false; + + // For verifying `config_.weight`. + size_t min_bits = ~size_t{0}; + Type weight_type = Type::kUnknown; + + HWY_ASSERT(reader.CallWithSpan( + kMatPtrsName, [&, this](SerializedSpan serialized) { + for (size_t pos = 0; pos < serialized.size();) { + MatPtr mat; + const IFields::ReadResult result = mat.Read(serialized, pos); + WarnIfExtra(result, mat.Name()); + if (result.pos == 0) { + HWY_ABORT("Deserializing MatPtr %s failed (pos %zu of %zu).", + mat.Name(), pos, serialized.size()); + } + pos = result.pos + result.extra_u32; + + // Retrieve actual key index because a writer may have written other + // blobs before the tensor data. + const BlobRange* range = reader.Find(mat.Name()); + HWY_ASSERT(range); + const size_t key_idx = range->key_idx; + AddMatPtr(key_idx, mat); + + const size_t bits = TypeBits(mat.GetType()); + if (bits < min_bits) { + min_bits = bits; + weight_type = mat.GetType(); + } + } + })); + + HWY_ASSERT(weight_type != Type::kUnknown); + HWY_ASSERT(weight_type == config_.weight); + + return true; +} + +// Pre-2025 format: synthesizes `MatPtr` from the blob names if `!ReadMatPtrs`. +void ModelStore::CreateMatPtrs(BlobReader& reader) { + const TensorInfoRegistry tensors(config_); + + const KeyVec& keys = reader.Keys(); + mat_ptrs_.reserve(keys.size()); + // `key_idx` is the blob index. It is not the same as the index of the + // `MatPtr` in `mat_ptrs_` because not all blobs are tensors. + for (size_t key_idx = 0; key_idx < keys.size(); ++key_idx) { + const Type type = TypePrefix::TypeFromChar(keys[key_idx][0]); + if (type == Type::kUnknown) continue; // likely not a tensor + + // Strip type prefix from the key. Still includes layer suffix. + const std::string name = keys[key_idx].substr(1); + const TensorInfo* info = tensors.Find(name); + if (HWY_UNLIKELY(!info)) { + if (name == "scales") continue; // ignore, not a tensor. + HWY_ABORT("Unknown tensor %s.", name.c_str()); + } + // Unable to set scale already because they are ordered according to + // `ForEachTensor`, which we do not know here. The initial value is 1.0f + // and we set the correct value in `FindAndUpdateMatPtr`. + AddMatPtr(key_idx, MatPtr(name.c_str(), type, ExtentsFromInfo(info))); + } + HWY_ASSERT(mat_ptrs_.size() <= keys.size()); + HWY_ASSERT(mat_ptrs_.size() == key_idx_.size()); +} + +ModelStore::ModelStore(BlobReader& reader, const Path& tokenizer_path, + Tristate wrapping) + : config_(ReadOrDeduceConfig(reader, wrapping)), + tokenizer_(ReadTokenizer(reader, tokenizer_path)) { + if (!ReadMatPtrs(reader)) { // Pre-2025 format. + CreateMatPtrs(reader); + scales_ = ReadScales(reader, config_); + // ModelConfig serialized a vector of strings. Unpack into a set for more + // efficient lookup. + for (const std::string& name : config_.scale_base_names) { + scale_base_names_.insert(name); + } + // If the model has scales, the config must know about it. + HWY_ASSERT(scales_.empty() || !scale_base_names_.empty()); + } + + HWY_ASSERT(key_idx_.size() == mat_ptrs_.size()); +} + +ModelStore::~ModelStore() { + // Sanity check: ensure all scales were consumed. + HWY_ASSERT(scales_consumed_ == scales_.size()); +} + +const MatPtr* ModelStore::FindMat(const char* name) const { + auto it = mat_idx_for_name_.find(name); + if (it == mat_idx_for_name_.end()) return nullptr; + const size_t mat_idx = it->second; + const MatPtr* file_mat = &mat_ptrs_[mat_idx]; + HWY_ASSERT(!strcmp(file_mat->Name(), name)); + return file_mat; +} + +bool ModelStore::FindAndUpdateMatPtr(MatPtr& mat, size_t& key_idx) const { + const MatPtr* file_mat = FindMat(mat.Name()); + if (!file_mat) return false; + if (file_mat->Rows() != mat.Rows() || file_mat->Cols() != mat.Cols()) { + HWY_ABORT("Tensor %s shape %zu %zu mismatches file %zu %zu.", mat.Name(), + mat.Rows(), mat.Cols(), file_mat->Rows(), file_mat->Cols()); + } + // `Compress()` output is always packed because it assumes a 1D array. + HWY_ASSERT(mat.IsPacked()); + // Update fields. Name already matched, otherwise we would not find it. + // For MatPtr tensors, the type will be `kUnknown`. If it was a `MatPtrT`, + // ensure the type set via code matches the file. + HWY_ASSERT_M( + mat.GetType() == Type::kUnknown || mat.GetType() == file_mat->GetType(), + mat.Name()); + mat.SetType(file_mat->GetType()); + if (scales_.empty()) { + // `file_mat->Scale()` is either read from file, or we have pre-2025 format + // without the optional scales, and it is default-initialized to 1.0f. + mat.SetScale(file_mat->Scale()); + } else { // Pre-2025 with scaling factors: set next if `mat` wants one. + if (scale_base_names_.find(StripLayerSuffix(mat.Name())) != + scale_base_names_.end()) { + HWY_ASSERT(scales_consumed_ < scales_.size()); + mat.SetScale(scales_[scales_consumed_++]); + } + } + + key_idx = key_idx_[file_mat - mat_ptrs_.data()]; + return true; +} + +static void AddBlob(const char* name, const std::vector& data, + BlobWriter& writer) { + HWY_ASSERT(!data.empty()); + writer.Add(name, data.data(), data.size() * sizeof(data[0])); +} + +void WriteSingleFile(const ModelConfig& config, const GemmaTokenizer& tokenizer, + const std::vector& serialized_mat_ptrs, + BlobWriter& writer) { + HWY_ASSERT(config.model != Model::UNKNOWN); + HWY_ASSERT(config.weight != Type::kUnknown); + HWY_ASSERT(config.wrapping != PromptWrapping::kSentinel); + const std::vector serialized_config = config.Write(); + AddBlob(kConfigName, serialized_config, writer); + + const std::string serialized_tokenizer = tokenizer.Serialize(); + HWY_ASSERT(!serialized_tokenizer.empty()); + writer.Add(kTokenizerName, serialized_tokenizer.data(), + serialized_tokenizer.size()); + + AddBlob(kMatPtrsName, serialized_mat_ptrs, writer); + + writer.WriteAll(); +} + +} // namespace gcpp diff --git a/gemma/model_store.h b/gemma/model_store.h new file mode 100644 index 0000000..42af343 --- /dev/null +++ b/gemma/model_store.h @@ -0,0 +1,111 @@ +// Copyright 2025 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Reads/writes model metadata (all but the weights) from/to a `BlobStore`. +#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_MODEL_STORE_H_ +#define THIRD_PARTY_GEMMA_CPP_GEMMA_MODEL_STORE_H_ + +#include +#include + +#include +#include +#include +#include +#include + +// IWYU pragma: begin_exports +#include "gemma/configs.h" // ModelConfig +#include "gemma/tokenizer.h" +#include "io/blob_store.h" +#include "io/io.h" // Path +#include "util/basics.h" // Tristate +#include "util/mat.h" // MatPtr +// IWYU pragma: end_exports + +#include "util/allocator.h" +#include "hwy/contrib/thread_pool/thread_pool.h" + +namespace gcpp { + +// Reads and holds the model config, tokenizer and all `MatPtr`: everything +// except the tensor data, which are read/written by `weights.cc`. +// +// As of 2025-04, the `BlobStore` format includes blobs for `ModelConfig`, +// tokenizer, and all `MatPtr` metadata. "Pre-2025" format instead stored the +// tokenizer in a separate file, encoded tensor type in a prefix of the blob +// name, and had a blob for tensor scaling factors. We still support reading +// both, but only write single-file format. +class ModelStore { + public: + // Reads from file(s) or aborts on error. The latter two arguments are only + // used for pre-2025 files. + ModelStore(BlobReader& reader, const Path& tokenizer_path = Path(), + Tristate wrapping = Tristate::kDefault); + ~ModelStore(); + + const ModelConfig& Config() const { + HWY_ASSERT(config_.model != Model::UNKNOWN); + return config_; + } + + const GemmaTokenizer& Tokenizer() const { return tokenizer_; } + + // Returns nullptr if `name` is not available for loading, otherwise the + // metadata of that tensor. + const MatPtr* FindMat(const char* name) const; + + // Returns false if `mat` is not available for loading, otherwise updates + // `mat` with metadata from the file and sets `key_idx` for use by + // `BlobReader`. Called via `ReadOrAllocate` in `weights.cc`. + bool FindAndUpdateMatPtr(MatPtr& mat, size_t& key_idx) const; + + private: + void AddMatPtr(const size_t key_idx, const MatPtr& mat) { + auto pair_ib = mat_idx_for_name_.insert({mat.Name(), mat_ptrs_.size()}); + HWY_ASSERT_M(pair_ib.second, mat.Name()); // Ensure inserted/unique. + mat_ptrs_.push_back(mat); + key_idx_.push_back(key_idx); + } + + bool ReadMatPtrs(BlobReader& reader); + void CreateMatPtrs(BlobReader& reader); // Aborts on error. + + ModelConfig config_; + GemmaTokenizer tokenizer_; + + // All `MatPtr` present in the `BlobStore`, see `ReadMatPtrs`/`CreateMatPtrs`. + std::vector mat_ptrs_; + // For each of `mat_ptrs_`, the index within `BlobReader::Keys()`. This is + // not necessarily iota because some blobs are not tensors, and callers may + // have added blobs before ours. + std::vector key_idx_; + // Index within `mat_ptrs_` and `key_idx_` for each tensor name. + std::unordered_map mat_idx_for_name_; + + // Only used if `!ReadMatPtrs` (pre-2025 format): + std::vector scales_; + std::unordered_set scale_base_names_; + mutable size_t scales_consumed_ = 0; +}; + +// Adds metadata blobs to `writer` and writes everything to `path`. This +// produces a single BlobStore file holding everything required for inference. +void WriteSingleFile(const ModelConfig& config, const GemmaTokenizer& tokenizer, + const std::vector& serialized_mat_ptrs, + BlobWriter& writer); + +} // namespace gcpp +#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_MODEL_STORE_H_ diff --git a/gemma/run.cc b/gemma/run.cc index 254d13f..7cbc4de 100644 --- a/gemma/run.cc +++ b/gemma/run.cc @@ -15,22 +15,22 @@ // Command line text interface to gemma. +#include + #include #include #include #include #include -// Placeholder for internal header, do not modify. -#include "compression/shared.h" // PromptWrapping +#include "compression/types.h" // PromptWrapping #include "evals/benchmark_helper.h" -#include "gemma/common.h" #include "gemma/gemma.h" // Gemma -#include "ops/matmul.h" // MatMulEnv +#include "gemma/gemma_args.h" +#include "gemma/tokenizer.h" // WrapAndTokenize +#include "ops/matmul.h" // MatMulEnv #include "paligemma/image.h" -#include "util/app.h" #include "util/args.h" // HasHelp -#include "util/threading.h" #include "hwy/base.h" #include "hwy/highway.h" #include "hwy/profiler.h" @@ -55,9 +55,8 @@ static constexpr std::string_view kAsciiArtBanner = R""( |___/ |_| |_| )""; -std::string GetPrompt(std::istream& input, int verbosity, - std::string_view eot_line) { - PROFILER_ZONE("Gen.input"); +std::string GetPromptFromStream(std::istream& input, int verbosity, + std::string_view eot_line) { if (verbosity >= 1) { std::cout << "> " << std::flush; } @@ -77,36 +76,55 @@ std::string GetPrompt(std::istream& input, int verbosity, return prompt_string; } +// Get prompt either from interactive input or command line +std::string GetPrompt(const InferenceArgs& inference) { + PROFILER_ZONE("Gen.input"); + // If prompt is provided via command line, use that + if (!inference.prompt.empty()) { + return inference.prompt; + } + if (!inference.prompt_file.Empty()) { + return ReadFileToString(inference.prompt_file); + } + + return GetPromptFromStream(std::cin, inference.verbosity, inference.eot_line); +} + // The main Read-Eval-Print Loop. -void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app, - const InferenceArgs& args, const AcceptFunc& accept_token, - std::string& eot_line) { +void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference, + const Gemma& gemma, KVCache& kv_cache, MatMulEnv& env) { PROFILER_ZONE("Gen.misc"); size_t abs_pos = 0; // across turns size_t tokens_generated_this_turn = 0; // differentiates prefill from reply size_t prompt_size = 0; + const ModelConfig& config = gemma.Config(); std::mt19937 gen; - InitGenerator(args, gen); + InitGenerator(inference, gen); - const bool have_image = !args.image_file.path.empty(); + const bool have_image = !inference.image_file.path.empty(); Image image; - ImageTokens image_tokens; + const size_t pool_dim = config.vit_config.pool_dim; + ImageTokens image_tokens( + "image_tokens", + have_image ? Extents2D(config.vit_config.seq_len / (pool_dim * pool_dim), + config.model_dim) + : Extents2D(0, 0), + env.ctx.allocator, MatPadding::kOdd); + image_tokens.AllocateAndAttachRowPtrs(env.row_ptrs); if (have_image) { - size_t pool_dim = model.GetModelConfig().vit_config.pool_dim; - image_tokens = ImageTokens(Extents2D( - model.GetModelConfig().vit_config.seq_len / (pool_dim * pool_dim), - model.GetModelConfig().model_dim)); - HWY_ASSERT(model.Info().wrapping == PromptWrapping::PALIGEMMA || - model.Info().wrapping == PromptWrapping::GEMMA_VLM); - HWY_ASSERT(image.ReadPPM(args.image_file.path)); - const size_t image_size = model.GetModelConfig().vit_config.image_size; + HWY_ASSERT(config.wrapping == PromptWrapping::PALIGEMMA || + config.wrapping == PromptWrapping::GEMMA_VLM); + HWY_ASSERT(image.ReadPPM(inference.image_file.path)); + const size_t image_size = config.vit_config.image_size; image.Resize(image_size, image_size); - RuntimeConfig runtime_config = { - .gen = &gen, .verbosity = app.verbosity, .use_spinning = app.spin}; + RuntimeConfig runtime_config = {.gen = &gen, + .verbosity = inference.verbosity, + .use_spinning = threading.spin}; double image_tokens_start = hwy::platform::Now(); - model.GenerateImageTokens(runtime_config, image, image_tokens); - if (app.verbosity >= 1) { + gemma.GenerateImageTokens(runtime_config, kv_cache.SeqLen(), image, + image_tokens, env); + if (inference.verbosity >= 1) { double image_tokens_duration = hwy::platform::Now() - image_tokens_start; fprintf(stderr, "\n\n[ Timing info ] Image token generation took: %d ms\n", @@ -121,21 +139,21 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app, const bool first_response_token = tokens_generated_this_turn == prompt_size; ++tokens_generated_this_turn; if (in_prompt) { - if (app.verbosity >= 1) { - std::cerr << "." << std::flush; + if (inference.verbosity >= 1) { + std::cout << "." << std::flush; } return true; - } else if (model.GetModelConfig().IsEOS(token)) { - if (app.verbosity >= 2) { + } else if (config.IsEOS(token)) { + if (inference.verbosity >= 2) { std::cout << "\n[ End ]\n"; } return true; } std::string token_text; - HWY_ASSERT(model.Tokenizer().Decode(std::vector{token}, &token_text)); + HWY_ASSERT(gemma.Tokenizer().Decode(std::vector{token}, &token_text)); if (first_response_token) { token_text.erase(0, token_text.find_first_not_of(" \t\n")); - if (app.verbosity >= 1) { + if (inference.verbosity >= 1) { std::cout << "\n\n"; } } @@ -146,72 +164,77 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app, while (true) { // Loop until user quits. tokens_generated_this_turn = 0; - // Read prompt and handle special commands. - std::string prompt_string = GetPrompt(std::cin, app.verbosity, eot_line); - if (!std::cin) return; - // If !eot_line.empty(), we append \n, so only look at the first 2 chars. - if (prompt_string.size() >= 2 && prompt_string[0] == '%') { - if (prompt_string[1] == 'q' || prompt_string[1] == 'Q') return; - if (prompt_string[1] == 'c' || prompt_string[1] == 'C') { - abs_pos = 0; + const std::string prompt_string = GetPrompt(inference); + const bool is_interactive = inference.IsInteractive(); + if (is_interactive) { // handle special commands: + if (!std::cin) return; + + // If !eot_line.empty(), we append \n, so only look at the first 2 chars. + if (prompt_string.size() >= 2 && prompt_string[0] == '%') { + if (prompt_string[1] == 'q' || prompt_string[1] == 'Q') return; + if (prompt_string[1] == 'c' || prompt_string[1] == 'C') { + abs_pos = 0; + continue; + } + } + + if (prompt_string.empty()) { + std::cout << "Use '%q' to quit.\n"; continue; } } - if (prompt_string.empty()) { - std::cout << "Use '%q' to quit.\n"; - continue; + + // Set up runtime config. + TimingInfo timing_info = {.verbosity = inference.verbosity}; + RuntimeConfig runtime_config = {.gen = &gen, + .verbosity = inference.verbosity, + .stream_token = stream_token, + .use_spinning = threading.spin}; + inference.CopyTo(runtime_config); + std::vector prompt; + size_t prefix_end = 0; + if (have_image) { + prompt = WrapAndTokenize(gemma.Tokenizer(), gemma.ChatTemplate(), + config.wrapping, abs_pos, prompt_string, + image_tokens.Rows()); + runtime_config.image_tokens = &image_tokens; + prompt_size = prompt.size(); + if (config.wrapping == PromptWrapping::PALIGEMMA) { + // The end of the prefix for prefix-LM style attention in Paligemma. + // See Figure 2 of https://arxiv.org/abs/2407.07726. + prefix_end = prompt_size; + // We need to look at all the tokens for the prefix. + // NOTE: Online softmax is on the roadmap, after which this requirement + // can be lifted. + runtime_config.prefill_tbatch_size = prompt_size; + } + } else { + prompt = WrapAndTokenize(gemma.Tokenizer(), gemma.ChatTemplate(), + config.wrapping, abs_pos, prompt_string); + prompt_size = prompt.size(); } - // Wrap, tokenize and maybe log prompt tokens. - std::vector prompt = WrapAndTokenize( - model.Tokenizer(), model.Info(), abs_pos, prompt_string); - prompt_size = prompt.size(); if constexpr (kVerboseLogTokens) { for (int i = 0; i < prompt_size; ++i) { fprintf(stderr, "DDD TOKEN %3d: %6d\n", i, prompt[i]); } } - // Set up runtime config. - TimingInfo timing_info = {.verbosity = app.verbosity}; - RuntimeConfig runtime_config = {.gen = &gen, - .verbosity = app.verbosity, - .stream_token = stream_token, - .accept_token = accept_token, - .use_spinning = app.spin}; - args.CopyTo(runtime_config); - size_t prefix_end = 0; - if (have_image) { - runtime_config.image_tokens = &image_tokens; - if (model.Info().wrapping == PromptWrapping::PALIGEMMA) { - prompt.insert(prompt.begin(), image_tokens.BatchSize(), 0); - } else if (model.Info().wrapping == PromptWrapping::GEMMA_VLM) { - size_t seq_len = model.GetModelConfig().vit_config.seq_len; - size_t pool_dim = model.GetModelConfig().vit_config.pool_dim; - prompt = - WrapVLM(model.Tokenizer(), model.Info(), abs_pos, prompt, - image_tokens.BatchSize(), seq_len / (pool_dim * pool_dim)); - } - prompt_size = prompt.size(); - // The end of the prefix for prefix-LM style attention in Paligemma. - // See Figure 2 of https://arxiv.org/abs/2407.07726. - prefix_end = prompt_size; - // We need to look at all the tokens for the prefix. - runtime_config.prefill_tbatch_size = prompt_size; - } - // Generate until EOS or max_generated_tokens. - if (app.verbosity >= 1) { + if (inference.verbosity >= 1) { std::cerr << "\n[ Reading prompt ] " << std::flush; } - model.Generate(runtime_config, prompt, abs_pos, prefix_end, kv_cache, + gemma.Generate(runtime_config, prompt, abs_pos, prefix_end, kv_cache, env, timing_info); std::cout << "\n\n"; + // In non-interactive mode, we only process one prompt/turn. + if (!is_interactive) break; + // Prepare for the next turn. Works only for PaliGemma. - if (!args.multiturn || model.Info().wrapping == PromptWrapping::PALIGEMMA) { + if (!inference.multiturn || config.wrapping == PromptWrapping::PALIGEMMA) { abs_pos = 0; // Start a new turn at position 0. - InitGenerator(args, gen); + InitGenerator(inference, gen); } else { // The last token was either EOS, then it should be ignored because it is // never part of the dialog, see Table 5 in the Gemma-2 paper: @@ -227,20 +250,17 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app, } } -void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) { +void Run(const LoaderArgs& loader, const ThreadingArgs& threading, + const InferenceArgs& inference) { PROFILER_ZONE("Run.misc"); - // Note that num_threads is an upper bound; we also limit to the number of - // detected and enabled cores. - const BoundedTopology topology = CreateTopology(app); - NestedPools pools = CreatePools(topology, app); - MatMulEnv env(topology, pools); - if (app.verbosity >= 2) env.print_best = true; - Gemma model = CreateGemma(loader, env); - KVCache kv_cache = - KVCache::Create(model.GetModelConfig(), inference.prefill_tbatch_size); + ThreadingContext ctx(UpdateArgs(threading, inference)); + MatMulEnv env(ctx); + if (inference.verbosity >= 2) env.print_best = true; + const Gemma gemma(loader, inference, ctx); + KVCache kv_cache(gemma.Config(), inference, ctx.allocator); - if (app.verbosity >= 1) { + if (inference.verbosity >= 1) { std::string instructions = "*Usage*\n" " Enter an instruction and press enter (%C resets conversation, " @@ -261,46 +281,37 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) { instructions += multiturn; instructions += examples; - std::cout << "\033[2J\033[1;1H" // clear screen - << kAsciiArtBanner << "\n\n"; - ShowConfig(loader, inference, app, topology, pools); - std::cout << "\n" << instructions << "\n"; + // Skip the banner and instructions in non-interactive mode + if (inference.IsInteractive()) { + std::cout << "\033[2J\033[1;1H" // clear screen + << kAsciiArtBanner << "\n\n"; + ShowConfig(loader, threading, inference, gemma.Config(), + gemma.WeightReadMode(), ctx); + std::cout << "\n" << instructions << "\n"; + } } - ReplGemma(model, kv_cache, app, inference, AcceptFunc(), app.eot_line); + ReplGemma(threading, inference, gemma, kv_cache, env); } } // namespace gcpp int main(int argc, char** argv) { + gcpp::InternalInit(); { PROFILER_ZONE("Startup.misc"); - // Placeholder for internal init, do not modify. - gcpp::LoaderArgs loader(argc, argv); + gcpp::ThreadingArgs threading(argc, argv); gcpp::InferenceArgs inference(argc, argv); - gcpp::AppArgs app(argc, argv); if (gcpp::HasHelp(argc, argv)) { std::cerr << gcpp::kAsciiArtBanner; - gcpp::ShowHelp(loader, inference, app); + gcpp::ShowHelp(loader, threading, inference); return 0; } - if (const char* error = loader.Validate()) { - std::cerr << gcpp::kAsciiArtBanner; - gcpp::ShowHelp(loader, inference, app); - HWY_ABORT("\nInvalid args: %s", error); - } - - if (const char* error = inference.Validate()) { - std::cerr << gcpp::kAsciiArtBanner; - gcpp::ShowHelp(loader, inference, app); - HWY_ABORT("\nInvalid args: %s", error); - } - - gcpp::Run(loader, inference, app); + gcpp::Run(loader, threading, inference); } PROFILER_PRINT_RESULTS(); // Must call outside the zone above. return 0; diff --git a/gemma/tensor_index.cc b/gemma/tensor_index.cc deleted file mode 100644 index 4308c9d..0000000 --- a/gemma/tensor_index.cc +++ /dev/null @@ -1,607 +0,0 @@ -#include "gemma/tensor_index.h" - -#include - -#include -#include -#include -#include -#include -#include - -#include "compression/shared.h" -#include "gemma/configs.h" - -namespace gcpp { -namespace { - -// Returns the non-layer tensors for the model. -std::vector ModelTensors(const ModelConfig& config) { - return { - TensorInfo{ - .name = "c_embedding", - .source_names = {"embedder/input_embedding"}, - .axes = {0, 1}, - .shape = {config.vocab_size, config.model_dim}, - .min_size = Type::kBF16, - }, - TensorInfo{ - .name = "c_final_norm", - .source_names = {"final_norm/scale"}, - .axes = {0}, - .shape = {config.model_dim}, - .min_size = Type::kBF16, - }, - TensorInfo{ - .name = "enc_norm_bias", - .source_names = {"img/Transformer/encoder_norm/bias"}, - .axes = {0}, - .shape = {config.vit_config.model_dim}, - .min_size = Type::kBF16, - }, - TensorInfo{ - .name = "enc_norm_scale", - .source_names = {"img/Transformer/encoder_norm/scale"}, - .axes = {0}, - .shape = {config.vit_config.model_dim}, - .min_size = Type::kBF16, - }, - TensorInfo{ - .name = "img_emb_bias", - .source_names = {"img/embedding/bias"}, - .axes = {0}, - .shape = {config.vit_config.model_dim}, - .min_size = Type::kF32, - }, - TensorInfo{ - .name = "img_emb_kernel", - .source_names = {"img/embedding/kernel"}, - .axes = {3, 0, 1, 2}, - .shape = {config.vit_config.model_dim, config.vit_config.patch_width, - config.vit_config.patch_width, 3}, - .min_size = Type::kBF16, - .cols_take_extra_dims = true, - }, - TensorInfo{ - .name = "img_head_bias", - .source_names = {"img/head/bias", "embedder/mm_input_projection/b"}, - .axes = {0}, - .shape = {config.model_dim}, - .min_size = Type::kF32, - }, - TensorInfo{ - .name = "img_head_kernel", - .source_names = {"img/head/kernel", "embedder/mm_input_projection/w"}, - .axes = {1, 0}, - .shape = {config.model_dim, config.vit_config.model_dim}, - .min_size = Type::kBF16, - }, - TensorInfo{ - .name = "img_pos_emb", - .source_names = {"img/pos_embedding"}, - .axes = {0, 1}, - .shape = {/*1,*/ config.vit_config.seq_len, - config.vit_config.model_dim}, - .min_size = Type::kF32, - }, - // RMS norm applied to soft tokens prior to pos embedding. - TensorInfo{ - .name = "mm_embed_norm", - .source_names = {"embedder/mm_soft_embedding_norm/scale"}, - .axes = {0}, - .shape = {config.vit_config.model_dim}, - .min_size = Type::kBF16, - }, - }; -} - -// Returns the tensors for the given image layer config. -std::vector ImageLayerTensors(const ModelConfig& config, - const LayerConfig& layer_config, - const int img_layer_idx) { - return { - // Vit layers. - TensorInfo{ - .name = "attn_out_w", - .source_names = {"MultiHeadDotProductAttention_0/out/kernel"}, - .axes = {2, 0, 1}, - .shape = {config.vit_config.model_dim, layer_config.heads, - layer_config.qkv_dim}, - .min_size = Type::kBF16, - .cols_take_extra_dims = true, - }, - TensorInfo{ - .name = "attn_out_b", - .source_names = {"MultiHeadDotProductAttention_0/out/bias"}, - .axes = {0}, - .shape = {config.vit_config.model_dim}, - .min_size = Type::kF32, - }, - TensorInfo{ - .name = "q_ein_w", - .source_names = {"MultiHeadDotProductAttention_0/query/kernel"}, - .axes = {1, 2, 0}, - .shape = {layer_config.heads, layer_config.qkv_dim, - config.vit_config.model_dim}, - .concat_names = {"qkv_ein_w", "k_ein_w", "v_ein_w"}, - .concat_axis = 1, - .min_size = Type::kBF16, - }, - TensorInfo{ - .name = "k_ein_w", - .source_names = {"MultiHeadDotProductAttention_0/key/kernel"}, - .axes = {1, 2, 0}, - .shape = {layer_config.heads, layer_config.qkv_dim, - config.vit_config.model_dim}, - .concat_names = {""}, - .min_size = Type::kBF16, - }, - TensorInfo{ - .name = "v_ein_w", - .source_names = {"MultiHeadDotProductAttention_0/value/kernel"}, - .axes = {1, 2, 0}, - .shape = {layer_config.heads, layer_config.qkv_dim, - config.vit_config.model_dim}, - .concat_names = {""}, - .min_size = Type::kBF16, - }, - TensorInfo{ - .name = "qkv_ein_w", - .source_names = {"MultiHeadDotProductAttention_0/qkv/kernel"}, - .axes = {1, 2, 0}, - .shape = {layer_config.heads, 3 * layer_config.qkv_dim, - config.vit_config.model_dim}, - .min_size = Type::kBF16, - }, - TensorInfo{ - .name = "q_ein_b", - .source_names = {"MultiHeadDotProductAttention_0/query/bias"}, - .axes = {0, 1}, - .shape = {layer_config.heads, layer_config.qkv_dim}, - .concat_names = {"qkv_ein_b", "k_ein_b", "v_ein_b"}, - .concat_axis = 1, - .min_size = Type::kF32, - }, - TensorInfo{ - .name = "k_ein_b", - .source_names = {"MultiHeadDotProductAttention_0/key/bias"}, - .axes = {0, 1}, - .shape = {layer_config.kv_heads, layer_config.qkv_dim}, - .concat_names = {""}, - .min_size = Type::kF32, - }, - TensorInfo{ - .name = "v_ein_b", - .source_names = {"MultiHeadDotProductAttention_0/value/bias"}, - .axes = {0, 1}, - .shape = {layer_config.kv_heads, layer_config.qkv_dim}, - .concat_names = {""}, - .min_size = Type::kF32, - }, - TensorInfo{ - .name = "qkv_ein_b", - .source_names = {"MultiHeadDotProductAttention_0/qkv/bias"}, - .axes = {0, 1}, - .shape = {layer_config.heads + layer_config.kv_heads * 2, - layer_config.qkv_dim}, - .min_size = Type::kF32, - }, - TensorInfo{ - .name = "linear_0_w", - .source_names = {"MlpBlock_0/Dense_0/kernel"}, - .axes = {1, 0}, - .shape = {layer_config.ff_hidden_dim, config.vit_config.model_dim}, - .min_size = Type::kBF16, - }, - TensorInfo{ - .name = "linear_0_b", - .source_names = {"MlpBlock_0/Dense_0/bias"}, - .axes = {0}, - .shape = {layer_config.ff_hidden_dim}, - .min_size = Type::kF32, - }, - TensorInfo{ - .name = "linear_1_w", - .source_names = {"MlpBlock_0/Dense_1/kernel"}, - .axes = {1, 0}, - .shape = {config.vit_config.model_dim, layer_config.ff_hidden_dim}, - .min_size = Type::kBF16, - }, - TensorInfo{ - .name = "linear_1_b", - .source_names = {"MlpBlock_0/Dense_1/bias"}, - .axes = {0}, - .shape = {config.vit_config.model_dim}, - .min_size = Type::kF32, - }, - TensorInfo{ - .name = "ln_0_bias", - .source_names = {"img/Transformer/encoderblock/LayerNorm_0/bias", - "img/Transformer/encoderblock_" + - std::to_string(img_layer_idx) + - "/LayerNorm_0/bias"}, - .axes = {0}, - .shape = {config.vit_config.model_dim}, - .min_size = Type::kBF16, - }, - TensorInfo{ - .name = "ln_0_scale", - .source_names = {"img/Transformer/encoderblock/LayerNorm_0/scale", - "img/Transformer/encoderblock_" + - std::to_string(img_layer_idx) + - "/LayerNorm_0/scale"}, - .axes = {0}, - .shape = {config.vit_config.model_dim}, - .min_size = Type::kBF16, - }, - TensorInfo{ - .name = "ln_1_bias", - .source_names = {"img/Transformer/encoderblock/LayerNorm_1/bias", - "img/Transformer/encoderblock_" + - std::to_string(img_layer_idx) + - "/LayerNorm_1/bias"}, - .axes = {0}, - .shape = {config.vit_config.model_dim}, - .min_size = Type::kBF16, - }, - TensorInfo{ - .name = "ln_1_scale", - .source_names = {"img/Transformer/encoderblock/LayerNorm_1/scale", - "img/Transformer/encoderblock_" + - std::to_string(img_layer_idx) + - "/LayerNorm_1/scale"}, - .axes = {0}, - .shape = {config.vit_config.model_dim}, - .min_size = Type::kBF16, - }, - }; -} - -// Returns the tensors for the given LLM layer config. -std::vector LLMLayerTensors(const ModelConfig& config, - const LayerConfig& layer_config, - bool reshape_att) { - std::vector tensors = { - TensorInfo{ - .name = "key_norm", - .source_names = {"attn/_key_norm/scale"}, - .axes = {0}, - .shape = {layer_config.qkv_dim}, - .min_size = Type::kBF16, - }, - TensorInfo{ - .name = "query_norm", - .source_names = {"attn/_query_norm/scale"}, - .axes = {0}, - .shape = {layer_config.qkv_dim}, - .min_size = Type::kBF16, - }, - TensorInfo{ - .name = "qkv1_w", - .source_names = {"attn/q_einsum/w"}, - .axes = {0, 2, 1}, - .shape = {layer_config.heads * layer_config.qkv_dim, - config.model_dim}, - .concat_names = {"qkv_ein", "qkv2_w"}, - }, - TensorInfo{ - .name = "qkv2_w", - .source_names = {"attn/kv_einsum/w"}, - .axes = {1, 0, 3, 2}, - .shape = {2 * layer_config.kv_heads * layer_config.qkv_dim, - config.model_dim}, - .concat_names = {""}, - }, - TensorInfo{ - .name = "q_ein", - .source_names = {"attention_block/proj_q/kernel"}, - .axes = {1, 0}, - .shape = {layer_config.model_dim, layer_config.model_dim}, - .concat_names = {"qkv_ein", "k_ein", "v_ein"}, - }, - TensorInfo{ - .name = "k_ein", - .source_names = {"attention_block/proj_k/kernel"}, - .axes = {1, 0}, - .shape = {layer_config.qkv_dim, layer_config.model_dim}, - .concat_names = {""}, - }, - TensorInfo{ - .name = "v_ein", - .source_names = {"attention_block/proj_v/kernel"}, - .axes = {1, 0}, - .shape = {layer_config.qkv_dim, layer_config.model_dim}, - .concat_names = {""}, - }, - TensorInfo{ - .name = "qkv_ein", - .source_names = {"attn/qkv_einsum/w"}, - .axes = {1, 0, 3, 2}, - .shape = {(layer_config.heads + 2 * layer_config.kv_heads) * - layer_config.qkv_dim, - config.model_dim}, - }, - TensorInfo{ - .name = "attn_ob", - .source_names = {"attention_block/proj_final/bias"}, - .axes = {0}, - .shape = {config.model_dim}, - .min_size = Type::kF32, - }, - // Griffin layers. - TensorInfo{ - .name = "gr_lin_x_w", - .source_names = {"recurrent_block/linear_x/kernel"}, - .axes = {1, 0}, - .shape = {layer_config.griffin_dim, layer_config.griffin_dim}, - }, - TensorInfo{ - .name = "gr_lin_x_b", - .source_names = {"recurrent_block/linear_x/bias"}, - .axes = {0}, - .shape = {layer_config.griffin_dim}, - .min_size = Type::kF32, - }, - TensorInfo{ - .name = "gr_lin_y_w", - .source_names = {"recurrent_block/linear_y/kernel"}, - .axes = {1, 0}, - .shape = {layer_config.griffin_dim, layer_config.griffin_dim}, - }, - TensorInfo{ - .name = "gr_lin_y_b", - .source_names = {"recurrent_block/linear_y/bias"}, - .axes = {0}, - .shape = {layer_config.griffin_dim}, - .min_size = Type::kF32, - }, - TensorInfo{ - .name = "gr_lin_out_w", - .source_names = {"recurrent_block/linear_out/kernel"}, - .axes = {1, 0}, - .shape = {layer_config.griffin_dim, layer_config.griffin_dim}, - }, - TensorInfo{ - .name = "gr_lin_out_b", - .source_names = {"recurrent_block/linear_out/bias"}, - .axes = {0}, - .shape = {layer_config.griffin_dim}, - .min_size = Type::kF32, - }, - TensorInfo{ - .name = "gr_conv_w", - .source_names = {"recurrent_block/conv_1d/w"}, - .axes = {0, 1}, - .shape = {layer_config.conv1d_width, layer_config.griffin_dim}, - .min_size = Type::kF32, - }, - TensorInfo{ - .name = "gr_conv_b", - .source_names = {"recurrent_block/conv_1d/b"}, - .axes = {0}, - .shape = {layer_config.griffin_dim}, - .min_size = Type::kF32, - }, - TensorInfo{ - .name = "gr1_gate_w", - .source_names = {"recurrent_block/rg_lru/input_gate/w"}, - .axes = {0, 2, 1}, - .shape = {layer_config.heads, - layer_config.griffin_dim / layer_config.heads, - layer_config.griffin_dim / layer_config.heads}, - .concat_names = {"gr_gate_w", "gr2_gate_w"}, - }, - TensorInfo{ - .name = "gr2_gate_w", - .source_names = {"recurrent_block/rg_lru/a_gate/w"}, - .axes = {0, 2, 1}, - .shape = {layer_config.heads, - layer_config.griffin_dim / layer_config.heads, - layer_config.griffin_dim / layer_config.heads}, - .concat_names = {""}, - }, - TensorInfo{ - .name = "gr_gate_w", - .source_names = {"recurrent_block/rg_lru/gate/w"}, - .axes = {0, 2, 1}, - .shape = {2 * layer_config.heads, - layer_config.griffin_dim / layer_config.heads, - layer_config.griffin_dim / layer_config.heads}, - }, - TensorInfo{ - .name = "gr1_gate_b", - .source_names = {"recurrent_block/rg_lru/input_gate/b"}, - .axes = {0}, - .shape = {layer_config.griffin_dim}, - .concat_names = {"gr_gate_b", "gr2_gate_b"}, - .min_size = Type::kF32, - }, - TensorInfo{ - .name = "gr2_gate_b", - .source_names = {"recurrent_block/rg_lru/a_gate/b"}, - .axes = {0}, - .shape = {layer_config.griffin_dim}, - .concat_names = {""}, - .min_size = Type::kF32, - }, - TensorInfo{ - .name = "gr_gate_b", - .source_names = {"recurrent_block/rg_lru/input_gate/b"}, - .axes = {0, 1}, - .shape = {2 * layer_config.griffin_dim}, - .min_size = Type::kF32, - }, - TensorInfo{ - .name = "gr_a", - .source_names = {"recurrent_block/rg_lru/a_param"}, - .axes = {0}, - .shape = {layer_config.griffin_dim}, - .min_size = Type::kF32, - .scaled_softplus = true, - }, - - TensorInfo{ - .name = "gating_ein", - .source_names = {"mlp/gating_einsum/w", "mlp/gating_einsum", - "mlp_block/ffw_up/w"}, - .axes = {0, layer_config.optimized_gating ? 1u : 2u, - layer_config.optimized_gating ? 2u : 1u}, - .shape = {2, layer_config.ff_hidden_dim, config.model_dim}, - }, - TensorInfo{ - .name = "gating1_w", - .source_names = {"none"}, - .axes = {0, layer_config.optimized_gating ? 1u : 2u, - layer_config.optimized_gating ? 2u : 1u}, - .shape = {layer_config.ff_hidden_dim, config.model_dim}, - }, - TensorInfo{ - .name = "gating2_w", - .source_names = {"none"}, - .axes = {0, layer_config.optimized_gating ? 1u : 2u, - layer_config.optimized_gating ? 2u : 1u}, - .shape = {layer_config.ff_hidden_dim, config.model_dim}, - }, - TensorInfo{ - .name = "linear_w", - .source_names = {"mlp/linear/w", "mlp/linear", - "mlp_block/ffw_down/kernel"}, - .axes = {1, 0}, - .shape = {config.model_dim, layer_config.ff_hidden_dim}, - }, - TensorInfo{ - .name = "pre_att_ns", - .source_names = {"pre_attention_norm/scale", - "temporal_pre_norm/scale"}, - .axes = {0}, - .shape = {config.model_dim}, - .min_size = Type::kBF16, - }, - TensorInfo{ - .name = "pre_ff_ns", - .source_names = {"pre_ffw_norm/scale", "channel_pre_norm/scale"}, - .axes = {0}, - .shape = {config.model_dim}, - .min_size = Type::kBF16, - }, - TensorInfo{ - .name = "post_att_ns", - .source_names = {"post_attention_norm/scale"}, - .axes = {0}, - .shape = {config.model_dim}, - .min_size = Type::kBF16, - }, - TensorInfo{ - .name = "post_ff_ns", - .source_names = {"post_ffw_norm/scale"}, - .axes = {0}, - .shape = {config.model_dim}, - .min_size = Type::kBF16, - }, - TensorInfo{ - .name = "ffw_gat_b", - .source_names = {"mlp_block/ffw_up/b"}, - .axes = {0}, - .shape = {2 * layer_config.ff_hidden_dim}, - .min_size = Type::kF32, - }, - TensorInfo{ - .name = "ffw_out_b", - .source_names = {"mlp_block/ffw_down/bias"}, - .axes = {0}, - .shape = {config.model_dim}, - .min_size = Type::kF32, - }, - }; - if (reshape_att) { - tensors.push_back(TensorInfo{ - .name = "att_w", - .source_names = {"attn/attn_vec_einsum/w", - "attention_block/proj_final/kernel"}, - .preshape = {layer_config.heads, layer_config.qkv_dim, - config.model_dim}, - .axes = {2, 0, 1}, - .shape = {config.model_dim, layer_config.heads, layer_config.qkv_dim}, - .cols_take_extra_dims = true, - }); - tensors.push_back(TensorInfo{ - .name = "att_ein", - .shape = {layer_config.heads, config.model_dim, layer_config.qkv_dim}, - }); - } else { - tensors.push_back(TensorInfo{ - .name = "att_ein", - .source_names = {"attn/attn_vec_einsum/w", - "attention_block/proj_final/kernel"}, - .preshape = {layer_config.heads, layer_config.qkv_dim, - config.model_dim}, - .axes = {0, 2, 1}, - .shape = {layer_config.heads, config.model_dim, layer_config.qkv_dim}, - }); - tensors.push_back(TensorInfo{ - .name = "att_w", - .shape = {config.model_dim, layer_config.heads, layer_config.qkv_dim}, - .cols_take_extra_dims = true, - }); - } - return tensors; -} - -} // namespace - -TensorIndex::TensorIndex(const ModelConfig& config, int llm_layer_idx, - int img_layer_idx, bool reshape_att) - : config_(config), - llm_layer_idx_(llm_layer_idx), - img_layer_idx_(img_layer_idx) { - int layer_idx = std::max(llm_layer_idx_, img_layer_idx_); - std::string suffix; - if (layer_idx >= 0) { - suffix = "_" + std::to_string(layer_idx); - } - if (llm_layer_idx < 0 && img_layer_idx < 0) { - tensors_ = ModelTensors(config); - } else if (llm_layer_idx_ < 0 && 0 <= img_layer_idx && - img_layer_idx < config.vit_config.layer_configs.size()) { - const auto& layer_config = config.vit_config.layer_configs[img_layer_idx]; - tensors_ = ImageLayerTensors(config, layer_config, img_layer_idx); - } else if (0 <= llm_layer_idx && - llm_layer_idx < config.layer_configs.size()) { - const auto& layer_config = config.layer_configs[llm_layer_idx]; - tensors_ = LLMLayerTensors(config, layer_config, reshape_att); - } - for (size_t i = 0; i < tensors_.size(); ++i) { - std::string key = tensors_[i].name + suffix; - name_map_.insert({key, i}); - } -} - -TensorInfo TensorIndex::TensorInfoFromSourcePath( - const std::string& path) const { - for (const auto& tensor : tensors_) { - for (const auto& source_name : tensor.source_names) { - auto pos = path.rfind(source_name); - if (pos != std::string::npos && path.size() == pos + source_name.size()) - return tensor; - } - } - return TensorInfo(); -} - -const TensorInfo* TensorIndex::FindName(const std::string& name) const { - std::string name_to_find = name; - if (!std::isdigit(name[name.size() - 1])) { - if (img_layer_idx_ >= 0 && llm_layer_idx_ < 0) { - name_to_find = name + "_" + std::to_string(img_layer_idx_); - } else if (llm_layer_idx_ >= 0) { - name_to_find = name + "_" + std::to_string(llm_layer_idx_); - } - } - auto it = name_map_.find(name_to_find); - if (it == name_map_.end()) { - return nullptr; - } - return &tensors_[it->second]; -} - -} // namespace gcpp \ No newline at end of file diff --git a/gemma/tensor_index.h b/gemma/tensor_index.h deleted file mode 100644 index dc6b86c..0000000 --- a/gemma/tensor_index.h +++ /dev/null @@ -1,101 +0,0 @@ -#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_TENSOR_INDEX_H_ -#define THIRD_PARTY_GEMMA_CPP_GEMMA_TENSOR_INDEX_H_ - -#include - -#include -#include -#include - -#include "compression/shared.h" -#include "gemma/configs.h" - -namespace gcpp { - -// Universal tensor information. Holds enough information to construct a -// tensor in LayerWeightsPtrs/ModelWeightsPtrs, as well as to export the -// tensor from the python model with necessary transpose/reshape info. -struct TensorInfo { - // The name of the tensor in the sbs file - std::string name; - // Strings to match to the end of the name of the tensor in the python model. - std::vector source_names; - // Initial reshape shape. Use only as a last resort when input may have - // dimensions combined that need to be split before the transpose, as it - // defeats the post-transpose shape checking. Normally empty. - std::vector preshape; - // Transpose axes arg. If the input tensor has more dimensions than axes, - // then leading dimensions are collapsed until the number of axes matches. - std::vector axes; - // Expected final shape of the tensor after reshape/transpose. - // Note that this is the shape of the tensor during export, - // not the shape of the tensor in the sbs file, as the sbs file - // is restricted to 2D tensors. With few exceptions, the sbs file - // tensor rows gather all the excess dimensions. See cols_take_extra_dims. - std::vector shape; - // List of names to concatenate with, used only if multiple tensors are - // concatenated into one. The first tensor in the concatenation should have - // concat names thus: The first name is the name of the result, and the - // tensors with the remaining names are concatenated after this. - // The remaining tensors to be concatenated should have just a single - // empty string in concat_names to indicate that they have been consumed. - std::vector concat_names; - // Axis at which to concatenate. - size_t concat_axis = 0; - // The minimum compression weight type for this tensor. The default is - // kNUQ, which provides maximum compression. Other values such as kBF16 - // or kF32 can be used to limit the compression to a specific type. - Type min_size = Type::kNUQ; - // Whether to apply scaled softplus to the data. - bool scaled_softplus = false; - // Whether the columns or the rows take any extra dimensions. - // If false, then [10, 20, 30] -> [10*20, 30] and [30] -> [1, 30]. - // If true, then [10, 20, 30] -> [10, 20*30] and [30] -> [1, 30]. - bool cols_take_extra_dims = false; -}; - -// Universal index of tensor information, which can be built for a specific -// layer_idx. -class TensorIndex { - public: - // Builds a list of TensorInfo for the given layer_idx. - // If reshape_att is true, the attn_vec_einsum tensor is reshaped. - TensorIndex(const ModelConfig& config, int llm_layer_idx, int img_layer_idx, - bool reshape_att); - ~TensorIndex() = default; - - // Returns the TensorInfo whose source_name matches the end of the given path, - // or an empty TensorInfo if not found. - // NOTE: that the returned TensorInfo is a copy, so that the source - // TensorIndex can be destroyed without affecting the returned TensorInfo. - TensorInfo TensorInfoFromSourcePath(const std::string& path) const; - - // Returns the TensorInfo whose name matches the given name, - // or an empty TensorInfo if not found. - // NOTE: that the returned TensorInfo is a copy, so that the source - // TensorIndex can be destroyed without affecting the returned TensorInfo. - TensorInfo TensorInfoFromName(const std::string& name) const { - const TensorInfo* info = FindName(name); - if (info == nullptr) return TensorInfo(); - return *info; - } - - // Returns the TensorInfo for the given tensor name, for concise construction - // of ModelWeightsPtrs/LayerWeightsPtrs. - const TensorInfo* FindName(const std::string& name) const; - - private: - // Config that was used to build the tensor index. - const ModelConfig& config_; - // Layer that this tensor index is for - either LLM or image. - int llm_layer_idx_; - int img_layer_idx_; - // List of tensor information for this layer. - std::vector tensors_; - // Map from tensor name to index in tensors_. - std::unordered_map name_map_; -}; - -} // namespace gcpp - -#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_TENSOR_INDEX_H_ diff --git a/gemma/tensor_index_test.cc b/gemma/tensor_index_test.cc deleted file mode 100644 index 50ff0b6..0000000 --- a/gemma/tensor_index_test.cc +++ /dev/null @@ -1,72 +0,0 @@ -#include "gemma/tensor_index.h" - -#include -#include -#include -#include -#include - -#include "gtest/gtest.h" -#include "compression/compress.h" -#include "compression/shared.h" -#include "gemma/configs.h" -#include "gemma/weights.h" -#include "util/basics.h" -#include "hwy/aligned_allocator.h" - -namespace gcpp { -namespace { - -// Tests that each tensor in the model can be found by exactly one TensorIndex, -// and that the TensorIndex returns the correct shape and name for the tensor, -// for all models. -TEST(TensorIndexTest, FindName) { - for (Model model : kAllModels) { - fprintf(stderr, "Testing model %d\n", static_cast(model)); - ModelConfig config = ConfigFromModel(model); - std::vector tensor_indexes; - tensor_indexes.emplace_back(config, /*llm_layer_idx=*/-1, - /*img_layer_idx=*/-1, - /*split_and_reshape=*/false); - for (size_t llm_layer_idx = 0; llm_layer_idx < config.layer_configs.size(); - ++llm_layer_idx) { - tensor_indexes.emplace_back(config, static_cast(llm_layer_idx), - /*img_layer_idx=*/-1, - /*split_and_reshape=*/false); - } - for (size_t img_layer_idx = 0; - img_layer_idx < config.vit_config.layer_configs.size(); - ++img_layer_idx) { - tensor_indexes.emplace_back(config, /*llm_layer_idx=*/-1, - static_cast(img_layer_idx), - /*split_and_reshape=*/false); - } - // For each tensor in any model, exactly one TensorIndex should find it. - ModelWeightsPtrs weights(config); - ModelWeightsPtrs::ForEachTensor( - {&weights}, ForEachType::kInitNoToc, - [&tensor_indexes](const char* name, hwy::Span tensors) { - int num_found = 0; - const MatPtr& tensor = *tensors[0]; - for (const auto& tensor_index : tensor_indexes) { - // Skip the type marker prefix, but we want the layer index suffix. - std::string name_to_find(name + 1, strlen(name) - 1); - const TensorInfo* info = tensor_index.FindName(name_to_find); - if (info != nullptr) { - // Test that the MatPtr can be constructed from the TensorInfo, - // and that the dimensions match. - MatPtrT mat_ptr(tensor.Name(), tensor_index); - EXPECT_STREQ(tensor.Name(), mat_ptr.Name()) - << "on tensor " << name; - EXPECT_EQ(tensor.Rows(), mat_ptr.Rows()) << "on tensor " << name; - EXPECT_EQ(tensor.Cols(), mat_ptr.Cols()) << "on tensor " << name; - ++num_found; - } - } - EXPECT_EQ(num_found, 1) << " for tensor " << name; - }); - } -} - -} // namespace -} // namespace gcpp diff --git a/gemma/tensor_info.cc b/gemma/tensor_info.cc new file mode 100644 index 0000000..de93cf9 --- /dev/null +++ b/gemma/tensor_info.cc @@ -0,0 +1,593 @@ +#include "gemma/tensor_info.h" + +#include +#include + +#include + +#include "compression/types.h" +#include "gemma/configs.h" + +namespace gcpp { + +void TensorInfoRegistry::Add(const std::string& suffix, + const TensorInfo& info) { + const size_t idx = tensors_.size(); + tensors_.push_back(info); + // Also add suffix to `concat_names`. + for (std::string& name : tensors_.back().concat_names) { + name += suffix; + } + + const std::string name = info.base_name + suffix; + // Ensure successful insertion because `suffix` ensures uniqueness for + // per-layer tensors, and per-model should only be inserted once. + HWY_ASSERT_M(idx_from_name_.insert({name, idx}).second, name.c_str()); +} + +// Non-layer tensors. +void TensorInfoRegistry::AddModelTensors(const ModelConfig& config) { + const std::string no_suffix; + Add(no_suffix, { + .base_name = "c_embedding", + .source_names = {"embedder/input_embedding"}, + .axes = {0, 1}, + .shape = {config.vocab_size, config.model_dim}, + .min_size = Type::kBF16, + }); + Add(no_suffix, { + .base_name = "c_final_norm", + .source_names = {"final_norm/scale"}, + .axes = {0}, + .shape = {config.model_dim}, + .min_size = Type::kBF16, + }); + Add(no_suffix, { + .base_name = "enc_norm_bias", + .source_names = {"img/Transformer/encoder_norm/bias"}, + .axes = {0}, + .shape = {config.vit_config.model_dim}, + .min_size = Type::kBF16, + }); + Add(no_suffix, { + .base_name = "enc_norm_scale", + .source_names = {"img/Transformer/encoder_norm/scale"}, + .axes = {0}, + .shape = {config.vit_config.model_dim}, + .min_size = Type::kBF16, + }); + Add(no_suffix, { + .base_name = "img_emb_bias", + .source_names = {"img/embedding/bias"}, + .axes = {0}, + .shape = {config.vit_config.model_dim}, + .min_size = Type::kF32, + }); + Add(no_suffix, + { + .base_name = "img_emb_kernel", + .source_names = {"img/embedding/kernel"}, + .axes = {3, 0, 1, 2}, + .shape = {config.vit_config.model_dim, config.vit_config.patch_width, + config.vit_config.patch_width, 3}, + .min_size = Type::kBF16, + .cols_take_extra_dims = true, + }); + Add(no_suffix, + { + .base_name = "img_head_bias", + .source_names = {"img/head/bias", "embedder/mm_input_projection/b"}, + .axes = {0}, + .shape = {config.model_dim}, + .min_size = Type::kF32, + }); + Add(no_suffix, + { + .base_name = "img_head_kernel", + .source_names = {"img/head/kernel", "embedder/mm_input_projection/w"}, + .axes = {1, 0}, + .shape = {config.model_dim, config.vit_config.model_dim}, + .min_size = Type::kBF16, + }); + Add(no_suffix, { + .base_name = "img_pos_emb", + .source_names = {"img/pos_embedding"}, + .axes = {0, 1}, + .shape = {/*1,*/ config.vit_config.seq_len, + config.vit_config.model_dim}, + .min_size = Type::kF32, + }); + // RMS norm applied to soft tokens prior to pos embedding. + Add(no_suffix, { + .base_name = "mm_embed_norm", + .source_names = {"embedder/mm_soft_embedding_norm/scale"}, + .axes = {0}, + .shape = {config.vit_config.model_dim}, + .min_size = Type::kBF16, + }); +} + +// Returns the tensors for the given image layer config. +void TensorInfoRegistry::AddImageLayerTensors(const ModelConfig& config, + const LayerConfig& layer_config, + const size_t img_layer_idx) { + const std::string suffix = LayerSuffix(img_layer_idx); + + // Vit layers. + Add(suffix, { + .base_name = "attn_out_w", + .source_names = {"MultiHeadDotProductAttention_0/out/kernel"}, + .axes = {2, 0, 1}, + .shape = {config.vit_config.model_dim, layer_config.heads, + layer_config.qkv_dim}, + .min_size = Type::kBF16, + .cols_take_extra_dims = true, + }); + Add(suffix, { + .base_name = "attn_out_b", + .source_names = {"MultiHeadDotProductAttention_0/out/bias"}, + .axes = {0}, + .shape = {config.vit_config.model_dim}, + .min_size = Type::kF32, + }); + Add(suffix, + { + .base_name = "q_ein_w", + .source_names = {"MultiHeadDotProductAttention_0/query/kernel"}, + .axes = {1, 2, 0}, + .shape = {layer_config.heads, layer_config.qkv_dim, + config.vit_config.model_dim}, + .concat_names = {"qkv_ein_w", "k_ein_w", "v_ein_w"}, + .concat_axis = 1, + .min_size = Type::kBF16, + }); + Add(suffix, { + .base_name = "k_ein_w", + .source_names = {"MultiHeadDotProductAttention_0/key/kernel"}, + .axes = {1, 2, 0}, + .shape = {layer_config.heads, layer_config.qkv_dim, + config.vit_config.model_dim}, + .concat_names = {""}, + .min_size = Type::kBF16, + }); + Add(suffix, + { + .base_name = "v_ein_w", + .source_names = {"MultiHeadDotProductAttention_0/value/kernel"}, + .axes = {1, 2, 0}, + .shape = {layer_config.heads, layer_config.qkv_dim, + config.vit_config.model_dim}, + .concat_names = {""}, + .min_size = Type::kBF16, + }); + Add(suffix, { + .base_name = "qkv_ein_w", + .source_names = {"MultiHeadDotProductAttention_0/qkv/kernel"}, + .axes = {1, 2, 0}, + .shape = {layer_config.heads, 3 * layer_config.qkv_dim, + config.vit_config.model_dim}, + .min_size = Type::kBF16, + }); + Add(suffix, { + .base_name = "q_ein_b", + .source_names = {"MultiHeadDotProductAttention_0/query/bias"}, + .axes = {0, 1}, + .shape = {layer_config.heads, layer_config.qkv_dim}, + .concat_names = {"qkv_ein_b", "k_ein_b", "v_ein_b"}, + .concat_axis = 1, + .min_size = Type::kF32, + }); + Add(suffix, { + .base_name = "k_ein_b", + .source_names = {"MultiHeadDotProductAttention_0/key/bias"}, + .axes = {0, 1}, + .shape = {layer_config.kv_heads, layer_config.qkv_dim}, + .concat_names = {""}, + .min_size = Type::kF32, + }); + Add(suffix, { + .base_name = "v_ein_b", + .source_names = {"MultiHeadDotProductAttention_0/value/bias"}, + .axes = {0, 1}, + .shape = {layer_config.kv_heads, layer_config.qkv_dim}, + .concat_names = {""}, + .min_size = Type::kF32, + }); + Add(suffix, { + .base_name = "qkv_ein_b", + .source_names = {"MultiHeadDotProductAttention_0/qkv/bias"}, + .axes = {0, 1}, + .shape = {layer_config.heads + layer_config.kv_heads * 2, + layer_config.qkv_dim}, + .min_size = Type::kF32, + }); + Add(suffix, + { + .base_name = "linear_0_w", + .source_names = {"MlpBlock_0/Dense_0/kernel"}, + .axes = {1, 0}, + .shape = {layer_config.ff_hidden_dim, config.vit_config.model_dim}, + .min_size = Type::kBF16, + }); + Add(suffix, { + .base_name = "linear_0_b", + .source_names = {"MlpBlock_0/Dense_0/bias"}, + .axes = {0}, + .shape = {layer_config.ff_hidden_dim}, + .min_size = Type::kF32, + }); + Add(suffix, + { + .base_name = "linear_1_w", + .source_names = {"MlpBlock_0/Dense_1/kernel"}, + .axes = {1, 0}, + .shape = {config.vit_config.model_dim, layer_config.ff_hidden_dim}, + .min_size = Type::kBF16, + }); + Add(suffix, { + .base_name = "linear_1_b", + .source_names = {"MlpBlock_0/Dense_1/bias"}, + .axes = {0}, + .shape = {config.vit_config.model_dim}, + .min_size = Type::kF32, + }); + Add(suffix, + { + .base_name = "ln_0_bias", + .source_names = {"img/Transformer/encoderblock/LayerNorm_0/bias", + "img/Transformer/encoderblock_" + + std::to_string(img_layer_idx) + + "/LayerNorm_0/bias"}, + .axes = {0}, + .shape = {config.vit_config.model_dim}, + .min_size = Type::kBF16, + }); + Add(suffix, + { + .base_name = "ln_0_scale", + .source_names = {"img/Transformer/encoderblock/LayerNorm_0/scale", + "img/Transformer/encoderblock_" + + std::to_string(img_layer_idx) + + "/LayerNorm_0/scale"}, + .axes = {0}, + .shape = {config.vit_config.model_dim}, + .min_size = Type::kBF16, + }); + Add(suffix, + { + .base_name = "ln_1_bias", + .source_names = {"img/Transformer/encoderblock/LayerNorm_1/bias", + "img/Transformer/encoderblock_" + + std::to_string(img_layer_idx) + + "/LayerNorm_1/bias"}, + .axes = {0}, + .shape = {config.vit_config.model_dim}, + .min_size = Type::kBF16, + }); + Add(suffix, + { + .base_name = "ln_1_scale", + .source_names = {"img/Transformer/encoderblock/LayerNorm_1/scale", + "img/Transformer/encoderblock_" + + std::to_string(img_layer_idx) + + "/LayerNorm_1/scale"}, + .axes = {0}, + .shape = {config.vit_config.model_dim}, + .min_size = Type::kBF16, + }); +} + +void TensorInfoRegistry::AddGriffinLayerTensors(const LayerConfig& layer_config, + const size_t layer_idx) { + const std::string suffix = LayerSuffix(layer_idx); + Add(suffix, { + .base_name = "gr_lin_x_w", + .source_names = {"recurrent_block/linear_x/kernel"}, + .axes = {1, 0}, + .shape = {layer_config.griffin_dim, layer_config.griffin_dim}, + }); + Add(suffix, { + .base_name = "gr_lin_x_b", + .source_names = {"recurrent_block/linear_x/bias"}, + .axes = {0}, + .shape = {layer_config.griffin_dim}, + .min_size = Type::kF32, + }); + Add(suffix, { + .base_name = "gr_lin_y_w", + .source_names = {"recurrent_block/linear_y/kernel"}, + .axes = {1, 0}, + .shape = {layer_config.griffin_dim, layer_config.griffin_dim}, + }); + Add(suffix, { + .base_name = "gr_lin_y_b", + .source_names = {"recurrent_block/linear_y/bias"}, + .axes = {0}, + .shape = {layer_config.griffin_dim}, + .min_size = Type::kF32, + }); + Add(suffix, { + .base_name = "gr_lin_out_w", + .source_names = {"recurrent_block/linear_out/kernel"}, + .axes = {1, 0}, + .shape = {layer_config.griffin_dim, layer_config.griffin_dim}, + }); + Add(suffix, { + .base_name = "gr_lin_out_b", + .source_names = {"recurrent_block/linear_out/bias"}, + .axes = {0}, + .shape = {layer_config.griffin_dim}, + .min_size = Type::kF32, + }); + Add(suffix, + { + .base_name = "gr_conv_w", + .source_names = {"recurrent_block/conv_1d/w"}, + .axes = {0, 1}, + .shape = {layer_config.conv1d_width, layer_config.griffin_dim}, + .min_size = Type::kF32, + }); + Add(suffix, { + .base_name = "gr_conv_b", + .source_names = {"recurrent_block/conv_1d/b"}, + .axes = {0}, + .shape = {layer_config.griffin_dim}, + .min_size = Type::kF32, + }); + Add(suffix, { + .base_name = "gr1_gate_w", + .source_names = {"recurrent_block/rg_lru/input_gate/w"}, + .axes = {0, 2, 1}, + .shape = {layer_config.heads, + layer_config.griffin_dim / layer_config.heads, + layer_config.griffin_dim / layer_config.heads}, + .concat_names = {"gr_gate_w", "gr2_gate_w"}, + }); + Add(suffix, { + .base_name = "gr2_gate_w", + .source_names = {"recurrent_block/rg_lru/a_gate/w"}, + .axes = {0, 2, 1}, + .shape = {layer_config.heads, + layer_config.griffin_dim / layer_config.heads, + layer_config.griffin_dim / layer_config.heads}, + .concat_names = {""}, + }); + Add(suffix, { + .base_name = "gr_gate_w", + .source_names = {"recurrent_block/rg_lru/gate/w"}, + .axes = {0, 2, 1}, + .shape = {2 * layer_config.heads, + layer_config.griffin_dim / layer_config.heads, + layer_config.griffin_dim / layer_config.heads}, + }); + Add(suffix, { + .base_name = "gr1_gate_b", + .source_names = {"recurrent_block/rg_lru/input_gate/b"}, + .axes = {0}, + .shape = {layer_config.griffin_dim}, + .concat_names = {"gr_gate_b", "gr2_gate_b"}, + .min_size = Type::kF32, + }); + Add(suffix, { + .base_name = "gr2_gate_b", + .source_names = {"recurrent_block/rg_lru/a_gate/b"}, + .axes = {0}, + .shape = {layer_config.griffin_dim}, + .concat_names = {""}, + .min_size = Type::kF32, + }); + Add(suffix, { + .base_name = "gr_gate_b", + .source_names = {"recurrent_block/rg_lru/input_gate/b"}, + .axes = {0, 1}, + .shape = {2 * layer_config.griffin_dim}, + .min_size = Type::kF32, + }); + Add(suffix, { + .base_name = "gr_a", + .source_names = {"recurrent_block/rg_lru/a_param"}, + .axes = {0}, + .shape = {layer_config.griffin_dim}, + .min_size = Type::kF32, + .scaled_softplus = true, + }); +} + +void TensorInfoRegistry::AddLayerTensors(const ModelConfig& config, + const LayerConfig& layer_config, + const size_t layer_idx) { + const std::string suffix = LayerSuffix(layer_idx); + Add(suffix, { + .base_name = "key_norm", + .source_names = {"attn/_key_norm/scale"}, + .axes = {0}, + .shape = {layer_config.qkv_dim}, + .min_size = Type::kBF16, + }); + Add(suffix, { + .base_name = "query_norm", + .source_names = {"attn/_query_norm/scale"}, + .axes = {0}, + .shape = {layer_config.qkv_dim}, + .min_size = Type::kBF16, + }); + Add(suffix, { + .base_name = "qkv1_w", + .source_names = {"attn/q_einsum/w"}, + .axes = {0, 2, 1}, + .shape = {layer_config.heads * layer_config.qkv_dim, + config.model_dim}, + .concat_names = {"qkv_ein", "qkv2_w"}, + }); + Add(suffix, { + .base_name = "qkv2_w", + .source_names = {"attn/kv_einsum/w"}, + .axes = {1, 0, 3, 2}, + .shape = {2 * layer_config.kv_heads * layer_config.qkv_dim, + config.model_dim}, + .concat_names = {""}, + }); + Add(suffix, { + .base_name = "q_ein", + .source_names = {"attention_block/proj_q/kernel"}, + .axes = {1, 0}, + .shape = {layer_config.model_dim, layer_config.model_dim}, + .concat_names = {"qkv_ein", "k_ein", "v_ein"}, + }); + Add(suffix, { + .base_name = "k_ein", + .source_names = {"attention_block/proj_k/kernel"}, + .axes = {1, 0}, + .shape = {layer_config.qkv_dim, layer_config.model_dim}, + .concat_names = {""}, + }); + Add(suffix, { + .base_name = "v_ein", + .source_names = {"attention_block/proj_v/kernel"}, + .axes = {1, 0}, + .shape = {layer_config.qkv_dim, layer_config.model_dim}, + .concat_names = {""}, + }); + Add(suffix, { + .base_name = "qkv_ein", + .source_names = {"attn/qkv_einsum/w"}, + .axes = {1, 0, 3, 2}, + .shape = {(layer_config.heads + 2 * layer_config.kv_heads) * + layer_config.qkv_dim, + config.model_dim}, + }); + Add(suffix, { + .base_name = "attn_ob", + .source_names = {"attention_block/proj_final/bias"}, + .axes = {0}, + .shape = {config.model_dim}, + .min_size = Type::kF32, + }); + + Add(suffix, { + .base_name = "gating_ein", + .source_names = {"mlp/gating_einsum/w", "mlp/gating_einsum", + "mlp_block/ffw_up/w"}, + .axes = {0, layer_config.optimized_gating ? 1u : 2u, + layer_config.optimized_gating ? 2u : 1u}, + .shape = {2, layer_config.ff_hidden_dim, config.model_dim}, + }); + Add(suffix, { + .base_name = "gating1_w", + .source_names = {"none"}, + .axes = {0, layer_config.optimized_gating ? 1u : 2u, + layer_config.optimized_gating ? 2u : 1u}, + .shape = {layer_config.ff_hidden_dim, config.model_dim}, + }); + Add(suffix, { + .base_name = "gating2_w", + .source_names = {"none"}, + .axes = {0, layer_config.optimized_gating ? 1u : 2u, + layer_config.optimized_gating ? 2u : 1u}, + .shape = {layer_config.ff_hidden_dim, config.model_dim}, + }); + Add(suffix, { + .base_name = "linear_w", + .source_names = {"mlp/linear/w", "mlp/linear", + "mlp_block/ffw_down/kernel"}, + .axes = {1, 0}, + .shape = {config.model_dim, layer_config.ff_hidden_dim}, + }); + Add(suffix, { + .base_name = "pre_att_ns", + .source_names = {"pre_attention_norm/scale", + "temporal_pre_norm/scale"}, + .axes = {0}, + .shape = {config.model_dim}, + .min_size = Type::kBF16, + }); + Add(suffix, + { + .base_name = "pre_ff_ns", + .source_names = {"pre_ffw_norm/scale", "channel_pre_norm/scale"}, + .axes = {0}, + .shape = {config.model_dim}, + .min_size = Type::kBF16, + }); + Add(suffix, { + .base_name = "post_att_ns", + .source_names = {"post_attention_norm/scale"}, + .axes = {0}, + .shape = {config.model_dim}, + .min_size = Type::kBF16, + }); + Add(suffix, { + .base_name = "post_ff_ns", + .source_names = {"post_ffw_norm/scale"}, + .axes = {0}, + .shape = {config.model_dim}, + .min_size = Type::kBF16, + }); + Add(suffix, { + .base_name = "ffw_gat_b", + .source_names = {"mlp_block/ffw_up/b"}, + .axes = {0}, + .shape = {2 * layer_config.ff_hidden_dim}, + .min_size = Type::kF32, + }); + Add(suffix, { + .base_name = "ffw_out_b", + .source_names = {"mlp_block/ffw_down/bias"}, + .axes = {0}, + .shape = {config.model_dim}, + .min_size = Type::kF32, + }); + Add(suffix, + { + .base_name = "att_ein", + .source_names = {"attn/attn_vec_einsum/w", + "attention_block/proj_final/kernel"}, + .preshape = {layer_config.heads, layer_config.qkv_dim, + config.model_dim}, + .axes = {0, 2, 1}, + .shape = {layer_config.heads, config.model_dim, layer_config.qkv_dim}, + }); + Add(suffix, + { + .base_name = "att_w", + .shape = {config.model_dim, layer_config.heads, layer_config.qkv_dim}, + .cols_take_extra_dims = true, + }); + + if (config.model == Model::GRIFFIN_2B) { + AddGriffinLayerTensors(layer_config, layer_idx); + } +} + +TensorInfoRegistry::TensorInfoRegistry(const ModelConfig& config) { + // Upper bound on the number of `Add()` calls in `Add*Tensors()`. Loose bound + // in case those are changed without updating this. Better to allocate a bit + // more than to 1.5-2x the size if too little. + tensors_.reserve(10 + 32 * config.layer_configs.size() + + 24 * config.vit_config.layer_configs.size()); + AddModelTensors(config); + for (size_t i = 0; i < config.layer_configs.size(); ++i) { + AddLayerTensors(config, config.layer_configs[i], i); + } + for (size_t i = 0; i < config.vit_config.layer_configs.size(); ++i) { + AddImageLayerTensors(config, config.vit_config.layer_configs[i], i); + } +} + +TensorInfo TensorInfoRegistry::TensorInfoFromSourcePath(const std::string& path, + int layer_idx) const { + for (const TensorInfo& tensor : tensors_) { + for (const std::string& source_name : tensor.source_names) { + // path ends with source_name? + const size_t pos = path.rfind(source_name); + if (pos != std::string::npos && path.size() == pos + source_name.size()) { + std::string name = tensor.base_name; + if (layer_idx >= 0) name += LayerSuffix(static_cast(layer_idx)); + return TensorInfoFromName(name); + } + } + } + return TensorInfo(); +} + +} // namespace gcpp diff --git a/gemma/tensor_info.h b/gemma/tensor_info.h new file mode 100644 index 0000000..c8252a4 --- /dev/null +++ b/gemma/tensor_info.h @@ -0,0 +1,141 @@ +#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_TENSOR_INFO_H_ +#define THIRD_PARTY_GEMMA_CPP_GEMMA_TENSOR_INFO_H_ + +#include + +#include +#include +#include + +#include "compression/types.h" // Type +#include "gemma/configs.h" +#include "util/basics.h" // Extents2D + +namespace gcpp { + +// Tensor metadata. This is far more than required to construct the `MatPtr` in +// `LayerWeightsPtrs/WeightsPtrs`; they only use `.shape` via `ExtentsFromInfo`. +// This is also bound to Python and filled by the exporter. +struct TensorInfo { + // The base name of the tensor without a layer suffix. + std::string base_name; + // Strings to match to the end of the name of the tensor in the python model. + std::vector source_names; + // Initial reshape shape. Use only as a last resort when input may have + // dimensions combined that need to be split before the transpose, as it + // defeats the post-transpose shape checking. Normally empty. + std::vector preshape; + // Transpose axes arg. If the input tensor has more dimensions than axes, + // then leading dimensions are collapsed until the number of axes matches. + std::vector axes; + // Expected final shape of the tensor after reshape/transpose. + // Note that this is the shape of the tensor during export, + // not the shape of the tensor in the sbs file, as the sbs file + // is restricted to 2D tensors. With few exceptions, the sbs file + // tensor rows gather all the excess dimensions. See cols_take_extra_dims. + std::vector shape; + // List of names to concatenate with, used only if multiple tensors are + // concatenated into one. The first tensor in the concatenation should have + // concat names thus: The first name is the name of the result, and the + // tensors with the remaining names are concatenated after this. + // The remaining tensors to be concatenated should have just a single + // empty string in concat_names to indicate that they have been consumed. + std::vector concat_names; + // Axis at which to concatenate. + size_t concat_axis = 0; + // The highest permissible compression for this tensor. The default is + // kNUQ, which provides maximum compression. Other values such as kBF16 + // or kF32 can be used to limit the compression to a specific type. + Type min_size = Type::kNUQ; + // Whether to apply scaled softplus to the data. + bool scaled_softplus = false; + // Whether the columns or the rows take any extra dimensions. + // If false, then [10, 20, 30] -> [10*20, 30] and [30] -> [1, 30]. + // If true, then [10, 20, 30] -> [10, 20*30] and [30] -> [1, 30]. + bool cols_take_extra_dims = false; +}; + +// Collapses/expands the tensor dims into 2D extents, which may be 0, 0 for +// not-present tensors such as ViT in a text-only model. Safely handles nullptr +// returned from `TensorInfoRegistry::Find`, hence not a member function. +static inline Extents2D ExtentsFromInfo(const TensorInfo* tensor) { + if (tensor == nullptr) return Extents2D(0, 0); + + size_t cols = tensor->shape.back(); + size_t rows = 1; + if (tensor->cols_take_extra_dims) { + rows = tensor->shape[0]; + for (size_t i = 1; i < tensor->shape.size() - 1; ++i) { + cols *= tensor->shape[i]; + } + } else { // rows take extra dims + for (size_t i = 0; i < tensor->shape.size() - 1; ++i) { + rows *= tensor->shape[i]; + } + } + // Sometimes only one of rows or cols is zero; set both for consistency. + if (rows == 0 || cols == 0) rows = cols = 0; + return Extents2D(rows, cols); +} + +static inline std::string LayerSuffix(size_t layer_idx) { + return std::string("_") + std::to_string(layer_idx); +} + +// Returns tensor base name without any layer suffix. +static inline std::string StripLayerSuffix(const std::string& name) { + return name.substr(0, name.rfind('_')); +} + +// Holds all `TensorInfo` for a model and retrieves them by (unique) name. +class TensorInfoRegistry { + public: + explicit TensorInfoRegistry(const ModelConfig& config); + ~TensorInfoRegistry() = default; + + // Returns nullptr if not found, otherwise the `TensorInfo` for the given + // `name`, which either lacks a suffix, or is per-layer and ends with + // `LayerSuffix(layer_idx)`. Used in `WeightsPtrs/LayerWeightsPtrs`. + const TensorInfo* Find(const std::string& name) const { + auto it = idx_from_name_.find(name); + if (it == idx_from_name_.end()) return nullptr; + return &tensors_[it->second]; + } + + // Returns a copy of the `TensorInfo` whose name matches the given name, or a + // default-constructed `TensorInfo` if not found. Destroying + // `TensorInfoRegistry` afterward will not invalidate the returned value. + TensorInfo TensorInfoFromName(const std::string& name) const { + const TensorInfo* info = Find(name); + if (info == nullptr) return TensorInfo(); + return *info; + } + + // Returns a copy of the `TensorInfo` whose source_name matches the end of the + // given path, and whose name ends with the given layer_idx, otherwise a + // default-constructed `TensorInfo`. Destroying `TensorInfoRegistry` + // afterward will not invalidate the returned value. + TensorInfo TensorInfoFromSourcePath(const std::string& path, + int layer_idx) const; + + private: + // `suffix` is empty (only) for per-model tensors, otherwise `LayerSuffix`. + void Add(const std::string& suffix, const TensorInfo& info); + void AddModelTensors(const ModelConfig& config); + void AddLayerTensors(const ModelConfig& config, + const LayerConfig& layer_config, size_t layer_idx); + void AddGriffinLayerTensors(const LayerConfig& layer_config, + size_t layer_idx); + + void AddImageLayerTensors(const ModelConfig& config, + const LayerConfig& layer_config, + size_t img_layer_idx); + + std::vector tensors_; + // Includes entries for base name *and* the suffixed name for each layer. + std::unordered_map idx_from_name_; +}; + +} // namespace gcpp + +#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_TENSOR_INFO_H_ diff --git a/gemma/tensor_info_test.cc b/gemma/tensor_info_test.cc new file mode 100644 index 0000000..8a95376 --- /dev/null +++ b/gemma/tensor_info_test.cc @@ -0,0 +1,40 @@ +#include "gemma/tensor_info.h" + +#include + +#include "gtest/gtest.h" +#include "compression/types.h" // SfpStream +#include "gemma/configs.h" +#include "gemma/weights.h" +#include "util/mat.h" +#include "hwy/base.h" // HWY_ASSERT_M + +namespace gcpp { +namespace { + +// Tests for all models that each tensor in the model can be found and that the +// TensorInfoRegistry returns the correct shape and name for the tensor. +TEST(TensorInfoRegistryTest, Find) { + ForEachModel([&](Model model) { + const ModelConfig config(model, Type::kSFP, ChooseWrapping(model)); + fprintf(stderr, "Testing %s (%s)\n", config.display_name.c_str(), + config.Specifier().c_str()); + const TensorInfoRegistry tensors(config); + // Each tensor in the model should be known/found. + WeightsPtrs weights(config); + weights.ForEachTensor(nullptr, nullptr, [&tensors](const TensorArgs& t) { + const TensorInfo* info = tensors.Find(t.mat.Name()); + HWY_ASSERT_M(info, t.mat.Name()); + // Test that the `MatPtr` can be constructed from the TensorInfo, + // and that the dimensions match. + const MatPtr mat_ptr(t.mat.Name(), Type::kUnknown, + ExtentsFromInfo(tensors.Find(t.mat.Name()))); + EXPECT_STREQ(t.mat.Name(), mat_ptr.Name()) << t.mat.Name(); + EXPECT_EQ(t.mat.Rows(), mat_ptr.Rows()) << t.mat.Name(); + EXPECT_EQ(t.mat.Cols(), mat_ptr.Cols()) << t.mat.Name(); + }); + }); +} + +} // namespace +} // namespace gcpp diff --git a/gemma/tokenizer.cc b/gemma/tokenizer.cc index e48abae..6e39f27 100644 --- a/gemma/tokenizer.cc +++ b/gemma/tokenizer.cc @@ -21,9 +21,7 @@ #include #include -#include "compression/io.h" // Path -#include "compression/shared.h" // PromptWrapping -#include "gemma/common.h" // Wrap +#include "gemma/configs.h" // PromptWrapping #include "hwy/base.h" // HWY_ASSERT #include "hwy/profiler.h" // copybara:import_next_line:sentencepiece @@ -37,24 +35,20 @@ constexpr bool kShowTokenization = false; class GemmaTokenizer::Impl { public: Impl() = default; - explicit Impl(const Path& tokenizer_path) { - PROFILER_ZONE("Startup.tokenizer"); - spp_ = std::make_unique(); - if (!spp_->Load(tokenizer_path.path).ok()) { - HWY_ABORT("Failed to load the tokenizer file."); - } - } // Loads the tokenizer from a serialized proto. explicit Impl(const std::string& tokenizer_proto) { + if (tokenizer_proto == kMockTokenizer) return; PROFILER_ZONE("Startup.tokenizer"); spp_ = std::make_unique(); if (!spp_->LoadFromSerializedProto(tokenizer_proto).ok()) { - fprintf(stderr, "serialized proto size=%zu.\n", tokenizer_proto.size()); - HWY_ABORT("Failed to load the tokenizer from serialized proto."); + HWY_ABORT("Failed to load tokenizer from %zu byte serialized proto.", + tokenizer_proto.size()); } } - std::string Serialize() const { return spp_->serialized_model_proto(); } + std::string Serialize() const { + return spp_ ? spp_->serialized_model_proto() : kMockTokenizer; + } bool Encode(const std::string& input, std::vector* pieces) const { @@ -82,22 +76,18 @@ class GemmaTokenizer::Impl { std::unique_ptr spp_; }; -GemmaTokenizer::GemmaTokenizer(const Path& tokenizer_path) { - impl_ = std::make_unique(tokenizer_path); +GemmaTokenizer::GemmaTokenizer(const std::string& tokenizer_proto) + : impl_(std::make_unique(tokenizer_proto)) { + HWY_ASSERT(impl_); } // Default suffices, but they must be defined after GemmaTokenizer::Impl. -GemmaTokenizer::GemmaTokenizer() = default; GemmaTokenizer::~GemmaTokenizer() = default; GemmaTokenizer::GemmaTokenizer(GemmaTokenizer&& other) = default; GemmaTokenizer& GemmaTokenizer::operator=(GemmaTokenizer&& other) = default; std::string GemmaTokenizer::Serialize() const { return impl_->Serialize(); } -void GemmaTokenizer::Deserialize(const std::string& tokenizer_proto) { - impl_ = std::make_unique(tokenizer_proto); -} - bool GemmaTokenizer::Encode(const std::string& input, std::vector* pieces) const { return impl_->Encode(input, pieces); @@ -114,57 +104,109 @@ bool GemmaTokenizer::Decode(const std::vector& ids, return impl_->Decode(ids, detokenized); } -std::vector WrapAndTokenize(const GemmaTokenizer& tokenizer, - const ModelInfo& info, size_t pos, - std::string& prompt) { - Wrap(info, pos, prompt); +GemmaChatTemplate::GemmaChatTemplate(const GemmaTokenizer& tokenizer, + Model model) { + sot_user_.reserve(3); + if (!tokenizer.Encode("user\n", &sot_user_)) return; + sot_model_.reserve(3); + HWY_ASSERT(tokenizer.Encode("model\n", &sot_model_)); + eot_.reserve(2); + HWY_ASSERT(tokenizer.Encode("\n", &eot_)); - std::vector tokens; - HWY_ASSERT(tokenizer.Encode(prompt, &tokens)); - // Both pre-trained and instruction-tuned require BOS as first token. - if (pos == 0) { - tokens.insert(tokens.begin(), BOS_ID); - } - - // PaliGemma separator. The SEP token "\n" is always tokenized separately. - if (info.wrapping == PromptWrapping::PALIGEMMA - // || info.wrapping == PromptWrapping::GEMMA_VLM - ) { - std::vector sep_tokens; - HWY_ASSERT(tokenizer.Encode("\n", &sep_tokens)); - tokens.insert(tokens.end(), sep_tokens.begin(), sep_tokens.end()); - } - - return tokens; + HWY_ASSERT(tokenizer.Encode("\n", &pali_sep_)); + vlm_soi_.reserve(2); + HWY_ASSERT(tokenizer.Encode("\n\n", &vlm_soi_)); + vlm_eoi_.reserve(2); + HWY_ASSERT(tokenizer.Encode("\n\n", &vlm_eoi_)); } -std::vector WrapVLM(const GemmaTokenizer& tokenizer, const ModelInfo& info, - size_t pos, std::vector& tokens, - size_t image_batch_size, size_t max_image_batch_size) { - HWY_ASSERT(info.wrapping == PromptWrapping::GEMMA_VLM); - size_t num_images = hwy::DivCeil(image_batch_size, max_image_batch_size); +std::vector GemmaChatTemplate::Apply(size_t pos, + const std::vector& ids) const { + HWY_ASSERT_M(!sot_user_.empty() && !sot_model_.empty() && !eot_.empty(), + "GemmaChatTemplate has not been initialized."); + std::vector out; + out.reserve(eot_.size() + sot_user_.size() + ids.size() + eot_.size() + + sot_model_.size()); - std::vector sep_tokens; - HWY_ASSERT(tokenizer.Encode("\n", &sep_tokens)); - - std::string begin_image_prompt = "\n\n"; - std::vector begin_image_tokens = - WrapAndTokenize(tokenizer, info, pos, begin_image_prompt); - - std::string end_image_prompt = "\n\n"; - std::vector end_image_tokens = - WrapAndTokenize(tokenizer, info, pos, end_image_prompt); - - for (size_t i = 0; i < num_images; ++i) { - tokens.insert(tokens.begin(), begin_image_tokens.begin(), - begin_image_tokens.end()); - tokens.insert(tokens.begin() + begin_image_tokens.size(), image_batch_size, - -2); - tokens.insert(tokens.begin() + begin_image_tokens.size() + image_batch_size, - end_image_tokens.begin(), end_image_tokens.end()); + // Start with BOS, or prepend end_of_turn if this is a continuation. + if (pos == 0) { + out.push_back(BOS_ID); + } else { + out.insert(out.cend(), eot_.cbegin(), eot_.cend()); } + // Start of user turn, user prompt, end of turn; then start of model turn. + out.insert(out.cend(), sot_user_.cbegin(), sot_user_.cend()); + out.insert(out.cend(), ids.cbegin(), ids.cend()); + out.insert(out.cend(), eot_.cbegin(), eot_.cend()); + out.insert(out.cend(), sot_model_.cbegin(), sot_model_.cend()); + return out; +} - return tokens; +std::vector GemmaChatTemplate::WrapPali(const std::vector& text_part, + size_t image_batch_size) const { + HWY_ASSERT_M(!pali_sep_.empty(), + "GemmaChatTemplate has not been initialized."); + std::vector out; + out.reserve(image_batch_size + 1 + text_part.size() + pali_sep_.size()); + out.resize(image_batch_size, 0); + out.push_back(BOS_ID); + out.insert(out.cend(), text_part.cbegin(), text_part.cend()); + out.insert(out.cend(), pali_sep_.cbegin(), pali_sep_.cend()); + return out; +} + +std::vector GemmaChatTemplate::WrapVLM(const std::vector& text_part, + size_t image_batch_size) const { + HWY_ASSERT_M(!vlm_soi_.empty() && !vlm_eoi_.empty(), + "GemmaChatTemplate has not been initialized."); + std::vector out; + out.reserve(text_part.size() + vlm_soi_.size() + image_batch_size + + vlm_eoi_.size()); + out.insert(out.cend(), text_part.cbegin(), text_part.cend()); + out.insert(out.cend(), vlm_soi_.cbegin(), vlm_soi_.cend()); + out.insert(out.cend(), image_batch_size, -2); + out.insert(out.cend(), vlm_eoi_.cbegin(), vlm_eoi_.cend()); + return out; +} + +// Text +std::vector WrapAndTokenize(const GemmaTokenizer& tokenizer, + const GemmaChatTemplate& chat_template, + const PromptWrapping wrapping, size_t pos, + const std::string& prompt) { + std::vector tokens; + HWY_ASSERT(tokenizer.Encode(prompt, &tokens)); + + switch (wrapping) { + case PromptWrapping::GEMMA_IT: + case PromptWrapping::GEMMA_VLM: + return chat_template.Apply(pos, tokens); + default: + if (pos == 0) { + tokens.insert(tokens.cbegin(), BOS_ID); + } + return tokens; + } +} + +// Vision +std::vector WrapAndTokenize(const GemmaTokenizer& tokenizer, + const GemmaChatTemplate& chat_template, + const PromptWrapping wrapping, size_t pos, + const std::string& prompt, + size_t image_batch_size) { + std::vector text_part; + HWY_ASSERT(tokenizer.Encode(prompt, &text_part)); + switch (wrapping) { + case PromptWrapping::PALIGEMMA: + HWY_ASSERT(pos == 0); + return chat_template.WrapPali(text_part, image_batch_size); + case PromptWrapping::GEMMA_VLM: + return chat_template.Apply( + pos, chat_template.WrapVLM(text_part, image_batch_size)); + default: + HWY_ASSERT_M(false, "Current variant does not support vision prompt."); + } } } // namespace gcpp diff --git a/gemma/tokenizer.h b/gemma/tokenizer.h index 0bbd8f4..aca01f9 100644 --- a/gemma/tokenizer.h +++ b/gemma/tokenizer.h @@ -22,29 +22,30 @@ #include #include -#include "compression/io.h" // Path -#include "gemma/common.h" // ModelInfo +#include "gemma/configs.h" // PromptWrapping namespace gcpp { -// The tokenizer's end of sentence and beginning of sentence token ids. -constexpr int EOS_ID = 1; -constexpr int SECONDARY_EOS_ID = 106; // for Gemma 3 -constexpr int BOS_ID = 2; +constexpr int BOS_ID = 2; // beginning of sequence + +// To avoid the complexity of storing the tokenizer into testdata/ or +// downloading from gs://, while still always writing a blob for the tokenizer, +// but also avoiding empty blobs, we store this placeholder string. +constexpr const char* kMockTokenizer = "unavailable"; class GemmaTokenizer { + // These must be defined after the definition of `Impl`. public: - GemmaTokenizer(); - explicit GemmaTokenizer(const Path& tokenizer_path); - - // must come after definition of Impl + // If unavailable, pass `kMockTokenizer`. + explicit GemmaTokenizer(const std::string& tokenizer_proto); ~GemmaTokenizer(); GemmaTokenizer(GemmaTokenizer&& other); GemmaTokenizer& operator=(GemmaTokenizer&& other); + // Returns `kMockTokenizer` if unavailable. std::string Serialize() const; - void Deserialize(const std::string& tokenizer_proto); + // Returns false on failure or if unavailable. bool Encode(const std::string& input, std::vector* pieces) const; bool Encode(const std::string& input, std::vector* ids) const; bool Decode(const std::vector& ids, std::string* detokenized) const; @@ -54,13 +55,38 @@ class GemmaTokenizer { std::unique_ptr impl_; }; -std::vector WrapAndTokenize(const GemmaTokenizer& tokenizer, - const ModelInfo& info, size_t pos, - std::string& prompt); +class GemmaChatTemplate { + public: + // No effect if `tokenizer` is unavailable, but any other method may abort. + GemmaChatTemplate(const GemmaTokenizer& tokenizer, Model model); -std::vector WrapVLM(const GemmaTokenizer& tokenizer, const ModelInfo& info, - size_t pos, std::vector& tokens, - size_t image_batch_size, size_t max_image_batch_size); + // Given prompt tokens, this returns the wrapped prompt including BOS and + // any "start_of_turn" structure required by the model. + std::vector Apply(size_t pos, const std::vector& ids) const; + std::vector WrapPali(const std::vector& text_part, + size_t image_batch_size) const; + std::vector WrapVLM(const std::vector& text_part, + size_t image_batch_size) const; + + private: + std::vector sot_user_; + std::vector sot_model_; + std::vector eot_; + std::vector pali_sep_; + std::vector vlm_soi_; + std::vector vlm_eoi_; +}; + +std::vector WrapAndTokenize(const GemmaTokenizer& tokenizer, + const GemmaChatTemplate& chat_template, + PromptWrapping wrapping, size_t pos, + const std::string& prompt); + +std::vector WrapAndTokenize(const GemmaTokenizer& tokenizer, + const GemmaChatTemplate& chat_template, + PromptWrapping wrapping, size_t pos, + const std::string& prompt, + size_t image_batch_size); } // namespace gcpp diff --git a/gemma/vit.cc b/gemma/vit.cc new file mode 100644 index 0000000..3549f85 --- /dev/null +++ b/gemma/vit.cc @@ -0,0 +1,349 @@ +// Copyright 2025 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include // sqrtf +#include +#include + +#include + +#include "compression/types.h" // GEMMA_DISABLED_TARGETS +#ifndef HWY_DISABLED_TARGETS +#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS +#endif // HWY_DISABLED_TARGETS + +#include "gemma/activations.h" +#include "gemma/gemma.h" +#include "gemma/gemma_args.h" +#include "gemma/weights.h" +#include "paligemma/image.h" +#include "hwy/contrib/thread_pool/thread_pool.h" +#include "hwy/profiler.h" + +// Compiles this file for multiple architectures via "foreach_target.h", to +// which we pass the filename via macro 'argument'. +// clang-format off +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "gemma/vit.cc" // NOLINT +// clang-format on +#include "hwy/foreach_target.h" // IWYU pragma: keep +#include "hwy/highway.h" +// After highway.h +#include "gemma/gemma-inl.h" +#include "ops/ops-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace gcpp { +namespace HWY_NAMESPACE { + +// Wrapper class; holds arguments in member variables to shorten call sites. +// The main differences to GemmaAttention are: +// - no KV Cache necessary, attention is always all-to-all and not causal. +// - no potential wrap-around, attention always goes from 0 to kSeqLen. +// - no need for batching, as we are always computing attention for kSeqLen +// tokens. +// This results in a much simpler implementation. However, to avoid duplicating +// code, we should still consider merging the two classes. +// TODO(keysers): Refactor to share code with GemmaAttention. +class VitAttention { + // Computes Q, K, V for all heads, stored in activations_.q. + HWY_NOINLINE void ComputeQKV() { + PROFILER_ZONE("Gen.VitAttention.QKV"); + auto& qkv = activations_.attention.q; + HWY_ASSERT(qkv.Rows() == num_tokens_); + HWY_ASSERT(qkv.Cols() == layer_config_.heads * 3 * layer_config_.qkv_dim); + CallMatMul(activations_.attention.pre_att_rms_out, layer_.vit.qkv_einsum_w, + layer_.vit.qkv_einsum_b.PackedScale1(), env_, qkv); + } + + // TODO(philculliton): transition fully to MatMul. + HWY_NOINLINE void DotSoftmaxWeightedSumMatrix() { + const size_t qkv_dim = layer_config_.qkv_dim; + const size_t heads = layer_config_.heads; + HWY_ASSERT_M(heads == layer_config_.kv_heads, "Vit expects MHA"); + const size_t seq_len = + static_cast(activations_.attention.div_seq_len.GetDivisor()); + const float query_scale = 1.0f / sqrtf(static_cast(qkv_dim)); + PROFILER_ZONE("Gen.VitAttention.DotSoftmax"); + + // Shift Q, K, VT to MatStorageT. + MatStorageT Q("Q2", Extents2D(num_tokens_, qkv_dim), + env_.ctx.allocator, MatPadding::kPacked); + MatStorageT K("K2", Extents2D(seq_len, qkv_dim), env_.ctx.allocator, + MatPadding::kPacked); + MatStorageT C("C2", Extents2D(num_tokens_, seq_len), + env_.ctx.allocator, MatPadding::kPacked); + + // Initialize att_out to zero prior to head loop. + ZeroInit(activations_.attention.att_out); + + for (size_t head = 0; head < heads; ++head) { + pool_.Run(0, num_tokens_, [&](uint64_t task, size_t worker) HWY_ATTR { + const size_t token = task; + float* HWY_RESTRICT q = + activations_.attention.q.Row(token) + head * 3 * qkv_dim; + // TODO: shift to MatMul with A.scale once MatMul is confirmed working + MulByConst(query_scale, q, qkv_dim, worker); + hwy::CopyBytes(q, Q.Row(token), qkv_dim * sizeof(float)); + }); + + pool_.Run(0, seq_len, [&](uint64_t task, size_t /*thread*/) HWY_ATTR { + const size_t seq_idx = task; + float* HWY_RESTRICT k = activations_.attention.q.Row(seq_idx) + + head * 3 * qkv_dim + qkv_dim; + hwy::CopyBytes(k, K.Row(seq_idx), qkv_dim * sizeof(float)); + }); + + // this produces C, a (num_tokens_, seq_len) matrix of dot products + CallMatMul(Q, K, nullptr, env_, C); + + pool_.Run(0, num_tokens_, [&](uint64_t task, size_t worker) HWY_ATTR { + float* HWY_RESTRICT c = C.Row(task); + Softmax(c, C.Cols(), worker); + }); + + pool_.Run(0, num_tokens_, [&](uint64_t task, size_t worker) HWY_ATTR { + size_t token = task; + float* HWY_RESTRICT att_out = + activations_.attention.att_out.Row(token) + head * qkv_dim; + for (size_t i = 0; i < seq_len; ++i) { + float* HWY_RESTRICT v = activations_.attention.q.Row(i) + + head * 3 * qkv_dim + 2 * qkv_dim; + MulByConstAndAdd(C.Row(token)[i], v, att_out, qkv_dim, worker); + } + }); + } + } + + HWY_NOINLINE void DotSoftmaxWeightedSum() { + const size_t qkv_dim = layer_config_.qkv_dim; + const size_t heads = layer_config_.heads; + HWY_ASSERT_M(heads == layer_config_.kv_heads, "Vit expects MHA"); + const size_t seq_len = + static_cast(activations_.attention.div_seq_len.GetDivisor()); + const float query_scale = 1.0f / sqrtf(static_cast(qkv_dim)); + PROFILER_ZONE("Gen.VitAttention.DotSoftmax"); + + // Compute Q.K, softmax, and weighted V. + pool_.Run(0, layer_config_.heads * num_tokens_, + [&](uint64_t task, size_t worker) HWY_ATTR { + const size_t head = task % layer_config_.heads; + const size_t token = task / layer_config_.heads; + // Compute Q.K scores, which are "logits" stored in head_att. + float* HWY_RESTRICT q = + activations_.attention.q.Row(token) + head * 3 * qkv_dim; + MulByConst(query_scale, q, qkv_dim, worker); + float* HWY_RESTRICT head_att = + activations_.attention.att.Row(token) + head * seq_len; + for (size_t i = 0; i < seq_len; ++i) { + float* HWY_RESTRICT k = activations_.attention.q.Row(i) + + head * 3 * qkv_dim + qkv_dim; + head_att[i] = Dot(q, k, qkv_dim); // score = q.k + } + // SoftMax yields "probabilities" in head_att. + Softmax(head_att, seq_len, worker); + // Compute weighted sum of v into att_out. + float* HWY_RESTRICT att_out = + activations_.attention.att_out.Row(token) + head * qkv_dim; + hwy::ZeroBytes(att_out, qkv_dim * sizeof(*att_out)); + for (size_t i = 0; i < seq_len; ++i) { + float* HWY_RESTRICT v = activations_.attention.q.Row(i) + + head * 3 * qkv_dim + 2 * qkv_dim; + MulByConstAndAdd(head_att[i], v, att_out, qkv_dim, worker); + } + }); + } + + // Sums encoded (`att_out`) over num_heads (`layer_config_.heads`) and + // head_dim (`qkv_dim`) into output (`att_sums`). + HWY_NOINLINE void SumHeads() { + PROFILER_ZONE("Gen.VitAttention.SumHeads"); + auto* bias = layer_.vit.attn_out_b.PackedScale1(); + // att_weights and att_out are concatenated heads, each of length + // qkv_dim. Thus the [num_tokens_, layer_config_.model_dim] + // matmul output is the sum over heads. + CallMatMul(activations_.attention.att_out, layer_.vit.attn_out_w, bias, + env_, activations_.attention.att_sums); + } + + public: + VitAttention(size_t num_tokens, size_t layer_idx, Activations& activations, + const LayerWeightsPtrs& layer, MatMulEnv& env) + : num_tokens_(num_tokens), + activations_(activations), + layer_(layer), + layer_config_(layer.layer_config), + env_(env), + pool_(env_.ctx.pools.Pool(0)) {} + + HWY_INLINE void operator()() { + ComputeQKV(); + if (activations_.attention.config.wrapping == PromptWrapping::GEMMA_VLM) { + DotSoftmaxWeightedSumMatrix(); + } else { + DotSoftmaxWeightedSum(); + } + SumHeads(); + } + + private: + const size_t num_tokens_; + Activations& activations_; + const LayerWeightsPtrs& layer_; + const LayerConfig& layer_config_; + MatMulEnv& env_; + hwy::ThreadPool& pool_; +}; + +// Same as FFWNoVit, but with different layer members and no second +// gating matrix. +void FFWVit(const LayerWeightsPtrs& layer, Activations& activations, + MatMulEnv& env) { + PROFILER_ZONE("Gen.FFW.ViT"); + const LayerConfig& layer_config = layer.layer_config; + + const bool add_bias = layer_config.ff_biases; + const float* bias1 = add_bias ? layer.vit.linear_0_b.PackedScale1() : nullptr; + const float* output_bias = + add_bias ? layer.vit.linear_1_b.PackedScale1() : nullptr; + + // Compute the hidden layer activations. + CallMatMul(activations.pre_ffw_rms_out, layer.vit.linear_0_w, bias1, env, + activations.C1); + + // Activation (Gelu), store in C1. + ActivationBatched(layer_config.activation, activations.C1, env.ctx.pools); + + // Hidden layer -> output layer. + CallMatMul(activations.C1, layer.vit.linear_1_w, output_bias, env, + activations.ffw_out); +} + +// Vit transformer layer. Some comments below refer to the Vit implementation in +// the Big Vision codebase. See +// github.com/google-research/big_vision/blob/main/big_vision/models/vit.py +// TODO(keysers): consider adding a wrapper for both LayerNorm with RMSNorm and +// try merging this with TransformerLayer. +void VitTransformerLayer(size_t num_tokens, const size_t layer_idx, + const LayerWeightsPtrs& layer, + Activations& activations, MatMulEnv& env) { + const size_t model_dim = activations.attention.config.model_dim; + auto type = layer.layer_config.type; + HWY_DASSERT(type == LayerAttentionType::kVit); + (void)type; + (void)model_dim; + + auto& x = activations.x; + HWY_DASSERT(x.Rows() == num_tokens); + HWY_DASSERT(x.Cols() == model_dim); + + // y = nn.LayerNorm()(x) + // y ~ pre_att_rms_out + LayerNormBatched(x, layer.vit.layer_norm_0_scale, layer.vit.layer_norm_0_bias, + activations.attention.pre_att_rms_out); + + // y = out["sa"] = nn.MultiHeadDotProductAttention(...)(y, y) + // y ~ att_sums + VitAttention(num_tokens, layer_idx, activations, layer, env)(); + + // x = out["+sa"] = x + y + AddFromBatched(activations.attention.att_sums, x, env.ctx); + + // y = nn.LayerNorm()(x) + // y ~ pre_ffw_rms_out + LayerNormBatched(x, layer.vit.layer_norm_1_scale, layer.vit.layer_norm_1_bias, + activations.pre_ffw_rms_out); + + // y = out["mlp"] = MlpBlock(...)(y) + // y ~ ffw_out + FFWVit(layer, activations, env); + + // x = out["+mlp"] = x + y + AddFromBatched(activations.ffw_out, x, env.ctx); +} + +// Gets the patches of the image and embeds them with the image embedding +// kernel. The result is stored in activations.x. +static HWY_NOINLINE void EmbedImagePatches(const Image& image, + const ModelConfig& model_config, + const WeightsPtrs& weights, + Activations& activations, + MatMulEnv& env) { + const size_t model_dim = model_config.vit_config.model_dim; + const size_t patch_width = model_config.vit_config.patch_width; + const size_t num_tokens = model_config.vit_config.seq_len; + const size_t patch_size = patch_width * patch_width * 3; + HWY_DASSERT(weights.vit_img_embedding_kernel.Rows() == model_dim); + HWY_DASSERT(weights.vit_img_embedding_kernel.Cols() == patch_size); + HWY_DASSERT(activations.x.Cols() == model_dim); + (void)model_dim; + // img/embedding/kernel has original shape (14, 14, 3, 1152) + // H x W x C x D transposed to D x (H x W x C) so here (1152, 14 * 14 * 3) + // image_patches is (256, 14 * 14 * 3) + // Must be padded, see `DoDecompressA`. + MatStorageT image_patches("patches", Extents2D(num_tokens, patch_size), + env.ctx.allocator, MatPadding::kOdd); + for (size_t i = 0; i < num_tokens; ++i) { + image.GetPatch(i, image_patches.Row(i)); + } + CallMatMul(image_patches, weights.vit_img_embedding_kernel, + weights.vit_img_embedding_bias.PackedScale1(), env, activations.x); + // Add position embeddings. + CallUpcastedActivation(&weights.vit_img_pos_embedding, + [&](const auto* weights_t) { + AddFromBatched(*weights_t, activations.x, env.ctx); + }); +} + +// Prefills the image tokens with the ViT encoder. +void PrefillVit(const ModelConfig& model_config, const WeightsPtrs& weights, + const RuntimeConfig& runtime_config, const Image& image, + ImageTokens& image_tokens, Activations& activations, + MatMulEnv& env) { + PROFILER_ZONE("Gen.PrefillVit"); + const size_t num_tokens = model_config.vit_config.seq_len; + const size_t vit_model_dim = model_config.vit_config.model_dim; + HWY_ASSERT(num_tokens == activations.x.Rows()); + // Embed the image patches. + EmbedImagePatches(image, model_config, weights, activations, env); + // Go through all layers. + for (size_t layer_idx = 0; + layer_idx < model_config.vit_config.layer_configs.size(); ++layer_idx) { + VitTransformerLayer(num_tokens, layer_idx, *weights.VitLayer(layer_idx), + activations, env); + } + // Final Layernorm. + LayerNormBatched(activations.x, weights.vit_encoder_norm_scale, + weights.vit_encoder_norm_bias, activations.x); + + if (model_config.wrapping == PromptWrapping::GEMMA_VLM) { + activations.x = AvgPool4x4(activations.x, env.ctx.allocator); + + // Apply soft embedding norm before input projection. + CallUpcasted(&weights.mm_embed_norm, [&](const auto* weights_t) { + RMSNormInplace(weights_t->PackedScale1(), 0, activations.x.Row(0), + vit_model_dim, /*worker=*/0); + }); + } + + // Apply head embedding into image_tokens of size of the LLM kModelDim. + CallMatMul(activations.x, weights.vit_img_head_kernel, + weights.vit_img_head_bias.PackedScale1(), env, image_tokens); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace gcpp +HWY_AFTER_NAMESPACE(); diff --git a/gemma/vit.h b/gemma/vit.h new file mode 100644 index 0000000..d6562f6 --- /dev/null +++ b/gemma/vit.h @@ -0,0 +1,50 @@ +// Copyright 2025 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_VIT_H_ +#define THIRD_PARTY_GEMMA_CPP_GEMMA_VIT_H_ + +// Declares vision transformer FFW/Prefill for all SIMD targets. + +#include + +#include "gemma/gemma.h" +#include "hwy/highway.h" + +namespace gcpp { + +// Passed to HWY_VISIT_TARGETS; declares for one target. +#define GEMMA_DECL_VIT(TARGET, NAMESPACE) \ + namespace NAMESPACE { \ + void FFWVit(const LayerWeightsPtrs& layer, Activations& activations, \ + MatMulEnv& env); \ + \ + void PrefillVit(const ModelConfig& model_config, const WeightsPtrs& weights, \ + const RuntimeConfig& runtime_config, const Image& image, \ + ImageTokens& image_tokens, Activations& activations, \ + MatMulEnv& env); \ + /* NOLINTNEXTLINE(google-readability-namespace-comments) */ \ + } // namespace NAMESPACE + +// Function declarations for each SIMD target. Allows direct call from the +// per-target namespace. We may later replace this with dynamic dispatch if +// the overhead is acceptable. +HWY_VISIT_TARGETS(GEMMA_DECL_VIT) + +#undef GEMMA_DECL_VIT + +} // namespace gcpp + +#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_VIT_H_ diff --git a/gemma/weights.cc b/gemma/weights.cc index d281391..3418acf 100644 --- a/gemma/weights.cc +++ b/gemma/weights.cc @@ -15,272 +15,172 @@ #include "gemma/weights.h" -#include -#include -#include -#include +#include +#include +#include +#include + +#include // NOLINT #include #include -#include "compression/blob_store.h" -#include "compression/compress-inl.h" #include "compression/compress.h" -#include "compression/io.h" // Path -#include "compression/shared.h" -#include "gemma/common.h" +#include "compression/types.h" #include "gemma/configs.h" -#include "hwy/aligned_allocator.h" -#include "hwy/base.h" // HWY_ABORT +#include "gemma/gemma_args.h" +#include "gemma/model_store.h" +#include "io/blob_store.h" +#include "ops/matmul.h" // MMParallel +#include "util/mat.h" +#include "util/threading_context.h" +#include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/highway.h" #include "hwy/profiler.h" -#include "hwy/stats.h" + +// TODO: move into foreach_target +#include "compression/compress-inl.h" namespace gcpp { -template -struct TensorLoader { - void operator()(ModelWeightsPtrs& weights, ForEachType fet, - ReadFromBlobStore& loader) { - weights.ForEachTensor( - {&weights}, fet, - [&loader](const char* name, hwy::Span tensors) { - loader(name, tensors); - }); - } -}; +// Copies att_weights from `attn_vec_einsum_w`. +void LayerWeightsPtrs::InitAttWeights(std::vector& mat_owners, + const Allocator& allocator) { + // We only use this tensor for Gemma layers. + if (layer_config.type != LayerAttentionType::kGemma) return; -BlobError ModelWeightsStorage::Load(const Path& weights, Model model_type, - Type weight_type, PromptWrapping wrapping, - hwy::ThreadPool& pool, - std::string* tokenizer_proto) { - PROFILER_ZONE("Startup.LoadModelWeightsPtrs"); - if (!weights.Exists()) { - HWY_ABORT("The model weights file '%s' does not exist.", - weights.path.c_str()); + // Files must have one or the other. + HWY_ASSERT(attn_vec_einsum_w.HasPtr() ^ att_weights.HasPtr()); + // Done if we already read the transposed tensor. + if (att_weights.HasPtr() && !attn_vec_einsum_w.HasPtr()) return; + + // NUQ is handled by a specialization in weights.cc. + HWY_ASSERT(attn_vec_einsum_w.GetType() != Type::kNUQ); + + const size_t model_dim = layer_config.model_dim; + const size_t heads = layer_config.heads; + const size_t qkv_dim = layer_config.qkv_dim; + + // Reshape [heads, model_dim, qkv_dim] to [model_dim, heads * qkv_dim]. + att_weights.SetType(attn_vec_einsum_w.GetType()); + HWY_ASSERT(att_weights.Rows() == model_dim); + HWY_ASSERT(att_weights.Cols() == heads * qkv_dim); + HWY_ASSERT(attn_vec_einsum_w.Rows() == heads * model_dim); + HWY_ASSERT(attn_vec_einsum_w.Cols() == qkv_dim); + + { + static std::mutex m; + std::lock_guard lock(m); + mat_owners.push_back(MatOwner()); + mat_owners.back().AllocateFor(att_weights, allocator, MatPadding::kOdd); } - ReadFromBlobStore loader(weights); - ForEachType fet = - loader.HaveToc() ? ForEachType::kLoadWithToc : ForEachType::kLoadNoToc; - std::vector scales; - if (fet == ForEachType::kLoadWithToc) { - BlobError err = loader.LoadConfig(config_); - if (err != 0 || config_.model_dim == 0) { - fprintf(stderr, "Failed to load model config: %d\n", err); - return err; + + const size_t T_bytes = att_weights.ElementBytes(); + for (size_t m = 0; m < model_dim; ++m) { + uint8_t* HWY_RESTRICT out_row = att_weights.RowBytes(m); + for (size_t h = 0; h < heads; ++h) { + hwy::CopyBytes(attn_vec_einsum_w.RowBytes(h * model_dim + m), + out_row + h * qkv_dim * T_bytes, qkv_dim * T_bytes); } - if (tokenizer_proto != nullptr) { - err = loader.LoadTokenizer(*tokenizer_proto); - if (err != 0) { - fprintf(stderr, "Failed to load tokenizer: %d\n", err); - return err; - } - } - } else { - if (weight_type == Type::kUnknown || model_type == Model::UNKNOWN) { - fprintf(stderr, - "weight type (%d) and model type (%d) must be specified when " - "no config is present in weights file\n", - static_cast(weight_type), static_cast(model_type)); - return __LINE__; - } - // No Toc-> no config. - config_ = ConfigFromModel(model_type); - config_.weight = weight_type; - config_.wrapping = wrapping; - scales.resize(config_.num_tensor_scales + config_.vit_config.num_scales); } - CreateForType(config_.weight, pool); - CallForModelWeightT(fet, loader); - if (!scales.empty()) { - loader.LoadScales(scales.data(), scales.size()); - } - BlobError err = loader.ReadAll(pool, model_storage_); - if (err != 0) { - fprintf(stderr, "Failed to load model weights: %d\n", err); - return err; - } - if (!scales.empty()) { - GetOrApplyScales(scales); - } - if (fet == ForEachType::kLoadNoToc) { - PROFILER_ZONE("Startup.Reshape"); - AllocAndCopyWithTranspose(pool); - } - return 0; + att_weights.SetScale(attn_vec_einsum_w.Scale()); } -template -struct TensorSaver { - // Adds all the tensors to the blob writer. - void operator()(ModelWeightsPtrs& weights, ForEachType fet, - WriteToBlobStore& writer) { - weights.ForEachTensor( - {&weights}, fet, - [&writer](const char* name, hwy::Span tensors) { - tensors[0]->CallUpcasted(writer, name); - }); - } -}; +// For FFN. Fast, only updates pointers. +void LayerWeightsPtrs::SplitW1() { + // Used for Gemma and Griffin layers; FFWVit uses different tensors. + if (layer_config.type == LayerAttentionType::kVit) return; -BlobError ModelWeightsStorage::Save(const std::string& tokenizer, - const Path& weights, - hwy::ThreadPool& pool) { - WriteToBlobStore writer(pool); - ForEachType fet = ForEachType::kLoadWithToc; - CallForModelWeightT(fet, writer); - writer.AddTokenizer(tokenizer); - int err = writer.WriteAll(weights, &config_); - if (err != 0) { - fprintf(stderr, "Failed to write model weights: %d\n", err); - return err; - } - return 0; + // Files have both or neither of w1 and w2. + HWY_ASSERT(gating_einsum_w1.HasPtr() == gating_einsum_w2.HasPtr()); + // w is mutually exclusive with w1 and w2 in the file. + HWY_ASSERT(gating_einsum_w.HasPtr() ^ gating_einsum_w1.HasPtr()); + // Done if we already read split tensors. Note that they are not + // necessarily the same type. + if (gating_einsum_w1.HasPtr() && !gating_einsum_w.HasPtr()) return; + + const size_t ff_hidden_dim = layer_config.ff_hidden_dim; + HWY_ASSERT(gating_einsum_w.Rows() == 2 * ff_hidden_dim); + HWY_ASSERT(gating_einsum_w1.Rows() == ff_hidden_dim); + HWY_ASSERT(gating_einsum_w2.Rows() == ff_hidden_dim); + // Cols are the model_dim but we don't have ModelConfig here. + HWY_ASSERT(gating_einsum_w1.Cols() == gating_einsum_w.Cols()); + HWY_ASSERT(gating_einsum_w2.Cols() == gating_einsum_w.Cols()); + + const size_t stride = gating_einsum_w.Stride(); + gating_einsum_w1.SetPtr(gating_einsum_w.RowBytes(0), stride); + gating_einsum_w2.SetPtr(gating_einsum_w.RowBytes(ff_hidden_dim), stride); + gating_einsum_w1.SetType(gating_einsum_w.GetType()); + gating_einsum_w2.SetType(gating_einsum_w.GetType()); + gating_einsum_w1.SetScale(gating_einsum_w.Scale()); + gating_einsum_w2.SetScale(gating_einsum_w.Scale()); + gating_einsum_w.SetPtr(nullptr, gating_einsum_w.Cols()); } -void ModelWeightsStorage::Allocate(const ModelConfig& config, Type weight_type, - hwy::ThreadPool& pool) { - PROFILER_ZONE("Startup.AllocateModelWeightsPtrs"); - config_ = config; - config_.weight = weight_type; - CreateForType(weight_type, pool); - if (float_weights_) float_weights_->Allocate(model_storage_, pool); - if (bf16_weights_) bf16_weights_->Allocate(model_storage_, pool); - if (sfp_weights_) sfp_weights_->Allocate(model_storage_, pool); - if (nuq_weights_) nuq_weights_->Allocate(model_storage_, pool); +// For attention, which might not have a w2. Fast, only updates pointers. +void LayerWeightsPtrs::SplitAttW1() { + // We only use this tensor for Gemma layers. + if (layer_config.type != LayerAttentionType::kGemma) return; + + // w is mutually exclusive with w1 in the file. + HWY_ASSERT(qkv_einsum_w.HasPtr() ^ qkv_einsum_w1.HasPtr()); + // Done if we already read split tensors. Note that w2 does not exist for + // MHA, and otherwise might not be the same type. + if (qkv_einsum_w1.HasPtr() && !qkv_einsum_w.HasPtr()) return; + + const size_t w1_rows = layer_config.heads * layer_config.qkv_dim; + const size_t w2_rows = layer_config.kv_heads * 2 * layer_config.qkv_dim; + HWY_ASSERT(qkv_einsum_w.Rows() == w1_rows + w2_rows); + HWY_ASSERT(qkv_einsum_w1.Rows() == w1_rows); + HWY_ASSERT(qkv_einsum_w2.Rows() == w2_rows); + // Cols are the model_dim but we don't have ModelConfig here. + HWY_ASSERT(qkv_einsum_w1.Cols() == qkv_einsum_w.Cols()); + HWY_ASSERT(qkv_einsum_w2.Cols() == qkv_einsum_w.Cols()); + + const size_t stride = qkv_einsum_w.Stride(); + qkv_einsum_w1.SetPtr(qkv_einsum_w.RowBytes(0), stride); + qkv_einsum_w2.SetPtr(qkv_einsum_w.RowBytes(w1_rows), stride); + qkv_einsum_w1.SetType(qkv_einsum_w.GetType()); + qkv_einsum_w2.SetType(qkv_einsum_w.GetType()); + qkv_einsum_w1.SetScale(qkv_einsum_w.Scale()); + qkv_einsum_w2.SetScale(qkv_einsum_w.Scale()); + qkv_einsum_w.SetPtr(nullptr, qkv_einsum_w.Cols()); } -class WeightInitializer { - public: - WeightInitializer(std::mt19937& gen) : dist_(0.0f, 1.0f), gen_(gen) {} - - void operator()(const char* name, hwy::Span tensors) { - float* data = tensors[0]->data(); - for (size_t i = 0; i < tensors[0]->NumElements(); ++i) { - data[i] = dist_(gen_); - } - tensors[0]->set_scale(1.0f); - } - - private: - std::normal_distribution dist_; - std::mt19937& gen_; -}; - -void ModelWeightsStorage::RandInit(std::mt19937& gen) { - HWY_ASSERT(float_weights_); - WeightInitializer init(gen); - ModelWeightsPtrs::ForEachTensor({float_weights_.get()}, - ForEachType::kLoadNoToc, init); +// Must be called after reading weights via `ForEachTensor`. +// TODO: exporters should bake this into the weights already. +// WARNING: called from multiple threads; `mat_owners` requires a lock. +void LayerWeightsPtrs::Fixup(std::vector& mat_owners, + const Allocator& allocator) { + // TODO(janwas): handle NUQ + InitAttWeights(mat_owners, allocator); + SplitW1(); + SplitAttW1(); } -void ModelWeightsStorage::ZeroInit() { - if (float_weights_) float_weights_->ZeroInit(); - if (bf16_weights_) bf16_weights_->ZeroInit(); - if (sfp_weights_) sfp_weights_->ZeroInit(); - if (nuq_weights_) nuq_weights_->ZeroInit(); -} +static void HWY_MAYBE_UNUSED InitAttWeightsNUQ( + const LayerConfig& layer_config, MatPtrT& attn_vec_einsum_w, + MatPtrT& att_weights, std::vector& mat_owners) { + if (!attn_vec_einsum_w.HasPtr()) return; + HWY_ASSERT(attn_vec_einsum_w.GetType() == Type::kNUQ); -void ModelWeightsStorage::GetOrApplyScales(std::vector& scales) { - if (float_weights_) float_weights_->GetOrApplyScales(scales); - if (bf16_weights_) bf16_weights_->GetOrApplyScales(scales); - if (sfp_weights_) sfp_weights_->GetOrApplyScales(scales); - if (nuq_weights_) nuq_weights_->GetOrApplyScales(scales); -} - -void ModelWeightsStorage::AllocAndCopyWithTranspose(hwy::ThreadPool& pool) { - if (float_weights_) - float_weights_->AllocAndCopyWithTranspose(pool, model_storage_); - if (bf16_weights_) - bf16_weights_->AllocAndCopyWithTranspose(pool, model_storage_); - if (sfp_weights_) - sfp_weights_->AllocAndCopyWithTranspose(pool, model_storage_); - if (nuq_weights_) - nuq_weights_->AllocAndCopyWithTranspose(pool, model_storage_); -} - -void ModelWeightsStorage::CopyWithTranspose(hwy::ThreadPool& pool) { - if (float_weights_) float_weights_->CopyWithTranspose(pool); - if (bf16_weights_) bf16_weights_->CopyWithTranspose(pool); - if (sfp_weights_) sfp_weights_->CopyWithTranspose(pool); - if (nuq_weights_) nuq_weights_->CopyWithTranspose(pool); -} - -namespace { - -void LogVec(const char* name, const float* data, size_t len) { - hwy::Stats stats; - for (size_t i = 0; i < len; ++i) { - stats.Notify(data[i]); - } - printf("%-20s %12zu %13.10f %8.5f %13.10f\n", - name, len, stats.Min(), stats.Mean(), stats.Max()); -} - -} // namespace - -void ModelWeightsStorage::LogWeightStats() { - size_t total_weights = 0; - // Only for float weights. - ModelWeightsPtrs::ForEachTensor( - {float_weights_.get()}, ForEachType::kInitNoToc, - [&total_weights](const char* name, hwy::Span tensors) { - const MatPtr& tensor = *tensors[0]; - if (tensor.scale() != 1.0f) { - printf("[scale=%f] ", tensor.scale()); - } - LogVec(name, tensor.data(), tensor.NumElements()); - total_weights += tensor.NumElements(); - }); - printf("%-20s %12zu\n", "Total", total_weights); -} - -void ModelWeightsStorage::CreateForType(Type weight_type, - hwy::ThreadPool& pool) { - switch (weight_type) { - case Type::kF32: - float_weights_ = std::make_unique>(config_); - break; - case Type::kBF16: - bf16_weights_ = std::make_unique>(config_); - break; - case Type::kSFP: - sfp_weights_ = - std::make_unique>(config_); - break; - case Type::kNUQ: - nuq_weights_ = - std::make_unique>(config_); - break; - default: - HWY_ABORT("Weight type %d unsupported.", static_cast(weight_type)); - } -} - -template <> -void LayerWeightsPtrs::Reshape(MatStorage* storage) { - if (attn_vec_einsum_w.data() == nullptr) return; + HWY_ASSERT(att_weights.HasPtr()); + att_weights.SetType(Type::kNUQ); const size_t model_dim = layer_config.model_dim; const size_t heads = layer_config.heads; const size_t qkv_dim = layer_config.qkv_dim; // Reshape [kHeads, kModelDim, kQKVDim] to [kModelDim, kHeads * kQKVDim]. - if (storage != nullptr) { - storage->Allocate(); - att_weights.SetPtr(*storage); - } - - const hwy::HWY_NAMESPACE::ScalableTag df; - hwy::AlignedFreeUniquePtr attn_vec_einsum_w_tmp = hwy::AllocateAligned(model_dim * heads * qkv_dim); hwy::AlignedFreeUniquePtr att_weights_tmp = hwy::AllocateAligned(model_dim * heads * qkv_dim); - HWY_NAMESPACE::DecompressAndZeroPad( - df, MakeSpan(attn_vec_einsum_w.data(), model_dim * heads * qkv_dim), 0, - attn_vec_einsum_w_tmp.get(), model_dim * heads * qkv_dim); + const hwy::HWY_NAMESPACE::ScalableTag df; + HWY_NAMESPACE::DecompressAndZeroPad(df, attn_vec_einsum_w.Span(), 0, + attn_vec_einsum_w_tmp.get(), + model_dim * heads * qkv_dim); for (size_t m = 0; m < model_dim; ++m) { float* HWY_RESTRICT out_row = att_weights_tmp.get() + m * heads * qkv_dim; @@ -293,13 +193,360 @@ void LayerWeightsPtrs::Reshape(MatStorage* storage) { CompressWorkingSet work; hwy::ThreadPool pool(0); + HWY_NAMESPACE::Compress(att_weights_tmp.get(), model_dim * heads * qkv_dim, + work, att_weights.Span(), + /*packed_ofs=*/0, pool); - HWY_NAMESPACE::Compress( - att_weights_tmp.get(), model_dim * heads * qkv_dim, work, - MakeSpan(att_weights.data(), model_dim * heads * qkv_dim), - /*packed_ofs=*/0, pool); + att_weights.SetScale(attn_vec_einsum_w.Scale()); +} - att_weights.set_scale(attn_vec_einsum_w.scale()); +static void HWY_MAYBE_UNUSED SplitW1NUQ(const LayerConfig& layer_config) { + // TODO(janwas): implement. +} + +// Zero-initializes only the allocated tensors in `*this`. +void WeightsPtrs::ZeroInit() { + ForEachTensor(nullptr, nullptr, [](const TensorArgs& t) { + if (!t.mat.HasPtr()) return; + gcpp::ZeroInit(t.mat); + }); +} + +// Copies only the allocated tensors in `*this` from tensors in `other`. +void WeightsPtrs::CopyFrom(const WeightsPtrs& other) { + ForEachTensor(const_cast(&other), nullptr, + [](const TensorArgs& t) { + if (!t.mat.HasPtr()) return; + HWY_ASSERT(t.other_mat1 && t.other_mat1->HasPtr()); + CopyMat(*t.other_mat1, t.mat); + }); +} + +// For reshaping file tensors to the shape expected by the code. This would +// ideally already happen in the importer. Called by `ReadFromBlobs`. +void WeightsPtrs::Fixup(std::vector& mat_owners, + ThreadingContext& ctx) { + // TODO: use 1D parallel-for helper function + hwy::ThreadPool& pool = ctx.pools.Pool(); + pool.Run(0, c_layers.size(), [&](uint64_t layer, size_t /*thread*/) { + GetLayer(layer)->Fixup(mat_owners, ctx.allocator); + }); + + pool.Run(0, vit_layers.size(), [&](uint64_t layer, size_t /*thread*/) { + VitLayer(layer)->Fixup(mat_owners, ctx.allocator); + }); +} + +std::vector WeightsPtrs::AddTensorDataToWriter( + BlobWriter& writer) const { + std::vector serialized_mat_ptrs; + // ForEachTensor is non-const but the lambda does not modify *this. + const_cast(this)->ForEachTensor( + nullptr, nullptr, [&](const TensorArgs& t) { + if (t.flags & TensorArgs::kMaybeRead && !t.mat.HasPtr()) return; + HWY_ASSERT_M(t.mat.HasPtr(), t.mat.Name()); + writer.Add(t.mat.Name(), t.mat.Packed(), t.mat.PackedBytes()); + t.mat.AppendTo(serialized_mat_ptrs); + }); + return serialized_mat_ptrs; +} + +// Decides whether to read or map based on heuristics and user override. +static WeightsPtrs::Mode ChooseMode(uint64_t file_bytes, + const LoaderArgs& loader, + const InferenceArgs& inference, + const Allocator& allocator) { + Tristate to_bf16 = loader.to_bf16; + Tristate map = loader.map; + + // Disable mapping if not padded to the base page size. + if (file_bytes % allocator.BasePageBytes() != 0) { + if (map == Tristate::kTrue) { // Only complain if explicitly requested. + HWY_WARN("Unable to map non-padded file (%zu, %zu), reading instead.", + static_cast(file_bytes >> 10), + allocator.BasePageBytes()); + } + map = Tristate::kFalse; + } + + // Check for user override: + if (to_bf16 == Tristate::kTrue && map == Tristate::kTrue) { + HWY_WARN("Cannot have to_bf16 && map, to_bf16 takes precedence."); + } + if (to_bf16 == Tristate::kTrue) return WeightsPtrs::Mode::kReadBF16; + if (map == Tristate::kTrue) return WeightsPtrs::Mode::kMap; + + if (to_bf16 == Tristate::kDefault) { + // Heuristic: sub-bf16 compression is not helpful if compute-bound. + const size_t batch_size = + HWY_MAX(inference.prefill_tbatch_size, inference.decode_qbatch_size); + to_bf16 = (batch_size >= 128) ? Tristate::kTrue : Tristate::kFalse; + } + + if (map == Tristate::kDefault) { + // Heuristic: map if large fraction of total. Do not decide based on + // `FreeMiB` because it is generally low. + const size_t file_mib = file_bytes >> 20; + const size_t total_mib = allocator.TotalMiB(); + if (file_mib > total_mib) { + HWY_WARN("Weight file %zu MiB > detected memory %zu MiB.", + static_cast(file_mib), total_mib); + } + // Large fraction of total. + map = (file_mib >= total_mib / 3) ? Tristate::kTrue : Tristate::kFalse; + } + + // If the `map` heuristic triggers, use that for safety. + if (map == Tristate::kTrue) return WeightsPtrs::Mode::kMap; + return (to_bf16 == Tristate::kTrue) ? WeightsPtrs::Mode::kReadBF16 + : WeightsPtrs::Mode::kRead; +} + +struct TensorToRead { + MatPtr* mat; + BlobRange range; + // Some tensors opt out of padding via kPacked flags. + MatPadding padding; + + // only for kReadBF16 + bool keep_type = false; + Type prev_type; +}; + +// Allocates multiple in parallel and binds to NUMA nodes. +static void AllocateAndBindAll(std::vector& tensors, + const WeightsPtrs::Mode mode, + std::vector& owners, + ThreadingContext& ctx) { + const size_t start = owners.size(); + owners.resize(start + tensors.size()); + + MMParallel parallel(ctx); + + // Allocate in parallel because faulting in large tensors is slow. + ctx.pools.Pool().Run( + 0, tensors.size(), [&](uint64_t task, size_t /*thread*/) { + TensorToRead& tensor = tensors[task]; + MatPtr& mat = *tensor.mat; + + tensor.prev_type = mat.GetType(); + // We only care about MatMul inputs; skip F32 or small tensors. + if (tensor.prev_type == Type::kF32 || mat.Rows() < 1024) { + tensor.keep_type = true; + tensor.padding = MatPadding::kPacked; // single I/O for simplicity + } else if (mode == WeightsPtrs::Mode::kReadBF16) { + mat.SetType(Type::kBF16); + } + + owners[start + task].AllocateFor(*tensor.mat, ctx.allocator, + tensor.padding); + BindB(*tensor.mat, tensor.mat->ElementBytes(), parallel); + }); +} + +// Mode == kMap +static void MapAll(const std::vector& tensors, + const MapPtr& mapped, uint64_t file_bytes) { + PROFILER_ZONE("Startup.Weights.Map"); + for (size_t i = 0; i < tensors.size(); ++i) { + // SetPtr does not change the stride, but it is expected to be packed + // because that is what Compress() writes to the file. + const size_t mat_bytes = tensors[i].mat->PackedBytes(); + // Ensure blob size matches that computed from metadata. + HWY_ASSERT_M(mat_bytes == tensors[i].range.bytes, tensors[i].mat->Name()); + // Ensure the blob lies within the file mapping. + const uint64_t offset = tensors[i].range.offset; + HWY_ASSERT_M(offset + mat_bytes <= file_bytes, tensors[i].mat->Name()); + + tensors[i].mat->SetPtr(const_cast(mapped.get() + offset), + tensors[i].mat->Stride()); + } +} + +// Mode == kReadBF16: + +template +static void DecompressToBF16(MatPtr& mat, + const hwy::AlignedFreeUniquePtr& buf) { + hwy::HWY_NAMESPACE::ScalableTag dbf; + const size_t cols = mat.Cols(); + + const size_t num_packed = CompressedArrayElements(mat.Extents().Area()); + const PackedSpan packed{HWY_RCAST_ALIGNED(T*, buf.get()), num_packed}; + + size_t packed_ofs = 0; + for (size_t r = 0; r < mat.Rows(); ++r, packed_ofs += cols) { + HWY_NAMESPACE::DecompressAndZeroPad( + dbf, packed, packed_ofs, HWY_RCAST_ALIGNED(BF16*, mat.RowBytes(r)), + cols); + } +} + +static void ReadAllToBF16(const std::vector& tensors, + const BlobReader& reader, hwy::ThreadPool& pool) { + pool.Run(0, tensors.size(), [&](uint64_t task, size_t thread) { + PROFILER_ZONE2(thread, "Startup.Weights.ReadBF16"); + const TensorToRead& tensor = tensors[task]; + MatPtr& mat = *tensor.mat; + + if (tensor.keep_type) { + HWY_ASSERT(reader.file().Read(tensor.range.offset, tensor.range.bytes, + mat.Packed())); + return; + } + + // Read to a temporary buffer. + const hwy::AlignedFreeUniquePtr buf = + hwy::AllocateAligned(tensor.range.bytes); + HWY_ASSERT( + reader.file().Read(tensor.range.offset, tensor.range.bytes, buf.get())); + + if constexpr (GEMMA_ENABLE_NUQ) { + if (tensor.prev_type == Type::kNUQ) { + return DecompressToBF16(*tensor.mat, buf); + } + } + switch (tensor.prev_type) { + case Type::kF32: + return DecompressToBF16(*tensor.mat, buf); + case Type::kBF16: + return DecompressToBF16(*tensor.mat, buf); + case Type::kSFP: + return DecompressToBF16(*tensor.mat, buf); + default: + HWY_ABORT("Unsupported type %s", TypeName(tensor.prev_type)); + } + }); +} + +// Mode == kRead: + +static std::vector MakeBatches( + const std::vector& tensors, const uint64_t file_bytes) { + PROFILER_ZONE("Startup.Weights.MakeBatches"); + // Batches must be contiguous but blobs are padded, hence at least one + // batch per tensor, and more when tensor rows exceed the batch size. + std::vector batches; + batches.reserve(tensors.size()); + + for (size_t i = 0; i < tensors.size(); ++i) { + const BlobRange& range = tensors[i].range; + MatPtr& mat = *tensors[i].mat; + uint64_t offset = range.offset; + HWY_ASSERT(range.End() <= file_bytes); + + batches.emplace_back(offset, range.key_idx); + const size_t file_bytes_per_row = mat.Cols() * mat.ElementBytes(); + const size_t mem_stride_bytes = mat.Stride() * mat.ElementBytes(); + uint8_t* row_bytes = mat.RowBytes(0); + for (size_t r = 0; r < mat.Rows(); ++r) { + if (!batches.back().Add(row_bytes, file_bytes_per_row)) { // Full batch. + batches.emplace_back(offset, range.key_idx); + // Adding to an empty batch is always successful. + HWY_ASSERT(batches.back().Add(row_bytes, file_bytes_per_row)); + } + offset += file_bytes_per_row; + // Must zero-initialize the in-memory row padding, see MatMul. + hwy::ZeroBytes(row_bytes + file_bytes_per_row, + mem_stride_bytes - file_bytes_per_row); + row_bytes += mem_stride_bytes; + } + HWY_ASSERT(offset == range.End()); + } + + HWY_ASSERT(batches.size() >= tensors.size()); + return batches; +} + +// Parallel synchronous I/O. Note that O_DIRECT seems undesirable because we +// want to use the OS cache between consecutive runs. +static void ReadBatches(const BlobReader& reader, + const std::vector& batches, + hwy::ThreadPool& pool) { + // >5x speedup from parallel reads when cached. + pool.Run(0, batches.size(), [&](uint64_t i, size_t thread) { + PROFILER_ZONE2(thread, "Startup.Weights.Read"); + const IOBatch& batch = batches[i]; + const std::string& key = reader.Keys()[batch.KeyIdx()]; + const uint64_t bytes_read = batch.Read(reader.file()); + if (bytes_read != batch.TotalBytes()) { + HWY_ABORT("Read failed for %s from %zu, %zu bytes; got %zu.", key.c_str(), + static_cast(batch.Offset()), + static_cast(batch.TotalBytes()), + static_cast(bytes_read)); + } + }); +} + +// Aborts on error. Updates `mode` to the actual mode used. Returns mapped +// memory or nullptr if `kMap` was not used. +static MapPtr MapOrReadAll(std::vector& tensors, + BlobReader& reader, WeightsPtrs::Mode* mode, + std::vector& mat_owners, + ThreadingContext& ctx) { + if (*mode == WeightsPtrs::Mode::kMap) { + if (MapPtr mapped = reader.Map()) { + MapAll(tensors, mapped, reader.file().FileSize()); + return mapped; + } + HWY_WARN("Failed to map file (%zu KiB), reading instead.", + static_cast(reader.file_bytes() >> 10)); + // If we wanted to map but failed, memory is probably not plentiful, so + // fall through to kRead because kReadBF16 requires more memory. + *mode = WeightsPtrs::Mode::kRead; + } + + { + PROFILER_ZONE("Startup.Weights.Allocate"); + // NOTE: this changes the stride of `mats`! + AllocateAndBindAll(tensors, *mode, mat_owners, ctx); + } + + hwy::ThreadPool& pool = ctx.pools.Pool(); + + if (*mode == WeightsPtrs::Mode::kReadBF16) { + ReadAllToBF16(tensors, reader, pool); + return MapPtr(); + } + + const std::vector batches = + MakeBatches(tensors, reader.file_bytes()); + ReadBatches(reader, batches, pool); + return MapPtr(); +} + +WeightsPtrs::Mode WeightsPtrs::ReadFromBlobs(const ModelStore& model, + BlobReader& reader, + const LoaderArgs& loader, + const InferenceArgs& inference, + std::vector& mat_owners, + ThreadingContext& ctx) { + // List of tensors to read/map, and where from. + std::vector tensors; + + // Enumerate all weights (negligible cost). + ForEachTensor(nullptr, nullptr, [&](const TensorArgs& t) { + const MatPadding padding = (t.flags & TensorArgs::kPacked) + ? MatPadding::kPacked + : MatPadding::kOdd; + size_t key_idx; + if (model.FindAndUpdateMatPtr(t.mat, key_idx)) { + tensors.push_back( + {.mat = &t.mat, .range = reader.Range(key_idx), .padding = padding}); + return; + } + if (t.flags & TensorArgs::kMaybeRead) return; // optional and not found. + HWY_ABORT("Tensor %s is required but not found in file.", t.mat.Name()); + }); + + Mode mode = ChooseMode(reader.file_bytes(), loader, inference, ctx.allocator); + mapped_ = MapOrReadAll(tensors, reader, &mode, mat_owners, ctx); + + { + PROFILER_ZONE("Startup.Fixup"); + Fixup(mat_owners, ctx); + } + return mode; } } // namespace gcpp diff --git a/gemma/weights.h b/gemma/weights.h index 5fd544b..de3652a 100644 --- a/gemma/weights.h +++ b/gemma/weights.h @@ -17,607 +17,455 @@ #define THIRD_PARTY_GEMMA_CPP_GEMMA_WEIGHTS_H_ #include +#include -#include -#include -#include -#include #include -#include #include -#include "compression/compress.h" -#include "compression/shared.h" -#include "gemma/common.h" -#include "gemma/configs.h" -#include "gemma/tensor_index.h" -#include "hwy/aligned_allocator.h" -#include "hwy/base.h" -#include "hwy/contrib/thread_pool/thread_pool.h" +#include "compression/types.h" +#include "gemma/configs.h" // ModelConfig +#include "gemma/gemma_args.h" // InferenceArgs +#include "gemma/model_store.h" // ModelStore +#include "gemma/tensor_info.h" // TensorInfoRegistry +#include "io/blob_store.h" // BlobWriter +#include "util/mat.h" // MatPtr +#include "util/threading_context.h" namespace gcpp { -// Different tensors need to appear in a ForEachTensor, according to what is -// happening. -enum class ForEachType { - // Under normal circumstances, when not initializing or loading, we can - // include all tensors and ignore the null ones. - kIgnoreNulls, - // If there is a table of contents, we can include all tensors. - kLoadWithToc, - // There is no table of contents, so we have to be careful to only include - // tensors that are actually present. - kLoadNoToc, - // We need to initialize all tensors needed when there is no table of - // contents. This differs from kLoadNoToc in that we need to include any - // tensor that is allocated but not loaded directly from file. - kInitNoToc, +// Argument passed to the `ForEachTensor` callback. +struct TensorArgs { + // `other_mat1` and `other_mat2` can be nullptr, or tensor(s) of the same + // name/type from another `LayerWeightsPtrs` for iterating over tensor pairs + // (for copying) or triples (for `AdamUpdateMV`). Set by `TENSOR_ARGS`. + // `flags` is a combination of zero or more `Flags`. + TensorArgs(MatPtr& mat, MatPtr* other_mat1, MatPtr* other_mat2, int flags) + : mat(mat), + other_mat1(other_mat1), + other_mat2(other_mat2), + flags(flags) {} + + MatPtr& mat; + MatPtr* other_mat1; // either/both can be nullptr. + MatPtr* other_mat2; + + enum Flags { + // Default: Read the tensor from the file and abort if it is not found. + kMustRead = 0, + + // Not an error if the tensor is not present in the file. For example, + // the _w1/_w2 tensors are not always present. + kMaybeRead = 1, + + // Avoid padding tensor rows when reading. Used for some Griffin tensors + // whose index computations do not use Row() accessors. + kPacked = 2, + }; + const int flags; }; -template +// Shorthand for creating the argument to the `ForEachTensor` callback. A macro +// seems less bad than member pointer syntax. +#define TENSOR_ARGS(mat, flag) \ + TensorArgs(mat, other1 ? &other1->mat : nullptr, \ + other2 ? &other2->mat : nullptr, TensorArgs::flag) + +// Finds tensors by name in `TensorInfoRegistry` (constructed from +// `ModelConfig`) and constructs `MatPtr` metadata with those shapes. +class MatFinder { + public: + MatFinder(const std::string& suffix, const TensorInfoRegistry& tensors) + : suffix_(suffix), tensors_(tensors) {} + + // Retrieves shape by name via `TensorInfo` from `TensorInfoRegistry`. + MatPtr operator()(const std::string& base_name) const { + const std::string name = std::string(base_name) + suffix_; + return MatPtr(name.c_str(), Type::kUnknown, + ExtentsFromInfo(tensors_.Find(name))); + } + + private: + const std::string suffix_; + const TensorInfoRegistry& tensors_; +}; + +// Per-layer weight metadata and pointers. The tensor data is owned by +// `MatOwner`. struct LayerWeightsPtrs { - // Large data is constructed separately. - explicit LayerWeightsPtrs(const LayerConfig& config, - const TensorIndex& tensor_index) - : attn_vec_einsum_w("att_ein", tensor_index), - qkv_einsum_w("qkv_ein", tensor_index), - qkv_einsum_w1("qkv1_w", tensor_index), - qkv_einsum_w2("qkv2_w", tensor_index), - attention_output_biases("attn_ob", tensor_index), - griffin({.linear_x_w = {"gr_lin_x_w", tensor_index}, - .linear_x_biases = {"gr_lin_x_b", tensor_index}, - .linear_y_w = {"gr_lin_y_w", tensor_index}, - .linear_y_biases = {"gr_lin_y_b", tensor_index}, - .linear_out_w = {"gr_lin_out_w", tensor_index}, - .linear_out_biases = {"gr_lin_out_b", tensor_index}, - .conv_w = {"gr_conv_w", tensor_index}, - .conv_biases = {"gr_conv_b", tensor_index}, - .gate_w = {"gr_gate_w", tensor_index}, - .gate_biases = {"gr_gate_b", tensor_index}, - .a = {"gr_a", tensor_index}}), + // Initializes tensor metadata without allocating. + // NOTE: do not store layer_idx, TransformerLayer and Attention may use + // other values for purposes of the KV cache. + LayerWeightsPtrs(size_t layer_idx, const LayerConfig& config, + const TensorInfoRegistry& tensors) + : finder_(LayerSuffix(layer_idx), tensors), + qkv_einsum_w(finder_("qkv_ein")), + qkv_einsum_w1(finder_("qkv1_w")), + qkv_einsum_w2(finder_("qkv2_w")), + attention_output_biases(finder_("attn_ob")), + griffin({.linear_x_w = finder_("gr_lin_x_w"), + .linear_x_biases = finder_("gr_lin_x_b"), + .linear_y_w = finder_("gr_lin_y_w"), + .linear_y_biases = finder_("gr_lin_y_b"), + .linear_out_w = finder_("gr_lin_out_w"), + .linear_out_biases = finder_("gr_lin_out_b"), + .conv_w = finder_("gr_conv_w"), + .conv_biases = finder_("gr_conv_b"), + .gate_w = finder_("gr_gate_w"), + .gate_biases = finder_("gr_gate_b"), + .a = finder_("gr_a")}), // MultiHeadDotProductAttention. - vit({.attn_out_w = {"attn_out_w", tensor_index}, - .attn_out_b = {"attn_out_b", tensor_index}, - .qkv_einsum_w = {"qkv_ein_w", tensor_index}, - .qkv_einsum_b = {"qkv_ein_b", tensor_index}, - .linear_0_w = {"linear_0_w", tensor_index}, - .linear_0_b = {"linear_0_b", tensor_index}, - .linear_1_w = {"linear_1_w", tensor_index}, - .linear_1_b = {"linear_1_b", tensor_index}, - .layer_norm_0_bias = {"ln_0_bias", tensor_index}, - .layer_norm_0_scale = {"ln_0_scale", tensor_index}, - .layer_norm_1_bias = {"ln_1_bias", tensor_index}, - .layer_norm_1_scale = {"ln_1_scale", tensor_index}}), - gating_einsum_w("gating_ein", tensor_index), - gating_einsum_w1("gating1_w", tensor_index), - gating_einsum_w2("gating2_w", tensor_index), - linear_w("linear_w", tensor_index), - pre_attention_norm_scale("pre_att_ns", tensor_index), - pre_ffw_norm_scale("pre_ff_ns", tensor_index), - post_attention_norm_scale("post_att_ns", tensor_index), - post_ffw_norm_scale("post_ff_ns", tensor_index), - ffw_gating_biases("ffw_gat_b", tensor_index), - ffw_output_biases("ffw_out_b", tensor_index), - att_weights("att_w", tensor_index), - key_norm_scale("key_norm", tensor_index), - query_norm_scale("query_norm", tensor_index), - layer_config(config) {} + vit({.attn_out_w = finder_("attn_out_w"), + .attn_out_b = finder_("attn_out_b"), + .qkv_einsum_w = finder_("qkv_ein_w"), + .qkv_einsum_b = finder_("qkv_ein_b"), + .linear_0_w = finder_("linear_0_w"), + .linear_0_b = finder_("linear_0_b"), + .linear_1_w = finder_("linear_1_w"), + .linear_1_b = finder_("linear_1_b"), + .layer_norm_0_bias = finder_("ln_0_bias"), + .layer_norm_0_scale = finder_("ln_0_scale"), + .layer_norm_1_bias = finder_("ln_1_bias"), + .layer_norm_1_scale = finder_("ln_1_scale")}), + gating_einsum_w(finder_("gating_ein")), + gating_einsum_w1(finder_("gating1_w")), + gating_einsum_w2(finder_("gating2_w")), + linear_w(finder_("linear_w")), + pre_attention_norm_scale(finder_("pre_att_ns")), + pre_ffw_norm_scale(finder_("pre_ff_ns")), + post_attention_norm_scale(finder_("post_att_ns")), + post_ffw_norm_scale(finder_("post_ff_ns")), + ffw_gating_biases(finder_("ffw_gat_b")), + ffw_output_biases(finder_("ffw_out_b")), + + attn_vec_einsum_w(finder_("att_ein")), + att_weights(finder_("att_w")), + + key_norm_scale(finder_("key_norm")), + query_norm_scale(finder_("query_norm")), + + layer_config(config) { + } ~LayerWeightsPtrs() = default; - // If weights are f32, also f32; otherwise at least bf16. Useful for ops that - // do not yet support smaller compressed types, or require at least bf16. When - // weights are f32, we also want such tensors to be f32. - // If weights are complex, this is also complex. - using WeightF32OrBF16 = - hwy::If>(), std::complex, - hwy::If(), double, - hwy::If(), float, BF16>>>; + const MatFinder finder_; - template - using ArrayT = MatPtrT; - - ArrayT attn_vec_einsum_w; - // qkv_einsum_w holds 2 different matrices, which may be separated out. - // On loading, which is used depends on what is in the file. - // At inference, the one with a non-null ptr is used. - ArrayT qkv_einsum_w; - ArrayT qkv_einsum_w1; - ArrayT qkv_einsum_w2; - ArrayT attention_output_biases; + // Files either have qkv_einsum_w with 2 stacked matrices or separate + // w1/w2 tensors. Fixup ensures w1/w2 are ready for use by gemma-inl.h. + MatPtr qkv_einsum_w; + MatPtr qkv_einsum_w1; + MatPtr qkv_einsum_w2; + MatPtrT attention_output_biases; struct { - ArrayT linear_x_w; - ArrayT linear_x_biases; - ArrayT linear_y_w; - ArrayT linear_y_biases; - ArrayT linear_out_w; - ArrayT linear_out_biases; - ArrayT conv_w; - ArrayT conv_biases; - ArrayT gate_w; - ArrayT gate_biases; - ArrayT a; + MatPtr linear_x_w; + MatPtrT linear_x_biases; + MatPtr linear_y_w; + MatPtrT linear_y_biases; + MatPtr linear_out_w; + MatPtrT linear_out_biases; + MatPtrT conv_w; + MatPtrT conv_biases; + MatPtr gate_w; + MatPtrT gate_biases; + MatPtrT a; } griffin; struct { // MultiHeadDotProductAttention. - ArrayT attn_out_w; - ArrayT attn_out_b; - ArrayT qkv_einsum_w; - ArrayT qkv_einsum_b; + MatPtr attn_out_w; // at least BF16. + MatPtrT attn_out_b; + MatPtr qkv_einsum_w; // at least BF16. + MatPtrT qkv_einsum_b; // MlpBlock. - ArrayT linear_0_w; - ArrayT linear_0_b; - ArrayT linear_1_w; - ArrayT linear_1_b; + MatPtr linear_0_w; // at least BF16. + MatPtrT linear_0_b; + MatPtr linear_1_w; // at least BF16. + MatPtrT linear_1_b; // LayerNorm. - ArrayT layer_norm_0_bias; - ArrayT layer_norm_0_scale; - ArrayT layer_norm_1_bias; - ArrayT layer_norm_1_scale; + MatPtr layer_norm_0_bias; // at least BF16. + MatPtr layer_norm_0_scale; // at least BF16. + MatPtr layer_norm_1_bias; // at least BF16. + MatPtr layer_norm_1_scale; // at least BF16. } vit; - // gating_einsum_w holds 2 different matrices, which may be separated out. - // On loading, which is used depends on what is in the file. - // At inference, the one with a non-null ptr is used. - ArrayT gating_einsum_w; - ArrayT gating_einsum_w1; - ArrayT gating_einsum_w2; - ArrayT linear_w; - // We don't yet have an RMSNorm that accepts all Weight. - ArrayT pre_attention_norm_scale; - ArrayT pre_ffw_norm_scale; - ArrayT post_attention_norm_scale; - ArrayT post_ffw_norm_scale; + // Files either have gating_einsum_w with 2 stacked matrices or separate + // w1/w2 tensors. `Fixup` ensures w1/w2 are ready for use by gemma-inl.h. + MatPtr gating_einsum_w; + MatPtr gating_einsum_w1; + MatPtr gating_einsum_w2; + MatPtr linear_w; + MatPtr pre_attention_norm_scale; // at least BF16. + MatPtr pre_ffw_norm_scale; // at least BF16. + MatPtr post_attention_norm_scale; // at least BF16. + MatPtr post_ffw_norm_scale; // at least BF16. - ArrayT ffw_gating_biases; - ArrayT ffw_output_biases; + MatPtrT ffw_gating_biases; + MatPtrT ffw_output_biases; - // Reshaped attention; not loaded from disk via ForEachTensor. - ArrayT att_weights; + MatPtr attn_vec_einsum_w; // Use att_weights instead of this. + MatPtr att_weights; // Use this instead of attn_vec_einsum_w. + + MatPtr key_norm_scale; // at least BF16. + MatPtr query_norm_scale; // at least BF16. const LayerConfig& layer_config; - // Initializes att_weights from attn_vec_einsum_w, hence this must be called - // after loading weights via ForEachTensor. - // TODO: update compression/convert_weights to bake this in. - void Reshape(MatStorage* storage) { - static_assert(!hwy::IsSame()); - - if (attn_vec_einsum_w.data() == nullptr) return; - - const size_t model_dim = layer_config.model_dim; - const size_t heads = layer_config.heads; - const size_t qkv_dim = layer_config.qkv_dim; - - // Reshape [kHeads, kModelDim, kQKVDim] to [kModelDim, kHeads * kQKVDim]. - if (storage != nullptr) { - storage->Allocate(); - att_weights.SetPtr(*storage); - } - for (size_t m = 0; m < model_dim; ++m) { - Weight* HWY_RESTRICT out_row = att_weights.data() + m * heads * qkv_dim; - for (size_t h = 0; h < heads; ++h) { - hwy::CopyBytes( - attn_vec_einsum_w.data() + h * model_dim * qkv_dim + m * qkv_dim, - out_row + h * qkv_dim, qkv_dim * sizeof(Weight)); - } - } - att_weights.set_scale(attn_vec_einsum_w.scale()); - } - - ArrayT key_norm_scale; - ArrayT query_norm_scale; - -// Used by ForEachTensor for per-layer tensors. -#define GEMMA_CALL_FUNC(member) \ - { \ - for (int i = 0; i < ptrs.size(); ++i) { \ - tensors[i] = &ptrs[i]->member; \ - } \ - if (tensors[0]->Ptr() != nullptr || fet != ForEachType::kIgnoreNulls) { \ - func(ptrs[0]->member.CacheName(layer_idx, sep, sep_index).c_str(), \ - hwy::Span(tensors.data(), ptrs.size())); \ - } \ - } - + // Calls `func(TensorArgs)` for each tensor which is in use for the + // current `layer_config`. `other1` and `other2` are optional arguments so we + // can also iterate over pairs or triples of tensors for `AdamUpdateMV`. + // Public because also called by `WeightsPtrs`. template - static void ForEachTensor(const std::vector*>& ptrs, - int layer_idx, ForEachType fet, Func func, - char sep = ' ', int sep_index = -1) { - std::vector tensors(ptrs.size(), nullptr); - auto type = ptrs[0]->layer_config.type; - if (type == LayerAttentionType::kVit) { + void ForEachTensor(LayerWeightsPtrs* other1, LayerWeightsPtrs* other2, + Func func) { + if (layer_config.type == LayerAttentionType::kVit) { // MHA. - GEMMA_CALL_FUNC(vit.attn_out_w); - GEMMA_CALL_FUNC(vit.attn_out_b); - GEMMA_CALL_FUNC(vit.qkv_einsum_w); - GEMMA_CALL_FUNC(vit.qkv_einsum_b); + func(TENSOR_ARGS(vit.attn_out_w, kMustRead)); + func(TENSOR_ARGS(vit.attn_out_b, kMustRead)); + func(TENSOR_ARGS(vit.qkv_einsum_w, kMustRead)); + // Used as 1D MatMul bias, but has `heads + 2 * kv_heads` rows, hence + // must not be padded. + func(TENSOR_ARGS(vit.qkv_einsum_b, kMustRead | TensorArgs::kPacked)); // MlpBlock. - GEMMA_CALL_FUNC(vit.linear_0_w); - GEMMA_CALL_FUNC(vit.linear_0_b); - GEMMA_CALL_FUNC(vit.linear_1_w); - GEMMA_CALL_FUNC(vit.linear_1_b); + func(TENSOR_ARGS(vit.linear_0_w, kMustRead)); + func(TENSOR_ARGS(vit.linear_0_b, kMustRead)); + func(TENSOR_ARGS(vit.linear_1_w, kMustRead)); + func(TENSOR_ARGS(vit.linear_1_b, kMustRead)); // LayerNorm. - GEMMA_CALL_FUNC(vit.layer_norm_0_bias); - GEMMA_CALL_FUNC(vit.layer_norm_0_scale); - GEMMA_CALL_FUNC(vit.layer_norm_1_bias); - GEMMA_CALL_FUNC(vit.layer_norm_1_scale); + func(TENSOR_ARGS(vit.layer_norm_0_bias, kMustRead)); + func(TENSOR_ARGS(vit.layer_norm_0_scale, kMustRead)); + func(TENSOR_ARGS(vit.layer_norm_1_bias, kMustRead)); + func(TENSOR_ARGS(vit.layer_norm_1_scale, kMustRead)); return; } - if (type == LayerAttentionType::kGemma) { - if (fet != ForEachType::kLoadNoToc) { - GEMMA_CALL_FUNC(att_weights); - } - if (fet == ForEachType::kInitNoToc || fet == ForEachType::kLoadNoToc || - fet == ForEachType::kIgnoreNulls) { - GEMMA_CALL_FUNC(attn_vec_einsum_w); - } - GEMMA_CALL_FUNC(qkv_einsum_w); - if (fet == ForEachType::kIgnoreNulls || - fet == ForEachType::kLoadWithToc) { - // The unwanted ones will be null or not in the toc. - GEMMA_CALL_FUNC(qkv_einsum_w1); - GEMMA_CALL_FUNC(qkv_einsum_w2); - } + if (layer_config.type == LayerAttentionType::kGemma) { + // Either read from file, or allocated during Fixup(). + func(TENSOR_ARGS(att_weights, kMaybeRead)); + func(TENSOR_ARGS(attn_vec_einsum_w, kMaybeRead)); + func(TENSOR_ARGS(qkv_einsum_w, kMaybeRead)); + func(TENSOR_ARGS(qkv_einsum_w1, kMaybeRead)); + func(TENSOR_ARGS(qkv_einsum_w2, kMaybeRead)); } else { - GEMMA_CALL_FUNC(griffin.linear_x_w); - GEMMA_CALL_FUNC(griffin.linear_x_biases); - GEMMA_CALL_FUNC(griffin.linear_y_w); - GEMMA_CALL_FUNC(griffin.linear_y_biases); - GEMMA_CALL_FUNC(griffin.linear_out_w); - GEMMA_CALL_FUNC(griffin.linear_out_biases); - GEMMA_CALL_FUNC(griffin.conv_w); - GEMMA_CALL_FUNC(griffin.conv_biases); - GEMMA_CALL_FUNC(griffin.gate_w); - GEMMA_CALL_FUNC(griffin.gate_biases); - GEMMA_CALL_FUNC(griffin.a); + func(TENSOR_ARGS(griffin.linear_x_w, kMustRead)); + func(TENSOR_ARGS(griffin.linear_x_biases, kMustRead)); + func(TENSOR_ARGS(griffin.linear_y_w, kMustRead)); + func(TENSOR_ARGS(griffin.linear_y_biases, kMustRead)); + func(TENSOR_ARGS(griffin.linear_out_w, kMustRead)); + func(TENSOR_ARGS(griffin.linear_out_biases, kMustRead)); + // conv_w and gate_w are not accessed via Row(), hence must not be padded. + // Note that *biases are 1D, hence packing/padding does not matter. + func(TENSOR_ARGS(griffin.conv_w, kMustRead | TensorArgs::kPacked)); + func(TENSOR_ARGS(griffin.conv_biases, kMustRead)); + func(TENSOR_ARGS(griffin.gate_w, kMustRead | TensorArgs::kPacked)); + func(TENSOR_ARGS(griffin.gate_biases, kMustRead)); + func(TENSOR_ARGS(griffin.a, kMustRead)); } - GEMMA_CALL_FUNC(gating_einsum_w); - if (fet == ForEachType::kIgnoreNulls || fet == ForEachType::kLoadWithToc) { - // The unwanted ones will be null or not in the toc. - GEMMA_CALL_FUNC(gating_einsum_w1); - GEMMA_CALL_FUNC(gating_einsum_w2); - } - GEMMA_CALL_FUNC(linear_w); - GEMMA_CALL_FUNC(pre_attention_norm_scale); - GEMMA_CALL_FUNC(pre_ffw_norm_scale); - - if (ptrs[0]->layer_config.post_norm == PostNormType::Scale) { - GEMMA_CALL_FUNC(post_attention_norm_scale); - GEMMA_CALL_FUNC(post_ffw_norm_scale); - } - if (ptrs[0]->layer_config.use_qk_norm) { - GEMMA_CALL_FUNC(key_norm_scale); - GEMMA_CALL_FUNC(query_norm_scale); + { + func(TENSOR_ARGS(gating_einsum_w, kMaybeRead)); + func(TENSOR_ARGS(gating_einsum_w1, kMaybeRead)); + func(TENSOR_ARGS(gating_einsum_w2, kMaybeRead)); + func(TENSOR_ARGS(linear_w, kMaybeRead)); + func(TENSOR_ARGS(pre_attention_norm_scale, kMustRead)); + func(TENSOR_ARGS(pre_ffw_norm_scale, kMustRead)); } - if (ptrs[0]->layer_config.ff_biases) { - GEMMA_CALL_FUNC(ffw_gating_biases); - GEMMA_CALL_FUNC(ffw_output_biases); + if (layer_config.post_norm == PostNormType::Scale) { + func(TENSOR_ARGS(post_attention_norm_scale, kMustRead)); + func(TENSOR_ARGS(post_ffw_norm_scale, kMustRead)); + } + if (layer_config.use_qk_norm) { + func(TENSOR_ARGS(key_norm_scale, kMustRead)); + func(TENSOR_ARGS(query_norm_scale, kMustRead)); } - if (ptrs[0]->layer_config.softmax_attn_output_biases && - type == LayerAttentionType::kGemma) { - GEMMA_CALL_FUNC(attention_output_biases); + if (layer_config.ff_biases) { + func(TENSOR_ARGS(ffw_gating_biases, kMustRead)); + func(TENSOR_ARGS(ffw_output_biases, kMustRead)); } - } - // Sets all the tensors in the layer to zero. Memory must have been allocated. - void ZeroInit(int layer_idx) { - ForEachTensor({this}, layer_idx, ForEachType::kIgnoreNulls, - [](const char*, hwy::Span tensors) { - tensors[0]->ZeroInit(); - }); - } - - // Allocates memory for all the tensors in the layer. - // Note that this is slow and only used for a stand-alone layer. - void Allocate(std::vector& layer_storage) { - ForEachTensor( - {this}, /*layer_idx=*/0, ForEachType::kInitNoToc, - [&layer_storage](const char* name, hwy::Span tensors) { - layer_storage.emplace_back(*tensors[0]); - layer_storage.back().Allocate(); - tensors[0]->SetPtr(layer_storage.back()); - }); - } -}; - -template -struct ModelWeightsPtrs { - explicit ModelWeightsPtrs(const ModelConfig& config) - : ModelWeightsPtrs( - config, - TensorIndex(config, /*llm_layer_idx=*/-1, /*vit_layer_idx=*/-1, - /*reshape_att=*/false)) {} - ModelWeightsPtrs(const ModelConfig& config, const TensorIndex& tensor_index) - : embedder_input_embedding("c_embedding", tensor_index), - final_norm_scale("c_final_norm", tensor_index), - vit_encoder_norm_bias("enc_norm_bias", tensor_index), - vit_encoder_norm_scale("enc_norm_scale", tensor_index), - vit_img_embedding_bias("img_emb_bias", tensor_index), - vit_img_embedding_kernel("img_emb_kernel", tensor_index), - vit_img_pos_embedding("img_pos_emb", tensor_index), - vit_img_head_bias("img_head_bias", tensor_index), - vit_img_head_kernel("img_head_kernel", tensor_index), - mm_embed_norm("mm_embed_norm", tensor_index), - scale_names(config.scale_names), - weights_config(config) { - c_layers.reserve(config.layer_configs.size()); - for (int index = 0; index < static_cast(config.layer_configs.size()); - ++index) { - const auto& layer_config = config.layer_configs[index]; - TensorIndex tensor_index(config, index, /*vit_layer_idx=*/-1, - /*reshape_att=*/false); - c_layers.push_back(LayerWeightsPtrs(layer_config, tensor_index)); + if (layer_config.softmax_attn_output_biases && + layer_config.type == LayerAttentionType::kGemma) { + func(TENSOR_ARGS(attention_output_biases, kMustRead)); } - for (int index = 0; - index < static_cast(config.vit_config.layer_configs.size()); - ++index) { - const auto& layer_config = config.vit_config.layer_configs[index]; - TensorIndex tensor_index(config, /*llm_layer_idx=*/-1, index, - /*reshape_att=*/false); - vit_layers.push_back( - LayerWeightsPtrs(layer_config, tensor_index)); - } - } + } // `ForEachTensor` - ~ModelWeightsPtrs() = default; - using WeightF32OrBF16 = typename LayerWeightsPtrs::WeightF32OrBF16; - using WeightF32OrInputT = hwy::If(), - EmbedderInputT, WeightF32OrBF16>; - - MatPtrT embedder_input_embedding; - MatPtrT final_norm_scale; - - // Vit parts. - MatPtrT vit_encoder_norm_bias; - MatPtrT vit_encoder_norm_scale; - MatPtrT vit_img_embedding_bias; - MatPtrT vit_img_embedding_kernel; - MatPtrT vit_img_pos_embedding; - // The head maps from VitConfig::kModelDim (Vit final layer) to - // kModelDim (LLM input). - MatPtrT vit_img_head_bias; - MatPtrT vit_img_head_kernel; - - MatPtrT mm_embed_norm; - - std::unordered_set scale_names; - - const ModelConfig& weights_config; - - std::vector> c_layers; - std::vector> vit_layers; - - // Called by weights.cc after Loading, before att_w has been allocated. - void AllocAndCopyWithTranspose(hwy::ThreadPool& pool, - std::vector& model_storage) { - size_t storage_index = model_storage.size(); - for (auto& layer : c_layers) { - model_storage.emplace_back(layer.att_weights); - } - pool.Run(0, c_layers.size(), - [this, &model_storage, storage_index](uint64_t layer, - size_t /*thread*/) { - GetLayer(layer)->Reshape(&model_storage[storage_index + layer]); - }); - } - // For when the storage has already been allocated. - void CopyWithTranspose(hwy::ThreadPool& pool) { - pool.Run(0, c_layers.size(), [this](uint64_t layer, size_t /*thread*/) { - GetLayer(layer)->Reshape(nullptr); + // Zero-initializes all allocated tensors in the layer. + void ZeroInit() { + ForEachTensor(nullptr, nullptr, [](const TensorArgs& t) { + if (!t.mat.HasPtr()) return; + gcpp::ZeroInit(t.mat); }); } - void ZeroInit() { - embedder_input_embedding.ZeroInit(); - final_norm_scale.ZeroInit(); - for (size_t i = 0; i < c_layers.size(); ++i) { - c_layers[i].ZeroInit(i); - } - } - - const LayerWeightsPtrs* GetLayer(size_t layer) const { - return &c_layers[layer]; - } - LayerWeightsPtrs* GetLayer(size_t layer) { return &c_layers[layer]; } - const LayerWeightsPtrs* GetVitLayer(size_t layer) const { - return &vit_layers[layer]; - } - LayerWeightsPtrs* GetVitLayer(size_t layer) { - return &vit_layers[layer]; - } - - void Allocate(std::vector& model_storage, hwy::ThreadPool& pool) { - std::vector model_toc; - ForEachTensor( - {this}, ForEachType::kInitNoToc, - [&model_toc, &model_storage](const char*, hwy::Span tensors) { - model_toc.push_back(tensors[0]); - model_storage.emplace_back(*tensors[0]); - }); - // Allocate in parallel using the pool. - pool.Run(0, model_toc.size(), - [&model_toc, &model_storage](uint64_t task, size_t /*thread*/) { - // model_storage may have had content before we started. - size_t idx = task + model_storage.size() - model_toc.size(); - model_storage[idx].Allocate(); - model_toc[task]->SetPtr(model_storage[idx]); - }); - } - - // Copies the data from other to *this. - void CopyFrom(const ModelWeightsPtrs& other) { - ForEachTensor({this, const_cast*>(&other)}, - ForEachType::kIgnoreNulls, - [](const char*, hwy::Span tensors) { - hwy::CopyBytes(tensors[1]->Ptr(), tensors[0]->Ptr(), - tensors[1]->SizeBytes()); - }); - } - - // If scales is empty, computes and returns the scale factors for the tensors, - // otherwise applies the scale factors to the tensors. - void GetOrApplyScales(std::vector& scales) { - int scale_pos = 0; - ForEachTensor( - {this}, ForEachType::kIgnoreNulls, - [&scales, &scale_pos, this](const char*, hwy::Span tensors) { - if (this->scale_names.count(tensors[0]->Name())) { - if (scale_pos < scales.size()) { - tensors[0]->set_scale(scales[scale_pos]); - } else { - float scale = ScaleWeights(tensors[0]->data(), - tensors[0]->NumElements()); - scales.push_back(scale); - } - ++scale_pos; - } - }); - HWY_ASSERT(scale_pos == weights_config.num_tensor_scales); - } - - template - static void ForEachTensor(const std::vector*>& ptrs, - ForEachType fet, Func func) { - std::vector*> layers(ptrs.size()); - std::vector*> vit_layers(ptrs.size()); - std::vector tensors(ptrs.size(), nullptr); - // Variables used by GEMMA_CALL_FUNC. - int layer_idx = -1; - char sep = ' '; - int sep_index = -1; - GEMMA_CALL_FUNC(embedder_input_embedding); - GEMMA_CALL_FUNC(final_norm_scale); - if (ptrs[0]->weights_config.vit_config.layer_configs.size() > 0) { - // Vit parts. - GEMMA_CALL_FUNC(vit_encoder_norm_bias); - GEMMA_CALL_FUNC(vit_encoder_norm_scale); - GEMMA_CALL_FUNC(vit_img_embedding_bias); - GEMMA_CALL_FUNC(vit_img_embedding_kernel); - GEMMA_CALL_FUNC(vit_img_pos_embedding); - GEMMA_CALL_FUNC(vit_img_head_bias); - GEMMA_CALL_FUNC(vit_img_head_kernel); - - if (ptrs[0]->weights_config.wrapping == PromptWrapping::GEMMA_VLM) - GEMMA_CALL_FUNC(mm_embed_norm); - } - - for (int layer_idx = 0; layer_idx < ptrs[0]->c_layers.size(); ++layer_idx) { - for (int i = 0; i < ptrs.size(); ++i) { - layers[i] = ptrs[i]->GetLayer(layer_idx); - } - LayerWeightsPtrs::ForEachTensor(layers, layer_idx, fet, func); - } - - // Vit layers. Not supported for compress_weights. - if (ptrs[0]->weights_config.vit_config.layer_configs.size() > 0) { - for (int layer_idx = 0; layer_idx < ptrs[0]->vit_layers.size(); - ++layer_idx) { - auto type = ptrs[0]->vit_layers[layer_idx].layer_config.type; - HWY_ASSERT(type == LayerAttentionType::kVit); - for (int i = 0; i < ptrs.size(); ++i) { - vit_layers[i] = ptrs[i]->GetVitLayer(layer_idx); - } - LayerWeightsPtrs::ForEachTensor(vit_layers, layer_idx, fet, - func); - } - } - } -}; -#undef GEMMA_CALL_FUNC - -// ---------------------------------------------------------------------------- -// Interface - -class ModelWeightsStorage { - public: - ModelWeightsStorage() = default; - ~ModelWeightsStorage() = default; - - // Loads the weights from a blob store file. Supports multi-file or - // single-file format. If the weights file contains a TOC, then it is in - // single-file format, and model_type, weight_type, wrapping are ignored, - // and tokenizer_proto is required and written to. - // With a multi-file format, file, model_type, weight_type, wrapping are - // required and tokenizer_proto is ignored. - BlobError Load(const Path& weights, Model model_type, Type weight_type, - PromptWrapping wrapping, hwy::ThreadPool& pool, - std::string* tokenizer_proto); - // Writes the weights to a blob store file, using the single-file format with - // a TOC and config included. - BlobError Save(const std::string& tokenizer, const Path& weights, - hwy::ThreadPool& pool); - void Allocate(Model model_type, Type weight_type, hwy::ThreadPool& pool) { - Allocate(ConfigFromModel(model_type), weight_type, pool); - } - void Allocate(const ModelConfig& config, Type weight_type, - hwy::ThreadPool& pool); - void RandInit(std::mt19937& gen); - void ZeroInit(); - void GetOrApplyScales(std::vector& scales); - void AllocAndCopyWithTranspose(hwy::ThreadPool& pool); - void CopyWithTranspose(hwy::ThreadPool& pool); - void LogWeightStats(); - const ModelConfig& Config() const { return config_; } - - template - ModelWeightsPtrs* GetWeightsOfType() const { - if constexpr (IsSfpStream()) { - return sfp_weights_.get(); - } else if constexpr (IsF32()) { - return float_weights_.get(); - } else if constexpr (IsBF16()) { - return bf16_weights_.get(); - } else if constexpr (IsNuqStream()) { - return nuq_weights_.get(); - } else { - return HWY_ABORT("Unsupported type."); - } - } - - template