From a60b564b883fb2f85f11031e5421178b67cd794d Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Thu, 23 Jan 2025 01:54:50 -0800 Subject: [PATCH] Infra improvements (2) ops.h: move CreateInvTimescale to allow calling without depending on gemma Pass around MatMulEnv instead of pools to avoid re-creating the env profiler.h can now be used outside SIMD code allocator: add StepBytes and QuantumSteps rename worker thread with package/cluster in the name threading: add Visit* to IndexRange PiperOrigin-RevId: 718766704 --- BUILD.bazel | 11 ++++-- CMakeLists.txt | 3 +- MODULE.bazel | 2 +- backprop/backward_test.cc | 7 ++-- backprop/optimize_test.cc | 7 ++-- compression/compress-inl.h | 42 ++++++++++++++++++++-- compression/shared.h | 2 +- examples/hello_world/CMakeLists.txt | 2 +- gemma/activations.h | 29 ++++----------- gemma/gemma-inl.h | 43 +++++++++++----------- gemma/gemma.cc | 55 +++++++++++++++-------------- gemma/gemma.h | 3 +- ops/bench_matmul.cc | 52 +++++++++++++++------------ ops/dot-inl.h | 2 +- ops/dot_test.cc | 2 +- ops/matmul.h | 24 ++++++++++--- ops/matmul_test.cc | 47 ++++++++++++------------ ops/matvec-inl.h | 2 +- ops/ops-inl.h | 2 +- ops/ops.h | 45 +++++++++++++++++++++++ ops/ops_test.cc | 9 ++--- util/allocator.cc | 22 ++++++++++-- util/allocator.h | 16 +++++---- util/basics.h | 3 ++ util/threading.cc | 37 +++++++++++-------- util/threading.h | 19 ++++++++++ 26 files changed, 317 insertions(+), 171 deletions(-) create mode 100644 ops/ops.h diff --git a/BUILD.bazel b/BUILD.bazel index 260b90c..b65a7d9 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -86,6 +86,7 @@ cc_library( name = "ops", hdrs = [ "ops/matmul.h", + "ops/ops.h", ], textual_hdrs = [ "ops/dot-inl.h", @@ -104,6 +105,7 @@ cc_library( "@highway//:hwy", "@highway//:math", "@highway//:matvec", + "@highway//:nanobenchmark", # timer "@highway//:profiler", "@highway//:thread_pool", ], @@ -145,7 +147,6 @@ cc_test( deps = [ ":allocator", ":common", - ":gemma_lib", ":ops", ":test_util", "@googletest//:gtest_main", # buildcleaner: keep @@ -221,8 +222,11 @@ cc_test( timeout = "long", srcs = ["ops/bench_matmul.cc"], local_defines = ["HWY_IS_TEST"], - # for test_suite. - tags = ["ops_tests"], + tags = [ + "manual", + "notap", + "ops_tests", # for test_suite. + ], deps = [ ":allocator", ":basics", @@ -683,6 +687,7 @@ cc_test( ":basics", ":common", ":gemma_lib", + ":ops", ":optimizer", ":prompt", ":sampler", diff --git a/CMakeLists.txt b/CMakeLists.txt index d40ffc0..70ce270 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -22,7 +22,7 @@ set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_EXPORT_COMPILE_COMMANDS ON) -FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 2b565e87d50b151660494624af532ac0b6076c79 EXCLUDE_FROM_ALL) +FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG f2209b911c74019e85d0b7a7a2833c9a2e1b7995 EXCLUDE_FROM_ALL) FetchContent_MakeAvailable(highway) ## Note: absl needs to be installed by sentencepiece. This will only happen if @@ -94,6 +94,7 @@ set(SOURCES ops/matmul-inl.h ops/matvec-inl.h ops/ops-inl.h + ops/ops.h ops/sum-inl.h paligemma/image.cc paligemma/image.h diff --git a/MODULE.bazel b/MODULE.bazel index 8d30ce7..e835941 100644 --- a/MODULE.bazel +++ b/MODULE.bazel @@ -17,7 +17,7 @@ bazel_dep(name = "google_benchmark", version = "1.8.5") # Require a more recent version. git_override( module_name = "highway", - commit = "2b565e87d50b151660494624af532ac0b6076c79", + commit = "f2209b911c74019e85d0b7a7a2833c9a2e1b7995", remote = "https://github.com/google/highway", ) diff --git a/backprop/backward_test.cc b/backprop/backward_test.cc index 0df079d..c5671c7 100644 --- a/backprop/backward_test.cc +++ b/backprop/backward_test.cc @@ -31,8 +31,8 @@ #include "backprop/prompt.h" #include "backprop/sampler.h" #include "backprop/test_util.h" -#include "gemma/activations.h" #include "gemma/configs.h" +#include "ops/ops.h" #include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" @@ -223,8 +223,9 @@ void TestEndToEnd() { ReverseSequenceSampler training_task({0, 0, 1, 1}); std::vector batch = training_task.SampleBatch(3, gen); - RowVectorBatch inv_timescale = Activations::CreateInvTimescale( - config.layer_configs[0].qkv_dim, config.layer_configs[0].post_qk); + RowVectorBatch inv_timescale = CreateInvTimescale( + config.layer_configs[0].qkv_dim, + config.layer_configs[0].post_qk == PostQKType::HalfRope); for (const Prompt& prompt : batch) { ReverseSequenceSampler::LogPrompt(prompt); RandInit(weights.get(), 1.0f, gen); diff --git a/backprop/optimize_test.cc b/backprop/optimize_test.cc index cc7cef7..6d3de50 100644 --- a/backprop/optimize_test.cc +++ b/backprop/optimize_test.cc @@ -28,11 +28,11 @@ #include "backprop/prompt.h" #include "backprop/sampler.h" #include "compression/shared.h" -#include "gemma/activations.h" #include "gemma/common.h" #include "gemma/configs.h" #include "gemma/gemma.h" #include "gemma/weights.h" +#include "ops/ops.h" #include "util/allocator.h" #include "util/basics.h" #include "util/threading.h" @@ -62,8 +62,9 @@ TEST(OptimizeTest, GradientDescent) { ForwardPass forward(config), backward(config); KVCache kv_cache = KVCache::Create(config, /*prefill_tbatch_size=*/16); - RowVectorBatch inv_timescale = Activations::CreateInvTimescale( - config.layer_configs[0].qkv_dim, config.layer_configs[0].post_qk); + RowVectorBatch inv_timescale = CreateInvTimescale( + config.layer_configs[0].qkv_dim, + config.layer_configs[0].post_qk == PostQKType::HalfRope); Gemma gemma(GemmaTokenizer(), info, pools); diff --git a/compression/compress-inl.h b/compression/compress-inl.h index e850a0e..8638b5f 100644 --- a/compression/compress-inl.h +++ b/compression/compress-inl.h @@ -303,6 +303,42 @@ struct CompressTraits { } } +#if 0 + template + static HWY_INLINE void DecompressAndZeroPad( + DBF dbf, const PackedSpan& packed, const size_t packed_ofs, + BF16* HWY_RESTRICT raw, size_t num) { + const size_t N16 = hn::Lanes(dbf); + +#if 1 + const BF16* HWY_RESTRICT start = packed.ptr + packed_ofs; + using VBF = hn::Vec; + size_t i = 0; + if (num >= 4 * N16) { + for (; i <= num - 4 * N16; i += 4 * N16) { + const VBF packed0 = hn::LoadU(dbf, start + i + 0 * N16); + const VBF packed1 = hn::LoadU(dbf, start + i + 1 * N16); + const VBF packed2 = hn::LoadU(dbf, start + i + 2 * N16); + const VBF packed3 = hn::LoadU(dbf, start + i + 3 * N16); + hn::StoreU(packed0, dbf, raw + i + 0 * N16); + hn::StoreU(packed1, dbf, raw + i + 1 * N16); + hn::StoreU(packed2, dbf, raw + i + 2 * N16); + hn::StoreU(packed3, dbf, raw + i + 3 * N16); + } + } + + for (; i < num; i += N16) { + const size_t remaining = num - i; + const VBF packed0 = hn::LoadN(dbf, start + i, remaining); + hn::StoreU(packed0, dbf, raw + i); + } +#else + hwy::CopyBytes(packed.ptr + packed_ofs, raw, num * sizeof(BF16)); + hwy::ZeroBytes(raw + num, (hwy::RoundUpTo(num, N16) - num) * sizeof(BF16)); +#endif + } +#endif + template static HWY_INLINE void DecompressAndZeroPad( DF df, const PackedSpan& packed, const size_t packed_ofs, @@ -534,9 +570,9 @@ HWY_INLINE void Decompress2(DRaw d, const PackedSpan& packed, // also wants to scale the decompressed elements. // `TRaw` can be `float/BF16`, or `double` if `Packed` is `float`. template > -HWY_NOINLINE void DecompressAndZeroPad(DRaw d, const PackedSpan& packed, - const size_t packed_ofs, TRaw* raw, - size_t num) { +HWY_INLINE void DecompressAndZeroPad(DRaw d, const PackedSpan& packed, + const size_t packed_ofs, TRaw* raw, + size_t num) { detail::VerifyRawAndPackedForDecompress(); packed.BoundsCheck(packed_ofs, num); using Traits = CompressTraits>; diff --git a/compression/shared.h b/compression/shared.h index f9814a1..eb33d48 100644 --- a/compression/shared.h +++ b/compression/shared.h @@ -286,7 +286,7 @@ struct PackedSpan { } Packed* HWY_RESTRICT ptr; - size_t num; // for BoundsCheck and nuq-inl.h HWY_ASSERT. + size_t num; // for BoundsCheck, also required by nuq-inl.h. }; // Avoids spelling out the template parameter in every call. diff --git a/examples/hello_world/CMakeLists.txt b/examples/hello_world/CMakeLists.txt index 0679aa5..030b2ba 100644 --- a/examples/hello_world/CMakeLists.txt +++ b/examples/hello_world/CMakeLists.txt @@ -17,7 +17,7 @@ project(hello_world) set(CMAKE_CXX_STANDARD_REQUIRED ON) include(FetchContent) -FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 2b565e87d50b151660494624af532ac0b6076c79) +FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG f2209b911c74019e85d0b7a7a2833c9a2e1b7995) FetchContent_MakeAvailable(highway) FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 53de76561cfc149d3c01037f0595669ad32a5e7c) FetchContent_MakeAvailable(sentencepiece) diff --git a/gemma/activations.h b/gemma/activations.h index 9d6ccb5..f08470c 100644 --- a/gemma/activations.h +++ b/gemma/activations.h @@ -18,12 +18,10 @@ #include -#include -#include // std::unique_ptr - #include "compression/shared.h" // BF16 #include "gemma/configs.h" #include "ops/matmul.h" // MatMulEnv +#include "ops/ops.h" // CreateInvTimescale #include "util/allocator.h" // RowVectorBatch #include "util/threading.h" #include "hwy/base.h" // HWY_DASSERT @@ -65,7 +63,7 @@ struct Activations { RowVectorBatch inv_timescale; // Dynamic because no default ctor and only initialized in `Allocate`. - std::unique_ptr env; + MatMulEnv* env; PostQKType post_qk = PostQKType::Rope; // And the config. @@ -74,23 +72,7 @@ struct Activations { size_t seq_len; size_t cache_pos_size = 0; - static RowVectorBatch CreateInvTimescale( - size_t qkv_dim, PostQKType post_qk, double base_frequency = 10000.0) { - const size_t rope_dim = - post_qk == PostQKType::HalfRope ? qkv_dim / 2 : qkv_dim; - RowVectorBatch inv_timescale(Extents2D(1, rope_dim / 2)); - for (size_t dim = 0; dim < rope_dim / 2; ++dim) { - const double freq_exponents = - static_cast(2 * dim) / static_cast(rope_dim); - // Replacing with expf(ln(1E4) * freq_exponents) changes results - // noticeably. - inv_timescale.Batch(0)[dim] = - static_cast(1.0 / std::pow(base_frequency, freq_exponents)); - } - return inv_timescale; - } - - void Allocate(size_t batch_size, NestedPools& pools) { + void Allocate(size_t batch_size, MatMulEnv* env) { post_qk = layer_config.post_qk; const size_t model_dim = weights_config.model_dim; const size_t ff_hidden_dim = layer_config.ff_hidden_dim; @@ -124,9 +106,10 @@ struct Activations { RowVectorBatch(Extents2D(batch_size, model_dim)); } - inv_timescale = CreateInvTimescale(qkv_dim, post_qk); + inv_timescale = CreateInvTimescale(layer_config.qkv_dim, + post_qk == PostQKType::HalfRope); - env = std::make_unique(pools); + this->env = env; } }; diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index 43d2f77..9aa3d11 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -37,6 +37,7 @@ #include "hwy/base.h" #include "hwy/bit_set.h" #include "hwy/contrib/thread_pool/thread_pool.h" +#include "hwy/profiler.h" #include "hwy/timer.h" // Include guard (still compiled once per target) @@ -53,7 +54,6 @@ #include "ops/matmul-inl.h" #include "ops/matvec-inl.h" #include "ops/ops-inl.h" -#include "hwy/profiler.h" // also uses SIMD #ifndef GEMMA_TYPE #if HWY_IDE @@ -81,7 +81,7 @@ HWY_NOINLINE void GriffinRecurrent(size_t batch_start, size_t num_tokens, const KVCaches& kv_caches) { PROFILER_ZONE("Gen.Griffin"); KVCache& kv_cache = kv_caches[0]; - hwy::ThreadPool& pool = activations.env->Pool(); + hwy::ThreadPool& pool = activations.env->parallel.Pools().Pool(0); namespace hn = hwy::HWY_NAMESPACE; using D = hn::ScalableTag; const size_t model_dim = layer_weights->layer_config.model_dim; @@ -517,7 +517,7 @@ class GemmaAttention { layer_weights_(*layer_weights), div_seq_len_(div_seq_len), kv_caches_(kv_caches), - pool_(activations.env->Pool()) { + pool_(activations.env->parallel.Pools().Pool(0)) { HWY_DASSERT(num_queries_ <= kv_caches_.size()); HWY_DASSERT_M((layer_config_.heads % layer_config_.kv_heads) == 0, "query heads must be a multiple of key-value heads"); @@ -654,7 +654,7 @@ class VitAttention { activations_(activations), layer_weights_(*layer_weights), layer_config_(layer_weights->layer_config), - pool_(activations.env->Pool()) {} + pool_(activations.env->parallel.Pools().Pool(0)) {} HWY_INLINE void operator()() { ComputeQKV(); @@ -1072,10 +1072,10 @@ HWY_NOINLINE void EmbedImagePatches(const Image& image, // which is not the case here. We should relax that requirement on MatMul and // then use the above. For now, we rely on MatVecAdd instead. for (size_t i = 0; i < seq_len; ++i) { - MatVecAdd(weights.vit_img_embedding_kernel, 0, model_dim, patch_size, - image_patches[i].get(), - weights.vit_img_embedding_bias.data_scale1(), - activations.x.Batch(i), activations.env->Pool()); + MatVecAdd( + weights.vit_img_embedding_kernel, 0, model_dim, patch_size, + image_patches[i].get(), weights.vit_img_embedding_bias.data_scale1(), + activations.x.Batch(i), activations.env->parallel.Pools().Pool(0)); } // Add position embeddings. AddFrom(weights.vit_img_pos_embedding.data_scale1(), activations.x.All(), @@ -1283,7 +1283,7 @@ void GenerateT(const ModelWeightsStorage& model, Activations& activations, Activations prefill_activations(weights.weights_config); if (use_prefill_activations) { prefill_activations.Allocate(runtime_config.prefill_tbatch_size, - activations.env->Pools()); + activations.env); } Prefill(queries_prompt, queries_mutable_pos, queries_prefix_end, query_idx_start, weights, @@ -1354,14 +1354,14 @@ template void GenerateSingleT(const ModelWeightsStorage& model, const RuntimeConfig& runtime_config, const PromptTokens& prompt, size_t pos, size_t prefix_end, - KVCache& kv_cache, NestedPools& pools, + KVCache& kv_cache, MatMulEnv* env, TimingInfo& timing_info) { constexpr size_t kNumQueries = 1; const size_t qbatch_start = 0; // TODO: move into Gemma? Activations activations(model.Config()); - activations.Allocate(kNumQueries, pools); + activations.Allocate(kNumQueries, env); const QueriesPromptTokens queries_prompt(&prompt, kNumQueries); QueriesPos queries_pos(&pos, kNumQueries); @@ -1378,7 +1378,7 @@ void GenerateBatchT(const ModelWeightsStorage& model, const QueriesPromptTokens& queries_prompt, const QueriesPos& queries_pos, const QueriesPos& queries_prefix_end, - const KVCaches& kv_caches, NestedPools& pools, + const KVCaches& kv_caches, MatMulEnv* env, TimingInfo& timing_info) { const size_t num_queries = queries_prompt.size(); HWY_ASSERT(queries_pos.size() == num_queries); @@ -1393,7 +1393,7 @@ void GenerateBatchT(const ModelWeightsStorage& model, } Activations activations(model.Config()); - activations.Allocate(max_qbatch_size, pools); + activations.Allocate(max_qbatch_size, env); for (size_t qbatch_start = 0; qbatch_start < num_queries; qbatch_start += max_qbatch_size) { @@ -1415,7 +1415,7 @@ template void GenerateImageTokensT(const ModelWeightsStorage& model, const RuntimeConfig& runtime_config, const Image& image, ImageTokens& image_tokens, - NestedPools& pools) { + MatMulEnv* env) { if (model.Config().vit_config.layer_configs.empty()) { HWY_ABORT("Model does not support generating image tokens."); } @@ -1423,7 +1423,7 @@ void GenerateImageTokensT(const ModelWeightsStorage& model, ModelConfig vit_config = GetVitConfig(model.Config()); prefill_runtime_config.prefill_tbatch_size = vit_config.seq_len; Activations prefill_activations(vit_config); - prefill_activations.Allocate(vit_config.seq_len, pools); + prefill_activations.Allocate(vit_config.seq_len, env); // Weights are for the full PaliGemma model, not just the ViT part. PrefillVit(*model.GetWeightsOfType(), prefill_runtime_config, image, image_tokens, prefill_activations); @@ -1438,11 +1438,10 @@ void GenerateImageTokensT(const ModelWeightsStorage& model, void GenerateSingle( // NOLINT(misc-definitions-in-headers) GEMMA_TYPE, const ModelWeightsStorage& model, const RuntimeConfig& runtime_config, const PromptTokens& prompt, size_t pos, - size_t prefix_end, KVCache& kv_cache, NestedPools& pools, + size_t prefix_end, KVCache& kv_cache, MatMulEnv* env, TimingInfo& timing_info) { HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateSingleT) - (model, runtime_config, prompt, pos, prefix_end, kv_cache, pools, - timing_info); + (model, runtime_config, prompt, pos, prefix_end, kv_cache, env, timing_info); } void GenerateBatch( // NOLINT(misc-definitions-in-headers) @@ -1450,18 +1449,18 @@ void GenerateBatch( // NOLINT(misc-definitions-in-headers) const RuntimeConfig& runtime_config, const QueriesPromptTokens& queries_prompt, const QueriesPos& queries_pos, const QueriesPos& queries_prefix_end, const KVCaches& kv_caches, - NestedPools& pools, TimingInfo& timing_info) { + MatMulEnv* env, TimingInfo& timing_info) { HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateBatchT) (model, runtime_config, queries_prompt, queries_pos, queries_prefix_end, - kv_caches, pools, timing_info); + kv_caches, env, timing_info); } void GenerateImageTokens( // NOLINT(misc-definitions-in-headers) GEMMA_TYPE, const ModelWeightsStorage& model, const RuntimeConfig& runtime_config, const Image& image, - ImageTokens& image_tokens, NestedPools& pools) { + ImageTokens& image_tokens, MatMulEnv* env) { HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateImageTokensT) - (model, runtime_config, image, image_tokens, pools); + (model, runtime_config, image, image_tokens, env); } #endif // HWY_ONCE diff --git a/gemma/gemma.cc b/gemma/gemma.cc index 328144b..20b1c75 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -41,23 +41,24 @@ namespace gcpp { Gemma::Gemma(const Path& tokenizer_path, const Path& weights, const ModelInfo& info, NestedPools& pools) - : pools_(pools), tokenizer_(tokenizer_path) { - model_.Load(weights, info.model, info.weight, info.wrapping, pools_.Pool(), + : env_(pools), tokenizer_(tokenizer_path) { + model_.Load(weights, info.model, info.weight, info.wrapping, + env_.parallel.Pools().Pool(0), /*tokenizer_proto=*/nullptr); } -Gemma::Gemma(const Path& weights, NestedPools& pools) : pools_(pools) { +Gemma::Gemma(const Path& weights, NestedPools& pools) : env_(pools) { std::string tokenizer_proto; model_.Load(weights, Model::UNKNOWN, Type::kUnknown, PromptWrapping::GEMMA_IT, - pools_.Pool(), &tokenizer_proto); + env_.parallel.Pools().Pool(0), &tokenizer_proto); tokenizer_.Deserialize(tokenizer_proto); } Gemma::Gemma(GemmaTokenizer&& tokenizer, const ModelInfo& info, NestedPools& pools) - : pools_(pools), tokenizer_(std::move(tokenizer)) { + : env_(pools), tokenizer_(std::move(tokenizer)) { HWY_ASSERT(info.weight == Type::kF32); - model_.Allocate(info.model, info.weight, pools_.Pool()); + model_.Allocate(info.model, info.weight, env_.parallel.Pools().Pool(0)); } Gemma::~Gemma() { @@ -72,16 +73,16 @@ Gemma::~Gemma() { const RuntimeConfig& runtime_config, \ const PromptTokens& prompt, size_t pos, \ size_t prefix_end, KVCache& kv_cache, \ - NestedPools& pools, TimingInfo& timing_info); \ + MatMulEnv* env, TimingInfo& timing_info); \ extern void GenerateBatch( \ TWEIGHT, const ModelWeightsStorage& model, \ const RuntimeConfig& runtime_config, const QueriesPromptTokens& prompts, \ const QueriesPos& queries_pos, const QueriesPos& queries_prefix_end, \ - const KVCaches& kv_caches, NestedPools& pools, TimingInfo& timing_info); \ - extern void GenerateImageTokens( \ - TWEIGHT, const ModelWeightsStorage& model, \ - const RuntimeConfig& runtime_config, const Image& image, \ - ImageTokens& image_tokens, NestedPools& pools); + const KVCaches& kv_caches, MatMulEnv* env, TimingInfo& timing_info); \ + extern void GenerateImageTokens(TWEIGHT, const ModelWeightsStorage& model, \ + const RuntimeConfig& runtime_config, \ + const Image& image, \ + ImageTokens& image_tokens, MatMulEnv* env); GEMMA_DECLARE(float) GEMMA_DECLARE(BF16) GEMMA_DECLARE(NuqStream) @@ -93,10 +94,10 @@ struct GenerateSingleT { void operator()(const ModelWeightsStorage& model, const RuntimeConfig& runtime_config, const PromptTokens& prompt, size_t pos, size_t prefix_end, - KVCache& kv_cache, NestedPools& pools, + KVCache& kv_cache, MatMulEnv* env, TimingInfo& timing_info) const { GenerateSingle(TConfig(), model, runtime_config, prompt, pos, prefix_end, - kv_cache, pools, timing_info); + kv_cache, env, timing_info); } }; @@ -107,10 +108,10 @@ struct GenerateBatchT { const QueriesPromptTokens& queries_prompt, const QueriesPos& queries_pos, const QueriesPos& queries_prefix_end, - const KVCaches& kv_caches, NestedPools& pools, + const KVCaches& kv_caches, MatMulEnv* env, TimingInfo& timing_info) const { GenerateBatch(TConfig(), model, runtime_config, queries_prompt, queries_pos, - queries_prefix_end, kv_caches, pools, timing_info); + queries_prefix_end, kv_caches, env, timing_info); } }; @@ -118,21 +119,21 @@ template struct GenerateImageTokensT { void operator()(const ModelWeightsStorage& model, const RuntimeConfig& runtime_config, const Image& image, - ImageTokens& image_tokens, NestedPools& pools) const { + ImageTokens& image_tokens, MatMulEnv* env) const { GenerateImageTokens(TConfig(), model, runtime_config, image, image_tokens, - pools); + env); } }; void Gemma::Generate(const RuntimeConfig& runtime_config, const PromptTokens& prompt, size_t pos, size_t prefix_end, KVCache& kv_cache, TimingInfo& timing_info) { - pools_.MaybeStartSpinning(runtime_config.use_spinning); + env_.parallel.Pools().MaybeStartSpinning(runtime_config.use_spinning); model_.CallForModelWeight( - runtime_config, prompt, pos, prefix_end, kv_cache, pools_, timing_info); + runtime_config, prompt, pos, prefix_end, kv_cache, &env_, timing_info); - pools_.MaybeStopSpinning(runtime_config.use_spinning); + env_.parallel.Pools().MaybeStopSpinning(runtime_config.use_spinning); } void Gemma::GenerateBatch(const RuntimeConfig& runtime_config, @@ -149,23 +150,23 @@ void Gemma::GenerateBatch(const RuntimeConfig& runtime_config, QueriesPos(prefix_end_vec.data(), prefix_end_vec.size()); } - pools_.MaybeStartSpinning(runtime_config.use_spinning); + env_.parallel.Pools().MaybeStartSpinning(runtime_config.use_spinning); model_.CallForModelWeight( runtime_config, queries_prompt, queries_pos, mutable_queries_prefix_end, - kv_caches, pools_, timing_info); + kv_caches, &env_, timing_info); - pools_.MaybeStopSpinning(runtime_config.use_spinning); + env_.parallel.Pools().MaybeStopSpinning(runtime_config.use_spinning); } void Gemma::GenerateImageTokens(const RuntimeConfig& runtime_config, const Image& image, ImageTokens& image_tokens) { - pools_.MaybeStartSpinning(runtime_config.use_spinning); + env_.parallel.Pools().MaybeStartSpinning(runtime_config.use_spinning); model_.CallForModelWeight(runtime_config, image, - image_tokens, pools_); + image_tokens, &env_); - pools_.MaybeStopSpinning(runtime_config.use_spinning); + env_.parallel.Pools().MaybeStopSpinning(runtime_config.use_spinning); } // Non-template functions moved from gemma-inl.h to avoid ODR violations. diff --git a/gemma/gemma.h b/gemma/gemma.h index 15d22a1..d0b0427 100644 --- a/gemma/gemma.h +++ b/gemma/gemma.h @@ -29,6 +29,7 @@ #include "gemma/kv_cache.h" #include "gemma/tokenizer.h" #include "gemma/weights.h" +#include "ops/matmul.h" // MatMulEnv #include "paligemma/image.h" #include "util/allocator.h" // RowVectorBatch #include "util/basics.h" // TokenAndProb @@ -246,7 +247,7 @@ class Gemma { const Image& image, ImageTokens& image_tokens); private: - NestedPools& pools_; + MatMulEnv env_; GemmaTokenizer tokenizer_; // Type-erased so that this can be defined in the header. diff --git a/ops/bench_matmul.cc b/ops/bench_matmul.cc index 3d8ff72..fa38c50 100644 --- a/ops/bench_matmul.cc +++ b/ops/bench_matmul.cc @@ -42,6 +42,7 @@ #include "util/threading.h" #include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/nanobenchmark.h" +#include "hwy/profiler.h" #include "hwy/timer.h" // clang-format off @@ -53,7 +54,6 @@ // After highway.h #include "compression/compress-inl.h" #include "ops/matmul-inl.h" -#include "hwy/profiler.h" // also uses SIMD #include "hwy/tests/test_util-inl.h" HWY_BEFORE_NAMESPACE(); @@ -119,24 +119,24 @@ MatStoragePtr GenerateTransposedMat(const Extents2D extents, void PrintSpeed(const Extents2D& A_extents, const Extents2D& B_extents, std::vector& times) { std::sort(times.begin(), times.end()); - // Many measurements are with suboptimal configs, so report the best like - // bench_dnn, but also the ratio to the 3rd best. - const double elapsed = times[0]; - const double ratio = times[2] / HWY_MAX(elapsed, 1E-6); + // bench_dnn reports the best and average, but the median seems more + // consistent and resistant to outliers. + const double elapsed = times[times.size() / 2]; + const double ratio = elapsed / (times[0] + 1E-6); // vs best, avoid / 0 const size_t num_b = B_extents.Area(); - // 2x because of FMA. - fprintf(stderr, "%.1f\t%.2f\n", 2 * 1E-9 * A_extents.rows * num_b / elapsed, - ratio); + // FMA counts as two FLOP. + fprintf(stderr, "%.1f\t(med %.3f ms = %0.2fx min)\n", + 2 * 1E-9 * A_extents.rows * num_b / elapsed, elapsed * 1E3, ratio); } // Generates inputs and prints observed throughput of MatMul. // M = A rows, K = A cols, N = C cols. template void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) { - hwy::ThreadPool& pool = env.Pool(); - fprintf(stderr, "BenchMatMul %lu, %lu, %lu, add=%d, MatTA=%s, MatTB=%s\n", M, - K, N, add, TypeName(), TypeName()); + hwy::ThreadPool& pool = env.parallel.Pools().Pool(0); + fprintf(stderr, "\nBenchMatMul %lu, %lu, %lu, add=%d, MatTA=%s, MatTB=%s\n", + M, K, N, add, TypeName(), TypeName()); const Extents2D A_extents(M, K); const Extents2D B_extents(N, K); // already transposed @@ -161,17 +161,26 @@ void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) { const float* add_row = add ? add_storage->data_scale1() : nullptr; const RowPtrF C = RowPtrFromBatch(c_batch); + constexpr size_t kSamples = 20; std::vector times; - times.reserve(20); - double result = 0.0; - for (;;) { + times.reserve(kSamples); + + Tristate use_spinning = Tristate::kDefault; + env.parallel.Pools().MaybeStartSpinning(use_spinning); + + double keep = 0.0; + // Until enough samples collected *after* autotuning finished: + while (times.size() < kSamples) { const double t0 = hwy::platform::Now(); MatMul(A, B, add_row, env, C); - times.push_back(hwy::platform::Now() - t0); - result += C.Row(0)[hwy::Unpredictable1()]; - if (times.size() >= 20) break; + const double t1 = hwy::platform::Now(); + double elapsed = t1 - t0; + keep += C.Row(0)[hwy::Unpredictable1()]; + + times.push_back(elapsed); } - hwy::PreventElision(result); + hwy::PreventElision(keep); + env.parallel.Pools().MaybeStopSpinning(use_spinning); PrintSpeed(A_extents, B_extents, times); } @@ -182,7 +191,7 @@ void BenchAllMatMul() { if (first_target == 0) first_target = HWY_TARGET; if (HWY_TARGET != first_target) return; - for (size_t max_packages : {1, 2}) { + for (size_t max_packages : {/*1,*/ 2}) { const size_t max_threads = 0; // no limit NestedPools pools(max_threads, Tristate::kDefault, BoundedSlice(0, max_packages)); @@ -195,17 +204,14 @@ void BenchAllMatMul() { fprintf(stderr, "BenchAllMatMul %zu: %s %s\n", max_packages, pools.TopologyString(), pools.PinString()); - Tristate use_spinning = Tristate::kDefault; - pools.MaybeStartSpinning(use_spinning); Allocator::Init(pools.Topology()); MatMulEnv env(pools); - for (size_t batch_size : {1, /* 4, 128,*/ 512}) { + for (size_t batch_size : {1, 4, 128, 512}) { constexpr bool kAdd = false; BenchMatMul(batch_size, 24576, 3072, kAdd, env); BenchMatMul(batch_size, 3072, 24576, kAdd, env); } - pools.MaybeStopSpinning(use_spinning); } PROFILER_PRINT_RESULTS(); diff --git a/ops/dot-inl.h b/ops/dot-inl.h index cbb34f6..f5282f2 100644 --- a/ops/dot-inl.h +++ b/ops/dot-inl.h @@ -17,6 +17,7 @@ #include "compression/compress.h" #include "hwy/base.h" +#include "hwy/profiler.h" // Include guard for (potentially) SIMD code. #if defined(THIRD_PARTY_GEMMA_CPP_DOT_TOGGLE) == defined(HWY_TARGET_TOGGLE) @@ -31,7 +32,6 @@ #include "compression/compress-inl.h" #include "ops/fp_arith-inl.h" #include "hwy/contrib/math/math-inl.h" -#include "hwy/profiler.h" // also uses SIMD HWY_BEFORE_NAMESPACE(); namespace gcpp { diff --git a/ops/dot_test.cc b/ops/dot_test.cc index 547cc63..770ad5a 100644 --- a/ops/dot_test.cc +++ b/ops/dot_test.cc @@ -32,6 +32,7 @@ #include "util/threading.h" #include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" +#include "hwy/profiler.h" #include "hwy/stats.h" #include "hwy/timer.h" @@ -44,7 +45,6 @@ // After highway.h #include "compression/test_util-inl.h" #include "ops/dot-inl.h" -#include "hwy/profiler.h" // also uses SIMD HWY_BEFORE_NAMESPACE(); namespace gcpp { diff --git a/ops/matmul.h b/ops/matmul.h index 00392ce..c1c6d44 100644 --- a/ops/matmul.h +++ b/ops/matmul.h @@ -65,22 +65,36 @@ struct CacheSizes { size_t l3_bytes; }; +class MMParallel { + public: + MMParallel() : pools_(nullptr) {} + explicit MMParallel(NestedPools& pools) : pools_(&pools) {} + + NestedPools& Pools() const { return *pools_; } + hwy::ThreadPool& Pool() const { return pools_->Pool(); } + + private: + NestedPools* pools_; +}; + // Allocations and threads, shared across MatMul calls. class MatMulEnv { public: - MatMulEnv() : pools_(nullptr) {} - explicit MatMulEnv(NestedPools& pools) : pools_(&pools) { + explicit MatMulEnv(NestedPools& pools) : parallel(pools) { const size_t N = hwy::VectorBytes() / sizeof(float); buf_ = RowVectorBatch(Extents2D(pools.MaxWorkers(), 16 * N)); } RowVectorBatch& Buf() { return buf_; } - NestedPools& Pools() const { return *pools_; } - hwy::ThreadPool& Pool() const { return pools_->Pool(); } + + MMParallel parallel; + + // TODO: remove once no longer used. + NestedPools& Pools() const { return parallel.Pools(); } + hwy::ThreadPool& Pool() const { return parallel.Pool(); } private: RowVectorBatch buf_; - NestedPools* pools_; }; // Used for the A and B arguments of `MatMul`, which are always const. diff --git a/ops/matmul_test.cc b/ops/matmul_test.cc index 7239912..3dc90b1 100644 --- a/ops/matmul_test.cc +++ b/ops/matmul_test.cc @@ -137,8 +137,8 @@ float MaxAbs(const float* HWY_RESTRICT a, const Extents2D& extents) { } // B is already transposed. -template -void AssertClose(const ConstMat& A, const ConstMat& B, +template +void AssertClose(const ConstMat& A, const ConstMat& B, const RowPtrF& C_slow, const RowPtrF& C) { const hn::ScalableTag df; const size_t num_a = A.extents.Area(); @@ -160,13 +160,13 @@ void AssertClose(const ConstMat& A, const ConstMat& B, MaxAbs(a.get(), A.Extents()) * MaxAbs(b_trans.get(), B.Extents()); const double eps_bf16 = hwy::ConvertScalarTo(hwy::Epsilon()); const double eps_f32 = hwy::ConvertScalarTo(hwy::Epsilon()); - double tolerance = 8 * norm * eps_f32; + double tolerance = 10 * norm * eps_f32; // Dot() also rounds F32,BF16 to BF16, but not with F32,F32, so increase the // tolerance there. - if (IsF32() && !IsF32()) { + if (IsF32() && IsF32()) { tolerance += 4 * max_abs * eps_bf16; } - if (tolerance > 8.0) { + if (tolerance > 500.0) { HWY_WARN("high tolerance %f norm %f maxabs %f\n", tolerance, norm, max_abs); } @@ -189,23 +189,22 @@ void AssertClose(const ConstMat& A, const ConstMat& B, } // B is already transposed. -template -HWY_INLINE void MatMulSlow(const ConstMat A, const ConstMat B, +template +HWY_INLINE void MatMulSlow(const ConstMat A, const ConstMat B, const float* HWY_RESTRICT add_row, MatMulEnv& env, const RowPtrF& C) { - // MatTA can be any Packed except NuqStream because it uses pointer + // TA can be any Packed except NuqStream because it uses pointer // arithmetic, because it is the second argument to Dot, which does not // support a v_ofs. - static_assert(sizeof(MatTA) >= sizeof(BF16), "A matrix must be BF16/f32"); + static_assert(sizeof(TA) >= sizeof(BF16), "A matrix must be BF16/f32"); const float scale = A.scale * B.scale; const hn::ScalableTag df; // lane type is ignored - const PackedSpan b_span = - MakeSpan(B.ptr, B.ofs + B.extents.Area()); + const PackedSpan b_span = MakeSpan(B.ptr, B.ofs + B.extents.Area()); const IndexRange all_rows_c(0, A.Extents().rows); const IndexRange all_cols_c(0, C.Cols()); - NestedPools& pools = env.Pools(); + NestedPools& pools = env.parallel.Pools(); hwy::ThreadPool& all_packages = pools.AllPackages(); const IndexRangePartition get_row_c = StaticPartition(all_rows_c, all_packages.NumWorkers(), 1); @@ -213,7 +212,7 @@ HWY_INLINE void MatMulSlow(const ConstMat A, const ConstMat B, get_row_c, all_packages, [&](const IndexRange& rows_c, size_t package_idx) HWY_ATTR { hwy::ThreadPool& all_clusters = pools.AllClusters(package_idx); - const size_t multiple = Allocator::QuantumBytes() / sizeof(MatTB); + const size_t multiple = Allocator::QuantumBytes() / sizeof(TB); const IndexRangePartition get_col_c = StaticPartition(all_cols_c, all_clusters.NumWorkers(), multiple); ParallelizeOneRange( @@ -240,20 +239,19 @@ void PrintSpeed(const char* algo, const Extents2D& A_extents, elapsed, 2 * 1E-9 * A_extents.rows * num_b / elapsed); } -template +template void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add, MatMulEnv& env) { - hwy::ThreadPool& pool = env.Pool(); - fprintf(stderr, "TestMatMul %lu, %lu, %lu, add=%d, MatTA=%s, MatTB=%s\n", - rows_ac, cols_a_rows_b, cols_bc, add, TypeName(), - TypeName()); + hwy::ThreadPool& pool = env.parallel.Pools().Pool(); + fprintf(stderr, "TestMatMul %zu, %zu, %zu, add=%d, TA=%s, TB=%s\n", rows_ac, + cols_a_rows_b, cols_bc, add, TypeName(), TypeName()); const Extents2D A_extents(rows_ac, cols_a_rows_b); const Extents2D B_extents(cols_bc, cols_a_rows_b); // already transposed const Extents2D C_extents(rows_ac, cols_bc); - MatStoragePtr a = GenerateMat(A_extents, pool); - MatStoragePtr b_trans = GenerateTransposedMat(B_extents, pool); + MatStoragePtr a = GenerateMat(A_extents, pool); + MatStoragePtr b_trans = GenerateTransposedMat(B_extents, pool); RowVectorBatch c_slow_batch(C_extents); RowVectorBatch c_batch(C_extents); HWY_ASSERT(a && b_trans); @@ -303,8 +301,13 @@ void TestTiny() { Allocator::Init(pools.Topology()); MatMulEnv env(pools); - for (size_t batch_size = 1; batch_size <= 3 * kRegRows; ++batch_size) { - TestMatMul(batch_size, 256, 256, /*add=*/false, env); + for (size_t M = 1; M <= 3 * kRegRows; ++M) { + for (size_t K = 64; K <= 128; K *= 2) { + for (size_t N = /*kRegRows*/ 16; N <= 64; + N += max_packages * kRegRows) { + TestMatMul(M, K, N, /*add=*/false, env); + } + } } pools.MaybeStopSpinning(use_spinning); } diff --git a/ops/matvec-inl.h b/ops/matvec-inl.h index 95f7f9e..7ad56e7 100644 --- a/ops/matvec-inl.h +++ b/ops/matvec-inl.h @@ -23,6 +23,7 @@ #include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" +#include "hwy/profiler.h" #endif // THIRD_PARTY_GEMMA_CPP_OPS_MATVEC_INL_H_ @@ -38,7 +39,6 @@ #include "ops/dot-inl.h" #include "hwy/contrib/math/math-inl.h" #include "hwy/contrib/matvec/matvec-inl.h" -#include "hwy/profiler.h" // also uses SIMD HWY_BEFORE_NAMESPACE(); namespace gcpp { diff --git a/ops/ops-inl.h b/ops/ops-inl.h index f8d54e7..3da48e1 100644 --- a/ops/ops-inl.h +++ b/ops/ops-inl.h @@ -32,6 +32,7 @@ #include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/detect_targets.h" +#include "hwy/profiler.h" #endif // THIRD_PARTY_GEMMA_CPP_OPS_OPS_INL_H_ // Include guard for (potentially) SIMD code. @@ -47,7 +48,6 @@ #include "ops/sum-inl.h" #include "hwy/contrib/algo/transform-inl.h" #include "hwy/contrib/math/math-inl.h" -#include "hwy/profiler.h" // also uses SIMD HWY_BEFORE_NAMESPACE(); namespace gcpp { diff --git a/ops/ops.h b/ops/ops.h new file mode 100644 index 0000000..6c243da --- /dev/null +++ b/ops/ops.h @@ -0,0 +1,45 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_GEMMA_CPP_OPS_OPS_H_ +#define THIRD_PARTY_GEMMA_CPP_OPS_OPS_H_ + +#include + +#include + +#include "util/allocator.h" +#include "hwy/base.h" + +namespace gcpp { + +static inline HWY_MAYBE_UNUSED RowVectorBatch CreateInvTimescale( + size_t qkv_dim, bool half_rope, double base_frequency = 10000.0) { + const size_t rope_dim = half_rope ? qkv_dim / 2 : qkv_dim; + RowVectorBatch inv_timescale(Extents2D(1, rope_dim / 2)); + for (size_t dim = 0; dim < rope_dim / 2; ++dim) { + const double freq_exponents = + static_cast(2 * dim) / static_cast(rope_dim); + // Replacing with expf(ln(1E4) * freq_exponents) changes results + // noticeably. + inv_timescale.Batch(0)[dim] = + static_cast(1.0 / std::pow(base_frequency, freq_exponents)); + } + return inv_timescale; +} + +} // namespace gcpp + +#endif // THIRD_PARTY_GEMMA_CPP_OPS_OPS_H_ diff --git a/ops/ops_test.cc b/ops/ops_test.cc index 93b8f31..ddf5ec6 100644 --- a/ops/ops_test.cc +++ b/ops/ops_test.cc @@ -18,6 +18,8 @@ #define HWY_DISABLED_TARGETS HWY_SCALAR #endif +#include "ops/ops.h" + #include #include @@ -30,12 +32,10 @@ #include #include "compression/compress.h" // BF16 -#include "gemma/activations.h" #include "gemma/common.h" #include "gemma/configs.h" #include "util/allocator.h" #include "util/test_util.h" -#include "hwy/aligned_allocator.h" #include "hwy/base.h" #include "hwy/tests/hwy_gtest.h" @@ -407,8 +407,9 @@ void TestRopeAndMulBy() { std::vector qactual(dim_qkv); std::vector kexpected(dim_qkv); std::vector kactual(dim_qkv); - RowVectorBatch inv_timescale = gcpp::Activations::CreateInvTimescale( - config.layer_configs[0].qkv_dim, config.layer_configs[0].post_qk); + RowVectorBatch inv_timescale = gcpp::CreateInvTimescale( + config.layer_configs[0].qkv_dim, + config.layer_configs[0].post_qk == PostQKType::HalfRope); // Assert VectorizedRope computation is same as regular rope at different pos. for (int pos = 1; pos < 500; pos++) { // Rope'd Q embeddings diff --git a/util/allocator.cc b/util/allocator.cc index 8915c4a..3ff2b5a 100644 --- a/util/allocator.cc +++ b/util/allocator.cc @@ -18,7 +18,6 @@ #include #include -#include #include #include "util/basics.h" // MaybeCheckInitialized @@ -76,31 +75,42 @@ size_t DetectPageSize() { static size_t line_bytes_; static size_t vector_bytes_; +static size_t step_bytes_; static size_t quantum_bytes_; +static size_t quantum_steps_; static size_t l1_bytes_; static size_t l2_bytes_; +static size_t l3_bytes_; static bool should_bind_ = false; size_t Allocator::LineBytes() { return line_bytes_; } size_t Allocator::VectorBytes() { return vector_bytes_; } +size_t Allocator::StepBytes() { return step_bytes_; } size_t Allocator::QuantumBytes() { return quantum_bytes_; } +size_t Allocator::QuantumSteps() { return quantum_steps_; } size_t Allocator::L1Bytes() { return l1_bytes_; } size_t Allocator::L2Bytes() { return l2_bytes_; } +size_t Allocator::L3Bytes() { return l3_bytes_; } bool Allocator::ShouldBind() { return should_bind_; } void Allocator::Init(const BoundedTopology& topology) { line_bytes_ = DetectLineBytes(); vector_bytes_ = hwy::VectorBytes(); - quantum_bytes_ = HWY_MAX(line_bytes_, vector_bytes_); // may overwrite below + step_bytes_ = HWY_MAX(line_bytes_, vector_bytes_); + quantum_bytes_ = step_bytes_; // may overwrite below + const BoundedTopology::Cluster& cluster = topology.GetCluster(0, 0); if (const hwy::Cache* caches = hwy::DataCaches()) { l1_bytes_ = caches[1].size_kib << 10; l2_bytes_ = caches[2].size_kib << 10; + l3_bytes_ = (caches[3].size_kib << 10) * caches[3].cores_sharing; } else { // Unknown, make reasonable assumptions. - const BoundedTopology::Cluster& cluster = topology.GetCluster(0, 0); l1_bytes_ = 32 << 10; l2_bytes_ = (cluster.PrivateKiB() ? cluster.PrivateKiB() : 256) << 10; } + if (l3_bytes_ == 0) { + l3_bytes_ = (cluster.SharedKiB() ? cluster.SharedKiB() : 1024) << 10; + } // Prerequisites for binding: // - supported by the OS (currently Linux only), @@ -115,9 +125,15 @@ void Allocator::Init(const BoundedTopology& topology) { // Ensure pages meet the alignment requirements of `AllocBytes`. HWY_ASSERT(page_bytes >= quantum_bytes_); quantum_bytes_ = page_bytes; + // Ensure MaxQuantumBytes() is an upper bound. + HWY_ASSERT(MaxQuantumBytes() >= quantum_bytes_); + quantum_bytes_ = HWY_MIN(quantum_bytes_, MaxQuantumBytes()); should_bind_ = true; } } + + HWY_DASSERT(quantum_bytes_ % step_bytes_ == 0); + quantum_steps_ = quantum_bytes_ / step_bytes_; } Allocator::PtrAndDeleter Allocator::AllocBytes(size_t bytes) { diff --git a/util/allocator.h b/util/allocator.h index 43dae62..4007b22 100644 --- a/util/allocator.h +++ b/util/allocator.h @@ -72,12 +72,17 @@ class Allocator { static size_t LineBytes(); // Bytes per full vector. Used to compute loop steps. static size_t VectorBytes(); - // Granularity of regions processed by different threads. Their start and - // length of regions should be divisible by this, which is at least - // `HWY_MAX(LineBytes(), VectorBytes())`. + // Work granularity that avoids false sharing and partial vectors. + static size_t StepBytes(); // = HWY_MAX(LineBytes(), VectorBytes()) + // Granularity like `StepBytes()`, but when NUMA may be involved. static size_t QuantumBytes(); + // Upper bound on `QuantumBytes()`, for stack allocations. + static constexpr size_t MaxQuantumBytes() { return 4096; } + static size_t QuantumSteps(); // = QuantumBytes() / StepBytes() + static size_t L1Bytes(); static size_t L2Bytes(); + static size_t L3Bytes(); // Returns pointer aligned to `QuantumBytes()`. template @@ -192,10 +197,9 @@ class RowPtr { RowPtr(T* HWY_RESTRICT row0, size_t cols, size_t stride) : row0_(row0), stride_(stride), - step_(static_cast( - HWY_MAX(Allocator::LineBytes(), Allocator::VectorBytes()))), + step_(static_cast(Allocator::StepBytes())), cols_(static_cast(cols)), - row_mask_(Allocator::QuantumBytes() / step_ - 1) { + row_mask_(Allocator::QuantumSteps() - 1) { HWY_DASSERT(stride >= cols); HWY_DASSERT(row_mask_ != ~size_t{0}); row_mask_ = 0; // TODO: remove diff --git a/util/basics.h b/util/basics.h index b25a9ef..c296934 100644 --- a/util/basics.h +++ b/util/basics.h @@ -85,6 +85,9 @@ struct IndexRange { IndexRange& operator=(const IndexRange& other) = default; size_t Num() const { return end_ - begin_; } + bool Contains(IndexRange other) const { + return other.begin_ >= begin_ && other.end_ <= end_; + } // Enable range-based for loops. class Iterator { diff --git a/util/threading.cc b/util/threading.cc index a10862e..1671187 100644 --- a/util/threading.cc +++ b/util/threading.cc @@ -89,22 +89,28 @@ class Pinning { // If want_pin_, tries to pin each worker in `pool` to an LP in `cluster`, // and sets `any_error_` if any fails. - void MaybePin(const BoundedTopology::Cluster& cluster, PoolPtr& pool) { - if (HWY_UNLIKELY(!want_pin_)) return; - + void MaybePin(size_t pkg_idx, size_t cluster_idx, + const BoundedTopology::Cluster& cluster, PoolPtr& pool) { const std::vector lps = cluster.LPVector(); HWY_ASSERT(pool->NumWorkers() <= lps.size()); - pool->Run( - 0, pool->NumWorkers(), - [this, &pool, &lps](uint64_t task, size_t thread) { - HWY_ASSERT(task == thread); // each worker has one task - if (HWY_UNLIKELY(!hwy::PinThreadToLogicalProcessor(lps[task]))) { - fprintf(stderr, - "Pinning failed for task %zu of %zu to %zu (size %zu)\n", - task, pool->NumWorkers(), lps[task], lps.size()); - (void)any_error_.test_and_set(); - } - }); + pool->Run(0, pool->NumWorkers(), [&](uint64_t task, size_t thread) { + HWY_ASSERT(task == thread); // each worker has one task + + char buf[16]; // Linux limitation + const int bytes_written = snprintf(buf, sizeof(buf), "P%zu X%02zu C%03zu", + pkg_idx, cluster_idx, task); + HWY_ASSERT(bytes_written < sizeof(buf)); + hwy::SetThreadName(buf, 0); // does not support varargs + + if (HWY_LIKELY(want_pin_)) { + if (HWY_UNLIKELY(!hwy::PinThreadToLogicalProcessor(lps[task]))) { + fprintf(stderr, + "Pinning failed for task %zu of %zu to %zu (size %zu)\n", + task, pool->NumWorkers(), lps[task], lps.size()); + (void)any_error_.test_and_set(); + } + } + }); } // Called ONCE after all MaybePin because it invalidates the error status. @@ -466,7 +472,8 @@ NestedPools::Package::Package(const BoundedTopology& topology, size_t pkg_idx, clusters_[cluster_idx] = MakePool(CapIfNonZero(cluster.Size(), max_workers_per_cluster)); // Pin workers AND the calling thread from `all_clusters`. - GetPinning().MaybePin(cluster, clusters_[cluster_idx]); + GetPinning().MaybePin(pkg_idx, cluster_idx, cluster, + clusters_[cluster_idx]); }); } diff --git a/util/threading.h b/util/threading.h index 6ac2859..c2db6ba 100644 --- a/util/threading.h +++ b/util/threading.h @@ -390,6 +390,25 @@ class IndexRangePartition { TaskSize()); } + template + void VisitAll(const Func& func) const { + for (size_t task_idx = 0; task_idx < NumTasks(); ++task_idx) { + func(Range(task_idx)); + } + } + + template + void VisitFirst(const Func& func) const { + func(Range(0)); + } + + template + void VisitRemaining(const Func& func) const { + for (size_t task_idx = 1; task_idx < NumTasks(); ++task_idx) { + func(Range(task_idx)); + } + } + private: IndexRange range_; uint32_t task_size_;