diff --git a/BUILD.bazel b/BUILD.bazel index 38d79cf..a9631dc 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -66,6 +66,7 @@ cc_library( srcs = ["util/topology.cc"], hdrs = ["util/topology.h"], deps = [ + "@highway//:bit_set", "@highway//:hwy", "@highway//:topology", ], @@ -111,6 +112,7 @@ cc_library( ":threading", ":topology", ":zones", + "//io", "@highway//:hwy", "@highway//:hwy_test_util", "@highway//:profiler", @@ -139,7 +141,7 @@ cc_test( ":kv_cache", ":mat", ":matmul", - ":ops", + ":test_util", ":threading_context", ":weights", "@googletest//:gtest_main", # buildcleaner: keep @@ -172,8 +174,11 @@ cc_library( name = "test_util", hdrs = ["util/test_util.h"], deps = [ + ":basics", + ":mat", "@highway//:hwy", "@highway//:hwy_test_util", + "@highway//:nanobenchmark", "@highway//:stats", ], ) @@ -440,9 +445,9 @@ cc_test( ":gemma_lib", ":mat", ":ops", + ":query", ":test_util", ":threading_context", - ":zones", "@googletest//:gtest_main", # buildcleaner: keep "//compression:test_util", "//compression:types", @@ -519,32 +524,70 @@ cc_library( ], ) +cc_test( + name = "kv_cache_test", + srcs = ["gemma/kv_cache_test.cc"], + deps = [ + ":configs", + ":gemma_args", + ":kv_cache", + ":threading_context", + "//testing/base/public:gunit_main", + "@highway//:hwy", + ], +) + +cc_library( + name = "query", + hdrs = ["gemma/query.h"], + deps = [ + ":basics", + ":gemma_args", + ":kv_cache", + "@highway//:hwy", + ], +) + cc_library( name = "gemma_args", hdrs = ["gemma/gemma_args.h"], deps = [ ":args", ":basics", + ":configs", ":mat", + ":threading_context", "//io", "@highway//:hwy", "@highway//:profiler", ], ) +cc_test( + name = "gemma_args_test", + srcs = ["gemma/gemma_args_test.cc"], + deps = [ + ":gemma_args", + "@googletest//:gtest_main", # buildcleaner: keep + ], +) + cc_library( name = "gemma_lib", srcs = [ "gemma/attention.cc", "gemma/flash_attention.cc", "gemma/gemma.cc", + "gemma/tensor_stats.cc", "gemma/vit.cc", ], hdrs = [ "gemma/activations.h", "gemma/attention.h", "gemma/flash_attention.h", + "gemma/flash_structs.h", "gemma/gemma.h", + "gemma/tensor_stats.h", "gemma/vit.h", ], exec_properties = { @@ -555,6 +598,7 @@ cc_library( "gemma/gemma-inl.h", ], deps = [ + ":allocator", ":basics", ":configs", ":gemma_args", @@ -564,6 +608,7 @@ cc_library( ":matmul_env", ":model_store", ":ops", + ":query", ":threading", ":threading_context", ":weights", @@ -577,8 +622,34 @@ cc_library( "@highway//:hwy", "@highway//:nanobenchmark", # timer "@highway//:profiler", + "@highway//:stats", "@highway//:thread_pool", "@highway//hwy/contrib/sort:vqsort", + ] + + [ + ], +) + +cc_test( + name = "gemma_lib_test", + srcs = ["gemma/attention_test.cc"], + # MatMulEnvs are up to 20GB large. + tags = ["requires-mem:28g"], + deps = [ + ":configs", + ":gemma_args", + ":gemma_lib", + ":kv_cache", + ":mat", + ":matmul_env", + ":test_util", + ":threading_context", + ":weights", + "@googletest//:gtest_main", # buildcleaner: keep + "//compression:compress", + "//compression:types", + "@highway//:hwy", + "@highway//:hwy_test_util", ], ) @@ -604,7 +675,6 @@ cc_library( ":gemma_args", ":gemma_lib", ":matmul_env", - ":ops", ":threading_context", ":tokenizer", "@google_benchmark//:benchmark", diff --git a/CMakeLists.txt b/CMakeLists.txt index a707078..47d7c4c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -22,7 +22,7 @@ set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_EXPORT_COMPILE_COMMANDS ON) -FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 2a16a50ff61071bb25ddef0ce35d92b0e2b9c579 EXCLUDE_FROM_ALL) +FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 3b680cde3a556bead9cc23c8f595d07a44d5a0d5 EXCLUDE_FROM_ALL) FetchContent_MakeAvailable(highway) ## Note: absl needs to be installed by sentencepiece. This will only happen if @@ -82,6 +82,7 @@ set(SOURCES gemma/configs.h gemma/flash_attention.cc gemma/flash_attention.h + gemma/flash_structs.h gemma/gemma_args.h gemma/gemma-inl.h gemma/gemma.cc @@ -221,6 +222,7 @@ set(GEMMA_TEST_FILES compression/nuq_test.cc compression/sfp_test.cc evals/gemma_test.cc + gemma/gemma_args_test.cc gemma/flash_attention_test.cc gemma/tensor_info_test.cc io/blob_store_test.cc diff --git a/MODULE.bazel b/MODULE.bazel index 861daba..d60c0f4 100644 --- a/MODULE.bazel +++ b/MODULE.bazel @@ -18,7 +18,7 @@ bazel_dep(name = "google_benchmark", version = "1.8.5") # Require a more recent version. git_override( module_name = "highway", - commit = "2a16a50ff61071bb25ddef0ce35d92b0e2b9c579", + commit = "3b680cde3a556bead9cc23c8f595d07a44d5a0d5", remote = "https://github.com/google/highway", ) diff --git a/README.md b/README.md index 6294920..0aedf38 100644 --- a/README.md +++ b/README.md @@ -55,7 +55,6 @@ Guidelines](https://opensource.google.com/conduct/). - CPU-only inference for: Gemma 2-3, PaliGemma 2. - Sampling with TopK and temperature. - - Backward pass (VJP) and Adam optimizer for Gemma research. - Optimizations @@ -452,7 +451,7 @@ FetchContent_MakeAvailable(sentencepiece) FetchContent_Declare(gemma GIT_REPOSITORY https://github.com/google/gemma.cpp GIT_TAG origin/main) FetchContent_MakeAvailable(gemma) -FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 2a16a50ff61071bb25ddef0ce35d92b0e2b9c579) +FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 3b680cde3a556bead9cc23c8f595d07a44d5a0d5) FetchContent_MakeAvailable(highway) ``` @@ -520,13 +519,19 @@ Mikhaylov, Eugene Kliuchnikov, Jan Wassenberg, Jyrki Alakuijala, Lode Vandevenne, Luca Versari, Martin Bruse, Phil Culliton, Sami Boukortt, Thomas Fischbacher and Zoltan Szabadka. It was removed in 2025-09. -Gemma-2 support was implemented in June/July 2024 with the help of several -people. +Gemma 2 support was implemented in June/July 2024 with the help of several +people including Daniel Keysers and Phil Culliton. PaliGemma support was implemented in September 2024 with contributions from Daniel Keysers. +Gemma 3 support was implemented in January-March 2025 with contributions from +Daniel Keysers and Phil Culliton. + [Jan Wassenberg](mailto:janwas@google.com) has continued to contribute many improvements, including major gains in efficiency, since the initial release. +[Phil Culliton](mailto:philculliton@google.com) has worked on model releases, +eval and validation, GTM, and quantization, since the initial release. + This is not an officially supported Google product. diff --git a/compression/BUILD.bazel b/compression/BUILD.bazel index c04bd08..4221a8d 100644 --- a/compression/BUILD.bazel +++ b/compression/BUILD.bazel @@ -1,4 +1,4 @@ -# Weight compression and analysis. +# Compressed tensor types. load("@rules_cc//cc:cc_library.bzl", "cc_library") load("@rules_cc//cc:cc_test.bzl", "cc_test") @@ -101,13 +101,11 @@ cc_test( # for test_suite. tags = ["hwy_ops_test"], deps = [ - ":distortion", ":int", "@googletest//:gtest_main", # buildcleaner: keep "//:test_util", "@highway//:hwy", "@highway//:hwy_test_util", - "@highway//:nanobenchmark", ], ) @@ -135,8 +133,8 @@ cc_test( # for test_suite. tags = ["hwy_ops_test"], deps = [ + ":compress", ":distortion", - ":sfp", "@googletest//:gtest_main", # buildcleaner: keep "//:test_util", "@highway//:hwy", @@ -182,7 +180,6 @@ cc_library( "//:mat", "//:threading_context", "@highway//:hwy", - "@highway//:nanobenchmark", "@highway//:profiler", "@highway//:stats", "@highway//:thread_pool", @@ -209,19 +206,3 @@ cc_test( "@highway//:hwy_test_util", ], ) - -# For internal experimentation -cc_library( - name = "analyze", - textual_hdrs = ["analyze.h"], - deps = [ - ":int", - ":nuq", - ":sfp", - ":types", - "@highway//:hwy", - "@highway//:stats", - "@highway//:thread_pool", - "@highway//hwy/contrib/sort:vqsort", - ], -) diff --git a/compression/analyze.h b/compression/analyze.h deleted file mode 100644 index 7d41633..0000000 --- a/compression/analyze.h +++ /dev/null @@ -1,238 +0,0 @@ -// Copyright 2024 Google LLC -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Normal include guard to placate lint. -#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_ANALYZE_H_ -#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_ANALYZE_H_ - -#include -#include -#include -#include // memcpy - -#include // std::signbit -#include // std::abs -#include - -#include "compression/types.h" -#include "hwy/base.h" -#include "hwy/contrib/thread_pool/thread_pool.h" -#include "hwy/stats.h" - -#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_ANALYZE_H_ - -// Actual per-target include guard. -#if defined(THIRD_PARTY_GEMMA_CPP_ANALYZE_TOGGLE) == defined(HWY_TARGET_TOGGLE) -#ifdef THIRD_PARTY_GEMMA_CPP_ANALYZE_TOGGLE -#undef THIRD_PARTY_GEMMA_CPP_ANALYZE_TOGGLE -#else -#define THIRD_PARTY_GEMMA_CPP_ANALYZE_TOGGLE -#endif - -#include "compression/nuq-inl.h" -#include "compression/sfp-inl.h" -#include "hwy/contrib/sort/vqsort-inl.h" -#include "hwy/highway.h" - -HWY_BEFORE_NAMESPACE(); -namespace gcpp { -namespace HWY_NAMESPACE { - -class PerThread { - public: - void NotifyGroup(const float* group) { - constexpr size_t kGroupSize = NuqStream::kGroupSize; - hwy::Stats s_group; - for (size_t i = 0; i < kGroupSize; ++i) { - // Skip zero so we can see the lowest actual magnitude - if (group[i] == 0.0f || group[i] == -0.0f) continue; - s_all_.Notify(group[i]); - s_group.Notify(group[i]); - - num_tiny_ += std::abs(group[i]) < 1e-3f; - - // b_magn100_.Notify(group[i] * 40.0f + 20.0f); - const uint32_t binary32 = - hwy::BitCastScalar(std::abs(group[i])); - - // const int32_t exp = (binary32 >> 23) - 127; - b_exp256_.Notify(binary32 >> 23); - const uint32_t m4 = (binary32 & 0x7FFFFF) >> (23 - 4); - b_m4_.Notify(m4); - } - s_group_ranges_.Notify(s_group.Max() - s_group.Min()); - s_group_mins_.Notify(s_group.Min()); - s_group_maxs_.Notify(s_group.Max()); - - float desc[kGroupSize]; - memcpy(desc, group, kGroupSize * sizeof(group[0])); - hn::VQSortStatic(desc, kGroupSize, hwy::SortDescending()); - - // Find largest |max/min| (dynamic range) - float max_ratio = 0.0f; - for (size_t i = 0; i < kGroupSize; ++i) { - if (desc[i] != 0.0f && desc[i] != -0.0f) { - max_ratio = std::max(max_ratio, std::abs(desc[0] / desc[i])); - } - } - s_group_max_vs_min_.Notify(max_ratio); - - // Relative errors - float diffs[kGroupSize]; - for (size_t i = 0; i < kGroupSize - 1; ++i) { - // was in descending order. Avoid div by 0. Ignore sign changes. - diffs[i] = std::abs(desc[i]) < 1e-5 - ? 0 - : std::abs((desc[i] - desc[i + 1]) / desc[i]); - } - hn::VQSortStatic(diffs, kGroupSize, hwy::SortDescending()); - s_cut15_.Notify(diffs[15]); - } - - void Assimilate(const PerThread& other) { - num_tiny_ += other.num_tiny_; - s_all_.Assimilate(other.s_all_); - s_group_ranges_.Assimilate(other.s_group_ranges_); - s_group_mins_.Assimilate(other.s_group_mins_); - s_group_maxs_.Assimilate(other.s_group_maxs_); - s_group_max_vs_min_.Assimilate(other.s_group_max_vs_min_); - s_erange_.Assimilate(other.s_erange_); - s_km_1_.Assimilate(other.s_km_1_); - s_km_2_.Assimilate(other.s_km_2_); - s_cut15_.Assimilate(other.s_cut15_); - b_magn100_.Assimilate(other.b_magn100_); - b_exp256_.Assimilate(other.b_exp256_); - b_m4_.Assimilate(other.b_m4_); - } - - void PrintAll() { - const int skip = hwy::Stats::kNoGeomean; - fprintf(stderr, "num tiny %zu\n", num_tiny_); - fprintf(stderr, "weights %s\n", s_all_.ToString(skip).c_str()); - fprintf(stderr, " ranges %s\n", s_group_ranges_.ToString(skip).c_str()); - fprintf(stderr, " mins %s\n", s_group_mins_.ToString(skip).c_str()); - fprintf(stderr, " maxs %s\n", s_group_maxs_.ToString(skip).c_str()); - fprintf(stderr, " Mvm %s\n", s_group_max_vs_min_.ToString(skip).c_str()); - fprintf(stderr, " cut15 %s\n", s_cut15_.ToString(skip).c_str()); - fprintf(stderr, " erange %s\n", s_erange_.ToString(skip).c_str()); - fprintf(stderr, " km1 %s\n", s_km_1_.ToString(skip).c_str()); - fprintf(stderr, " km2 %s\n", s_km_2_.ToString(skip).c_str()); - - // b_magn100_.Print("magn100"); - // b_exp256_.Print("exp"); - // b_m4_.Print("mantissa bits4"); - - fprintf(stderr, "\n"); - } - - private: - size_t num_tiny_ = 0; - hwy::Stats s_all_; - hwy::Stats s_group_ranges_; - hwy::Stats s_group_mins_; - hwy::Stats s_group_maxs_; - hwy::Stats s_group_max_vs_min_; - hwy::Stats s_erange_; - hwy::Stats s_km_1_; - hwy::Stats s_km_2_; - hwy::Stats s_cut15_; - hwy::Bins<100> b_magn100_; - hwy::Bins<256> b_exp256_; - hwy::Bins<16> b_m4_; - uint8_t padding_[64]; // prevent false sharing -}; - -class PerLayer { - public: - void NotifyGroup(const float* group) { - for (size_t i = 0; i < NuqStream::kGroupSize; ++i) { - s_layer_.Notify(group[i]); - } - } - - void UpdateOutliers(const float* layer, size_t weights_per_layer) { - const float layer_mean = s_layer_.Mean(); - const float layer_sd = s_layer_.StandardDeviation(); - for (size_t i = 0; i < weights_per_layer; ++i) { - num_outliers_ += - std::abs(std::abs(layer[i]) - layer_mean) >= 3.0f * layer_sd; - } - } - - const hwy::Stats& GetStats() const { return s_layer_; } - size_t Outliers() const { return num_outliers_; } - - private: - hwy::Stats s_layer_; - size_t num_outliers_ = 0; - uint8_t padding[64]; // prevent false sharing -}; - -static HWY_NOINLINE void Analyze(const char* caption, float* mat, size_t layers, - size_t weights_per_layer, - hwy::ThreadPool& pool) { - std::vector tls; - std::vector per_layer(layers); - const auto init = [&](size_t num_threads) { - tls.resize(num_threads); - return true; - }; - - pool.Run(0, static_cast(layers), init, - [&](uint32_t idx_layer, size_t idx_thread) { - PerThread& self = tls[idx_thread]; - const float* layer = &mat[idx_layer * weights_per_layer]; - // For each whole group in the layer - for (size_t group_start = 0; - group_start + NuqStream::kGroupSize <= weights_per_layer; - group_start += NuqStream::kGroupSize) { - const float* group = layer + group_start; - per_layer[idx_layer].NotifyGroup(group); - self.NotifyGroup(group); - } - - per_layer[idx_layer].UpdateOutliers(layer, weights_per_layer); - }); - - const int skip = hwy::Stats::kNoGeomean; - fprintf(stderr, "\n------------%s\n", caption); - - for (size_t i = 1; i < pool.NumWorkers(); ++i) { - tls[0].Assimilate(tls[i]); - } - tls[0].PrintAll(); - - hwy::Stats s_layer_ranges; - hwy::Stats s_layer_outliers; - for (size_t i = 0; i < layers; ++i) { - fprintf(stderr, " %02zu %s\n", i, - per_layer[i].GetStats().ToString(skip).c_str()); - const float range = - per_layer[i].GetStats().Max() - per_layer[i].GetStats().Min(); - s_layer_ranges.Notify(range); - s_layer_outliers.Notify((100.0 * per_layer[i].Outliers()) / - weights_per_layer); - } - fprintf(stderr, "layer outliers%% %s\n", - s_layer_outliers.ToString(skip).c_str()); - fprintf(stderr, "layer ranges %s\n", s_layer_ranges.ToString(skip).c_str()); -} - -// NOLINTNEXTLINE(google-readability-namespace-comments) -} // namespace HWY_NAMESPACE -} // namespace gcpp -HWY_AFTER_NAMESPACE(); - -#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_ANALYZE_H_ diff --git a/compression/compress-inl.h b/compression/compress-inl.h index 10ce57c..e7bb9d6 100644 --- a/compression/compress-inl.h +++ b/compression/compress-inl.h @@ -82,6 +82,8 @@ struct CompressTraits { hn::StoreU(raw1, df, packed.ptr + packed_ofs + NF); } + static float ToFloatSlow(const Packed x) { return x; } + template > static HWY_INLINE void Load2(DBF16 dbf16, const PackedSpan& packed, @@ -254,6 +256,10 @@ struct CompressTraits { packed.ptr + packed_ofs); } + static float ToFloatSlow(const Packed x) { + return hwy::ConvertScalarTo(x); + } + template static HWY_INLINE void Load2(DBF16 dbf16, const PackedSpan& packed, @@ -397,6 +403,27 @@ struct CompressTraits { } } + // NOTE: this does not take into account the per-tensor scale. + static float ToFloatSlow(const Packed x) { + uint32_t sfp = x.byte; + HWY_ASSERT(sfp != 0x80); // -0 is reserved + + const uint32_t sign32 = (sfp & 0x80) << 24; + sfp &= 0x7F; + const bool large_e = sfp >= 64; + const size_t m_bits = large_e ? 3 : 2; + uint32_t m = sfp & ((1u << m_bits) - 1u); + size_t e = sfp >> m_bits; + if (sfp == 0) return 0.0f; + const uint32_t e_bias = large_e ? 15 : 23; + const uint32_t exp32 = static_cast(127 + e - e_bias) << 23; + const uint32_t mnt32 = m << (23 - m_bits); + const uint32_t binary32 = sign32 | exp32 | mnt32; + float result; + hwy::CopySameSize(&binary32, &result); + return result; + } + template // Caller checks this is f32 or bf16 static HWY_INLINE void Load2(D d, const PackedSpan& packed, const size_t packed_ofs, hn::Vec& raw0, @@ -437,6 +464,12 @@ struct CompressTraits { IntCodec::Dec2(d, packed, packed_ofs, raw0, raw1); } + static float ToFloatSlow(const Packed x) { + HWY_DASSERT(!"Not supported - requires a stream"); + return 0.0f; + } + // Store2 is not yet implemented. + template static HWY_INLINE void DecompressAndZeroPad( D d, const PackedSpan& packed, const size_t packed_ofs, @@ -483,6 +516,10 @@ struct CompressTraits { NuqCodec::Dec2(d, packed, packed_ofs, raw0, raw1); } + static float ToFloatSlow(const Packed x) { + HWY_DASSERT(!"Not supported - requires a stream"); + return 0.0f; + } // Store2 is not yet implemented. template @@ -604,6 +641,13 @@ HWY_INLINE void DecompressAndZeroPad(DRaw d, const PackedSpan& packed, Traits::DecompressAndZeroPad(d, MakeConst(packed), packed_ofs, raw, num); } +// NOTE: the following are the recommended way to iterate over arrays of +// potentially compressed elements, including remainder handling. Prefer them +// over calling `Decompress2` directly, which does not handle remainders. +// `DecompressAndCall` is for algorithms expressed as `Kernel` objects, such as +// `Dot`. `Decompress*AndCompress*` are for varying numbers of input arrays and +// user code expressed as lambdas. + // Invokes `kernel` for the `v.num` elements of `w` and `v`. Decompresses from // both into groups of four vectors with lane type `Kernel::Raw`, passes them to // `kernel.Update4`; loads the final vector(s) with zero-padding, then passes @@ -733,8 +777,8 @@ HWY_INLINE float DecompressAndCall(D, const PackedSpan v, comp3); } -// Similar to `hn::Transform*`, but for compressed `T`. Used by ops-inl.h. -// `DF` is the decompressed type, typically `float`. +// Similar to `hn::Transform*`, but for compressed `T`. Used by `ops-inl.h`. +// `DF` is the decompressed type, typically `float`. Calls `func(df, v_inout)`. template HWY_INLINE void DecompressAndCompressInplace(DF df, T* HWY_RESTRICT inout, size_t num, Func&& func) { @@ -773,6 +817,7 @@ HWY_INLINE void DecompressAndCompressInplace(DF df, T* HWY_RESTRICT inout, } // One extra argument. `DF` is the decompressed type, typically `float`. +// Calls `func(df, v_inout, v1)`. template HWY_INLINE void Decompress1AndCompressInplace(DF df, T* HWY_RESTRICT inout, size_t num, @@ -821,8 +866,64 @@ HWY_INLINE void Decompress1AndCompressInplace(DF df, T* HWY_RESTRICT inout, } } +// Two extra arguments. `DF` is the decompressed type, typically `float`. +// Calls `func(df, v_inout, v1, v2)`. +template +HWY_INLINE void Decompress2AndCompressInplace( + DF df, T* HWY_RESTRICT inout, size_t num, const T1* HWY_RESTRICT p1, + const T2* HWY_RESTRICT p2, const size_t p2_ofs, Func&& func) { + const auto packed_inout = MakeSpan(inout, num); + const auto packed1 = MakeSpan(p1, num); + const auto packed2 = MakeSpan(p2, p2_ofs + num); + + using VF = hn::Vec; + HWY_LANES_CONSTEXPR const size_t NF = hn::Lanes(df); + size_t i = 0; + if (num >= 2 * NF) { + for (; i <= num - 2 * NF; i += 2 * NF) { + VF v0, v1; + Decompress2(df, packed_inout, i, v0, v1); + VF v10, v11; + Decompress2(df, packed1, i, v10, v11); + VF v20, v21; + Decompress2(df, packed2, p2_ofs + i, v20, v21); + const VF out0 = func(df, v0, v10, v20); + const VF out1 = func(df, v1, v11, v21); + Compress2(df, out0, out1, packed_inout, i); + } + } + + const size_t remaining = num - i; + HWY_DASSERT(remaining < 2 * NF); + if (HWY_UNLIKELY(remaining != 0)) { + HWY_ALIGN float buf_inout[2 * hn::MaxLanes(df)]; + HWY_ALIGN float buf1[2 * hn::MaxLanes(df)]; + HWY_ALIGN float buf2[2 * hn::MaxLanes(df)]; + // Ensure the second vector is zeroed even if remaining <= NF. + hn::Store(hn::Zero(df), df, buf_inout + NF); + hn::Store(hn::Zero(df), df, buf1 + NF); + hn::Store(hn::Zero(df), df, buf2 + NF); + DecompressAndZeroPad(df, packed_inout, i, buf_inout, remaining); + DecompressAndZeroPad(df, packed1, i, buf1, remaining); + DecompressAndZeroPad(df, packed2, p2_ofs + i, buf2, remaining); + const VF v0 = hn::Load(df, buf_inout); + const VF v1 = hn::Load(df, buf_inout + NF); + const VF v10 = hn::Load(df, buf1); + const VF v11 = hn::Load(df, buf1 + NF); + const VF v20 = hn::Load(df, buf2); + const VF v21 = hn::Load(df, buf2 + NF); + const VF out0 = func(df, v0, v10, v20); + const VF out1 = func(df, v1, v11, v21); + Compress2(df, out0, out1, MakeSpan(buf_inout, 2 * NF), 0); + // Clang generates incorrect code for CopyBytes if num = 2. + for (size_t j = 0; j < remaining; ++j) { + inout[i + j] = hwy::ConvertScalarTo(buf_inout[j]); + } + } +} + // Single input, separate output. `DF` is the decompressed type, typically -// `float`. +// `float`. Calls `func(df, v1)`. template HWY_INLINE void Decompress1AndCompressTo(DF df, T* HWY_RESTRICT out, size_t num, const T1* HWY_RESTRICT p1, @@ -863,7 +964,8 @@ HWY_INLINE void Decompress1AndCompressTo(DF df, T* HWY_RESTRICT out, size_t num, } } -// Two inputs. `DF` is the decompressed type, typically `float`. +// Two inputs, separate output. `DF` is the decompressed type, typically +// `float`. Calls `func(df, v1, v2)`. template HWY_INLINE void Decompress2AndCompressTo(DF df, T* HWY_RESTRICT out, size_t num, const T1* HWY_RESTRICT p1, @@ -912,7 +1014,8 @@ HWY_INLINE void Decompress2AndCompressTo(DF df, T* HWY_RESTRICT out, size_t num, } } -// Three inputs. `DF` is the decompressed type, typically `float`. +// Three inputs, separate output. `DF` is the decompressed type, typically +// `float`. Calls `func(df, v1, v2, v3)`. template HWY_INLINE void Decompress3AndCompressTo(DF df, T* HWY_RESTRICT out, size_t num, diff --git a/compression/compress_test.cc b/compression/compress_test.cc index 987f409..421492e 100644 --- a/compression/compress_test.cc +++ b/compression/compress_test.cc @@ -259,6 +259,13 @@ class TestDecompressAndCompress { [](DF, VF v, VF v1) HWY_ATTR -> VF { return hn::Add(v, v1); }); HWY_ASSERT_ARRAY_EQ(expected2.get(), out.get(), num); + // `out` already contains v + v1. + Decompress2AndCompressInplace( + df, out.get(), num, p1.get(), p2.get(), /*p2_ofs=*/0, + [](DF, VF v, VF /*v1*/, VF v2) + HWY_ATTR -> VF { return hn::Add(v, v2); }); + HWY_ASSERT_ARRAY_EQ(expected3.get(), out.get(), num); + Decompress1AndCompressTo(df, out.get(), num, p.get(), [](DF, VF v) HWY_ATTR -> VF { return v; }); HWY_ASSERT_ARRAY_EQ(expected1.get(), out.get(), num); diff --git a/compression/nuq-inl.h b/compression/nuq-inl.h index 997bb5b..ce387f9 100644 --- a/compression/nuq-inl.h +++ b/compression/nuq-inl.h @@ -480,9 +480,12 @@ class NibbleCodec { static_assert(kHalf <= 1); const size_t N = hn::Lanes(d8); constexpr size_t kMaxN = hn::MaxLanes(d8); + constexpr bool kPermuteAcrossBlocks = + HWY_TARGET <= HWY_AVX3_DL || !HWY_ARCH_X86; // For kHalf=1 and 512-bit vectors, kAdd would be 16, which is out of // bounds for TableLookupBytes. We instead BroadcastBlock<1> there. - constexpr uint8_t kAdd = kMaxN < 64 ? kHalf * kMaxN / 4 : 0; + constexpr uint8_t kAdd = + kMaxN < 64 || kPermuteAcrossBlocks ? kHalf * kMaxN / 4 : 0; // The only performance-portable op to replicate bytes is TableLookupBytes, // but this only works if vectors are 128-bit or we first BroadcastBlock, // which only works for <= 512-bit vectors. For scalable vectors, we @@ -506,7 +509,7 @@ class NibbleCodec { } else if constexpr (kMaxN <= 16) { // <= 128-bit // No BroadcastBlock, we anyway only have one block. return hn::TableLookupBytes(bytes, hn::Load(d8, kRep4)); - } else if constexpr (HWY_TARGET <= HWY_AVX3_DL || !HWY_ARCH_X86) { + } else if constexpr (kPermuteAcrossBlocks) { // No BroadcastBlock, can directly permute across blocks. return hn::TableLookupLanes(bytes, hn::SetTableIndices(d8, kRep4)); } else { // 256..512-bit, no efficient TableLookupLanes diff --git a/compression/python/BUILD.bazel b/compression/python/BUILD.bazel index 4d9b2ac..e3b7e36 100644 --- a/compression/python/BUILD.bazel +++ b/compression/python/BUILD.bazel @@ -26,7 +26,6 @@ cc_library( "//io", "//io:blob_store", "@highway//:hwy", - "@highway//:thread_pool", ], ) diff --git a/compression/sfp_test.cc b/compression/sfp_test.cc index 8e49ceb..df3e846 100644 --- a/compression/sfp_test.cc +++ b/compression/sfp_test.cc @@ -37,37 +37,23 @@ #include "hwy/foreach_target.h" // IWYU pragma: keep #include "hwy/highway.h" // After highway.h -#include "compression/sfp-inl.h" +#include "compression/compress-inl.h" #include "hwy/tests/test_util-inl.h" HWY_BEFORE_NAMESPACE(); namespace gcpp { namespace HWY_NAMESPACE { -// Decode -float F32FromSFP8(uint32_t sfp) { - HWY_ASSERT(sfp < 256); - HWY_ASSERT(sfp != 0x80); // -0 is reserved +HWY_INLINE_VAR constexpr bool kPrint = false; - const uint32_t sign32 = (sfp & 0x80) << 24; - sfp &= 0x7F; - const bool large_e = sfp >= 64; - const size_t m_bits = large_e ? 3 : 2; - uint32_t m = sfp & ((1u << m_bits) - 1u); - size_t e = sfp >> m_bits; - if (sfp == 0) return 0.0f; - const uint32_t e_bias = large_e ? 15 : 23; - const uint32_t exp32 = static_cast(127 + e - e_bias) << 23; - const uint32_t mnt32 = m << (23 - m_bits); - const uint32_t binary32 = sign32 | exp32 | mnt32; - float result; - hwy::CopySameSize(&binary32, &result); - return result; +static float F32FromSFP8(uint32_t sfp) { + return CompressTraits::ToFloatSlow( + SfpStream{static_cast(sfp)}); } // Used for HWY_AVX3_DL and newer. void PrintTables() { - if (HWY_ONCE && false) { + if (HWY_ONCE && kPrint) { uint8_t hi[128]; fprintf(stderr, "lo\n"); for (uint32_t sfp = 0; sfp < 128; ++sfp) { @@ -92,7 +78,7 @@ void TestAllUnique() { unique.insert(F32FromSFP8(sfp)); } HWY_ASSERT_EQ(size_t{255}, unique.size()); - if (false) { + if (kPrint) { for (float f : unique) { fprintf(stderr, "%e\n", f); } @@ -163,7 +149,7 @@ HWY_INLINE uint32_t SFP8FromF32(float f) { if (m == 0) m = 1; } - if (false) { + if (kPrint) { fprintf(stderr, "in %x round %x rounded %x e %d m %x large_e %d\n", org_binary32, round, rounded, e, m, large_e); } diff --git a/compression/test_util-inl.h b/compression/test_util-inl.h index 81442f7..f2c8b8c 100644 --- a/compression/test_util-inl.h +++ b/compression/test_util-inl.h @@ -105,7 +105,7 @@ MatStorageT GenerateMat(const Extents2D& extents, MatPadding padding, MatStorageT raw("raw", extents, ctx.allocator, MatPadding::kPacked); MatStorageT compressed("mat", extents, ctx.allocator, padding); const float scale = SfpStream::kMax / extents.Area(); - ParallelFor(ParallelismStrategy::kFlat, extents.rows, ctx, /*cluster_idx=*/0, + ParallelFor(Parallelism::kFlat, extents.rows, ctx, /*cluster_idx=*/0, Callers::kTest, [&](size_t r, size_t thread) { float* HWY_RESTRICT row = raw.Row(r); for (size_t c = 0; c < extents.cols; c++) { @@ -134,7 +134,7 @@ MatStorageT GenerateTransposedMat(const Extents2D extents, MatStorageT raw("raw", extents, ctx.allocator, MatPadding::kPacked); MatStorageT compressed("trans", extents, ctx.allocator, padding); const float scale = SfpStream::kMax / extents.Area(); - ParallelFor(ParallelismStrategy::kFlat, extents.rows, ctx, /*cluster_idx=*/0, + ParallelFor(Parallelism::kFlat, extents.rows, ctx, /*cluster_idx=*/0, Callers::kTest, [&](size_t r, size_t thread) { float* HWY_RESTRICT row = raw.Row(r); for (size_t c = 0; c < extents.cols; c++) { diff --git a/evals/benchmark.cc b/evals/benchmark.cc index 4dec9ee..69cd644 100644 --- a/evals/benchmark.cc +++ b/evals/benchmark.cc @@ -23,7 +23,9 @@ using json = nlohmann::json; class BenchmarkArgs : public ArgsBase { public: - BenchmarkArgs(int argc, char* argv[]) { InitAndParse(argc, argv); } + BenchmarkArgs(int argc, char* argv[], ConsumedArgs& consumed) { + InitAndParse(argc, argv, consumed); + } Path summarize_text; Path cross_entropy; @@ -127,9 +129,16 @@ int BenchmarkTriviaQA(GemmaEnv& env, const Path& json_file, } // namespace gcpp int main(int argc, char** argv) { - gcpp::GemmaEnv env(argc, argv); - gcpp::BenchmarkArgs benchmark_args(argc, argv); + gcpp::ConsumedArgs consumed(argc, argv); + gcpp::GemmaArgs args(argc, argv, consumed); + gcpp::BenchmarkArgs benchmark_args(argc, argv, consumed); + if (gcpp::HasHelp(argc, argv)) { + args.Help(); + return 0; + } + consumed.AbortIfUnconsumed(); + gcpp::GemmaEnv env(args); if (!benchmark_args.summarize_text.Empty()) { return BenchmarkSummary(env, benchmark_args.summarize_text); } else if (!benchmark_args.cross_entropy.Empty()) { diff --git a/evals/benchmark_helper.cc b/evals/benchmark_helper.cc index a495dea..30d364f 100644 --- a/evals/benchmark_helper.cc +++ b/evals/benchmark_helper.cc @@ -36,30 +36,29 @@ namespace gcpp { -GemmaEnv::GemmaEnv(const LoaderArgs& loader, const ThreadingArgs& threading, - const InferenceArgs& inference) - : ctx_(threading), env_(ctx_), gemma_(loader, inference, ctx_) { +GemmaEnv::GemmaEnv(const GemmaArgs& args) + : initializer_value_(gcpp::InternalInit()), + ctx_(args.threading), + env_(ctx_), + gemma_(args, ctx_) { const ModelConfig& config = gemma_.Config(); // Only allocate one for starters because GenerateBatch might not be called. - kv_caches_.push_back(KVCache(config, inference, ctx_.allocator)); + kv_caches_.push_back(KVCache(config, args.inference, ctx_.allocator)); - if (inference.verbosity >= 2) { - ShowConfig(loader, threading, inference, config, gemma_.WeightReadMode(), - ctx_); + if (args.inference.verbosity >= 2) { + ShowConfig(args, config, gemma_.WeightReadMode(), ctx_); } + if (args.inference.verbosity >= 3) env_.print_best = true; + if (args.inference.verbosity >= 4) env_.print_config = true; runtime_config_ = { - .max_generated_tokens = inference.max_generated_tokens, - .temperature = inference.temperature, - .verbosity = inference.verbosity, + .max_generated_tokens = args.inference.max_generated_tokens, + .temperature = args.inference.temperature, + .verbosity = args.inference.verbosity, }; - inference.CopyTo(runtime_config_); + args.inference.CopyTo(runtime_config_); } -GemmaEnv::GemmaEnv(int argc, char** argv) - : GemmaEnv(LoaderArgs(argc, argv), ThreadingArgs(argc, argv), - InferenceArgs(argc, argv)) {} - QueryResult GemmaEnv::QueryModel(const std::vector& tokens) { QueryResult result; @@ -229,19 +228,19 @@ static constexpr const char* CompiledConfig() { } } -void ShowConfig(const LoaderArgs& loader, const ThreadingArgs& threading, - const InferenceArgs& inference, const ModelConfig& config, +void ShowConfig(const GemmaArgs& args, const ModelConfig& config, const WeightsPtrs::Mode weight_read_mode, const ThreadingContext& ctx) { - threading.Print(inference.verbosity); - loader.Print(inference.verbosity); - inference.Print(inference.verbosity); - fprintf( - stderr, "Model : %s, to_bf16 %d, mmap %d => %s\n", - config.Specifier().c_str(), static_cast(loader.to_bf16), - static_cast(loader.map), WeightsPtrs::ToString(weight_read_mode)); + args.threading.Print(args.inference.verbosity); + args.loader.Print(args.inference.verbosity); + args.inference.Print(args.inference.verbosity); + fprintf(stderr, + "Model : %s, to_bf16 %d, mmap %d => %s\n", + config.Specifier().c_str(), static_cast(args.loader.to_bf16), + static_cast(args.loader.map), + WeightsPtrs::ToString(weight_read_mode)); - if (inference.verbosity >= 2) { + if (args.inference.verbosity >= 2) { time_t now = time(nullptr); char* dt = ctime(&now); // NOLINT char cpu100[100] = "unknown"; @@ -254,7 +253,7 @@ void ShowConfig(const LoaderArgs& loader, const ThreadingArgs& threading, "Instruction set : %s (%zu bits)\n" "Compiled config : %s, profiler %d\n" "Memory MiB : %4zu\n", - dt, cpu100, static_cast(threading.bind), + dt, cpu100, static_cast(args.threading.bind), ctx.topology.TopologyString(), ctx.pools.PinString(), CacheString().c_str(), hwy::TargetName(hwy::DispatchedTarget()), ctx.cache_info.VectorBytes() * 8, CompiledConfig(), @@ -262,22 +261,4 @@ void ShowConfig(const LoaderArgs& loader, const ThreadingArgs& threading, } } -void ShowHelp(const LoaderArgs& loader, const ThreadingArgs& threading, - const InferenceArgs& inference) { - std::cerr - << "\n\ngemma.cpp : a lightweight, standalone C++ inference engine\n" - "==========================================================\n\n" - "To run with pre-2025 weights, specify --tokenizer and --weights.\n" - "With the single-file weights format, specify just --weights.\n"; - std::cerr << "\n*Example Usage*\n\n./gemma --tokenizer tokenizer.spm " - "--weights gemma2-2b-it-sfp.sbs\n"; - std::cerr << "\n*Model Loading Arguments*\n\n"; - loader.Help(); - std::cerr << "\n*Threading Arguments*\n\n"; - threading.Help(); - std::cerr << "\n*Inference Arguments*\n\n"; - inference.Help(); - std::cerr << "\n"; -} - } // namespace gcpp diff --git a/evals/benchmark_helper.h b/evals/benchmark_helper.h index 2380dbf..203174c 100644 --- a/evals/benchmark_helper.h +++ b/evals/benchmark_helper.h @@ -23,7 +23,7 @@ #include "gemma/configs.h" #include "gemma/gemma.h" -#include "gemma/gemma_args.h" +#include "gemma/gemma_args.h" // IWYU pragma: export #include "gemma/tokenizer.h" // WrapAndTokenize #include "ops/matmul.h" #include "util/threading_context.h" @@ -50,10 +50,8 @@ struct QueryResultAndMetrics { // Convenience class to load a model and run inference. class GemmaEnv { public: - // Calls the other constructor with *Args arguments initialized from argv. - GemmaEnv(int argc, char** argv); - GemmaEnv(const LoaderArgs& loader, const ThreadingArgs& threading, - const InferenceArgs& inference); + explicit GemmaEnv(const GemmaArgs& args); + MatMulEnv& Env() { return env_; } size_t MaxGeneratedTokens() const { @@ -125,6 +123,8 @@ class GemmaEnv { MatMulEnv& MutableEnv() { return env_; } private: + // This is used to ensure that InternalInit is called before anything else. + int initializer_value_ = 0; ThreadingContext ctx_; MatMulEnv env_; Gemma gemma_; @@ -135,12 +135,9 @@ class GemmaEnv { // Logs the inference speed in tokens/sec. void LogSpeedStats(double time_start, size_t total_tokens); -void ShowConfig(const LoaderArgs& loader, const ThreadingArgs& threading, - const InferenceArgs& inference, const ModelConfig& config, +void ShowConfig(const GemmaArgs& args, const ModelConfig& config, WeightsPtrs::Mode weight_read_mode, const ThreadingContext& ctx); -void ShowHelp(const LoaderArgs& loader, const ThreadingArgs& threading, - const InferenceArgs& inference); } // namespace gcpp diff --git a/evals/benchmarks.cc b/evals/benchmarks.cc index 3cb3d3f..f44c62b 100644 --- a/evals/benchmarks.cc +++ b/evals/benchmarks.cc @@ -98,7 +98,11 @@ BENCHMARK(BM_coding_prompt) ->UseRealTime(); int main(int argc, char** argv) { - gcpp::GemmaEnv env(argc, argv); + gcpp::ConsumedArgs consumed(argc, argv); + gcpp::GemmaArgs args(argc, argv, consumed); + consumed.AbortIfUnconsumed(); + + gcpp::GemmaEnv env(args); env.SetMaxGeneratedTokens(256); gcpp::s_env = &env; diff --git a/evals/debug_prompt.cc b/evals/debug_prompt.cc index 66fa466..a6cf8c4 100644 --- a/evals/debug_prompt.cc +++ b/evals/debug_prompt.cc @@ -31,7 +31,9 @@ namespace gcpp { class PromptArgs : public ArgsBase { public: - PromptArgs(int argc, char* argv[]) { InitAndParse(argc, argv); } + PromptArgs(int argc, char* argv[], ConsumedArgs& consumed) { + InitAndParse(argc, argv, consumed); + } Path layers_output; // optional std::string prompt; @@ -51,11 +53,15 @@ class PromptArgs : public ArgsBase { }; int Run(int argc, char** argv) { - PromptArgs prompt_args(argc, argv); + ConsumedArgs consumed(argc, argv); + const GemmaArgs args(argc, argv, consumed); + const PromptArgs prompt_args(argc, argv, consumed); AbortIfInvalidArgs(prompt_args); + consumed.AbortIfUnconsumed(); json json_output; - GemmaEnv env(argc, argv); + GemmaEnv env(args); + env.MutableConfig().layers_output = prompt_args.layers_output.Empty() ? LayersOutputFunc() diff --git a/evals/gemma_batch_bench.cc b/evals/gemma_batch_bench.cc index 4a6f5ea..dd9cb45 100644 --- a/evals/gemma_batch_bench.cc +++ b/evals/gemma_batch_bench.cc @@ -48,7 +48,7 @@ class GemmaBatchBench : public ::testing::Test { } }; -TEST_F(GemmaBatchBench, RandomQuestionsBatched) { +std::vector GenerateInputs() { std::vector prompts = { {"Describe dynamic programming."}, {"Explain how electric cars work."}, @@ -122,33 +122,38 @@ TEST_F(GemmaBatchBench, RandomQuestionsBatched) { inputs.push_back(prompts[qpos++]); if (qpos == prompts.size()) qpos = 0; } - s_env->SetMaxGeneratedTokens(24); - std::vector responses = BatchGemmaReply(inputs); - for (size_t i = 0; i < HWY_MIN(hwy::Unpredictable1() * 3, responses.size()); - ++i) { - fprintf(stderr, "Batch answer %zu '%s'\n\n", i, responses[i].c_str()); - } - - PROFILER_PRINT_RESULTS(); - - // Run again: prefill will be faster due to autotuning. Fewer decode steps - // because those are already fast. - s_env->SetMaxGeneratedTokens(2); - responses = BatchGemmaReply(inputs); - - PROFILER_PRINT_RESULTS(); + return inputs; } + +TEST_F(GemmaBatchBench, RandomQuestionsBatched) { + s_env->SetMaxGeneratedTokens(12); + const std::vector inputs = GenerateInputs(); + + // Run multiple times so that auto-tuning is closer to complete. + for (size_t rep = 0; rep < 4; ++rep) { + std::vector responses = BatchGemmaReply(inputs); + for (size_t i = 0; i < HWY_MIN(hwy::Unpredictable1() * 3, responses.size()); + ++i) { + fprintf(stderr, "Rep %zu batch answer %zu '%s'\n\n", rep, i, + responses[i].c_str()); + } + PROFILER_PRINT_RESULTS(); + } +} + } // namespace } // namespace gcpp int main(int argc, char** argv) { fprintf(stderr, "GemmaEnv setup..\n"); - gcpp::GemmaEnv env(argc, argv); + gcpp::ConsumedArgs consumed(argc, argv); + gcpp::GemmaArgs args(argc, argv, consumed); + consumed.AbortIfUnconsumed(); + + gcpp::GemmaEnv env(args); gcpp::s_env = &env; testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); } - - diff --git a/evals/gemma_test.cc b/evals/gemma_test.cc index 04eb20a..a581561 100644 --- a/evals/gemma_test.cc +++ b/evals/gemma_test.cc @@ -22,7 +22,6 @@ #include "evals/benchmark_helper.h" #include "gemma/configs.h" -#include "io/io.h" #include "hwy/base.h" #include "hwy/tests/hwy_gtest.h" @@ -42,7 +41,11 @@ class GemmaTest : public ::testing::Test { // Requires argc/argv, hence do not use `SetUpTestSuite`. static void InitEnv(int argc, char** argv) { HWY_ASSERT(s_env == nullptr); // Should only be called once. - s_env = new GemmaEnv(argc, argv); + ConsumedArgs consumed(argc, argv); + GemmaArgs args(argc, argv, consumed); + consumed.AbortIfUnconsumed(); + + s_env = new GemmaEnv(args); const gcpp::ModelConfig& config = s_env->GetGemma()->Config(); fprintf(stderr, "Using %s\n", config.Specifier().c_str()); } @@ -130,7 +133,7 @@ TEST_F(GemmaTest, Multiturn) { // Note: we do not rewind any tokens here. If the model // produced one and WrapAndTokenize() inserts another one, it will just be // duplicated. - mutable_prompt = "Please repeat all prior statements."; + mutable_prompt = "Please repeat what I just told you."; tokens = WrapAndTokenize(model->Tokenizer(), model->ChatTemplate(), config.wrapping, abs_pos, mutable_prompt); @@ -167,6 +170,9 @@ TEST_F(GemmaTest, CrossEntropySmall) { case gcpp::Model::GEMMA2_27B: EXPECT_NEAR(entropy, 1.30f, 0.02f); break; + case gcpp::Model::GEMMA3_270M: + EXPECT_NEAR(entropy, 1.41f, 0.02f); + break; default: FAIL() << "no entropy expectation for this model"; break; @@ -178,7 +184,6 @@ TEST_F(GemmaTest, CrossEntropySmall) { int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); - gcpp::InternalInit(); gcpp::GemmaTest::InitEnv(argc, argv); int ret = RUN_ALL_TESTS(); gcpp::GemmaTest::DeleteEnv(); diff --git a/evals/run_mmlu.cc b/evals/run_mmlu.cc index 04a6e00..c6ce972 100644 --- a/evals/run_mmlu.cc +++ b/evals/run_mmlu.cc @@ -31,7 +31,9 @@ namespace gcpp { struct JsonArgs : public ArgsBase { - JsonArgs(int argc, char* argv[]) { InitAndParse(argc, argv); } + JsonArgs(int argc, char* argv[], ConsumedArgs& consumed) { + InitAndParse(argc, argv, consumed); + } Path input; @@ -151,10 +153,14 @@ void Run(GemmaEnv& env, JsonArgs& json) { int main(int argc, char** argv) { { PROFILER_ZONE("Startup.all"); - gcpp::GemmaEnv env(argc, argv); - gcpp::JsonArgs json(argc, argv); - gcpp::AbortIfInvalidArgs(json); - gcpp::Run(env, json); + gcpp::ConsumedArgs consumed(argc, argv); + gcpp::GemmaArgs args(argc, argv, consumed); + gcpp::JsonArgs json_args(argc, argv, consumed); + gcpp::AbortIfInvalidArgs(json_args); + consumed.AbortIfUnconsumed(); + + gcpp::GemmaEnv env(args); + gcpp::Run(env, json_args); } PROFILER_PRINT_RESULTS(); // Must call outside the zone above. return 0; diff --git a/examples/hello_world/CMakeLists.txt b/examples/hello_world/CMakeLists.txt index 1ff827e..2fd94b4 100644 --- a/examples/hello_world/CMakeLists.txt +++ b/examples/hello_world/CMakeLists.txt @@ -18,7 +18,7 @@ set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) include(FetchContent) -FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 2a16a50ff61071bb25ddef0ce35d92b0e2b9c579) +FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 3b680cde3a556bead9cc23c8f595d07a44d5a0d5) FetchContent_MakeAvailable(highway) FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 9045b2f60fa2b323dfac0eaef8fc17565036f9f9) FetchContent_MakeAvailable(sentencepiece) diff --git a/examples/hello_world/run.cc b/examples/hello_world/run.cc index f67324d..9cd8b0e 100644 --- a/examples/hello_world/run.cc +++ b/examples/hello_world/run.cc @@ -24,20 +24,20 @@ #include #include "gemma/gemma.h" -#include "gemma/gemma_args.h" // LoaderArgs +#include "gemma/gemma_args.h" // GemmaArgs #include "gemma/tokenizer.h" #include "util/args.h" #include "util/threading_context.h" #include "hwy/base.h" int main(int argc, char** argv) { - gcpp::LoaderArgs loader(argc, argv); - gcpp::ThreadingArgs threading(argc, argv); - gcpp::InferenceArgs inference(argc, argv); + gcpp::ConsumedArgs consumed(argc, argv); + gcpp::GemmaArgs args(argc, argv, consumed); if (gcpp::HasHelp(argc, argv)) { - loader.Help(); + args.Help(); return 0; } + consumed.AbortIfUnconsumed(); // Demonstrate constrained decoding by never outputting certain tokens. std::set reject_tokens; @@ -49,10 +49,10 @@ int main(int argc, char** argv) { } // Instantiate model and KV Cache - gcpp::ThreadingContext ctx(threading); + gcpp::ThreadingContext ctx(args.threading); gcpp::MatMulEnv env(ctx); - gcpp::Gemma gemma(loader, inference, ctx); - gcpp::KVCache kv_cache(gemma.Config(), inference, ctx.allocator); + gcpp::Gemma gemma(args, ctx); + gcpp::KVCache kv_cache(gemma.Config(), args.inference, ctx.allocator); size_t generated = 0; // Tokenize instructions. diff --git a/examples/simplified_gemma/CMakeLists.txt b/examples/simplified_gemma/CMakeLists.txt index 2fd4228..ca2e405 100644 --- a/examples/simplified_gemma/CMakeLists.txt +++ b/examples/simplified_gemma/CMakeLists.txt @@ -18,7 +18,7 @@ set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) include(FetchContent) -FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 2a16a50ff61071bb25ddef0ce35d92b0e2b9c579) +FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 3b680cde3a556bead9cc23c8f595d07a44d5a0d5) FetchContent_MakeAvailable(highway) FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 53de76561cfc149d3c01037f0595669ad32a5e7c) FetchContent_MakeAvailable(sentencepiece) diff --git a/examples/simplified_gemma/gemma.hpp b/examples/simplified_gemma/gemma.hpp index e5bb1d8..4a69923 100644 --- a/examples/simplified_gemma/gemma.hpp +++ b/examples/simplified_gemma/gemma.hpp @@ -23,7 +23,7 @@ #include #include "third_party/gemma_cpp/gemma/gemma.h" -#include "third_party/gemma_cpp/gemma/gemma_args.h" // LoaderArgs +#include "third_party/gemma_cpp/gemma/gemma_args.h" // GemmaArgs #include "third_party/gemma_cpp/gemma/tokenizer.h" #include "third_party/gemma_cpp/ops/matmul.h" #include "third_party/gemma_cpp/util/threading_context.h" @@ -31,18 +31,11 @@ class SimplifiedGemma { public: - SimplifiedGemma(const gcpp::LoaderArgs& loader, - const gcpp::ThreadingArgs& threading = gcpp::ThreadingArgs(), - const gcpp::InferenceArgs& inference = gcpp::InferenceArgs()) - : ctx_(threading), + SimplifiedGemma(const gcpp::GemmaArgs& args) + : ctx_(args.threading), env_(ctx_), - gemma_(loader, inference, ctx_), - kv_cache_(gemma_.Config(), inference, ctx_.allocator) {} - - SimplifiedGemma(int argc, char** argv) - : SimplifiedGemma(gcpp::LoaderArgs(argc, argv), - gcpp::ThreadingArgs(argc, argv), - gcpp::InferenceArgs(argc, argv)) {} + gemma_(args, ctx_), + kv_cache_(gemma_.Config(), args.inference, ctx_.allocator) {} void Generate(std::string& prompt, size_t max_generated_tokens = 1024, float temperature = 0.7, diff --git a/examples/simplified_gemma/run.cc b/examples/simplified_gemma/run.cc index b7af134..58356d2 100644 --- a/examples/simplified_gemma/run.cc +++ b/examples/simplified_gemma/run.cc @@ -18,28 +18,18 @@ #include #include "third_party/gemma_cpp/examples/simplified_gemma/gemma.hpp" -#include "gemma/gemma_args.h" // LoaderArgs +#include "gemma/gemma_args.h" int main(int argc, char** argv) { - // Standard usage: LoaderArgs takes argc and argv as input, then parses - // necessary flags. - gcpp::LoaderArgs loader(argc, argv); + // Sets arguments from argc and argv. Note that you can instead pass in + // LoaderArgs, ThreadingArgs, and InferenceArgs directly. + gcpp::ConsumedArgs consumed(argc, argv); + gcpp::GemmaArgs args(argc, argv, consumed); + consumed.AbortIfUnconsumed(); - // Optional: LoaderArgs can also take tokenizer and weights paths directly. - // - // gcpp::LoaderArgs loader("/path/to/tokenizer", "/path/to/weights", - // "model_identifier"); - - // Optional: ThreadingArgs and InferenceArgs can be passed in as well. If not - // specified, default values will be used. - // - // gcpp::InferenceArgs inference(argc, argv); - // gcpp::ThreadingArgs threading(argc, argv); - // SimplifiedGemma gemma(loader, threading, inference); - - SimplifiedGemma gemma(loader); + SimplifiedGemma gemma(args); std::string prompt = "Write a greeting to the world."; gemma.Generate(prompt, 256, 0.6); return 0; -} \ No newline at end of file +} diff --git a/gemma/activations.h b/gemma/activations.h index f474c84..adb6d02 100644 --- a/gemma/activations.h +++ b/gemma/activations.h @@ -23,44 +23,54 @@ #include #include -#include "gemma/configs.h" // ModelConfig -#include "ops/ops.h" // CreateInvTimescale -#include "util/basics.h" // BF16 -#include "util/mat.h" // MatStorageT +#include "gemma/configs.h" // ModelConfig +#include "gemma/gemma_args.h" // AttentionImpl +#include "gemma/kv_cache.h" +#include "gemma/tensor_stats.h" +#include "ops/ops.h" // CreateInvTimescale +#include "util/basics.h" // BF16 +#include "util/mat.h" // MatStorageT #include "util/threading_context.h" namespace gcpp { -struct AttentionActivations { - // Returns the scale value to use for the query in the attention computation. - // Also called by ops_test. - static inline float ChooseQueryScale(const ModelConfig& config) { - const LayerConfig& layer_config = config.layer_configs[0]; - if (config.query_scale == QueryScaleType::SqrtModelDimDivNumHeads) - return 1.0f / - sqrtf(static_cast(config.model_dim / layer_config.heads)); - // QueryScaleType::SqrtKeySize - return 1.0f / sqrtf(static_cast(layer_config.qkv_dim)); - } +// Returns the scale value to use for the query in the attention computation. +// Also called by ops_test. +static inline float ChooseQueryScale(const ModelConfig& config) { + const LayerConfig& layer_config = config.layer_configs[0]; + if (config.query_scale == QueryScaleType::SqrtModelDimDivNumHeads) + return 1.0f / + sqrtf(static_cast(config.model_dim / layer_config.heads)); + // QueryScaleType::SqrtKeySize + return 1.0f / sqrtf(static_cast(layer_config.qkv_dim)); +} +struct AttentionActivations { AttentionActivations( const ModelConfig& config, const LayerConfig& layer_config, - size_t batch_size, size_t seq_len, const Allocator& allocator, + size_t batch_size, size_t seq_len, const RuntimeConfig& runtime_config, + const Allocator& allocator, std::vector>& row_ptrs) - : config(config), - - // `vocab_size == 0` means it is for Vit part, VitAttention is still MHA - // and does not use an external KV cache. + : // `vocab_size == 0` means it is for Vit part, VitAttention is still + // MHA and does not use an external KV cache. q(MatFactory("q", batch_size, config.vocab_size == 0 ? layer_config.heads * 3 * layer_config.qkv_dim : layer_config.heads * layer_config.qkv_dim, allocator)), + q_bf(MatFactory("q_bf", batch_size, + config.vocab_size == 0 + ? layer_config.heads * 3 * layer_config.qkv_dim + : layer_config.heads * layer_config.qkv_dim, + allocator)), q_T(MatFactory("q_T", layer_config.qkv_dim, config.vocab_size == 0 ? batch_size * layer_config.heads * 3 : batch_size * layer_config.heads, allocator)), + vit_Q(MatFactory("Q2", batch_size, layer_config.qkv_dim, allocator)), + vit_K(MatFactory("K2", seq_len, layer_config.qkv_dim, allocator)), + vit_C(MatFactory("C2", batch_size, seq_len, allocator)), pre_att_rms_out(MatFactory("pre_att_rms_out", batch_size, config.model_dim, allocator)), att(MatFactory("att", batch_size, layer_config.heads * seq_len, @@ -68,6 +78,10 @@ struct AttentionActivations { att_out(MatFactory("att_out", batch_size, layer_config.heads * layer_config.qkv_dim, allocator)), + softmax_max(MatFactory("softmax_max", batch_size, layer_config.heads, + allocator)), + softmax_d( + MatFactory("softmax_d", batch_size, layer_config.heads, allocator)), att_sums( MatFactory("att_sums", batch_size, config.model_dim, allocator)), @@ -76,11 +90,7 @@ struct AttentionActivations { layer_config.post_qk == PostQKType::HalfRope)), inv_timescale_global(CreateInvTimescale( allocator, layer_config.qkv_dim, - layer_config.post_qk == PostQKType::HalfRope, 1000000.0)), - - div_seq_len(static_cast(seq_len)), - div_heads(static_cast(layer_config.heads)), - query_scale(ChooseQueryScale(config)) { + layer_config.post_qk == PostQKType::HalfRope, 1000000.0)) { // Batch size can be 0 in experimental code so do not assert. if (batch_size == 0) { static std::atomic_flag warned = ATOMIC_FLAG_INIT; @@ -94,44 +104,153 @@ struct AttentionActivations { // If we forget any MatMul outputs here, debug builds print a warning but // fill them in each MatMul call. q.AllocateAndAttachRowPtrs(row_ptrs); + q_bf.AllocateAndAttachRowPtrs(row_ptrs); q_T.AllocateAndAttachRowPtrs(row_ptrs); + vit_C.AllocateAndAttachRowPtrs(row_ptrs); att_sums.AllocateAndAttachRowPtrs(row_ptrs); } void SetBatchSize(size_t batch_size) { q.OverrideRows(batch_size); + q_bf.OverrideRows(batch_size); // q_T rows are always qkv_dim! + vit_Q.OverrideRows(batch_size); + // vit_K stays seq_len! + vit_C.OverrideRows(batch_size); + pre_att_rms_out.OverrideRows(batch_size); att.OverrideRows(batch_size); att_out.OverrideRows(batch_size); + softmax_max.OverrideRows(batch_size); + softmax_d.OverrideRows(batch_size); att_sums.OverrideRows(batch_size); + + // `inv_timescale*` are not batched. } - const ModelConfig& config; - MatStorageT q; // query - MatStorageT q_T; // Transposed to maximize attention speed. + MatStorageT q_bf; + MatStorageT q_T; // Transposed to maximize attention speed. + + MatStorageT vit_Q; + MatStorageT vit_K; + MatStorageT vit_C; MatStorageT pre_att_rms_out; - MatStorageT att; // attention vector - MatStorageT att_out; // attention output + MatStorageT att; // attention vector + MatStorageT att_out; // attention output + MatStorageT softmax_max; // see OnlineSoftmaxState + MatStorageT softmax_d; // see OnlineSoftmaxState // Accumulation of attention outputs over heads MatStorageT att_sums; // Rope MatStorageT inv_timescale; MatStorageT inv_timescale_global; +}; +// A non-owning view of AttentionActivations. +struct AttentionActivationsPtrs { + AttentionActivationsPtrs(const ModelConfig& config, size_t seq_len) + : config(config), + div_seq_len(static_cast(seq_len)), + div_heads(static_cast(config.layer_configs[0].heads)), + query_scale(ChooseQueryScale(config)) {} + + AttentionActivationsPtrs(const ModelConfig& config, size_t seq_len, + const AttentionActivations& activations) + : AttentionActivationsPtrs(config, seq_len) { + q = activations.q; + q_bf = activations.q_bf; + q_T = activations.q_T; + vit_Q = activations.vit_Q; + vit_K = activations.vit_K; + vit_C = activations.vit_C; + pre_att_rms_out = activations.pre_att_rms_out; + att = activations.att; + att_out = activations.att_out; + softmax_max = activations.softmax_max; + softmax_d = activations.softmax_d; + att_sums = activations.att_sums; + inv_timescale = activations.inv_timescale; + inv_timescale_global = activations.inv_timescale_global; + } + + void SetBatchSize(size_t batch_size) { + q.OverrideRows(batch_size); + q_bf.OverrideRows(batch_size); + // q_T rows are always qkv_dim! + + vit_Q.OverrideRows(batch_size); + // vit_K stays seq_len! + vit_C.OverrideRows(batch_size); + + pre_att_rms_out.OverrideRows(batch_size); + att.OverrideRows(batch_size); + att_out.OverrideRows(batch_size); + softmax_max.OverrideRows(batch_size); + softmax_d.OverrideRows(batch_size); + att_sums.OverrideRows(batch_size); + // `inv_timescale*` are not batched. + } + + size_t SeqLen() const { + return static_cast(div_seq_len.GetDivisor()); + } + + const ModelConfig& config; + + // For the matrices below, the batch_size dimension is really qbatch.Size() * + // token_batch_size, but in all known uses, one of those is 1. Specifically, + // during PrefillTBatch, it is prompt length (up to some max batch size) + // and otherwise it's qbatch.Size(). + + // Query matrix of size batch_size x (q_heads * qkv_dim). + MatPtrT q; + // Query matrix of size batch_size x (q_heads * qkv_dim). + MatPtrT q_bf; + // Transposed query matrix for faster Q*K^T. + MatPtrT q_T; + + MatPtrT vit_Q; + MatPtrT vit_K; + MatPtrT vit_C; + + // Output of RMSNorm before attention, size batch_size x model_dim. + MatPtrT pre_att_rms_out; + // Attention scores computed from Q*K^T, size batch_size x (q_heads * + // seq_len). + MatPtrT att; + // Attention output computed from att * V, size batch_size x (q_heads * + // qkv_dim). + MatPtrT att_out; + // The maximum logit value encountered when computing att_out from att, + // size batch_size x q_heads . See OnlineSoftmaxState for details. + // WARNING: Only filled in for AttentionImpl::kOld. + MatPtrT softmax_max; + // The sum of scaled exponentials when computing att_out from att, + // size batch_size x q_heads . See OnlineSoftmaxState for details. + // WARNING: Only filled in for AttentionImpl::kOld. + MatPtrT softmax_d; + // Accumulation of attention outputs over heads, size batch_size x + // model_dim. + MatPtrT att_sums; + // Inverse timescales for RoPE computation. + MatPtrT inv_timescale; + // Inverse timescales for global RoPE computation. + MatPtrT inv_timescale_global; + // Divisor for faster division by sequence length. hwy::Divisor div_seq_len; - // Unfortunately, some models have had non-power-of-two heads. + // Divisor for faster division by number of heads. hwy::Divisor div_heads; + // Query scaling factor for attention computation. float query_scale; }; struct Activations { - Activations(const ModelConfig& config, size_t batch_size, size_t seq_len, - ThreadingContext& ctx, + Activations(const RuntimeConfig& runtime_config, const ModelConfig& config, + size_t batch_size, size_t seq_len, ThreadingContext& ctx, std::vector>& row_ptrs) : layer_config(config.layer_configs[0]), @@ -150,8 +269,18 @@ struct Activations { ffw_out( MatFactory("ffw_out", batch_size, config.model_dim, ctx.allocator)), - attention(config, layer_config, batch_size, seq_len, ctx.allocator, - row_ptrs) { + max_workers(ctx.pools.MaxWorkers()), + s_ffw_in(config.num_layers, max_workers), + s_ffw_hidden(config.num_layers, max_workers), + s_ffw_out(config.num_layers, max_workers), + + s_w_gating_einsum_w1(config.num_layers, max_workers), + s_w_gating_einsum_w2(config.num_layers, max_workers), + s_w_linear_w(config.num_layers, max_workers), + attention_impl(runtime_config.attention_impl), + attention_storage(config, layer_config, batch_size, seq_len, + runtime_config, ctx.allocator, row_ptrs), + attention(config, seq_len, attention_storage) { HWY_ASSERT(batch_size != 0); // For MatMul outputs, precompute their row pointers. @@ -167,6 +296,12 @@ struct Activations { // Note that BindC on any MatMul output considerably slows down Prefill. } + ~Activations() { + s_ffw_in.ReduceAndPrint("ffw_in"); + s_ffw_hidden.ReduceAndPrint("ffw_hidden"); + s_ffw_out.ReduceAndPrint("ffw_out"); + } + // Negligible CPU time. void SetBatchSize(size_t batch_size) { x.OverrideRows(batch_size); @@ -179,12 +314,15 @@ struct Activations { C2.OverrideRows(batch_size); ffw_out.OverrideRows(batch_size); + attention_storage.SetBatchSize(batch_size); + // `AttentionActivationsPtrs` holds `MatPtrT` which also require updating; + // their row override is not updated when the underlying storage changes. attention.SetBatchSize(batch_size); } const LayerConfig& layer_config; - MatStorageT x; // input + MatStorageT x; // input MatStorageT x_bf; // output of final RMSNorm, input to EmbeddingMatmul MatStorageT logits; // TODO: BF16 after Softmax supports that. MatStorageT sampled; // batch_size x 3 (padded) @@ -195,7 +333,19 @@ struct Activations { MatStorageT C2; MatStorageT ffw_out; - AttentionActivations attention; + const size_t max_workers; + TensorStats s_ffw_in; + TensorStats s_ffw_hidden; // after Activation+gating + TensorStats s_ffw_out; + + TensorStats s_w_gating_einsum_w1; + TensorStats s_w_gating_einsum_w2; + TensorStats s_w_linear_w; + + AttentionImpl attention_impl; + + AttentionActivations attention_storage; + AttentionActivationsPtrs attention; }; } // namespace gcpp diff --git a/gemma/api_client.cc b/gemma/api_client.cc index 1f64d96..e6ce191 100644 --- a/gemma/api_client.cc +++ b/gemma/api_client.cc @@ -15,18 +15,22 @@ // Test client for API server -#include -#include -#include +#include + #include #include +#include +#include +#include #include "httplib.h" -#include "nlohmann/json.hpp" #include "gemma/gemma_args.h" +#include "nlohmann/json.hpp" using json = nlohmann::json; +namespace gcpp { + // ANSI color codes const std::string RESET = "\033[0m"; const std::string BOLD = "\033[1m"; @@ -37,9 +41,15 @@ const std::string YELLOW = "\033[33m"; const std::string RED = "\033[31m"; class APIClient { -public: - APIClient(const std::string& host, int port, const std::string& api_key = "", const std::string& model = "gemma3-4b") - : host_(host), port_(port), api_key_(api_key), model_(model), use_https_(port == 443), interactive_mode_(false) { + public: + APIClient(const std::string& host, int port, const std::string& api_key = "", + const std::string& model = "gemma3-4b") + : host_(host), + port_(port), + api_key_(api_key), + model_(model), + use_https_(port == 443), + interactive_mode_(false) { if (use_https_) { ssl_client_ = std::make_unique(host, port); ssl_client_->set_read_timeout(60, 0); @@ -55,22 +65,25 @@ public: // Unified request processing for both public and local APIs json ProcessRequest(const json& request, bool stream = true) { bool is_public_api = !api_key_.empty(); - + std::string endpoint; if (is_public_api) { - endpoint = stream ? "/v1beta/models/gemini-2.0-flash:streamGenerateContent?alt=sse" - : "/v1beta/models/gemini-2.0-flash:generateContent"; + endpoint = + stream + ? "/v1beta/models/gemini-2.0-flash:streamGenerateContent?alt=sse" + : "/v1beta/models/gemini-2.0-flash:generateContent"; } else { - endpoint = stream ? "/v1beta/models/" + model_ + ":streamGenerateContent" + endpoint = stream ? "/v1beta/models/" + model_ + ":streamGenerateContent" : "/v1beta/models/" + model_ + ":generateContent"; } - + // Only show verbose output in non-interactive mode if (!interactive_mode_) { - std::cout << "\n" << BOLD << BLUE << "📤 POST " << endpoint << RESET << std::endl; + std::cout << "\n" + << BOLD << BLUE << "📤 POST " << endpoint << RESET << std::endl; std::cout << "Request: " << request.dump(2) << std::endl; } - + if (stream) { return ProcessStreamingRequest(request, endpoint); } else { @@ -81,21 +94,24 @@ public: void TestGenerateContent(const std::string& prompt, bool stream = true) { json request = CreateAPIRequest(prompt); json response = ProcessRequest(request, stream); - + if (response.contains("error")) { - std::cerr << RED << "❌ Error: " << response["error"]["message"] << RESET << std::endl; + std::cerr << RED << "❌ Error: " << response["error"]["message"] << RESET + << std::endl; } } void TestListModels() { - std::cout << "\n" << BOLD << BLUE << "📤 GET /v1beta/models" << RESET << std::endl; - + std::cout << "\n" + << BOLD << BLUE << "📤 GET /v1beta/models" << RESET << std::endl; + httplib::Headers headers; if (!api_key_.empty()) { headers.emplace("X-goog-api-key", api_key_); } - auto res = use_https_ ? ssl_client_->Get("/v1beta/models", headers) : client_->Get("/v1beta/models", headers); - + auto res = use_https_ ? ssl_client_->Get("/v1beta/models", headers) + : client_->Get("/v1beta/models", headers); + if (res && res->status == 200) { json response = json::parse(res->body); std::cout << GREEN << "✅ Available models:" << RESET << std::endl; @@ -106,49 +122,53 @@ public: } void InteractiveChat() { - std::cout << "\n" << BOLD << CYAN << "💬 Interactive Chat Mode (with session)" << RESET << std::endl; + std::cout << "\n" + << BOLD << CYAN << "💬 Interactive Chat Mode (with session)" + << RESET << std::endl; std::cout << "Type ':gemma %q' to end.\n" << std::endl; - + interactive_mode_ = true; json messages; - + while (true) { std::cout << BOLD << BLUE << "You: " << RESET; std::string input; std::getline(std::cin, input); - + if (input == ":gemma %q") { std::cout << BOLD << YELLOW << "👋 Goodbye!" << RESET << std::endl; break; } - + if (input.empty()) continue; - + // Add user message with proper role json user_message = {{"parts", {{{"text", input}}}}}; if (!api_key_.empty()) { user_message["role"] = "user"; } messages.push_back(user_message); - + // Create request using unified logic json request = CreateAPIRequest("", messages); - + std::cout << BOLD << GREEN << "Assistant: " << RESET; - + // Use unified processing - streaming for real-time output json response = ProcessRequest(request, true); - + if (response.contains("candidates") && !response["candidates"].empty()) { auto& candidate = response["candidates"][0]; - if (candidate.contains("content") && candidate["content"].contains("parts")) { + if (candidate.contains("content") && + candidate["content"].contains("parts")) { for (const auto& part : candidate["content"]["parts"]) { if (part.contains("text")) { std::string assistant_response = part["text"].get(); - + // For streaming, the response is already displayed in real-time // Just add to message history for context - json assistant_message = {{"parts", {{{"text", assistant_response}}}}}; + json assistant_message = { + {"parts", {{{"text", assistant_response}}}}}; if (!api_key_.empty()) { assistant_message["role"] = "model"; } @@ -157,23 +177,21 @@ public: } } } else if (response.contains("error")) { - std::cerr << RED << "❌ Error: " << response["error"]["message"] << RESET << std::endl; + std::cerr << RED << "❌ Error: " << response["error"]["message"] + << RESET << std::endl; } - + std::cout << std::endl; } } -private: - json CreateAPIRequest(const std::string& prompt, const json& messages = json::array()) { + private: + json CreateAPIRequest(const std::string& prompt, + const json& messages = json::array()) { json request = { - {"generationConfig", { - {"temperature", 0.9}, - {"topK", 1}, - {"maxOutputTokens", 1024} - }} - }; - + {"generationConfig", + {{"temperature", 0.9}, {"topK", 1}, {"maxOutputTokens", 1024}}}}; + if (messages.empty()) { // Single prompt json user_message = {{"parts", {{{"text", prompt}}}}}; @@ -185,44 +203,48 @@ private: // Use provided message history request["contents"] = messages; } - + return request; } - json ProcessNonStreamingRequest(const json& request, const std::string& endpoint) { + json ProcessNonStreamingRequest(const json& request, + const std::string& endpoint) { httplib::Headers headers = {{"Content-Type", "application/json"}}; if (!api_key_.empty()) { headers.emplace("X-goog-api-key", api_key_); } - - auto res = use_https_ ? ssl_client_->Post(endpoint, headers, request.dump(), "application/json") - : client_->Post(endpoint, headers, request.dump(), "application/json"); - + + auto res = use_https_ ? ssl_client_->Post(endpoint, headers, request.dump(), + "application/json") + : client_->Post(endpoint, headers, request.dump(), + "application/json"); + if (res && res->status == 200) { json response = json::parse(res->body); if (!interactive_mode_) { - std::cout << "\n" << BOLD << GREEN << "📥 Response:" << RESET << std::endl; + std::cout << "\n" + << BOLD << GREEN << "📥 Response:" << RESET << std::endl; std::cout << response.dump(2) << std::endl; } return response; } else { - json error_response = { - {"error", { - {"message", "Request failed"}, - {"status", res ? res->status : -1} - }} - }; + json error_response = {{"error", + {{"message", "Request failed"}, + {"status", res ? res->status : -1}}}}; if (res && !res->body.empty()) { error_response["error"]["details"] = res->body; } - std::cerr << RED << "❌ Request failed. Status: " << (res ? res->status : -1) << RESET << std::endl; + std::cerr << RED + << "❌ Request failed. Status: " << (res ? res->status : -1) + << RESET << std::endl; return error_response; } } - json ProcessStreamingRequest(const json& request, const std::string& endpoint) { + json ProcessStreamingRequest(const json& request, + const std::string& endpoint) { std::string accumulated_response; - + // Use same SSE logic for both public and local APIs httplib::Request req; req.method = "POST"; @@ -232,72 +254,73 @@ private: req.set_header("X-goog-api-key", api_key_); } req.body = request.dump(); - - req.content_receiver = [&accumulated_response, this](const char* data, size_t data_length, uint64_t offset, uint64_t total_length) -> bool { - std::string chunk(data, data_length); - std::istringstream stream(chunk); - std::string line; - - while (std::getline(stream, line)) { - if (line.substr(0, 6) == "data: ") { - std::string event_data = line.substr(6); - - if (event_data == "[DONE]") { - if (!interactive_mode_) { - std::cout << "\n\n" << GREEN << "✅ Generation complete!" << RESET << std::endl; - } - } else { - try { - json event = json::parse(event_data); - if (event.contains("candidates") && !event["candidates"].empty()) { - auto& candidate = event["candidates"][0]; - if (candidate.contains("content") && candidate["content"].contains("parts")) { - for (const auto& part : candidate["content"]["parts"]) { - if (part.contains("text")) { - std::string text = part["text"].get(); - std::cout << text << std::flush; - accumulated_response += text; - } + + req.content_receiver = [&accumulated_response, this]( + const char* data, size_t data_length, + uint64_t offset, uint64_t total_length) -> bool { + std::string chunk(data, data_length); + std::istringstream stream(chunk); + std::string line; + + while (std::getline(stream, line)) { + if (line.substr(0, 6) == "data: ") { + std::string event_data = line.substr(6); + + if (event_data == "[DONE]") { + if (!interactive_mode_) { + std::cout << "\n\n" + << GREEN << "✅ Generation complete!" << RESET + << std::endl; + } + } else { + try { + json event = json::parse(event_data); + if (event.contains("candidates") && + !event["candidates"].empty()) { + auto& candidate = event["candidates"][0]; + if (candidate.contains("content") && + candidate["content"].contains("parts")) { + for (const auto& part : candidate["content"]["parts"]) { + if (part.contains("text")) { + std::string text = part["text"].get(); + std::cout << text << std::flush; + accumulated_response += text; } } } - } catch (const json::exception& e) { - // Skip parse errors } + } catch (const json::exception& e) { + // Skip parse errors } } } - return true; - }; - + } + return true; + }; + httplib::Response res; httplib::Error error; - bool success = use_https_ ? ssl_client_->send(req, res, error) : client_->send(req, res, error); - + bool success = use_https_ ? ssl_client_->send(req, res, error) + : client_->send(req, res, error); + if (res.status == 200 && !accumulated_response.empty()) { return json{ - {"candidates", {{ - {"content", { - {"parts", {{{"text", accumulated_response}}}} - }} - }}} - }; + {"candidates", + {{{"content", {{"parts", {{{"text", accumulated_response}}}}}}}}}}; } else { json error_response = { - {"error", { - {"message", "Streaming request failed"}, - {"status", res.status} - }} - }; + {"error", + {{"message", "Streaming request failed"}, {"status", res.status}}}}; if (!res.body.empty()) { error_response["error"]["details"] = res.body; } - std::cerr << RED << "❌ Streaming request failed. Status: " << res.status << RESET << std::endl; + std::cerr << RED << "❌ Streaming request failed. Status: " << res.status + << RESET << std::endl; return error_response; } } -private: + private: std::unique_ptr client_; std::unique_ptr ssl_client_; std::string host_; @@ -308,19 +331,55 @@ private: bool interactive_mode_; }; +struct ClientArgs : public ArgsBase { + ClientArgs(int argc, char* argv[], ConsumedArgs& consumed) { + InitAndParse(argc, argv, consumed); + } + ClientArgs() { Init(); }; + + std::string host; + int port; + std::string api_key; + std::string model; + std::string prompt; + bool interactive; + + template + void ForEach(const Visitor& visitor) { + visitor(host, "host", std::string("localhost"), + "Server host (default: localhost)"); + visitor(port, "port", 8080, "Server port (default: 8080)"); + visitor(api_key, "api_key", std::string(""), + "Use public API with key (changes host to " + "generativelanguage.googleapis.com:443)"); + visitor(model, "model", std::string("gemma3-4b"), + "Model name to use (default: gemma3-4b)"); + visitor(prompt, "prompt", std::string("Hello! How are you?"), + "Prompt for generation (default: 'Hello! How are you?')"); + visitor(interactive, "interactive", false, + "Start interactive chat mode (0 = no, 1 = yes)"); + } +}; + +} // namespace gcpp + int main(int argc, char* argv[]) { - gcpp::ClientArgs client_args(argc, argv); - + gcpp::ConsumedArgs consumed(argc, argv); + gcpp::ClientArgs client_args(argc, argv, consumed); + if (gcpp::HasHelp(argc, argv)) { - std::cout << "\nAPI Client for gemma.cpp\n"; - std::cout << "========================\n\n"; + fprintf(stderr, + "\nAPI Client for gemma.cpp\n" + "========================\n\n"); client_args.Help(); - std::cout << std::endl; - std::cout << "Environment Variables:" << std::endl; - std::cout << " GOOGLE_API_KEY : Automatically use public Google API if set" << std::endl; + fprintf(stderr, + "\n*Environment Variables:\n" + " GOOGLE_API_KEY : Automatically use public Google API if set\n"); return 0; } - + + consumed.AbortIfUnconsumed(); + // Check for GOOGLE_API_KEY environment variable const char* env_api_key = std::getenv("GOOGLE_API_KEY"); if (env_api_key != nullptr && strlen(env_api_key) > 0) { @@ -328,32 +387,34 @@ int main(int argc, char* argv[]) { client_args.host = "generativelanguage.googleapis.com"; client_args.port = 443; } - + // Handle API key override if (!client_args.api_key.empty()) { client_args.host = "generativelanguage.googleapis.com"; client_args.port = 443; } - - std::cout << BOLD << YELLOW << "🚀 Testing API Server at " - << client_args.host << ":" << client_args.port << RESET << std::endl; - + + std::cout << BOLD << YELLOW << "🚀 Testing API Server at " << client_args.host + << ":" << client_args.port << RESET << std::endl; + try { - APIClient client(client_args.host, client_args.port, client_args.api_key, client_args.model); - + APIClient client(client_args.host, client_args.port, client_args.api_key, + client_args.model); + if (client_args.interactive) { client.InteractiveChat(); } else { client.TestListModels(); client.TestGenerateContent(client_args.prompt, true); } - } catch (const std::exception& e) { std::cerr << RED << "❌ Error: " << e.what() << RESET << std::endl; std::cerr << "Make sure the API server is running:" << std::endl; - std::cerr << " ./build/gemma_api_server --tokenizer --weights " << std::endl; + std::cerr + << " ./build/gemma_api_server --tokenizer --weights " + << std::endl; return 1; } - + return 0; } diff --git a/gemma/api_server.cc b/gemma/api_server.cc index f05447b..8f71043 100644 --- a/gemma/api_server.cc +++ b/gemma/api_server.cc @@ -15,22 +15,19 @@ // HTTP API server for gemma.cpp with SSE support -#include #include +#include -#include -#include -#include -#include -#include -#include -#include #include #include -#include -#include +#include +#include #include +#include +#include +#include // NOLINT #include +#include // HTTP server library #undef CPPHTTPLIB_OPENSSL_SUPPORT @@ -38,16 +35,12 @@ #include "httplib.h" // JSON library -#include "nlohmann/json.hpp" - -#include "compression/types.h" #include "gemma/gemma.h" #include "gemma/gemma_args.h" -#include "gemma/tokenizer.h" #include "ops/matmul.h" #include "util/args.h" #include "hwy/base.h" -#include "hwy/profiler.h" +#include "nlohmann/json.hpp" using json = nlohmann::json; @@ -90,7 +83,8 @@ struct ServerState { std::lock_guard lock(sessions_mutex); auto& session = sessions[session_id]; if (!session.kv_cache) { - session.kv_cache = std::make_unique(gemma->Config(), InferenceArgs(), env->ctx.allocator); + session.kv_cache = std::make_unique( + gemma->Config(), InferenceArgs(), env->ctx.allocator); } session.last_access = std::chrono::steady_clock::now(); return session; @@ -107,7 +101,8 @@ std::string GenerateSessionId() { return ss.str(); } -// Wraps messages with start_of_turn markers - handles both with and without roles +// Wraps messages with start_of_turn markers - handles both with and without +// roles std::string WrapMessagesWithTurnMarkers(const json& contents) { std::string prompt; @@ -121,12 +116,14 @@ std::string WrapMessagesWithTurnMarkers(const json& contents) { std::string text = part["text"]; if (role == "user") { - prompt += "user\n" + text + "\nmodel\n"; + prompt += + "user\n" + text + "\nmodel\n"; } else if (role == "model") { prompt += text + "\n"; } else if (role.empty()) { // Local format without roles - for now, treat as user input - prompt += "user\n" + text + "\nmodel\n"; + prompt += + "user\n" + text + "\nmodel\n"; } } } @@ -163,18 +160,15 @@ RuntimeConfig ParseGenerationConfig(const json& request) { return config; } -// Unified response formatter - creates consistent format regardless of request type -json CreateAPIResponse(const std::string& text, bool is_streaming_chunk = false) { +// Unified response formatter - creates consistent format regardless of request +// type +json CreateAPIResponse(const std::string& text, + bool is_streaming_chunk = false) { json response = { - {"candidates", {{ - {"content", { - {"parts", {{{"text", text}}}}, - {"role", "model"} - }}, - {"index", 0} - }}}, - {"promptFeedback", {{"safetyRatings", json::array()}}} - }; + {"candidates", + {{{"content", {{"parts", {{{"text", text}}}}, {"role", "model"}}}, + {"index", 0}}}}, + {"promptFeedback", {{"safetyRatings", json::array()}}}}; // Only add finishReason for non-streaming chunks if (!is_streaming_chunk) { @@ -185,7 +179,9 @@ json CreateAPIResponse(const std::string& text, bool is_streaming_chunk = false) } // Handle generateContent endpoint (non-streaming) -void HandleGenerateContentNonStreaming(ServerState& state, const httplib::Request& req, httplib::Response& res) { +void HandleGenerateContentNonStreaming(ServerState& state, + const httplib::Request& req, + httplib::Response& res) { try { json request = json::parse(req.body); @@ -199,7 +195,9 @@ void HandleGenerateContentNonStreaming(ServerState& state, const httplib::Reques prompt = WrapMessagesWithTurnMarkers(request["contents"]); } else { res.status = 400; - res.set_content(json{{"error", {{"message", "Missing 'contents' field"}}}}.dump(), "application/json"); + res.set_content( + json{{"error", {{"message", "Missing 'contents' field"}}}}.dump(), + "application/json"); return; } @@ -209,12 +207,7 @@ void HandleGenerateContentNonStreaming(ServerState& state, const httplib::Reques // Set up runtime config RuntimeConfig runtime_config = ParseGenerationConfig(request); - // Collect full response - std::string full_response; - runtime_config.stream_token = [&full_response](int token, float) { - // Skip EOS token - return true; - }; + runtime_config.stream_token = [](int token, float) { return true; }; // Tokenize prompt std::vector tokens = WrapAndTokenize( @@ -227,7 +220,8 @@ void HandleGenerateContentNonStreaming(ServerState& state, const httplib::Reques // Temporarily redirect output to capture response std::stringstream output; - runtime_config.stream_token = [&output, &state, &session, &tokens](int token, float) { + runtime_config.stream_token = [&output, &state, &session, &tokens]( + int token, float) { // Skip prompt tokens if (session.abs_pos < tokens.size()) { session.abs_pos++; @@ -279,7 +273,9 @@ void HandleGenerateContentNonStreaming(ServerState& state, const httplib::Reques } // Handle streamGenerateContent endpoint with SSE) -void HandleGenerateContentStreaming(ServerState& state, const httplib::Request& req, httplib::Response& res) { +void HandleGenerateContentStreaming(ServerState& state, + const httplib::Request& req, + httplib::Response& res) { try { json request = json::parse(req.body); @@ -293,7 +289,9 @@ void HandleGenerateContentStreaming(ServerState& state, const httplib::Request& prompt = WrapMessagesWithTurnMarkers(request["contents"]); } else { res.status = 400; - res.set_content(json{{"error", {{"message", "Missing 'contents' field"}}}}.dump(), "application/json"); + res.set_content( + json{{"error", {{"message", "Missing 'contents' field"}}}}.dump(), + "application/json"); return; } @@ -305,88 +303,85 @@ void HandleGenerateContentStreaming(ServerState& state, const httplib::Request& // Set up chunked content provider for SSE res.set_chunked_content_provider( - "text/event-stream", - [&state, request, prompt, session_id](size_t offset, httplib::DataSink& sink) { - try { - // Lock for inference - std::lock_guard lock(state.inference_mutex); - auto& session = state.GetOrCreateSession(session_id); + "text/event-stream", [&state, request, prompt, session_id]( + size_t offset, httplib::DataSink& sink) { + try { + // Lock for inference + std::lock_guard lock(state.inference_mutex); + auto& session = state.GetOrCreateSession(session_id); - // Set up runtime config - RuntimeConfig runtime_config = ParseGenerationConfig(request); + // Set up runtime config + RuntimeConfig runtime_config = ParseGenerationConfig(request); - // Tokenize prompt - std::vector tokens = WrapAndTokenize( - state.gemma->Tokenizer(), state.gemma->ChatTemplate(), - state.gemma->Config().wrapping, session.abs_pos, prompt); + // Tokenize prompt + std::vector tokens = WrapAndTokenize( + state.gemma->Tokenizer(), state.gemma->ChatTemplate(), + state.gemma->Config().wrapping, session.abs_pos, prompt); + + // Stream token callback + std::string accumulated_text; + auto stream_token = [&](int token, float) { + // Skip prompt tokens + if (session.abs_pos < tokens.size()) { + session.abs_pos++; + return true; + } - // Stream token callback - std::string accumulated_text; - auto stream_token = [&](int token, float) { - // Skip prompt tokens - if (session.abs_pos < tokens.size()) { session.abs_pos++; + + // Check for EOS + if (state.gemma->Config().IsEOS(token)) { + return true; + } + + // Decode token + std::string token_text; + state.gemma->Tokenizer().Decode(std::vector{token}, + &token_text); + accumulated_text += token_text; + + // Send SSE event using unified formatter + json event = CreateAPIResponse(token_text, true); + + std::string sse_data = "data: " + event.dump() + "\n\n"; + sink.write(sse_data.data(), sse_data.size()); + return true; - } + }; - session.abs_pos++; + runtime_config.stream_token = stream_token; - // Check for EOS - if (state.gemma->Config().IsEOS(token)) { - return true; - } + // Run inference with KV cache + TimingInfo timing_info = {.verbosity = 0}; + size_t prefix_end = 0; - // Decode token - std::string token_text; - state.gemma->Tokenizer().Decode(std::vector{token}, &token_text); - accumulated_text += token_text; + state.gemma->Generate(runtime_config, tokens, session.abs_pos, + prefix_end, *session.kv_cache, *state.env, + timing_info); - // Send SSE event using unified formatter - json event = CreateAPIResponse(token_text, true); + // Send final event using unified formatter + json final_event = CreateAPIResponse("", false); + final_event["usageMetadata"] = { + {"promptTokenCount", tokens.size()}, + {"candidatesTokenCount", session.abs_pos - tokens.size()}, + {"totalTokenCount", session.abs_pos}}; - std::string sse_data = "data: " + event.dump() + "\n\n"; - sink.write(sse_data.data(), sse_data.size()); + std::string final_sse = "data: " + final_event.dump() + "\n\n"; + sink.write(final_sse.data(), final_sse.size()); - return true; - }; - - runtime_config.stream_token = stream_token; - - // Run inference with KV cache - TimingInfo timing_info = {.verbosity = 0}; - size_t prefix_end = 0; - - state.gemma->Generate(runtime_config, tokens, session.abs_pos, - prefix_end, *session.kv_cache, *state.env, - timing_info); - - // Send final event using unified formatter - json final_event = CreateAPIResponse("", false); - final_event["usageMetadata"] = { - {"promptTokenCount", tokens.size()}, - {"candidatesTokenCount", session.abs_pos - tokens.size()}, - {"totalTokenCount", session.abs_pos} - }; - - std::string final_sse = "data: " + final_event.dump() + "\n\n"; - sink.write(final_sse.data(), final_sse.size()); - - // Send done event - sink.write("data: [DONE]\n\n", 15); - - // Ensure all data is sent - sink.done(); - return false; // End streaming - - } catch (const std::exception& e) { - json error_event = {{"error", {{"message", e.what()}}}}; - std::string error_sse = "data: " + error_event.dump() + "\n\n"; - sink.write(error_sse.data(), error_sse.size()); - return false; - } - } - ); + // Send done event + sink.write("data: [DONE]\n\n", 15); + // Ensure all data is sent + sink.done(); + return false; // End streaming + } catch (const std::exception& e) { + json error_event = {{"error", {{"message", e.what()}}}}; + std::string error_sse = "data: " + error_event.dump() + "\n\n"; + sink.write(error_sse.data(), error_sse.size()); + return false; + } + }); } catch (const json::exception& e) { res.status = 400; res.set_content( @@ -398,20 +393,20 @@ void HandleGenerateContentStreaming(ServerState& state, const httplib::Request& } // Handle models list endpoint -void HandleListModels(ServerState& state, const InferenceArgs& inference, const httplib::Request& req, httplib::Response& res) { +void HandleListModels(ServerState& state, const InferenceArgs& inference, + const httplib::Request& req, httplib::Response& res) { json response = { - {"models", {{ - {"name", "models/" + inference.model}, - {"version", "001"}, - {"displayName", inference.model}, - {"description", inference.model + " model running locally"}, - {"inputTokenLimit", 8192}, - {"outputTokenLimit", 8192}, - {"supportedGenerationMethods", json::array({"generateContent", "streamGenerateContent"})}, - {"temperature", 1.0}, - {"topK", 1} - }}} - }; + {"models", + {{{"name", "models/" + inference.model}, + {"version", "001"}, + {"displayName", inference.model}, + {"description", inference.model + " model running locally"}, + {"inputTokenLimit", 8192}, + {"outputTokenLimit", 8192}, + {"supportedGenerationMethods", + json::array({"generateContent", "streamGenerateContent"})}, + {"temperature", 1.0}, + {"topK", 1}}}}}; res.set_content(response.dump(), "application/json"); } @@ -421,39 +416,45 @@ void HandleListModels(ServerState& state, const InferenceArgs& inference, const // server_running = false; // } -void RunServer(const LoaderArgs& loader, const ThreadingArgs& threading, - const InferenceArgs& inference) { +void RunServer(const GemmaArgs& args) { std::cerr << "Loading model..." << std::endl; // Initialize model - ThreadingContext ctx(threading); + ThreadingContext ctx(args.threading); MatMulEnv env(ctx); - ServerState state; - state.gemma = std::make_unique(loader, inference, ctx); + state.gemma = std::make_unique(args, ctx); state.env = &env; state.ctx = &ctx; + const InferenceArgs& inference = args.inference; + httplib::Server server; // Set up routes - server.Get("/", [&inference](const httplib::Request&, httplib::Response& res) { - res.set_content("API Server (gemma.cpp) - Use POST /v1beta/models/" + inference.model + ":generateContent", "text/plain"); - }); + server.Get( + "/", [&inference](const httplib::Request&, httplib::Response& res) { + res.set_content("API Server (gemma.cpp) - Use POST /v1beta/models/" + + inference.model + ":generateContent", + "text/plain"); + }); // API endpoints - server.Get("/v1beta/models", [&state, &inference](const httplib::Request& req, httplib::Response& res) { + server.Get("/v1beta/models", [&state, &inference](const httplib::Request& req, + httplib::Response& res) { HandleListModels(state, inference, req, res); }); std::string model_endpoint = "/v1beta/models/" + inference.model; - server.Post(model_endpoint + ":generateContent", [&state](const httplib::Request& req, httplib::Response& res) { - HandleGenerateContentNonStreaming(state, req, res); - }); + server.Post(model_endpoint + ":generateContent", + [&state](const httplib::Request& req, httplib::Response& res) { + HandleGenerateContentNonStreaming(state, req, res); + }); - server.Post(model_endpoint + ":streamGenerateContent", [&state](const httplib::Request& req, httplib::Response& res) { - HandleGenerateContentStreaming(state, req, res); - }); + server.Post(model_endpoint + ":streamGenerateContent", + [&state](const httplib::Request& req, httplib::Response& res) { + HandleGenerateContentStreaming(state, req, res); + }); // Periodic cleanup of old sessions std::thread cleanup_thread([&state]() { @@ -466,12 +467,15 @@ void RunServer(const LoaderArgs& loader, const ThreadingArgs& threading, std::cerr << "Starting API server on port " << inference.port << std::endl; std::cerr << "Model loaded successfully" << std::endl; std::cerr << "Endpoints:" << std::endl; - std::cerr << " POST /v1beta/models/" << inference.model << ":generateContent" << std::endl; - std::cerr << " POST /v1beta/models/" << inference.model << ":streamGenerateContent (SSE)" << std::endl; + std::cerr << " POST /v1beta/models/" << inference.model << ":generateContent" + << std::endl; + std::cerr << " POST /v1beta/models/" << inference.model + << ":streamGenerateContent (SSE)" << std::endl; std::cerr << " GET /v1beta/models" << std::endl; if (!server.listen("0.0.0.0", inference.port)) { - std::cerr << "Failed to start server on port " << inference.port << std::endl; + std::cerr << "Failed to start server on port " << inference.port + << std::endl; } cleanup_thread.join(); @@ -482,35 +486,27 @@ void RunServer(const LoaderArgs& loader, const ThreadingArgs& threading, int main(int argc, char** argv) { gcpp::InternalInit(); - gcpp::LoaderArgs loader(argc, argv); - gcpp::ThreadingArgs threading(argc, argv); - gcpp::InferenceArgs inference(argc, argv); + gcpp::ConsumedArgs consumed(argc, argv); + gcpp::GemmaArgs args(argc, argv, consumed); if (gcpp::HasHelp(argc, argv)) { - std::cerr << "\n\nAPI server for gemma.cpp\n"; - std::cout << "========================\n\n"; - std::cerr << "Usage: " << argv[0] << " --weights --tokenizer [options]\n"; - std::cerr << "\nOptions:\n"; - std::cerr << " --port PORT Server port (default: 8080)\n"; - std::cerr << " --model MODEL Model name for endpoints (default: gemma3-4b)\n"; - std::cerr << "\n"; - std::cerr << "\n*Model Loading Arguments*\n\n"; - loader.Help(); - std::cerr << "\n*Threading Arguments*\n\n"; - threading.Help(); - std::cerr << "\n*Inference Arguments*\n\n"; - inference.Help(); - std::cerr << "\n"; + fprintf( + stderr, + "\n\nAPI server for gemma.cpp\n" + "========================\n\n" + " --port PORT Server port (default: 8080)\n" + " --model MODEL Model name for endpoints (default: gemma3-4b)\n"); + args.Help(); return 0; } - // Arguments are now handled by InferenceArgs + consumed.AbortIfUnconsumed(); // // Set up signal handler // signal(SIGINT, gcpp::HandleShutdown); // signal(SIGTERM, gcpp::HandleShutdown); - gcpp::RunServer(loader, threading, inference); + gcpp::RunServer(args); return 0; } diff --git a/gemma/attention.cc b/gemma/attention.cc index 117b533..eccfd25 100644 --- a/gemma/attention.cc +++ b/gemma/attention.cc @@ -43,6 +43,7 @@ // After highway.h #include "compression/compress-inl.h" #include "gemma/flash_attention.h" +#include "gemma/gemma-inl.h" #include "ops/ops-inl.h" HWY_BEFORE_NAMESPACE(); @@ -57,33 +58,35 @@ static HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos, const MatPtrT& k, float* HWY_RESTRICT att, ThreadingContext& ctx, const size_t worker) { GCPP_ZONE(ctx, worker, Zones::kGenAttentionQDotK); - if (HWY_LIKELY(last_pos < static_cast(div_seq_len.GetDivisor()))) { - // Slightly faster: no wraparound. - for (size_t pos = start_pos; pos <= last_pos; ++pos) { - const float score = Dot(q, k.Row(pos), k.Cols()); - att[pos] = score; - } - } else { - for (size_t pos = start_pos; pos <= last_pos; ++pos) { - const size_t pos_modulo = div_seq_len.Remainder(pos); - const float score = Dot(q, k.Row(pos_modulo), k.Cols()); - att[pos_modulo] = score; - } + const hn::ScalableTag dbf; + const size_t qkv_dim = k.Cols(); + HWY_ALIGN BF16 q_bf[kMaxQKVDim]; + + CompressPerThread tls; + const hn::ScalableTag df; + CompressTraits::Compress(df, q, qkv_dim, tls, MakeSpan(q_bf, qkv_dim), + 0); + + // --seq_len must be large enough to avoid wraparound. + HWY_DASSERT(last_pos < static_cast(div_seq_len.GetDivisor())); + for (size_t pos = start_pos; pos <= last_pos; ++pos) { + const float score = + Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(pos), qkv_dim); + att[pos] = score; } } void PositionalEncodingQK(float* qk, const size_t layer_idx, - const LayerWeightsPtrs& layer, - const AttentionActivations& activations, + const AttentionActivationsPtrs& activations, ThreadingContext& ctx, const size_t worker, const size_t pos, const float mul) { - const size_t qkv_dim = layer.layer_config.qkv_dim; - const PostQKType& post_qk = layer.layer_config.post_qk; + const LayerConfig& layer_config = activations.config.layer_configs[layer_idx]; + const size_t qkv_dim = layer_config.qkv_dim; + const PostQKType& post_qk = layer_config.post_qk; // qk is either q or k, so qkv_dim is the length we operate on. const float* inv_timescale = activations.inv_timescale.PackedScale1(); const bool is_global_layer = activations.config.IsGlobalLayer(layer_idx); - // TODO: add a config flag instead of hardcoding the model. - if (is_global_layer && IsVLM(activations.config.model)) { + if (is_global_layer && activations.config.use_global_timescale) { inv_timescale = activations.inv_timescale_global.PackedScale1(); } // PostQKType::Rope @@ -104,62 +107,52 @@ static HWY_INLINE void WeightedSumV( const hwy::Divisor& div_seq_len, const float* HWY_RESTRICT att, const MatPtrT& v, float* HWY_RESTRICT att_out, ThreadingContext& ctx, const size_t worker) { - if (HWY_LIKELY(last_pos < static_cast(div_seq_len.GetDivisor()))) { - // Slightly faster: no wraparound. Could be replaced with MatMul(att, v) if - // we supported non-transposed B. - // TODO: 2..4x unroll - MulByConstTo(att[start_pos], v.Row(start_pos), att_out, v.Cols(), ctx, - worker); - for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) { - MulByConstAndAdd(att[pos], v.Row(pos), att_out, v.Cols()); - } - } else { - { - const size_t pos_mod = div_seq_len.Remainder(start_pos); - MulByConstTo(att[pos_mod], v.Row(pos_mod), att_out, v.Cols(), ctx, - worker); - } - for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) { - const size_t pos_mod = div_seq_len.Remainder(pos); - MulByConstAndAdd(att[pos_mod], v.Row(pos_mod), att_out, v.Cols()); - } + // --seq_len must be large enough to avoid wraparound. + HWY_DASSERT(last_pos < static_cast(div_seq_len.GetDivisor())); + // TODO: replace with MatMul(att, v) after it supports non-transposed B. + MulByConstTo(att[start_pos], v.Row(start_pos), att_out, v.Cols(), ctx, + worker); + for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) { + MulByConstAndAdd(att[pos], v.Row(pos), att_out, v.Cols()); } } // Calculates the attention outputs for a single q, which may be updated // in place for RMSNorm. void SingleDotSoftmaxWeightedSum( - const size_t pos, const size_t start_pos, const size_t last_pos, + const size_t q_pos, const size_t kv_start_pos, const size_t kv_last_pos, float* HWY_RESTRICT q, const MatPtrT& k, const MatPtrT& v, - const size_t layer_idx, const LayerWeightsPtrs& layer, - const AttentionActivations& activations, float* HWY_RESTRICT att, - float* HWY_RESTRICT att_out, ThreadingContext& ctx, const size_t worker) { + const MatPtr& query_norm_scale, const size_t layer_idx, + const AttentionActivationsPtrs& activations, float* HWY_RESTRICT att, + float* HWY_RESTRICT att_out, const SMOptions& sm_options, + ThreadingContext& ctx, const size_t worker) { const float att_cap = activations.config.att_cap; const float query_scale = activations.query_scale; - const size_t seq_len = - static_cast(activations.div_seq_len.GetDivisor()); + // --seq_len must be large enough to avoid wraparound. + HWY_DASSERT(kv_last_pos < activations.SeqLen()); + const LayerConfig& layer_config = activations.config.layer_configs[layer_idx]; // Apply rope and scaling to Q. - if (layer.query_norm_scale.HasPtr()) { - CallUpcasted(&layer.query_norm_scale, [&](const auto* weights_t) { + if (query_norm_scale.HasPtr()) { + CallUpcasted(&query_norm_scale, [&](const auto* weights_t) { RMSNormInplace(weights_t->PackedScale1(), /*w_ofs=*/0, q, - layer.layer_config.qkv_dim, ctx, worker); + layer_config.qkv_dim, ctx, worker); }); } - PositionalEncodingQK(q, layer_idx, layer, activations, ctx, worker, pos, + PositionalEncodingQK(q, layer_idx, activations, ctx, worker, q_pos, query_scale); - QDotK(start_pos, last_pos, activations.div_seq_len, q, k, att, ctx, worker); + QDotK(kv_start_pos, kv_last_pos, activations.div_seq_len, q, k, att, ctx, + worker); // SoftMax with optional SoftCap yields "probabilities" in att. - const size_t att_len = HWY_MIN(last_pos + 1, seq_len); - const Logits logits(att, att_len); + const Logits logits(att, kv_last_pos + 1); MaybeLogitsSoftCap(att_cap, logits, ctx, worker); - Softmax(logits, ctx, worker, /*temperature=*/1.0f); + Softmax(logits, ctx, worker, /*temperature=*/1.0f, sm_options); - WeightedSumV(start_pos, last_pos, activations.div_seq_len, att, v, att_out, - ctx, worker); + WeightedSumV(kv_start_pos, kv_last_pos, activations.div_seq_len, att, v, + att_out, ctx, worker); } // The attention window usually starts at 0 unless `pos` is larger than @@ -170,13 +163,13 @@ size_t StartPos(size_t pos, const ModelConfig& config, size_t layer_idx) { } void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx, - const LayerWeightsPtrs& layer, - AttentionActivations& activations, QBatch& qbatch, - ThreadingContext& ctx) { + const MatPtr& query_norm_scale, + AttentionActivationsPtrs& activations, + QBatch& qbatch, ThreadingContext& ctx) { GCPP_ZONE(ctx, 0, Zones::kGenAttentionDotSoftmaxWeightedSumInclusive); const hwy::Divisor div_qbatch(qbatch.Size()); - const LayerConfig& layer_config = layer.layer_config; + const LayerConfig& layer_config = activations.config.layer_configs[layer_idx]; const size_t qkv_dim = layer_config.qkv_dim; // A "head group" in the context of GQA refers to a collection of query @@ -184,8 +177,7 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx, const size_t kHeadGroups = layer_config.heads / layer_config.kv_heads; const size_t cache_layer_size = layer_config.CacheLayerSize(); - const size_t seq_len = - static_cast(activations.div_seq_len.GetDivisor()); + const size_t seq_len = activations.SeqLen(); // All layers should have the same number of heads. HWY_DASSERT(activations.div_heads.GetDivisor() == layer_config.heads); @@ -196,12 +188,12 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx, GCPP_ZONE(ctx, worker, Zones::kGenAttentionDotSoftmaxWeightedSumPar); const size_t qi = div_qbatch.Remainder(tq_idx); - const size_t batch_idx = div_qbatch.Divide(tq_idx); + const size_t token_idx = div_qbatch.Divide(tq_idx); auto& kv_cache = qbatch.KV(qi).kv_cache; // Find the token position in the query and calculate // the range of cache positions to attend to. - const size_t pos = qbatch.Pos(qi) + batch_idx; + const size_t pos = qbatch.Pos(qi) + token_idx; const size_t start_pos = StartPos(pos, activations.config, layer_idx); size_t last_pos = pos; const size_t prefix_end = qbatch.PrefixEnd(qi); @@ -214,6 +206,8 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx, float* HWY_RESTRICT att = activations.att.Row(tq_idx) + head * seq_len; float* HWY_RESTRICT att_out = activations.att_out.Row(tq_idx) + head * qkv_dim; + SMOptions sm_options{.max_out = activations.softmax_max.Row(tq_idx) + head, + .d_out = activations.softmax_d.Row(tq_idx) + head}; // Make strided read-only views into the kv cache for // this query and head. @@ -224,8 +218,10 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx, MatPtrT v("v_view", Extents2D(seq_len, qkv_dim)); v.SetPtr(kv_cache.Row(0) + kv_head_offset + qkv_dim, kv_cache.Stride()); - SingleDotSoftmaxWeightedSum(pos, start_pos, last_pos, q, k, v, layer_idx, - layer, activations, att, att_out, ctx, worker); + constexpr size_t offset = 0; // placeholder, do not remove + SingleDotSoftmaxWeightedSum(pos + offset, start_pos, last_pos, q, k, v, + query_norm_scale, layer_idx, activations, att, + att_out, sm_options, ctx, worker); }; { @@ -246,7 +242,7 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx, // Fills activations.q and writes to KV cache. static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx, const LayerWeightsPtrs& layer, - AttentionActivations& activations, + AttentionActivationsPtrs& activations, const QBatch& qbatch, const int flags, MatMulEnv& env) { GCPP_ZONE(env.ctx, hwy::Profiler::GlobalIdx(), @@ -271,10 +267,14 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx, layer.qkv_einsum_w2.Rows())); for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved; ++interleaved_idx) { + // Index into qbatch, within [0, qbatch.Size()] const size_t qi = div_qbatch.Remainder(interleaved_idx); - const size_t batch_idx = div_qbatch.Divide(interleaved_idx); - const size_t cache_pos = - activations.div_seq_len.Remainder(qbatch.Pos(qi) + batch_idx); + // Index along token sequence, within [0, num_tokens) + const size_t token_idx = div_qbatch.Divide(interleaved_idx); + const size_t cache_pos = qbatch.Pos(qi) + token_idx; + // --seq_len must be large enough to avoid wraparound. + HWY_DASSERT(cache_pos < activations.SeqLen()); + env.row_ptrs[0][interleaved_idx] = reinterpret_cast( qbatch.KV(qi).kv_cache.Row(cache_pos) + layer_idx * cache_layer_size); } @@ -286,15 +286,16 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx, // Note that 2D parallelism is not worth the fork/join overhead because the // tasks are very lightweight. ParallelFor( - ParallelismStrategy::kFlat, kv_heads * num_interleaved, env.ctx, + Parallelism::kFlat, kv_heads * num_interleaved, env.ctx, /*cluster_idx=*/0, Callers::kAttComputeQKV, [&](size_t task, size_t worker) HWY_ATTR { const size_t head = task % kv_heads; const size_t interleaved_idx = task / kv_heads; const size_t qi = div_qbatch.Remainder(interleaved_idx); - const size_t batch_idx = div_qbatch.Divide(interleaved_idx); - const size_t pos = qbatch.Pos(qi) + batch_idx; - const size_t cache_pos = activations.div_seq_len.Remainder(pos); + const size_t token_idx = div_qbatch.Divide(interleaved_idx); + const size_t cache_pos = qbatch.Pos(qi) + token_idx; + // --seq_len must be large enough to avoid wraparound. + HWY_DASSERT(cache_pos < activations.SeqLen()); auto& kv_cache = qbatch.KV(qi).kv_cache; KV_t* HWY_RESTRICT kv = kv_cache.Row(cache_pos) + layer_idx * cache_layer_size + @@ -313,35 +314,18 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx, }); } - PositionalEncodingQK(kv_f32, layer_idx, layer, activations, env.ctx, - worker, pos, /*mul=*/1.0f); + constexpr size_t offset = 0; // placeholder, do not remove + PositionalEncodingQK(kv_f32, layer_idx, activations, env.ctx, worker, + cache_pos + offset, + /*mul=*/1.0f); CompressPerThread tls; Compress(kv_f32, 2 * qkv_dim, tls, MakeSpan(kv, 2 * qkv_dim), 0); }); } -// Sums encoded (`att_out`) over num_heads (`layer_config.heads`) and -// head_dim (`qkv_dim`) into output (`layer_out`). -static HWY_INLINE void SumHeads(const LayerWeightsPtrs& layer, - AttentionActivations& activations, - MatMulEnv& env) { - GCPP_ZONE(env.ctx, hwy::Profiler::GlobalIdx(), Zones::kGenAttentionSumHeads); - const LayerConfig& layer_config = layer.layer_config; - (void)layer_config; // For HWY_DASSERT - // att_weights and att_out are concatenated heads, each of length - // layer_config.qkv_dim. Thus the [num_interleaved, - // layer_config.model_dim] matmul output is the sum over heads. Compare - // gemma/modules.py: attn_output = self.attn_vec_einsum('BTNH,NHD->BTD', - // encoded) - HWY_DASSERT(layer_config.model_dim != 0 && layer_config.heads != 0 && - layer_config.qkv_dim != 0); - CallMatMul(activations.att_out, layer.att_weights, /*add=*/nullptr, env, - activations.att_sums); -} - void GemmaAttention(size_t num_tokens, const size_t layer_idx, const LayerWeightsPtrs& layer, - AttentionActivations& activations, QBatch& qbatch, + AttentionActivationsPtrs& activations, QBatch& qbatch, MatMulEnv& env, int flags) { GCPP_ZONE(env.ctx, hwy::Profiler::GlobalIdx(), Zones::kGenAttention); @@ -353,13 +337,14 @@ void GemmaAttention(size_t num_tokens, const size_t layer_idx, ComputeQKV(num_tokens, layer_idx, layer, activations, qbatch, flags, env); if (flags & kAttentionUseOld) { - DotSoftmaxWeightedSum(num_tokens, layer_idx, layer, activations, qbatch, - env.ctx); + DotSoftmaxWeightedSum(num_tokens, layer_idx, layer.query_norm_scale, + activations, qbatch, env.ctx); } else { // * 2 does not help on Turin. FlashAttention(num_tokens, /*target_parallelism=*/env.ctx.pools.MaxWorkers() * 1, - layer_idx, layer, activations, qbatch, env.ctx); + layer_idx, layer.query_norm_scale, activations, qbatch, + env.ctx); } SumHeads(layer, activations, env); } diff --git a/gemma/attention.h b/gemma/attention.h index 6c4a48e..60e6823 100644 --- a/gemma/attention.h +++ b/gemma/attention.h @@ -29,8 +29,7 @@ namespace gcpp { #define GEMMA_DECL_ATTENTION(TARGET, NAMESPACE) \ namespace NAMESPACE { \ void PositionalEncodingQK(float* qk, size_t layer_idx, \ - const LayerWeightsPtrs& layer, \ - const AttentionActivations& activations, \ + const AttentionActivationsPtrs& activations, \ ThreadingContext& ctx, size_t worker, size_t pos, \ float mul); \ \ @@ -39,18 +38,18 @@ namespace gcpp { void SingleDotSoftmaxWeightedSum( \ const size_t pos, const size_t start_pos, const size_t last_pos, \ float* HWY_RESTRICT q, const MatPtrT& k, const MatPtrT& v, \ - size_t layer_idx, const LayerWeightsPtrs& layer, \ - const AttentionActivations& activations, float* HWY_RESTRICT att, \ + const MatPtr& query_norm_scale, size_t layer_idx, \ + const AttentionActivationsPtrs& activations, float* HWY_RESTRICT att, \ float* HWY_RESTRICT att_out, ThreadingContext& ctx, size_t worker); \ \ void DotSoftmaxWeightedSum(const size_t num_tokens, size_t layer_idx, \ - const LayerWeightsPtrs& layer, \ - AttentionActivations& activations, \ + const MatPtr& query_norm_scale, \ + AttentionActivationsPtrs& activations, \ QBatch& qbatch, ThreadingContext& ctx); \ \ void GemmaAttention(size_t num_tokens, const size_t layer_idx, \ const LayerWeightsPtrs& layer, \ - AttentionActivations& activations, QBatch& qbatch, \ + AttentionActivationsPtrs& activations, QBatch& qbatch, \ MatMulEnv& env, int flags); \ /* NOLINTNEXTLINE(google-readability-namespace-comments) */ \ } // namespace NAMESPACE diff --git a/gemma/attention_test.cc b/gemma/attention_test.cc new file mode 100644 index 0000000..53f1d01 --- /dev/null +++ b/gemma/attention_test.cc @@ -0,0 +1,570 @@ +#include +#include // strcmp +#include +#include +#include +#include + +#include "gtest/gtest.h" +#include "compression/types.h" // GEMMA_DISABLED_TARGETS +#include "gemma/activations.h" +#include "gemma/gemma.h" +#include "gemma/gemma_args.h" +#include "gemma/kv_cache.h" +#include "gemma/weights.h" +#include "ops/matmul.h" +#include "util/mat.h" +#include "util/threading_context.h" +#include "hwy/aligned_allocator.h" +#include "hwy/base.h" +#ifndef HWY_DISABLED_TARGETS +// These tests aren't designed to suss out instruction set specific problems. +// Disable most targets to keep the tests fast and simple and not have to +// worry about tolerances on floating point results. +#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS +#endif // HWY_DISABLED_TARGETS + +// clang-format off +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "gemma/attention_test.cc" // NOLINT +// clang-format on +#include "hwy/foreach_target.h" // IWYU pragma: keep +#include "hwy/highway.h" +// After highway.h +#include "compression/compress-inl.h" +#include "gemma/attention.h" +#include "gemma/configs.h" +#include "util/test_util.h" +#include "hwy/tests/test_util-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace gcpp { +namespace HWY_NAMESPACE { + +void FillRandom(MatPtrT& mat, uint64_t seed) { + hwy::RandomState rng(seed); + for (size_t r = 0; r < mat.Rows(); ++r) { + float* row = mat.Row(r); + for (size_t c = 0; c < mat.Cols(); ++c) { + row[c] = static_cast(RandomGaussian(rng)); + } + } +} + +void AllocateAndFillRandom(MatPtr& mat, const Allocator& allocator, + std::vector& mat_owners, uint64_t seed) { + if (mat.IsEmpty()) return; + if (mat.GetType() == Type::kUnknown) { + mat.SetType(Type::kF32); + } + mat_owners.emplace_back(); + mat_owners.back().AllocateFor(mat, allocator, MatPadding::kPacked); + MatPtrT mat_f32(mat); + FillRandom(mat_f32, seed); +} + +struct TestState { + TestState() : ctx({}), env(ctx) {} + ThreadingContext ctx; + std::vector mat_owners; + MatMulEnv env; +}; + +struct TestModelState { + TestModelState(TestState& state) + : config(Model::GEMMA2_2B, Type::kF32, PromptWrapping::GEMMA_PT), + tensor_info_registry(config), + layer_config(config.layer_configs[0]), + layer(0, layer_config, tensor_info_registry) { + config.att_cap = 1024.0f; + AllocateAndFillRandom(layer.qkv_einsum_w, state.ctx.allocator, + state.mat_owners, 42); + AllocateAndFillRandom(layer.attn_vec_einsum_w, state.ctx.allocator, + state.mat_owners, 43); + AllocateAndFillRandom(layer.gating_einsum_w, state.ctx.allocator, + state.mat_owners, 44); + AllocateAndFillRandom(layer.linear_w, state.ctx.allocator, state.mat_owners, + 45); + layer.Fixup(state.mat_owners, state.ctx); + } + + ModelConfig config; + TensorInfoRegistry tensor_info_registry; + const LayerConfig& layer_config; + LayerWeightsPtrs layer; +}; + +struct TestAttentionState { + TestAttentionState(TestState& state, TestModelState& model_state, + size_t num_tokens, size_t qbatch_size, + AttentionImpl attention_impl) + : num_tokens(num_tokens), + qbatch_size(qbatch_size), + batch_size(qbatch_size * num_tokens), + runtime_config{.attention_impl = attention_impl}, + tokens(num_tokens), + attention_storage_(model_state.config, model_state.layer_config, + batch_size, num_tokens, runtime_config, + state.ctx.allocator, row_ptrs_), + attention(model_state.config, num_tokens, attention_storage_) { + for (size_t i = 0; i < qbatch_size; ++i) { + kv_caches.emplace_back(model_state.config, inference_args, + state.ctx.allocator); + } + activations.emplace( + runtime_config, model_state.config, runtime_config.prefill_tbatch_size, + kv_caches[0].SeqLen(), state.env.ctx, state.env.row_ptrs); + // Tokens don't matter, since we fill in pre_att_rms_out before calling + // GemmaAttention. + std::iota(tokens.begin(), tokens.end(), 1); + for (size_t i = 0; i < qbatch_size; ++i) { + prompts.emplace_back(tokens); + } + all_queries.emplace(prompts, + hwy::Span(kv_caches.data(), kv_caches.size())); + qbatch.emplace(/*start=*/0, /*max_size=*/qbatch_size, *all_queries); + FillRandom(attention.pre_att_rms_out, 46); + } + + const size_t num_tokens; + const size_t qbatch_size; + const size_t batch_size; + InferenceArgs inference_args; + RuntimeConfig runtime_config; + std::vector kv_caches; + std::optional activations; + std::vector tokens; + std::vector prompts; + std::optional all_queries; + std::optional qbatch; + std::vector> row_ptrs_; + AttentionActivations attention_storage_; + AttentionActivationsPtrs attention; +}; + +double GetTolerance() { + const char* target_name = hwy::TargetName(HWY_TARGET); + if (strncmp(target_name, "AVX2", 4) == 0) { + return 2e-2; + } else if (strncmp(target_name, "AVX3", 4) == 0) { + return 3e-4; + } else if (strncmp(target_name, "NEON", 4) == 0) { + return 5e-3; + } else { + return 1e-7; + } +} + +template +void CompareAttSumsWithGolden( + const AttentionActivationsPtrs& attention, + const float (&golden)[kNumTokens][kQBatchSize][kDims]) { + ASSERT_EQ(attention.att_sums.Rows(), kNumTokens * kQBatchSize); + ASSERT_LE(kDims, attention.att_sums.Cols()); + + hwy::AlignedFreeUniquePtr actual_row = + hwy::AllocateAligned(kDims); + for (size_t token_idx = 0; token_idx < kNumTokens; ++token_idx) { + for (size_t qi = 0; qi < kQBatchSize; ++qi) { + const size_t i = token_idx * kQBatchSize + qi; + for (size_t j = 0; j < kDims; ++j) { + actual_row[j] = hwy::F32FromBF16(attention.att_sums.Row(i)[j]); + } + EXPECT_TRUE(hwy::CompareArraySimilar( + golden[token_idx][qi], actual_row.get(), kDims, GetTolerance(), + hwy::TargetName(HWY_TARGET), __FILE__, __LINE__)) + << "att_sums mismatch for token_idx=" << token_idx << " qi=" << qi; + } + } +} + +template +void CompareKVCacheWithGolden( + const ModelConfig& config, hwy::Span kv_caches, const size_t layer, + const size_t kv_head, + const float (&k_golden)[kNumTokens][kQBatchSize][kDims], + const float (&v_golden)[kNumTokens][kQBatchSize][kDims]) { + const size_t qbatch_size = kv_caches.size(); + ASSERT_EQ(kQBatchSize, qbatch_size); + const size_t start_offset = 0; + const size_t qkv_dim = config.layer_configs[0].qkv_dim; + + hwy::AlignedFreeUniquePtr actual_k_row = + hwy::AllocateAligned(kDims); + hwy::AlignedFreeUniquePtr actual_v_row = + hwy::AllocateAligned(kDims); + + const size_t cache_layer_size = config.layer_configs[layer].CacheLayerSize(); + const size_t head_offset = kv_head * qkv_dim * 2; + const size_t kv_offset = layer * cache_layer_size + head_offset; + + for (size_t token_idx = 0; token_idx < kNumTokens; ++token_idx) { + for (size_t qi = 0; qi < kQBatchSize; ++qi) { + const float* cache_row = + kv_caches[qi].kv_cache.Row(start_offset + token_idx); + for (size_t j = 0; j < kDims; ++j) { + actual_k_row[j] = cache_row[kv_offset + j]; + actual_v_row[j] = cache_row[kv_offset + qkv_dim + j]; + } + EXPECT_TRUE(hwy::CompareArraySimilar( + k_golden[token_idx][qi], actual_k_row.get(), kDims, GetTolerance(), + hwy::TargetName(HWY_TARGET), __FILE__, __LINE__)) + << "K cache mismatch for token_idx=" << token_idx << " qi=" << qi + << " kv_head=" << kv_head; + EXPECT_TRUE(hwy::CompareArraySimilar( + v_golden[token_idx][qi], actual_v_row.get(), kDims, GetTolerance(), + hwy::TargetName(HWY_TARGET), __FILE__, __LINE__)) + << "V cache mismatch for token_idx=" << token_idx << " qi=" << qi + << " kv_head=" << kv_head; + } + } +} + +template +void CompareQVecsWithGolden( + const ModelConfig& config, const AttentionActivationsPtrs& attention, + const size_t q_head, + const float (&q_golden)[kNumTokens][kQBatchSize][kDims]) { + const size_t qkv_dim = config.layer_configs[0].qkv_dim; + + hwy::AlignedFreeUniquePtr actual_q_row = + hwy::AllocateAligned(kDims); + + const size_t head_offset = q_head * qkv_dim; + for (size_t token_idx = 0; token_idx < kNumTokens; ++token_idx) { + for (size_t qi = 0; qi < kQBatchSize; ++qi) { + const size_t i = token_idx * kQBatchSize + qi; + const float* q_row = attention.q.Row(i); + for (size_t j = 0; j < kDims; ++j) { + actual_q_row[j] = q_row[head_offset + j]; + } + EXPECT_TRUE(hwy::CompareArraySimilar( + q_golden[token_idx][qi], actual_q_row.get(), kDims, GetTolerance(), + hwy::TargetName(HWY_TARGET), __FILE__, __LINE__)) + << "Q vec mismatch for token_idx=" << token_idx << " qi=" << qi + << " q_head=" << q_head; + } + } +} + +constexpr size_t kNumTokens = 10; +constexpr size_t kQBatchSize = 2; + +// Only compare the first 10 dimensions of the attention output to keep the +// test manageable. It's unlikely that the first few dimensions would have +// different behavior than the later dimensions. +const size_t kDimsToCompare = 17; // greater than AVX-512 vector of floats + +// Use a Coding Agent to update the golden values from test +// failure output (which includes the new values). +// Instructions to the Coding Agent: You must never update the +// EMU128 tolerance. Always use the value from the EMU128 test to update the +// Goldens. If necessary, add relaxed tolerance for other instruction sets. + +// Layer 0 +const float kGoldenAttSums[kNumTokens][kQBatchSize][kDimsToCompare] = { + {{46.5, 56.5, 10.0625, 65.5, -2.109375, 135, 15.8125, 51, -100, 52.5, + 26.875, 63, 3.34375, -67.5, 31.125, -190, 125}, + {-30.375, -17.875, 51.75, -78, -84, 6.40625, 15.375, 70, -22.875, 20.125, + -14.9375, -109.5, 76, 9.25, -142, 29.5, -105}}, + {{-32.75, 38.25, 78.5, 107.5, 20.25, 197, -136, 42.5, -84, 25.625, 4.96875, + 128, 27.25, -161, 19.125, -58, 97.5}, + {-18.5, -18, 135, -13.4375, -6.625, -45.75, 29.625, 93, 18.625, 75.5, + 102.5, -184, 52.75, 83.5, -71, 46.5, -52}}, + {{-16.375, -61.5, -58.25, -27.375, -28, 71, -109.5, 60.25, 3.125, -29.125, + 6.90625, 150, 144, -155, -47.25, -98.5, 3.5625}, + {-19, -16.75, 129, 0.59765625, -82, 123.5, 60.75, -36.75, -77, 26.625, 51, + -66.5, -0.84765625, -46.5, -152, -2.9375, -81}}, + {{3.984375, 83, -41.75, 39.5, -203, 110, -76, 131, 0.4609375, -44.5, -63.75, + -46, -22, -19.375, -16.125, -148, 20.875}, + {-47, -19.5, 58, 81.5, 21.75, -30, -118, 44.25, -149, 22.5, 188, -66.5, 33, + 10.9375, -52.5, 23.25, 75}}, + {{64, -31, -89, -92.5, -11.1875, -54.75, -302, 3.453125, -108, 39.25, + -34.75, 18, -52, 100, -186, -75.5, 50.75}, + {7.6875, -80, -40, 32.25, -30.25, 90, -41, 44.25, -140, -2.4375, 82.5, + 39.25, 65, 47.25, -89.5, -34.25, 137}}, + {{39.75, 17.875, 115, 38.75, -44, 139, -53.25, -23.875, -13.0625, 38.5, + 32.5, 53.75, 109, 4.09375, 57.5, -20.5, 132}, + {143, 249, 5.09375, 0.83984375, 27.875, -5.84375, 30.25, -101.5, 65.5, + 13.5, 195, -10.0625, 97.5, 2.203125, -97.5, -100, -19.25}}, + {{-30.125, -169, -150, 58, -35.75, 22.75, 36.5, -32.25, -8.9375, 55.25, + -117, 26.375, 39.5, 125, 66, 48.75, 20.75}, + {137, 5.25, 61.25, 37, -42.75, 240, 62, -164, 11.3125, 173, 174, 23.5, + 88.5, 48.5, -46.25, -36.75, 101.5}}, + {{-103, -47.5, 39, -48, -67.5, 121, -136, 99, 80, -47.5, 107.5, 48.75, 97.5, + 125, -53.5, -14.625, 262}, + {29.875, 7.34375, -36.75, -14.5, -27.5, 44.75, -67.5, -40.75, 71.5, 172, + 81, -27.25, -3.03125, 111, -167, 59, 176}}, + {{-37.25, 109.5, -26.125, -115.5, 108, 57.25, 1.3671875, 72, -122.5, 59.25, + -52, -12.625, 43.25, 16.25, -41.75, 26.5, 70.5}, + {40.25, 53.25, -142, 78.5, 38, 4.3125, -27.75, -134, -85, 107.5, 2.5, 93.5, + 58.25, 173, -53.5, 25.125, 4.8125}}, + {{-8.4375, -35, -35.5, 131, -33.25, 106, 109.5, -92, -135, 80, 21.5, + -17.125, 15.25, 143, -27, 103, 101}, + {-77, 40.75, -10.125, 33.25, -33, 104, -7.6875, 85.5, -40, 93, 61, 14.5625, + 8.125, -99.5, 13.6875, -11.6875, 33}}, +}; + +// Layer 0, *K*V Head 0 +const float kGoldenK[kNumTokens][kQBatchSize][kDimsToCompare] = { + {{-4.51717567, 6.93118095, 6.48003578, 9.12825584, 2.38755274, 11.8121576, + 1.65376127, 5.04456615, -7.19549274, 2.57609844, 3.55331731, -3.48494458, + -8.90498638, 9.66047478, -0.379868984, 6.37043715, -2.24351144}, + {0.152208567, 3.14520073, -8.35154343, 5.44226503, -6.74000502, + -1.43484437, -4.72092056, -9.48932, -6.12409401, -1.55352509, -3.90701318, + 2.12124252, 3.93649936, -8.09877586, -3.30277514, -0.898857355, + 1.76684189}}, + {{4.378829, 5.05565643, -7.63948059, -5.74608946, 2.90109587, 0.155819178, + 4.56115055, 1.37885749, 1.48427355, -1.07145202, 2.82399392, -1.20864201, + 3.05434561, -2.65185618, -0.0731391, -8.2279253, 7.63228416}, + {-0.702698231, 1.49563932, 6.42149782, -6.68306589, 1.85317755, + -7.70267582, 2.07357907, -7.60303402, -0.514724255, 0.308567047, + 5.99250412, -4.67359257, -3.49322176, -2.62086344, -3.18411255, + 2.04027057, -4.29057407}}, + {{-1.20844436, 4.14724302, 6.04515219, 8.7753458, -0.975198627, 0.564640105, + 5.39941597, 4.64036179, 0.366614938, 3.48258138, -0.470701456, 15.2267399, + 4.63302803, 9.12662697, -5.89148045, 2.25731587, 5.24449492}, + {4.57078934, -4.60315752, -3.3364439, 1.29875994, -3.40833569, -6.95262, + -6.39040232, -6.60212612, 6.63269806, -0.815209687, -5.0346446, + -4.13564968, 8.25674057, -6.0910182, -8.21130085, -8.91020393, + 10.6188011}}, + {{0.602011144, 2.22505236, 3.62411499, -4.07026958, 12.8036356, 3.76139069, + 6.99502087, 7.02500725, -2.51568675, 4.2489934, 0.00210827589, + -1.43267739, -2.10394144, -0.0506809056, -1.54883039, 4.3740139, + -1.61869526}, + {-6.37204599, -3.34989691, 2.10935307, 4.23634195, 5.79134035, 13.502944, + -2.19158888, -1.55771351, -1.22244942, 3.36499929, -2.11375904, + -4.5448761, 1.0611912, -2.47849369, -0.212709218, 0.363292456, + 7.91467094}}, + {{-8.85739231, -4.08585882, -0.618261, 6.52911091, 5.14922285, 7.6869874, + 0.750387549, -0.812200725, 2.7509625, 6.29693508, -1.77248931, 5.68896484, + -6.9369607, -4.61359406, 0.184977874, -1.27769828, -2.1619854}, + {-8.2555, 2.84032059, -1.03791106, 2.07648611, -4.94546843, 1.76888537, + -1.75901175, 11.2628574, 1.41086221, -3.58669901, -2.85925198, 2.29133463, + 1.55509436, -0.0553357825, -10.0363655, 1.94261, -2.95691729}}, + {{0.919141412, 1.97533965, -11.3202848, -3.3137629, -4.7161727, 5.07012081, + 1.76256621, 8.20588207, 6.05700159, -3.89765406, -1.13639557, -1.32326794, + -3.01544905, -0.585309267, 2.60637712, 2.83708405, -3.39202118}, + {9.11918, 2.11261511, -5.87290621, 11.6033278, -4.66597795, -7.13774204, + -9.10563755, -2.48294282, 3.35282946, -3.75122213, 0.404774547, + -9.11625195, 4.85711479, 1.43184578, 1.47673059, -4.75093, -3.45323014}}, + {{4.17705393, -4.95192289, -10.5068378, 3.90004015, -3.51306129, 5.38068056, + 0.901511431, 11.222868, 2.67285442, 9.18779, 5.61346769, 3.06534624, + -3.78898215, 0.767340839, 15.8207836, -4.14079094, -4.63177109}, + {3.61795235, -7.00262165, 2.08284521, -6.70515728, 1.93205631, 2.84467721, + 3.94591737, -6.18882942, -1.78465152, -9.39100933, -10.8780289, + 6.32468653, 6.53142738, -3.30765963, 2.89132166, 4.53347206, 1.89792418}}, + {{-0.361971855, -1.57735932, 5.07296801, -1.55669761, -1.44996238, + 7.29838896, 5.23075104, -0.512441278, -3.59834242, 2.38584423, 6.48518324, + -1.48220074, -2.4264791, 10.7237988, 5.64735842, 5.6251297, -7.04244423}, + {-0.795628309, 7.30230665, -1.71035647, -16.6999454, 3.05102086, + -4.9243927, 4.28508186, -0.694577456, 6.58464718, 4.40330124, 3.3250041, + 1.90579033, -6.29048729, 2.55308104, -4.9746747, -0.681708, -5.98152351}}, + {{2.57555652, -3.5651083, 0.784440041, -4.7043705, 2.37520599, -3.62385964, + -3.48913693, -7.28049421, -5.48726082, 1.95519221, 7.25192928, 3.07074118, + -11.9897156, 5.92244673, 5.07564354, 0.162699938, -6.00809956}, + {5.56260443, -5.7683115, 1.26402235, -17.507719, 4.18873024, -3.20694613, + -4.42512083, 1.78077614, -3.25167561, 0.864362717, 0.474019766, + -7.92327404, -2.27795148, -0.436354101, -3.15722394, 0.415780187, + 2.60931611}}, + {{-9.43858051, 0.391518891, -2.74012518, 4.9842453, 7.48263216, -16.3434925, + -4.75156116, -1.99114823, 3.99918842, -5.95400572, 10.8700314, 1.07596064, + 0.30389142, 8.39548779, -5.11913681, 5.45641088, -5.63240337}, + {-1.22347319, 9.57339382, -1.31736016, -5.02770805, -4.81617355, + -1.96618557, -0.456317186, 12.6451035, -1.50221801, 6.7991147, + -5.97842169, 1.85410941, -8.44729, 0.378282309, 0.0442156792, 17.6773052, + -7.43491}}, +}; + +// Layer 0, K*V* Head 0 +const float kGoldenV[kNumTokens][kQBatchSize][kDimsToCompare] = { + {{2.77553034, -7.67514181, -1.60433948, 4.67795134, -1.75084186, 8.57896423, + -1.15065813, -3.75088787, -4.7442131, -1.68890858, -10.0202332, + -4.20167446, 9.36844635, 13.7364845, 11.5634, 2.95288706, 2.89380026}, + {-4.79950905, -1.66658688, 4.14471292, -4.95649052, -5.4200325, 3.52626801, + -10.9432049, 0.338347554, -1.53204226, 0.473476171, -0.58271, 1.42195463, + 0.301399827, -4.40214968, -2.12298298, 9.27825642, -0.690600872}}, + {{-10.6566734, 4.12785721, 4.54053593, -1.39667869, -1.55028772, 0.20508635, + -0.00620913506, 2.93214, -0.788117647, 2.78032446, -2.68898249, 9.5985508, + -10.6630878, -11.9006901, 0.851743698, 0.581826329, 5.21927929}, + {-0.322291255, 2.63848567, -2.30808377, -13.0153809, 2.74378228, + 3.21460533, 0.688529968, 2.37544608, 6.06825066, 4.57566404, 1.17124248, + -7.96587658, -2.65279341, 4.75271225, -4.09937954, -10.3570251, + 3.30500841}}, + {{-3.34342527, 6.03099537, 6.335958, 0.993818045, 0.905343294, 6.93058586, + 3.9635396, 10.8044815, 7.8620863, -10.1157322, -3.92666101, -0.183003783, + -5.27309418, -1.45110512, -8.96734, -2.63866425, 2.19913912}, + {16.416317, -1.62025332, 2.3161006, 3.32571959, -1.79581594, -10.2925539, + -5.86338425, -6.36642933, 9.18872166, 5.95524168, 6.38640785, 8.23832, + -6.57342291, -14.2017632, 1.10925388, 4.27255058, -2.65661311}}, + {{6.58254147, -6.96165133, -4.97437, -2.33467388, 5.83671236, -0.794236898, + -2.03117108, -3.93387103, -5.96872902, 5.83316422, 3.01795, -4.05260706, + -4.39556885, 3.24399853, 10.1573639, 4.71967888, 0.274738848}, + {7.13243389, -8.04649162, 2.53055143, 2.0771277, -0.667295456, -13.0285645, + 0.960428238, -2.11983275, 8.18105602, -6.72609901, -5.46944714, + 0.204244614, 0.0900330544, 8.86620903, 4.63697529, 3.19756651, + 2.99392676}}, + {{9.52539158, -4.3840766, -6.94514465, -2.75913763, -10.8364506, + -3.95606327, 2.43603897, -5.78482246, -0.801304817, 8.23436832, + -7.11484337, 2.53943753, -0.652261257, 9.77392, 3.53345847, -9.62052822, + 16.0471916}, + {6.89768124, 2.36394405, -2.08569574, -0.682706833, 3.38872, -6.28313875, + 4.79594612, 4.93417454, -6.40791416, -10.7355442, -5.66094208, 2.44881392, + 1.99794042, -9.19855404, -4.02383137, -3.63013959, -5.65853405}}, + {{1.64614546, -3.93421197, -0.48935914, 5.48284435, -7.69781828, 11.8203125, + 1.81672478, -1.42535269, -5.26496315, -5.31612349, -4.19499826, + 7.06049395, 0.18029356, -0.0519902706, 10.317358, 2.19345617, 3.5296216}, + {7.52353811, 3.56836724, 0.414305687, 0.340799928, 2.44263697, 7.52111912, + 0.246491909, -11.1172791, -3.82061529, 3.24794388, 0.751524329, + 3.14019632, 6.33881855, -0.169233799, 7.82640171, 1.5389179, 8.15851307}}, + {{-2.48950672, -8.55112743, 8.04663277, -5.77116871, -0.637019753, + -7.65882111, -7.49037457, 3.8041625, -3.57038307, 9.37715435, -6.42604256, + 1.62610793, -1.54000568, 2.52110147, 5.30775261, -4.10454893, + -4.96251774}, + {-2.95554614, -5.18210888, 1.00015664, -4.03864431, -7.14954519, + 5.99929142, 5.86350155, 2.03810191, -4.23009968, 9.39885902, -5.68198299, + 2.72845244, 11.7133255, 0.838779449, -13.2235403, 2.94607735, + -2.7902379}}, + {{2.86876941, -0.836064458, -0.374509573, -0.277966499, 3.20654631, + -3.68510771, -7.76134634, 2.23905277, -8.35530376, 5.25071716, + -1.38490796, -2.93542218, 0.509032726, -3.57361269, -2.82580233, + -4.49954033, 2.91235542}, + {-4.37938213, 4.78577232, 2.03453469, 5.48564529, -1.05589461, -1.65940428, + 4.0130887, 5.26074123, 4.67537832, 0.791350365, 6.3880868, 2.50402451, + 7.6603322, -3.16343474, -2.71949649, 4.61576128, 1.3817997}}, + {{0.289200783, 7.06031752, -1.15099299, -5.29136801, -1.343642, -8.36283112, + 4.13158274, -1.93137062, 3.16199875, 2.21854591, 2.18270063, 0.77002573, + 6.90393353, -0.644045949, -5.62211609, -1.09085155, 1.07821059}, + {-3.04716778, -2.52233481, -5.99031925, 2.80152273, 0.340899587, + 0.667474508, -2.39674735, 8.83768654, -5.45613146, -1.55994594, -2.216362, + 1.49354, -4.27255821, -9.05310917, 5.90691471, -1.29772806, -8.50278}}, + {{-3.1383903, -7.71573353, 3.38072681, 6.07642221, -2.39587545, -7.84178352, + -1.60108304, -8.6121521, -5.151721, 4.17612457, -2.86532378, 1.64645958, + -0.37970829, -4.34561253, -0.454322815, 0.331385136, -5.74550819}, + {4.77026033, -5.51171303, -7.38155365, -5.38462543, 2.95842505, 5.18372536, + 0.521988213, 7.23966122, -4.90852165, 7.18465281, 2.99289083, 10.0519466, + -2.09695673, 7.34368706, -2.40495348, 3.61603308, 0.131510735}}, +}; + +// Layer 0, QHead 0 +const float kGoldenQ[kNumTokens][kQBatchSize][kDimsToCompare] = { + {{-0.574401975, 0.370210886, -0.426894158, -0.543187439, -0.0266762674, + -0.177960411, -0.00839618221, 0.411925405, 0.536462784, 0.528389931, + -0.499812007, -0.123897657, -0.0170236826, 0.266041577, -0.0781469196, + -0.44081074, 0.185976267}, + {0.270543516, -0.109283224, -0.58602041, -0.358663559, -0.393124342, + -0.0895933211, -0.632167816, 0.386703, 0.314152211, 0.0554139167, + 0.0241559595, -0.194484815, 0.143893063, 0.103837147, -0.384245932, + -0.00418212265, 0.385817379}}, + {{-0.0331106335, -0.100827977, 0.322449774, 0.225943685, -0.384854138, + -0.208085626, 0.0206767023, 0.287796348, -0.139513299, 0.255447835, + -0.0845065042, -0.0619940236, 0.477489054, 0.517492294, -0.0172665715, + -0.0302075297, 0.365989387}, + {-0.0266781822, -0.453293771, 0.560033202, 0.105156079, -0.35259968, + 0.711447716, -0.253611088, 0.0487165749, -0.086192511, -0.0338740349, + -0.655441046, 0.00413730741, -0.510472536, -0.0748229772, -0.29113093, + -0.0432077348, 0.09223634}}, + {{-0.321974993, -0.466039479, 0.207254037, -0.126807183, -0.192775592, + -0.0953654051, 0.209789664, 0.405356169, -0.00627984107, -0.0590961352, + 0.0907663852, -0.190793216, -0.730463982, 0.340142608, -0.295675993, + -0.165913597, -0.233714506}, + {-0.345578939, 0.394073665, 0.299743414, -0.0075177839, -0.288939595, + 0.127782941, -0.207550645, 0.0655022636, -0.705084503, -0.241842598, + 0.333820701, 0.217911497, 0.29735288, 0.0147881694, -0.152306199, + -0.589594781, -0.373093933}}, + {{0.216089666, 0.0918798149, 0.0560657382, -0.157523662, -0.00141695142, + 0.51770103, 0.596379519, -0.271057904, 0.241035417, -0.275827706, + 0.112851456, 0.026878573, -0.579843462, -0.5116328, 0.192026839, + 0.125176072, 0.34234497}, + {-0.0744233653, 0.180814236, 0.170143247, -0.337861449, -0.175804421, + 0.213403732, -0.173699334, 0.109528325, -0.385727316, 0.109683953, + 0.475667775, 0.253016889, 0.477347463, 0.111096457, 0.394625545, + 0.0172286481, -0.357992649}}, + {{-0.350524545, -0.142550975, -0.212269634, -0.0589753427, -0.434021264, + 0.384472728, 0.445421219, -0.635599554, -0.246593416, 0.120986834, + 0.623568773, -0.161932915, -0.702406883, 0.44038102, 0.268234134, + 0.480264157, 0.103595078}, + {-0.227436215, 0.357608706, -0.25339672, -0.0683218762, -0.179259315, + 0.23657614, 0.559984326, 0.165754288, -0.0402980596, -0.101906747, + -0.278261065, -0.16327399, 0.235923961, -0.428657919, -0.290629387, + 0.579215467, -0.0717103705}}, + {{-0.246389642, -0.266164362, -0.0967710763, -0.4011603, 0.242542207, + 0.0869855583, 0.20158039, 0.207793877, -0.0875666738, -0.242263764, + -0.0462955758, -0.617374003, 0.454443514, 0.207072973, -0.0235372931, + -0.0193868056, -0.660622239}, + {0.703284621, 0.0382430181, 0.43997851, -0.858277559, 0.342218578, + 0.414044619, 0.403636098, -0.579880178, -1.12243, -0.112913512, + 0.629238605, -0.0285760984, -0.152203664, -0.088969171, -0.0681343, + 0.476349175, 0.283238202}}, + {{0.138267457, 0.483219147, 0.230450034, -0.568304598, 0.204461277, + -0.286731184, -0.416590065, -0.483460307, -0.561008453, 0.395195067, + 0.104367018, -0.196090236, -0.324770749, -0.0881370157, -0.626873195, + 0.0936089084, 0.262185335}, + {0.282603383, 0.0723766163, -0.206548154, 0.561849833, 0.482716829, + 0.135281503, -0.438841999, 0.472577304, -0.346201897, -0.0211652666, + -0.0905084163, -0.168639392, -0.154975936, -0.303443581, -0.41771856, + 0.400717318, 0.426146686}}, + {{-0.0537007451, -0.227346331, -0.2871463, 0.247746795, -0.0975416005, + -0.0123391449, 0.0612513907, -0.374673814, 0.283457696, 0.40945363, + 0.137944818, -0.0119741419, 0.775918365, -0.308365196, 0.230615795, + -0.440364927, 0.218536288}, + {0.0688965544, -0.149037778, -0.246169299, 0.0599289536, -0.456733435, + 0.0808929354, 0.115154952, 0.0997388735, -0.408117741, 0.576600909, + -0.193775773, 0.0340575948, -0.29254055, 0.695465446, 0.373336494, + 0.421431482, 0.00197479129}}, + {{0.402076721, -0.118151993, 0.542394996, 0.0382412486, -0.614983976, + 0.28617692, 0.318540633, -0.299300969, -0.177486539, 0.394140214, + 0.0644133314, -0.0321308076, 0.671587527, -0.0173831787, -0.219400048, + -0.340277791, 0.5130288}, + {0.105372488, -0.145784974, 0.0695323348, -0.106080391, -0.755512118, + 0.975362539, -0.15056029, 0.58882606, -0.059625227, -0.810613, + -0.321623206, 0.193939567, 0.0340242684, -0.626081824, 0.109950632, + -0.141072854, 0.0177994221}}, + {{0.243249148, 0.0904035419, -0.472183734, -0.176162, 0.314925164, + -0.191137731, 0.492265761, -0.0120046511, 0.824757636, 0.298175, + 0.148151726, -0.0197859108, -0.64297086, 0.432318538, -0.555079758, + 0.101636633, 0.155741245}, + {0.0523641109, 0.224086404, 0.0143201668, 0.0090854, 0.304901183, + -0.391372293, 0.267655343, 0.117368169, 0.645064473, 0.336050332, + -0.282133281, -0.231817603, 0.376230389, -0.575031936, -0.628365576, + 0.484799922, 0.0824087635}}, +}; + +void RunAttentionTest(AttentionImpl attention_impl) { + TestState state; + TestModelState model_state(state); + TestAttentionState attention_state(state, model_state, kNumTokens, + kQBatchSize, attention_impl); + + GemmaAttention(attention_state.tokens.size(), 0, model_state.layer, + attention_state.attention, *attention_state.qbatch, state.env, + AttentionImplToFlags(attention_impl, HWY_NATIVE_DOT_BF16)); + + CompareAttSumsWithGolden(attention_state.attention, kGoldenAttSums); + CompareKVCacheWithGolden(model_state.config, + hwy::Span(attention_state.kv_caches.data(), + attention_state.kv_caches.size()), + /*layer=*/0, /*kv_head=*/0, kGoldenK, kGoldenV); + CompareQVecsWithGolden(model_state.config, attention_state.attention, + /*q_head=*/0, kGoldenQ); +} + +void TestGemmaAttentionOld() { RunAttentionTest(AttentionImpl::kOld); } + +void TestGemmaAttentionFlash() { RunAttentionTest(AttentionImpl::kFlash); } + +} // namespace HWY_NAMESPACE +} // namespace gcpp +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE + +namespace gcpp { +HWY_BEFORE_TEST(AttentionTest); +HWY_EXPORT_AND_TEST_P(AttentionTest, TestGemmaAttentionOld); +HWY_EXPORT_AND_TEST_P(AttentionTest, TestGemmaAttentionFlash); +HWY_AFTER_TEST(); + +} // namespace gcpp + +#endif diff --git a/gemma/bindings/context.cc b/gemma/bindings/context.cc index 5741d70..5db3adc 100644 --- a/gemma/bindings/context.cc +++ b/gemma/bindings/context.cc @@ -73,45 +73,38 @@ GemmaContext* GemmaContext::Create(const char* tokenizer_path, ThreadingArgs threading_args; threading_args.spin = gcpp::Tristate::kFalse; - LoaderArgs loader(tokenizer_path, weights_path); - LogDebug("LoaderArgs created"); + threading_args.spin = gcpp::Tristate::kFalse; + GemmaArgs args(LoaderArgs(tokenizer_path, weights_path), threading_args); // Initialize cached args LogDebug("Initializing inference args"); - InferenceArgs inference_args; - inference_args.Init(); - inference_args.max_generated_tokens = max_generated_tokens; - inference_args.temperature = 0.7f; - inference_args.top_k = 1; - inference_args.deterministic = false; + args.inference.max_generated_tokens = max_generated_tokens; + args.inference.temperature = 0.7f; + args.inference.top_k = 1; + args.inference.deterministic = false; ss.str(""); ss << "Inference args initialized with max_tokens: " << max_generated_tokens - << ", temperature: " << inference_args.temperature - << ", top_k: " << inference_args.top_k << ", deterministic: " - << (inference_args.deterministic ? "true" : "false"); + << ", temperature: " << args.inference.temperature + << ", top_k: " << args.inference.top_k << ", deterministic: " + << (args.inference.deterministic ? "true" : "false"); LogDebug(ss.str().c_str()); - return new GemmaContext(loader, inference_args, threading_args, - max_generated_tokens); + return new GemmaContext(args, max_generated_tokens); } -GemmaContext::GemmaContext(const LoaderArgs& loader, - const InferenceArgs& inference_args, - const ThreadingArgs& threading_args, - int max_generated_tokens) - : inference_args(inference_args), - threading_args(threading_args), - ctx(threading_args), +GemmaContext::GemmaContext(const GemmaArgs& args, int max_generated_tokens) + : args(args), + ctx(args.threading), matmul_env(ctx), active_conversation_name("default"), - model(loader, inference_args, matmul_env.ctx) { + model(args, matmul_env.ctx) { std::stringstream ss; LogDebug("Creating initial ConversationData"); // Create the initial ConversationData object using make_shared active_conversation = std::make_shared( - model.Config(), inference_args, ctx.allocator); + model.Config(), args.inference, ctx.allocator); LogDebug( "Storing initial ConversationData in conversation_cache[\"default\"]"); @@ -172,8 +165,8 @@ int GemmaContext::GenerateInternal(const char* prompt_string, // set up runtime config TimingInfo timing_info = {}; RuntimeConfig runtime_config = {.stream_token = stream_token, - .use_spinning = threading_args.spin}; - inference_args.CopyTo(runtime_config); + .use_spinning = args.threading.spin}; + args.inference.CopyTo(runtime_config); size_t prefix_end = 0; const ModelConfig& model_config = model.Config(); @@ -247,7 +240,7 @@ int GemmaContext::GenerateInternal(const char* prompt_string, timing_info); // prepare for next turn - if (!inference_args.multiturn || + if (!args.inference.multiturn || model_config.wrapping == PromptWrapping::PALIGEMMA) { // If not multiturn, or Paligemma (which handles turns differently), // reset the *active* conversation's position. diff --git a/gemma/bindings/context.h b/gemma/bindings/context.h index 00648fc..5aa3412 100644 --- a/gemma/bindings/context.h +++ b/gemma/bindings/context.h @@ -53,8 +53,7 @@ typedef void (*GemmaLogCallback)(const char* message, void* user_data); class GemmaContext { private: - GemmaContext(const LoaderArgs& loader, const InferenceArgs& inference_args, - const ThreadingArgs& threading_args, int max_generated_tokens); + GemmaContext(const GemmaArgs& args, int max_generated_tokens); public: static GemmaContext* Create(const char* tokenizer_path, @@ -81,37 +80,37 @@ class GemmaContext { // Set max generated tokens void SetMaxGeneratedTokens(size_t value) { - inference_args.max_generated_tokens = value; + args.inference.max_generated_tokens = value; LogDebug("Setting max_generated_tokens to configured value"); } // Set multiturn flag (0 = disabled, 1 = enabled) void SetMultiturn(int value) { - inference_args.multiturn = value; + args.inference.multiturn = value; LogDebug("Setting multiturn to configured value"); } // Set temperature for token generation void SetTemperature(float value) { - inference_args.temperature = value; + args.inference.temperature = value; LogDebug("Setting temperature to configured value"); } // Set top_k parameter for sampling void SetTopK(int value) { - inference_args.top_k = value; + args.inference.top_k = value; LogDebug("Setting top_k to configured value"); } // Set deterministic flag void SetDeterministic(bool value) { - inference_args.deterministic = value; + args.inference.deterministic = value; LogDebug("Setting deterministic flag to configured value"); } // Set prefill_tbatch_size void SetPrefillTbatchSize(size_t value) { - inference_args.prefill_tbatch_size = value; + args.inference.prefill_tbatch_size = value; LogDebug("Setting prefill_tbatch_size to configured value"); } @@ -175,7 +174,7 @@ class GemmaContext { active_conversation->abs_pos = 0; // Replace the cache within the current ConversationData object active_conversation->kv_cache = std::make_unique( - model.Config(), inference_args, ctx.allocator); + model.Config(), args.inference, ctx.allocator); LogDebug((log_prefix + "Successfully rewound to initial state.").c_str()); } else { @@ -193,7 +192,7 @@ class GemmaContext { LogDebug("Creating new conversation"); // Create a new ConversationData object using make_shared conversation_cache[name] = std::make_shared( - model.Config(), inference_args, ctx.allocator); + model.Config(), args.inference, ctx.allocator); return true; } @@ -274,8 +273,7 @@ class GemmaContext { std::vector token_buffer; // Cached args (remain global for the context) - InferenceArgs inference_args; - ThreadingArgs threading_args; + GemmaArgs args; ThreadingContext ctx; MatMulEnv matmul_env; diff --git a/gemma/configs.cc b/gemma/configs.cc index 8856203..cb508e8 100644 --- a/gemma/configs.cc +++ b/gemma/configs.cc @@ -22,8 +22,8 @@ #include #include "compression/types.h" // Type -#include "io/fields.h" // IFields -#include "io/io.h" // Path +#include "io/fields.h" // IFields +#include "io/io.h" // Path #include "hwy/base.h" namespace gcpp { @@ -238,6 +238,7 @@ static ModelConfig ConfigGemma3_1B() { config.display_name = "Gemma3_1B"; config.model = Model::GEMMA3_1B; config.wrapping = PromptWrapping::GEMMA_VLM; + config.use_global_timescale = true; config.model_dim = 1152; config.vocab_size = kGemmaV3VocabSize; // new vocab size / tokenizer config.max_seq_len = 32 * 1024; @@ -288,6 +289,7 @@ static ModelConfig ConfigGemma3_4B() { config.display_name = "Gemma3_4B"; config.model = Model::GEMMA3_4B; config.wrapping = PromptWrapping::GEMMA_VLM; + config.use_global_timescale = true; AddVitConfig(config, /*image_size=*/896); config.vocab_size = kGemmaV3VocabSize; config.vit_config.pool_dim = 4; @@ -337,6 +339,7 @@ static ModelConfig ConfigGemma3_12B() { config.display_name = "Gemma3_12B"; config.model = Model::GEMMA3_12B; config.wrapping = PromptWrapping::GEMMA_VLM; + config.use_global_timescale = true; AddVitConfig(config, /*image_size=*/896); config.vocab_size = kGemmaV3VocabSize; config.vit_config.pool_dim = 4; @@ -386,6 +389,7 @@ static ModelConfig ConfigGemma3_27B() { config.display_name = "Gemma3_27B"; config.model = Model::GEMMA3_27B; config.wrapping = PromptWrapping::GEMMA_VLM; + config.use_global_timescale = true; AddVitConfig(config, /*image_size=*/896); config.vocab_size = kGemmaV3VocabSize; config.vit_config.pool_dim = 4; @@ -495,19 +499,19 @@ const char* ModelPrefix(Model model) { } PromptWrapping ChooseWrapping(const Model model, Tristate wrapping) { - if (IsPaliGemma(model)) { + const PromptWrapping config_wrapping = ConfigFromModel(model).wrapping; + + // For models with a fixed wrapping mode, ignore user override. + if (config_wrapping == PromptWrapping::PALIGEMMA || + config_wrapping == PromptWrapping::GEMMA_VLM) { if (wrapping != Tristate::kDefault) { - HWY_WARN("Ignoring unnecessary --wrapping for PaliGemma models."); + HWY_WARN("Ignoring unnecessary --wrapping for model %s.", + ModelPrefix(model)); } - return PromptWrapping::PALIGEMMA; + return config_wrapping; } - if (IsVLM(model)) { - if (wrapping != Tristate::kDefault) { - HWY_WARN("Ignoring unnecessary --wrapping for VLM models."); - } - return PromptWrapping::GEMMA_VLM; - } - // Default to IT unless --wrapping=0. + + // For other models, default to IT unless --wrapping=0 is passed. return wrapping == Tristate::kFalse ? PromptWrapping::GEMMA_PT : PromptWrapping::GEMMA_IT; } @@ -674,7 +678,9 @@ Model DeduceModel(const Path& blob_path, size_t layers, int layer_types) { return Model::GEMMA3_270M; case 26: - if (layer_types & kDeducedViT) return Model::GEMMA3_1B; + if (layer_types & (kDeducedViT|kDeducedKqNorm)) { + return Model::GEMMA3_1B; + } return Model::GEMMA2_2B; case 27: return (layer_types & kDeduced448) ? Model::PALIGEMMA2_3B_448 @@ -706,4 +712,11 @@ Model DeduceModel(const Path& blob_path, size_t layers, int layer_types) { } } +AttentionImpl GetAttentionImpl(const std::string& impl) { + if (impl == "old") return AttentionImpl::kOld; + if (impl == "flash") return AttentionImpl::kFlash; + HWY_WARN("Unknown attention implementation: %s. Using kOld.\n", impl.c_str()); + return AttentionImpl::kOld; +} + } // namespace gcpp diff --git a/gemma/configs.h b/gemma/configs.h index 275f374..b727480 100644 --- a/gemma/configs.h +++ b/gemma/configs.h @@ -80,6 +80,38 @@ static inline bool EnumValid(LayerAttentionType type) { return type == LayerAttentionType::kGemma || type == LayerAttentionType::kVit; } +enum class AttentionImpl { + kOld, + kFlash, + kSentinel, +}; + +AttentionImpl GetAttentionImpl(const std::string& impl); + +/* + * Returns a bitmask of flags to pass to attention functions based on the + * attention implementation selected. + * + * If `hwy_native_dot_bf16` is true, the function will use the old attention + * implementation, ignoring `impl`. + * + * `hwy_native_dot_bf16` needs to be passed in, because the HWY_NATIVE_DOT_BF16 + * macro is not available outside of highway instrumented translation units and + * cannot be made accessible from .h files. + */ +static inline int AttentionImplToFlags(AttentionImpl impl, + int hwy_native_dot_bf16) { + if (hwy_native_dot_bf16) return kAttentionUseOld; + + switch (impl) { + case AttentionImpl::kOld: + return kAttentionUseOld; + case AttentionImpl::kFlash: + default: + return 0; + } +} + // Post attention and ffw normalization type. enum class PostNormType { None, @@ -184,13 +216,6 @@ enum class Model { // in Specifier and thus does not change. const char* ModelPrefix(Model model); -// Gemma3 is multimodal and has a different prompt wrapping than PaliGemma. -// This is used for deducing the PromptWrapping for pre-2025 BlobStore. -static inline bool IsVLM(Model model) { - return model == Model::GEMMA3_4B || model == Model::GEMMA3_1B || - model == Model::GEMMA3_12B || model == Model::GEMMA3_27B; -} - static inline bool IsPaliGemma(Model model) { if (model == Model::PALIGEMMA2_3B_224 || model == Model::PALIGEMMA2_3B_448 || model == Model::PALIGEMMA2_10B_224 || @@ -280,7 +305,7 @@ struct LayerConfig : public IFields { uint32_t kv_heads = 0; uint32_t qkv_dim = 0; // length of Q, K, V vectors (contiguous). bool ff_biases = false; - bool optimized_gating = true; // for Gemma3 + bool optimized_gating = true; // for Gemma3 PostNormType post_norm = PostNormType::None; LayerAttentionType type = LayerAttentionType::kGemma; ActivationType activation = ActivationType::Gelu; @@ -383,6 +408,8 @@ struct ModelConfig : public IFields { internal.VisitFields(visitor); + visitor(use_global_timescale); + // Append new fields here, then update `python/configs.cc`. } @@ -481,6 +508,7 @@ struct ModelConfig : public IFields { std::vector scale_base_names; InternalModelConfig internal; + bool use_global_timescale = false; // for Gemma 3 }; // Returns the sub-config for the ViT model of the PaliGemma model. @@ -489,6 +517,7 @@ ModelConfig GetVitConfig(const ModelConfig& config); enum DeducedLayerTypes { kDeducedViT = 2, kDeduced448 = 4, // For ViT, 448x448 resolution instead of 224x224. + kDeducedKqNorm = 8, }; // layer_types is one or more of `DeducedLayerTypes`. diff --git a/gemma/flash_attention.cc b/gemma/flash_attention.cc index bf3aede..671efb4 100644 --- a/gemma/flash_attention.cc +++ b/gemma/flash_attention.cc @@ -17,12 +17,18 @@ #include #include +#include #include +#include +#include #include +#include #include "compression/types.h" // GEMMA_DISABLED_TARGETS +#include "gemma/flash_structs.h" #include "util/threading_context.h" #include "util/zones.h" +#include "hwy/base.h" #ifndef HWY_DISABLED_TARGETS #define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS #endif // HWY_DISABLED_TARGETS @@ -30,7 +36,6 @@ #include "gemma/activations.h" #include "gemma/configs.h" // kMaxQKVDim #include "gemma/gemma.h" -#include "gemma/weights.h" #include "util/threading.h" #include "hwy/profiler.h" @@ -59,7 +64,7 @@ static constexpr size_t kNFx8HTileSize = 8; // q has shape [batch, qbatch][head, qkv_dim]. // q_t has shape [qkv_dim][qbatch, head, batch] in order to make the maximum // possible consecutive elements have the same KV. -static void TransposeQ(const MatPtrT& q, MatPtrT& q_t, +static void TransposeQ(const MatPtrT& q, MatPtrT& q_t, const size_t qbatch_size, ThreadingContext& ctx) { // Group floats by the number of floats in a cache line. const size_t kNF = ctx.cache_info.LineBytes() / sizeof(float); @@ -70,12 +75,13 @@ static void TransposeQ(const MatPtrT& q, MatPtrT& q_t, for (size_t lane = 0; lane < kNF; ++lane) { size_t q_row = task * kNF + lane; if (q_row >= q_t.Rows()) break; - float* HWY_RESTRICT qt_row = q_t.Row(q_row); + BF16* HWY_RESTRICT qt_row = q_t.Row(q_row); for (size_t qi = 0; qi < qbatch_size; ++qi) { for (size_t h = 0; h < num_heads; ++h) { for (size_t b = 0; b < batch_size; ++b) { qt_row[(qi * num_heads + h) * batch_size + b] = - q.Row(b * qbatch_size + qi)[h * q_t.Rows() + q_row]; + hwy::ConvertScalarTo( + q.Row(b * qbatch_size + qi)[h * q_t.Rows() + q_row]); } } } @@ -84,45 +90,48 @@ static void TransposeQ(const MatPtrT& q, MatPtrT& q_t, { const size_t num_tasks = hwy::DivCeil(q_t.Rows(), kNF); // Better than kFlat. - ParallelFor(ParallelismStrategy::kHierarchical, num_tasks, ctx, + ParallelFor(Parallelism::kHierarchical, num_tasks, ctx, /*cluster_idx=*/0, Callers::kFlashTransposeQ, func); } } // Updates q in place for RMSNorm and positional encoding. void RMSNormAndPositionalEncoding(const size_t num_tokens, const QBatch& qbatch, - MatPtrT& q, const size_t layer_idx, - const LayerWeightsPtrs& layer, - const AttentionActivations& activations, + MatPtrT& q, + const MatPtr& query_norm_scale, + const size_t layer_idx, + const AttentionActivationsPtrs& activations, ThreadingContext& ctx) { + const LayerConfig& layer_config = activations.config.layer_configs[layer_idx]; const float query_scale = activations.query_scale; const hwy::Divisor div_qbatch(qbatch.Size()); const auto func = [&](const size_t task, size_t worker) HWY_ATTR { GCPP_ZONE(ctx, worker, Zones::kFlashAttentionRmsNormAndPositionalEncoding); size_t qi = div_qbatch.Remainder(task); size_t batch_idx = div_qbatch.Divide(task); - for (size_t h = 0; h < layer.layer_config.heads; ++h) { + for (size_t h = 0; h < layer_config.heads; ++h) { const size_t tq_idx = qbatch.Size() * batch_idx + qi; // Find the token position in the query and calculate // the range of cache positions to attend to. - const size_t pos = qbatch.Pos(qi) + batch_idx; - float* HWY_RESTRICT q_row = - q.Row(tq_idx) + h * layer.layer_config.qkv_dim; + constexpr size_t offset = 0; // placeholder, do not remove + const size_t pos = + qbatch.Pos(qi) + batch_idx + offset; + float* HWY_RESTRICT q_row = q.Row(tq_idx) + h * layer_config.qkv_dim; // Apply rope and scaling to Q. - if (layer.query_norm_scale.HasPtr()) { - CallUpcasted(&layer.query_norm_scale, [&](const auto* weights_t) { + if (query_norm_scale.HasPtr()) { + CallUpcasted(&query_norm_scale, [&](const auto* weights_t) { RMSNormInplace(weights_t->PackedScale1(), /*w_ofs=*/0, q_row, - layer.layer_config.qkv_dim, ctx, worker); + layer_config.qkv_dim, ctx, worker); }); } - PositionalEncodingQK(q_row, layer_idx, layer, activations, ctx, worker, - pos, query_scale); + PositionalEncodingQK(q_row, layer_idx, activations, ctx, worker, pos, + query_scale); } }; { // kHierarchical is not worth the extra sync overhead because the tasks are // very lightweight. - ParallelFor(ParallelismStrategy::kFlat, num_tokens * qbatch.Size(), ctx, + ParallelFor(Parallelism::kFlat, num_tokens * qbatch.Size(), ctx, /*cluster_idx=*/0, Callers::kFlashRMSNormAndPositionalEncoding, func); } @@ -152,15 +161,20 @@ void HWY_INLINE SingleFlashAttentionStep(float x, float cap, float& old_max, // Calculates the complete attention outputs for a single row of q. void SingleFlashAttention(const size_t start_pos, const size_t last_pos, - const float* HWY_RESTRICT q, const MatPtrT& k, + const BF16* HWY_RESTRICT q, const MatPtrT& k, const MatPtrT& v, const size_t layer_idx, - const LayerWeightsPtrs& layer, - const AttentionActivations& activations, + const AttentionActivationsPtrs& activations, float* HWY_RESTRICT att_out, ThreadingContext& ctx, const size_t worker) { GCPP_ZONE(ctx, worker, Zones::kFlashAttentionSingleFlashAttention); + const hn::ScalableTag dbf; + const size_t qkv_dim = k.Cols(); + const size_t pos_mod = activations.div_seq_len.Remainder(start_pos); - float m = Dot(q, k.Row(pos_mod), k.Cols()); + // TODO: Mixed-mode can be further improved for Turin: we can demote right + // before we do the dot product instruction, rather than promote both to f32. + // But some potential accuracy loss there, needs evaluation first. + float m = Dot(dbf, MakeConstSpan(q, qkv_dim), 0, k.Row(pos_mod), qkv_dim); if (float cap = activations.config.att_cap; cap > 0.0f) { // Compute tanh(x / cap) * cap, being LogitsSoftCap on the scalar x. m = cap * std::tanh(m / cap); @@ -170,7 +184,7 @@ void SingleFlashAttention(const size_t start_pos, const size_t last_pos, MulByConstTo(d, v.Row(pos_mod), att_out, v.Cols(), ctx, worker); for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) { const size_t pos_mod = activations.div_seq_len.Remainder(pos); - float x = Dot(q, k.Row(pos_mod), k.Cols()); + float x = Dot(dbf, MakeConstSpan(q, qkv_dim), 0, k.Row(pos_mod), qkv_dim); SingleFlashAttentionStep(x, activations.config.att_cap, m, d, v.Row(pos_mod), v.Cols(), att_out); } @@ -180,25 +194,27 @@ void SingleFlashAttention(const size_t start_pos, const size_t last_pos, // the dot products of NF rows of Q for a single K timestep. template > VF QDotKVector(DF df, const uint32_t* HWY_RESTRICT q_offsets, - const size_t k_pos, const MatPtrT& q, + const size_t k_pos, const MatPtrT& q, const MatPtrT& k) { + const hn::ScalableTag dbf; + const size_t qkv_dim = k.Cols(); + hn::TFromD results[hn::MaxLanes(df)]; for (size_t i = 0; i < hn::Lanes(df); ++i) { - results[i] = Dot(q.Row(0) + q_offsets[i], k.Row(k_pos), k.Cols()); + results[i] = Dot(dbf, MakeConstSpan(q.Row(0) + q_offsets[i], qkv_dim), 0, + k.Row(k_pos), qkv_dim); } return hn::LoadU(df, results); } -// Returns an NF Q rows by 8 K rows tile of Q.K dot products, in single -// precision. +// Returns an NF Q rows by 8 K rows tile of Q.K dot products. // This is the result of NF rows of Q against 8 K timesteps, with positions // given by k_pos[0..7]. Q has been transposed so that the NF rows are read in // consecutive elements, and other columns by adding q_stride. template > -void QDotKTileFloat(DF df, const float* HWY_RESTRICT q, const size_t q_stride, - const MatPtrT& k, const size_t* k_pos, VF& sum0, - VF& sum1, VF& sum2, VF& sum3, VF& sum4, VF& sum5, VF& sum6, - VF& sum7) { +void QDotKTile(DF df, const BF16* HWY_RESTRICT q, const size_t q_stride, + const MatPtrT& k, const size_t* k_pos, VF& sum0, VF& sum1, + VF& sum2, VF& sum3, VF& sum4, VF& sum5, VF& sum6, VF& sum7) { constexpr size_t kHTileSize = kNFx8HTileSize; sum0 = hn::Zero(df); sum1 = hn::Zero(df); @@ -209,11 +225,16 @@ void QDotKTileFloat(DF df, const float* HWY_RESTRICT q, const size_t q_stride, sum6 = hn::Zero(df); sum7 = hn::Zero(df); const float* HWY_RESTRICT k_row[kHTileSize]; - for (int i = 0; i < kHTileSize; ++i) { + for (size_t i = 0; i < kHTileSize; ++i) { k_row[i] = k.Row(k_pos[i]); } + + const hn::Rebind dbfh; + using VBF = hn::Vec; + for (size_t i = 0; i < k.Cols(); ++i) { - VF q_vec = hn::Load(df, q); + const VBF q_vec_bf = hn::Load(dbfh, q); + const VF q_vec = hn::PromoteTo(df, q_vec_bf); VF k_0 = hn::Set(df, k_row[0][i]); sum0 = hn::MulAdd(q_vec, k_0, sum0); VF k_1 = hn::Set(df, k_row[1][i]); @@ -266,31 +287,30 @@ VF HWY_INLINE ElementwiseSumOf8(DF df, const VF& x0, const VF& x1, const VF& x2, // min_last_pos, then sweeps the remaining timesteps in the range (min_last_pos, // max_last_pos]. void TileFlashAttention( - const MatPtrT& q, const uint32_t* HWY_RESTRICT q_offsets, - const StridedView& qT, const MatPtrT& k, - const size_t start_pos, const uint32_t* HWY_RESTRICT last_pos, - const size_t min_last_pos, const size_t max_last_pos, - const MatPtrT& v, const size_t layer_idx, - const LayerWeightsPtrs& layer, const AttentionActivations& activations, - MatPtrT& att_out, const uint32_t* HWY_RESTRICT out_offsets, - ThreadingContext& ctx, const size_t worker) { + const MatPtrT& q, const uint32_t* HWY_RESTRICT q_offsets, + const StridedView& qT, const MatPtrT& k, const size_t start_pos, + const uint32_t* HWY_RESTRICT last_pos, const size_t min_last_pos, + const size_t max_last_pos, const MatPtrT& v, const size_t layer_idx, + const AttentionActivationsPtrs& activations, MatPtrT& att_out, + const uint32_t* HWY_RESTRICT out_offsets, ThreadingContext& ctx, + const size_t worker) { GCPP_ZONE(ctx, worker, Zones::kFlashAttentionTileFlashAttention); - constexpr int kHTileSize = kNFx8HTileSize; + constexpr size_t kHTileSize = kNFx8HTileSize; using DF = hn::ScalableTag; const DF df; using VF = hn::Vec; using DI = hn::ScalableTag; const DI di; using VI = hn::Vec; - const int kVTileSize = hn::Lanes(df); - for (int i = 0; i < kVTileSize; ++i) { + const size_t kVTileSize = hn::Lanes(df); + for (size_t i = 0; i < kVTileSize; ++i) { hwy::ZeroBytes(att_out.Row(0) + out_offsets[i], v.Cols() * sizeof(att_out.Row(0)[0])); } VI lasts = hn::LoadU(di, last_pos); VF old_m = hn::Set(df, -std::numeric_limits::max() / 2.0f); VF old_d = hn::Zero(df); - const float* HWY_RESTRICT qT_row = qT.Row(0); + const BF16* HWY_RESTRICT qT_row = qT.Row(0); const size_t qT_stride = qT.Stride(); size_t position = start_pos; while (position + kHTileSize - 1 <= min_last_pos) { @@ -299,8 +319,7 @@ void TileFlashAttention( k_pos[i] = activations.div_seq_len.Remainder(position + i); } VF x0, x1, x2, x3, x4, x5, x6, x7; - QDotKTileFloat(df, qT_row, qT_stride, k, k_pos, x0, x1, x2, x3, x4, x5, x6, - x7); + QDotKTile(df, qT_row, qT_stride, k, k_pos, x0, x1, x2, x3, x4, x5, x6, x7); if (activations.config.att_cap > 0.0f) { // Compute tanh(x / cap) * cap, being LogitsSoftCap on the tile. VF cap = hn::Set(df, activations.config.att_cap); @@ -374,7 +393,7 @@ void TileFlashAttention( // This is the result of 4 rows of Q against NF K timesteps, with positions // given by k_offsets[0..NF]. template > -void QDotKTilex4(DF df, const float* HWY_RESTRICT q, +void QDotKTilex4(DF df, const BF16* HWY_RESTRICT q, const uint32_t* HWY_RESTRICT q_offsets, const MatPtrT& k, const int32_t* HWY_RESTRICT k_offsets, VF& sum0, VF& sum1, VF& sum2, VF& sum3) { @@ -389,13 +408,13 @@ void QDotKTilex4(DF df, const float* HWY_RESTRICT q, VI k_offsets_vec = hn::LoadU(di, k_offsets); for (size_t i = 0; i < k.Cols(); ++i) { VF k_vec = hn::GatherIndex(df, k_base + i, k_offsets_vec); - VF q_0 = hn::Set(df, q[q_offsets[0] + i]); + VF q_0 = hn::Set(df, hwy::ConvertScalarTo(q[q_offsets[0] + i])); sum0 = hn::MulAdd(q_0, k_vec, sum0); - VF q_1 = hn::Set(df, q[q_offsets[1] + i]); + VF q_1 = hn::Set(df, hwy::ConvertScalarTo(q[q_offsets[1] + i])); sum1 = hn::MulAdd(q_1, k_vec, sum1); - VF q_2 = hn::Set(df, q[q_offsets[2] + i]); + VF q_2 = hn::Set(df, hwy::ConvertScalarTo(q[q_offsets[2] + i])); sum2 = hn::MulAdd(q_2, k_vec, sum2); - VF q_3 = hn::Set(df, q[q_offsets[3] + i]); + VF q_3 = hn::Set(df, hwy::ConvertScalarTo(q[q_offsets[3] + i])); sum3 = hn::MulAdd(q_3, k_vec, sum3); } } @@ -410,23 +429,202 @@ float HWY_INLINE SingleFlashAttentionRowVector(DF df, VF& x, float& old_max, float scale = old_d * std::exp(old_max - m); old_d = hn::ReduceSum(df, x) + scale; old_max = m; - float one_over_d = 1.0f / old_d; - scale *= one_over_d; - x = hn::Mul(x, hn::Set(df, one_over_d)); + if (old_d > 0.0f) { + const float one_over_d = 1.0f / old_d; + scale *= one_over_d; + x = hn::Mul(x, hn::Set(df, one_over_d)); + } else { + scale = 0.0f; + x = hn::Zero(df); + } return scale; } -// Sweeps a tile of 4 Q rows by NF K timesteps accumulators from start_pos to -// min_last_pos, then sweeps the remaining timesteps in the range (min_last_pos, -// max_last_pos]. -void TileFlashAttention4( - const MatPtrT& q, const uint32_t* HWY_RESTRICT q_offsets, +// Reduces each of x and stores in following lanes of max (tested with float32) +template , + class DF4 = hn::CappedTag, class VF4 = hn::Vec, + class VF = hn::Vec, typename F> +static VF4 HWY_INLINE Reduce4(DF df, VF x_0, VF x_1, VF x_2, VF x_3, + F reducer) { + const DF4 df4; + constexpr size_t kMaxLanes = hn::MaxLanes(df); + HWY_LANES_CONSTEXPR size_t kLanes = hn::Lanes(df); + HWY_ALIGN T x_transposed[4 * kMaxLanes]; + hn::StoreInterleaved4(x_0, x_1, x_2, x_3, df, x_transposed); + VF4 result = hn::Load(df4, x_transposed); + for (int i = 1; i < kLanes; ++i) { + result = reducer(result, hn::Load(df4, x_transposed + i * 4)); + } + return result; +} + +// Handles Up to 4 Q rows by NF*2 timesteps of flash attention. +template > +static void HWY_INLINE FlashAttentionTileStepAndApplySoftCap( + DF df, float att_cap, float one_over_att_cap, VF& x_0_p0, VF& x_0_p1, + VF& x_1_p0, VF& x_1_p1, VF& x_2_p0, VF& x_2_p1, VF& x_3_p0, VF& x_3_p1, + float* HWY_RESTRICT old_max, float* HWY_RESTRICT old_d, + float* HWY_RESTRICT scales) { + using DF4 = hn::CappedTag; + const DF4 df4; + using VF4 = hn::Vec; + static_assert(kNumQueries >= 1 && kNumQueries <= 4); + VF4 new_max = hn::Set(df4, -std::numeric_limits::max() / 2.0f); + VF max_0, max_1, max_2, max_3 = hn::Zero(df); + max_0 = hn::Max(x_0_p0, x_0_p1); + if constexpr (kNumQueries >= 2) { + max_1 = hn::Max(x_1_p0, x_1_p1); + } + if constexpr (kNumQueries >= 3) { + max_2 = hn::Max(x_2_p0, x_2_p1); + } + if constexpr (kNumQueries >= 4) { + max_3 = hn::Max(x_3_p0, x_3_p1); + } + if constexpr (kNumQueries == 1) { + new_max = hn::InsertLane(new_max, 0, hn::ReduceMax(df, max_0)); + } else { + new_max = Reduce4(df, max_0, max_1, max_2, max_3, + [](auto a, auto b) { return hn::Max(a, b); }); + } + if (att_cap > 0.0f) { + VF4 cap = hn::Set(df4, att_cap); + VF4 one_over_cap = hn::Set(df4, one_over_att_cap); + new_max = hn::Mul(cap, hn::Tanh(df4, hn::Mul(new_max, one_over_cap))); + } + VF4 old_max_vf = hn::Set(df4, -std::numeric_limits::max() / 2.0f); + old_max_vf = hn::LoadU(df4, old_max); + new_max = hn::Max(new_max, old_max_vf); + // TODO figure out what was wrong with broadcasts and change to that. + HWY_ALIGN float tmp_max[4]; + hn::Store(new_max, df4, tmp_max); + if constexpr (kNumQueries >= 1) { + const VF new_max_0 = hn::Set(df, tmp_max[0]); + x_0_p0 = hn::Exp(df, hn::Sub(x_0_p0 , new_max_0)); + x_0_p1 = hn::Exp(df, hn::Sub(x_0_p1, new_max_0)); + } + if constexpr (kNumQueries >= 2) { + const VF new_max_0 = hn::Set(df, tmp_max[1]); + x_1_p0 = hn::Exp(df, hn::Sub(x_1_p0, new_max_0)); + x_1_p1 = hn::Exp(df, hn::Sub(x_1_p1, new_max_0)); + } + if constexpr (kNumQueries >= 3) { + const VF new_max_0 = hn::Set(df, tmp_max[2]); + x_2_p0 = hn::Exp(df, hn::Sub(x_2_p0, new_max_0)); + x_2_p1 = hn::Exp(df, hn::Sub(x_2_p1, new_max_0)); + } + if constexpr (kNumQueries >= 4) { + const VF new_max_0 = hn::Set(df, tmp_max[3]); + x_3_p0 = hn::Exp(df, hn::Sub(x_3_p0, new_max_0)); + x_3_p1 = hn::Exp(df, hn::Sub(x_3_p1, new_max_0)); + } + VF4 old_d_vf = hn::Set(df4, 0.0f); + old_d_vf = hn::LoadU(df4, old_d); + VF4 scale = hn::Mul(old_d_vf, hn::Exp(df4, hn::Sub(old_max_vf, new_max))); + + hn::StoreU(new_max, df4, old_max); + + VF4 x_sum = hn::Zero(df4); + if constexpr (kNumQueries == 1) { + x_sum = hn::Set(df4, hn::ReduceSum(df, x_0_p0) + hn::ReduceSum(df, x_0_p1)); + } else { + VF x_0_sum = hn::Add(x_0_p0, x_0_p1); + VF x_1_sum = hn::Add(x_1_p0, x_1_p1); + VF x_2_sum = hn::Add(x_2_p0, x_2_p1); + VF x_3_sum = hn::Add(x_3_p0, x_3_p1); + x_sum = Reduce4(df, x_0_sum, x_1_sum, x_2_sum, x_3_sum, + [](auto a, auto b) { return hn::Add(a, b); }); + } + old_d_vf = hn::Add(scale, x_sum); + auto non_zero_mask = hn::Gt(old_d_vf, hn::Set(df4, 0.0f)); + const VF zero = hn::Zero(df); + const VF4 zero4 = hn::Zero(df4); + const VF4 one_over_d = + hn::MaskedDivOr(zero4, non_zero_mask, hn::Set(df4, 1.0f), old_d_vf); + float tmp_one_over_d[4]; + hn::Store(one_over_d, df4, tmp_one_over_d); + hn::Store(old_d_vf, df4, old_d); + scale = hn::Mul(scale, one_over_d); + hn::Store(scale, df4, scales); + if (hn::ExtractLane(old_d_vf, 0) > 0.0f) { + const VF one_over_d_0 = hn::Set(df, tmp_one_over_d[0]); + x_0_p0 = hn::Mul(x_0_p0, one_over_d_0); + x_0_p1 = hn::Mul(x_0_p1, one_over_d_0); + } else { + x_0_p0 = zero; + x_0_p1 = zero; + } + if constexpr (kNumQueries >= 2) { + if (hn::ExtractLane(old_d_vf, 1) > 0.0f) { + const VF one_over_d_1 = hn::Set(df, tmp_one_over_d[1]); + x_1_p0 = hn::Mul(x_1_p0, one_over_d_1); + x_1_p1 = hn::Mul(x_1_p1, one_over_d_1); + } else { + x_1_p0 = zero; + x_1_p1 = zero; + } + } + if constexpr (kNumQueries >= 3) { + if (hn::ExtractLane(old_d_vf, 2) > 0.0f) { + const VF one_over_d_2 = hn::Set(df, tmp_one_over_d[2]); + x_2_p0 = hn::Mul(x_2_p0, one_over_d_2); + x_2_p1 = hn::Mul(x_2_p1, one_over_d_2); + } else { + x_2_p0 = zero; + x_2_p1 = zero; + } + } + if constexpr (kNumQueries >= 4) { + if (hn::ExtractLane(old_d_vf, 3) > 0.0f) { + const VF one_over_d_3 = hn::Set(df, tmp_one_over_d[3]); + x_3_p0 = hn::Mul(x_3_p0, one_over_d_3); + x_3_p1 = hn::Mul(x_3_p1, one_over_d_3); + } else { + x_3_p0 = zero; + x_3_p1 = zero; + } + } +} + +// Implements flash attention for a strip of 4 query vectors. +// It iterates through timesteps in K from `start_pos` up to `max_last_pos`. +// Timesteps up to `min_last_pos` (*) are processed in tiles of shape 4 Q rows +// by NF timesteps in K for efficiency while timesteps between `min_last_pos + +// 1` and `max_last_pos` are processed one-by-one to handle differing `last_pos` +// values within the strip. +// (*) Actually, it only iterates through +// `min_last_pos - (min_last_pos + 1 - start_pos) % NF` in tiles, as the tiled +// computation can, for obvious reasons, only process an integer number of +// tiles. +// +// @param q The query matrix [batch_size * q_heads, qkv_dim] in BF16 format. +// @param q_offsets Offsets from `q.Row(0)` to the start of the 4 query +// vectors to be processed in this tile. +// @param k Key matrix [seq_len, qkv_dim] from KV cache. +// @param start_pos The first token position in the KV cache to attend to. +// @param last_pos An array of 4 indices giving the last token position +// (inclusive) that each of the 4 queries may attend to. +// @param min_last_pos The minimum value in `last_pos`. Timesteps up to this +// position can be processed efficiently in batches. +// @param max_last_pos The maximum value in `last_pos`. Timesteps between +// `min_last_pos + 1` and this position are processed individually to +// respect each query's `last_pos` limit. +// @param v Value matrix [seq_len, qkv_dim] from KV cache. +// @param layer_idx The index of the current transformer layer. +// @param activations Attention configurations and buffers. +// @param att_out Output buffer for attention results. +// @param out_offsets Offsets from `att_out.Row(0)` to store the 4 output +// vectors. +// @param ctx Threading context. +// @param worker Worker thread index. +Tile4FlashState TileFlashAttention4( + const MatPtrT& q, const uint32_t* HWY_RESTRICT q_offsets, const MatPtrT& k, const size_t start_pos, const uint32_t* HWY_RESTRICT last_pos, const size_t min_last_pos, const size_t max_last_pos, const MatPtrT& v, const size_t layer_idx, - const LayerWeightsPtrs& layer, const AttentionActivations& activations, - MatPtrT& att_out, const uint32_t* HWY_RESTRICT out_offsets, - ThreadingContext& ctx, const size_t worker) { + const AttentionActivationsPtrs& activations, MatPtrT& att_out, + const uint32_t* HWY_RESTRICT out_offsets, ThreadingContext& ctx, + const size_t worker) { GCPP_ZONE(ctx, worker, Zones::kFlashAttentionTileFlashAttention4); using DF = hn::ScalableTag; const DF df; @@ -440,14 +638,7 @@ void TileFlashAttention4( hwy::ZeroBytes(att_out.Row(0) + out_offsets[i], v.Cols() * sizeof(att_out.Row(0)[0])); } - float old_m0 = -std::numeric_limits::max() / 2.0f; - float old_m1 = -std::numeric_limits::max() / 2.0f; - float old_m2 = -std::numeric_limits::max() / 2.0f; - float old_m3 = -std::numeric_limits::max() / 2.0f; - float old_d0 = 0.0f; - float old_d1 = 0.0f; - float old_d2 = 0.0f; - float old_d3 = 0.0f; + Tile4FlashState state; size_t position = start_pos; while (position + kHTileSize - 1 <= min_last_pos) { int32_t k_offsets[kMaxNF]; @@ -467,46 +658,62 @@ void TileFlashAttention4( x2 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x2, one_over_cap))); x3 = hn::Mul(cap, hn::Tanh(df, hn::Mul(x3, one_over_cap))); } - scales[0] = SingleFlashAttentionRowVector(df, x0, old_m0, old_d0); - scales[1] = SingleFlashAttentionRowVector(df, x1, old_m1, old_d1); - scales[2] = SingleFlashAttentionRowVector(df, x2, old_m2, old_d2); - scales[3] = SingleFlashAttentionRowVector(df, x3, old_m3, old_d3); + scales[0] = SingleFlashAttentionRowVector(df, x0, state.row_states[0].max, + state.row_states[0].d); + scales[1] = SingleFlashAttentionRowVector(df, x1, state.row_states[1].max, + state.row_states[1].d); + scales[2] = SingleFlashAttentionRowVector(df, x2, state.row_states[2].max, + state.row_states[2].d); + scales[3] = SingleFlashAttentionRowVector(df, x3, state.row_states[3].max, + state.row_states[3].d); MulByConstAndAddTile4(df, scales, x0, x1, x2, x3, v, v_pos, att_out.Row(0), out_offsets, v.Cols()); position += kHTileSize; } + const hn::ScalableTag dbf; + const size_t qkv_dim = k.Cols(); + while (position <= max_last_pos) { size_t k_pos = activations.div_seq_len.Remainder(position); if (position <= last_pos[0]) { // Past the last position, x0 doesn't count. - float x0 = Dot(q.Row(0) + q_offsets[0], k.Row(k_pos), k.Cols()); - SingleFlashAttentionStep(x0, activations.config.att_cap, old_m0, old_d0, + float x0 = Dot(dbf, MakeConstSpan(q.Row(0) + q_offsets[0], qkv_dim), 0, + k.Row(k_pos), qkv_dim); + SingleFlashAttentionStep(x0, activations.config.att_cap, + state.row_states[0].max, state.row_states[0].d, v.Row(k_pos), v.Cols(), att_out.Row(0) + out_offsets[0]); } if (position <= last_pos[1]) { // Past the last position, x1 doesn't count. - float x1 = Dot(q.Row(0) + q_offsets[1], k.Row(k_pos), k.Cols()); - SingleFlashAttentionStep(x1, activations.config.att_cap, old_m1, old_d1, + float x1 = Dot(dbf, MakeConstSpan(q.Row(0) + q_offsets[1], qkv_dim), 0, + k.Row(k_pos), qkv_dim); + SingleFlashAttentionStep(x1, activations.config.att_cap, + state.row_states[1].max, state.row_states[1].d, v.Row(k_pos), v.Cols(), att_out.Row(0) + out_offsets[1]); } if (position <= last_pos[2]) { // Past the last position, x2 doesn't count. - float x2 = Dot(q.Row(0) + q_offsets[2], k.Row(k_pos), k.Cols()); - SingleFlashAttentionStep(x2, activations.config.att_cap, old_m2, old_d2, + float x2 = Dot(dbf, MakeConstSpan(q.Row(0) + q_offsets[2], qkv_dim), 0, + k.Row(k_pos), qkv_dim); + SingleFlashAttentionStep(x2, activations.config.att_cap, + state.row_states[2].max, state.row_states[2].d, v.Row(k_pos), v.Cols(), att_out.Row(0) + out_offsets[2]); } if (position <= last_pos[3]) { // Past the last position, x3 doesn't count. - float x3 = Dot(q.Row(0) + q_offsets[3], k.Row(k_pos), k.Cols()); - SingleFlashAttentionStep(x3, activations.config.att_cap, old_m3, old_d3, + float x3 = Dot(dbf, MakeConstSpan(q.Row(0) + q_offsets[3], qkv_dim), 0, + k.Row(k_pos), qkv_dim); + SingleFlashAttentionStep(x3, activations.config.att_cap, + state.row_states[3].max, state.row_states[3].d, v.Row(k_pos), v.Cols(), att_out.Row(0) + out_offsets[3]); } ++position; } + return state; } // Rounds n to a number that can be used as the number of Q rows in a tile @@ -589,14 +796,25 @@ size_t GetVTileSize(size_t kNF, size_t num_head_groups, size_t num_tokens, // grouped together so that mode 1 or 2 can be used, and choosing which of the // 3 modes to use for best efficiency. void FlashAttention(const size_t num_tokens, const size_t target_parallelism, - const size_t layer_idx, const LayerWeightsPtrs& layer, - AttentionActivations& activations, QBatch& qbatch, + const size_t layer_idx, const MatPtr& query_norm_scale, + AttentionActivationsPtrs& activations, QBatch& qbatch, ThreadingContext& ctx) { GCPP_ZONE(ctx, 0, Zones::kFlashAttentionInclusive); - RMSNormAndPositionalEncoding(num_tokens, qbatch, activations.q, layer_idx, - layer, activations, ctx); + RMSNormAndPositionalEncoding(num_tokens, qbatch, activations.q, + query_norm_scale, layer_idx, activations, ctx); const hwy::Divisor div_qbatch(qbatch.Size()); - const LayerConfig& layer_config = layer.layer_config; + // Compress q to q_bf. + ParallelFor( + Parallelism::kWithinCluster, activations.q.Rows(), ctx, + /*cluster_idx=*/0, Callers::kFlashAttention, + [&](size_t row, size_t worker) { + CompressPerThread tls; + const hn::ScalableTag df; + CompressTraits::Compress( + df, activations.q.Row(row), activations.q.Cols(), tls, + MakeSpan(activations.q_bf.Row(row), activations.q_bf.Cols()), 0); + }); + const LayerConfig& layer_config = activations.config.layer_configs[layer_idx]; const size_t qkv_dim = layer_config.qkv_dim; // A "head group" in the context of GQA refers to a collection of query @@ -684,14 +902,14 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism, size_t last = pos; const size_t prefix_end = qbatch.PrefixEnd(qi); if (prefix_end > 0 && prefix_end - 1 > last) { - // last_pos in QDotK and WeightedSumV is inclusive. + // last_pos in `TileFlashAttention` is inclusive. last = prefix_end - 1; } last_pos[offset] = last; min_last_pos = HWY_MIN(min_last_pos, last); max_last_pos = HWY_MAX(max_last_pos, last); - q_offsets[offset] = - activations.q.Row(tq_idx) + head * qkv_dim - activations.q.Row(0); + q_offsets[offset] = activations.q_bf.Row(tq_idx) + head * qkv_dim - + activations.q_bf.Row(0); out_offsets[offset] = activations.att_out.Row(tq_idx) + head * qkv_dim - activations.att_out.Row(0); const size_t kv_index = head / kHeadGroups; @@ -719,9 +937,9 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism, // To avoid duplicating the code to setup K and V, the call to // TileFlashAttention is inside the loop over tasks, even though it // handles all rows in the task at once. - StridedView qT = - StridedView(activations.q_T.Row(0) + first_task, kVTileSize, - activations.q_T.Stride()); + StridedView qT = + StridedView(activations.q_T.Row(0) + first_task, kVTileSize, + activations.q_T.Stride()); if (kVTileSize == kNF) { // We can still use TileFlashAttention even if we didn't transpose Q // above. The condition used for transposing Q above is more general @@ -730,14 +948,14 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism, // kNFx8HTileSize. In this case, qT is never used. Some tasks might // use qT and some might not, which is why the more general condition // is used above to catch all cases where qT will be used. - TileFlashAttention(activations.q, q_offsets, qT, k, + TileFlashAttention(activations.q_bf, q_offsets, qT, k, start_positions[offset], last_pos, min_last_pos, - max_last_pos, v, layer_idx, layer, activations, + max_last_pos, v, layer_idx, activations, activations.att_out, out_offsets, ctx, worker); } else if (kVTileSize == 4) { - TileFlashAttention4(activations.q, q_offsets, k, + TileFlashAttention4(activations.q_bf, q_offsets, k, start_positions[offset], last_pos, min_last_pos, - max_last_pos, v, layer_idx, layer, activations, + max_last_pos, v, layer_idx, activations, activations.att_out, out_offsets, ctx, worker); } else { HWY_UNREACHABLE; @@ -745,8 +963,8 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism, break; } else { SingleFlashAttention(start_positions[offset], last_pos[offset], - activations.q.Row(0) + q_offsets[offset], k, v, - layer_idx, layer, activations, + activations.q_bf.Row(0) + q_offsets[offset], k, v, + layer_idx, activations, activations.att_out.Row(0) + out_offsets[offset], ctx, worker); } diff --git a/gemma/flash_attention.h b/gemma/flash_attention.h index 959b227..b8a70ea 100644 --- a/gemma/flash_attention.h +++ b/gemma/flash_attention.h @@ -20,36 +20,48 @@ #include +#include + +#include "gemma/flash_structs.h" #include "gemma/gemma.h" #include "hwy/highway.h" namespace gcpp { // Passed to HWY_VISIT_TARGETS; declares for one target. -#define GEMMA_DECL_FLASH_ATTENTION(TARGET, NAMESPACE) \ - namespace NAMESPACE { \ - void RMSNormAndPositionalEncoding(size_t num_tokens, const QBatch& qbatch, \ - MatPtrT& q, size_t layer_idx, \ - const LayerWeightsPtrs& layer, \ - const AttentionActivations& activations, \ - ThreadingContext& ctx); \ - \ - void SingleFlashAttention(size_t start_pos, size_t last_pos, \ - const float* HWY_RESTRICT q, \ - const MatPtrT& k, const MatPtrT& v, \ - size_t layer_idx, const LayerWeightsPtrs& layer, \ - const AttentionActivations& activations, \ - float* HWY_RESTRICT att_out, \ - ThreadingContext& ctx, size_t worker); \ - \ - size_t GetVTileSize(size_t kNF, size_t num_head_groups, size_t num_tokens, \ - size_t total_tasks, size_t target_parallelism); \ - \ - void FlashAttention(size_t num_tokens, size_t target_parallelism, \ - size_t layer_idx, const LayerWeightsPtrs& layer, \ - AttentionActivations& activations, QBatch& qbatch, \ - ThreadingContext& ctx); \ - /* NOLINTNEXTLINE(google-readability-namespace-comments) */ \ +#define GEMMA_DECL_FLASH_ATTENTION(TARGET, NAMESPACE) \ + namespace NAMESPACE { \ + void RMSNormAndPositionalEncoding( \ + size_t num_tokens, const QBatch& qbatch, MatPtrT& q, \ + const MatPtr& query_norm_scale, size_t layer_idx, \ + const AttentionActivationsPtrs& activations, ThreadingContext& ctx); \ + \ + void SingleFlashAttention(size_t start_pos, size_t last_pos, \ + const BF16* HWY_RESTRICT q, \ + const MatPtrT& k, const MatPtrT& v, \ + size_t layer_idx, \ + const AttentionActivationsPtrs& activations, \ + float* HWY_RESTRICT att_out, \ + ThreadingContext& ctx, size_t worker); \ + \ + Tile4FlashState TileFlashAttention4( \ + const MatPtrT& q, const uint32_t* HWY_RESTRICT q_offsets, \ + const MatPtrT& k, size_t start_pos, \ + const uint32_t* HWY_RESTRICT last_pos, size_t min_last_pos, \ + size_t max_last_pos, const MatPtrT& v, size_t layer_idx, \ + const LayerWeightsPtrs& layer, const AttentionActivations& activations, \ + MatPtrT& att_out, const uint32_t* HWY_RESTRICT out_offsets, \ + ThreadingContext& ctx, const size_t worker); \ + \ + size_t GetVTileSize(size_t kNF, size_t num_head_groups, size_t num_tokens, \ + size_t total_tasks, size_t target_parallelism); \ + \ + void FlashAttention(size_t num_tokens, size_t target_parallelism, \ + size_t layer_idx, const MatPtr& query_norm_scale, \ + AttentionActivationsPtrs& activations, QBatch& qbatch, \ + ThreadingContext& ctx); \ + \ + /* NOLINTNEXTLINE(google-readability-namespace-comments) */ \ } // namespace NAMESPACE // Function declarations for each SIMD target. Allows direct call from the diff --git a/gemma/flash_attention_test.cc b/gemma/flash_attention_test.cc index 4147e38..f0a90fa 100644 --- a/gemma/flash_attention_test.cc +++ b/gemma/flash_attention_test.cc @@ -14,6 +14,8 @@ // limitations under the License. #include +#include +#include #include #include @@ -24,6 +26,7 @@ #include "gemma/kv_cache.h" #include "gemma/weights.h" #include "ops/matmul.h" +#include "util/test_util.h" #ifndef HWY_DISABLED_TARGETS #define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS #endif // HWY_DISABLED_TARGETS @@ -109,11 +112,14 @@ void TestFlashAttention(size_t target_parallelism) { const LayerConfig& layer_config = config.layer_configs[0]; const LayerWeightsPtrs layers(0, layer_config, tensor_info_registry); InferenceArgs inference_args; + inference_args.attention_impl = "flash"; RuntimeConfig runtime_config; + inference_args.CopyTo(runtime_config); KVCache kv_cache(config, inference_args, ctx.allocator); MatMulEnv env(ctx); - Activations activations(config, runtime_config.prefill_tbatch_size, - kv_cache.SeqLen(), env.ctx, env.row_ptrs); + Activations activations(runtime_config, config, + runtime_config.prefill_tbatch_size, kv_cache.SeqLen(), + env.ctx, env.row_ptrs); std::vector tokens(kOuter); std::iota(tokens.begin(), tokens.end(), 1); PromptTokens prompt(tokens); @@ -122,8 +128,10 @@ void TestFlashAttention(size_t target_parallelism) { QBatch qbatch(/*start=*/0, /*max_size=*/kOuter, all_queries); const size_t batch_size = kOuter; std::vector> row_ptrs; - AttentionActivations attention(config, layer_config, batch_size, kOuter, - ctx.allocator, row_ptrs); + AttentionActivations attention_storage(config, layer_config, batch_size, + kOuter, runtime_config, ctx.allocator, + row_ptrs); + AttentionActivationsPtrs attention(config, kOuter, attention_storage); const size_t qkv_dim = layer_config.qkv_dim; ASSERT_EQ(qkv_dim, kInner); const hwy::Divisor div_qbatch(qbatch.Size()); @@ -145,7 +153,8 @@ void TestFlashAttention(size_t target_parallelism) { SetMat(h + layer_config.heads * 2, v); } SetMat(1, attention.q); - DotSoftmaxWeightedSum(tokens.size(), 0, layers, attention, qbatch, ctx); + DotSoftmaxWeightedSum(tokens.size(), 0, layers.query_norm_scale, attention, + qbatch, ctx); // Copy the output to saved_att to allow for comparison. auto saved_att = MakeCopyOfMat(attention.att_out, ctx.allocator); SetMat(1, attention.q); @@ -158,8 +167,8 @@ void TestFlashAttention(size_t target_parallelism) { total_tasks, target_parallelism); printf("FlashAttention: target_parallelism=%zu, kNF=%zu, kVTileSize=%zu\n", target_parallelism, kNF, kVTileSize); - FlashAttention(tokens.size(), target_parallelism, 0, layers, attention, - qbatch, ctx); + FlashAttention(tokens.size(), target_parallelism, 0, layers.query_norm_scale, + attention, qbatch, ctx); AssertClose(attention.att_out, *saved_att); ctx.profiler.PrintResults(); } diff --git a/gemma/flash_structs.h b/gemma/flash_structs.h new file mode 100644 index 0000000..73563fe --- /dev/null +++ b/gemma/flash_structs.h @@ -0,0 +1,31 @@ +#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_FLASH_STRUCTS_H_ +#define THIRD_PARTY_GEMMA_CPP_GEMMA_FLASH_STRUCTS_H_ + +#include + +#include + +namespace gcpp { + +// State for computing softmax in a streaming ("online") manner, +// avoiding large intermediate values by subtracting the running maximum. +// For a sequence x_1, ..., x_n: +// m_i = max(m_{i-1}, x_i) +// d_i = d_{i-1} * exp(m_{i-1} - m_i) + exp(x_i - m_i) +// softmax_i = exp(x_i - m_i) / d_i +struct OnlineSoftmaxState { + // Maximum logit value encountered so far. + float max = -std::numeric_limits::max() / 2.0f; + // Sum of exponentials scaled by exp(-max). + float d = 0.0f; +}; + +static constexpr size_t kVTileSize4 = 4; + +struct Tile4FlashState { + OnlineSoftmaxState row_states[kVTileSize4]; +}; + +} // namespace gcpp + +#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_FLASH_STRUCTS_H_ diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index dc7efea..0cd364a 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -20,6 +20,7 @@ #include "gemma/activations.h" #include "gemma/configs.h" +#include "gemma/tensor_stats.h" #include "gemma/weights.h" #include "ops/matmul.h" #include "util/mat.h" @@ -70,7 +71,7 @@ template void ActivationBatched( ActivationType activation, Mat& c1, ThreadingContext& ctx, size_t cluster_idx = 0, - ParallelismStrategy parallelism = ParallelismStrategy::kFlat) { + Parallelism parallelism = Parallelism::kFlat) { using T = typename Mat::T; ParallelFor(parallelism, c1.Rows(), ctx, cluster_idx, Callers::kActivationBatched, [&](uint64_t task, size_t worker) { @@ -115,7 +116,7 @@ template HWY_NOINLINE void ActivationBatched( ActivationType activation, Mat1& c1, const Mat2* c2, ThreadingContext& ctx, size_t cluster_idx = 0, - ParallelismStrategy parallelism = ParallelismStrategy::kFlat) { + Parallelism parallelism = Parallelism::kFlat) { HWY_DASSERT(c1.SameShape(*c2)); if (c2 && c2->HasPtr()) { ParallelFor(parallelism, c1.Rows(), ctx, cluster_idx, @@ -158,6 +159,9 @@ static inline void FFWNoVit(const LayerWeightsPtrs& layer, HWY_DASSERT(!layer_config.ff_biases); // Only used in Vit. + activations.s_ffw_in.Notify(layer.layer_idx, activations.pre_ffw_rms_out, + env.ctx); + #if GEMMA_FUSED_FFN const auto fused = [&](RowPtrsBF C1, IndexRange range_r, IndexRange range_c, StridedViewBF C2, size_t worker) { @@ -179,8 +183,31 @@ static inline void FFWNoVit(const LayerWeightsPtrs& layer, env.ctx); #endif + activations.s_ffw_hidden.Notify(layer.layer_idx, activations.C1, env.ctx); + // Hidden layer -> output layer. CallMatMul(activations.C1, layer.linear_w, nullptr, env, activations.ffw_out); + + activations.s_ffw_out.Notify(layer.layer_idx, activations.ffw_out, env.ctx); +} + +// Sums encoded (`att_out`) over num_heads (`layer_config.heads`) and +// head_dim (`qkv_dim`) into output (`layer_out`). +static HWY_INLINE void SumHeads(const LayerWeightsPtrs& layer, + AttentionActivationsPtrs& activations, + MatMulEnv& env) { + GCPP_ZONE(env.ctx, hwy::Profiler::GlobalIdx(), Zones::kGenAttentionSumHeads); + const LayerConfig& layer_config = layer.layer_config; + (void)layer_config; // For HWY_DASSERT + // att_weights and att_out are concatenated heads, each of length + // layer_config.qkv_dim. Thus the [num_interleaved, + // layer_config.model_dim] matmul output is the sum over heads. Compare + // gemma/modules.py: attn_output = self.attn_vec_einsum('BTNH,NHD->BTD', + // encoded) + HWY_DASSERT(layer_config.model_dim != 0 && layer_config.heads != 0 && + layer_config.qkv_dim != 0); + CallMatMul(activations.att_out, layer.att_weights, /*add=*/nullptr, env, + activations.att_sums); } // NOLINTNEXTLINE(google-readability-namespace-comments) diff --git a/gemma/gemma.cc b/gemma/gemma.cc index 7991c35..0ce6ab3 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -18,12 +18,16 @@ #include "gemma/gemma.h" +#include + #include "compression/types.h" // GEMMA_DISABLED_TARGETS -#include "util/zones.h" #ifndef HWY_DISABLED_TARGETS #define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS #endif // HWY_DISABLED_TARGETS +#include "gemma/tensor_stats.h" +#include "util/zones.h" + // Compiles this file for multiple architectures via "foreach_target.h", to // which we pass the filename via macro 'argument'. // clang-format off @@ -35,7 +39,7 @@ // After highway.h #include "gemma/attention.h" // includes highway.h #include "gemma/gemma-inl.h" -#include "gemma/vit.h" // includes highway.h +#include "gemma/vit.h" // includes highway.h #ifndef GEMMA_CC_ONCE #define GEMMA_CC_ONCE @@ -73,10 +77,12 @@ namespace HWY_NAMESPACE { void Attention(LayerAttentionType type, const size_t num_tokens, const size_t layer_idx, const LayerWeightsPtrs& layer, Activations& activations, QBatch& qbatch, MatMulEnv& env) { + if (type == LayerAttentionType::kGemma) { // TODO: remove flag to enable FlashAttention. - GemmaAttention(num_tokens, layer_idx, layer, activations.attention, qbatch, - env, HWY_NATIVE_DOT_BF16 ? kAttentionUseOld : 0); + GemmaAttention( + num_tokens, layer_idx, layer, activations.attention, qbatch, env, + AttentionImplToFlags(activations.attention_impl, HWY_NATIVE_DOT_BF16)); } } @@ -355,6 +361,10 @@ static HWY_NOINLINE void PrefillQBatch(const size_t max_prompt_size, (void)runtime_config.StreamToken(qbatch.QueryIdx(qi), pos_in_prompt, token, 0.0f); qbatch.MutablePos(qi) = pos_in_prompt; + } else { + // This prevents the kv cache of eos_id to be written to last prefilled + // token. + qbatch.MutablePos(qi) = qbatch.Prompt(qi).size(); } qbatch.PrevToken(qi) = token; @@ -426,7 +436,7 @@ static void SampleAndStream(const ModelConfig& config, timing_info.NotifyGenerated(non_eos.Count()); ParallelFor( - ParallelismStrategy::kFlat, qbatch.Size(), env.ctx, + Parallelism::kFlat, qbatch.Size(), env.ctx, /*cluster_idx=*/0, Callers::kSampleAndStream, [&](size_t qi, size_t worker) { if (!non_eos.Get(qi)) return; @@ -484,12 +494,11 @@ ChooseSampleFunc(const RuntimeConfig& runtime_config, }; } -// Decode: generates one continuation token for each query in `qbatch`. -static void GenerateT(const ModelConfig& config, - const RuntimeConfig& runtime_config, - const AesCtrEngine& engine, const WeightsPtrs& weights, - Activations& activations, QBatch& qbatch, MatMulEnv& env, - TimingInfo& timing_info) { +static size_t PrefillTBatchOrQBatch(const ModelConfig& config, + const RuntimeConfig& runtime_config, + const WeightsPtrs& weights, + Activations& activations, QBatch& qbatch, + MatMulEnv& env, TimingInfo& timing_info) { size_t max_prompt_size = 0; bool all_prefix_end_are_zero = true; size_t total_prefill_tokens = 0; // only for throughput stats. @@ -511,14 +520,14 @@ static void GenerateT(const ModelConfig& config, // We use a single divisor, so all sequence lengths must be the same. HWY_ASSERT(qbatch.KV(qi).SeqLen() == seq_len); } - if (max_prompt_size >= seq_len) { - HWY_ABORT("max_prompt_size = %zu, increase --seq_len to at least that.", - max_prompt_size); + if (max_prompt_size > seq_len) { + HWY_ABORT( + "max_prompt_size = %zu, seq_len = %zu, increase --seq_len to at least " + "that.", + max_prompt_size, seq_len); } HWY_ASSERT(activations.attention.div_seq_len.GetDivisor() == seq_len); - // Lacks a constructor to bulk-set, hence initialized by Prefill* which have - // qi loops anyway. hwy::BitSet4096<> non_eos; // indexed by qi timing_info.prefill_start = hwy::platform::Now(); @@ -536,18 +545,6 @@ static void GenerateT(const ModelConfig& config, timing_info.NotifyPrefill(total_prefill_tokens); // queries_pos have been incremented by Prefill. - // Stream the last prompt token from each query, fill activations.gen_tokens. - for (size_t qi = 0; qi < qbatch.Size(); ++qi) { - const size_t last_pos_in_prompt = qbatch.Pos(qi) - qbatch.InitialPos(qi); - - const size_t pos = qbatch.Pos(qi); // during prefill, pos is still correct. - // In autoregressive mode, we have not prefilled the last token, so do - // not advance. - const bool update_pos = (qbatch.Pos(qi) < qbatch.PrefixEnd(qi)); - StreamAndUpdateEOS(qi, pos, qbatch.Prompt(qi)[last_pos_in_prompt], 0.0f, - config, runtime_config, qbatch, update_pos, non_eos); - } - size_t max_gen_steps = runtime_config.max_generated_tokens; if (max_prompt_size + max_gen_steps > seq_len) { HWY_WARN("prefill %zu + max_gen_steps %zu > seq_len %zu, truncating.", @@ -555,6 +552,55 @@ static void GenerateT(const ModelConfig& config, max_gen_steps = seq_len - max_prompt_size; } + return max_gen_steps; +} + +static void StreamAndUpdateEOSAfterPrefill(const ModelConfig& config, + const RuntimeConfig& runtime_config, + QBatch& qbatch, + hwy::BitSet4096<>& non_eos, + size_t qi) { + const size_t last_pos_in_prompt = qbatch.Pos(qi) - qbatch.InitialPos(qi); + + const size_t pos = qbatch.Pos(qi); // during prefill, pos is still correct. + // In autoregressive mode, we have not prefilled the last token, so do + // not advance. + const bool update_pos = (qbatch.Pos(qi) < qbatch.PrefixEnd(qi)); + StreamAndUpdateEOS(qi, pos, qbatch.Prompt(qi)[last_pos_in_prompt], 0.0f, + config, runtime_config, qbatch, update_pos, non_eos); +} + +void SetWeightStats(const LayerWeightsPtrs& layer, Activations& a, + ThreadingContext& ctx) { + const size_t layer_idx = layer.layer_idx; + a.s_w_gating_einsum_w1.Notify(layer_idx, layer.gating_einsum_w1, ctx, + kTensorStatsIsWeight); + a.s_w_gating_einsum_w2.Notify(layer_idx, layer.gating_einsum_w2, ctx, + kTensorStatsIsWeight); + a.s_w_linear_w.Notify(layer_idx, layer.linear_w, ctx, kTensorStatsIsWeight); +} + +// Decode: generates one continuation token for each query in `qbatch`. +static void GenerateT(const ModelConfig& config, + const RuntimeConfig& runtime_config, + const AesCtrEngine& engine, const WeightsPtrs& weights, + Activations& activations, QBatch& qbatch, MatMulEnv& env, + TimingInfo& timing_info) { + for (const LayerWeightsPtrs& layer : weights.c_layers) { + SetWeightStats(layer, activations, env.ctx); + } + + const size_t max_gen_steps = PrefillTBatchOrQBatch( + config, runtime_config, weights, activations, qbatch, env, timing_info); + + hwy::BitSet4096<> non_eos; // indexed by qi + + // Stream the last prompt token from each query, fill activations.gen_tokens. + for (size_t qi = 0; qi < qbatch.Size(); ++qi) { + non_eos.Set(qi); + StreamAndUpdateEOSAfterPrefill(config, runtime_config, qbatch, non_eos, qi); + } + const SampleFunc sample_token = ChooseSampleFunc(runtime_config, engine, env.ctx); @@ -567,14 +613,66 @@ static void GenerateT(const ModelConfig& config, timing_info.NotifyGenerateDone(); } +// Same as GenerateT, but uses ContinuousQBatch. +static void GenerateTWithContinuousBatching( + const ModelConfig& config, const RuntimeConfig& runtime_config, + const AesCtrEngine& engine, const WeightsPtrs& weights, + Activations& activations, AllQueries& all_queries, MatMulEnv& env, + TimingInfo& timing_info) { + const size_t qbatch_size = runtime_config.decode_qbatch_size; + + QBatch qbatch(0, qbatch_size, all_queries); + ContinuousQBatch prefill_batch(qbatch_size, all_queries); + + hwy::BitSet4096<> non_eos; + const SampleFunc sample_token = + ChooseSampleFunc(runtime_config, engine, env.ctx); + + size_t query_inserted = 0; + while (non_eos.Any() || query_inserted < all_queries.NumQueries()) { + for (size_t qi = 0; qi < qbatch.Size(); ++qi) { + // Continue if qi slot is still processing. + if (non_eos.Get(qi)) continue; + // Collect the kv_cache from the qi slot in the qbatch to the + // available_kv_caches_ in the prefill_batch. + prefill_batch.MaybeReleaseKV(qbatch.Single(qi)); + + // Prefill if no available prefilled queries to insert. + if (prefill_batch.ShouldPrefill()) { + prefill_batch.SetupNextBatchForPrefill(); + PrefillTBatchOrQBatch(config, runtime_config, weights, activations, + prefill_batch, env, timing_info); + activations.SetBatchSize(qbatch.Size()); + } + + // Get the next query to insert to the generate batch. + std::optional qi_to_insert = prefill_batch.GetNextToInsert(); + if (qi_to_insert) { + qbatch.Insert(qi_to_insert.value(), qi); + query_inserted++; + + non_eos.Set(qi); + StreamAndUpdateEOSAfterPrefill(config, runtime_config, qbatch, non_eos, + qi); + } + } + + Transformer(config, runtime_config, weights, activations, qbatch, env); + SampleAndStream(config, runtime_config, weights, sample_token, activations, + qbatch, env, non_eos, timing_info); + } + timing_info.NotifyGenerateDone(); +} + void GenerateSingleT(const PromptTokens& prompt, size_t pos, size_t prefix_end, const ModelConfig& config, const RuntimeConfig& runtime_config, const AesCtrEngine& engine, const WeightsPtrs& weights, KVCache& kv_cache, MatMulEnv& env, TimingInfo& timing_info) { - Activations activations(config, runtime_config.prefill_tbatch_size, - kv_cache.SeqLen(), env.ctx, env.row_ptrs); + Activations activations(runtime_config, config, + runtime_config.prefill_tbatch_size, kv_cache.SeqLen(), + env.ctx, env.row_ptrs); AllQueries all_queries(prompt, pos, prefix_end, hwy::Span(&kv_cache, 1)); @@ -592,16 +690,21 @@ void GenerateBatchT(const ModelConfig& config, TimingInfo& timing_info) { const size_t max_batch_size = HWY_MAX(runtime_config.decode_qbatch_size, runtime_config.prefill_tbatch_size); - Activations activations(config, max_batch_size, + Activations activations(runtime_config, config, max_batch_size, all_queries[0].kv_cache.SeqLen(), env.ctx, env.row_ptrs); - for (size_t start = 0; start < all_queries.NumQueries(); - start += runtime_config.decode_qbatch_size) { - QBatch qbatch(start, runtime_config.decode_qbatch_size, all_queries); - // Generate a batch of one token for each of `qbatch.Size()` queries. - GenerateT(config, runtime_config, engine, weights, activations, qbatch, env, - timing_info); + if (runtime_config.use_continuous_batching) { + GenerateTWithContinuousBatching(config, runtime_config, engine, weights, + activations, all_queries, env, timing_info); + } else { + for (size_t start = 0; start < all_queries.NumQueries(); + start += runtime_config.decode_qbatch_size) { + QBatch qbatch(start, runtime_config.decode_qbatch_size, all_queries); + // Generate a batch of one token for each of `qbatch.Size()` queries. + GenerateT(config, runtime_config, engine, weights, activations, qbatch, + env, timing_info); + } } } @@ -617,8 +720,8 @@ void GenerateImageTokensT(const ModelConfig& config, const size_t num_tokens = vit_config.max_seq_len; prefill_runtime_config.prefill_tbatch_size = num_tokens / (vit_config.pool_dim * vit_config.pool_dim); - Activations prefill_activations(vit_config, num_tokens, num_tokens, env.ctx, - env.row_ptrs); + Activations prefill_activations(runtime_config, vit_config, num_tokens, + num_tokens, env.ctx, env.row_ptrs); // Weights are for the full PaliGemma model, not just the ViT part. PrefillVit(config, weights, prefill_runtime_config, image, image_tokens, prefill_activations, env); @@ -635,17 +738,16 @@ HWY_EXPORT(GenerateSingleT); HWY_EXPORT(GenerateBatchT); HWY_EXPORT(GenerateImageTokensT); -Gemma::Gemma(const LoaderArgs& loader, const InferenceArgs& inference, - ThreadingContext& ctx) - : reader_(loader.weights), - model_(reader_, loader.tokenizer, loader.wrapping), +Gemma::Gemma(const GemmaArgs& args, ThreadingContext& ctx) + : reader_(args.loader.weights), + model_(reader_, args.loader.tokenizer, args.loader.wrapping), weights_(model_.Config()), chat_template_(model_.Tokenizer(), model_.Config().model), - inference_(inference), - aes_ctr_engine_(inference.deterministic) { + inference_(args.inference), + aes_ctr_engine_(args.inference.deterministic) { // Negligible CPU time in the ctor body (except ReadFromBlobs). - weight_read_mode_ = weights_.ReadFromBlobs(model_, reader_, loader, inference, - mat_owners_, ctx); + weight_read_mode_ = weights_.ReadFromBlobs(model_, reader_, args.loader, + args.inference, mat_owners_, ctx); // Read everything into memory, or `weights_.mapped_` keeps the mapping alive. reader_.CloseFile(); } @@ -698,5 +800,64 @@ void Gemma::GenerateImageTokens(const RuntimeConfig& runtime_config, env.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning); } +ContinuousQBatch::ContinuousQBatch(size_t max_size, AllQueries& queries) + : QBatch(0, max_size, queries) { + for (size_t i = start_; i < queries_.NumQueries(); ++i) { + if (!queries_[i].kv_cache.IsEmpty()) { + // Put the kv_cache to the available_kv_caches_ instead; leaving the + // kv_cache in the queries_ is very confusing. This simplifies the logic + // of kv_cache management. + available_kv_caches_.push_back(queries_[i].kv_cache); + queries_[i].kv_cache = KVCachePtr(); + } + } +} + +bool ContinuousQBatch::ShouldPrefill() const { + const bool no_available_to_insert = next_to_insert_ == next_to_prefill_; + const int more_queries_to_prefill = next_to_prefill_ < queries_.NumQueries(); + return no_available_to_insert && more_queries_to_prefill; +} + +void ContinuousQBatch::SetupNextBatchForPrefill() { + start_ = next_to_prefill_; + size_ = HWY_MIN(max_size_, queries_.NumQueries() - start_); + HWY_DASSERT(size_ != 0); + HWY_DASSERT(start_ + size_ <= queries_.NumQueries()); + query_idx_.clear(); + query_idx_.reserve(size_); + for (size_t i = 0; i < size_; ++i) { + const size_t next_query_idx = start_ + i; + query_idx_.push_back(next_query_idx); + HWY_ASSERT(queries_[next_query_idx].kv_cache.IsEmpty()); + queries_[next_query_idx].kv_cache = available_kv_caches_.back(); + available_kv_caches_.pop_back(); + } + next_to_prefill_ += size_; +} + +std::optional ContinuousQBatch::GetNextToInsert() { + if (next_to_insert_ == next_to_prefill_) { + return std::nullopt; + } + next_to_insert_++; + return next_to_insert_ - 1; +} + +void ContinuousQBatch::MaybeReleaseKV(const QBatch& from) { + const int query_to_collect = from.QueryIdx(0); + // Only collect if the query to collect is not the same as the next query to + // insert. This happens at the beginning of each Generate call. + if (query_to_collect != next_to_insert_) { + // Only clear the KV cache if there are more queries to insert; Otherwise + // we get a crash because Transformer will still access that KV cache. + if (next_to_insert_ < queries_.NumQueries()) { + available_kv_caches_.push_back(from.KV(0)); + ZeroInit(from.KV(0).kv_cache); + from.KV(0) = KVCachePtr(); + } + } +} + } // namespace gcpp #endif // HWY_ONCE diff --git a/gemma/gemma.h b/gemma/gemma.h index 771cd1c..b630a8c 100644 --- a/gemma/gemma.h +++ b/gemma/gemma.h @@ -18,6 +18,7 @@ #include +#include #include // IWYU pragma: begin_exports @@ -26,6 +27,7 @@ #include "gemma/gemma_args.h" #include "gemma/kv_cache.h" #include "gemma/model_store.h" +#include "gemma/query.h" #include "gemma/weights.h" #include "io/blob_store.h" #include "io/io.h" // Path @@ -38,132 +40,28 @@ namespace gcpp { -struct PerQuery { - PromptTokens prompt; - - // Position in the KV cache: initially zero for the first turn, or when - // multi-turn is NOT desired. Incremented by prefill and `StreamAndUpdateEOS`. - size_t mutable_pos; - // Allows computing the last prefill token as `mutable_pos - initial_pos`, - // which might differ from `prompt.size() - 1` for prefix-LM. - size_t initial_pos; - // Zero for causal attention, or the end of the prefix for prefix-LM style - // attention in Paligemma. - size_t prefix_end; - - KVCache& kv_cache; - - // Previous token generated for this query, or the last prompt token. Will be - // fed into the next Transformer() call. - int prev_token = 0; -}; - -// Array of `PerQuery`. Referenced by `QBatch` and passed to `GenerateBatch`. -struct AllQueries { - AllQueries() = default; - - // For `GenerateSingleT`: same prompt/pos, replicated for each KV cache. - AllQueries(const PromptTokens& prompt, size_t pos, size_t prefix_end, - const hwy::Span& kv_caches) { - per_query_.reserve(kv_caches.size()); - for (size_t i = 0; i < kv_caches.size(); ++i) { - HWY_ASSERT(kv_caches[i].SeqLen() == kv_caches[0].SeqLen()); - per_query_.push_back(PerQuery{ - .prompt = prompt, - .mutable_pos = pos, - .initial_pos = pos, - .prefix_end = prefix_end, - .kv_cache = kv_caches[i], - }); - } - } - - // Batch of queries with initial position set to zero. Causal attention - // is requested via empty or all-zero `prefix_end`. - AllQueries( - const hwy::Span& prompts, - const hwy::Span& kv_caches, - const hwy::Span& prefix_end = hwy::Span()) { - HWY_ASSERT(prompts.size() == kv_caches.size()); - HWY_ASSERT(prompts.size() == prefix_end.size() || prefix_end.size() == 0); - per_query_.reserve(kv_caches.size()); - for (size_t i = 0; i < kv_caches.size(); ++i) { - HWY_ASSERT(kv_caches[i].SeqLen() == kv_caches[0].SeqLen()); - per_query_.push_back(PerQuery{ - .prompt = prompts[i], - .mutable_pos = 0, - .initial_pos = 0, - .prefix_end = prefix_end.size() == 0 ? 0 : prefix_end[i], - .kv_cache = kv_caches[i], - }); - } - } - - void Reserve(size_t size) { per_query_.reserve(size); } - void Append(const PerQuery& query) { per_query_.push_back(query); } - - size_t NumQueries() const { return per_query_.size(); } - - PerQuery& operator[](size_t query_idx) { - HWY_DASSERT(query_idx < NumQueries()); - return per_query_[query_idx]; - } - const PerQuery& operator[](size_t query_idx) const { - HWY_DASSERT(query_idx < NumQueries()); - return per_query_[query_idx]; - } - - private: - std::vector per_query_; -}; - -// View into AllQueries: either a batch of queries, or a single query for use -// in PrefillTBatch or GenerateSingleT. Cheap to create because it holds a -// reference to AllQueries. -class QBatch { +// Used for continuous batching. +class ContinuousQBatch : public QBatch { public: - QBatch(size_t start, size_t max_size, AllQueries& queries) - : start_(start), - max_size_(max_size), - queries_(queries), - size_(HWY_MIN(max_size_, queries_.NumQueries() - start_)) { - HWY_ASSERT(max_size_ <= kMaxBatchSize); - HWY_DASSERT(size_ != 0); - HWY_DASSERT(start_ + size_ <= queries_.NumQueries()); - } + ContinuousQBatch(size_t max_size, AllQueries& queries); - // Returns a single-query view starting at `qi` relative to this batch. - QBatch Single(size_t qi) const { return QBatch(start_ + qi, 1, queries_); } + // Whether we should prefill the next batch, i.e. next_to_insert_ == + // next_to_prefill_. + bool ShouldPrefill() const; - // How many queries in this batch, <= `queries_.NumQueries()` and `max_size_`. - size_t Size() const { return size_; } + // Setup the query_idx_ to point to the next group of queries to prefill. + void SetupNextBatchForPrefill(); - // Returns index for use with `AllQueries` and `BatchStreamToken`. - size_t QueryIdx(size_t qi) const { - HWY_DASSERT(qi < size_); - return start_ + qi; - } + // Get the next query to insert to the generate batch. + std::optional GetNextToInsert(); - // Accessor functions to bridge the previous SoA and current AoS layout. - const PromptTokens& Prompt(size_t qi) const { - return queries_[QueryIdx(qi)].prompt; - } - size_t Pos(size_t qi) const { return queries_[QueryIdx(qi)].mutable_pos; } - size_t& MutablePos(size_t qi) { return queries_[QueryIdx(qi)].mutable_pos; } - size_t InitialPos(size_t qi) const { - return queries_[QueryIdx(qi)].initial_pos; - } - size_t PrefixEnd(size_t qi) const { - return queries_[QueryIdx(qi)].prefix_end; - } - KVCache& KV(size_t qi) const { return queries_[QueryIdx(qi)].kv_cache; } - int& PrevToken(size_t qi) { return queries_[QueryIdx(qi)].prev_token; } + // Collect the kv_cache from QBatch to available_kv_caches_. + void MaybeReleaseKV(const QBatch& from); - private: - size_t start_; - size_t max_size_; - AllQueries& queries_; - size_t size_; + public: + int next_to_prefill_ = 0; + int next_to_insert_ = 0; + std::vector available_kv_caches_; }; struct TimingInfo { @@ -232,11 +130,16 @@ struct TimingInfo { // separate `ThreadingContext` and `MatMulEnv` for each concurrent `Generate`. class Gemma { public: - // Reads weights/config/tokenizer from the `BlobStore` at `loader.weights`. + // Reads weights/config/tokenizer from `BlobStore` at `args.loader.weights`. // `ctx` is only used to read tensors and not stored. Calls to `Generate*` // may reference the same, or other `ThreadingContext` via `MatMulEnv`. + Gemma(const GemmaArgs& args, ThreadingContext& ctx); + + // Deprecated prior interface for backwards compatibility. Gemma(const LoaderArgs& loader, const InferenceArgs& inference, - ThreadingContext& ctx); + ThreadingContext& ctx) + : Gemma(GemmaArgs(loader, ThreadingArgs(), inference), ctx) {} + ~Gemma(); const ModelConfig& Config() const { return model_.Config(); } diff --git a/gemma/gemma_args.h b/gemma/gemma_args.h index 3135f50..ba72db6 100644 --- a/gemma/gemma_args.h +++ b/gemma/gemma_args.h @@ -24,10 +24,12 @@ #include #include -#include "io/io.h" // Path -#include "util/args.h" +#include "gemma/configs.h" +#include "io/io.h" // Path +#include "util/args.h" // IWYU pragma: export #include "util/basics.h" // Tristate #include "util/mat.h" +#include "util/threading_context.h" #include "hwy/aligned_allocator.h" // Span #include "hwy/base.h" // HWY_ABORT #include "hwy/profiler.h" @@ -35,7 +37,9 @@ namespace gcpp { struct LoaderArgs : public ArgsBase { - LoaderArgs(int argc, char* argv[]) { InitAndParse(argc, argv); } + LoaderArgs(int argc, char* argv[], ConsumedArgs& consumed) { + InitAndParse(argc, argv, consumed); + } LoaderArgs(const std::string& tokenizer_path, const std::string& weights_path) { Init(); // Init sets to defaults, so assignments must come after Init(). @@ -139,6 +143,9 @@ struct RuntimeConfig { int verbosity; // Controls verbosity of printed messages. + // Which attention implementation to use. + AttentionImpl attention_impl = AttentionImpl::kFlash; + // Functions operating on the generated tokens. StreamFunc stream_token; BatchStreamFunc batch_stream_token; @@ -159,10 +166,15 @@ struct RuntimeConfig { // default decision is likely sufficient because it is based on whether // threads are successfully pinned. mutable Tristate use_spinning = Tristate::kDefault; + + // Whether to use continuous batching. + bool use_continuous_batching = false; }; struct InferenceArgs : public ArgsBase { - InferenceArgs(int argc, char* argv[]) { InitAndParse(argc, argv); } + InferenceArgs(int argc, char* argv[], ConsumedArgs& consumed) { + InitAndParse(argc, argv, consumed); + } InferenceArgs() { Init(); }; bool IsInteractive() const { return prompt.empty() && prompt_file.Empty(); } @@ -187,6 +199,7 @@ struct InferenceArgs : public ArgsBase { // For prompts longer than the Linux terminal's 4K line edit buffer. Path prompt_file; std::string eot_line; + std::string attention_impl; template void ForEach(const Visitor& visitor) { @@ -240,6 +253,8 @@ struct InferenceArgs : public ArgsBase { "before the line where only the given string appears.\n Default = " "When a newline is encountered, that signals the end of the turn.", 2); + visitor(attention_impl, "attention_impl", std::string("flash"), + "Attention implementation to use. See configs.cc for options.", 2); } void CopyTo(RuntimeConfig& runtime_config) const { @@ -261,36 +276,39 @@ struct InferenceArgs : public ArgsBase { runtime_config.temperature = temperature; runtime_config.top_k = top_k; + runtime_config.attention_impl = GetAttentionImpl(attention_impl); } }; -struct ClientArgs : public ArgsBase { - ClientArgs(int argc, char* argv[]) { InitAndParse(argc, argv); } - ClientArgs() { Init(); }; +// Bundles all args required to construct a `GemmaEnv` or the equivalent. +struct GemmaArgs { + // For callers that do not parse command line args. + GemmaArgs(const LoaderArgs& loader, + const ThreadingArgs& threading = ThreadingArgs(), + const InferenceArgs& inference = InferenceArgs()) + : loader(loader), threading(threading), inference(inference) {} - std::string host; - int port; - std::string api_key; - std::string model; - std::string prompt; - bool interactive; + GemmaArgs(int argc, char** argv, ConsumedArgs& consumed) + : loader(argc, argv, consumed), + threading(argc, argv, consumed), + inference(argc, argv, consumed) {} - template - void ForEach(const Visitor& visitor) { - visitor(host, "host", std::string("localhost"), - "Server host (default: localhost)"); - visitor(port, "port", 8080, - "Server port (default: 8080)"); - visitor(api_key, "api_key", std::string(""), - "Use public API with key (changes host to " - "generativelanguage.googleapis.com:443)"); - visitor(model, "model", std::string("gemma3-4b"), - "Model name to use (default: gemma3-4b)"); - visitor(prompt, "prompt", std::string("Hello! How are you?"), - "Prompt for generation (default: 'Hello! How are you?')"); - visitor(interactive, "interactive", false, - "Start interactive chat mode (0 = no, 1 = yes)"); + void Help() { + fprintf(stderr, + "To run with pre-2025 weights, specify --tokenizer and --weights.\n" + "With the single-file weights format, specify just --weights.\n" + "\n*Model Loading Arguments*\n"); + loader.Help(); + fprintf(stderr, "\n*Threading Arguments*\n"); + threading.Help(); + fprintf(stderr, "\n*Inference Arguments*\n"); + inference.Help(); + fprintf(stderr, "\n"); } + + LoaderArgs loader; + ThreadingArgs threading; + InferenceArgs inference; }; } // namespace gcpp diff --git a/gemma/gemma_args_test.cc b/gemma/gemma_args_test.cc new file mode 100644 index 0000000..d9ee8b7 --- /dev/null +++ b/gemma/gemma_args_test.cc @@ -0,0 +1,74 @@ +#include "gemma/gemma_args.h" + +#include + +#include +#include + +#include "gtest/gtest.h" + +namespace gcpp { + +void FillPtrs(const std::vector& args, std::vector& ptrs) { + ptrs.reserve(args.size()); + for (const std::string& arg : args) { + ptrs.push_back(const_cast(arg.data())); + } +} + +static void CheckAllConsumed(const std::vector& args) { + std::vector ptrs; + FillPtrs(args, ptrs); + const int argc = static_cast(args.size()); + char** argv = const_cast(ptrs.data()); + + ConsumedArgs consumed(argc, argv); + GemmaArgs gemma_args(argc, argv, consumed); + consumed.AbortIfUnconsumed(); +} + +static void CheckUnconsumed(const std::vector& args, + size_t expected) { + std::vector ptrs; + FillPtrs(args, ptrs); + const int argc = static_cast(args.size()); + char** argv = const_cast(ptrs.data()); + + ConsumedArgs consumed(argc, argv); + GemmaArgs gemma_args(argc, argv, consumed); + ASSERT_EQ(expected, consumed.FirstUnconsumed()); +} + +// Note: do not use --help because that is not actually consumed; it is actually +// special-cased in `HasHelp`. +TEST(GemmaArgsTest, AllConsumedArgs) { + // Single arg + CheckAllConsumed({"gemma", "--weights=x"}); + // Two args, one with = + CheckAllConsumed({"gemma", "--weights=x", "--verbosity=1"}); + // Two args, one with extra value + CheckAllConsumed({"gemma", "--weights=x", "--verbosity", "2"}); + // Two args with values + CheckAllConsumed({"gemma", "--verbosity", "2", "--deterministic=true"}); +} + +TEST(GemmaArgsTest, UnconsumedArgs) { + // Single unconsumed arg + CheckUnconsumed({"gemma", "--UNDEFINED"}, 1); + // Single unconsumed arg, no -- + CheckUnconsumed({"gemma", "UNDEFINED"}, 1); + // Single unconsumed arg after valid arg + CheckUnconsumed({"gemma", "--weights=x", "--UNDEFINED"}, 2); + // Single unconsumed arg before valid arg + CheckUnconsumed({"gemma", "--UNDEFINED", "--weights=x"}, 1); + // Single unconsumed arg with = after valid arg + CheckUnconsumed({"gemma", "--weights=x", "--UNDEFINED=1"}, 2); + // Single unconsumed arg with = before valid arg + CheckUnconsumed({"gemma", "--UNDEFINED=false", "--weights=x"}, 1); + // Multiple unconsumed args + CheckUnconsumed({"gemma", "--UNDEFINED", "--XXX"}, 1); + // Multiple unconsumed args with valid arg between + CheckUnconsumed({"gemma", "--UNDEFINED", "--weights=x", "--XXX"}, 1); +} + +} // namespace gcpp diff --git a/gemma/kv_cache.cc b/gemma/kv_cache.cc index ca814f4..2fe6885 100644 --- a/gemma/kv_cache.cc +++ b/gemma/kv_cache.cc @@ -16,6 +16,7 @@ #include "gemma/kv_cache.h" #include +#include #include "gemma/configs.h" #include "gemma/gemma_args.h" @@ -50,8 +51,16 @@ KVCache KVCache::Copy() { KVCache copy(kv_cache.Extents(), allocator_); CopyMat(kv_cache, copy.kv_cache); - return copy; } +std::vector ToKVCachePtrs(const hwy::Span& kv_caches) { + std::vector ptrs; + ptrs.reserve(kv_caches.size()); + for (size_t i = 0; i < kv_caches.size(); ++i) { + ptrs.push_back(kv_caches[i].ToPtr()); + } + return ptrs; +} + } // namespace gcpp diff --git a/gemma/kv_cache.h b/gemma/kv_cache.h index 31e964b..bad66fa 100644 --- a/gemma/kv_cache.h +++ b/gemma/kv_cache.h @@ -18,7 +18,10 @@ #include -#include "gemma/configs.h" // ModelConfig +#include +#include + +#include "gemma/configs.h" // ModelConfig #include "gemma/gemma_args.h" // InferenceArgs #include "util/basics.h" // BF16 #include "util/mat.h" @@ -27,18 +30,33 @@ namespace gcpp { using KV_t = float; +// A non-owning view of a KVCache. +struct KVCachePtr { + bool IsEmpty() const { return kv_cache.Rows() == 0; } + size_t SeqLen() const; + + MatPtrT kv_cache; +}; + struct KVCache { KVCache(const ModelConfig& config, const InferenceArgs& inference_args, const Allocator& allocator); - // Returns a deep copy of the KVCache. Use explicit function instead of // copy ctor to make the cost explicit. KVCache Copy(); - size_t SeqLen() const { return kv_cache.Rows(); } + size_t SeqLen() const { + return kv_cache.Rows(); + } MatStorageT kv_cache; // [seq_len, layers * kv_heads * qkv_dim * 2] + KVCachePtr ToPtr() { + return KVCachePtr{ + .kv_cache = kv_cache, + }; + } + private: const Allocator& allocator_; @@ -46,6 +64,13 @@ struct KVCache { KVCache(const Extents2D& kv_extents, const Allocator& allocator); }; +inline size_t KVCachePtr::SeqLen() const { + return kv_cache.Rows(); +} + +// Convenience function to create views into KVCaches. +std::vector ToKVCachePtrs(const hwy::Span& kv_caches); + } // namespace gcpp #endif // THIRD_PARTY_GEMMA_CPP_GEMMA_KV_CACHE_H_ diff --git a/gemma/kv_cache_test.cc b/gemma/kv_cache_test.cc new file mode 100644 index 0000000..157b3d9 --- /dev/null +++ b/gemma/kv_cache_test.cc @@ -0,0 +1,43 @@ +#include "gemma/kv_cache.h" + +#include +#include + +#include "gtest/gtest.h" +#include "gemma/configs.h" +#include "gemma/gemma_args.h" +#include "util/threading_context.h" +#include "hwy/aligned_allocator.h" +namespace gcpp { +namespace { + +TEST(KVCacheTest, KVCacheToPtrs) { + ModelConfig model_config; + model_config.max_seq_len = 1024; + model_config.num_layers = 2; + for (int i = 0; i < model_config.num_layers; ++i) { + model_config.layer_configs.push_back(LayerConfig()); + model_config.layer_configs.back().kv_heads = 4; + model_config.layer_configs.back().qkv_dim = 256; + } + InferenceArgs inference_args; + inference_args.seq_len = 1024; + RuntimeConfig runtime_config; + runtime_config.attention_impl = AttentionImpl::kFlash; + ThreadingArgs threading_args; + ThreadingContext ctx(threading_args); + std::vector caches; + caches.emplace_back(model_config, inference_args, runtime_config, + ctx.allocator); + inference_args.seq_len = 512; + caches.emplace_back(model_config, inference_args, runtime_config, + ctx.allocator); + + std::vector ptrs = ToKVCachePtrs({caches.data(), caches.size()}); + ASSERT_EQ(ptrs.size(), 2); + EXPECT_EQ(ptrs[0].kv_cache.Row(0), caches[0].kv_cache.Row(0)); + EXPECT_EQ(ptrs[1].kv_cache.Row(0), caches[1].kv_cache.Row(0)); +} + +} // namespace +} // namespace gcpp diff --git a/gemma/model_store.cc b/gemma/model_store.cc index 2f3e1ec..204dee9 100644 --- a/gemma/model_store.cc +++ b/gemma/model_store.cc @@ -221,6 +221,8 @@ static size_t DeduceNumLayers(const KeyVec& keys) { // This works with or without type prefixes because it searches for substrings. static int DeduceLayerTypes(const BlobReader& reader) { int layer_types = 0; + bool has_key_norm = false; + bool has_query_norm = false; for (size_t key_idx = 0; key_idx < reader.Keys().size(); ++key_idx) { const std::string& key = reader.Keys()[key_idx]; if (key.find("qkv_ein_w") != std::string::npos) { // NOLINT @@ -232,6 +234,15 @@ static int DeduceLayerTypes(const BlobReader& reader) { layer_types |= kDeduced448; } } + if (key.find("key_norm") != std::string::npos) { // NOLINT + has_key_norm = true; + } + if (key.find("query_norm") != std::string::npos) { // NOLINT + has_query_norm = true; + } + } + if (has_key_norm && has_query_norm) { + layer_types |= kDeducedKqNorm; } return layer_types; } diff --git a/gemma/query.h b/gemma/query.h new file mode 100644 index 0000000..36e8ee5 --- /dev/null +++ b/gemma/query.h @@ -0,0 +1,186 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_QUERY_H_ +#define THIRD_PARTY_GEMMA_CPP_GEMMA_QUERY_H_ + +#include + +#include "gemma/gemma_args.h" +#include "gemma/kv_cache.h" +#include "util/basics.h" +#include "hwy/aligned_allocator.h" // Span +#include "hwy/base.h" + +namespace gcpp { + +struct PerQuery { + PromptTokens prompt; + + // Position in the KV cache: initially zero for the first turn, or when + // multi-turn is NOT desired. Incremented by prefill and `StreamAndUpdateEOS`. + size_t mutable_pos; + // Allows computing the last prefill token as `mutable_pos - initial_pos`, + // which might differ from `prompt.size() - 1` for prefix-LM. + size_t initial_pos; + // Zero for causal attention, or the end of the prefix for prefix-LM style + // attention in Paligemma. + size_t prefix_end; + + KVCachePtr kv_cache; + + // Previous token generated for this query, or the last prompt token. Will be + // fed into the next Transformer() call. + int prev_token = 0; +}; + +// Array of `PerQuery`. Referenced by `QBatch` and passed to `GenerateBatch`. +struct AllQueries { + AllQueries() = default; + + // For `GenerateSingleT`: same prompt/pos, replicated for each KV cache. + AllQueries(const PromptTokens& prompt, size_t pos, size_t prefix_end, + const hwy::Span& kv_caches) { + per_query_.reserve(kv_caches.size()); + for (size_t i = 0; i < kv_caches.size(); ++i) { + HWY_ASSERT(kv_caches[i].SeqLen() == kv_caches[0].SeqLen()); + per_query_.push_back(PerQuery{ + .prompt = prompt, + .mutable_pos = pos, + .initial_pos = pos, + .prefix_end = prefix_end, + .kv_cache = kv_caches[i], + }); + } + } + + AllQueries(const PromptTokens& prompt, size_t pos, size_t prefix_end, + const hwy::Span& kv_caches) + : AllQueries(prompt, pos, prefix_end, + hwy::Span(ToKVCachePtrs(kv_caches))) {} + + // Batch of queries with initial position set to zero. Causal attention + // is requested via empty or all-zero `prefix_end`. + AllQueries( + const hwy::Span& prompts, + const hwy::Span& kv_caches, + const hwy::Span& prefix_end = hwy::Span()) { + HWY_ASSERT(prompts.size() == prefix_end.size() || prefix_end.size() == 0); + per_query_.reserve(prompts.size()); + for (size_t i = 0; i < prompts.size(); ++i) { + HWY_ASSERT(kv_caches.size() == 0 || + kv_caches[i].SeqLen() == kv_caches[0].SeqLen()); + per_query_.push_back(PerQuery{ + .prompt = prompts[i], + .mutable_pos = 0, + .initial_pos = 0, + .prefix_end = prefix_end.size() == 0 ? 0 : prefix_end[i], + .kv_cache = kv_caches.size() == 0 ? KVCachePtr() : kv_caches[i], + }); + } + } + + AllQueries( + const hwy::Span& prompts, + const hwy::Span& kv_caches, + const hwy::Span& prefix_end = hwy::Span()) + : AllQueries(prompts, hwy::Span(ToKVCachePtrs(kv_caches)), + prefix_end) {} + + void Reserve(size_t size) { per_query_.reserve(size); } + void Append(const PerQuery& query) { per_query_.push_back(query); } + + size_t NumQueries() const { return per_query_.size(); } + + PerQuery& operator[](size_t query_idx) { + HWY_DASSERT(query_idx < NumQueries()); + return per_query_[query_idx]; + } + const PerQuery& operator[](size_t query_idx) const { + HWY_DASSERT(query_idx < NumQueries()); + return per_query_[query_idx]; + } + + private: + std::vector per_query_; +}; + +// View into AllQueries: either a batch of queries, or a single query for use +// in PrefillTBatch or GenerateSingleT. Cheap to create because it holds a +// reference to AllQueries. +class QBatch { + public: + QBatch(size_t start, size_t max_size, AllQueries& queries) + : start_(start), + max_size_(max_size), + queries_(queries), + size_(HWY_MIN(max_size_, queries_.NumQueries() - start_)) { + HWY_ASSERT(max_size_ <= kMaxBatchSize); + HWY_DASSERT(size_ != 0); + HWY_DASSERT(start_ + size_ <= queries_.NumQueries()); + query_idx_.reserve(size_); + for (size_t i = 0; i < size_; ++i) { + query_idx_.push_back(start_ + i); + } + } + + // Returns a single-query view starting at `qi` relative to this batch. + QBatch Single(size_t qi) const { return QBatch(QueryIdx(qi), 1, queries_); } + + // How many queries in this batch, <= `queries_.NumQueries()` and `max_size_`. + size_t Size() const { return size_; } + + // Returns index for use with `AllQueries` and `BatchStreamToken`. + size_t QueryIdx(size_t qi) const { + HWY_DASSERT(qi < size_); + return query_idx_[qi]; + } + + // Accessor functions to bridge the previous SoA and current AoS layout. + const PromptTokens& Prompt(size_t qi) const { + return queries_[QueryIdx(qi)].prompt; + } + size_t Pos(size_t qi) const { return queries_[QueryIdx(qi)].mutable_pos; } + size_t& MutablePos(size_t qi) { return queries_[QueryIdx(qi)].mutable_pos; } + size_t InitialPos(size_t qi) const { + return queries_[QueryIdx(qi)].initial_pos; + } + size_t PrefixEnd(size_t qi) const { + return queries_[QueryIdx(qi)].prefix_end; + } + KVCachePtr& KV(size_t qi) const { return queries_[QueryIdx(qi)].kv_cache; } + int& PrevToken(size_t qi) { return queries_[QueryIdx(qi)].prev_token; } + + // let query_idx_[to] point to the from in the queries_; this is only used if + // the slot in the QBatch is less than the number of queries. + void Insert(size_t from, size_t to) { + if (from == to) return; + HWY_ASSERT(!queries_[from].kv_cache.IsEmpty()); + HWY_ASSERT(queries_[to].kv_cache.IsEmpty()); + // Conceptually, insert from.query to location to. + query_idx_[to] = from; + } + + protected: + size_t start_; + size_t max_size_; + AllQueries& queries_; + std::vector query_idx_; + size_t size_; +}; + +} // namespace gcpp + +#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_QUERY_H_ diff --git a/gemma/run.cc b/gemma/run.cc index 7e2059f..6c6f4d0 100644 --- a/gemma/run.cc +++ b/gemma/run.cc @@ -89,9 +89,11 @@ std::string GetPrompt(const InferenceArgs& inference) { } // The main Read-Eval-Print Loop. -void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference, - const Gemma& gemma, KVCache& kv_cache, MatMulEnv& env) { +void ReplGemma(const GemmaArgs& args, const Gemma& gemma, KVCache& kv_cache, + MatMulEnv& env) { PROFILER_ZONE("Gen.misc"); + const InferenceArgs& inference = args.inference; + const int verbosity = inference.verbosity; size_t abs_pos = 0; // across turns size_t tokens_generated_this_turn = 0; // differentiates prefill from reply size_t prompt_size = 0; @@ -113,12 +115,12 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference, HWY_ASSERT(image.ReadPPM(inference.image_file.path)); const size_t image_size = config.vit_config.image_size; image.Resize(image_size, image_size); - RuntimeConfig runtime_config = {.verbosity = inference.verbosity, - .use_spinning = threading.spin}; + RuntimeConfig runtime_config = {.verbosity = verbosity, + .use_spinning = args.threading.spin}; double image_tokens_start = hwy::platform::Now(); gemma.GenerateImageTokens(runtime_config, kv_cache.SeqLen(), image, image_tokens, env); - if (inference.verbosity >= 1) { + if (verbosity >= 1) { double image_tokens_duration = hwy::platform::Now() - image_tokens_start; fprintf(stderr, "\n\n[ Timing info ] Image token generation took: %d ms\n", @@ -129,11 +131,15 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference, // callback function invoked for each generated token. auto batch_stream_token = [&](size_t query_idx, size_t pos, int token, float) { - std::string token_text; - HWY_ASSERT(gemma.Tokenizer().Decode(std::vector{token}, &token_text)); - HWY_ASSERT(pos == abs_pos); ++abs_pos; + + std::string token_text; + if (!gemma.Tokenizer().Decode(std::vector{token}, &token_text)) { + if (token == -2) return true; // Gemma 3 ViT? + HWY_WARN("Failed to decode token %d.", token); + } + const bool in_prompt = tokens_generated_this_turn < prompt_size; const bool first_response_token = tokens_generated_this_turn == prompt_size; ++tokens_generated_this_turn; @@ -185,7 +191,7 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference, TimingInfo timing_info = {.verbosity = inference.verbosity}; RuntimeConfig runtime_config = {.verbosity = inference.verbosity, .batch_stream_token = batch_stream_token, - .use_spinning = threading.spin}; + .use_spinning = args.threading.spin}; inference.CopyTo(runtime_config); std::vector prompt; size_t prefix_end = 0; @@ -248,14 +254,14 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference, } } -void Run(const LoaderArgs& loader, const ThreadingArgs& threading, - const InferenceArgs& inference) { +void Run(const GemmaArgs& args) { PROFILER_ZONE("Run.misc"); - ThreadingContext ctx(threading); + ThreadingContext ctx(args.threading); MatMulEnv env(ctx); + const InferenceArgs& inference = args.inference; if (inference.verbosity >= 3) env.print_best = true; - const Gemma gemma(loader, inference, ctx); + const Gemma gemma(args, ctx); KVCache kv_cache(gemma.Config(), inference, ctx.allocator); if (inference.verbosity >= 1) { @@ -283,13 +289,12 @@ void Run(const LoaderArgs& loader, const ThreadingArgs& threading, if (inference.IsInteractive()) { std::cout << "\033[2J\033[1;1H" // clear screen << kAsciiArtBanner << "\n\n"; - ShowConfig(loader, threading, inference, gemma.Config(), - gemma.WeightReadMode(), ctx); + ShowConfig(args, gemma.Config(), gemma.WeightReadMode(), ctx); std::cout << "\n" << instructions << "\n"; } } - ReplGemma(threading, inference, gemma, kv_cache, env); + ReplGemma(args, gemma, kv_cache, env); } } // namespace gcpp @@ -298,17 +303,24 @@ int main(int argc, char** argv) { gcpp::InternalInit(); { // Negligible CPU time. - gcpp::LoaderArgs loader(argc, argv); - gcpp::ThreadingArgs threading(argc, argv); - gcpp::InferenceArgs inference(argc, argv); + gcpp::ConsumedArgs consumed(argc, argv); + gcpp::GemmaArgs args(argc, argv, consumed); if (gcpp::HasHelp(argc, argv)) { std::cerr << gcpp::kAsciiArtBanner; - gcpp::ShowHelp(loader, threading, inference); + fprintf(stderr, + "\n\ngemma.cpp : a lightweight, standalone C++ inference engine\n" + "==========================================================\n\n" + "*Example Usage*\n\n./gemma --tokenizer tokenizer.spm " + "--weights gemma2-2b-it-sfp.sbs\n\n"); + args.Help(); return 0; } - gcpp::Run(loader, threading, inference); + // After `HasHelp` so that we print --help even if unconsumed args remain. + consumed.AbortIfUnconsumed(); + + gcpp::Run(args); } PROFILER_PRINT_RESULTS(); // Must call outside the zone above. return 0; diff --git a/gemma/tensor_info.cc b/gemma/tensor_info.cc index 05f829b..2810307 100644 --- a/gemma/tensor_info.cc +++ b/gemma/tensor_info.cc @@ -1,3 +1,18 @@ +// Copyright 2025 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #include "gemma/tensor_info.h" #include diff --git a/gemma/tensor_info.h b/gemma/tensor_info.h index 6becb29..60decf8 100644 --- a/gemma/tensor_info.h +++ b/gemma/tensor_info.h @@ -1,3 +1,18 @@ +// Copyright 2025 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_TENSOR_INFO_H_ #define THIRD_PARTY_GEMMA_CPP_GEMMA_TENSOR_INFO_H_ diff --git a/gemma/tensor_stats.cc b/gemma/tensor_stats.cc new file mode 100644 index 0000000..53203b6 --- /dev/null +++ b/gemma/tensor_stats.cc @@ -0,0 +1,205 @@ +// Copyright 2025 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gemma/tensor_stats.h" + +#if GCPP_TENSOR_STATS +#include +#include +#include + +#include +#include +#include + +#include "io/io.h" +#include "util/mat.h" +#include "util/threading_context.h" +#include "util/zones.h" +#include "hwy/profiler.h" // StringTable + +// Compiles this file for multiple architectures via "foreach_target.h", to +// which we pass the filename via macro 'argument'. +// clang-format off +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "gemma/tensor_stats.cc" // NOLINT +// clang-format on +#include "hwy/foreach_target.h" // IWYU pragma: keep +#include "hwy/highway.h" +// After highway.h +#include "compression/compress-inl.h" +#include "ops/dot-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace gcpp { +namespace HWY_NAMESPACE { + +float Correlation(const float* x, size_t num) { + double sum = 0.0; + for (size_t i = 0; i < num; ++i) { + sum += x[i]; + } + const double mean = sum / static_cast(num); + + double numerator = 0.0; + double sum_sq_current = 0.0; + double sum_sq_next = 0.0; + + for (size_t i = 0; i < num - 1; ++i) { + const double diff_current = static_cast(x[i]) - mean; + const double diff_next = static_cast(x[i + 1]) - mean; + + numerator += diff_current * diff_next; + sum_sq_current += diff_current * diff_current; + sum_sq_next += diff_next * diff_next; + } + + if (sum_sq_current == 0.0 || sum_sq_next == 0.0) return 0.0f; + const double denominator = std::sqrt(sum_sq_current * sum_sq_next); + const float corr = static_cast(numerator / denominator); + HWY_DASSERT(-1.0f <= corr && corr <= 1.0f); + return corr; +} + +// Only write tensor data the first time it is encountered per layer. This is +// a concurrent string+layer -> flag map which avoids std::mutex (incompatible +// with fibers). We use a string table to index into per-layer atomic flags. +static bool ShouldWrite(const char* name, size_t layer_idx) { + constexpr size_t kMaxNames = 128; + constexpr size_t kMaxLayers = 128; + HWY_DASSERT(layer_idx < kMaxLayers); + static hwy::StringTable s_table; + const size_t name_idx = s_table.Add(name); + static std::atomic_flag flags[kMaxNames * kMaxLayers] = {}; + return !flags[name_idx * kMaxLayers + layer_idx].test_and_set( + std::memory_order_acq_rel); +} + +std::unique_ptr MaybeOpenFile(size_t layer_idx, const MatPtr& type_erased, + const Path& tensor_output) { + if (tensor_output.Empty()) return nullptr; + if (!ShouldWrite(type_erased.Name(), layer_idx)) return nullptr; + char path[1024]; + snprintf(path, sizeof(path), "%s/%s_L%02zu_%zux%zu_%s.bin", + tensor_output.path.c_str(), type_erased.Name(), layer_idx, + type_erased.Rows(), type_erased.Cols(), + TypeName(type_erased.GetType())); + return OpenFileOrAbort(Path(path), "wb"); +} + +void MaybeWriteRow(const std::unique_ptr& file, const MatPtr& type_erased, + size_t row_idx) { + if (!file) return; + const size_t bytes_per_row = type_erased.Cols() * type_erased.ElementBytes(); + file->Write(type_erased.RowBytes(row_idx), bytes_per_row, + bytes_per_row * row_idx); +} + +// First dispatch to the type, then parallel over rows, then vectorized +// decompress and Notify for each value. +void UpdateStatsT(TensorStats& stats, size_t layer_idx, + const MatPtr& type_erased, ThreadingContext& ctx, int flags, + size_t cluster_idx, Parallelism parallelism) { + std::unique_ptr file = + MaybeOpenFile(layer_idx, type_erased, ctx.tensor_output); + + if ((flags & kTensorStatsIsWeight) && layer_idx != 0) { + // Still compute stats, but remember not to print them. + stats.Get(layer_idx, 0).DoNotPrint(); + } + + CallUpcasted(&type_erased, [&](const auto* mat) { + const size_t cols = mat->Cols(); + + ParallelFor( + parallelism, mat->Rows(), ctx, cluster_idx, Callers::kTensorStats, + [&](size_t row_idx, size_t global_idx) { + GCPP_ZONE(ctx, global_idx, Zones::kGenStats); + + auto* HWY_RESTRICT row = mat->Row(row_idx); + MaybeWriteRow(file, type_erased, row_idx); + + using Packed = hwy::RemoveCvRef; + PackedSpan packed(const_cast(row), cols); + + TensorStatsAccumulator& my_stats = stats.Get(layer_idx, global_idx); + my_stats.NotifyCond(ConditionNumber(row, cols)); + + namespace hn = hwy::HWY_NAMESPACE; + hn::ScalableTag df; + using VF = hn::Vec; + HWY_LANES_CONSTEXPR size_t NF = hn::Lanes(df); + HWY_ALIGN float buf[2 * hn::MaxLanes(df)]; + + size_t packed_ofs = 0; + if (cols >= 2 * NF) { + for (; packed_ofs <= cols - 2 * NF; packed_ofs += 2 * NF) { + VF v0, v1; + Decompress2(df, packed, packed_ofs, v0, v1); + hn::Store(v0, df, buf); + hn::Store(v1, df, buf + NF); + const VF min_mag = hn::Min(hn::Abs(v0), hn::Abs(v1)); + const VF max_mag = hn::Max(hn::Abs(v0), hn::Abs(v1)); + const float min = hn::ReduceMin(df, min_mag); + if (min != 0.0f) { // Avoid division by zero. + my_stats.NotifyGroup(min, hn::ReduceMax(df, max_mag)); + } + + for (size_t i = 0; i < 2 * NF; ++i) { + my_stats.Notify(buf[i], row_idx, packed_ofs + i); + } + my_stats.NotifyCorr(Correlation(buf, 2 * NF)); + } + } + + // Zero to two vectors remaining. + for (; packed_ofs < cols; packed_ofs += NF) { + const size_t remaining = HWY_MIN(NF, cols - packed_ofs); + DecompressAndZeroPad(df, packed, packed_ofs, buf, remaining); + // Skip NotifyGroup for this partial group. + for (size_t i = 0; i < remaining; ++i) { + my_stats.Notify(buf[i], row_idx, packed_ofs + i); + } + my_stats.NotifyCorr(Correlation(buf, remaining)); + } + }); + }); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace gcpp +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace gcpp { +HWY_EXPORT(UpdateStatsT); + +// Must reside in .cc file so that we can #include compress-inl.h. +void TensorStats::Notify(size_t layer_idx, const MatPtr& type_erased, + ThreadingContext& ctx, int flags, size_t cluster_idx, + Parallelism parallelism) { + // Ignore empty tensors. + if (type_erased.GetType() == Type::kUnknown || type_erased.Cols() == 0) { + return; + } + HWY_DYNAMIC_DISPATCH(UpdateStatsT)(*this, layer_idx, type_erased, ctx, flags, + cluster_idx, parallelism); +} + +} // namespace gcpp +#endif // HWY_ONCE + +#endif // GCPP_TENSOR_STATS diff --git a/gemma/tensor_stats.h b/gemma/tensor_stats.h new file mode 100644 index 0000000..6975ab5 --- /dev/null +++ b/gemma/tensor_stats.h @@ -0,0 +1,347 @@ +// Copyright 2025 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_TENSOR_STATS_H_ +#define THIRD_PARTY_GEMMA_CPP_GEMMA_TENSOR_STATS_H_ + +#include +#include +#include + +#include "util/basics.h" +#include "hwy/base.h" + +#ifndef GCPP_TENSOR_STATS +#define GCPP_TENSOR_STATS 0 +#endif + +#include "util/mat.h" +#include "util/threading_context.h" + +#if GCPP_TENSOR_STATS +#include +#include + +#include "hwy/stats.h" +#endif // GCPP_TENSOR_STATS + +namespace gcpp { + +// For flags. Used to inhibit printing per-layer stats for weights. +HWY_INLINE_VAR constexpr int kTensorStatsIsWeight = 1; + +#if GCPP_TENSOR_STATS + +HWY_INLINE_VAR constexpr size_t kStatsMaxCols = 8192; + +// Separate summary of the per-layer stats, updated by `TensorStatsAccumulator`. +// We pass per-layer statistics such as the mean value to `hwy::Stats::Notify`` +// to see the distribution of per-layer means. +struct TensorStatsAcrossLayers { + bool IsEmpty() const { return s_frobenius.Count() == 0; } + + void Print() { + const int skip = hwy::Stats::kNoGeomean; + fprintf(stderr, "frob %s\n", s_frobenius.ToString(skip).c_str()); + fprintf(stderr, "cnd.min %s\n", s_cond_min.ToString(skip).c_str()); + fprintf(stderr, "cnd.avg %s\n", s_cond_avg.ToString(skip).c_str()); + fprintf(stderr, "cnd.max %s\n", s_cond_max.ToString(skip).c_str()); + fprintf(stderr, "val.min %s\n", s_val_min.ToString(skip).c_str()); + fprintf(stderr, "val.avg %s\n", s_val_avg.ToString(skip).c_str()); + fprintf(stderr, "val.krt %s\n", s_val_kurt.ToString(skip).c_str()); + fprintf(stderr, "mag.min %s\n", s_mag_min.ToString(skip).c_str()); + fprintf(stderr, "mag.avg %s\n", s_mag_avg.ToString(skip).c_str()); + fprintf(stderr, "mag.max %s\n", s_mag_max.ToString(skip).c_str()); + if (hwy::ScalarAbs(s_corr_avg.Max()) > 0.05f) { + fprintf(stderr, "cor.avg %s\n", s_corr_avg.ToString(skip).c_str()); + } + fprintf(stderr, "cor.max %s\n", s_corr_max.ToString(skip).c_str()); + fprintf(stderr, "rng_avg %s\n", s_range_avg.ToString(skip).c_str()); + fprintf(stderr, "exp.min %s\n", s_exp_min.ToString(skip).c_str()); + fprintf(stderr, "exp.max %s\n", s_exp_max.ToString(skip).c_str()); + fprintf(stderr, "exp.mod %s\n", s_exp_mode.ToString(skip).c_str()); + if (s_exp_subnormal.Min() != 0.0f) { + fprintf(stderr, "exp.sub %s\n", s_exp_subnormal.ToString(skip).c_str()); + } + if (s_big_cols.Count() != 0) { + fprintf(stderr, "bigCols %s\n", s_big_cols.ToString(skip).c_str()); + const size_t modal_col = b_big_cols.ModalBinIdx(); + const size_t num_outlier_cols = b_big_cols.NumNonzero(); + if (num_outlier_cols > 256) { + fprintf(stderr, "bigCols: all up to %zu (max at %zu: %u layers):\n", + b_big_cols.LastNonzero(), modal_col, b_big_cols.Bin(modal_col)); + } else { + fprintf(stderr, "bigCols (max at %zu: %u layers):\n", modal_col, + b_big_cols.Bin(modal_col)); + for (size_t i = 0; i < kStatsMaxCols; ++i) { + if (b_big_cols.Bin(i) > 2) { + fprintf(stderr, " %3zu: %u\n", i, b_big_cols.Bin(i)); + } + } + } + } + fprintf(stderr, "\n"); + } + + hwy::Stats s_frobenius; + + hwy::Stats s_cond_min; + hwy::Stats s_cond_avg; + hwy::Stats s_cond_max; + + hwy::Stats s_val_min; + hwy::Stats s_val_avg; + hwy::Stats s_val_kurt; + + hwy::Stats s_mag_min; + hwy::Stats s_mag_avg; + hwy::Stats s_mag_max; + + hwy::Stats s_corr_avg; + hwy::Stats s_corr_max; + + hwy::Stats s_range_avg; + + hwy::Stats s_exp_min; + hwy::Stats s_exp_max; + hwy::Stats s_exp_mode; + hwy::Stats s_exp_subnormal; + + hwy::Stats s_big_cols; // total number of outlier cols + hwy::Bins b_big_cols; // # layers with outlier per col +}; + +// Per-thread and layer. +class TensorStatsAccumulator { + public: + void Notify(float val, size_t row_idx, size_t col_idx) { + const double dval = static_cast(val); + sum_sq_ += dval * dval; + + s_val_.Notify(val); + const float mag = hwy::ScalarAbs(val); + + if (HWY_UNLIKELY(mag >= 64.0f)) { + if (row_idx < kMaxBatchSize) b_big_row_.Notify(row_idx); + if (col_idx < kStatsMaxCols) b_big_col_.Notify(col_idx); + } + + // Skip zero so we can see the lowest actual magnitude + if (mag != 0.0f && mag != -0.0f) s_mag_.Notify(mag); + + const uint32_t binary32 = hwy::BitCastScalar(mag); + // Use biased exponent because Bins wants unsigned values. + const uint32_t biased_exp = binary32 >> 23; + HWY_DASSERT(biased_exp < 256); // already cleared sign bit + b_exp256_.Notify(biased_exp); + } + + void DoNotPrint() { skip_.fetch_or(1); } + bool ShouldPrint() const { return skip_.load() == 0; } + + // Vector code computed the min/max of a group (= two vectors); this is + // faster than doing it in `Notify`. + void NotifyGroup(float min, float max) { + s_group_min_.Notify(min); + s_group_max_.Notify(max); + // Caller ensures min != 0. + s_group_range_.Notify(max / min); + } + + void NotifyCorr(float corr) { s_corr_.Notify(corr); } + void NotifyCond(double cond) { s_cond_.Notify(cond); } + + void Assimilate(const TensorStatsAccumulator& other) { + skip_.fetch_or(other.skip_.load()); + + sum_sq_ += other.sum_sq_; + b_exp256_.Assimilate(other.b_exp256_); + b_big_row_.Assimilate(other.b_big_row_); + b_big_col_.Assimilate(other.b_big_col_); + s_val_.Assimilate(other.s_val_); + s_mag_.Assimilate(other.s_mag_); + s_corr_.Assimilate(other.s_corr_); + s_group_min_.Assimilate(other.s_group_min_); + s_group_max_.Assimilate(other.s_group_max_); + s_group_range_.Assimilate(other.s_group_range_); + } + + // Called on the per-layer representative after reducing across threads. + void NotifyAcrossLayer(TensorStatsAcrossLayers& s) { + s.s_frobenius.Notify(std::sqrt(sum_sq_)); + + s.s_cond_min.Notify(s_cond_.Min()); + s.s_cond_avg.Notify(s_cond_.Mean()); + s.s_cond_max.Notify(s_cond_.Max()); + + s.s_val_min.Notify(s_val_.Min()); + s.s_val_avg.Notify(s_val_.Mean()); + s.s_val_kurt.Notify(s_val_.Kurtosis()); + + s.s_mag_min.Notify(s_mag_.Min()); + s.s_mag_avg.Notify(s_mag_.Mean()); + s.s_mag_max.Notify(s_mag_.Max()); + + s.s_corr_avg.Notify(s_corr_.Mean()); + s.s_corr_max.Notify(s_corr_.Max()); + + s.s_range_avg.Notify(s_group_range_.Mean()); + + const uint32_t subnormals = b_exp256_.Bin(0); + // Prevent subnormals from hiding the min exponent. + b_exp256_.ResetBin(0); + s.s_exp_min.Notify(b_exp256_.FirstNonzero()); + s.s_exp_max.Notify(b_exp256_.LastNonzero()); + s.s_exp_mode.Notify(b_exp256_.ModalBinIdx()); + s.s_exp_subnormal.Notify(subnormals); + + const uint32_t num_outliers = b_big_col_.NumNonzero(); + if (num_outliers != 0) { + s.s_big_cols.Notify(num_outliers); + // For each col, count the number of layers that have an outlier there. + for (size_t i = 0; i < kStatsMaxCols; ++i) { + if (b_big_col_.Bin(i) != 0) s.b_big_cols.Notify(i); + } + } + } + + bool IsEmpty() const { return s_val_.Count() == 0; } + + void PrintAll() { + fprintf(stderr, "Frob %.2E\n", std::sqrt(sum_sq_)); + const int skip = hwy::Stats::kNoGeomean; + fprintf(stderr, "cnd %s\n", s_cond_.ToString(skip).c_str()); + fprintf(stderr, "val %s\n", s_val_.ToString(skip).c_str()); + fprintf(stderr, "mag %s\n", s_mag_.ToString(skip).c_str()); + fprintf(stderr, "corr %s\n", s_corr_.ToString(skip).c_str()); + fprintf(stderr, "group_min %s\n", s_group_min_.ToString(skip).c_str()); + fprintf(stderr, "group_max %s\n", s_group_max_.ToString(skip).c_str()); + fprintf(stderr, "group_range %s\n", s_group_range_.ToString(skip).c_str()); + b_exp256_.Print("exp"); + PrintBinRanges(b_big_row_, "big row"); + PrintBinRanges(b_big_col_, "big col"); + fprintf(stderr, "\n"); + } + + private: + template + void PrintBinRanges(const hwy::Bins& b, const char* name) { + uint64_t total = 0; + for (size_t i = 0; i < N; ++i) { + total += b.Bin(i); + } + if (total == 0) return; + + // If all bins are at least 10% of a uniform distribution, print the range + // to vastly reduce the log size. + const size_t min = HWY_MAX(1, total / (N * 10)); + size_t last = 0; + for (; last < N; ++last) { + if (b.Bin(last) < min) break; + } + if (last >= N / 2) { + // Also require all subsequent bins to be zero, otherwise we should + // print the outlier bins. + bool all_zero = true; + for (size_t i = last + 1; i < N; ++i) { + if (b.Bin(last) != 0) { + all_zero = false; + break; + } + } + if (all_zero) { + fprintf(stderr, "%s: uniform up to %zu\n", name, last); + return; + } + } + + b.Print(name, /*skip_zero=*/true); + } + + double sum_sq_ = 0.0; // for Frobenius norm + hwy::Bins<256> b_exp256_; // exponent + hwy::Bins b_big_row_; + hwy::Bins b_big_col_; + hwy::Stats s_val_; + hwy::Stats s_mag_; + hwy::Stats s_cond_; // condition number + hwy::Stats s_corr_; // lag-1 autocorrelation + hwy::Stats s_group_min_; + hwy::Stats s_group_max_; + hwy::Stats s_group_range_; + std::atomic skip_{0}; +}; + +class TensorStats { + public: + TensorStats(size_t num_layers, size_t max_workers) + : num_layers_(num_layers), + max_workers_(max_workers), + acc_(num_layers * max_workers) {} + + // Parallelized across rows. If `ctx.tensor_output` is not empty, writes + // tensor data to disk for offline analysis, once per tensor and layer. + void Notify(size_t layer_idx, const MatPtr& type_erased, + ThreadingContext& ctx, int flags = 0, size_t cluster_idx = 0, + Parallelism parallelism = Parallelism::kFlat); + + // For use by `UpdateStatsT`. + TensorStatsAccumulator& Get(size_t layer_idx, size_t global_idx) { + const size_t idx = layer_idx * max_workers_ + global_idx; + HWY_DASSERT(idx < acc_.size()); + return acc_[idx]; + } + + void ReduceAndPrint(const char* prefix) { + for (size_t layer_idx = 0; layer_idx < num_layers_; ++layer_idx) { + TensorStatsAccumulator& per_layer = Get(layer_idx, 0); + for (size_t global_idx = 1; global_idx < max_workers_; ++global_idx) { + per_layer.Assimilate(Get(layer_idx, global_idx)); + } + if (per_layer.IsEmpty()) continue; + per_layer.NotifyAcrossLayer(across_layers_); + if (per_layer.ShouldPrint()) { + fprintf(stderr, "-------------------- %s %zu\n", prefix, layer_idx); + per_layer.PrintAll(); + } + } + + if (!across_layers_.IsEmpty()) { + fprintf(stderr, "================= across layers %s\n", prefix); + across_layers_.Print(); + } + } + + private: + size_t num_layers_; + size_t max_workers_; + std::vector acc_; + TensorStatsAcrossLayers across_layers_; +}; + +#else // GCPP_TENSOR_STATS + +class TensorStats { + public: + TensorStats(size_t, size_t) {} + void Notify(size_t, const MatPtr&, ThreadingContext&, int = 0, size_t = 0, + Parallelism = Parallelism::kFlat) {} + void ReduceAndPrint(const char*) {} +}; + +#endif // GCPP_TENSOR_STATS +} // namespace gcpp + +#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_TENSOR_STATS_H_ diff --git a/gemma/vit.cc b/gemma/vit.cc index b00efda..31c6f0f 100644 --- a/gemma/vit.cc +++ b/gemma/vit.cc @@ -78,13 +78,9 @@ class VitAttention { const float query_scale = 1.0f / sqrtf(static_cast(qkv_dim)); PROFILER_ZONE("Gen.VitAttention.DotSoftmax"); - // Shift Q, K, VT to MatStorageT. - MatStorageT Q("Q2", Extents2D(num_tokens_, qkv_dim), - env_.ctx.allocator, MatPadding::kPacked); - MatStorageT K("K2", Extents2D(seq_len, qkv_dim), env_.ctx.allocator, - MatPadding::kPacked); - MatStorageT C("C2", Extents2D(num_tokens_, seq_len), - env_.ctx.allocator, MatPadding::kPacked); + MatPtrT& Q = activations_.attention.vit_Q; + MatPtrT& K = activations_.attention.vit_K; + MatPtrT& C = activations_.attention.vit_C; // Initialize att_out to zero prior to head loop. ZeroInit(activations_.attention.att_out); @@ -295,19 +291,20 @@ static HWY_NOINLINE void EmbedImagePatches(const Image& image, const size_t model_dim = model_config.vit_config.model_dim; const size_t patch_width = model_config.vit_config.patch_width; const size_t num_tokens = model_config.vit_config.seq_len; - const size_t patch_size = patch_width * patch_width * 3; + const size_t patch_area = patch_width * patch_width * 3; + const hwy::Divisor div_patch_dim(patch_width); HWY_DASSERT(weights.vit_img_embedding_kernel.Rows() == model_dim); - HWY_DASSERT(weights.vit_img_embedding_kernel.Cols() == patch_size); + HWY_DASSERT(weights.vit_img_embedding_kernel.Cols() == patch_area); HWY_DASSERT(activations.x.Cols() == model_dim); (void)model_dim; // img/embedding/kernel has original shape (14, 14, 3, 1152) // H x W x C x D transposed to D x (H x W x C) so here (1152, 14 * 14 * 3) // image_patches is (256, 14 * 14 * 3) // Must be padded, see `DoDecompressA`. - MatStorageT image_patches("patches", Extents2D(num_tokens, patch_size), + MatStorageT image_patches("patches", Extents2D(num_tokens, patch_area), env.ctx.allocator, MatPadding::kOdd); for (size_t i = 0; i < num_tokens; ++i) { - image.GetPatch(i, image_patches.Row(i)); + image.GetPatch(i, div_patch_dim, image_patches.Row(i)); } CallMatMul(image_patches, weights.vit_img_embedding_kernel, weights.vit_img_embedding_bias.PackedScale1(), env, activations.x); diff --git a/gemma/weights.cc b/gemma/weights.cc index e1e01bf..00c12c6 100644 --- a/gemma/weights.cc +++ b/gemma/weights.cc @@ -431,12 +431,12 @@ void WeightsPtrs::CopyFrom(const WeightsPtrs& other) { void WeightsPtrs::Fixup(std::vector& mat_owners, ThreadingContext& ctx) { const size_t cluster_idx = 0; - ParallelFor(ParallelismStrategy::kFlat, c_layers.size(), ctx, cluster_idx, + ParallelFor(Parallelism::kFlat, c_layers.size(), ctx, cluster_idx, Callers::kFixupWeights, [&](uint64_t layer, size_t /*worker*/) { GetLayer(layer)->Fixup(mat_owners, ctx); }); - ParallelFor(ParallelismStrategy::kFlat, vit_layers.size(), ctx, cluster_idx, + ParallelFor(Parallelism::kFlat, vit_layers.size(), ctx, cluster_idx, Callers::kFixupWeights, [&](uint64_t layer, size_t /*worker*/) { VitLayer(layer)->Fixup(mat_owners, ctx); }); @@ -527,7 +527,7 @@ static void AllocateAndBindAll(std::vector& tensors, // Allocate in parallel because faulting in large tensors is slow. ParallelFor( - ParallelismStrategy::kFlat, tensors.size(), ctx, /*cluster_idx=*/0, + Parallelism::kFlat, tensors.size(), ctx, /*cluster_idx=*/0, Callers::kAllocateAndBindAll, [&](uint64_t task, size_t /*thread*/) { TensorToRead& tensor = tensors[task]; MatPtr& mat = *tensor.mat; @@ -586,10 +586,9 @@ static void DecompressToBF16(MatPtr& mat, static void ReadAllToBF16(const std::vector& tensors, const BlobReader& reader, ThreadingContext& ctx) { // Especially TSAN is slow enough to warrant hierarchical parallelism. - const ParallelismStrategy strategy = HWY_IS_DEBUG_BUILD - ? ParallelismStrategy::kHierarchical - : ParallelismStrategy::kFlat; - ParallelFor(strategy, tensors.size(), ctx, /*cluster_idx=*/0, + const Parallelism parallelism = + HWY_IS_DEBUG_BUILD ? Parallelism::kHierarchical : Parallelism::kFlat; + ParallelFor(parallelism, tensors.size(), ctx, /*cluster_idx=*/0, Callers::kReadAllToBF16, [&](uint64_t task, size_t thread) { GCPP_ZONE(ctx, thread, Zones::kStartupWeightsReadAllToBF16); const TensorToRead& tensor = tensors[task]; @@ -677,7 +676,7 @@ static void ReadBatches(const BlobReader& reader, const std::vector& batches, ThreadingContext& ctx) { // >5x speedup from parallel reads when cached. - ParallelFor(ParallelismStrategy::kHierarchical, batches.size(), ctx, + ParallelFor(Parallelism::kHierarchical, batches.size(), ctx, /*cluster_idx=*/0, Callers::kReadBatches, [&](uint64_t task, size_t thread) { GCPP_ZONE(ctx, thread, Zones::kStartupWeightsReadBatches); diff --git a/gemma/weights.h b/gemma/weights.h index 3661869..4476e22 100644 --- a/gemma/weights.h +++ b/gemma/weights.h @@ -96,7 +96,8 @@ struct LayerWeightsPtrs { // other values for purposes of the KV cache. LayerWeightsPtrs(size_t layer_idx, const LayerConfig& config, const TensorInfoRegistry& tensors) - : finder_(LayerSuffix(layer_idx), tensors), + : layer_idx(layer_idx), + finder_(LayerSuffix(layer_idx), tensors), qkv_einsum_w(finder_("qkv_ein")), qkv_einsum_w1(finder_("qkv1_w")), qkv_einsum_w2(finder_("qkv2_w")), @@ -135,6 +136,7 @@ struct LayerWeightsPtrs { } ~LayerWeightsPtrs() = default; + const size_t layer_idx; const MatFinder finder_; // Files either have qkv_einsum_w with 2 stacked matrices or separate diff --git a/io/blob_compare.cc b/io/blob_compare.cc index 30a2199..9bb860e 100644 --- a/io/blob_compare.cc +++ b/io/blob_compare.cc @@ -106,7 +106,7 @@ void ReadBlobs(BlobReader& reader, const RangeVec& ranges, BlobVec& blobs, ThreadingContext& ctx, size_t cluster_idx) { HWY_ASSERT(reader.Keys().size() == blobs.size()); HWY_ASSERT(ranges.size() == blobs.size()); - ParallelFor(ParallelismStrategy::kWithinCluster, blobs.size(), ctx, + ParallelFor(Parallelism::kWithinCluster, blobs.size(), ctx, cluster_idx, Callers::kTest, [&](size_t i, size_t /*thread*/) { HWY_ASSERT(ranges[i].bytes == blobs[i].size()); reader.file().Read(ranges[i].offset, ranges[i].bytes, @@ -122,7 +122,7 @@ void ReadBothBlobs(BlobReader& reader1, BlobReader& reader2, const double t0 = hwy::platform::Now(); HWY_WARN("Reading %zu GiB, %zu clusters: ", total_bytes >> 30, ctx.pools.NumClusters()); - ParallelFor(ParallelismStrategy::kAcrossClusters, 2, ctx, 0, Callers::kTest, + ParallelFor(Parallelism::kAcrossClusters, 2, ctx, 0, Callers::kTest, [&](const size_t task, size_t cluster_idx) { ReadBlobs(task ? reader1 : reader2, task ? ranges1 : ranges2, task ? blobs1 : blobs2, ctx, cluster_idx); @@ -189,7 +189,7 @@ void CompareBlobs(const KeyVec& keys, BlobVec& blobs1, BlobVec& blobs2, const double t0 = hwy::platform::Now(); std::atomic blobs_equal{}; std::atomic blobs_diff{}; - ParallelFor(ParallelismStrategy::kHierarchical, keys.size(), ctx, 0, + ParallelFor(Parallelism::kHierarchical, keys.size(), ctx, 0, Callers::kTest, [&](size_t i, size_t /*thread*/) { const size_t mismatches = BlobDifferences(blobs1[i], blobs2[i], keys[i]); diff --git a/io/blob_store.cc b/io/blob_store.cc index af9f81d..8346e4b 100644 --- a/io/blob_store.cc +++ b/io/blob_store.cc @@ -488,11 +488,10 @@ void BlobWriter::Add(const std::string& key, const void* data, size_t bytes) { EnqueueChunks(keys_.size() - 1, curr_offset_, bytes, static_cast(data), writes); - const ParallelismStrategy strategy = file_->IsAppendOnly() - ? ParallelismStrategy::kNone - : ParallelismStrategy::kFlat; + const Parallelism parallelism = + file_->IsAppendOnly() ? Parallelism::kNone : Parallelism::kFlat; ParallelFor( - strategy, writes.size(), ctx_, + parallelism, writes.size(), ctx_, /*cluster_idx=*/0, Callers::kBlobWriter, [this, &writes](uint64_t i, size_t /*thread*/) { const BlobRange& range = writes[i].range; diff --git a/io/blob_store.h b/io/blob_store.h index e5f2221..82c2357 100644 --- a/io/blob_store.h +++ b/io/blob_store.h @@ -131,7 +131,7 @@ class BlobWriter { std::vector blob_sizes_; ThreadingContext& ctx_; // Current offset in the file used for writing. - int64_t curr_offset_ = 0; + uint64_t curr_offset_ = 0; }; } // namespace gcpp diff --git a/io/blob_store_test.cc b/io/blob_store_test.cc index bb41c7e..cf96684 100644 --- a/io/blob_store_test.cc +++ b/io/blob_store_test.cc @@ -130,7 +130,7 @@ TEST(BlobStoreTest, TestNumBlobs) { HWY_ASSERT_EQ(reader.Keys().size(), num_blobs); ParallelFor( - ParallelismStrategy::kFlat, num_blobs, ctx, /*cluster_idx=*/0, + Parallelism::kFlat, num_blobs, ctx, /*cluster_idx=*/0, Callers::kTest, [&](uint64_t i, size_t /*thread*/) { HWY_ASSERT_STRING_EQ(reader.Keys()[i].c_str(), std::to_string(i).c_str()); diff --git a/io/fields_test.cc b/io/fields_test.cc index 37bb942..f720c15 100644 --- a/io/fields_test.cc +++ b/io/fields_test.cc @@ -20,6 +20,8 @@ #include #include +#include +#include #include #include "hwy/tests/hwy_gtest.h" diff --git a/io/io.cc b/io/io.cc index 9363b07..2f479b2 100644 --- a/io/io.cc +++ b/io/io.cc @@ -110,7 +110,8 @@ class FilePosix : public File { HWY_WARN( "Read failure at pos %zu within size %zu with offset %zu and " "errno %d\n", - pos, size, offset, errno); + static_cast(pos), static_cast(size), + static_cast(offset), errno); break; } pos += bytes_read; @@ -130,7 +131,8 @@ class FilePosix : public File { HWY_WARN( "Write failure at pos %zu within size %zu with offset %zu and " "errno %d\n", - pos, size, offset, errno); + static_cast(pos), static_cast(size), + static_cast(offset), errno); break; } pos += bytes_written; @@ -194,9 +196,9 @@ std::unique_ptr OpenFileOrNull(const Path& filename, const char* mode) { namespace gcpp { std::unique_ptr OpenFileOrAbort(const Path& filename, const char* mode) { - std::unique_ptr file = OpenFileOrNull(filename, "r"); + std::unique_ptr file = OpenFileOrNull(filename, mode); if (!file) { - HWY_ABORT("Failed to open %s", filename.path.c_str()); + HWY_ABORT("Failed to open %s, errno %d", filename.path.c_str(), errno); } return file; } @@ -234,7 +236,9 @@ bool IOBatch::Add(void* mem, size_t bytes) { return true; } -void InternalInit() { +int InternalInit() { + // currently unused, except for init list ordering in GemmaEnv. + return 0; } uint64_t IOBatch::Read(const File& file) const { diff --git a/io/io.h b/io/io.h index f90a636..d051715 100644 --- a/io/io.h +++ b/io/io.h @@ -150,7 +150,7 @@ std::string ReadFileToString(const Path& path); // No-op in open-source. Must be called at the beginning of a binary, before // any I/O or flag usage. -void InternalInit(); +int InternalInit(); } // namespace gcpp diff --git a/io/migrate_weights.cc b/io/migrate_weights.cc index d20835f..beb268e 100644 --- a/io/migrate_weights.cc +++ b/io/migrate_weights.cc @@ -23,7 +23,9 @@ namespace gcpp { namespace { struct WriterArgs : public ArgsBase { - WriterArgs(int argc, char* argv[]) { InitAndParse(argc, argv); } + WriterArgs(int argc, char* argv[], ConsumedArgs& consumed) { + InitAndParse(argc, argv, consumed); + } Path output_weights; @@ -38,12 +40,15 @@ struct WriterArgs : public ArgsBase { } // namespace gcpp int main(int argc, char** argv) { - gcpp::WriterArgs args(argc, argv); - if (args.output_weights.Empty()) { + gcpp::ConsumedArgs consumed(argc, argv); + gcpp::GemmaArgs args(argc, argv, consumed); + gcpp::WriterArgs writer_args(argc, argv, consumed); + if (writer_args.output_weights.Empty()) { HWY_ABORT("Missing --output_weights flag, a file for the model weights."); } + consumed.AbortIfUnconsumed(); - gcpp::GemmaEnv env(argc, argv); - env.GetGemma()->Save(args.output_weights, env.Env().ctx); + gcpp::GemmaEnv env(args); + env.GetGemma()->Save(writer_args.output_weights, env.Env().ctx); return 0; } diff --git a/ops/dot-inl.h b/ops/dot-inl.h index dae2106..ecf1ecf 100644 --- a/ops/dot-inl.h +++ b/ops/dot-inl.h @@ -413,7 +413,8 @@ using DotKernelDefault = template HWY_INLINE float Dot(D d, const PackedSpan& w, size_t w_ofs, const VT* HWY_RESTRICT vec, size_t num) { - return DecompressAndCall(d, w, w_ofs, MakeSpan(vec, num), DotKernelDefault()); + return DecompressAndCall(d, w, w_ofs, MakeConstSpan(vec, num), + DotKernelDefault()); } // Adapter for two pointers, no bounds checking. diff --git a/ops/dot_test.cc b/ops/dot_test.cc index 827b6b4..bce8904 100644 --- a/ops/dot_test.cc +++ b/ops/dot_test.cc @@ -891,18 +891,6 @@ class DotStats { hwy::Stats s_times[kVariants]; }; -// Returns normalized value in [-1, 1). -float RandomFloat(RngStream& rng) { - const uint32_t exp = hwy::BitCastScalar(1.0f); - const uint32_t mantissa_mask = hwy::MantissaMask(); - const uint32_t representation = exp | (rng() & mantissa_mask); - const float f12 = hwy::BitCastScalar(representation); - HWY_DASSERT(1.0f <= f12 && f12 < 2.0f); // exponent is 2^0, only mantissa - const float f = (2.0f * (f12 - 1.0f)) - 1.0f; - HWY_DASSERT(-1.0f <= f && f < 1.0f); - return f; -} - // `raw` holds the decompressed values, so that the test measures only the // error from the Dot algorithms, not the compression. template @@ -1126,7 +1114,7 @@ void TestAllDot() { std::array all_stats; ParallelFor( - ParallelismStrategy::kWithinCluster, kReps, ctx, 0, Callers::kTest, + Parallelism::kWithinCluster, kReps, ctx, 0, Callers::kTest, [&](size_t rep, size_t thread) { float* HWY_RESTRICT pa = a.Row(thread); float* HWY_RESTRICT pb = b.Row(thread); diff --git a/ops/matmul-inl.h b/ops/matmul-inl.h index 96cd4f1..4b217a1 100644 --- a/ops/matmul-inl.h +++ b/ops/matmul-inl.h @@ -837,10 +837,11 @@ class MMImpl { hwy::platform::InvariantTicksPerSecond(); const double flops = 2 * M * K * N * num_B / min_elapsed; // * 2 for FMA if (HWY_UNLIKELY(env.print_measurement && tuner.ShouldPrint())) { - fprintf(stderr, "%zu,%zu,%zu,%zu,%7.1f,%.2f,%zu,%4zu,%4zu,%5zu,%s,%zu\n", - M, K, N, num_B, flops * 1E-9, min_elapsed * 1E3, cfg.MR(), - cfg.MC(), cfg.KC(), cfg.NC(), StringFromOrder(cfg.Order()), - cfg.InnerTasks()); + fprintf( + stderr, + "%4zu,%4zu,%4zu,B%zu,%7.1f,%.2f ms, MR%zu,%4zu,%4zu,%5zu,%-7s,%zu\n", + M, K, N, num_B, flops * 1E-9, min_elapsed * 1E3, cfg.MR(), cfg.MC(), + cfg.KC(), cfg.NC(), StringFromOrder(cfg.Order()), cfg.InnerTasks()); } if (HWY_UNLIKELY(env.print_best && tuner.Best())) { const auto ratio = [&tuner](uint64_t ticks) -> double { @@ -850,7 +851,8 @@ class MMImpl { const MMConfig& best = *tuner.Best(); fprintf( stderr, - "\n%zu,%zu,%zu,%zu,%7.1f,%.2f,%zu,%4zu,%4zu,%5zu,%s,%zu,%.2f,%.2f\n", + "\n%4zu,%4zu,%4zu,B%zu,%7.1f,%.2f ms, MR%zu,%4zu,%4zu,%5zu,%-7s,%zu, " + "%.2fx,%.2fx\n", M, K, N, num_B, flops * 1E-9, min_elapsed * 1E3, best.MR(), best.MC(), best.KC(), best.NC(), StringFromOrder(best.Order()), best.InnerTasks(), ratio(tuner.WorstMinTicks()), @@ -906,8 +908,8 @@ class MMLoops { const auto zone = args.env.ctx.profiler_zones.Get(Zones::kMMNT); HWY_DASSERT(args.ranges_mc.NumTasks() == 1); HWY_DASSERT(args.ranges_kc.NumTasks() == 1); - const IndexRange& range_mc = args.ranges_mc.Range(0); - const IndexRange& range_kc = args.ranges_kc.Range(0); + const IndexRange& range_mc = args.ranges_mc.Range(0); // whole M + const IndexRange& range_kc = args.ranges_kc.Range(0); // whole K parallel.ForN( args.env.ctx, args.range_n, MultipleN(sizeof(TC), args.line_bytes), @@ -941,7 +943,7 @@ class MMLoops { const MMArgs& args) { const auto zone = args.env.ctx.profiler_zones.Get(Zones::kMMNT_K); HWY_DASSERT(args.ranges_mc.NumTasks() == 1); - const IndexRange& range_mc = args.ranges_mc.Range(0); + const IndexRange& range_mc = args.ranges_mc.Range(0); // whole M parallel.ForN(args.env.ctx, args.range_n, MultipleN(sizeof(TC), args.line_bytes), args.inner_tasks, @@ -977,7 +979,7 @@ class MMLoops { const MMArgs& args) { const auto zone = args.env.ctx.profiler_zones.Get(Zones::kMMNT_MT); HWY_DASSERT(args.ranges_kc.NumTasks() == 1); - const IndexRange& range_kc = args.ranges_kc.Range(0); + const IndexRange& range_kc = args.ranges_kc.Range(0); // whole K parallel.ForRangesMC_NC( args.env.ctx, args.ranges_mc, args.ranges_nc, args.options.cluster_idx, @@ -1158,8 +1160,9 @@ HWY_NOINLINE MMPerKey* TwoMatMul(const MatPtrT& A, const MatPtrT& B1, HWY_ASSERT(K <= MMEntireA::kMaxK); HWY_ASSERT(N % kNR == 0); MMImpl::EnsureAligned(A, cache.VectorBytes()); - tuner.SetCandidates( - MMCandidates(cache, M, K, N, num_B, sizeof(BF16), env.print_config)); + const size_t max_M = MMKeys::BucketM(M); + tuner.SetCandidates(MMCandidates(cache, max_M, K, N, num_B, sizeof(BF16), + env.print_config)); } const MMConfig& cfg = tuner.NextConfig(); diff --git a/ops/matmul.cc b/ops/matmul.cc index ebeff9b..c01943d 100644 --- a/ops/matmul.cc +++ b/ops/matmul.cc @@ -21,6 +21,7 @@ #include #include +#include #include #include "util/allocator.h" @@ -46,7 +47,9 @@ size_t RoundDownWithFloor(size_t value, size_t multiple) { // multiple of `multiple`, or 0 if none exists. size_t PrevDivisor(const size_t begin, const size_t end, const size_t dim, const size_t multiple) { - HWY_DASSERT(end != 0 && dim != 0 && multiple != 0); + HWY_DASSERT(end != 0); + HWY_DASSERT(dim != 0); + HWY_DASSERT(multiple != 0); size_t prev = RoundDownWithFloor(end, multiple); // Avoid returning `end` if rounding down had no effect. if (prev == end) prev -= multiple; @@ -62,10 +65,10 @@ size_t PrevDivisor(const size_t begin, const size_t end, const size_t dim, // and holds most of their arguments in member variables. class GenerateCandidates { public: - GenerateCandidates(const CacheInfo& cache, size_t M, size_t K, size_t N, + GenerateCandidates(const CacheInfo& cache, size_t max_M, size_t K, size_t N, size_t num_B, size_t sizeof_TC, bool print_config) : cache_(cache), - M_(M), + max_M_(max_M), K_(K), N_(N), num_B_(num_B), @@ -89,14 +92,14 @@ class GenerateCandidates { for (size_t mc : MC(mr, kc, order)) { for (size_t nc : NC(mr, mc, kc, order)) { for (int inner_tasks : all_inner_tasks) { - const MMConfig config(K_, N_, mr, mc, kc, nc, kc_multiple_, - nc_multiple_, order, inner_tasks); - const size_t M_tasks = config.RangesOfMC(M_).NumTasks(); + const MMConfig config(max_M_, K_, N_, mr, mc, kc, nc, + kc_multiple_, nc_multiple_, order, + inner_tasks); + const size_t M_tasks = config.RangesOfMC(max_M_).NumTasks(); const size_t K_tasks = config.RangesOfKC(K_).NumTasks(); - // Blocks only make sense when there are multiple M tasks. - if (IsBlock(order) != (M_tasks > 1)) continue; - // Single KC only makes sense when there is a single K task. + // Do not use single-MC/KC order if there are multiple. + if (IsOneMC(order) != (M_tasks == 1)) continue; if (IsOneKC(order) != (K_tasks == 1)) continue; candidates.push_back(config); @@ -114,6 +117,25 @@ class GenerateCandidates { private: using SizeVec = std::vector; + // Concatenate and print once because this can be called concurrently. + void MaybePrintSizes(size_t dim, size_t max, const char* caption, + const SizeVec& sizes) const { + if (!print_config_ || sizes.empty()) return; + std::string out("num_B "); + out += std::to_string(num_B_); + out += " ("; + out += std::to_string(dim); + out += ", max "; + out += std::to_string(max); + out += ") "; + out += caption; + out += ": "; + for (size_t size : sizes) { + out += std::to_string(size) + " "; + } + fprintf(stderr, "%s\n", out.c_str()); + } + // How many rows of A per call to `MMKernel::LoopKC`. Lower values may // be better for SIMD targets with fewer registers. SizeVec MR() const { @@ -125,14 +147,14 @@ class GenerateCandidates { SizeVec all_mr; all_mr.reserve(3); // AVX2's 16 registers are not enough for four rows, but SSE4 may benefit. - if (M_ >= kMaxMR && !is_avx2) all_mr.push_back(kMaxMR); + if (max_M_ >= kMaxMR && !is_avx2) all_mr.push_back(kMaxMR); // Allow for AVX-512 but not SSE4 (for which 4 are usually better). Also // enable if not enough rows for 4. - if (M_ >= 2 && (M_ < kMaxMR || (!is_sse && !is_wasm))) { + if (max_M_ >= 2 && (max_M_ < kMaxMR || (!is_sse && !is_wasm))) { all_mr.push_back(size_t{2}); } // Even SSE4 usually prefers 2 rows; only enable for single rows. - if (M_ == 1) all_mr.push_back(size_t{1}); + if (max_M_ == 1) all_mr.push_back(size_t{1}); HWY_ASSERT(!all_mr.empty()); return all_mr; } @@ -143,18 +165,26 @@ class GenerateCandidates { for (size_t order_idx = 0;; ++order_idx) { const MMOrder order = static_cast(order_idx); if (StringFromOrder(order) == nullptr) return orders; // done - // 2D blocking is useless for a single row of M. - if (IsBlock(order) && M_ <= mr) continue; + // Multiple-MC is useless for a single row of M. + if (!IsOneMC(order) && max_M_ <= mr) continue; // Conversely, N-only parallelism is uncompetitive for large M. - if (!IsBlock(order) && M_ >= kMaxTilesM * mr) continue; + if (IsOneMC(order) && max_M_ >= 8 * mr) continue; orders.push_back(order); } } // The number of A and B columns to read between updating `C`. SizeVec KC(size_t mr, MMOrder order) const { + if (IsOneKC(order)) { + // A single KC range is infeasible when K exceeds the max. The caller + // will skip all configs with `order`. + if (K_ > kMaxKC) return SizeVec(); + // Must return the actual value: although ignored by `RangesOfKC`, this + // will be used in MC() and NC(). + return SizeVec(1, K_); + } // `LoopKC` handles up to `mr` rows of A. - const size_t rows_a = HWY_MIN(M_, mr); + const size_t rows_a = HWY_MIN(max_M_, mr); // After looping over `kc` columns, we write `mr x 4` outputs and 16 vector // `buf`. To amortize the write cost, we want to maximize `kc`. However, it @@ -186,7 +216,7 @@ class GenerateCandidates { // If we can afford a single K task, that's usually best; only try one // more. Otherwise, blocks may require smaller kc (more options). - const size_t reps = (kc_max == K_) ? 1 : IsBlock(order) ? 3 : 2; + const size_t reps = (kc_max == K_) ? 1 : IsOneMC(order) ? 2 : 3; size_t prev = kc_max; for (size_t rep = 0; rep < reps; ++rep) { @@ -196,22 +226,27 @@ class GenerateCandidates { } } - if (print_config_ && all_kc.size() > 1) { - fprintf(stderr, "num_B %zu: KC: ", num_B_); - for (size_t kc : all_kc) { - fprintf(stderr, "%zu ", kc); - } - fprintf(stderr, "\n"); - } - + MaybePrintSizes(K_, kc_max, "KC", all_kc); return all_kc; } // The number of (L2 resident) A rows for `A2C0` to loop over. SizeVec MC(size_t mr, size_t kc, MMOrder order) const { + if (max_M_ <= mr) return SizeVec(1, max_M_); + if (IsOneMC(order)) { + // A single MC range is infeasible when M exceeds the max. The caller + // will skip all configs with `order`. + if (max_M_ > kMaxMC) return SizeVec(); + // Must return the actual value: although ignored by `RangesOfMC`, this + // will be used in NC(). + return SizeVec(1, max_M_); + } + // Typically 12-24K. The B rows are pinned in L1, but also occupy L2 because // it is typically inclusive. const size_t bytes_b = kNR * kc * (sizeof(SfpStream) + sizeof(BF16)); + // `kc` was chosen to fit in L1, hence this should not exceed L2. + HWY_ASSERT(bytes_b <= cache_.L2Bytes()); // Choose the largest feasible `mc_max` (A/C rows) to maximize reuse of the // packed B. We want `mc * kc` elements of A to fit in L2, alongside @@ -219,35 +254,45 @@ class GenerateCandidates { const size_t bytes_per_mc = kc * sizeof(BF16) + cache_.LineBytes(); size_t mc_max = hwy::DivCeil(cache_.L2Bytes() - bytes_b, bytes_per_mc); mc_max = HWY_MIN(mc_max, HWY_MIN(kMaxBatchSize, kMaxMC)); - HWY_DASSERT(mc_max != 0); - mc_max = HWY_MIN(mc_max, M_); - mc_max = hwy::RoundDownTo(mc_max, mr); + mc_max = HWY_MIN(mc_max, max_M_); + HWY_ASSERT(mc_max != 0); - SizeVec all_mc(1, mc_max); - // Larger MC is better for non-blocks, otherwise we want more small options, - // especially for two B. - const size_t reps = !IsBlock(order) ? 2 : (2 + num_B_); + SizeVec all_mc; + all_mc.reserve(6); - size_t prev = mc_max; - for (size_t rep = 0; rep < reps; ++rep) { - prev = PrevDivisor(1, prev, M_, mr); - if (prev >= mc_max || prev == 0) break; + const size_t rounded_M = HWY_MAX(mr, hwy::RoundDownTo(max_M_, mr)); + size_t prev = hwy::RoundDownTo(mc_max, mr); + + // If mc_max is large enough, allow using the whole range without rounding + // down (which may require two ranges). + if (mc_max == max_M_ && (max_M_ % mr) != 0) { + all_mc.push_back(max_M_); + // The next option should be considerably smaller than `max_M_`. + prev = HWY_MAX(mr, hwy::RoundDownTo(3 * prev / 4, mr)); + } else { all_mc.push_back(prev); } - // Blocks: largest is not useful. - if (IsBlock(order) && all_mc.size() > 1) { - all_mc.erase(all_mc.begin(), all_mc.begin() + 1); - } - - if (print_config_ && all_mc.size() > 1) { - fprintf(stderr, "num_B %zu: MC: ", num_B_); - for (size_t mc : all_mc) { - fprintf(stderr, "%zu ", mc); + // We know `order` is multiple MC, where more/smaller values of `mc` are + // helpful, especially for two B, hence add iterations. + const size_t reps = 2 + num_B_; + for (size_t rep = 0; rep < reps; ++rep) { + prev = PrevDivisor(mr, prev, rounded_M, mr); + if (prev == 0) break; // none found + if (prev == mr) { + if (all_mc.back() != prev) all_mc.push_back(prev); + break; } - fprintf(stderr, "\n"); + if (prev <= mc_max / 8) break; + all_mc.push_back(prev); } + if (all_mc.size() <= 2) { + if (max_M_ > mr) all_mc.push_back(max_M_ / 2); + if (mc_max > mr) all_mc.push_back(mc_max / 2); + } + + MaybePrintSizes(max_M_, mc_max, "MC", all_mc); return all_mc; } @@ -257,7 +302,7 @@ class GenerateCandidates { // Only if there will be reuse of B: choose the largest `nc_max` (C cols) // such that `nc x kc` of B and `mc x nc` of `C` fit in L3. Otherwise, // leave it unbounded. - if (M_ > mr) { + if (max_M_ > mr) { const size_t bytes_per_nc = (kc * sizeof(BF16) + mc * sizeof_TC_); nc_max = HWY_MIN(hwy::DivCeil(cache_.L3Bytes(), bytes_per_nc), kMaxNC); } @@ -271,8 +316,8 @@ class GenerateCandidates { nc_max = RoundDownWithFloor(N_ / 2, nc_multiple_); } - // Non-block calls ForNP, which ignores `range_nc` and uses `range_np`. - if (!IsBlock(order)) return SizeVec(1, N_); + // Single-MC calls `ForNP`, which ignores `range_nc`. + if (IsOneMC(order)) return SizeVec(1, N_); SizeVec all_nc(1, nc_max); @@ -282,7 +327,7 @@ class GenerateCandidates { // hence autotune a wider range of nc than the other dimensions. size_t reps = 9 + num_B_; // For small M, we can afford larger NC, hence allow fewer small options. - if (M_ <= 2 * mr) reps -= 1; + if (max_M_ <= 2 * mr) reps -= 1; size_t prev = nc_max; for (size_t rep = 0; rep < reps; ++rep) { @@ -302,14 +347,7 @@ class GenerateCandidates { all_nc.begin() + HWY_MIN(want_delete, max_delete)); } - if (print_config_ && all_nc.size() > 1) { - fprintf(stderr, "num_B %zu: NC: ", num_B_); - for (size_t nc : all_nc) { - fprintf(stderr, "%zu ", nc); - } - fprintf(stderr, "\n"); - } - + MaybePrintSizes(N_, nc_max, "NC", all_nc); return all_nc; } @@ -319,8 +357,8 @@ class GenerateCandidates { std::vector inner_tasks; inner_tasks.reserve(3); inner_tasks.push_back(1); - // Blocks have one task per mc/nc range and ignore this parameter. - if (!IsBlock(order)) { + // Multiple-MC have one task per mc/nc range and ignore this parameter. + if (IsOneMC(order)) { inner_tasks.push_back(2); inner_tasks.push_back(4); } @@ -328,7 +366,7 @@ class GenerateCandidates { } const CacheInfo& cache_; - const size_t M_; + const size_t max_M_; const size_t K_; const size_t N_; const size_t num_B_; @@ -343,10 +381,11 @@ class GenerateCandidates { } // namespace // Facade to avoid exposing `GenerateCandidates` in the header. -std::vector MMCandidates(const CacheInfo& cache, size_t M, size_t K, - size_t N, size_t num_B, size_t sizeof_TC, - bool print_config) { - return GenerateCandidates(cache, M, K, N, num_B, sizeof_TC, print_config)(); +std::vector MMCandidates(const CacheInfo& cache, size_t max_M, + size_t K, size_t N, size_t num_B, + size_t sizeof_TC, bool print_config) { + return GenerateCandidates(cache, max_M, K, N, num_B, sizeof_TC, + print_config)(); } MatMulEnv::MatMulEnv(ThreadingContext& ctx) diff --git a/ops/matmul.h b/ops/matmul.h index ea7c090..85deb62 100644 --- a/ops/matmul.h +++ b/ops/matmul.h @@ -61,7 +61,7 @@ HWY_INLINE_VAR constexpr size_t kMaxNC = 16384; // of BF16 A and B fit in 32 KiB L1, but there may be `kMaxMR` and `kNR`. HWY_INLINE_VAR constexpr size_t kMaxKC = 8 * 1024; -// Policy classes for parallelism, implementing some of `ParallelismStrategy`. +// Policy classes for parallelism, implementing some of `Parallelism`. struct MMParallelNone { template @@ -103,18 +103,14 @@ struct MMParallelWithinCluster { template void ForN(ThreadingContext& ctx, const IndexRange& range_n, size_t n_multiple, size_t inner_tasks, size_t cluster_idx, const Func& func) const { - HWY_DASSERT(1 <= inner_tasks && inner_tasks <= 4); + const hwy::pool::Caller caller = + ctx.pool_callers.Get(Callers::kMMClusterForN); - hwy::ThreadPool& cluster = ctx.pools.Cluster(cluster_idx); - const size_t base = ctx.Worker(cluster_idx); - - const IndexRangePartition ranges_n = StaticPartition( - range_n, cluster.NumWorkers() * inner_tasks, n_multiple); - ParallelizeOneRange(ranges_n, cluster, - ctx.pool_callers.Get(Callers::kMMClusterForN), - [&](const IndexRange& worker_range, size_t worker) { - func(worker_range, base + worker); - }); + ParallelPartitionWithinCluster( + range_n, n_multiple, inner_tasks, ctx, cluster_idx, caller, + [&](const IndexRange& worker_range, size_t worker) { + func(worker_range, worker); + }); } template @@ -122,79 +118,56 @@ struct MMParallelWithinCluster { const IndexRangePartition& ranges_mc, const IndexRangePartition& ranges_nc, size_t cluster_idx, const Func& func) const { - hwy::ThreadPool& cluster = ctx.pools.Cluster(cluster_idx); - const size_t base = ctx.Worker(cluster_idx); + const hwy::pool::Caller caller = + ctx.pool_callers.Get(Callers::kMMClusterForMCNC); - // Low-batch: avoid Divide/Remainder. - if (HWY_UNLIKELY(ranges_mc.NumTasks() == 1)) { - ParallelizeOneRange(ranges_nc, cluster, - ctx.pool_callers.Get(Callers::kMMClusterForMCNC), - [&](const IndexRange& range_nc, size_t worker) { - func(ranges_mc.Range(0), range_nc, base + worker); - }); - } else { - ParallelizeTwoRanges( - ranges_mc, ranges_nc, cluster, - ctx.pool_callers.Get(Callers::kMMClusterForMCNC), - [&](const IndexRange& range_mc, const IndexRange& range_nc, - size_t worker) { func(range_mc, range_nc, base + worker); }); - } + // We are running on one pool, hence collapse into a 1D range. + const hwy::Divisor div_m(static_cast(ranges_mc.NumTasks())); + const auto get_mc = [&](uint64_t task) { + return ranges_mc.Range(div_m.Remainder(static_cast(task))); + }; + const auto get_nc = [&](uint64_t task) { + return ranges_nc.Range(div_m.Divide(static_cast(task))); + }; + const size_t num_tasks = ranges_mc.NumTasks() * ranges_nc.NumTasks(); + + ParallelForWithinCluster(num_tasks, ctx, cluster_idx, caller, + [&](uint64_t task, size_t worker) { + func(get_mc(task), get_nc(task), worker); + }); } template void ForRangeMC(ThreadingContext& ctx, const IndexRange& range_mc, size_t cluster_idx, const Func& func) const { - hwy::ThreadPool& cluster = ctx.pools.Cluster(cluster_idx); - const size_t base = ctx.Worker(cluster_idx); + const hwy::pool::Caller caller = + ctx.pool_callers.Get(Callers::kMMClusterForMC); - cluster.Run( - range_mc.begin(), range_mc.end(), - ctx.pool_callers.Get(Callers::kMMClusterForMC), - [&](uint64_t row_a, size_t worker) { func(row_a, base + worker); }); + ParallelForWithinCluster( + range_mc.Num(), ctx, cluster_idx, caller, + [&](uint64_t i, size_t worker) { func(range_mc.begin() + i, worker); }); } }; struct MMParallelHierarchical { - // Cluster/CCX-aware parallel-for over B rows in `range_n`. `n_multiple` is - // the granularity of per-cluster tasks. Calls `func(worker_range, worker)`. + // Similar to `HierarchicalParallelFor`, but over *sub-ranges* of B rows in + // `range_n` governed by `n_multiple` and `inner_tasks`. template void ForN(ThreadingContext& ctx, const IndexRange& range_n, size_t n_multiple, - size_t inner_tasks, HWY_MAYBE_UNUSED size_t caller_cluster_idx, + size_t inner_tasks, size_t caller_cluster_idx, const Func& func) const { - HWY_DASSERT(1 <= inner_tasks && inner_tasks <= 4); HWY_DASSERT(caller_cluster_idx == 0); + (void)caller_cluster_idx; const hwy::pool::Caller caller = ctx.pool_callers.Get(Callers::kMMHierForN); - // Single cluster: parallel-for over static partition of `range_n`. - hwy::ThreadPool& all_clusters = ctx.pools.AllClusters(); - const size_t num_clusters = all_clusters.NumWorkers(); - if (num_clusters == 1) { - const size_t cluster_idx = 0; - hwy::ThreadPool& cluster = ctx.pools.Cluster(cluster_idx); - const IndexRangePartition ranges_n = StaticPartition( - range_n, cluster.NumWorkers() * inner_tasks, n_multiple); - return ParallelizeOneRange( - ranges_n, cluster, caller, - [&](const IndexRange& worker_range, size_t worker) { - func(worker_range, worker); - }); - } - - // Assign each cluster a sub-range of `range_n` (typically hundreds). - const IndexRangePartition ranges_n = - StaticPartition(range_n, num_clusters, n_multiple); - ParallelizeOneRange( - ranges_n, all_clusters, caller, - [&](const IndexRange& n_range, const size_t cluster_idx) { - hwy::ThreadPool& cluster = ctx.pools.Cluster(cluster_idx); - const size_t cluster_base = ctx.Worker(cluster_idx); - // Parallel-for over sub-ranges of `cluster_range` within the cluster. - const IndexRangePartition worker_ranges = StaticPartition( - n_range, cluster.NumWorkers() * inner_tasks, n_multiple); - ParallelizeOneRange( - worker_ranges, cluster, caller, + // Assign clusters (if any) a sub-range of `range_n` (typically hundreds). + ParallelPartitionAcrossClusters( + range_n, n_multiple, /*inner_tasks=*/1, ctx, caller, + [&](const IndexRange& cluster_range, size_t cluster_idx) { + ParallelPartitionWithinCluster( + cluster_range, n_multiple, inner_tasks, ctx, cluster_idx, caller, [&](const IndexRange& worker_range, size_t worker) { - func(worker_range, cluster_base + worker); + func(worker_range, worker); }); }); } @@ -205,69 +178,56 @@ struct MMParallelHierarchical { void ForRangesMC_NC(ThreadingContext& ctx, const IndexRangePartition& ranges_mc, const IndexRangePartition& ranges_nc, - HWY_MAYBE_UNUSED size_t caller_cluster_idx, - const Func& func) const { + size_t caller_cluster_idx, const Func& func) const { HWY_DASSERT(caller_cluster_idx == 0); + (void)caller_cluster_idx; const hwy::pool::Caller caller = ctx.pool_callers.Get(Callers::kMMHierForMCNC); - hwy::ThreadPool& all_clusters = ctx.pools.AllClusters(); - // `all_clusters` is a pool with one worker per cluster in a package. - const size_t num_clusters = all_clusters.NumWorkers(); - // Single (big) cluster: collapse two range indices into one parallel-for - // to reduce the number of fork-joins. - if (num_clusters == 1) { - const size_t cluster_idx = 0; - hwy::ThreadPool& cluster = ctx.pools.Cluster(cluster_idx); - // Low-batch: avoid Divide/Remainder. - if (HWY_UNLIKELY(ranges_mc.NumTasks() == 1)) { - return ParallelizeOneRange( - ranges_nc, cluster, caller, - [&](const IndexRange& range_nc, size_t worker) { - func(ranges_mc.Range(0), range_nc, worker); - }); - } else { - return ParallelizeTwoRanges( - ranges_mc, ranges_nc, cluster, caller, - [&](const IndexRange& range_mc, const IndexRange& range_nc, - size_t worker) { func(range_mc, range_nc, worker); }); - } - } + // Collapse two range indices into a 1D range for better load-balancing, + // because `ranges_mc` may just have one task. + const hwy::Divisor div_m(static_cast(ranges_mc.NumTasks())); + const auto get_mc = [&](uint64_t task) { + return ranges_mc.Range(div_m.Remainder(static_cast(task))); + }; + const auto get_nc = [&](uint64_t task) { + return ranges_nc.Range(div_m.Divide(static_cast(task))); + }; + const IndexRange all_range(0, ranges_mc.NumTasks() * ranges_nc.NumTasks()); - // Multiple clusters: N across clusters (both are usually the larger), and - // M within each cluster. We assume auto-tuning finds small MC/NC tasks. - ParallelizeOneRange( - ranges_nc, all_clusters, caller, - [&](const IndexRange range_nc, size_t cluster_idx) { - const size_t cluster_base = ctx.Worker(cluster_idx); - hwy::ThreadPool& cluster = ctx.pools.Cluster(cluster_idx); - ParallelizeOneRange(ranges_mc, cluster, caller, - [&](const IndexRange& range_mc, size_t worker) { - func(range_mc, range_nc, cluster_base + worker); - }); + ParallelPartitionAcrossClusters( + all_range, /*task_multiple=*/1, /*inner_tasks=*/1, ctx, caller, + [&](const IndexRange& cluster_range, size_t cluster_idx) { + ParallelForWithinCluster(cluster_range.Num(), ctx, cluster_idx, + caller, [&](uint64_t i, size_t worker) { + const size_t task = + cluster_range.begin() + i; + func(get_mc(task), get_nc(task), worker); + }); }); } - // Calls `func(row_a, worker)` in parallel. + // No multiple/inner_tasks, so this is just HierarchicalParallelFor. template void ForRangeMC(ThreadingContext& ctx, const IndexRange& range_mc, size_t caller_cluster_idx, const Func& func) const { - HierarchicalParallelFor(range_mc.Num(), ctx, Callers::kMMHierForMC, - [&](size_t task, size_t worker) { - func(range_mc.begin() + task, worker); - }); + HWY_DASSERT(caller_cluster_idx == 0); + (void)caller_cluster_idx; + HierarchicalParallelFor( + range_mc.Num(), ctx, Callers::kMMHierForMC, + [&](size_t i, size_t worker) { func(range_mc.begin() + i, worker); }); } }; template -void DispatchParallelism(ParallelismStrategy parallelism, const Func& func, +void DispatchParallelism(Parallelism parallelism, const Func& func, Args&&... args) { switch (parallelism) { - case ParallelismStrategy::kNone: + case Parallelism::kNone: return func(MMParallelNone(), std::forward(args)...); - case ParallelismStrategy::kWithinCluster: + case Parallelism::kWithinCluster: return func(MMParallelWithinCluster(), std::forward(args)...); - case ParallelismStrategy::kHierarchical: + case Parallelism::kHierarchical: return func(MMParallelHierarchical(), std::forward(args)...); default: HWY_UNREACHABLE; @@ -371,8 +331,8 @@ void DispatchOrder(MMOrder order, const Func& func, Args&&... args) { } } -static inline bool IsBlock(MMOrder order) { - return order == MMOrder::kNT_MT_K || order == MMOrder::kNT_MT; +static inline bool IsOneMC(MMOrder order) { + return order == MMOrder::kNT || order == MMOrder::kNT_K; } static inline bool IsOneKC(MMOrder order) { @@ -421,6 +381,8 @@ static inline const char* StringFromParA(MMParA par_a) { // `mc` := A rows such that `kc` columns fit in L2, // `nc` := B rows such that `kc` columns fit in L3 alongside `mc x nc` C. // Also includes loop order and task granularity. +// +// This is shared by multiple M which return the same `BucketM`. #pragma pack(push, 1) class MMConfig { public: @@ -428,8 +390,8 @@ class MMConfig { // `mr` is the number of A rows per call to `MMKernel::LoopKC`. // `MMOrder` is how to parallelize the outer loops. // `inner_tasks` chooses the within-cluster task granularity in `ForN`. - MMConfig(size_t K, size_t N, size_t mr, size_t mc, size_t kc, size_t nc, - size_t kc_multiple, size_t nc_multiple, MMOrder order, + MMConfig(size_t M, size_t K, size_t N, size_t mr, size_t mc, size_t kc, + size_t nc, size_t kc_multiple, size_t nc_multiple, MMOrder order, int inner_tasks) : mr_(static_cast(mr)), mc_(static_cast(mc)), @@ -441,11 +403,7 @@ class MMConfig { inner_tasks_(static_cast(inner_tasks)), reserved_{} { HWY_DASSERT(mr == 1 || mr == 2 || mr == 4); - if (mc % mr != 0) { - HWY_WARN("mc %zu not a multiple of mr %zu", mc, mr); - } - // Do not warn for single-kc tasks; some models unfortunately have K which - // are not multiples of `kc_multiple`. + // Some models have K which are not multiples of `kc_multiple`. if (kc != K && (kc % kc_multiple) != 0) { HWY_WARN("kc %zu not a multiple of kc_multiple %zu", kc, kc_multiple); } @@ -457,11 +415,21 @@ class MMConfig { } // Splits M/N into blocks which are visited sequentially or in parallel. - // K is always sequential, see `MMOrder`. IndexRangePartition RangesOfMC(size_t M) const { - return MaxSizePartition(IndexRange(0, M), mc_, mr_); + if (IsOneMC(order_)) { + // Must have exactly one M range/tile, regardless of `mr_` and `mc_`. + return IndexRangePartition(M); + } + const size_t mc = HWY_MIN(M, MC()); + const size_t mr = HWY_MIN(M, MR()); + return MaxSizePartition(IndexRange(0, M), mc, mr); } + // K is either a single range, or a sequential loop. IndexRangePartition RangesOfKC(size_t K) const { + if (IsOneKC(order_)) { + // Must have exactly one K range/tile, regardless of `kc_`. + return IndexRangePartition(K); + } return MaxSizePartition(IndexRange(0, K), kc_, kc_multiple_); } IndexRangePartition RangesOfNC(size_t N) const { @@ -488,7 +456,7 @@ class MMConfig { uint32_t kc_multiple_; MMOrder order_; uint8_t inner_tasks_; - HWY_MAYBE_UNUSED uint8_t reserved_[6]; + HWY_MEMBER_VAR_MAYBE_UNUSED uint8_t reserved_[6]; }; static_assert(sizeof(MMConfig) == 32); // for faster indexing #pragma pack(pop) @@ -597,26 +565,27 @@ class MMAutoTune { //------------------------------------------------------------------------------ -// Minimum M, in units of tile rows of height mr={1, 2, 4}, from which -// `MMOrder::kNT[_K]` are no longer allowed. They require a single MC range, -// but choosing the same config for a larger M can result in multiple MC ranges. -// Thus M less than this must have unique keys/configs. -HWY_INLINE_VAR constexpr size_t kMaxTilesM = 8; - // Map of previously seen dimensions to index via linear search. class MMKeys { - // Group batch size into buckets to reduce #auto-tunes. - static size_t BucketM(size_t M) { - if (M < kMaxTilesM * kMaxMR) return M; // See kMaxTilesM above. - if (M <= 128) return 128; - return 512; - } - public: using Key = uint64_t; // KeyFromDims will only return this if all dims are zero, which is invalid. static constexpr Key kPadding = 0; + // Returns the maximum permissible M in the bucket, for grouping batch sizes + // into buckets to reduce #auto-tunes. + static size_t BucketM(size_t M) { + HWY_DASSERT(M != 0); + // Small M: 1..3, 4..7, 8..15, etc. share the same config. + if (M < 64) return M | (kMaxMR - 1); + // Larger M use power of two buckets: 64..127, 128..255, etc. + const size_t floor_log2_M = + 31 - hwy::Num0BitsAboveMS1Bit_Nonzero32(static_cast(M)); + const size_t min_M = size_t{1} << floor_log2_M; + HWY_DASSERT(min_M <= M && M < 2 * min_M); + return 2 * min_M - 1; + } + // Compresses the dimensions into a single Key for faster comparison. static Key KeyFromDims(size_t M, size_t K, size_t N, size_t num_B) { HWY_DASSERT(M < (Key{1} << 16)); // batch sizes are smaller @@ -747,7 +716,7 @@ class MMOptions { const void* opaque = nullptr; uint32_t cluster_idx = 0; // for `parallelism == kWithinCluster`. - ParallelismStrategy parallelism = ParallelismStrategy::kHierarchical; + Parallelism parallelism = Parallelism::kHierarchical; }; // Arguments to MatMul() that are independent of the A/B/C types. Reduces diff --git a/ops/matmul_test.cc b/ops/matmul_test.cc index 2aaf301..4787122 100644 --- a/ops/matmul_test.cc +++ b/ops/matmul_test.cc @@ -195,9 +195,10 @@ HWY_INLINE void MatMulSlow(const MatPtrT A, const MatPtrT B, const size_t multiple = env.ctx.allocator.QuantumBytes() / sizeof(TB); const IndexRangePartition get_col_c = StaticPartition(all_cols_c, all_clusters.NumWorkers(), multiple); - ParallelizeOneRange( - get_col_c, all_clusters, env.ctx.pool_callers.Get(Callers::kTest), - [&](const IndexRange& cols_c, size_t cluster_idx) HWY_ATTR { + ParallelForAcrossClusters( + get_col_c.NumTasks(), env.ctx, env.ctx.pool_callers.Get(Callers::kTest), + [&](size_t range_idx, size_t cluster_idx) HWY_ATTR { + const IndexRange cols_c = get_col_c.Range(range_idx); for (size_t r : all_rows_c) { TC* HWY_RESTRICT C_row = C.Row(r); for (size_t c : cols_c) { diff --git a/ops/ops-inl.h b/ops/ops-inl.h index f2933ca..0eeec31 100644 --- a/ops/ops-inl.h +++ b/ops/ops-inl.h @@ -25,9 +25,11 @@ #include #include #include // std::enable_if_t +#include #include #include "ops/matmul.h" +#include "ops/ops.h" #include "util/allocator.h" #include "util/basics.h" // TokenAndProb, RngStream #include "util/mat.h" @@ -61,6 +63,9 @@ namespace gcpp { namespace HWY_NAMESPACE { namespace hn = hwy::HWY_NAMESPACE; +// Computes C = A * B + add via MatMulStatic. +// This function uses CallUpcasted to dispatch to the correct MatMulStatic +// instantiation based on the runtime type of B. template MMPerKey* CallMatMul(const MatPtrT& A, const MatPtr& B, const float* HWY_RESTRICT add, MatMulEnv& env, @@ -497,10 +502,10 @@ void RMSNormBatched(const MatPtrT& activations, const MatPtr& weights, size_t cluster_idx = 0) { HWY_DASSERT(weights.Rows() == 1); HWY_DASSERT(weights.Cols() == activations.Cols()); - HWY_DASSERT(activations.SameShape(out)); + activations.DebugCheckSameShape(out); CallUpcasted(&weights, [&](const auto* weights_t) { - ParallelFor(ParallelismStrategy::kFlat, activations.Rows(), ctx, + ParallelFor(Parallelism::kFlat, activations.Rows(), ctx, cluster_idx, Callers::kOpsRMSNormBatched, [&](uint64_t token_idx, size_t worker) { RMSNorm(activations.Row(token_idx), weights_t->PackedScale1(), @@ -517,7 +522,7 @@ void RMSNormInplaceBatched(const MatPtr& weights, MatPtrT& inout, HWY_DASSERT(weights.Cols() == inout.Cols()); CallUpcasted(&weights, [&](const auto* weights_t) { - ParallelFor(ParallelismStrategy::kFlat, inout.Rows(), ctx, cluster_idx, + ParallelFor(Parallelism::kFlat, inout.Rows(), ctx, cluster_idx, Callers::kOpsRMSNormInplaceBatched, [&](uint64_t token_idx, size_t worker) { RMSNormInplace(weights_t->PackedScale1(), /*w_ofs=*/0, @@ -550,7 +555,7 @@ static HWY_INLINE void AddFromBatched(const MatPtrT& x, MatPtrT& out, size_t cluster_idx = 0) { HWY_DASSERT(out.SameShape(x)); ParallelFor( - ParallelismStrategy::kFlat, out.Rows(), ctx, cluster_idx, + Parallelism::kFlat, out.Rows(), ctx, cluster_idx, Callers::kOpsAddFromBatched, [&](uint64_t token_idx, size_t worker) { AddFrom(x.Row(token_idx), out.Row(token_idx), x.Cols(), ctx, worker); }); @@ -1122,9 +1127,25 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddVector( // See below for a specialized version for top-1 sampling. // TODO: support bf16 logits using Decompress2. +// Computes softmax probabilities for the given logits, normalizing in-place. +// The calculation is numerically stable, using the max-subtraction trick to +// compute exp(logits[i] - max(logits)) before normalizing by the sum. +// If temperature is provided and not 1.0, each intermediate exp() result is +// divided by temperature before normalization; however, this division by +// temperature cancels out during the final normalization step, meaning +// temperature currently has no effect on the output probabilities. +// @param logits In-out: on input, contains logits; on output, overwritten with +// probabilities. +// @param ctx Input: threading context for parallelism and profiling. +// @param worker Input: worker thread index. +// @param temperature Input: softmax temperature. +// @param softmax_max_out Optional output: if not null, stores the max logit +// value. +// @param softmax_d_out Optional output: if softmax_max is not null, this must +// not be null and stores the sum of exp(logit - max). static HWY_NOINLINE void Softmax(Logits logits, ThreadingContext& ctx, - const size_t worker, - float temperature = 1.0f) { + const size_t worker, float temperature = 1.0f, + const SMOptions& sm_options = {}) { GCPP_ZONE(ctx, worker, Zones::kOpsSoftmax); HWY_DASSERT(logits.size() != 0); @@ -1168,6 +1189,10 @@ static HWY_NOINLINE void Softmax(Logits logits, ThreadingContext& ctx, // Double-precision reciprocal does not appear to affect the results. const float mul = 1.0f / sum_exp; MulByConst(mul, logits.data(), logits.size()); + if (sm_options.max_out) { + *sm_options.max_out = hn::GetLane(vmax); + *sm_options.d_out = sum_exp; + } } // Note: https://arxiv.org/pdf/2001.04438 proposes to replace the three max / @@ -1290,7 +1315,7 @@ static HWY_INLINE HWY_MAYBE_UNUSED void MaybeLogitsSoftCapBatched( const float cap, MatPtrT& x, const hwy::BitSet4096<>& non_eos, ThreadingContext& ctx, size_t cluster_idx = 0) { if (cap == 0.0f) return; - ParallelFor(ParallelismStrategy::kFlat, x.Rows(), ctx, cluster_idx, + ParallelFor(Parallelism::kFlat, x.Rows(), ctx, cluster_idx, Callers::kOpsMaybeLogitsSoftCapBatched, [&](uint64_t task, size_t worker) { if (non_eos.Get(task)) { diff --git a/ops/ops.h b/ops/ops.h index 03b023b..002cb97 100644 --- a/ops/ops.h +++ b/ops/ops.h @@ -41,6 +41,11 @@ static inline HWY_MAYBE_UNUSED MatStorageT CreateInvTimescale( return inv_timescale; } +struct SMOptions { + float* HWY_RESTRICT max_out = nullptr; + float* HWY_RESTRICT d_out = nullptr; +}; + } // namespace gcpp #endif // THIRD_PARTY_GEMMA_CPP_OPS_OPS_H_ diff --git a/ops/ops_test.cc b/ops/ops_test.cc index 4d94b61..0f83df1 100644 --- a/ops/ops_test.cc +++ b/ops/ops_test.cc @@ -14,7 +14,6 @@ // limitations under the License. #include "compression/types.h" -#include "util/zones.h" #ifndef HWY_DISABLED_TARGETS #define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS #endif // HWY_DISABLED_TARGETS @@ -38,7 +37,6 @@ #include "util/mat.h" // MatStorageT #include "util/test_util.h" #include "util/threading_context.h" -#include "hwy/profiler.h" #include "hwy/tests/hwy_gtest.h" // clang-format off @@ -348,6 +346,51 @@ void TestAllSoftmax() { hn::ForPartialVectors>()(float()); } +class TestSoftmaxState { + public: + template + void operator()(D d, size_t count, size_t misalign_a, size_t misalign_b, + hwy::RandomState& rng) { + if (count == 0) return; // *Softmax would assert + if (misalign_b == 0) return; + using T = hn::TFromD; + + hwy::AlignedFreeUniquePtr px = + hwy::AllocateAligned(HWY_MAX(1, misalign_a + count)); + hwy::AlignedFreeUniquePtr pe = + hwy::AllocateAligned(HWY_MAX(1, misalign_a + count)); + HWY_ASSERT(px && pe); + + T* x = px.get() + misalign_a; + T* initial_logits = pe.get() + misalign_a; + + for (size_t i = 0; i < count; ++i) { + x[i] = Random(rng); + initial_logits[i] = x[i]; + } + + float softmax_max; + float softmax_d; + Softmax(Logits(x, count), Ctx(), /*worker=*/0, /*temperature=*/1.0f, + {.max_out = &softmax_max, .d_out = &softmax_d}); + + const float maxval = + *std::max_element(initial_logits, initial_logits + count); + + float sum_exp = 0.0f; + for (size_t i = 0; i < count; ++i) { + sum_exp += std::exp(initial_logits[i] - maxval); + } + + ASSERT_NEAR(softmax_max, maxval, 1e-6); + ASSERT_NEAR(softmax_d, sum_exp, 1e-6); + } +}; + +void TestAllSoftmaxState() { + hn::ForPartialVectors>()(float()); +} + template struct TestCreateDistribution { void operator()(hwy::RandomState& rng) { @@ -456,7 +499,7 @@ void TestRopeAndMulBy() { x.Row(0)[i] = random_float(); } - const float qmul = AttentionActivations::ChooseQueryScale(config); + const float qmul = ChooseQueryScale(config); constexpr float kmul = 1.0f; MatStorageT qexpected("qexpected", dim_qkv, ctx.allocator); @@ -771,6 +814,7 @@ HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConst); HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConstTo); HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConstAndAdd); HWY_EXPORT_AND_TEST_P(OpsTest, TestAllSoftmax); +HWY_EXPORT_AND_TEST_P(OpsTest, TestAllSoftmaxState); HWY_EXPORT_AND_TEST_P(OpsTest, TestAllCreateDistribution); HWY_EXPORT_AND_TEST_P(OpsTest, TestAllSigmoid); HWY_EXPORT_AND_TEST_P(OpsTest, TestAllGelu); diff --git a/paligemma/BUILD.bazel b/paligemma/BUILD.bazel index 7a2e870..cc6c6e1 100644 --- a/paligemma/BUILD.bazel +++ b/paligemma/BUILD.bazel @@ -29,6 +29,7 @@ cc_test( deps = [ ":image", "@googletest//:gtest_main", # buildcleaner: keep + "@highway//:hwy", ], ) diff --git a/paligemma/image.cc b/paligemma/image.cc index 20ecad8..d8b0cfc 100644 --- a/paligemma/image.cc +++ b/paligemma/image.cc @@ -37,8 +37,6 @@ namespace gcpp { namespace { -// Hardcoded for PaliGemma ViT input. -constexpr size_t kPatchSize = 14; // Returns the linearly scaled index in [0, to_size) closest to the // value in [0, from_size). @@ -208,24 +206,25 @@ bool Image::WriteBinary(const std::string& filename) const { } // Image.data() is H x W x 3. -// We want the N-th patch of size kPatchSize x kPatchSize x 3. -void Image::GetPatch(size_t patch_num, float* patch) const { +// We want the N-th patch of size patch_dim x patch_dim x 3. +void Image::GetPatch(size_t patch_num, const hwy::Divisor& div_patch_dim, + float* patch) const { PROFILER_FUNC; constexpr size_t kNumChannels = 3; - constexpr size_t kBytesPerPixel = (kNumChannels * sizeof(float)); - constexpr size_t kBytesPerRow = (kPatchSize * kBytesPerPixel); - const size_t kDataSize = width_ * height_ * kNumChannels; + constexpr size_t kBytesPerPixel = kNumChannels * sizeof(float); + const size_t patch_dim = div_patch_dim.GetDivisor(); + const size_t bytes_per_row = (patch_dim * kBytesPerPixel); const size_t in_bytes_to_next_row = (width_ * kBytesPerPixel); - HWY_ASSERT(size() == kDataSize); - HWY_ASSERT(width_ % kPatchSize == 0); - HWY_ASSERT(height_ % kPatchSize == 0); - const size_t kNumPatchesPerRow = width_ / kPatchSize; - size_t patch_y = patch_num / kNumPatchesPerRow; - size_t patch_x = patch_num % kNumPatchesPerRow; - HWY_ASSERT(0 <= patch_y && patch_y < height_ / kPatchSize); - HWY_ASSERT(0 <= patch_x && patch_x < kNumPatchesPerRow); - patch_y *= kPatchSize; - patch_x *= kPatchSize; + HWY_ASSERT(size() == width_ * height_ * kNumChannels); + HWY_ASSERT(div_patch_dim.Remainder(width_) == 0); + HWY_ASSERT(div_patch_dim.Remainder(height_) == 0); + const size_t patches_x = div_patch_dim.Divide(width_); + size_t patch_y = patch_num / patches_x; + size_t patch_x = patch_num % patches_x; + HWY_DASSERT(0 <= patch_y && patch_y < div_patch_dim.Divide(height_)); + HWY_DASSERT(0 <= patch_x && patch_x < patches_x); + patch_y *= patch_dim; + patch_x *= patch_dim; // Move `out` and `in` to the start of the patch. char* out = reinterpret_cast(patch); @@ -233,9 +232,9 @@ void Image::GetPatch(size_t patch_num, float* patch) const { in += (((patch_y * width_) + patch_x) * kBytesPerPixel); // Copy the patch one row at a time. - for (size_t y = 0; y < kPatchSize; ++y) { - std::memcpy(out, in, kBytesPerRow); - out += kBytesPerRow; + for (size_t y = 0; y < patch_dim; ++y) { + std::memcpy(out, in, bytes_per_row); + out += bytes_per_row; in += in_bytes_to_next_row; } } diff --git a/paligemma/image.h b/paligemma/image.h index e0b1530..e54bf86 100644 --- a/paligemma/image.h +++ b/paligemma/image.h @@ -21,6 +21,7 @@ #include #include "hwy/aligned_allocator.h" // Span +#include "hwy/base.h" // Divisor namespace gcpp { @@ -44,11 +45,12 @@ class Image { bool WriteBinary(const std::string& filename) const; // Stores the patch for the given patch number in `patch`. // Patches are numbered in usual raster-order. E.g. for an image of size - // 224 x 224, there are 16 x 16 = 256 patches. - // `patch` should have space for at least 14 * 14 * 3 = 588 floats. + // 224 x 224 and patch_dim = 14, there are 16 x 16 = 256 patches. + // `patch` should have space for at least patch_dim * patch_dim * 3. // Requires that Normalize() has been called and that the image width and - // height are multiples of 14. - void GetPatch(size_t patch_num, float* patch) const; + // height are multiples of patch_dim. + void GetPatch(size_t patch_num, const hwy::Divisor& div_patch_dim, + float* patch) const; float *data() { return data_.data(); } const float *data() const { return data_.data(); } diff --git a/paligemma/image_test.cc b/paligemma/image_test.cc index e2c4bbf..3721363 100644 --- a/paligemma/image_test.cc +++ b/paligemma/image_test.cc @@ -20,6 +20,7 @@ #include #include "gtest/gtest.h" +#include "hwy/base.h" namespace gcpp { namespace { @@ -61,11 +62,12 @@ TEST(ImageTest, LoadResize224GetPatch) { EXPECT_EQ(image.data()[image.size() - 1], Normalize(122)); // Extract two patches. float patch[588]; - image.GetPatch(0, patch); + const hwy::Divisor div_patch_dim(14); + image.GetPatch(0, div_patch_dim, patch); EXPECT_EQ(patch[0], Normalize(160)); EXPECT_EQ(patch[1], Normalize(184)); EXPECT_EQ(patch[2], Normalize(188)); - image.GetPatch(18, patch); + image.GetPatch(18, div_patch_dim, patch); // Check the first row of the patch. for (size_t i = 0; i < 14 * 3; ++i) { EXPECT_EQ(patch[i], image.data()[(14 * 224 + 2 * 14) * 3 + i]); @@ -108,14 +110,15 @@ TEST(ImageTest, Non224) { // Extract two patches. const size_t kPatchValues = 14 * 14 * 3; // = 588 float patch[kPatchValues]; + const hwy::Divisor div_patch_dim(14); // Patch 0 is just the "start" of the image. - image.GetPatch(0, patch); + image.GetPatch(0, div_patch_dim, patch); EXPECT_NEAR(patch[0], Normalize(0.0f, max_value), 1e-6); EXPECT_NEAR(patch[1], Normalize(1.0f, max_value), 1e-6); EXPECT_NEAR(patch[2], Normalize(2.0f, max_value), 1e-6); // The "image" has 4x3 patches, so patch 6 has coordinates (1, 2) and its // pixel coordinates are offset by (14, 28). - image.GetPatch(6, patch); + image.GetPatch(6, div_patch_dim, patch); for (size_t n = 0; n < kPatchValues; ++n) { size_t k = n % 3; size_t j = ((n - k) / 3) % 14; diff --git a/paligemma/paligemma_test.cc b/paligemma/paligemma_test.cc index 0a7401a..7bfd78c 100644 --- a/paligemma/paligemma_test.cc +++ b/paligemma/paligemma_test.cc @@ -21,10 +21,9 @@ #include "evals/benchmark_helper.h" #include "gemma/configs.h" #include "gemma/gemma.h" -#include "io/io.h" +#include "paligemma/paligemma_helper.h" #include "util/allocator.h" #include "hwy/tests/hwy_gtest.h" -#include "paligemma/paligemma_helper.h" // This test can be run manually with the downloaded PaliGemma weights. // It should pass for `paligemma-3b-mix-224` and `paligemma2-3b-pt-448`. @@ -72,9 +71,12 @@ TEST_F(PaliGemmaTest, QueryObjects) { int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); - gcpp::InternalInit(); - gcpp::GemmaEnv env(argc, argv); + gcpp::ConsumedArgs consumed(argc, argv); + gcpp::GemmaArgs args(argc, argv, consumed); + consumed.AbortIfUnconsumed(); + + gcpp::GemmaEnv env(args); gcpp::s_env = &env; return RUN_ALL_TESTS(); diff --git a/python/configs.cc b/python/configs.cc index e544bb0..0d505dc 100644 --- a/python/configs.cc +++ b/python/configs.cc @@ -173,6 +173,8 @@ PYBIND11_MODULE(configs, py_module) { .def_readwrite("secondary_eos_id", &ModelConfig::secondary_eos_id) .def_readwrite("scale_base_names", &ModelConfig::scale_base_names) .def_readwrite("internal", &ModelConfig::internal) + .def_readwrite("use_global_timescale", + &ModelConfig::use_global_timescale) .def("add_layer_config", &ModelConfig::AddLayerConfig, arg("layer_config")) diff --git a/python/gemma_py.cc b/python/gemma_py.cc index 1bab194..0d056d9 100644 --- a/python/gemma_py.cc +++ b/python/gemma_py.cc @@ -45,10 +45,7 @@ static void RemoveTrailingZeros(std::vector &vec) { // Wrapper around GemmaEnv to expose to Python. class GemmaModel { public: - GemmaModel(const gcpp::LoaderArgs& loader, - const gcpp::ThreadingArgs& threading, - const gcpp::InferenceArgs& inference) - : env_(loader, threading, inference), last_prob_(0.0f) {} + GemmaModel(const gcpp::GemmaArgs& args) : env_(args), last_prob_(0.0f) {} // Generates a single example, given a prompt and a callback to stream the // generated tokens. @@ -254,13 +251,15 @@ PYBIND11_MODULE(gemma, mod) { py::class_(mod, "GemmaModel") .def(py::init([](const std::string& tokenizer, const std::string& weights, size_t max_threads) { - const gcpp::LoaderArgs loader(tokenizer, weights); gcpp::ThreadingArgs threading; threading.max_lps = max_threads; + gcpp::InferenceArgs inference; inference.max_generated_tokens = 512; - auto gemma = - std::make_unique(loader, threading, inference); + + const gcpp::GemmaArgs args(gcpp::LoaderArgs(tokenizer, weights), + threading, inference); + auto gemma = std::make_unique(args); if (!gemma->ModelIsLoaded()) { throw std::invalid_argument("Could not load model."); } diff --git a/util/args.h b/util/args.h index 32d54fb..8c6423b 100644 --- a/util/args.h +++ b/util/args.h @@ -22,6 +22,7 @@ #include // std::transform #include +#include #include "io/io.h" // Path #include "util/basics.h" // Tristate @@ -29,6 +30,56 @@ namespace gcpp { +// For checking which args were not matched/consumed. Passed to each `*Args` +// ctor that parses argc/argv to ensure that their args are tracked, without +// requiring global state. +class ConsumedArgs { + public: + ConsumedArgs(int argc, char** argv) : argv_(argv), consumed_(argc) { + // We assume argc >= 1, because argv[0] is the binary name. That allows us + // to signal "called AbortIfUnconsumed" with an empty vector. + HWY_ASSERT(!consumed_.empty()); + } + + ~ConsumedArgs() { + if (HWY_UNLIKELY(!consumed_.empty())) { + HWY_ABORT("AbortIfUnconsumed was not called."); + } + } + + void NotifyConsumed(size_t idx) { + HWY_ASSERT(idx < consumed_.size()); + HWY_ASSERT(consumed_[idx] == 0); + consumed_[idx] = 1; + } + + // Returns index of first unconsumed arg, or 0 if none. Also disarms the + // warning in the dtor checking whether this/`AbortIfUnconsumed` were called. + size_t FirstUnconsumed() { + // Ignore argv[0], which is the binary name. + for (size_t i = 1; i < consumed_.size(); ++i) { + if (HWY_UNLIKELY(consumed_[i] == 0)) { + consumed_.clear(); + return i; + } + } + + consumed_.clear(); + return 0; + } + + void AbortIfUnconsumed() { + const size_t i = FirstUnconsumed(); + if (HWY_UNLIKELY(i != 0)) { + HWY_ABORT("Unrecognized arg %zu: %s\n", i, argv_[i]); + } + } + + private: + char** argv_; + std::vector consumed_; +}; + // Args is a class that provides a ForEach member function which visits each of // its member variables. ArgsBase provides functions called by Args to // initialize values to their defaults (passed as an argument to the visitor), @@ -93,12 +144,14 @@ class ArgsBase { // consider adding a hash-map to speed this up. class ParseVisitor { public: - ParseVisitor(int argc, char* argv[]) : argc_(argc), argv_(argv) {} + ParseVisitor(int argc, char* argv[], ConsumedArgs& consumed) + : argc_(argc), argv_(argv), consumed_(consumed) {} template void operator()(T& t, const char* name, const T& /*init*/, const char* /*help*/, int /*print_verbosity*/ = 0) const { const std::string prefixed = std::string("--") + name; + const std::string prefixed_eq = prefixed + "="; for (int i = 1; i < argc_; ++i) { if (std::string(argv_[i]) == prefixed) { if (i + 1 >= argc_) { @@ -107,6 +160,16 @@ class ArgsBase { if (!SetValue(argv_[i + 1], t)) { HWY_ABORT("Invalid value for %s, got %s\n", name, argv_[i + 1]); } + consumed_.NotifyConsumed(i); + consumed_.NotifyConsumed(i + 1); + return; + } + if (std::string(argv_[i]).find(prefixed_eq) == 0) { + const char* value = argv_[i] + prefixed_eq.length(); + if (!SetValue(value, t)) { + HWY_ABORT("Invalid value for %s, got %s\n", name, value); + } + consumed_.NotifyConsumed(i); return; } } @@ -173,8 +236,9 @@ class ArgsBase { } } - int argc_; - char** argv_; + const int argc_; + char** const argv_; + ConsumedArgs& consumed_; }; // ParseVisitor template @@ -203,15 +267,15 @@ class ArgsBase { ForEach(visitor); } - void Parse(int argc, char* argv[]) { - ParseVisitor visitor(argc, argv); + void Parse(int argc, char* argv[], ConsumedArgs& consumed) { + ParseVisitor visitor(argc, argv, consumed); ForEach(visitor); } // For convenience, enables single-line constructor. - void InitAndParse(int argc, char* argv[]) { + void InitAndParse(int argc, char* argv[], ConsumedArgs& consumed) { Init(); - Parse(argc, argv); + Parse(argc, argv, consumed); } }; diff --git a/util/basics.h b/util/basics.h index 5a7f0d5..49996ba 100644 --- a/util/basics.h +++ b/util/basics.h @@ -33,6 +33,9 @@ namespace gcpp { // For hwy::BitSet4096. Note that KVs are extremely large for such batches. HWY_INLINE_VAR constexpr size_t kMaxBatchSize = 4096; +// Multiplier so a u64 occupies an entire cache line; avoids false sharing. +HWY_INLINE_VAR constexpr size_t kU64PerLine = HWY_ALIGNMENT / sizeof(uint64_t); + enum class Tristate : int32_t { kFalse = 0, kTrue = 1, kDefault = -1 }; static inline const char* ToString(Tristate t) { diff --git a/util/mat.h b/util/mat.h index 59eceaa..83d03b1 100644 --- a/util/mat.h +++ b/util/mat.h @@ -181,7 +181,15 @@ class MatPtr : public IFields { Extents2D Extents() const { return Extents2D(Rows(), cols_); } bool IsEmpty() const { return Rows() == 0 || cols_ == 0; } bool SameShape(const MatPtr& other) const { - return Rows() == other.Rows() && cols_ == other.cols_; + return Rows() == other.Rows() && Cols() == other.Cols(); + } + void DebugCheckSameShape(const MatPtr& other) const { + if constexpr (HWY_IS_DEBUG_BUILD) { + if (!SameShape(other)) { + HWY_ABORT("%s: shape mismatch %zu x %zu vs %zu x %zu\n", name_.c_str(), + Rows(), Cols(), other.Rows(), Cols()); + } + } } // Future calls to `Rows()` during this class' lifetime (not serialized) // will return this value. Used to set the actual number of rows for @@ -284,6 +292,9 @@ class MatPtrT : public MatPtr { public: using T = MatT; + // Default constructor for use with uninitialized views. + MatPtrT() = default; + // Called by `MatStorageT`. MatPtrT(const char* name, Extents2D extents) : MatPtr(name, TypeEnum(), extents) {} @@ -296,7 +307,10 @@ class MatPtrT : public MatPtr { if (GetType() == Type::kUnknown) { SetType(TypeEnum()); } else { - HWY_ASSERT(other.GetType() == TypeEnum()); + if (HWY_UNLIKELY(other.GetType() != TypeEnum())) { + HWY_ABORT("Type mismatch: MatT %s, constructing from %s", + TypeName(), TypeName(other.GetType())); + } } } MatPtrT& operator=(const MatPtr& other) { @@ -440,6 +454,21 @@ decltype(auto) CallUpcastedActivation(const MatPtr* base, const Func& func, } } +// Like CallUpcasted, but only for kv_cache types: kBF16 and kF32. +template +decltype(auto) CallUpcastedKV(const MatPtr* base, const Func& func, + Args&&... args) { + if (base->GetType() == Type::kF32) { + const MatPtrT mat(*base); + return func(&mat, std::forward(args)...); + } else if (base->GetType() == Type::kBF16) { + const MatPtrT mat(*base); + return func(&mat, std::forward(args)...); + } else { + HWY_ABORT("Unhandled type %s.", TypeName(base->GetType())); + } +} + void CopyMat(const MatPtr& from, MatPtr& to); void ZeroInit(MatPtr& mat); diff --git a/util/test_util.h b/util/test_util.h index 355b096..f0c37f9 100644 --- a/util/test_util.h +++ b/util/test_util.h @@ -19,20 +19,51 @@ #include #include +#include // std::sort #include +#include +#include "util/basics.h" // RngStream +#include "util/mat.h" #include "hwy/base.h" // IWYU pragma: begin_exports +#include "hwy/nanobenchmark.h" #include "hwy/stats.h" #include "hwy/tests/test_util.h" // RandomState // IWYU pragma: end_exports namespace gcpp { +// Excludes outliers; we might not have enough samples for a reliable mode. +HWY_INLINE double TrimmedMean(double* seconds, size_t num) { + std::sort(seconds, seconds + num); + double sum = 0; + int count = 0; + for (size_t i = num / 4; i < num / 2; ++i) { + sum += seconds[i]; + count += 1; + } + HWY_DASSERT(num != 0); + return sum / count; +} + +// Returns normalized value in [-1, 1). +HWY_INLINE float RandomFloat(RngStream& rng) { + const uint32_t exp = hwy::BitCastScalar(1.0f); + const uint32_t mantissa_mask = hwy::MantissaMask(); + const uint32_t representation = exp | (rng() & mantissa_mask); + const float f12 = hwy::BitCastScalar(representation); + HWY_DASSERT(1.0f <= f12 && f12 < 2.0f); // exponent is 2^0, only mantissa + const float f = (2.0f * (f12 - 1.0f)) - 1.0f; + HWY_DASSERT(-1.0f <= f && f < 1.0f); + return f; +} + // Returns random Gaussian (mean=0, stddev=1/3 similar to expected weights) // using the central limit theorem. Avoid std::normal_distribution for // consistent cross-platform output. +// TODO: use RngStream instead of RandomState. HWY_INLINE double RandomGaussian(hwy::RandomState& rng) { uint64_t sum = 0; constexpr int kReps = 40; @@ -71,6 +102,25 @@ HWY_INLINE void VerifyGaussian(hwy::Stats& stats) { HWY_ASSERT(IsNear(3.0, stats.Kurtosis(), 0.3)); } +template +void FillMatPtrT(MatPtrT& mat) { + for (int i = 0; i < mat.Rows(); ++i) { + for (int j = 0; j < mat.Cols(); ++j) { + mat.Row(i)[j] = hwy::Unpredictable1() * 0.01f * (i + j + 1); + } + } +} + +template +void PrintMatPtr(MatPtrT mat) { + for (int i = 0; i < mat.Rows(); ++i) { + for (int j = 0; j < mat.Cols(); ++j) { + std::cerr << mat.Row(i)[j] << " ,"; + } + std::cerr << std::endl; + } +}; + } // namespace gcpp #endif // THIRD_PARTY_GEMMA_CPP_UTIL_TEST_UTIL_H_ diff --git a/util/threading.h b/util/threading.h index dcdcf24..3fb0227 100644 --- a/util/threading.h +++ b/util/threading.h @@ -187,7 +187,9 @@ class NestedPools { // functions below. class IndexRangePartition { public: - IndexRangePartition() = default; // for MMPartitions + explicit IndexRangePartition(size_t single_task) + : range_(0, single_task), task_size_(single_task), num_tasks_(1) {} + IndexRangePartition(const IndexRange& range, const size_t task_size) : range_(range), task_size_(static_cast(task_size)) { const uint32_t num = static_cast(range.Num()); @@ -262,43 +264,6 @@ static inline IndexRangePartition StaticPartition(const IndexRange& range, return IndexRangePartition(range, size); } -// Parallel-for over a single range. This takes care of translating the task -// index to a range. -template -void ParallelizeOneRange(const IndexRangePartition& get1, hwy::ThreadPool& pool, - hwy::pool::Caller caller, const Func& func) { - const size_t num_tasks = get1.NumTasks(); - pool.Run(0, num_tasks, caller, [&](uint64_t task, size_t thread) { - const IndexRange range1 = get1.Range(task); - func(range1, thread); - }); -} - -// Parallel-for over the Cartesian product of the two sets of ranges. This -// combines their indices into a single 'task' so they can be executed by one -// `pool.Run`, which increases the amount of work available to workers and -// reduces fork-join overhead vs. nested parallel-for loops. Calls `func` with -// the two ranges and the thread index within `pool`. -template -void ParallelizeTwoRanges(const IndexRangePartition& get1, - const IndexRangePartition& get2, - hwy::ThreadPool& pool, hwy::pool::Caller caller, - const Func& func) { - const hwy::Divisor div1(static_cast(get1.NumTasks())); - - const size_t num_tasks = get1.NumTasks() * get2.NumTasks(); - pool.Run(0, num_tasks, caller, [&](uint64_t task, size_t thread) { - HWY_DASSERT(task < (uint64_t{1} << 32)); - const size_t idx2 = div1.Divide(static_cast(task)); - const size_t idx1 = div1.Remainder(static_cast(task)); - HWY_DASSERT(idx1 < get1.NumTasks()); - HWY_DASSERT(idx2 < get2.NumTasks()); - const IndexRange range1 = get1.Range(idx1); - const IndexRange range2 = get2.Range(idx2); - func(range1, range2, thread); - }); -} - } // namespace gcpp #endif // THIRD_PARTY_GEMMA_CPP_UTIL_THREADING_H_ diff --git a/util/threading_context.cc b/util/threading_context.cc index e725ce3..4a3b927 100644 --- a/util/threading_context.cc +++ b/util/threading_context.cc @@ -43,12 +43,9 @@ static void TunePool(hwy::PoolWaitMode wait_mode, hwy::ThreadPool& pool) { const size_t num_tasks[4] = {HWY_MAX(1, num_workers / 2), num_workers * 1, num_workers * 5, num_workers * 20}; - // Count tasks executed to ensure workers aren't optimized out. One per - // cache line to avoid false sharing. - const size_t kSizePerLine = HWY_ALIGNMENT / sizeof(size_t); - - std::vector counters(num_workers * kSizePerLine); - size_t prev_total = 0; // avoids having to reset counters. + // Count tasks executed to ensure workers aren't optimized out. + std::vector counters(num_workers * kU64PerLine); + uint64_t prev_total = 0; // avoids having to reset counters. hwy::RandomState rng; for (size_t rep = 0; rep < 500; ++rep) { @@ -63,13 +60,13 @@ static void TunePool(hwy::PoolWaitMode wait_mode, hwy::ThreadPool& pool) { pool.Run(begin, end, [&](uint64_t task, size_t thread) { HWY_ASSERT(begin <= task && task < end); HWY_ASSERT(thread < num_workers); - counters[thread * kSizePerLine]++; + counters[thread * kU64PerLine]++; }); // Reduce count and ensure it matches the expected number of tasks. - size_t total = 0; + uint64_t total = 0; for (size_t i = 0; i < num_workers; ++i) { - total += counters[i * kSizePerLine]; + total += counters[i * kU64PerLine]; } const size_t expected = end - begin; HWY_ASSERT(total == prev_total + expected); @@ -100,7 +97,8 @@ ThreadingContext::ThreadingContext(const ThreadingArgs& args) BoundedSlice(args.skip_lps, args.max_lps)), cache_info(topology), allocator(topology, cache_info, args.bind != Tristate::kFalse), - pools(topology, allocator, args.max_threads, args.pin) { + pools(topology, allocator, args.max_threads, args.pin), + tensor_output(args.tensor_output) { PROFILER_ZONE("Startup.ThreadingContext autotune"); TunePools(hwy::PoolWaitMode::kSpin, *this); // kBlock is the default, hence set/tune it last. diff --git a/util/threading_context.h b/util/threading_context.h index 251888f..7e595ba 100644 --- a/util/threading_context.h +++ b/util/threading_context.h @@ -23,6 +23,7 @@ #include // IWYU pragma: begin_exports +#include "io/io.h" // Path #include "util/allocator.h" #include "util/args.h" #include "util/basics.h" // Tristate @@ -37,7 +38,9 @@ namespace gcpp { // Optional arguments for `ThreadingContext` from the command line. class ThreadingArgs : public ArgsBase { public: - ThreadingArgs(int argc, char* argv[]) { InitAndParse(argc, argv); } + ThreadingArgs(int argc, char* argv[], ConsumedArgs& consumed) { + InitAndParse(argc, argv, consumed); + } ThreadingArgs() { Init(); }; // For BoundedTopology: @@ -55,6 +58,8 @@ class ThreadingArgs : public ArgsBase { Tristate pin; // pin threads? Tristate spin; // use spin waits? + Path tensor_output; // empty, or directory for tensor output + template void ForEach(const Visitor& visitor) { // These can be used to partition CPU packages/sockets and their @@ -85,6 +90,9 @@ class ThreadingArgs : public ArgsBase { visitor(bind, "bind", Tristate::kDefault, "Bind memory to sockets? -1 = auto, 0 = no, 1 = yes.", 2); + + visitor(tensor_output, "tensor_output", Path(), + "Empty, or directory for tensor output.", 2); } }; @@ -100,7 +108,7 @@ struct ThreadingContext { // Returns a worker index compatible with those from `ParallelFor`, assuming // the current thread is running on one thread per cluster, which happens - // when `ParallelismStrategy` is `kAcrossClusters`. + // when `Parallelism` is `kAcrossClusters`. size_t Worker(size_t cluster_idx) const { return cluster_idx * pools.MaxWorkersPerCluster(); } @@ -124,13 +132,15 @@ struct ThreadingContext { // Per-package/cluster/within cluster pools of threads, matching `topology`. NestedPools pools; + + Path tensor_output; // used by `TensorStats::Notify`. }; #define GCPP_ZONE(ctx, global_idx, zone_enum) \ PROFILER_ZONE3(ctx.profiler, global_idx, ctx.profiler_zones.Get(zone_enum)) // Describes the strategy for distributing parallel work across cores. -enum class ParallelismStrategy : uint8_t { +enum class Parallelism : uint8_t { // Execute using a single-threaded loop on the calling thread. The `worker` // index passed to the user's `Func` is unique across clusters. kNone, @@ -154,56 +164,110 @@ enum class ParallelismStrategy : uint8_t { kHierarchical, }; -// Calls `func(task, worker)` for each task in `[0, num_tasks)`. Parallelizes -// over clusters of ONE package, then within each cluster. +// Helper functions used to implement `ParallelFor`, also reused in multiple +// places. User code should call `ParallelFor` instead, which accepts the more +// convenient `Callers` enum. +// +// These call `func(task, worker)` for each task in `[0, num_tasks)`. + +// NOTE: the worker argument is actually the `cluster_idx`, so that `Func` can +// pass that to `ParallelForWithinCluster`. +template +void ParallelForAcrossClusters(size_t num_tasks, ThreadingContext& ctx, + hwy::pool::Caller caller, const Func& func) { + ctx.pools.AllClusters().Run( + 0, num_tasks, caller, + [&](uint64_t task, size_t cluster_idx) { func(task, cluster_idx); }); +} + +template +void ParallelForWithinCluster(size_t num_tasks, ThreadingContext& ctx, + size_t cluster_idx, hwy::pool::Caller caller, + const Func& func) { + const size_t cluster_base = ctx.Worker(cluster_idx); + ctx.pools.Cluster(cluster_idx) + .Run(0, num_tasks, caller, [&](uint64_t task, size_t worker) { + func(task, cluster_base + worker); + }); +} + +// Calls `func(range, cluster_idx)`, for passing to `*WithinCluster`. +template +void ParallelPartitionAcrossClusters(const IndexRange range, + size_t task_multiple, size_t inner_tasks, + ThreadingContext& ctx, + hwy::pool::Caller caller, + const Func& func) { + HWY_DASSERT(1 <= inner_tasks && inner_tasks <= 4); + const IndexRangePartition ranges = StaticPartition( + range, ctx.pools.NumClusters() * inner_tasks, task_multiple); + ParallelForAcrossClusters(ranges.NumTasks(), ctx, caller, + [&](uint64_t task, size_t cluster_idx) { + func(ranges.Range(task), cluster_idx); + }); +} + +// Calls `func(range, worker)`. +template +void ParallelPartitionWithinCluster(const IndexRange range, + size_t task_multiple, size_t inner_tasks, + ThreadingContext& ctx, size_t cluster_idx, + hwy::pool::Caller caller, + const Func& func) { + HWY_DASSERT(1 <= inner_tasks && inner_tasks <= 4); + const size_t num_workers = ctx.pools.Cluster(cluster_idx).NumWorkers(); + const IndexRangePartition ranges = + StaticPartition(range, num_workers * inner_tasks, task_multiple); + ParallelForWithinCluster( + ranges.NumTasks(), ctx, cluster_idx, caller, + [&](uint64_t task, size_t worker) { func(ranges.Range(task), worker); }); +} + +// Parallelizes across clusters, then within each cluster. template void HierarchicalParallelFor(size_t num_tasks, ThreadingContext& ctx, Callers callers, const Func& func) { const hwy::pool::Caller caller = ctx.pool_callers.Get(callers); - // If few tasks, run on a single cluster. Also avoids a bit of overhead if - // there is only one cluster. - hwy::ThreadPool& all_clusters = ctx.pools.AllClusters(); - const size_t num_clusters = all_clusters.NumWorkers(); - hwy::ThreadPool& cluster = ctx.pools.Cluster(0); - if (num_clusters == 1 || num_tasks <= cluster.NumWorkers()) { - return cluster.Run(0, num_tasks, caller, [&](uint64_t task, size_t thread) { - func(task, thread); - }); + + // If at most one task per cluster worker, run on a single cluster to avoid + // the expensive cross-cluster barrier. + { + const size_t cluster_idx = 0; + const size_t cluster_workers = ctx.pools.Cluster(cluster_idx).NumWorkers(); + if (HWY_UNLIKELY(num_tasks <= cluster_workers)) { + return ParallelForWithinCluster(num_tasks, ctx, cluster_idx, caller, + func); + } } - // Assign each cluster a sub-range. - const IndexRangePartition ranges = - StaticPartition(IndexRange(0, num_tasks), num_clusters, 1); - ParallelizeOneRange(ranges, all_clusters, caller, - [&](const IndexRange& range, const size_t cluster_idx) { - hwy::ThreadPool& cluster = - ctx.pools.Cluster(cluster_idx); - const size_t cluster_base = - cluster_idx * ctx.pools.MaxWorkersPerCluster(); - cluster.Run(range.begin(), range.end(), caller, - [&](uint64_t task, size_t thread) { - func(task, cluster_base + thread); - }); - }); + ParallelPartitionAcrossClusters( + IndexRange(0, num_tasks), /*task_multiple=*/1, /*inner_tasks=*/1, ctx, + caller, [&](const IndexRange& cluster_range, size_t cluster_idx) { + ParallelForWithinCluster(cluster_range.Num(), ctx, cluster_idx, caller, + [&](uint64_t i, size_t worker) { + func(cluster_range.begin() + i, worker); + }); + }); } // Calls `func(task, worker)` for each `task` in `[0, num_tasks)`, with the -// number/type of workers determined by `parallelism`. `cluster_idx` is for -// `parallelism == kWithinCluster`, and should be 0 if unknown. +// number/type of workers determined by `parallelism`. NOTE: worker is actually +// `cluster_idx` for `kAcrossClusters`. The `cluster_idx` argument is for +// `parallelism == {kWithinCluster, kNone}`, and should be 0 if unknown. template -void ParallelFor(ParallelismStrategy parallelism, size_t num_tasks, +void ParallelFor(Parallelism parallelism, size_t num_tasks, ThreadingContext& ctx, size_t cluster_idx, Callers callers, const Func& func) { HWY_DASSERT(cluster_idx < ctx.topology.NumClusters()); if (cluster_idx != 0) { // If already running across clusters, only use within-cluster modes. - HWY_DASSERT(parallelism == ParallelismStrategy::kNone || - parallelism == ParallelismStrategy::kWithinCluster); + HWY_DASSERT(parallelism == Parallelism::kNone || + parallelism == Parallelism::kWithinCluster); } const hwy::pool::Caller caller = ctx.pool_callers.Get(callers); switch (parallelism) { - case ParallelismStrategy::kNone: { + case Parallelism::kNone: { const size_t worker = ctx.Worker(cluster_idx); for (size_t task = 0; task < num_tasks; ++task) { func(task, worker); @@ -211,40 +275,28 @@ void ParallelFor(ParallelismStrategy parallelism, size_t num_tasks, return; } - case ParallelismStrategy::kAcrossClusters: - return ctx.pools.AllClusters().Run( - 0, num_tasks, caller, + case Parallelism::kAcrossClusters: + return ParallelForAcrossClusters( + num_tasks, ctx, caller, [&](uint64_t task, size_t cluster_idx) { func(task, cluster_idx); }); - case ParallelismStrategy::kWithinCluster: { - // Ensure the worker argument is unique across clusters, because it is - // used for TLS indexing for example in profiler.h. - const size_t base = ctx.Worker(cluster_idx); - return ctx.pools.Cluster(cluster_idx) - .Run(0, num_tasks, caller, [&](uint64_t task, size_t worker) { - func(task, base + worker); - }); - } + case Parallelism::kWithinCluster: + return ParallelForWithinCluster(num_tasks, ctx, cluster_idx, caller, + func); - case ParallelismStrategy::kFlat: { - // Check for single cluster; if not, we must compute `cluster_base` for - // consistent and non-overlapping worker indices. - hwy::ThreadPool& all_clusters = ctx.pools.AllClusters(); - const size_t num_clusters = all_clusters.NumWorkers(); - if (num_clusters == 1) { - return ctx.pools.Cluster(cluster_idx) - .Run(0, num_tasks, caller, - [&](uint64_t task, size_t worker) { func(task, worker); }); + case Parallelism::kFlat: + // Choose a single pool: the only cluster, or across all clusters + // (slower synchronization, but more memory bandwidth) + if (HWY_UNLIKELY(ctx.pools.NumClusters() == 1)) { + return ParallelForWithinCluster(num_tasks, ctx, cluster_idx, caller, + func); } + return ParallelForAcrossClusters(num_tasks, ctx, caller, + [&](uint64_t task, size_t cluster_idx) { + func(task, ctx.Worker(cluster_idx)); + }); - return all_clusters.Run(0, num_tasks, caller, - [&](uint64_t task, size_t cluster_idx) { - const size_t worker = ctx.Worker(cluster_idx); - func(task, worker); - }); - } - - case ParallelismStrategy::kHierarchical: + case Parallelism::kHierarchical: return HierarchicalParallelFor(num_tasks, ctx, callers, func); } } diff --git a/util/threading_test.cc b/util/threading_test.cc index 14ea1a0..b5d1858 100644 --- a/util/threading_test.cc +++ b/util/threading_test.cc @@ -202,59 +202,7 @@ TEST(ThreadingTest, TestStaticPartition) { } } -TEST(ThreadingTest, TestParallelizeOneRange) { - const IndexRange range(0, 10); - const IndexRangePartition partition = StaticPartition(range, 2, 4); - hwy::ThreadPool null_pool(0); - size_t calls = 0; - ParallelizeOneRange(partition, null_pool, kCaller, - [&](const IndexRange& range, size_t) { - if (++calls == 1) { - HWY_ASSERT(range.begin() == 0 && range.end() == 8); - } else { - HWY_ASSERT(range.begin() == 8 && range.end() == 10); - } - }); - HWY_ASSERT(calls == 2); -} - -TEST(ThreadingTest, TestParallelizeTwoRanges) { - const IndexRangePartition partition1 = - StaticPartition(IndexRange(0, 10), 2, 4); - const IndexRangePartition partition2 = - MaxSizePartition(IndexRange(128, 256), 32, 32); - HWY_ASSERT(partition2.NumTasks() == 4); - hwy::ThreadPool null_pool(0); - { - size_t calls = 0; - ParallelizeTwoRanges( - partition1, partition2, null_pool, kCaller, - [&](const IndexRange& range1, const IndexRange& range2, size_t) { - ++calls; - HWY_ASSERT(range1.begin() == 0 || range1.begin() == 8); - HWY_ASSERT(range2.begin() % 32 == 0); - HWY_ASSERT(range2.Num() % 32 == 0); - }); - HWY_ASSERT(calls == 2 * 4); - } - - // Also swap order to test Remainder() logic. - { - size_t calls = 0; - ParallelizeTwoRanges( - partition2, partition1, null_pool, kCaller, - [&](const IndexRange& range2, const IndexRange& range1, size_t) { - ++calls; - HWY_ASSERT(range1.begin() == 0 || range1.begin() == 8); - HWY_ASSERT(range2.begin() % 32 == 0); - HWY_ASSERT(range2.Num() % 32 == 0); - }); - HWY_ASSERT(calls == 2 * 4); - } -} - -static constexpr size_t kU64PerThread = HWY_ALIGNMENT / sizeof(size_t); -static uint64_t outputs[hwy::kMaxLogicalProcessors * kU64PerThread]; +static uint64_t outputs[hwy::kMaxLogicalProcessors * kU64PerLine]; std::vector MeasureForkJoin(hwy::ThreadPool& pool) { // Governs duration of test; avoid timeout in debug builds. @@ -268,7 +216,7 @@ std::vector MeasureForkJoin(hwy::ThreadPool& pool) { const double t0 = hwy::platform::Now(); for (size_t reps = 0; reps < 1200; ++reps) { pool.Run(0, pool.NumWorkers(), kCaller, [&](uint64_t task, size_t thread) { - outputs[thread * kU64PerThread] = base + thread; + outputs[thread * kU64PerLine] = base + thread; }); hwy::PreventElision(outputs[base]); if (pool.AutoTuneComplete()) break; @@ -309,7 +257,7 @@ std::vector MeasureForkJoin(hwy::ThreadPool& pool) { const uint64_t t0 = hwy::timer::Start(); pool.Run(0, pool.NumWorkers(), kCaller, [&](uint64_t task, size_t thread) { - outputs[thread * kU64PerThread] = base + thread; + outputs[thread * kU64PerLine] = base + thread; }); const uint64_t t1 = hwy::timer::Stop(); times.push_back(t1 - t0); @@ -319,7 +267,7 @@ std::vector MeasureForkJoin(hwy::ThreadPool& pool) { const uint64_t t0 = hwy::timer::Start(); pool.Run(0, pool.NumWorkers(), kCaller, [&](uint64_t task, size_t thread) { - outputs[thread * kU64PerThread] = base + thread; + outputs[thread * kU64PerLine] = base + thread; }); const uint64_t t1 = hwy::timer::Start(); times.push_back(t1 - t0); @@ -366,10 +314,10 @@ TEST(ThreadingTest, BenchJoin) { // Verify outputs to ensure the measured code is not a no-op. for (size_t lp = 0; lp < pool.NumWorkers(); ++lp) { - HWY_ASSERT(outputs[lp * kU64PerThread] >= 1); - HWY_ASSERT(outputs[lp * kU64PerThread] <= 1 + pool.NumWorkers()); - for (size_t i = 1; i < kU64PerThread; ++i) { - HWY_ASSERT(outputs[lp * kU64PerThread + i] == 0); + HWY_ASSERT(outputs[lp * kU64PerLine] >= 1); + HWY_ASSERT(outputs[lp * kU64PerLine] <= 1 + pool.NumWorkers()); + for (size_t i = 1; i < kU64PerLine; ++i) { + HWY_ASSERT(outputs[lp * kU64PerLine + i] == 0); } } }; diff --git a/util/topology.cc b/util/topology.cc index f20b7f9..fae1dee 100644 --- a/util/topology.cc +++ b/util/topology.cc @@ -21,12 +21,14 @@ #include #include "hwy/base.h" +#include "hwy/bit_set.h" namespace gcpp { // Returns set of LPs available for use. static LPS EnabledLPs(const BoundedSlice& lp_slice) { LPS enabled_lps; + const size_t num_lps = hwy::TotalLogicalProcessors(); // Thread-safe caching during the first call because subsequent pinning // overwrites the main thread's affinity. @@ -35,6 +37,7 @@ static LPS EnabledLPs(const BoundedSlice& lp_slice) { if (!GetThreadAffinity(affinity)) affinity = LPS(); return affinity; }(); + if (HWY_LIKELY(affinity.Any())) { // To honor taskset/numactl *and* the users's `lp_slice`, we interpret // the latter as a slice of the 1-bits of `enabled_lps`. Note that this @@ -48,18 +51,32 @@ static LPS EnabledLPs(const BoundedSlice& lp_slice) { } ++enabled_idx; }); - } else { - const size_t num_lps = hwy::TotalLogicalProcessors(); - // Do not warn on Apple, where affinity is not supported. - if (!HWY_OS_APPLE) { - HWY_WARN("unknown OS affinity, max %zu LPs and slice %zu.", num_lps, - lp_slice.Num(num_lps)); + } + + if (HWY_UNLIKELY(!enabled_lps.Any())) { + // First warn: either about unknown affinity, or no overlap with `lp_slice`. + if (!affinity.Any()) { + // Do not warn on Apple, where affinity is not supported. + if (!HWY_OS_APPLE) { + HWY_WARN("unknown OS affinity, max %zu LPs and slice %zu.", num_lps, + lp_slice.Num(num_lps)); + } + } else { + HWY_WARN("LP slice [%zu, %zu) of initial affinity %zu is empty.", + lp_slice.Begin(), lp_slice.End(num_lps), affinity.Count()); } + + // Set `enabled_lps` based only on `lp_slice` and total logical processors. for (size_t lp = 0; lp < num_lps; ++lp) { if (lp_slice.Contains(num_lps, lp)) { enabled_lps.Set(lp); } } + + if (!enabled_lps.Any()) { + HWY_WARN("no enabled LPs of total %zu, slice [%zu, %zu).", num_lps, + lp_slice.Begin(), lp_slice.End(affinity.Count())); + } } // Without threading support, only keep the first enabled LP; it might still @@ -72,6 +89,7 @@ static LPS EnabledLPs(const BoundedSlice& lp_slice) { HWY_WARN("Warning, threads not supported, using only the main thread."); } + HWY_ASSERT(enabled_lps.Any()); return enabled_lps; } @@ -156,12 +174,13 @@ constexpr size_t kMaxLPsPerCluster = 6; #if !GEMMA_DISABLE_TOPOLOGY -static size_t CoresFromLPs(const LPS& lps, const hwy::Topology& topology) { - LPS cores; - lps.Foreach([&](size_t lp) { - if (topology.lps[lp].smt == 0) cores.Set(lp); - }); - return cores.Count(); +// Returns number of distinct SMT (hyperthreads). +static size_t NumSMT(const hwy::Topology& topology) { + hwy::BitSet64 smt; + for (const hwy::Topology::LP& lp : topology.lps) { + smt.Set(lp.smt); + } + return smt.Count(); } // tcluster is a modifiable copy of the first cluster in the package. @@ -187,34 +206,66 @@ void BoundedTopology::SplitLargeCluster(const LPS& enabled_lps, } } -// Main part of ctor, called when topology is known. -bool BoundedTopology::InitFromTopology(const LPS& enabled_lps) { - const size_t tpkg_idx = package_slice_.Begin(); - HWY_ASSERT(tpkg_idx < topology_.packages.size()); - const hwy::Topology::Package& tpackage = topology_.packages[tpkg_idx]; - const std::vector& tclusters = tpackage.clusters; +using TClusters = std::vector; + +// Returns false if no cluster in `tclusters` has any enabled LPs. +static bool AnyEnabledLPs(const TClusters& tclusters, const LPS& enabled_lps) { if (HWY_UNLIKELY(tclusters.empty())) { - HWY_WARN("Topology: no clusters found in package %zu.", tpkg_idx); + HWY_WARN("Topology: no clusters found."); return false; } - size_t max_tcluster_cores = 0; - size_t max_tcluster_lps = 0; for (const hwy::Topology::Cluster& tcluster : tclusters) { - const size_t cores = CoresFromLPs(tcluster.lps, topology_); - const size_t lps = tcluster.lps.Count(); - max_tcluster_cores = HWY_MAX(max_tcluster_cores, cores); - max_tcluster_lps = HWY_MAX(max_tcluster_lps, lps); + bool any_lp_enabled = false; + tcluster.lps.Foreach( + [&](size_t lp) { any_lp_enabled |= (enabled_lps.Get(lp)); }); + if (any_lp_enabled) return true; } - HWY_ASSERT(max_tcluster_cores != 0); - HWY_ASSERT(max_tcluster_lps >= max_tcluster_cores); + + // No warning: this can happen if OS affinity limits us to the second package. + return false; +} + +// Returns nullptr on failure. Also attempts `1 - tpkg_idx`, which is suitable +// for the common case of up to two packages. +static const TClusters* GetPackageClusters(const hwy::Topology& topology, + size_t tpkg_idx, + const LPS& enabled_lps) { + const size_t num_packages = topology.packages.size(); + HWY_ASSERT(tpkg_idx < num_packages); + { + const TClusters& tclusters = topology.packages[tpkg_idx].clusters; + if (AnyEnabledLPs(tclusters, enabled_lps)) return &tclusters; + } + + // Retry with the other package, if any. + tpkg_idx ^= 1; + if (tpkg_idx == num_packages) return nullptr; + { + const TClusters& tclusters = topology.packages[tpkg_idx].clusters; + if (AnyEnabledLPs(tclusters, enabled_lps)) return &tclusters; + } + + HWY_WARN( + "Ignoring topology (%zu tpackages) because no clusters overlap with the " + "OS affinity (%zu enabled LPs): ", + num_packages, enabled_lps.Count()); + enabled_lps.Foreach([](size_t lp) { fprintf(stderr, "%zu, ", lp); }); + return nullptr; +} + +// Main part of ctor, called when topology is known. +bool BoundedTopology::InitFromTopology(const LPS& enabled_lps) { + const TClusters* maybe_tclusters = + GetPackageClusters(topology_, package_slice_.Begin(), enabled_lps); + if (!maybe_tclusters) return false; + const TClusters& tclusters = *maybe_tclusters; // Populate `clusters` with the subset of clusters in `cluster_slice` that // have any enabled LPs. clusters_.reserve(cluster_slice_.Num(tclusters.size())); cluster_slice_.Foreach("cluster", tclusters.size(), [&](size_t cluster_idx) { - const hwy::Topology::Cluster& tcluster = tpackage.clusters[cluster_idx]; - Cluster cluster(enabled_lps, topology_.lps, tcluster); + Cluster cluster(enabled_lps, topology_.lps, tclusters[cluster_idx]); // Skip if empty, i.e. too few `enabled_lps`. if (HWY_LIKELY(cluster.NumWorkers() != 0)) { @@ -223,14 +274,10 @@ bool BoundedTopology::InitFromTopology(const LPS& enabled_lps) { nodes_.Set(cluster.Node()); } }); - if (HWY_UNLIKELY(clusters_.empty())) { - HWY_WARN("Too restrictive cluster_slice or enabled_lps, no clusters left."); - return false; - } if (kSplitLargeClusters && clusters_.size() == 1 && enabled_lps.Count() >= 16) { - SplitLargeCluster(enabled_lps, tpackage.clusters[0]); + SplitLargeCluster(enabled_lps, tclusters[0]); } // Sort by descending 'size' so that users who only use one get the largest. @@ -239,20 +286,23 @@ bool BoundedTopology::InitFromTopology(const LPS& enabled_lps) { return a.NumWorkers() > b.NumWorkers(); }); - // Largest number of enabled workers in any cluster, for `topology_string_`. - // This may be less than `max_tcluster_cores` if `enabled_lps` excludes some. - size_t max_cluster_workers = 0; - for (const Cluster& c : clusters_) { - max_cluster_workers = HWY_MAX(max_cluster_workers, c.NumWorkers()); + // Happens if all LPs are HTs (we checked that at least some LPs are enabled). + if (HWY_UNLIKELY(clusters_.empty())) { + HWY_WARN( + "Ignoring topology - no usable clusters. cluster_slice [%zu, %zu), " + "%zu tclusters, %zu tLPs, %zu enabled LPs: ", + cluster_slice_.Begin(), cluster_slice_.End(tclusters.size()), + tclusters.size(), topology_.lps.size(), enabled_lps.Count()); + enabled_lps.Foreach([](size_t lp) { fprintf(stderr, "%zu, ", lp); }); + return false; } - HWY_ASSERT(max_cluster_workers <= max_tcluster_cores); - // Do not warn about large clusters: GNR has 40. + const size_t num_smt = NumSMT(topology_); snprintf(topology_string_, sizeof(topology_string_), "%zuS %zuX %zuC %zuH, using %zuX %zuC (nodes=%zu)", - topology_.packages.size(), tclusters.size(), max_tcluster_cores, - max_tcluster_lps / max_tcluster_cores, NumClusters(), - max_cluster_workers, nodes_.Count()); + topology_.packages.size(), tclusters.size(), + tclusters[0].lps.Count() / num_smt, num_smt, NumClusters(), + clusters_[0].NumWorkers(), nodes_.Count()); return true; } diff --git a/util/topology.h b/util/topology.h index d4f80cc..b4a03fe 100644 --- a/util/topology.h +++ b/util/topology.h @@ -93,7 +93,7 @@ class BoundedTopology { class Cluster { public: - Cluster(const LPS& lps); + explicit Cluster(const LPS& lps); Cluster(const LPS& enabled_lps, const std::vector& all_lps, const hwy::Topology::Cluster& tcluster); diff --git a/util/zones.cc b/util/zones.cc index a474311..6480b96 100644 --- a/util/zones.cc +++ b/util/zones.cc @@ -51,6 +51,8 @@ const char* ZoneName(Zones zone) { return "Gen.SampleTop1"; case Zones::kGenSampleTopK: return "Gen.SampleTopK"; + case Zones::kGenStats: + return "Gen.Stats"; case Zones::kMMDecompressA: return "MM.DecompressA"; case Zones::kMMDispatch: @@ -163,6 +165,8 @@ const char* CallerName(Callers caller) { return "ReadBatches"; case Callers::kSampleAndStream: return "SampleAndStream"; + case Callers::kTensorStats: + return "TensorStats"; case Callers::kTest: // only for unit tests. return "Test-only!"; case Callers::kTunePool: diff --git a/util/zones.h b/util/zones.h index 5624e24..f324086 100644 --- a/util/zones.h +++ b/util/zones.h @@ -31,6 +31,7 @@ enum class Zones { // Keep sorted kGenFFW, kGenSampleTop1, kGenSampleTopK, + kGenStats, kMMDecompressA, kMMDispatch, kMMMatMul, @@ -96,6 +97,7 @@ enum class Callers { // Keep sorted kReadAllToBF16, kReadBatches, kSampleAndStream, + kTensorStats, kTest, // only for unit tests. kTunePool, kVitDotSoftmax1,