From 6ab3ff5bde71f7e44b5cf6351628d036bb70dd27 Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Thu, 10 Oct 2024 09:04:19 -0700 Subject: [PATCH] Minor cleanup, Windows+Bazel build fixes add app.h comment compress-inl: remove unused typedef gemma-inl: add missing HWY_ATTR and cast separate sum-inl.h and basics.h headers replace more hwy::bfloat16_t with BF16 update include pragmas update dot_test thresholds update Highway version in Bazel for HWY_RCAST_ALIGNED fix PiperOrigin-RevId: 684464326 --- BUILD.bazel | 174 +++++++++++++------------ CMakeLists.txt | 4 +- MODULE.bazel | 54 +++----- bazel/sentencepiece.bazel | 14 +- compression/BUILD.bazel | 68 +++++----- compression/compress-inl.h | 11 +- compression/nuq-inl.h | 2 +- compression/python/BUILD | 4 +- compression/shared.h | 5 + examples/hello_world/BUILD.bazel | 4 +- gemma/activations.h | 9 +- gemma/common.h | 16 +-- gemma/configs.h | 4 +- gemma/gemma-inl.h | 7 +- gemma/gemma.h | 4 +- gemma/weights.h | 13 +- ops/dot_test.cc | 2 +- ops/matmul.h | 5 +- ops/ops-inl.h | 180 +------------------------ ops/sum-inl.h | 217 +++++++++++++++++++++++++++++++ paligemma/BUILD | 6 +- util/allocator.h | 21 +-- util/app.h | 3 + util/basics.h | 49 +++++++ util/threading.h | 2 +- 25 files changed, 475 insertions(+), 403 deletions(-) create mode 100644 ops/sum-inl.h create mode 100644 util/basics.h diff --git a/BUILD.bazel b/BUILD.bazel index eeac7cd..eef2ff1 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -20,11 +20,19 @@ licenses(["notice"]) exports_files(["LICENSE"]) +cc_library( + name = "basics", + hdrs = ["util/basics.h"], + deps = [ + "@highway//:hwy", + ], +) + cc_library( name = "allocator", hdrs = ["util/allocator.h"], deps = [ - "@hwy//:hwy", + "@highway//:hwy", ], ) @@ -32,9 +40,9 @@ cc_library( name = "test_util", hdrs = ["util/test_util.h"], deps = [ - "@hwy//:hwy", - "@hwy//:hwy_test_util", - "@hwy//:stats", + "@highway//:hwy", + "@highway//:hwy_test_util", + "@highway//:stats", ], ) @@ -42,9 +50,9 @@ cc_library( name = "threading", hdrs = ["util/threading.h"], deps = [ - "@hwy//:hwy", - "@hwy//:thread_pool", - "@hwy//:topology", + "@highway//:hwy", + "@highway//:thread_pool", + "@highway//:topology", ], ) @@ -54,8 +62,8 @@ cc_test( deps = [ ":threading", "@googletest//:gtest_main", - "@hwy//:hwy", - "@hwy//:hwy_test_util", + "@highway//:hwy", + "@highway//:hwy_test_util", ], ) @@ -66,6 +74,7 @@ cc_library( ], textual_hdrs = [ "ops/dot-inl.h", + "ops/sum-inl.h", "ops/fp_arith-inl.h", "ops/matmul-inl.h", "ops/matvec-inl.h", @@ -73,14 +82,15 @@ cc_library( ], deps = [ ":allocator", + ":basics", ":threading", "//compression:compress", - "@hwy//:algo", - "@hwy//:hwy", - "@hwy//:math", - "@hwy//:matvec", - "@hwy//:profiler", - "@hwy//:thread_pool", + "@highway//:algo", + "@highway//:hwy", + "@highway//:math", + "@highway//:matvec", + "@highway//:profiler", + "@highway//:thread_pool", ], ) @@ -102,12 +112,12 @@ cc_test( "@googletest//:gtest_main", # buildcleaner: keep "//compression:compress", "//compression:test_util", - "@hwy//:hwy", - "@hwy//:hwy_test_util", - "@hwy//:nanobenchmark", #buildcleaner: keep - "@hwy//:profiler", - "@hwy//:stats", - "@hwy//:thread_pool", + "@highway//:hwy", + "@highway//:hwy_test_util", + "@highway//:nanobenchmark", #buildcleaner: keep + "@highway//:profiler", + "@highway//:stats", + "@highway//:thread_pool", ], ) @@ -127,9 +137,9 @@ cc_test( ":test_util", "@googletest//:gtest_main", # buildcleaner: keep "//compression:compress", - "@hwy//:hwy", - "@hwy//:hwy_test_util", - "@hwy//:nanobenchmark", #buildcleaner: keep + "@highway//:hwy", + "@highway//:hwy_test_util", + "@highway//:nanobenchmark", #buildcleaner: keep ], ) @@ -145,9 +155,9 @@ cc_test( ":ops", "@googletest//:gtest_main", # buildcleaner: keep "//compression:compress", - "@hwy//:hwy", - "@hwy//:hwy_test_util", - "@hwy//:thread_pool", + "@highway//:hwy", + "@highway//:hwy_test_util", + "@highway//:thread_pool", ], ) @@ -164,10 +174,10 @@ cc_test( ":threading", "@googletest//:gtest_main", # buildcleaner: keep "//compression:compress", - "@hwy//:hwy", - "@hwy//:hwy_test_util", - "@hwy//:nanobenchmark", - "@hwy//:thread_pool", + "@highway//:hwy", + "@highway//:hwy_test_util", + "@highway//:nanobenchmark", + "@highway//:thread_pool", ], ) @@ -180,8 +190,8 @@ cc_library( ], deps = [ "//compression:compress", - "@hwy//:hwy", # base.h - "@hwy//:thread_pool", + "@highway//:hwy", # base.h + "@highway//:thread_pool", ], ) @@ -194,10 +204,10 @@ cc_library( ":common", "//compression:compress", "//compression:io", - "@hwy//:hwy", - "@hwy//:profiler", - "@hwy//:stats", - "@hwy//:thread_pool", + "@highway//:hwy", + "@highway//:profiler", + "@highway//:stats", + "@highway//:thread_pool", ], ) @@ -208,9 +218,9 @@ cc_library( deps = [ ":common", "//compression:io", - "@hwy//:hwy", - "@hwy//:nanobenchmark", # timer - "@hwy//:profiler", + "@highway//:hwy", + "@highway//:nanobenchmark", # timer + "@highway//:profiler", "@com_google_sentencepiece//:sentencepiece_processor", ], ) @@ -221,7 +231,7 @@ cc_library( hdrs = ["gemma/kv_cache.h"], deps = [ ":common", - "@hwy//:hwy", + "@highway//:hwy", ], ) @@ -268,6 +278,7 @@ cc_library( ], deps = [ ":allocator", + ":basics", ":common", ":ops", ":tokenizer", @@ -275,12 +286,13 @@ cc_library( ":weights", ":threading", "//compression:io", + "//compression:sfp", "//paligemma:image", - "@hwy//:hwy", - "@hwy//:bit_set", - "@hwy//:nanobenchmark", # timer - "@hwy//:profiler", - "@hwy//:thread_pool", + "@highway//:hwy", + "@highway//:bit_set", + "@highway//:nanobenchmark", # timer + "@highway//:profiler", + "@highway//:thread_pool", ], ) @@ -292,7 +304,7 @@ cc_library( ":common", ":gemma_lib", ":ops", - "@hwy//:hwy", + "@highway//:hwy", ], ) @@ -301,7 +313,7 @@ cc_library( hdrs = ["util/args.h"], deps = [ "//compression:io", - "@hwy//:hwy", + "@highway//:hwy", ], ) @@ -314,9 +326,9 @@ cc_library( ":gemma_lib", ":threading", "//compression:io", - "@hwy//:hwy", - "@hwy//:thread_pool", - "@hwy//:topology", + "@highway//:hwy", + "@highway//:thread_pool", + "@highway//:topology", ], ) @@ -333,12 +345,12 @@ cc_library( ":kv_cache", ":threading", # Placeholder for internal dep, do not remove., - "@benchmark//:benchmark", + "@google_benchmark//:benchmark", "//compression:compress", - "@hwy//:hwy", - "@hwy//:nanobenchmark", - "@hwy//:thread_pool", - "@hwy//:topology", + "@highway//:hwy", + "@highway//:nanobenchmark", + "@highway//:thread_pool", + "@highway//:topology", ], ) @@ -357,8 +369,8 @@ cc_test( ":gemma_lib", ":tokenizer", "@googletest//:gtest_main", - "@hwy//:hwy", - "@hwy//:hwy_test_util", + "@highway//:hwy", + "@highway//:hwy_test_util", ], ) @@ -374,9 +386,9 @@ cc_binary( ":threading", # Placeholder for internal dep, do not remove., "//paligemma:image", - "@hwy//:hwy", - "@hwy//:profiler", - "@hwy//:thread_pool", + "@highway//:hwy", + "@highway//:profiler", + "@highway//:thread_pool", ], ) @@ -391,9 +403,9 @@ cc_binary( ":cross_entropy", ":gemma_lib", "//compression:io", - "@hwy//:hwy", - "@hwy//:nanobenchmark", - "@hwy//:thread_pool", + "@highway//:hwy", + "@highway//:nanobenchmark", + "@highway//:thread_pool", "@nlohmann_json//:json", ], ) @@ -403,7 +415,7 @@ cc_binary( srcs = ["evals/benchmarks.cc"], deps = [ ":benchmark_helper", - "@benchmark//:benchmark", + "@google_benchmark//:benchmark", ], ) @@ -418,8 +430,8 @@ cc_binary( ":benchmark_helper", ":gemma_lib", "//compression:io", - "@hwy//:hwy", - "@hwy//:thread_pool", + "@highway//:hwy", + "@highway//:thread_pool", "@nlohmann_json//:json", ], ) @@ -433,9 +445,9 @@ cc_binary( ":benchmark_helper", ":gemma_lib", "//compression:io", - "@hwy//:hwy", - "@hwy//:profiler", - "@hwy//:thread_pool", + "@highway//:hwy", + "@highway//:profiler", + "@highway//:thread_pool", "@nlohmann_json//:json", ], ) @@ -477,9 +489,9 @@ cc_library( ":prompt", ":weights", "//compression:compress", - "@hwy//:dot", - "@hwy//:hwy", # base.h - "@hwy//:thread_pool", + "@highway//:dot", + "@highway//:hwy", # base.h + "@highway//:thread_pool", ], ) @@ -497,7 +509,7 @@ cc_library( ":prompt", ":weights", "//compression:compress", - "@hwy//:hwy", + "@highway//:hwy", ], ) @@ -517,7 +529,7 @@ cc_test( ":weights", "@googletest//:gtest_main", "//compression:compress", - "@hwy//:thread_pool", + "@highway//:thread_pool", ], ) @@ -544,9 +556,9 @@ cc_test( ":weights", "@googletest//:gtest_main", "//compression:compress", - "@hwy//:hwy", - "@hwy//:hwy_test_util", - "@hwy//:thread_pool", + "@highway//:hwy", + "@highway//:hwy_test_util", + "@highway//:thread_pool", ], ) @@ -559,8 +571,8 @@ cc_library( ":common", ":weights", "//compression:compress", - "@hwy//:hwy", - "@hwy//:thread_pool", + "@highway//:hwy", + "@highway//:thread_pool", ], ) @@ -583,6 +595,6 @@ cc_test( ":threading", ":weights", "@googletest//:gtest_main", - "@hwy//:thread_pool", + "@highway//:thread_pool", ], ) diff --git a/CMakeLists.txt b/CMakeLists.txt index 51e8891..51ab2e4 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -22,7 +22,7 @@ set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_EXPORT_COMPILE_COMMANDS ON) -FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG bb6c3f36b0c8dde8a8ef98b0f0884f4de820a7ca EXCLUDE_FROM_ALL) +FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 0ca297227a373710e76dd45e0ad4d68adb6928fe EXCLUDE_FROM_ALL) FetchContent_MakeAvailable(highway) ## Note: absl needs to be installed by sentencepiece. This will only happen if @@ -106,11 +106,13 @@ set(SOURCES ops/matmul-inl.h ops/matvec-inl.h ops/ops-inl.h + ops/sum-inl.h paligemma/image.cc paligemma/image.h util/allocator.h util/app.h util/args.h + util/basics.h util/test_util.h util/threading.h ) diff --git a/MODULE.bazel b/MODULE.bazel index 43b33a5..58faa0d 100644 --- a/MODULE.bazel +++ b/MODULE.bazel @@ -3,37 +3,33 @@ module( version = "0.1.0", ) -bazel_dep(name = "rules_license", version = "0.0.7") -bazel_dep(name = "googletest", version = "1.14.0") - -# Copied from Highway because Bazel does not load them transitively -bazel_dep(name = "bazel_skylib", version = "1.4.1") +bazel_dep(name = "abseil-cpp", version = "20240722.0") +bazel_dep(name = "bazel_skylib", version = "1.6.1") +bazel_dep(name = "googletest", version = "1.15.2") +bazel_dep(name = "highway", version = "1.1.0") +bazel_dep(name = "nlohmann_json", version = "3.11.3") +bazel_dep(name = "platforms", version = "0.0.10") bazel_dep(name = "rules_cc", version = "0.0.9") -bazel_dep(name = "platforms", version = "0.0.7") +bazel_dep(name = "rules_license", version = "0.0.7") +bazel_dep(name = "google_benchmark", version = "1.8.5") + +# Require a more recent version. +git_override( + module_name = "highway", + commit = "0ca297227a373710e76dd45e0ad4d68adb6928fe", + remote = "https://github.com/google/highway", +) http_archive = use_repo_rule("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") -http_archive( - name = "hwy", - urls = ["https://github.com/google/highway/archive/refs/tags/1.2.0.zip"], - integrity = "sha256-fbtKAGj5hhhBr5Bggtsrj4aIodC2OHb1njB8LGfom8A=", strip_prefix = "highway-1.2.0", -) - -http_archive( - name = "nlohmann_json", - urls = ["https://github.com/nlohmann/json/archive/refs/tags/v3.11.3.zip"], - integrity = "sha256-BAIrBdgG61/3MCPCgLaGl9Erk+G3JnoLIqGjnsdXgGk=", - strip_prefix = "json-3.11.3", -) - http_archive( name = "com_google_sentencepiece", + build_file = "@//bazel:sentencepiece.bazel", + patch_args = ["-p1"], + patches = ["@//bazel:sentencepiece.patch"], sha256 = "8409b0126ebd62b256c685d5757150cf7fcb2b92a2f2b98efb3f38fc36719754", strip_prefix = "sentencepiece-0.1.96", urls = ["https://github.com/google/sentencepiece/archive/refs/tags/v0.1.96.zip"], - build_file = "@//bazel:sentencepiece.bazel", - patches = ["@//bazel:sentencepiece.patch"], - patch_args = ["-p1"], ) # For sentencepiece @@ -56,17 +52,3 @@ cc_library( "https://github.com/s-yata/darts-clone/archive/e40ce4627526985a7767444b6ed6893ab6ff8983.zip", ], ) -# ABSL on 2023-10-18 -http_archive( - name = "com_google_absl", - sha256 = "f841f78243f179326f2a80b719f2887c38fe226d288ecdc46e2aa091e6aa43bc", - strip_prefix = "abseil-cpp-9687a8ea750bfcddf790372093245a1d041b21a3", - urls = ["https://github.com/abseil/abseil-cpp/archive//9687a8ea750bfcddf790372093245a1d041b21a3.tar.gz"], -) -# Benchmark -http_archive( - name = "benchmark", - urls = ["https://github.com/google/benchmark/archive/refs/tags/v1.8.2.tar.gz"], - integrity = "sha256-KqspgNA3YTf5adkoSPu2gharsHYzA0U0/IxlzE56DpM=", - strip_prefix = "benchmark-1.8.2", -) diff --git a/bazel/sentencepiece.bazel b/bazel/sentencepiece.bazel index a08e76e..ab72887 100644 --- a/bazel/sentencepiece.bazel +++ b/bazel/sentencepiece.bazel @@ -42,7 +42,7 @@ cc_library( "src/common.h", ], deps = [ - "@com_google_absl//absl/base", + "@abseil-cpp//absl/base", ], ) @@ -86,12 +86,12 @@ cc_library( ":common", ":sentencepiece_cc_proto", ":sentencepiece_model_cc_proto", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", + "@abseil-cpp//absl/container:flat_hash_map", + "@abseil-cpp//absl/container:flat_hash_set", + "@abseil-cpp//absl/memory", + "@abseil-cpp//absl/status", + "@abseil-cpp//absl/strings", + "@abseil-cpp//absl/strings:str_format", "@darts_clone", ], ) diff --git a/compression/BUILD.bazel b/compression/BUILD.bazel index 32af763..2e6b293 100644 --- a/compression/BUILD.bazel +++ b/compression/BUILD.bazel @@ -33,7 +33,7 @@ cc_library( ], hdrs = ["io.h"], deps = [ - "@hwy//:hwy", + "@highway//:hwy", ] + FILE_DEPS, ) @@ -43,8 +43,8 @@ cc_library( hdrs = ["blob_store.h"], deps = [ ":io", - "@hwy//:hwy", - "@hwy//:thread_pool", + "@highway//:hwy", + "@highway//:thread_pool", ], ) @@ -55,9 +55,9 @@ cc_library( "shared.h", ], deps = [ - "@hwy//:hwy", - "@hwy//:stats", - "@hwy//hwy/contrib/sort:vqsort", + "@highway//:hwy", + "@highway//:stats", + "@highway//hwy/contrib/sort:vqsort", ], ) @@ -69,8 +69,8 @@ cc_test( ":distortion", "@googletest//:gtest_main", # buildcleaner: keep "//:test_util", - "@hwy//:hwy_test_util", - "@hwy//:nanobenchmark", # Unpredictable1 + "@highway//:hwy_test_util", + "@highway//:nanobenchmark", # Unpredictable1 ], ) @@ -79,7 +79,7 @@ cc_library( hdrs = ["shared.h"], textual_hdrs = ["sfp-inl.h"], deps = [ - "@hwy//:hwy", + "@highway//:hwy", ], ) @@ -89,9 +89,9 @@ cc_library( textual_hdrs = ["nuq-inl.h"], deps = [ ":sfp", - "//:allocator", - "@hwy//:hwy", - "@hwy//hwy/contrib/sort:vqsort", + "//:basics", + "@highway//:hwy", + "@highway//hwy/contrib/sort:vqsort", ], ) @@ -103,8 +103,8 @@ cc_library( deps = [ ":compress", ":distortion", - "@hwy//:hwy", - "@hwy//:hwy_test_util", + "@highway//:hwy", + "@highway//:hwy_test_util", ], ) @@ -122,9 +122,9 @@ cc_test( ":sfp", "@googletest//:gtest_main", # buildcleaner: keep "//:test_util", - "@hwy//:hwy", - "@hwy//:hwy_test_util", - "@hwy//:nanobenchmark", + "@highway//:hwy", + "@highway//:hwy_test_util", + "@highway//:nanobenchmark", ], ) @@ -144,9 +144,9 @@ cc_test( ":sfp", "@googletest//:gtest_main", # buildcleaner: keep "//:test_util", - "@hwy//:hwy", - "@hwy//:hwy_test_util", - "@hwy//:nanobenchmark", + "@highway//:hwy", + "@highway//:hwy_test_util", + "@highway//:nanobenchmark", ], ) @@ -164,11 +164,11 @@ cc_library( ":io", ":nuq", ":sfp", - "@hwy//:hwy", - "@hwy//:nanobenchmark", - "@hwy//:profiler", - "@hwy//:stats", - "@hwy//:thread_pool", + "@highway//:hwy", + "@highway//:nanobenchmark", + "@highway//:profiler", + "@highway//:stats", + "@highway//:thread_pool", ], ) @@ -188,9 +188,9 @@ cc_test( ":test_util", "@googletest//:gtest_main", # buildcleaner: keep "//:test_util", - "@hwy//:hwy", - "@hwy//:hwy_test_util", - "@hwy//:thread_pool", + "@highway//:hwy", + "@highway//:hwy_test_util", + "@highway//:thread_pool", ], ) @@ -201,10 +201,10 @@ cc_library( deps = [ ":nuq", ":sfp", - "@hwy//:hwy", - "@hwy//:stats", - "@hwy//:thread_pool", - "@hwy//hwy/contrib/sort:vqsort", + "@highway//:hwy", + "@highway//:stats", + "@highway//:thread_pool", + "@highway//hwy/contrib/sort:vqsort", ], ) @@ -218,7 +218,7 @@ cc_binary( "//:args", "//:common", "//:weights", - "@hwy//:hwy", - "@hwy//:thread_pool", + "@highway//:hwy", + "@highway//:thread_pool", ], ) diff --git a/compression/compress-inl.h b/compression/compress-inl.h index 79f8b40..ef30033 100644 --- a/compression/compress-inl.h +++ b/compression/compress-inl.h @@ -24,7 +24,7 @@ #include // lroundf, only if COMPRESS_STATS #include "compression/blob_store.h" -#include "compression/compress.h" +#include "compression/compress.h" // IWYU pragma: export #include "compression/distortion.h" #include "hwy/aligned_allocator.h" #include "hwy/base.h" @@ -494,11 +494,6 @@ void Compress2(DF df, VF raw0, VF raw1, const PackedSpan& packed, Traits::Store2(df, raw0, raw1, packed, packed_ofs); } -template -constexpr bool IsF32() { - return hwy::IsSame, float>(); -} - namespace detail { // Compile-time-only check that `DRaw` and `Packed` are compatible. This makes @@ -678,8 +673,8 @@ class Compressor { template void operator()(MatPtrT* compressed, const char* decorated_name, const float* HWY_RESTRICT weights) { - int num_weights = compressed->NumElements(); - int num_compressed = compressed->NumElements(); + size_t num_weights = compressed->NumElements(); + size_t num_compressed = compressed->NumElements(); PackedSpan packed = MakeSpan(compressed->data(), num_compressed); fprintf(stderr, "Compressing %s (%zuM), please wait\n", decorated_name, num_weights / (1000 * 1000)); diff --git a/compression/nuq-inl.h b/compression/nuq-inl.h index faa5ba7..63c4255 100644 --- a/compression/nuq-inl.h +++ b/compression/nuq-inl.h @@ -22,7 +22,7 @@ #include #include "compression/shared.h" -#include "util/allocator.h" +#include "util/basics.h" #include "hwy/base.h" #endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_NUQ_INL_H_ diff --git a/compression/python/BUILD b/compression/python/BUILD index bdf67f9..89e2222 100644 --- a/compression/python/BUILD +++ b/compression/python/BUILD @@ -16,8 +16,8 @@ cc_library( "//third_party/absl/types:span", "//compression:compress", "//compression:io", - "@hwy//:hwy", - "@hwy//:thread_pool", + "@highway//:hwy", + "@highway//:thread_pool", ], ) diff --git a/compression/shared.h b/compression/shared.h index b79a067..c216d24 100644 --- a/compression/shared.h +++ b/compression/shared.h @@ -32,6 +32,11 @@ namespace gcpp { using BF16 = hwy::bfloat16_t; +template +constexpr bool IsF32() { + return hwy::IsSame, float>(); +} + // Switching Floating Point: a hybrid 8-bit float representation of bf16/f32 // inputs that combines the advantages of e4m3 and e5m2 into a single format. // It supports seeking at a granularity of 1 and decoding to bf16/f32. diff --git a/examples/hello_world/BUILD.bazel b/examples/hello_world/BUILD.bazel index ca7f426..52af610 100644 --- a/examples/hello_world/BUILD.bazel +++ b/examples/hello_world/BUILD.bazel @@ -17,7 +17,7 @@ cc_binary( "//:gemma_lib", "//:threading", "//:tokenizer", - "@hwy//:hwy", - "@hwy//:thread_pool", + "@highway//:hwy", + "@highway//:thread_pool", ], ) diff --git a/gemma/activations.h b/gemma/activations.h index d819d4a..b10b562 100644 --- a/gemma/activations.h +++ b/gemma/activations.h @@ -20,8 +20,9 @@ #include -#include "ops/matmul.h" // MatMulEnv -#include "util/allocator.h" // RowVectorBatch +#include "compression/shared.h" // BF16 +#include "ops/matmul.h" // MatMulEnv +#include "util/allocator.h" // RowVectorBatch #include "util/threading.h" #include "hwy/base.h" // HWY_DASSERT #include "hwy/contrib/thread_pool/thread_pool.h" @@ -41,7 +42,7 @@ struct Activations { RowVectorBatch att_sums; // Gated FFW - RowVectorBatch bf_pre_ffw_rms_out; + RowVectorBatch bf_pre_ffw_rms_out; RowVectorBatch C1; RowVectorBatch C2; RowVectorBatch ffw_out; @@ -106,7 +107,7 @@ struct Activations { att_out = RowVectorBatch(batch_size, kHeads * kQKVDim); att_sums = RowVectorBatch(batch_size, kModelDim); - bf_pre_ffw_rms_out = RowVectorBatch(batch_size, kModelDim); + bf_pre_ffw_rms_out = RowVectorBatch(batch_size, kModelDim); C1 = RowVectorBatch(batch_size, kFFHiddenDim); C2 = RowVectorBatch(batch_size, kFFHiddenDim); ffw_out = RowVectorBatch(batch_size, kModelDim); diff --git a/gemma/common.h b/gemma/common.h index aa5bc52..18ac5d1 100644 --- a/gemma/common.h +++ b/gemma/common.h @@ -118,8 +118,8 @@ decltype(auto) CallForModelAndWeight(Model model, Type weight, return CallForModel( // model, std::forward(args)...); case Type::kBF16: - return CallForModel( - model, std::forward(args)...); + return CallForModel(model, + std::forward(args)...); case Type::kSFP: return CallForModel( model, std::forward(args)...); @@ -130,7 +130,7 @@ decltype(auto) CallForModelAndWeight(Model model, Type weight, #define GEMMA_FOREACH_WEIGHT(X, CONFIGT) \ X(CONFIGT, float) \ - X(CONFIGT, hwy::bfloat16_t) \ + X(CONFIGT, BF16) \ X(CONFIGT, SfpStream) #define GEMMA_FOREACH_CONFIG_AND_WEIGHT(X) \ @@ -205,7 +205,7 @@ decltype(auto) CallForModelAndWeight(Model model, Type weight, GEMMA_DISPATCH_MODEL(MODEL, float, FUNC, ARGS); \ break; \ case Type::kBF16: \ - GEMMA_DISPATCH_MODEL(MODEL, hwy::bfloat16_t, FUNC, ARGS); \ + GEMMA_DISPATCH_MODEL(MODEL, BF16, FUNC, ARGS); \ break; \ case Type::kSFP: \ GEMMA_DISPATCH_MODEL(MODEL, SfpStream, FUNC, ARGS); \ @@ -239,15 +239,15 @@ static GEMMA_CONSTEXPR_SQRT HWY_INLINE float Sqrt(float x) { return sqrtf(x); } template GEMMA_CONSTEXPR_EMBSCALING float EmbeddingScaling() { // Round to bf16 to match Gemma's Embedder, which casts before mul. - return hwy::ConvertScalarTo(hwy::ConvertScalarTo( - Sqrt(static_cast(TConfig::kModelDim)))); + return hwy::ConvertScalarTo( + hwy::ConvertScalarTo(Sqrt(static_cast(TConfig::kModelDim)))); } static HWY_INLINE GEMMA_CONSTEXPR_EMBSCALING float EmbeddingScaling( size_t model_dim) { // Round to bf16 to match Gemma's Embedder, which casts before mul. - return hwy::ConvertScalarTo(hwy::ConvertScalarTo( - Sqrt(static_cast(model_dim)))); + return hwy::ConvertScalarTo( + hwy::ConvertScalarTo(Sqrt(static_cast(model_dim)))); } template diff --git a/gemma/configs.h b/gemma/configs.h index 51df334..7c1ce88 100644 --- a/gemma/configs.h +++ b/gemma/configs.h @@ -22,7 +22,7 @@ #include -#include "hwy/base.h" // hwy::bfloat16_t +#include "compression/shared.h" // BF16 namespace gcpp { @@ -40,7 +40,7 @@ static constexpr size_t kSeqLen = GEMMA_MAX_SEQLEN; static constexpr size_t kTopK = GEMMA_TOPK; static constexpr size_t kVocabSize = 256000; -using EmbedderInputT = hwy::bfloat16_t; +using EmbedderInputT = BF16; enum class LayerAttentionType { kGemma, diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index 9b9a0c4..9e38490 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -763,7 +763,7 @@ HWY_NOINLINE void EmbedToken(int token, size_t batch_idx, size_t pos, EmbeddingScaling(); HWY_DASSERT(token >= 0); - HWY_DASSERT(token < kVocabSize); + HWY_DASSERT(token < static_cast(kVocabSize)); const hn::ScalableTag df; DecompressAndZeroPad( @@ -1193,14 +1193,15 @@ SampleFunc ChooseSampleFunc(const RuntimeConfig& runtime_config) { // Fast path for top-1 with no accept_token. if (kTopK == 1 && !runtime_config.accept_token) { - return [](float* logits, size_t vocab_size) -> TokenAndProb { + return [](float* logits, size_t vocab_size) HWY_ATTR -> TokenAndProb { PROFILER_ZONE("Gen.Sample Top1"); return Top1OfSoftmax(logits, vocab_size); }; } // General case: Softmax with top-k sampling. - return [&runtime_config](float* logits, size_t vocab_size) -> TokenAndProb { + return [&runtime_config](float* logits, + size_t vocab_size) HWY_ATTR -> TokenAndProb { PROFILER_ZONE("Gen.Sample general"); Softmax(logits, vocab_size); const int token = SampleTopK(logits, vocab_size, *runtime_config.gen, diff --git a/gemma/gemma.h b/gemma/gemma.h index ea25281..654871d 100644 --- a/gemma/gemma.h +++ b/gemma/gemma.h @@ -28,13 +28,13 @@ #include "gemma/kv_cache.h" #include "gemma/tokenizer.h" #include "paligemma/image.h" -#include "util/allocator.h" +#include "util/allocator.h" // RowVectorBatch +#include "util/basics.h" // TokenAndProb #include "util/threading.h" #include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/timer.h" // IWYU pragma: end_exports #include "hwy/aligned_allocator.h" // Span -#include "hwy/base.h" // hwy::bfloat16_t namespace gcpp { using PromptTokens = hwy::Span; diff --git a/gemma/weights.h b/gemma/weights.h index 0c97253..e8655ed 100644 --- a/gemma/weights.h +++ b/gemma/weights.h @@ -108,10 +108,10 @@ struct CompressedLayer { // do not yet support smaller compressed types, or require at least bf16. When // weights are f32, we also want such tensors to be f32. // If weights are complex, this is also complex. - using WeightF32OrBF16 = hwy::If< - hwy::IsSame>(), std::complex, - hwy::If(), double, - hwy::If(), float, hwy::bfloat16_t>>>; + using WeightF32OrBF16 = + hwy::If>(), std::complex, + hwy::If(), double, + hwy::If(), float, BF16>>>; static constexpr size_t kHeads = TConfig::kHeads; static constexpr size_t kKVHeads = TConfig::kKVHeads; @@ -363,9 +363,8 @@ struct CompressedWeights { using Weight = typename TConfig::Weight; using WeightF32OrBF16 = typename CompressedLayer::WeightF32OrBF16; - using WeightF32OrInputT = - hwy::If(), EmbedderInputT, - WeightF32OrBF16>; + using WeightF32OrInputT = hwy::If(), + EmbedderInputT, WeightF32OrBF16>; MatPtrT embedder_input_embedding; MatPtrT final_norm_scale; diff --git a/ops/dot_test.cc b/ops/dot_test.cc index 4a29540..6c9f8c9 100644 --- a/ops/dot_test.cc +++ b/ops/dot_test.cc @@ -840,7 +840,7 @@ class DotStats { ASSERT_INSIDE(kPairwise, 4.5E-4, s_rels[kPairwise].GeometricMean(), 1.5E-2); // Extremely high error on aarch64. - ASSERT_INSIDE(kPairwise, 1.1E-3f, s_rels[kPairwise].Max(), 1250.f); + ASSERT_INSIDE(kPairwise, 1.1E-3f, s_rels[kPairwise].Max(), 2E3f); } // Backward relative error, lower is better. diff --git a/ops/matmul.h b/ops/matmul.h index 4ef63bc..34851f5 100644 --- a/ops/matmul.h +++ b/ops/matmul.h @@ -19,9 +19,10 @@ #include #include "util/allocator.h" // RowVectorBatch -#include "util/threading.h" // PerClusterPools +#include "util/threading.h" +#include "hwy/aligned_allocator.h" // IWYU pragma: export #include "hwy/base.h" -#include "hwy/contrib/thread_pool/thread_pool.h" +#include "hwy/contrib/thread_pool/thread_pool.h" // IWYU pragma: export #include "hwy/per_target.h" namespace gcpp { diff --git a/ops/ops-inl.h b/ops/ops-inl.h index f03159a..79f77bb 100644 --- a/ops/ops-inl.h +++ b/ops/ops-inl.h @@ -28,7 +28,7 @@ #include // std::enable_if_t #include "compression/compress.h" -#include "util/allocator.h" // TokenAndProb +#include "util/basics.h" // TokenAndProb #include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/detect_targets.h" @@ -44,6 +44,7 @@ #include "compression/compress-inl.h" #include "ops/dot-inl.h" +#include "ops/sum-inl.h" #include "hwy/contrib/algo/transform-inl.h" #include "hwy/contrib/math/math-inl.h" #include "hwy/profiler.h" // also uses SIMD @@ -507,183 +508,6 @@ static HWY_INLINE HWY_MAYBE_UNUSED void MulByConstAndAdd( MulByConstAndAdd(c, x, out, size, size); } -// f64 Add, called for f32 inputs promoted to f64. Runs at about half the speed -// of f32 sums. -struct SumKernelDouble { - // Only `CompressTraits` can `Decompress2` to `double`, so both have - // to be `float` in order to have `Raw = double`. Note that if either type is - // smaller than `float`, we may demote the other type from `float` to `BF16`. - template - using Raw = hwy::If() && IsF32(), double, BF16>; - using State = double; - - // Raw = double - template , HWY_IF_F64_D(DRaw)> - HWY_INLINE void Update4(DRaw /*dd*/, const VR w0, const VR w1, const VR w2, - const VR w3, VR, VR, VR, VR, VR& sum0, VR& sum1, - VR& sum2, VR& sum3, VR&, VR&, VR&, VR&) const { - sum0 = hn::Add(sum0, w0); - sum1 = hn::Add(sum1, w1); - sum2 = hn::Add(sum2, w2); - sum3 = hn::Add(sum3, w3); - } - - // Raw = BF16 - template , HWY_IF_BF16_D(DRaw), - class DS = hn::Repartition, class VS = hn::Vec> - HWY_INLINE void Update4(DRaw dr, const VR w0, const VR w1, const VR w2, - const VR w3, VR, VR, VR, VR, VS& sum0, VS& sum1, - VS& sum2, VS& sum3, VS&, VS&, VS&, VS&) const { - const hn::Repartition df; - using VF = hn::Vec; - // Reduce to two f32 sums so we can promote them to four f64 vectors. - VF sum02, sum13; - if constexpr (HWY_NATIVE_DOT_BF16) { - const VR k1 = hn::Set(dr, hwy::ConvertScalarTo(1.0f)); - const VF prod0 = hn::WidenMulPairwiseAdd(df, w0, k1); - const VF prod1 = hn::WidenMulPairwiseAdd(df, w1, k1); - // Fuse WidenMulPairwiseAdd plus Add into ReorderWidenMulAccumulate. - VF unused0 = hn::Zero(df); - VF unused1 = hn::Zero(df); - sum02 = hn::ReorderWidenMulAccumulate(df, w2, k1, prod0, unused0); - sum13 = hn::ReorderWidenMulAccumulate(df, w3, k1, prod1, unused1); - } else { - // If not native, the multiplication costs extra, so convert to f32. - // PromoteEvenTo is cheaper than PromoteUpperTo especially on `SVE`. - const VF fe0 = hn::PromoteEvenTo(df, w0); - const VF fe1 = hn::PromoteEvenTo(df, w1); - const VF fe2 = hn::PromoteEvenTo(df, w2); - const VF fe3 = hn::PromoteEvenTo(df, w3); - const VF fo0 = hn::PromoteOddTo(df, w0); - const VF fo1 = hn::PromoteOddTo(df, w1); - const VF fo2 = hn::PromoteOddTo(df, w2); - const VF fo3 = hn::PromoteOddTo(df, w3); - const VF fe01 = hn::Add(fe0, fe1); - const VF fe23 = hn::Add(fe2, fe3); - const VF fo01 = hn::Add(fo0, fo1); - const VF fo23 = hn::Add(fo2, fo3); - sum02 = hn::Add(fe01, fe23); - sum13 = hn::Add(fo01, fo23); - } - - const DS ds; - const VS d0 = hn::PromoteLowerTo(ds, sum02); - const VS d1 = hn::PromoteUpperTo(ds, sum02); - const VS d2 = hn::PromoteLowerTo(ds, sum13); - const VS d3 = hn::PromoteUpperTo(ds, sum13); - - sum0 = hn::Add(sum0, d0); - sum1 = hn::Add(sum1, d1); - sum2 = hn::Add(sum2, d2); - sum3 = hn::Add(sum3, d3); - } - - // Raw = double - template , HWY_IF_F64_D(DRaw)> - HWY_INLINE void Update1(DRaw /*dd*/, const VR w0, const VR v0, VR& sum0, - VR& comp0) const { - sum0 = hn::Add(sum0, w0); - } - - // Raw = BF16 - template , HWY_IF_BF16_D(DRaw), - class DS = hn::Repartition, class VS = hn::Vec> - HWY_INLINE void Update1(DRaw dr, const VR w0, VR, VS& sum0, - VS& extra0) const { - const hn::Repartition df; - using VF = hn::Vec; - VF f0; - if constexpr (HWY_NATIVE_DOT_BF16) { - const VR k1 = hn::Set(dr, hwy::ConvertScalarTo(1.0f)); - f0 = hn::WidenMulPairwiseAdd(df, w0, k1); - } else { - const VF fe0 = hn::PromoteEvenTo(df, w0); - const VF fo0 = hn::PromoteOddTo(df, w0); - f0 = hn::Add(fe0, fo0); - } - - const DS ds; - const VS d0 = hn::PromoteLowerTo(ds, f0); - const VS d1 = hn::PromoteUpperTo(ds, f0); - - sum0 = hn::Add(sum0, d0); - extra0 = hn::Add(extra0, d1); - } - - template > - HWY_INLINE float Reduce(DState dd, VS& sum0, VS& sum1, VS& sum2, VS& sum3, - VS& extra0, VS&, VS&, VS&) const { - // Reduction tree: sum of all accumulators by pairs, then across lanes. - sum0 = hn::Add(sum0, sum1); - sum2 = hn::Add(sum2, sum3); - sum0 = hn::Add(sum0, extra0); // from Update1 - sum0 = hn::Add(sum0, sum2); - return static_cast(hn::ReduceSum(dd, sum0)); - } -}; - -// ORO Cascaded Summation, algorithm 6.11 from Handbook of Floating-Point -// Arithmetic. Note that Algorithm 6.7 (KBN) appears erroneous. We use TwoSums -// instead of FastTwoSums because the magnitude of the initial sum is not -// always greater than the next input, and this does actually change the e2e -// generation results. Note that Kahan summation differs in that it first adds -// comp* to w*, so each operation is serially dependent. By contrast, the sum* -// and comp* here have shorter dependency chains. -// -// This about as accurate as SumKernelDouble but slower, hence we only use this -// if f64 is not supported on this target. -struct SumKernelCascaded { - template - using Raw = float; - using State = float; - - template , HWY_IF_F32_D(DF)> - HWY_INLINE void Update4(DF df, const VF w0, const VF w1, const VF w2, - const VF w3, VF, VF, VF, VF, VF& sum0, VF& sum1, - VF& sum2, VF& sum3, VF& comp0, VF& comp1, VF& comp2, - VF& comp3) const { - VF serr0, serr1, serr2, serr3; - sum0 = TwoSums(df, sum0, w0, serr0); - sum1 = TwoSums(df, sum1, w1, serr1); - sum2 = TwoSums(df, sum2, w2, serr2); - sum3 = TwoSums(df, sum3, w3, serr3); - - comp0 = hn::Add(comp0, serr0); - comp1 = hn::Add(comp1, serr1); - comp2 = hn::Add(comp2, serr2); - comp3 = hn::Add(comp3, serr3); - } - - template , HWY_IF_F32_D(DF)> - HWY_INLINE void Update1(DF df, const VF w0, const VF v0, VF& sum0, - VF& comp0) const { - VF serr0; - sum0 = TwoSums(df, sum0, w0, serr0); - - comp0 = hn::Add(comp0, serr0); - } - - template > - HWY_INLINE float Reduce(DF df, VF& sum0, VF& sum1, VF& sum2, VF& sum3, - VF& comp0, VF& comp1, VF& comp2, VF& comp3) const { - // Reduction tree: sum of all accumulators by pairs, then across lanes. - AssimilateCascadedSums(df, sum1, comp1, sum0, comp0); - AssimilateCascadedSums(df, sum3, comp3, sum2, comp2); - AssimilateCascadedSums(df, sum2, comp2, sum0, comp0); - return ReduceCascadedSums(df, sum0, comp0); - } -}; - -using SumKernelDefault = - hwy::If; - -template -HWY_INLINE float Sum(D d, const VT* HWY_RESTRICT vec, size_t num) { - using Raw = hwy::If; - const hn::Repartition d_raw; - return DecompressAndCall(d_raw, MakeSpan(vec, num), SumKernelDefault()); -} - // See below for a specialized version for top-1 sampling. static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size, const size_t mask_pos) { diff --git a/ops/sum-inl.h b/ops/sum-inl.h new file mode 100644 index 0000000..3f5d1de --- /dev/null +++ b/ops/sum-inl.h @@ -0,0 +1,217 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "hwy/base.h" + +// Include guard for SIMD code. +#if defined(THIRD_PARTY_GEMMA_CPP_SUM_TOGGLE) == defined(HWY_TARGET_TOGGLE) +#ifdef THIRD_PARTY_GEMMA_CPP_SUM_TOGGLE +#undef THIRD_PARTY_GEMMA_CPP_SUM_TOGGLE +#else +#define THIRD_PARTY_GEMMA_CPP_SUM_TOGGLE +#endif + +#include "compression/compress-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace gcpp { +namespace HWY_NAMESPACE { +namespace hn = hwy::HWY_NAMESPACE; + +// f64 Add, called for f32 inputs promoted to f64. Runs at about half the speed +// of f32 sums. +struct SumKernelDouble { + // Only `CompressTraits` can `Decompress2` to `double`, so both have + // to be `float` in order to have `Raw = double`. Note that if either type is + // smaller than `float`, we may demote the other type from `float` to `BF16`. + template + using Raw = hwy::If() && IsF32(), double, BF16>; + using State = double; + + // Raw = double + template , HWY_IF_F64_D(DRaw)> + HWY_INLINE void Update4(DRaw /*dd*/, const VR w0, const VR w1, const VR w2, + const VR w3, VR, VR, VR, VR, VR& sum0, VR& sum1, + VR& sum2, VR& sum3, VR&, VR&, VR&, VR&) const { + sum0 = hn::Add(sum0, w0); + sum1 = hn::Add(sum1, w1); + sum2 = hn::Add(sum2, w2); + sum3 = hn::Add(sum3, w3); + } + + // Raw = BF16 + template , HWY_IF_BF16_D(DRaw), + class DS = hn::Repartition, class VS = hn::Vec> + HWY_INLINE void Update4(DRaw dr, const VR w0, const VR w1, const VR w2, + const VR w3, VR, VR, VR, VR, VS& sum0, VS& sum1, + VS& sum2, VS& sum3, VS&, VS&, VS&, VS&) const { + const hn::Repartition df; + using VF = hn::Vec; + // Reduce to two f32 sums so we can promote them to four f64 vectors. + VF sum02, sum13; + if constexpr (HWY_NATIVE_DOT_BF16) { + const VR k1 = hn::Set(dr, hwy::ConvertScalarTo(1.0f)); + const VF prod0 = hn::WidenMulPairwiseAdd(df, w0, k1); + const VF prod1 = hn::WidenMulPairwiseAdd(df, w1, k1); + // Fuse WidenMulPairwiseAdd plus Add into ReorderWidenMulAccumulate. + VF unused0 = hn::Zero(df); + VF unused1 = hn::Zero(df); + sum02 = hn::ReorderWidenMulAccumulate(df, w2, k1, prod0, unused0); + sum13 = hn::ReorderWidenMulAccumulate(df, w3, k1, prod1, unused1); + } else { + // If not native, the multiplication costs extra, so convert to f32. + // PromoteEvenTo is cheaper than PromoteUpperTo especially on `SVE`. + const VF fe0 = hn::PromoteEvenTo(df, w0); + const VF fe1 = hn::PromoteEvenTo(df, w1); + const VF fe2 = hn::PromoteEvenTo(df, w2); + const VF fe3 = hn::PromoteEvenTo(df, w3); + const VF fo0 = hn::PromoteOddTo(df, w0); + const VF fo1 = hn::PromoteOddTo(df, w1); + const VF fo2 = hn::PromoteOddTo(df, w2); + const VF fo3 = hn::PromoteOddTo(df, w3); + const VF fe01 = hn::Add(fe0, fe1); + const VF fe23 = hn::Add(fe2, fe3); + const VF fo01 = hn::Add(fo0, fo1); + const VF fo23 = hn::Add(fo2, fo3); + sum02 = hn::Add(fe01, fe23); + sum13 = hn::Add(fo01, fo23); + } + + const DS ds; + const VS d0 = hn::PromoteLowerTo(ds, sum02); + const VS d1 = hn::PromoteUpperTo(ds, sum02); + const VS d2 = hn::PromoteLowerTo(ds, sum13); + const VS d3 = hn::PromoteUpperTo(ds, sum13); + + sum0 = hn::Add(sum0, d0); + sum1 = hn::Add(sum1, d1); + sum2 = hn::Add(sum2, d2); + sum3 = hn::Add(sum3, d3); + } + + // Raw = double + template , HWY_IF_F64_D(DRaw)> + HWY_INLINE void Update1(DRaw /*dd*/, const VR w0, const VR v0, VR& sum0, + VR& comp0) const { + sum0 = hn::Add(sum0, w0); + } + + // Raw = BF16 + template , HWY_IF_BF16_D(DRaw), + class DS = hn::Repartition, class VS = hn::Vec> + HWY_INLINE void Update1(DRaw dr, const VR w0, VR, VS& sum0, + VS& extra0) const { + const hn::Repartition df; + using VF = hn::Vec; + VF f0; + if constexpr (HWY_NATIVE_DOT_BF16) { + const VR k1 = hn::Set(dr, hwy::ConvertScalarTo(1.0f)); + f0 = hn::WidenMulPairwiseAdd(df, w0, k1); + } else { + const VF fe0 = hn::PromoteEvenTo(df, w0); + const VF fo0 = hn::PromoteOddTo(df, w0); + f0 = hn::Add(fe0, fo0); + } + + const DS ds; + const VS d0 = hn::PromoteLowerTo(ds, f0); + const VS d1 = hn::PromoteUpperTo(ds, f0); + + sum0 = hn::Add(sum0, d0); + extra0 = hn::Add(extra0, d1); + } + + template > + HWY_INLINE float Reduce(DState dd, VS& sum0, VS& sum1, VS& sum2, VS& sum3, + VS& extra0, VS&, VS&, VS&) const { + // Reduction tree: sum of all accumulators by pairs, then across lanes. + sum0 = hn::Add(sum0, sum1); + sum2 = hn::Add(sum2, sum3); + sum0 = hn::Add(sum0, extra0); // from Update1 + sum0 = hn::Add(sum0, sum2); + return static_cast(hn::ReduceSum(dd, sum0)); + } +}; + +// ORO Cascaded Summation, algorithm 6.11 from Handbook of Floating-Point +// Arithmetic. Note that Algorithm 6.7 (KBN) appears erroneous. We use TwoSums +// instead of FastTwoSums because the magnitude of the initial sum is not +// always greater than the next input, and this does actually change the e2e +// generation results. Note that Kahan summation differs in that it first adds +// comp* to w*, so each operation is serially dependent. By contrast, the sum* +// and comp* here have shorter dependency chains. +// +// This about as accurate as SumKernelDouble but slower, hence we only use this +// if f64 is not supported on this target. +struct SumKernelCascaded { + template + using Raw = float; + using State = float; + + template , HWY_IF_F32_D(DF)> + HWY_INLINE void Update4(DF df, const VF w0, const VF w1, const VF w2, + const VF w3, VF, VF, VF, VF, VF& sum0, VF& sum1, + VF& sum2, VF& sum3, VF& comp0, VF& comp1, VF& comp2, + VF& comp3) const { + VF serr0, serr1, serr2, serr3; + sum0 = TwoSums(df, sum0, w0, serr0); + sum1 = TwoSums(df, sum1, w1, serr1); + sum2 = TwoSums(df, sum2, w2, serr2); + sum3 = TwoSums(df, sum3, w3, serr3); + + comp0 = hn::Add(comp0, serr0); + comp1 = hn::Add(comp1, serr1); + comp2 = hn::Add(comp2, serr2); + comp3 = hn::Add(comp3, serr3); + } + + template , HWY_IF_F32_D(DF)> + HWY_INLINE void Update1(DF df, const VF w0, const VF v0, VF& sum0, + VF& comp0) const { + VF serr0; + sum0 = TwoSums(df, sum0, w0, serr0); + + comp0 = hn::Add(comp0, serr0); + } + + template > + HWY_INLINE float Reduce(DF df, VF& sum0, VF& sum1, VF& sum2, VF& sum3, + VF& comp0, VF& comp1, VF& comp2, VF& comp3) const { + // Reduction tree: sum of all accumulators by pairs, then across lanes. + AssimilateCascadedSums(df, sum1, comp1, sum0, comp0); + AssimilateCascadedSums(df, sum3, comp3, sum2, comp2); + AssimilateCascadedSums(df, sum2, comp2, sum0, comp0); + return ReduceCascadedSums(df, sum0, comp0); + } +}; + +using SumKernelDefault = + hwy::If; + +template +HWY_INLINE float Sum(D d, const VT* HWY_RESTRICT vec, size_t num) { + using Raw = hwy::If; + const hn::Repartition d_raw; + return DecompressAndCall(d_raw, MakeSpan(vec, num), SumKernelDefault()); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace gcpp +HWY_AFTER_NAMESPACE(); + +#endif // NOLINT diff --git a/paligemma/BUILD b/paligemma/BUILD index 335e531..6b303c8 100644 --- a/paligemma/BUILD +++ b/paligemma/BUILD @@ -11,7 +11,7 @@ cc_library( name = "image", srcs = ["image.cc"], hdrs = ["image.h"], - deps = ["@hwy//:hwy"], + deps = ["@highway//:hwy"], ) cc_test( @@ -39,7 +39,7 @@ cc_test( "//:common", "//:gemma_lib", "//:tokenizer", - "@hwy//:hwy", - "@hwy//:hwy_test_util", + "@highway//:hwy", + "@highway//:hwy_test_util", ], ) diff --git a/util/allocator.h b/util/allocator.h index 9e664b5..821268c 100644 --- a/util/allocator.h +++ b/util/allocator.h @@ -19,30 +19,11 @@ #include #include -#include "hwy/aligned_allocator.h" +#include "hwy/aligned_allocator.h" // IWYU pragma: export #include "hwy/base.h" -#if HWY_IS_MSAN -#include -#endif - namespace gcpp { -static inline void MaybeCheckInitialized(const void* ptr, size_t size) { -#if HWY_IS_MSAN - __msan_check_mem_is_initialized(ptr, size); -#else - (void)ptr; - (void)size; -#endif -} - -// Shared between gemma.h and ops-inl.h. -struct TokenAndProb { - int token; - float prob; -}; - using ByteStorageT = hwy::AlignedFreeUniquePtr; template diff --git a/util/app.h b/util/app.h index b3786c0..69a1f88 100644 --- a/util/app.h +++ b/util/app.h @@ -82,6 +82,9 @@ class AppArgs : public ArgsBase { visitor(max_threads, "num_threads", size_t{0}, "Maximum number of threads to use; default 0 = unlimited.", 2); visitor(pin, "pin", -1, "Pin threads? -1 = auto, 0 = no, 1 = yes.", 2); + // These can be used to partition CPU sockets/packages and their + // clusters/CCXs across several program instances. The default is to use + // all available resources. visitor(skip_packages, "skip_packages", size_t{0}, "Index of the first socket to use; default 0 = unlimited.", 2); visitor(max_packages, "max_packages", size_t{0}, diff --git a/util/basics.h b/util/basics.h new file mode 100644 index 0000000..7f433e2 --- /dev/null +++ b/util/basics.h @@ -0,0 +1,49 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_GEMMA_CPP_UTIL_BASICS_H_ +#define THIRD_PARTY_GEMMA_CPP_UTIL_BASICS_H_ + +// IWYU pragma: begin_exports +#include +#include + +#include "hwy/base.h" // HWY_IS_MSAN +// IWYU pragma: end_exports + +#if HWY_IS_MSAN +#include +#endif + +namespace gcpp { + +static inline void MaybeCheckInitialized(const void* ptr, size_t size) { +#if HWY_IS_MSAN + __msan_check_mem_is_initialized(ptr, size); +#else + (void)ptr; + (void)size; +#endif +} + +// Shared between gemma.h and ops-inl.h. +struct TokenAndProb { + int token; + float prob; +}; + +} // namespace gcpp + +#endif // THIRD_PARTY_GEMMA_CPP_UTIL_BASICS_H_ diff --git a/util/threading.h b/util/threading.h index f3dab6a..bf26ca0 100644 --- a/util/threading.h +++ b/util/threading.h @@ -381,7 +381,7 @@ class BoundedTopology { LPS enabled_lps; // LPs not disabled via OS, taskset, or numactl. bool missing_cluster = false; - if (HWY_LIKELY(have_threading_support)) { + if (HWY_LIKELY(have_threading_support && !topology_.packages.empty())) { (void)GetThreadAffinity(enabled_lps); // failure = all disabled // No effect if topology is unknown or `enabled_lps` is empty.