diff --git a/BUILD.bazel b/BUILD.bazel index ffd5435..38d79cf 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -95,7 +95,6 @@ cc_library( ":topology", # Placeholder for container detection, do not remove "@highway//:hwy", - "@highway//:profiler", "@highway//:thread_pool", "@highway//:topology", ], @@ -124,7 +123,9 @@ cc_library( srcs = ["util/zones.cc"], hdrs = ["util/zones.h"], deps = [ + "@highway//:hwy", "@highway//:profiler", + "@highway//:thread_pool", ], ) @@ -258,7 +259,6 @@ cc_library( "//io:fields", "@highway//:hwy", "@highway//:profiler", - "@highway//:thread_pool", ], ) @@ -278,7 +278,6 @@ cc_library( "//io:blob_store", "@highway//:hwy", "@highway//:profiler", - "@highway//:thread_pool", ], ) @@ -309,7 +308,6 @@ cc_library( deps = [ ":allocator", ":basics", - ":configs", ":mat", ":threading", ":threading_context", @@ -397,7 +395,6 @@ cc_library( "@highway//:hwy", "@highway//:math", "@highway//:profiler", - "@highway//:thread_pool", "@highway//hwy/contrib/sort:vqsort", ], ) @@ -424,7 +421,6 @@ cc_test( "@highway//:nanobenchmark", #buildcleaner: keep "@highway//:profiler", "@highway//:stats", - "@highway//:thread_pool", ], ) @@ -507,7 +503,6 @@ cc_test( "@highway//:hwy_test_util", "@highway//:nanobenchmark", "@highway//:profiler", - "@highway//:thread_pool", ], ) diff --git a/CMakeLists.txt b/CMakeLists.txt index 3eb2046..a707078 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -22,7 +22,7 @@ set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_EXPORT_COMPILE_COMMANDS ON) -FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 9781a1698ee0756ef1eaaf96930113ed7cb6d3ee EXCLUDE_FROM_ALL) +FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 2a16a50ff61071bb25ddef0ce35d92b0e2b9c579 EXCLUDE_FROM_ALL) FetchContent_MakeAvailable(highway) ## Note: absl needs to be installed by sentencepiece. This will only happen if diff --git a/MODULE.bazel b/MODULE.bazel index e0ba1c7..861daba 100644 --- a/MODULE.bazel +++ b/MODULE.bazel @@ -18,7 +18,7 @@ bazel_dep(name = "google_benchmark", version = "1.8.5") # Require a more recent version. git_override( module_name = "highway", - commit = "9781a1698ee0756ef1eaaf96930113ed7cb6d3ee", + commit = "2a16a50ff61071bb25ddef0ce35d92b0e2b9c579", remote = "https://github.com/google/highway", ) diff --git a/README.md b/README.md index 722c2a8..6294920 100644 --- a/README.md +++ b/README.md @@ -452,7 +452,7 @@ FetchContent_MakeAvailable(sentencepiece) FetchContent_Declare(gemma GIT_REPOSITORY https://github.com/google/gemma.cpp GIT_TAG origin/main) FetchContent_MakeAvailable(gemma) -FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 9781a1698ee0756ef1eaaf96930113ed7cb6d3ee) +FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 2a16a50ff61071bb25ddef0ce35d92b0e2b9c579) FetchContent_MakeAvailable(highway) ``` diff --git a/compression/BUILD.bazel b/compression/BUILD.bazel index c7232e6..c04bd08 100644 --- a/compression/BUILD.bazel +++ b/compression/BUILD.bazel @@ -120,9 +120,9 @@ cc_library( ":compress", ":distortion", "//:mat", + "//:threading_context", "@highway//:hwy", "@highway//:hwy_test_util", - "@highway//:thread_pool", ], ) @@ -180,6 +180,7 @@ cc_library( ":sfp", "//:basics", "//:mat", + "//:threading_context", "@highway//:hwy", "@highway//:nanobenchmark", "@highway//:profiler", @@ -203,9 +204,9 @@ cc_test( ":test_util", "@googletest//:gtest_main", # buildcleaner: keep "//:test_util", + "//:threading_context", "@highway//:hwy", "@highway//:hwy_test_util", - "@highway//:thread_pool", ], ) diff --git a/compression/compress-inl.h b/compression/compress-inl.h index 35f0433..10ce57c 100644 --- a/compression/compress-inl.h +++ b/compression/compress-inl.h @@ -26,10 +26,10 @@ #include "compression/compress.h" // IWYU pragma: export #include "compression/distortion.h" +#include "util/threading_context.h" #include "hwy/aligned_allocator.h" #include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" -#include "hwy/timer.h" #if COMPRESS_STATS #include // lroundf @@ -493,13 +493,13 @@ struct CompressTraits { } }; -// Compresses `num` elements of `raw` to `packed` starting at `packed_ofs`, -// which is useful for compressing sub-regions of an array. +// DEPRECATED: Use the overload with ThreadingContext instead. template HWY_NOINLINE void Compress(const float* HWY_RESTRICT raw, size_t num, CompressWorkingSet& work, const PackedSpan& packed, - const size_t packed_ofs, hwy::ThreadPool& pool) { + const size_t packed_ofs, hwy::ThreadPool& pool, + hwy::pool::Caller caller = hwy::pool::Caller()) { packed.BoundsCheck(packed_ofs, num); work.tls.resize(pool.NumWorkers()); if constexpr (COMPRESS_STATS) { @@ -511,7 +511,7 @@ HWY_NOINLINE void Compress(const float* HWY_RESTRICT raw, size_t num, using Traits = CompressTraits>; constexpr size_t kBatch = 8192; const size_t num_batches = hwy::DivCeil(num, kBatch); - pool.Run(0, num_batches, + pool.Run(0, num_batches, caller, [&](const uint32_t idx_batch, size_t thread) HWY_ATTR { const hn::ScalableTag df; @@ -530,6 +530,17 @@ HWY_NOINLINE void Compress(const float* HWY_RESTRICT raw, size_t num, } } +// Compresses `num` elements of `raw` to `packed` starting at `packed_ofs`, +// which is useful for compressing sub-regions of an array. +template +HWY_NOINLINE void Compress(const float* HWY_RESTRICT raw, size_t num, + CompressWorkingSet& work, + const PackedSpan& packed, + const size_t packed_ofs, ThreadingContext& ctx) { + Compress(raw, num, work, packed, packed_ofs, ctx.pools.Pool(), + ctx.pool_callers.Get(Callers::kCompress)); +} + // Same as above, but without parallelization nor benchmarking. template HWY_NOINLINE void Compress(const float* HWY_RESTRICT raw, size_t num, diff --git a/compression/compress_test.cc b/compression/compress_test.cc index 2ee7f63..987f409 100644 --- a/compression/compress_test.cc +++ b/compression/compress_test.cc @@ -24,9 +24,9 @@ #include "compression/compress.h" #include "compression/distortion.h" #include "util/test_util.h" +#include "util/threading_context.h" #include "hwy/aligned_allocator.h" #include "hwy/base.h" -#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/tests/hwy_gtest.h" // clang-format off #undef HWY_TARGET_INCLUDE @@ -42,6 +42,17 @@ namespace gcpp { namespace HWY_NAMESPACE { namespace hn = hwy::HWY_NAMESPACE; +static ThreadingArgs SingleThreadArgs() { + ThreadingArgs args; + args.max_lps = 1; + return args; +} + +static ThreadingContext& Ctx() { + static ThreadingContext* ctx = new ThreadingContext(SingleThreadArgs()); + return *ctx; +} + // Calls Compress and Decompress2 and verifies the distortion/error. template struct TestDecompress2 { @@ -49,7 +60,9 @@ struct TestDecompress2 { HWY_INLINE void operator()(T /*unused*/, D d) { const size_t N = hn::Lanes(d); CompressWorkingSet work; - hwy::ThreadPool pool(0); + ThreadingArgs args; + args.max_lps = 1; + ThreadingContext ctx(args); hwy::RandomState rng; const size_t num = 2 * N; @@ -68,7 +81,7 @@ struct TestDecompress2 { // Short inputs fail VerifyGaussian. const size_t packed_ofs = 0; - Compress(raw.get(), num, work, packed_span, packed_ofs, pool); + Compress(raw.get(), num, work, packed_span, packed_ofs, ctx); hn::Vec raw0, raw1; Decompress2(d, MakeConst(packed_span), packed_ofs, raw0, raw1); hn::Store(raw0, d, dec.get()); @@ -129,7 +142,6 @@ struct TestShortLengths { HWY_INLINE void operator()(T /*unused*/, D d) { const size_t N = hn::Lanes(d); CompressWorkingSet work; - hwy::ThreadPool pool(0); hwy::RandomState rng; for (size_t num = 1; num < 5 * hn::Lanes(d); ++num) { @@ -149,7 +161,7 @@ struct TestShortLengths { // Short inputs fail VerifyGaussian. const size_t packed_ofs = 0; - Compress(raw.get(), num, work, packed_span, packed_ofs, pool); + Compress(raw.get(), num, work, packed_span, packed_ofs, Ctx()); DecompressAndZeroPad(d, MakeConst(packed_span), packed_ofs, dec.get(), num); diff --git a/compression/python/compression_clif_aux.cc b/compression/python/compression_clif_aux.cc index 5f227ac..3568ad3 100644 --- a/compression/python/compression_clif_aux.cc +++ b/compression/python/compression_clif_aux.cc @@ -37,7 +37,6 @@ #include "util/basics.h" #include "util/mat.h" #include "util/threading_context.h" -#include "hwy/contrib/thread_pool/thread_pool.h" #undef HWY_TARGET_INCLUDE #define HWY_TARGET_INCLUDE \ @@ -57,9 +56,6 @@ class SbsWriterImpl : public ISbsWriter { template void InsertT(const char* name, F32Span weights, const TensorInfo& tensor_info) { - // TODO(janwas): 1D parallel-for. - hwy::ThreadPool& pool = ctx_.pools.Pool(); - MatPtrT mat(name, ExtentsFromInfo(&tensor_info)); // SFP and NUQ (which uses SFP for cluster centers) have a limited range // and depending on the input values may require rescaling. Scaling is @@ -82,21 +78,20 @@ class SbsWriterImpl : public ISbsWriter { // succeeds, but we only have 10 floats, not the full tensor. if (weights.size() == 10 && mat.Extents().Area() != 10) { Compress(weights.data(), weights.size(), working_set_, mat.Span(), - /*packed_ofs=*/0, pool); + /*packed_ofs=*/0, ctx_); writer_.Add(name, mat.Packed(), mat.ElementBytes() * 10); return; } HWY_ASSERT(weights.size() == mat.Extents().Area()); Compress(weights.data(), weights.size(), working_set_, mat.Span(), - /*packed_ofs=*/0, pool); + /*packed_ofs=*/0, ctx_); writer_.Add(name, mat.Packed(), mat.PackedBytes()); } public: SbsWriterImpl(const std::string& sbs_path) - : ctx_(ThreadingArgs()), - writer_(gcpp::Path(sbs_path), ctx_.pools.Pool()) {} + : ctx_(ThreadingArgs()), writer_(gcpp::Path(sbs_path), ctx_) {} void Insert(const char* name, F32Span weights, Type type, const TensorInfo& tensor_info) override { diff --git a/compression/test_util-inl.h b/compression/test_util-inl.h index 1c72b32..81442f7 100644 --- a/compression/test_util-inl.h +++ b/compression/test_util-inl.h @@ -23,7 +23,7 @@ // IWYU pragma: end_exports #include "compression/compress.h" -#include "hwy/contrib/thread_pool/thread_pool.h" +#include "util/threading_context.h" #endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_TEST_UTIL_INL_H_ @@ -98,25 +98,26 @@ void ForeachActivationType3(D d) { // Generates inputs: deterministic, within max SfpStream range. template -MatStorageT GenerateMat(const Extents2D& extents, - const Allocator& allocator, MatPadding padding, - hwy::ThreadPool& pool) { +MatStorageT GenerateMat(const Extents2D& extents, MatPadding padding, + ThreadingContext& ctx) { gcpp::CompressWorkingSet ws; - ws.tls.resize(pool.NumWorkers()); - MatStorageT raw("raw", extents, allocator, MatPadding::kPacked); - MatStorageT compressed("mat", extents, allocator, padding); + ws.tls.resize(ctx.pools.MaxWorkers()); + MatStorageT raw("raw", extents, ctx.allocator, MatPadding::kPacked); + MatStorageT compressed("mat", extents, ctx.allocator, padding); const float scale = SfpStream::kMax / extents.Area(); - pool.Run(0, extents.rows, [&](const size_t r, size_t thread) { - float* HWY_RESTRICT row = raw.Row(r); - for (size_t c = 0; c < extents.cols; c++) { - float f = static_cast(r * extents.cols + c) * scale; - if ((r + c) & 1) f = -f; // Also generate some negative values. - row[c] = f; - } - Compress(raw.Row(r), raw.Cols(), ws.tls[thread], - MakeSpan(compressed.Row(r), extents.cols), - /*packed_ofs=*/0); - }); + ParallelFor(ParallelismStrategy::kFlat, extents.rows, ctx, /*cluster_idx=*/0, + Callers::kTest, [&](size_t r, size_t thread) { + float* HWY_RESTRICT row = raw.Row(r); + for (size_t c = 0; c < extents.cols; c++) { + float f = static_cast(r * extents.cols + c) * scale; + if ((r + c) & 1) + f = -f; // Also generate some negative values. + row[c] = f; + } + Compress(raw.Row(r), raw.Cols(), ws.tls[thread], + MakeSpan(compressed.Row(r), extents.cols), + /*packed_ofs=*/0); + }); compressed.SetScale(0.6f); // Arbitrary value, different from 1. return compressed; @@ -126,25 +127,26 @@ MatStorageT GenerateMat(const Extents2D& extents, // `f` swaps `r` and `c`. template MatStorageT GenerateTransposedMat(const Extents2D extents, - const Allocator& allocator, MatPadding padding, - hwy::ThreadPool& pool) { + ThreadingContext& ctx) { gcpp::CompressWorkingSet ws; - ws.tls.resize(pool.NumWorkers()); - MatStorageT raw("raw", extents, allocator, MatPadding::kPacked); - MatStorageT compressed("trans", extents, allocator, padding); + ws.tls.resize(ctx.pools.MaxWorkers()); + MatStorageT raw("raw", extents, ctx.allocator, MatPadding::kPacked); + MatStorageT compressed("trans", extents, ctx.allocator, padding); const float scale = SfpStream::kMax / extents.Area(); - pool.Run(0, extents.rows, [&](const size_t r, size_t thread) { - float* HWY_RESTRICT row = raw.Row(r); - for (size_t c = 0; c < extents.cols; c++) { - float f = static_cast(c * extents.rows + r) * scale; - if ((r + c) & 1) f = -f; // Also generate some negative values. - row[c] = f; - } - Compress(raw.Row(r), raw.Cols(), ws.tls[thread], - MakeSpan(compressed.Row(r), extents.cols), - /*packed_ofs=*/0); - }); + ParallelFor(ParallelismStrategy::kFlat, extents.rows, ctx, /*cluster_idx=*/0, + Callers::kTest, [&](size_t r, size_t thread) { + float* HWY_RESTRICT row = raw.Row(r); + for (size_t c = 0; c < extents.cols; c++) { + float f = static_cast(c * extents.rows + r) * scale; + if ((r + c) & 1) + f = -f; // Also generate some negative values. + row[c] = f; + } + Compress(raw.Row(r), raw.Cols(), ws.tls[thread], + MakeSpan(compressed.Row(r), extents.cols), + /*packed_ofs=*/0); + }); // Arbitrary value, different from 1, must match `GenerateMat`. compressed.SetScale(0.6f); diff --git a/evals/cross_entropy.cc b/evals/cross_entropy.cc index 355f26d..5ff4e69 100644 --- a/evals/cross_entropy.cc +++ b/evals/cross_entropy.cc @@ -83,8 +83,8 @@ HWY_BEFORE_NAMESPACE(); namespace gcpp { namespace HWY_NAMESPACE { -void CallSoftmax(Logits logits, hwy::Profiler& p) { - Softmax(logits, p, hwy::Profiler::GlobalIdx()); +void CallSoftmax(Logits logits, ThreadingContext& ctx) { + Softmax(logits, ctx, hwy::Profiler::GlobalIdx()); } } // namespace HWY_NAMESPACE @@ -109,7 +109,7 @@ float ComputeCrossEntropy(const Gemma& gemma, size_t max_generated_tokens, const SampleFunc sample_token = [&](size_t qi, size_t pos, Logits logits, size_t /*worker*/) -> TokenAndProb { // input is logits, not yet probabilities - HWY_DYNAMIC_DISPATCH(CallSoftmax)(logits, env.ctx.profiler); + HWY_DYNAMIC_DISPATCH(CallSoftmax)(logits, env.ctx); // We are called for each token, but pos starts at 1. Clamping // max_generated_tokens to prompt.size() should prevent overrun. HWY_ASSERT(pos < prompt.size()); diff --git a/examples/hello_world/CMakeLists.txt b/examples/hello_world/CMakeLists.txt index 65541d8..1ff827e 100644 --- a/examples/hello_world/CMakeLists.txt +++ b/examples/hello_world/CMakeLists.txt @@ -18,7 +18,7 @@ set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) include(FetchContent) -FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 9781a1698ee0756ef1eaaf96930113ed7cb6d3ee) +FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 2a16a50ff61071bb25ddef0ce35d92b0e2b9c579) FetchContent_MakeAvailable(highway) FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 9045b2f60fa2b323dfac0eaef8fc17565036f9f9) FetchContent_MakeAvailable(sentencepiece) diff --git a/examples/simplified_gemma/CMakeLists.txt b/examples/simplified_gemma/CMakeLists.txt index 710f5ee..2fd4228 100644 --- a/examples/simplified_gemma/CMakeLists.txt +++ b/examples/simplified_gemma/CMakeLists.txt @@ -18,7 +18,7 @@ set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) include(FetchContent) -FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 9781a1698ee0756ef1eaaf96930113ed7cb6d3ee) +FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 2a16a50ff61071bb25ddef0ce35d92b0e2b9c579) FetchContent_MakeAvailable(highway) FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 53de76561cfc149d3c01037f0595669ad32a5e7c) FetchContent_MakeAvailable(sentencepiece) diff --git a/gemma/attention.cc b/gemma/attention.cc index 1269e53..117b533 100644 --- a/gemma/attention.cc +++ b/gemma/attention.cc @@ -55,8 +55,8 @@ static HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos, const hwy::Divisor& div_seq_len, const float* HWY_RESTRICT q, const MatPtrT& k, float* HWY_RESTRICT att, - hwy::Profiler& p, const size_t worker) { - PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kGenAttentionQDotK)); + ThreadingContext& ctx, const size_t worker) { + GCPP_ZONE(ctx, worker, Zones::kGenAttentionQDotK); if (HWY_LIKELY(last_pos < static_cast(div_seq_len.GetDivisor()))) { // Slightly faster: no wraparound. for (size_t pos = start_pos; pos <= last_pos; ++pos) { @@ -75,7 +75,7 @@ static HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos, void PositionalEncodingQK(float* qk, const size_t layer_idx, const LayerWeightsPtrs& layer, const AttentionActivations& activations, - hwy::Profiler& p, const size_t worker, + ThreadingContext& ctx, const size_t worker, const size_t pos, const float mul) { const size_t qkv_dim = layer.layer_config.qkv_dim; const PostQKType& post_qk = layer.layer_config.post_qk; @@ -88,10 +88,10 @@ void PositionalEncodingQK(float* qk, const size_t layer_idx, } // PostQKType::Rope if (post_qk == PostQKType::HalfRope) { - Rope(qk, qkv_dim / 2, inv_timescale, pos, p, worker); + Rope(qk, qkv_dim / 2, inv_timescale, pos, ctx, worker); if (mul != 1.0f) MulByConst(mul, qk, qkv_dim); } else { - RopeAndMulBy(mul, qk, qkv_dim, inv_timescale, pos, p, worker); + RopeAndMulBy(mul, qk, qkv_dim, inv_timescale, pos, ctx, worker); } } @@ -99,18 +99,16 @@ void PositionalEncodingQK(float* qk, const size_t layer_idx, // `att_out`. Equivalent in gemma/modules.py: // encoded = jnp.einsum('BTNS,BSNH->BTNH', probs, value_proj) // `v` is a strided view of the kv cache with dimensions [seq_len, qkv_dim]. -static HWY_INLINE void WeightedSumV(const size_t start_pos, - const size_t last_pos, - const hwy::Divisor& div_seq_len, - const float* HWY_RESTRICT att, - const MatPtrT& v, - float* HWY_RESTRICT att_out, - hwy::Profiler& p, const size_t worker) { +static HWY_INLINE void WeightedSumV( + const size_t start_pos, const size_t last_pos, + const hwy::Divisor& div_seq_len, const float* HWY_RESTRICT att, + const MatPtrT& v, float* HWY_RESTRICT att_out, ThreadingContext& ctx, + const size_t worker) { if (HWY_LIKELY(last_pos < static_cast(div_seq_len.GetDivisor()))) { // Slightly faster: no wraparound. Could be replaced with MatMul(att, v) if // we supported non-transposed B. // TODO: 2..4x unroll - MulByConstTo(att[start_pos], v.Row(start_pos), att_out, v.Cols(), p, + MulByConstTo(att[start_pos], v.Row(start_pos), att_out, v.Cols(), ctx, worker); for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) { MulByConstAndAdd(att[pos], v.Row(pos), att_out, v.Cols()); @@ -118,7 +116,8 @@ static HWY_INLINE void WeightedSumV(const size_t start_pos, } else { { const size_t pos_mod = div_seq_len.Remainder(start_pos); - MulByConstTo(att[pos_mod], v.Row(pos_mod), att_out, v.Cols(), p, worker); + MulByConstTo(att[pos_mod], v.Row(pos_mod), att_out, v.Cols(), ctx, + worker); } for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) { const size_t pos_mod = div_seq_len.Remainder(pos); @@ -134,7 +133,7 @@ void SingleDotSoftmaxWeightedSum( float* HWY_RESTRICT q, const MatPtrT& k, const MatPtrT& v, const size_t layer_idx, const LayerWeightsPtrs& layer, const AttentionActivations& activations, float* HWY_RESTRICT att, - float* HWY_RESTRICT att_out, hwy::Profiler& p, const size_t worker) { + float* HWY_RESTRICT att_out, ThreadingContext& ctx, const size_t worker) { const float att_cap = activations.config.att_cap; const float query_scale = activations.query_scale; const size_t seq_len = @@ -144,23 +143,23 @@ void SingleDotSoftmaxWeightedSum( if (layer.query_norm_scale.HasPtr()) { CallUpcasted(&layer.query_norm_scale, [&](const auto* weights_t) { RMSNormInplace(weights_t->PackedScale1(), /*w_ofs=*/0, q, - layer.layer_config.qkv_dim, p, worker); + layer.layer_config.qkv_dim, ctx, worker); }); } - PositionalEncodingQK(q, layer_idx, layer, activations, p, worker, pos, + PositionalEncodingQK(q, layer_idx, layer, activations, ctx, worker, pos, query_scale); - QDotK(start_pos, last_pos, activations.div_seq_len, q, k, att, p, worker); + QDotK(start_pos, last_pos, activations.div_seq_len, q, k, att, ctx, worker); // SoftMax with optional SoftCap yields "probabilities" in att. const size_t att_len = HWY_MIN(last_pos + 1, seq_len); const Logits logits(att, att_len); - MaybeLogitsSoftCap(att_cap, logits, p, worker); - Softmax(logits, p, worker, /*temperature=*/1.0f); + MaybeLogitsSoftCap(att_cap, logits, ctx, worker); + Softmax(logits, ctx, worker, /*temperature=*/1.0f); - WeightedSumV(start_pos, last_pos, activations.div_seq_len, att, v, att_out, p, - worker); + WeightedSumV(start_pos, last_pos, activations.div_seq_len, att, v, att_out, + ctx, worker); } // The attention window usually starts at 0 unless `pos` is larger than @@ -174,12 +173,7 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx, const LayerWeightsPtrs& layer, AttentionActivations& activations, QBatch& qbatch, ThreadingContext& ctx) { - static const auto root_zone = - ctx.profiler.AddZone("Gen.Attention.DotSoftmaxWeightedSumInclusive", - hwy::ProfilerFlags::kInclusive); - PROFILER_ZONE3(ctx.profiler, 0, root_zone); - const auto zone = - GetProfilerZone(Zones::kGenAttentionDotSoftmaxWeightedSumPar); + GCPP_ZONE(ctx, 0, Zones::kGenAttentionDotSoftmaxWeightedSumInclusive); const hwy::Divisor div_qbatch(qbatch.Size()); const LayerConfig& layer_config = layer.layer_config; @@ -199,7 +193,7 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx, const auto func = [&](const size_t task, size_t worker) HWY_ATTR { const size_t tq_idx = activations.div_heads.Divide(task); const size_t head = activations.div_heads.Remainder(task); - PROFILER_ZONE3(ctx.profiler, worker, zone); + GCPP_ZONE(ctx, worker, Zones::kGenAttentionDotSoftmaxWeightedSumPar); const size_t qi = div_qbatch.Remainder(tq_idx); const size_t batch_idx = div_qbatch.Divide(tq_idx); @@ -231,16 +225,15 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx, v.SetPtr(kv_cache.Row(0) + kv_head_offset + qkv_dim, kv_cache.Stride()); SingleDotSoftmaxWeightedSum(pos, start_pos, last_pos, q, k, v, layer_idx, - layer, activations, att, att_out, ctx.profiler, - worker); + layer, activations, att, att_out, ctx, worker); }; { PROFILER_ZONE("Gen.Attention.DotSoftmax.ForkJoin"); // Full parallelism is helpful, kAcrossClusters is insufficient. HierarchicalParallelFor( - num_tokens * div_qbatch.GetDivisor() * layer_config.heads, ctx.pools, - func); + num_tokens * div_qbatch.GetDivisor() * layer_config.heads, ctx, + Callers::kAttDotSoftmaxWeightedSum, func); } } @@ -256,9 +249,8 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx, AttentionActivations& activations, const QBatch& qbatch, const int flags, MatMulEnv& env) { - static const auto zone = env.ctx.profiler.AddZone( - "Gen.Attention.ComputeQKV", hwy::ProfilerFlags::kInclusive); - PROFILER_ZONE3(env.ctx.profiler, hwy::Profiler::GlobalIdx(), zone); + GCPP_ZONE(env.ctx, hwy::Profiler::GlobalIdx(), + Zones::kGenAttentionComputeQKV); const hwy::Divisor div_qbatch(qbatch.Size()); const size_t num_interleaved = num_tokens * div_qbatch.GetDivisor(); @@ -295,7 +287,8 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx, // tasks are very lightweight. ParallelFor( ParallelismStrategy::kFlat, kv_heads * num_interleaved, env.ctx, - /*cluster_idx=*/0, [&](size_t task, size_t worker) HWY_ATTR { + /*cluster_idx=*/0, Callers::kAttComputeQKV, + [&](size_t task, size_t worker) HWY_ATTR { const size_t head = task % kv_heads; const size_t interleaved_idx = task / kv_heads; const size_t qi = div_qbatch.Remainder(interleaved_idx); @@ -316,12 +309,12 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx, if (layer.key_norm_scale.HasPtr()) { CallUpcasted(&layer.key_norm_scale, [&](const auto* weights_t) { RMSNormInplace(weights_t->PackedScale1(), /*w_ofs=*/0, kv_f32, - qkv_dim, env.ctx.profiler, worker); + qkv_dim, env.ctx, worker); }); } - PositionalEncodingQK(kv_f32, layer_idx, layer, activations, - env.ctx.profiler, worker, pos, /*mul=*/1.0f); + PositionalEncodingQK(kv_f32, layer_idx, layer, activations, env.ctx, + worker, pos, /*mul=*/1.0f); CompressPerThread tls; Compress(kv_f32, 2 * qkv_dim, tls, MakeSpan(kv, 2 * qkv_dim), 0); }); @@ -332,9 +325,7 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx, static HWY_INLINE void SumHeads(const LayerWeightsPtrs& layer, AttentionActivations& activations, MatMulEnv& env) { - static const auto zone = env.ctx.profiler.AddZone( - "Gen.Attention.SumHeads", hwy::ProfilerFlags::kInclusive); - PROFILER_ZONE3(env.ctx.profiler, hwy::Profiler::GlobalIdx(), zone); + GCPP_ZONE(env.ctx, hwy::Profiler::GlobalIdx(), Zones::kGenAttentionSumHeads); const LayerConfig& layer_config = layer.layer_config; (void)layer_config; // For HWY_DASSERT // att_weights and att_out are concatenated heads, each of length @@ -352,9 +343,7 @@ void GemmaAttention(size_t num_tokens, const size_t layer_idx, const LayerWeightsPtrs& layer, AttentionActivations& activations, QBatch& qbatch, MatMulEnv& env, int flags) { - static const auto zone = - env.ctx.profiler.AddZone("Gen.Attention", hwy::ProfilerFlags::kInclusive); - PROFILER_ZONE3(env.ctx.profiler, hwy::Profiler::GlobalIdx(), zone); + GCPP_ZONE(env.ctx, hwy::Profiler::GlobalIdx(), Zones::kGenAttention); const LayerConfig& layer_config = layer.layer_config; HWY_DASSERT(!layer_config.IsMHA()); // No longer supported. diff --git a/gemma/attention.h b/gemma/attention.h index a0af4ff..6c4a48e 100644 --- a/gemma/attention.h +++ b/gemma/attention.h @@ -26,33 +26,33 @@ namespace gcpp { // Passed to HWY_VISIT_TARGETS; declares for one target. -#define GEMMA_DECL_ATTENTION(TARGET, NAMESPACE) \ - namespace NAMESPACE { \ - void PositionalEncodingQK(float* qk, size_t layer_idx, \ - const LayerWeightsPtrs& layer, \ - const AttentionActivations& activations, \ - hwy::Profiler& p, size_t worker, size_t pos, \ - float mul); \ - \ - size_t StartPos(size_t pos, const ModelConfig& config, size_t layer_idx); \ - \ - void SingleDotSoftmaxWeightedSum( \ - const size_t pos, const size_t start_pos, const size_t last_pos, \ - float* HWY_RESTRICT q, const MatPtrT& k, const MatPtrT& v, \ - size_t layer_idx, const LayerWeightsPtrs& layer, \ - const AttentionActivations& activations, float* HWY_RESTRICT att, \ - float* HWY_RESTRICT att_out, hwy::Profiler& p, size_t worker); \ - \ - void DotSoftmaxWeightedSum(const size_t num_tokens, size_t layer_idx, \ - const LayerWeightsPtrs& layer, \ - AttentionActivations& activations, \ - QBatch& qbatch, ThreadingContext& ctx); \ - \ - void GemmaAttention(size_t num_tokens, const size_t layer_idx, \ - const LayerWeightsPtrs& layer, \ - AttentionActivations& activations, QBatch& qbatch, \ - MatMulEnv& env, int flags); \ - /* NOLINTNEXTLINE(google-readability-namespace-comments) */ \ +#define GEMMA_DECL_ATTENTION(TARGET, NAMESPACE) \ + namespace NAMESPACE { \ + void PositionalEncodingQK(float* qk, size_t layer_idx, \ + const LayerWeightsPtrs& layer, \ + const AttentionActivations& activations, \ + ThreadingContext& ctx, size_t worker, size_t pos, \ + float mul); \ + \ + size_t StartPos(size_t pos, const ModelConfig& config, size_t layer_idx); \ + \ + void SingleDotSoftmaxWeightedSum( \ + const size_t pos, const size_t start_pos, const size_t last_pos, \ + float* HWY_RESTRICT q, const MatPtrT& k, const MatPtrT& v, \ + size_t layer_idx, const LayerWeightsPtrs& layer, \ + const AttentionActivations& activations, float* HWY_RESTRICT att, \ + float* HWY_RESTRICT att_out, ThreadingContext& ctx, size_t worker); \ + \ + void DotSoftmaxWeightedSum(const size_t num_tokens, size_t layer_idx, \ + const LayerWeightsPtrs& layer, \ + AttentionActivations& activations, \ + QBatch& qbatch, ThreadingContext& ctx); \ + \ + void GemmaAttention(size_t num_tokens, const size_t layer_idx, \ + const LayerWeightsPtrs& layer, \ + AttentionActivations& activations, QBatch& qbatch, \ + MatMulEnv& env, int flags); \ + /* NOLINTNEXTLINE(google-readability-namespace-comments) */ \ } // namespace NAMESPACE // Function declarations for each SIMD target. Allows direct call from the diff --git a/gemma/flash_attention.cc b/gemma/flash_attention.cc index c6a2fba..bf3aede 100644 --- a/gemma/flash_attention.cc +++ b/gemma/flash_attention.cc @@ -61,13 +61,12 @@ static constexpr size_t kNFx8HTileSize = 8; // possible consecutive elements have the same KV. static void TransposeQ(const MatPtrT& q, MatPtrT& q_t, const size_t qbatch_size, ThreadingContext& ctx) { - const auto zone = GetProfilerZone(Zones::kFlashAttentionTransposeQ); // Group floats by the number of floats in a cache line. const size_t kNF = ctx.cache_info.LineBytes() / sizeof(float); const size_t num_heads = q.Cols() / q_t.Rows(); const size_t batch_size = q.Rows() / qbatch_size; const auto func = [&](const size_t task, size_t worker) HWY_ATTR { - PROFILER_ZONE3(ctx.profiler, worker, zone); + GCPP_ZONE(ctx, worker, Zones::kFlashAttentionTransposeQ); for (size_t lane = 0; lane < kNF; ++lane) { size_t q_row = task * kNF + lane; if (q_row >= q_t.Rows()) break; @@ -83,10 +82,10 @@ static void TransposeQ(const MatPtrT& q, MatPtrT& q_t, } }; { + const size_t num_tasks = hwy::DivCeil(q_t.Rows(), kNF); // Better than kFlat. - size_t num_tasks = hwy::DivCeil(q_t.Rows(), kNF); ParallelFor(ParallelismStrategy::kHierarchical, num_tasks, ctx, - /*cluster_idx=*/0, func); + /*cluster_idx=*/0, Callers::kFlashTransposeQ, func); } } @@ -96,12 +95,10 @@ void RMSNormAndPositionalEncoding(const size_t num_tokens, const QBatch& qbatch, const LayerWeightsPtrs& layer, const AttentionActivations& activations, ThreadingContext& ctx) { - const auto zone = - GetProfilerZone(Zones::kFlashAttentionRmsNormAndPositionalEncoding); const float query_scale = activations.query_scale; const hwy::Divisor div_qbatch(qbatch.Size()); const auto func = [&](const size_t task, size_t worker) HWY_ATTR { - PROFILER_ZONE3(ctx.profiler, worker, zone); + GCPP_ZONE(ctx, worker, Zones::kFlashAttentionRmsNormAndPositionalEncoding); size_t qi = div_qbatch.Remainder(task); size_t batch_idx = div_qbatch.Divide(task); for (size_t h = 0; h < layer.layer_config.heads; ++h) { @@ -115,18 +112,19 @@ void RMSNormAndPositionalEncoding(const size_t num_tokens, const QBatch& qbatch, if (layer.query_norm_scale.HasPtr()) { CallUpcasted(&layer.query_norm_scale, [&](const auto* weights_t) { RMSNormInplace(weights_t->PackedScale1(), /*w_ofs=*/0, q_row, - layer.layer_config.qkv_dim, ctx.profiler, worker); + layer.layer_config.qkv_dim, ctx, worker); }); } - PositionalEncodingQK(q_row, layer_idx, layer, activations, ctx.profiler, - worker, pos, query_scale); + PositionalEncodingQK(q_row, layer_idx, layer, activations, ctx, worker, + pos, query_scale); } }; { // kHierarchical is not worth the extra sync overhead because the tasks are // very lightweight. ParallelFor(ParallelismStrategy::kFlat, num_tokens * qbatch.Size(), ctx, - /*cluster_idx=*/0, func); + /*cluster_idx=*/0, Callers::kFlashRMSNormAndPositionalEncoding, + func); } } @@ -158,10 +156,9 @@ void SingleFlashAttention(const size_t start_pos, const size_t last_pos, const MatPtrT& v, const size_t layer_idx, const LayerWeightsPtrs& layer, const AttentionActivations& activations, - float* HWY_RESTRICT att_out, hwy::Profiler& p, + float* HWY_RESTRICT att_out, ThreadingContext& ctx, const size_t worker) { - PROFILER_ZONE3(p, worker, - GetProfilerZone(Zones::kFlashAttentionSingleFlashAttention)); + GCPP_ZONE(ctx, worker, Zones::kFlashAttentionSingleFlashAttention); const size_t pos_mod = activations.div_seq_len.Remainder(start_pos); float m = Dot(q, k.Row(pos_mod), k.Cols()); if (float cap = activations.config.att_cap; cap > 0.0f) { @@ -170,7 +167,7 @@ void SingleFlashAttention(const size_t start_pos, const size_t last_pos, } float d = 1.0f; // This is just a copy of the first token. - MulByConstTo(d, v.Row(pos_mod), att_out, v.Cols(), p, worker); + MulByConstTo(d, v.Row(pos_mod), att_out, v.Cols(), ctx, worker); for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) { const size_t pos_mod = activations.div_seq_len.Remainder(pos); float x = Dot(q, k.Row(pos_mod), k.Cols()); @@ -276,9 +273,8 @@ void TileFlashAttention( const MatPtrT& v, const size_t layer_idx, const LayerWeightsPtrs& layer, const AttentionActivations& activations, MatPtrT& att_out, const uint32_t* HWY_RESTRICT out_offsets, - hwy::Profiler& p, const size_t worker) { - PROFILER_ZONE3(p, worker, - GetProfilerZone(Zones::kFlashAttentionTileFlashAttention)); + ThreadingContext& ctx, const size_t worker) { + GCPP_ZONE(ctx, worker, Zones::kFlashAttentionTileFlashAttention); constexpr int kHTileSize = kNFx8HTileSize; using DF = hn::ScalableTag; const DF df; @@ -430,9 +426,8 @@ void TileFlashAttention4( const size_t max_last_pos, const MatPtrT& v, const size_t layer_idx, const LayerWeightsPtrs& layer, const AttentionActivations& activations, MatPtrT& att_out, const uint32_t* HWY_RESTRICT out_offsets, - hwy::Profiler& p, const size_t worker) { - PROFILER_ZONE3(p, worker, - GetProfilerZone(Zones::kFlashAttentionTileFlashAttention4)); + ThreadingContext& ctx, const size_t worker) { + GCPP_ZONE(ctx, worker, Zones::kFlashAttentionTileFlashAttention4); using DF = hn::ScalableTag; const DF df; using VF = hn::Vec; @@ -597,10 +592,7 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism, const size_t layer_idx, const LayerWeightsPtrs& layer, AttentionActivations& activations, QBatch& qbatch, ThreadingContext& ctx) { - static const auto root_zone = ctx.profiler.AddZone( - "FlashAttention.Inclusive", hwy::ProfilerFlags::kInclusive); - PROFILER_ZONE3(ctx.profiler, 0, root_zone); - const auto zone = GetProfilerZone(Zones::kFlashAttentionFlashAttention); + GCPP_ZONE(ctx, 0, Zones::kFlashAttentionInclusive); RMSNormAndPositionalEncoding(num_tokens, qbatch, activations.q, layer_idx, layer, activations, ctx); const hwy::Divisor div_qbatch(qbatch.Size()); @@ -653,7 +645,7 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism, // For each head/token/query, compute fused flash Q.K, softmax and weighted V. const auto func = [&](const size_t task, size_t worker) HWY_ATTR { - PROFILER_ZONE3(ctx.profiler, worker, zone); + GCPP_ZONE(ctx, worker, Zones::kFlashAttentionFlashAttention); // Offsets into original Q for each row in the tile. uint32_t q_offsets[kMaxNF]; // Offsets into att_out for each row in the tile. @@ -741,13 +733,12 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism, TileFlashAttention(activations.q, q_offsets, qT, k, start_positions[offset], last_pos, min_last_pos, max_last_pos, v, layer_idx, layer, activations, - activations.att_out, out_offsets, ctx.profiler, - worker); + activations.att_out, out_offsets, ctx, worker); } else if (kVTileSize == 4) { - TileFlashAttention4( - activations.q, q_offsets, k, start_positions[offset], last_pos, - min_last_pos, max_last_pos, v, layer_idx, layer, activations, - activations.att_out, out_offsets, ctx.profiler, worker); + TileFlashAttention4(activations.q, q_offsets, k, + start_positions[offset], last_pos, min_last_pos, + max_last_pos, v, layer_idx, layer, activations, + activations.att_out, out_offsets, ctx, worker); } else { HWY_UNREACHABLE; } @@ -757,7 +748,7 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism, activations.q.Row(0) + q_offsets[offset], k, v, layer_idx, layer, activations, activations.att_out.Row(0) + out_offsets[offset], - ctx.profiler, worker); + ctx, worker); } } }; @@ -765,7 +756,8 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism, { PROFILER_ZONE("Gen.FlashAttention.ForkJoin"); // Full parallelism is helpful, SmallParallelFor is insufficient. - HierarchicalParallelFor(num_thread_tasks, ctx.pools, func); + HierarchicalParallelFor(num_thread_tasks, ctx, Callers::kFlashAttention, + func); } } diff --git a/gemma/flash_attention.h b/gemma/flash_attention.h index 8aa787b..959b227 100644 --- a/gemma/flash_attention.h +++ b/gemma/flash_attention.h @@ -39,8 +39,8 @@ namespace gcpp { const MatPtrT& k, const MatPtrT& v, \ size_t layer_idx, const LayerWeightsPtrs& layer, \ const AttentionActivations& activations, \ - float* HWY_RESTRICT att_out, hwy::Profiler& p, \ - size_t worker); \ + float* HWY_RESTRICT att_out, \ + ThreadingContext& ctx, size_t worker); \ \ size_t GetVTileSize(size_t kNF, size_t num_head_groups, size_t num_tokens, \ size_t total_tasks, size_t target_parallelism); \ diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index 0034f3f..dc7efea 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -47,9 +47,9 @@ namespace HWY_NAMESPACE { // For use by Vit even if !GEMMA_FUSED_FFN. template void Activation(ActivationType activation, T1* HWY_RESTRICT c1, - const T2* HWY_RESTRICT c2, const size_t count, hwy::Profiler& p, - const size_t worker) { - PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kGenActivation)); + const T2* HWY_RESTRICT c2, const size_t count, + ThreadingContext& ctx, const size_t worker) { + GCPP_ZONE(ctx, worker, Zones::kGenActivation); namespace hn = hwy::HWY_NAMESPACE; using DF = hn::ScalableTag; using VF = hn::Vec; @@ -73,11 +73,11 @@ void ActivationBatched( ParallelismStrategy parallelism = ParallelismStrategy::kFlat) { using T = typename Mat::T; ParallelFor(parallelism, c1.Rows(), ctx, cluster_idx, - [&](uint64_t task, size_t worker) { + Callers::kActivationBatched, [&](uint64_t task, size_t worker) { // Cast to correct type so type deduction works. Activation(activation, c1.Row(task), - static_cast(nullptr), c1.Cols(), - ctx.profiler, worker); + static_cast(nullptr), c1.Cols(), ctx, + worker); }); } @@ -87,8 +87,8 @@ void ActivationBatched( static inline void Activation(ActivationType activation, const RowPtrsBF C1, const IndexRange range_r, const IndexRange range_c, const StridedViewBF C2, - hwy::Profiler& p, const size_t worker) { - PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kGenActivationFused)); + ThreadingContext& ctx, const size_t worker) { + GCPP_ZONE(ctx, worker, Zones::kGenActivationFused); const size_t cols = range_c.Num(); HWY_DASSERT(C2.Cols() == cols); @@ -119,16 +119,16 @@ HWY_NOINLINE void ActivationBatched( HWY_DASSERT(c1.SameShape(*c2)); if (c2 && c2->HasPtr()) { ParallelFor(parallelism, c1.Rows(), ctx, cluster_idx, - [&](uint64_t task, size_t worker) { + Callers::kActivationBatched, [&](uint64_t task, size_t worker) { Activation(activation, c1.Row(task), c2->Row(task), c1.Cols(), - ctx.profiler, worker); + ctx, worker); }); } else { // No multiplier ParallelFor(parallelism, c1.Rows(), ctx, cluster_idx, - [&](uint64_t task, size_t worker) { + Callers::kActivationBatched, [&](uint64_t task, size_t worker) { Activation(activation, c1.Row(task), static_cast(nullptr), - c1.Cols(), ctx.profiler, worker); + c1.Cols(), ctx, worker); }); } } @@ -153,9 +153,7 @@ void PostNorm(PostNormType post_norm, const MatPtr& weights, static inline void FFWNoVit(const LayerWeightsPtrs& layer, Activations& activations, MatMulEnv& env) { - static const auto zone = - env.ctx.profiler.AddZone("Gen.FFW", hwy::ProfilerFlags::kInclusive); - PROFILER_ZONE3(env.ctx.profiler, hwy::Profiler::GlobalIdx(), zone); + GCPP_ZONE(env.ctx, hwy::Profiler::GlobalIdx(), Zones::kGenFFW); const LayerConfig& layer_config = layer.layer_config; HWY_DASSERT(!layer_config.ff_biases); // Only used in Vit. @@ -163,8 +161,8 @@ static inline void FFWNoVit(const LayerWeightsPtrs& layer, #if GEMMA_FUSED_FFN const auto fused = [&](RowPtrsBF C1, IndexRange range_r, IndexRange range_c, StridedViewBF C2, size_t worker) { - Activation(layer_config.activation, C1, range_r, range_c, C2, - env.ctx.profiler, worker); + Activation(layer_config.activation, C1, range_r, range_c, C2, env.ctx, + worker); }; MMOptions options; options.SetFunc(fused); diff --git a/gemma/gemma.cc b/gemma/gemma.cc index 80bf9e2..7991c35 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -55,7 +55,7 @@ #include "io/io.h" // Path #include "ops/matmul.h" #include "paligemma/image.h" -#include "util/basics.h" // PROFILER_ZONE3 +#include "util/basics.h" #include "util/threading_context.h" #include "hwy/aligned_allocator.h" // Span #include "hwy/base.h" @@ -138,9 +138,7 @@ EmbedMMToken(int token, size_t x_row, size_t pos, size_t pos_in_prompt, MatStorageT& x, ThreadingContext& ctx, const ImageTokens* image_tokens = nullptr, size_t image_token_position = 0) { - static const auto zone = - ctx.profiler.AddZone("Gen.Embed", hwy::ProfilerFlags::kInclusive); - PROFILER_ZONE3(ctx.profiler, hwy::Profiler::GlobalIdx(), zone); + GCPP_ZONE(ctx, hwy::Profiler::GlobalIdx(), Zones::kGenEmbed); // Image tokens just need to be copied. if (model_config.wrapping == PromptWrapping::GEMMA_VLM && @@ -415,9 +413,7 @@ static void SampleAndStream(const ModelConfig& config, MaybeObserve(runtime_config, activations, qbatch, -1); { - static const auto zone = env.ctx.profiler.AddZone( - "Gen.EmbeddingMatmul", hwy::ProfilerFlags::kInclusive); - PROFILER_ZONE3(env.ctx.profiler, /*worker=*/0, zone); + GCPP_ZONE(env.ctx, /*worker=*/0, Zones::kGenEmbeddingMatmul); // Compute logits from last layer activations. CallMatMul(activations.x_bf, weights.embedder_input_embedding, /*add=*/nullptr, env, activations.logits); @@ -431,7 +427,8 @@ static void SampleAndStream(const ModelConfig& config, ParallelFor( ParallelismStrategy::kFlat, qbatch.Size(), env.ctx, - /*cluster_idx=*/0, [&](size_t qi, size_t worker) { + /*cluster_idx=*/0, Callers::kSampleAndStream, + [&](size_t qi, size_t worker) { if (!non_eos.Get(qi)) return; // We streamed all prefill tokens, but pos is still one behind @@ -469,8 +466,7 @@ ChooseSampleFunc(const RuntimeConfig& runtime_config, if (runtime_config.top_k == 1 && !runtime_config.accept_token) { return [&](size_t /*qi*/, size_t /*pos*/, Logits logits, size_t worker) HWY_ATTR -> TokenAndProb { - PROFILER_ZONE3(ctx.profiler, worker, - GetProfilerZone(Zones::kGenSampleTop1)); + GCPP_ZONE(ctx, worker, Zones::kGenSampleTop1); return Top1OfSoftmax(logits); }; } @@ -478,14 +474,13 @@ ChooseSampleFunc(const RuntimeConfig& runtime_config, // General case: Softmax with top-k sampling. return [&](size_t qi, size_t pos, Logits logits, size_t worker) HWY_ATTR -> TokenAndProb { - PROFILER_ZONE3(ctx.profiler, worker, - GetProfilerZone(Zones::kGenSampleTopK)); + GCPP_ZONE(ctx, worker, Zones::kGenSampleTopK); // We want a different sequence for each batch element and position. const uint64_t stream = (static_cast(qi) << 32) | pos; RngStream gen(engine, stream); - return FusedSoftmaxAndSampleTopK( - logits, runtime_config.top_k, gen, runtime_config.temperature, - runtime_config.accept_token, ctx.profiler, worker); + return FusedSoftmaxAndSampleTopK(logits, runtime_config.top_k, gen, + runtime_config.temperature, + runtime_config.accept_token, ctx, worker); }; } @@ -657,8 +652,8 @@ Gemma::Gemma(const LoaderArgs& loader, const InferenceArgs& inference, Gemma::~Gemma() = default; -void Gemma::Save(const Path& weights_path, NestedPools& pools) const { - BlobWriter writer(weights_path, pools.Pool()); +void Gemma::Save(const Path& weights_path, ThreadingContext& ctx) const { + BlobWriter writer(weights_path, ctx); const std::vector serialized_mat_ptrs = weights_.AddTensorDataToWriter(writer); WriteSingleFile(model_.Config(), model_.Tokenizer(), serialized_mat_ptrs, diff --git a/gemma/gemma.h b/gemma/gemma.h index 5e40bda..771cd1c 100644 --- a/gemma/gemma.h +++ b/gemma/gemma.h @@ -246,7 +246,7 @@ class Gemma { const GemmaChatTemplate& ChatTemplate() const { return chat_template_; } const InferenceArgs& Inference() const { return inference_; } - void Save(const Path& weights_path, NestedPools& pools) const; + void Save(const Path& weights_path, ThreadingContext& ctx) const; // `pos` is the position in the KV cache. Users are responsible for // incrementing it in the `*StreamFunc`, or setting to zero for single-turn. diff --git a/gemma/model_store.h b/gemma/model_store.h index 42af343..b4d63ad 100644 --- a/gemma/model_store.h +++ b/gemma/model_store.h @@ -36,7 +36,6 @@ // IWYU pragma: end_exports #include "util/allocator.h" -#include "hwy/contrib/thread_pool/thread_pool.h" namespace gcpp { diff --git a/gemma/vit.cc b/gemma/vit.cc index abe0a37..b00efda 100644 --- a/gemma/vit.cc +++ b/gemma/vit.cc @@ -90,39 +90,43 @@ class VitAttention { ZeroInit(activations_.attention.att_out); for (size_t head = 0; head < heads; ++head) { - pool_.Run(0, num_tokens_, [&](uint64_t task, size_t worker) HWY_ATTR { - const size_t token = task; - float* HWY_RESTRICT q = - activations_.attention.q.Row(token) + head * 3 * qkv_dim; - // TODO: shift to MatMul with A.scale once MatMul is confirmed working - MulByConst(query_scale, q, qkv_dim); - hwy::CopyBytes(q, Q.Row(token), qkv_dim * sizeof(float)); - }); + pool_.Run(0, num_tokens_, caller1_, + [&](uint64_t task, size_t worker) HWY_ATTR { + const size_t token = task; + float* HWY_RESTRICT q = + activations_.attention.q.Row(token) + head * 3 * qkv_dim; + // TODO: shift to MatMul with A.scale once MatMul is confirmed + // working + MulByConst(query_scale, q, qkv_dim); + hwy::CopyBytes(q, Q.Row(token), qkv_dim * sizeof(float)); + }); - pool_.Run(0, seq_len, [&](uint64_t task, size_t /*thread*/) HWY_ATTR { - const size_t seq_idx = task; - float* HWY_RESTRICT k = activations_.attention.q.Row(seq_idx) + - head * 3 * qkv_dim + qkv_dim; - hwy::CopyBytes(k, K.Row(seq_idx), qkv_dim * sizeof(float)); - }); + pool_.Run( + 0, seq_len, caller2_, [&](uint64_t task, size_t /*thread*/) HWY_ATTR { + const size_t seq_idx = task; + float* HWY_RESTRICT k = activations_.attention.q.Row(seq_idx) + + head * 3 * qkv_dim + qkv_dim; + hwy::CopyBytes(k, K.Row(seq_idx), qkv_dim * sizeof(float)); + }); // this produces C, a (num_tokens_, seq_len) matrix of dot products CallMatMul(Q, K, nullptr, env_, C); - pool_.Run(0, num_tokens_, [&](uint64_t task, size_t worker) HWY_ATTR { - Softmax(C.RowSpan(task), env_.ctx.profiler, worker); - }); + pool_.Run(0, num_tokens_, caller3_, + [&](uint64_t task, size_t worker) + HWY_ATTR { Softmax(C.RowSpan(task), env_.ctx, worker); }); - pool_.Run(0, num_tokens_, [&](uint64_t task, size_t worker) HWY_ATTR { - size_t token = task; - float* HWY_RESTRICT att_out = - activations_.attention.att_out.Row(token) + head * qkv_dim; - for (size_t i = 0; i < seq_len; ++i) { - float* HWY_RESTRICT v = activations_.attention.q.Row(i) + - head * 3 * qkv_dim + 2 * qkv_dim; - MulByConstAndAdd(C.Row(token)[i], v, att_out, qkv_dim); - } - }); + pool_.Run( + 0, num_tokens_, caller4_, [&](uint64_t task, size_t worker) HWY_ATTR { + size_t token = task; + float* HWY_RESTRICT att_out = + activations_.attention.att_out.Row(token) + head * qkv_dim; + for (size_t i = 0; i < seq_len; ++i) { + float* HWY_RESTRICT v = activations_.attention.q.Row(i) + + head * 3 * qkv_dim + 2 * qkv_dim; + MulByConstAndAdd(C.Row(token)[i], v, att_out, qkv_dim); + } + }); } } @@ -136,7 +140,7 @@ class VitAttention { PROFILER_ZONE("Gen.VitAttention.DotSoftmax"); // Compute Q.K, softmax, and weighted V. - pool_.Run(0, layer_config_.heads * num_tokens_, + pool_.Run(0, layer_config_.heads * num_tokens_, caller1_, [&](uint64_t task, size_t worker) HWY_ATTR { const size_t head = task % layer_config_.heads; const size_t token = task / layer_config_.heads; @@ -152,7 +156,7 @@ class VitAttention { head_att[i] = Dot(q, k, qkv_dim); // score = q.k } // SoftMax yields "probabilities" in head_att. - Softmax(Logits(head_att, seq_len), env_.ctx.profiler, worker); + Softmax(Logits(head_att, seq_len), env_.ctx, worker); // Compute weighted sum of v into att_out. float* HWY_RESTRICT att_out = activations_.attention.att_out.Row(token) + head * qkv_dim; @@ -185,7 +189,11 @@ class VitAttention { layer_(layer), layer_config_(layer.layer_config), env_(env), - pool_(env_.ctx.pools.Pool(0)) {} + pool_(env_.ctx.pools.Pool(0)), + caller1_(env_.ctx.pool_callers.Get(Callers::kVitDotSoftmax1)), + caller2_(env_.ctx.pool_callers.Get(Callers::kVitDotSoftmax2)), + caller3_(env_.ctx.pool_callers.Get(Callers::kVitDotSoftmax3)), + caller4_(env_.ctx.pool_callers.Get(Callers::kVitDotSoftmax4)) {} HWY_INLINE void operator()() { ComputeQKV(); @@ -204,6 +212,10 @@ class VitAttention { const LayerConfig& layer_config_; MatMulEnv& env_; hwy::ThreadPool& pool_; + hwy::pool::Caller caller1_; + hwy::pool::Caller caller2_; + hwy::pool::Caller caller3_; + hwy::pool::Caller caller4_; }; // Same as FFWNoVit, but with different layer members and no second @@ -333,7 +345,7 @@ void PrefillVit(const ModelConfig& model_config, const WeightsPtrs& weights, // Apply soft embedding norm before input projection. CallUpcasted(&weights.mm_embed_norm, [&](const auto* weights_t) { RMSNormInplace(weights_t->PackedScale1(), /*w_ofs=*/0, - activations.x.Row(0), vit_model_dim, env.ctx.profiler, + activations.x.Row(0), vit_model_dim, env.ctx, hwy::Profiler::GlobalIdx()); }); } diff --git a/gemma/weights.cc b/gemma/weights.cc index d871c6f..e1e01bf 100644 --- a/gemma/weights.cc +++ b/gemma/weights.cc @@ -34,7 +34,6 @@ #include "util/threading_context.h" #include "util/zones.h" #include "hwy/base.h" -#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/highway.h" #include "hwy/profiler.h" @@ -150,7 +149,7 @@ void LayerWeightsPtrs::SplitAttW1() { static void HWY_MAYBE_UNUSED InitAttWeightsI8( const LayerConfig& layer_config, MatPtrT& attn_vec_einsum_w, MatPtrT& att_weights, std::vector& mat_owners, - const Allocator& allocator) { + ThreadingContext& ctx) { if (!attn_vec_einsum_w.HasPtr()) return; HWY_ASSERT(attn_vec_einsum_w.GetType() == Type::kI8); @@ -160,7 +159,8 @@ static void HWY_MAYBE_UNUSED InitAttWeightsI8( static std::mutex m; std::lock_guard lock(m); mat_owners.emplace_back(); - mat_owners.back().AllocateFor(att_weights, allocator, MatPadding::kPacked); + mat_owners.back().AllocateFor(att_weights, ctx.allocator, + MatPadding::kPacked); } const size_t model_dim = layer_config.model_dim; @@ -188,10 +188,9 @@ static void HWY_MAYBE_UNUSED InitAttWeightsI8( } CompressWorkingSet work; - hwy::ThreadPool pool(0); HWY_NAMESPACE::Compress(att_weights_tmp.get(), model_dim * heads * qkv_dim, work, att_weights.Span(), - /*packed_ofs=*/0, pool); + /*packed_ofs=*/0, ctx); att_weights.SetScale(attn_vec_einsum_w.Scale()); } @@ -201,7 +200,7 @@ static void HWY_MAYBE_UNUSED SplitW1I8(const LayerConfig& layer_config, MatPtrT& gating_einsum_w1, MatPtrT& gating_einsum_w2, std::vector& mat_owners, - const Allocator& allocator) { + ThreadingContext& ctx) { // Files have both or neither of w1 and w2. HWY_ASSERT(gating_einsum_w1.HasPtr() == gating_einsum_w2.HasPtr()); // w is mutually exclusive with w1 and w2 in the file. @@ -228,10 +227,10 @@ static void HWY_MAYBE_UNUSED SplitW1I8(const LayerConfig& layer_config, static std::mutex m; std::lock_guard lock(m); mat_owners.emplace_back(); - mat_owners.back().AllocateFor(gating_einsum_w1, allocator, + mat_owners.back().AllocateFor(gating_einsum_w1, ctx.allocator, MatPadding::kPacked); mat_owners.emplace_back(); - mat_owners.back().AllocateFor(gating_einsum_w2, allocator, + mat_owners.back().AllocateFor(gating_einsum_w2, ctx.allocator, MatPadding::kPacked); } @@ -248,11 +247,10 @@ static void HWY_MAYBE_UNUSED SplitW1I8(const LayerConfig& layer_config, float* w2_tmp = w_tmp.get() + split_size; CompressWorkingSet work; - hwy::ThreadPool pool(0); HWY_NAMESPACE::Compress(w1_tmp, split_size, work, gating_einsum_w1.Span(), 0, - pool); + ctx); HWY_NAMESPACE::Compress(w2_tmp, split_size, work, gating_einsum_w2.Span(), 0, - pool); + ctx); gating_einsum_w1.SetScale(1.0f); gating_einsum_w2.SetScale(1.0f); @@ -265,7 +263,7 @@ static void HWY_MAYBE_UNUSED SplitAttW1I8(const LayerConfig& layer_config, MatPtrT& qkv_einsum_w1, MatPtrT& qkv_einsum_w2, std::vector& mat_owners, - const Allocator& allocator) { + ThreadingContext& ctx) { // w is mutually exclusive with w1 in the file. HWY_ASSERT(qkv_einsum_w.HasPtr() ^ qkv_einsum_w1.HasPtr()); // Done if we already read split tensors. @@ -291,10 +289,10 @@ static void HWY_MAYBE_UNUSED SplitAttW1I8(const LayerConfig& layer_config, static std::mutex m; std::lock_guard lock(m); mat_owners.emplace_back(); - mat_owners.back().AllocateFor(qkv_einsum_w1, allocator, + mat_owners.back().AllocateFor(qkv_einsum_w1, ctx.allocator, MatPadding::kPacked); mat_owners.emplace_back(); - mat_owners.back().AllocateFor(qkv_einsum_w2, allocator, + mat_owners.back().AllocateFor(qkv_einsum_w2, ctx.allocator, MatPadding::kPacked); } @@ -312,9 +310,8 @@ static void HWY_MAYBE_UNUSED SplitAttW1I8(const LayerConfig& layer_config, float* w2_tmp = w_tmp.get() + w1_size; CompressWorkingSet work; - hwy::ThreadPool pool(0); - HWY_NAMESPACE::Compress(w1_tmp, w1_size, work, qkv_einsum_w1.Span(), 0, pool); - HWY_NAMESPACE::Compress(w2_tmp, w2_size, work, qkv_einsum_w2.Span(), 0, pool); + HWY_NAMESPACE::Compress(w1_tmp, w1_size, work, qkv_einsum_w1.Span(), 0, ctx); + HWY_NAMESPACE::Compress(w2_tmp, w2_size, work, qkv_einsum_w2.Span(), 0, ctx); qkv_einsum_w1.SetScale(1.0f); qkv_einsum_w2.SetScale(1.0f); @@ -326,16 +323,16 @@ static void HWY_MAYBE_UNUSED SplitAttW1I8(const LayerConfig& layer_config, // TODO: exporters should bake this into the weights already. // WARNING: called from multiple threads; `mat_owners` requires a lock. void LayerWeightsPtrs::Fixup(std::vector& mat_owners, - const Allocator& allocator) { + ThreadingContext& ctx) { if (attn_vec_einsum_w.GetType() == Type::kI8) { MatPtrT attn_vec_einsum_w_i8(attn_vec_einsum_w); MatPtrT att_weights_i8(att_weights); InitAttWeightsI8(layer_config, attn_vec_einsum_w_i8, att_weights_i8, - mat_owners, allocator); + mat_owners, ctx); attn_vec_einsum_w = attn_vec_einsum_w_i8; att_weights = att_weights_i8; } else { - InitAttWeights(mat_owners, allocator); + InitAttWeights(mat_owners, ctx.allocator); } if (gating_einsum_w.GetType() == Type::kI8) { @@ -343,7 +340,7 @@ void LayerWeightsPtrs::Fixup(std::vector& mat_owners, MatPtrT gating_einsum_w1_i8(gating_einsum_w1); MatPtrT gating_einsum_w2_i8(gating_einsum_w2); SplitW1I8(layer_config, gating_einsum_w_i8, gating_einsum_w1_i8, - gating_einsum_w2_i8, mat_owners, allocator); + gating_einsum_w2_i8, mat_owners, ctx); gating_einsum_w = gating_einsum_w_i8; gating_einsum_w1 = gating_einsum_w1_i8; gating_einsum_w2 = gating_einsum_w2_i8; @@ -356,7 +353,7 @@ void LayerWeightsPtrs::Fixup(std::vector& mat_owners, MatPtrT qkv_einsum_w1_i8(qkv_einsum_w1); MatPtrT qkv_einsum_w2_i8(qkv_einsum_w2); SplitAttW1I8(layer_config, qkv_einsum_w_i8, qkv_einsum_w1_i8, - qkv_einsum_w2_i8, mat_owners, allocator); + qkv_einsum_w2_i8, mat_owners, ctx); qkv_einsum_w = qkv_einsum_w_i8; qkv_einsum_w1 = qkv_einsum_w1_i8; qkv_einsum_w2 = qkv_einsum_w2_i8; @@ -367,7 +364,8 @@ void LayerWeightsPtrs::Fixup(std::vector& mat_owners, static void HWY_MAYBE_UNUSED InitAttWeightsNUQ( const LayerConfig& layer_config, MatPtrT& attn_vec_einsum_w, - MatPtrT& att_weights, std::vector& mat_owners) { + MatPtrT& att_weights, std::vector& mat_owners, + ThreadingContext& ctx) { if (!attn_vec_einsum_w.HasPtr()) return; HWY_ASSERT(attn_vec_einsum_w.GetType() == Type::kNUQ); @@ -399,10 +397,9 @@ static void HWY_MAYBE_UNUSED InitAttWeightsNUQ( } CompressWorkingSet work; - hwy::ThreadPool pool(0); HWY_NAMESPACE::Compress(att_weights_tmp.get(), model_dim * heads * qkv_dim, work, att_weights.Span(), - /*packed_ofs=*/0, pool); + /*packed_ofs=*/0, ctx); att_weights.SetScale(attn_vec_einsum_w.Scale()); } @@ -435,13 +432,13 @@ void WeightsPtrs::Fixup(std::vector& mat_owners, ThreadingContext& ctx) { const size_t cluster_idx = 0; ParallelFor(ParallelismStrategy::kFlat, c_layers.size(), ctx, cluster_idx, - [&](uint64_t layer, size_t /*worker*/) { - GetLayer(layer)->Fixup(mat_owners, ctx.allocator); + Callers::kFixupWeights, [&](uint64_t layer, size_t /*worker*/) { + GetLayer(layer)->Fixup(mat_owners, ctx); }); ParallelFor(ParallelismStrategy::kFlat, vit_layers.size(), ctx, cluster_idx, - [&](uint64_t layer, size_t /*worker*/) { - VitLayer(layer)->Fixup(mat_owners, ctx.allocator); + Callers::kFixupWeights, [&](uint64_t layer, size_t /*worker*/) { + VitLayer(layer)->Fixup(mat_owners, ctx); }); } @@ -529,8 +526,9 @@ static void AllocateAndBindAll(std::vector& tensors, owners.resize(start + tensors.size()); // Allocate in parallel because faulting in large tensors is slow. - ctx.pools.Pool().Run( - 0, tensors.size(), [&](uint64_t task, size_t /*thread*/) { + ParallelFor( + ParallelismStrategy::kFlat, tensors.size(), ctx, /*cluster_idx=*/0, + Callers::kAllocateAndBindAll, [&](uint64_t task, size_t /*thread*/) { TensorToRead& tensor = tensors[task]; MatPtr& mat = *tensor.mat; @@ -587,14 +585,13 @@ static void DecompressToBF16(MatPtr& mat, static void ReadAllToBF16(const std::vector& tensors, const BlobReader& reader, ThreadingContext& ctx) { - const auto zone = GetProfilerZone(Zones::kStartupWeightsReadAllToBF16); // Especially TSAN is slow enough to warrant hierarchical parallelism. const ParallelismStrategy strategy = HWY_IS_DEBUG_BUILD ? ParallelismStrategy::kHierarchical : ParallelismStrategy::kFlat; ParallelFor(strategy, tensors.size(), ctx, /*cluster_idx=*/0, - [&](uint64_t task, size_t thread) { - PROFILER_ZONE3(ctx.profiler, thread, zone); + Callers::kReadAllToBF16, [&](uint64_t task, size_t thread) { + GCPP_ZONE(ctx, thread, Zones::kStartupWeightsReadAllToBF16); const TensorToRead& tensor = tensors[task]; MatPtr& mat = *tensor.mat; @@ -679,12 +676,11 @@ static std::vector MakeBatches( static void ReadBatches(const BlobReader& reader, const std::vector& batches, ThreadingContext& ctx) { - const auto zone = GetProfilerZone(Zones::kStartupWeightsReadBatches); // >5x speedup from parallel reads when cached. - ParallelFor(ParallelismStrategy::kHierarchical, - batches.size(), ctx, /*cluster_idx=*/0, + ParallelFor(ParallelismStrategy::kHierarchical, batches.size(), ctx, + /*cluster_idx=*/0, Callers::kReadBatches, [&](uint64_t task, size_t thread) { - PROFILER_ZONE3(ctx.profiler, thread, zone); + GCPP_ZONE(ctx, thread, Zones::kStartupWeightsReadBatches); const IOBatch& batch = batches[task]; const std::string& key = reader.Keys()[batch.KeyIdx()]; const uint64_t bytes_read = batch.Read(reader.file()); diff --git a/gemma/weights.h b/gemma/weights.h index 06c0186..3661869 100644 --- a/gemma/weights.h +++ b/gemma/weights.h @@ -254,7 +254,7 @@ struct LayerWeightsPtrs { // Must be called after reading weights via `ForEachTensor`. // TODO: exporters should bake this into the weights already. // WARNING: called from multiple threads; `mat_owners` requires a lock. - void Fixup(std::vector& mat_owners, const Allocator& allocator); + void Fixup(std::vector& mat_owners, ThreadingContext& ctx); private: // Copies att_weights from `attn_vec_einsum_w`. diff --git a/io/BUILD.bazel b/io/BUILD.bazel index ac925aa..444115a 100644 --- a/io/BUILD.bazel +++ b/io/BUILD.bazel @@ -79,7 +79,6 @@ cc_library( "//:threading_context", "@highway//:hwy", "@highway//:profiler", - "@highway//:thread_pool", ], ) @@ -108,7 +107,6 @@ cc_binary( "@highway//:hwy", "@highway//:hwy_test_util", "@highway//:nanobenchmark", - "@highway//:thread_pool", ], ) diff --git a/io/blob_compare.cc b/io/blob_compare.cc index 998036e..30a2199 100644 --- a/io/blob_compare.cc +++ b/io/blob_compare.cc @@ -107,7 +107,7 @@ void ReadBlobs(BlobReader& reader, const RangeVec& ranges, BlobVec& blobs, HWY_ASSERT(reader.Keys().size() == blobs.size()); HWY_ASSERT(ranges.size() == blobs.size()); ParallelFor(ParallelismStrategy::kWithinCluster, blobs.size(), ctx, - cluster_idx, [&](size_t i, size_t /*thread*/) { + cluster_idx, Callers::kTest, [&](size_t i, size_t /*thread*/) { HWY_ASSERT(ranges[i].bytes == blobs[i].size()); reader.file().Read(ranges[i].offset, ranges[i].bytes, blobs[i].data()); @@ -122,7 +122,7 @@ void ReadBothBlobs(BlobReader& reader1, BlobReader& reader2, const double t0 = hwy::platform::Now(); HWY_WARN("Reading %zu GiB, %zu clusters: ", total_bytes >> 30, ctx.pools.NumClusters()); - ParallelFor(ParallelismStrategy::kAcrossClusters, 2, ctx, 0, + ParallelFor(ParallelismStrategy::kAcrossClusters, 2, ctx, 0, Callers::kTest, [&](const size_t task, size_t cluster_idx) { ReadBlobs(task ? reader1 : reader2, task ? ranges1 : ranges2, task ? blobs1 : blobs2, ctx, cluster_idx); @@ -190,7 +190,7 @@ void CompareBlobs(const KeyVec& keys, BlobVec& blobs1, BlobVec& blobs2, std::atomic blobs_equal{}; std::atomic blobs_diff{}; ParallelFor(ParallelismStrategy::kHierarchical, keys.size(), ctx, 0, - [&](size_t i, size_t /*thread*/) { + Callers::kTest, [&](size_t i, size_t /*thread*/) { const size_t mismatches = BlobDifferences(blobs1[i], blobs2[i], keys[i]); if (mismatches != 0) { diff --git a/io/blob_store.cc b/io/blob_store.cc index 176f46e..af9f81d 100644 --- a/io/blob_store.cc +++ b/io/blob_store.cc @@ -28,7 +28,6 @@ #include "util/threading_context.h" #include "hwy/aligned_allocator.h" // Span #include "hwy/base.h" -#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/detect_compiler_arch.h" #include "hwy/profiler.h" @@ -469,8 +468,8 @@ static void EnqueueChunks(size_t key_idx, uint64_t offset, uint64_t bytes, } } -BlobWriter::BlobWriter(const Path& filename, hwy::ThreadPool& pool) - : file_(OpenFileOrNull(filename, "w+")), pool_(pool) { +BlobWriter::BlobWriter(const Path& filename, ThreadingContext& ctx) + : file_(OpenFileOrNull(filename, "w+")), ctx_(ctx) { if (!file_) HWY_ABORT("Failed to open for writing %s", filename.path.c_str()); // Write a placeholder header to the beginning of the file. If append-only, // we will later write a footer, else we will update the header. @@ -489,10 +488,13 @@ void BlobWriter::Add(const std::string& key, const void* data, size_t bytes) { EnqueueChunks(keys_.size() - 1, curr_offset_, bytes, static_cast(data), writes); - hwy::ThreadPool null_pool(0); - hwy::ThreadPool& pool_or_serial = file_->IsAppendOnly() ? null_pool : pool_; - pool_or_serial.Run( - 0, writes.size(), [this, &writes](uint64_t i, size_t /*thread*/) { + const ParallelismStrategy strategy = file_->IsAppendOnly() + ? ParallelismStrategy::kNone + : ParallelismStrategy::kFlat; + ParallelFor( + strategy, writes.size(), ctx_, + /*cluster_idx=*/0, Callers::kBlobWriter, + [this, &writes](uint64_t i, size_t /*thread*/) { const BlobRange& range = writes[i].range; if (!file_->Write(writes[i].data, range.bytes, range.offset)) { const std::string& key = StringFromKey(keys_[range.key_idx]); diff --git a/io/blob_store.h b/io/blob_store.h index 7ee57d2..e5f2221 100644 --- a/io/blob_store.h +++ b/io/blob_store.h @@ -28,9 +28,9 @@ #include "io/io.h" // File, Path, MapPtr #include "util/basics.h" // Tristate +#include "util/threading_context.h" #include "hwy/aligned_allocator.h" // Span #include "hwy/base.h" // HWY_ASSERT -#include "hwy/contrib/thread_pool/thread_pool.h" namespace gcpp { @@ -117,7 +117,7 @@ class BlobReader { // does not make sense to call the methods concurrently. class BlobWriter { public: - BlobWriter(const Path& filename, hwy::ThreadPool& pool); + BlobWriter(const Path& filename, ThreadingContext& ctx); // Writes the blob to disk with padding for alignment. Aborts on error. void Add(const std::string& key, const void* data, size_t bytes); @@ -129,7 +129,7 @@ class BlobWriter { std::unique_ptr file_; std::vector keys_; std::vector blob_sizes_; - hwy::ThreadPool& pool_; + ThreadingContext& ctx_; // Current offset in the file used for writing. int64_t curr_offset_ = 0; }; diff --git a/io/blob_store_test.cc b/io/blob_store_test.cc index 078a974..bb41c7e 100644 --- a/io/blob_store_test.cc +++ b/io/blob_store_test.cc @@ -38,7 +38,6 @@ class BlobStoreTest : public testing::Test {}; TEST(BlobStoreTest, TestReadWrite) { ThreadingArgs threading_args; ThreadingContext ctx(threading_args); - hwy::ThreadPool& pool = ctx.pools.Pool(); static const std::array kOriginalData = {-1, 0, 3.14159, 2.71828}; @@ -52,7 +51,7 @@ TEST(BlobStoreTest, TestReadWrite) { const std::string keyA("0123456789abcdef"); // max 16 characters const std::string keyB("q"); - BlobWriter writer(path, pool); + BlobWriter writer(path, ctx); writer.Add(keyA, "DATA", 5); writer.Add(keyB, buffer.data(), sizeof(buffer)); writer.Finalize(); @@ -96,7 +95,6 @@ TEST(BlobStoreTest, TestReadWrite) { TEST(BlobStoreTest, TestNumBlobs) { ThreadingArgs threading_args; ThreadingContext ctx(threading_args); - hwy::ThreadPool& pool = ctx.pools.Pool(); hwy::RandomState rng; for (size_t num_blobs = 1; num_blobs <= 512; ++num_blobs) { @@ -106,7 +104,7 @@ TEST(BlobStoreTest, TestNumBlobs) { HWY_ASSERT(fd > 0); const Path path(path_str); - BlobWriter writer(path, pool); + BlobWriter writer(path, ctx); std::vector keys; keys.reserve(num_blobs); std::vector> blobs; @@ -130,26 +128,31 @@ TEST(BlobStoreTest, TestNumBlobs) { BlobReader reader(path); HWY_ASSERT_EQ(reader.Keys().size(), num_blobs); - pool.Run(0, num_blobs, [&](uint64_t i, size_t /*thread*/) { - HWY_ASSERT_STRING_EQ(reader.Keys()[i].c_str(), std::to_string(i).c_str()); - const BlobRange* range = reader.Find(keys[i]); - HWY_ASSERT(range); - HWY_ASSERT_EQ(blobs[i].size(), range->bytes); - HWY_ASSERT(reader.CallWithSpan( - keys[i], [path_str, num_blobs, i, range, - &blobs](const hwy::Span span) { - HWY_ASSERT_EQ(blobs[i].size(), span.size()); - const bool match1 = span[0] == static_cast(i & 255); - // If size == 1, we don't have a second byte to check. - const bool match2 = - span.size() == 1 || - span[span.size() - 1] == static_cast(i >> 8); - if (!match1 || !match2) { - HWY_ABORT("%s num_blobs %zu blob %zu offset %zu is corrupted.", - path_str, num_blobs, i, range->offset); - } - })); - }); + + ParallelFor( + ParallelismStrategy::kFlat, num_blobs, ctx, /*cluster_idx=*/0, + Callers::kTest, [&](uint64_t i, size_t /*thread*/) { + HWY_ASSERT_STRING_EQ(reader.Keys()[i].c_str(), + std::to_string(i).c_str()); + const BlobRange* range = reader.Find(keys[i]); + HWY_ASSERT(range); + HWY_ASSERT_EQ(blobs[i].size(), range->bytes); + HWY_ASSERT(reader.CallWithSpan( + keys[i], [path_str, num_blobs, i, range, + &blobs](const hwy::Span span) { + HWY_ASSERT_EQ(blobs[i].size(), span.size()); + const bool match1 = span[0] == static_cast(i & 255); + // If size == 1, we don't have a second byte to check. + const bool match2 = + span.size() == 1 || + span[span.size() - 1] == static_cast(i >> 8); + if (!match1 || !match2) { + HWY_ABORT( + "%s num_blobs %zu blob %zu offset %zu is corrupted.", + path_str, num_blobs, i, range->offset); + } + })); + }); close(fd); unlink(path_str); diff --git a/io/migrate_weights.cc b/io/migrate_weights.cc index aa500bb..d20835f 100644 --- a/io/migrate_weights.cc +++ b/io/migrate_weights.cc @@ -44,6 +44,6 @@ int main(int argc, char** argv) { } gcpp::GemmaEnv env(argc, argv); - env.GetGemma()->Save(args.output_weights, env.Env().ctx.pools); + env.GetGemma()->Save(args.output_weights, env.Env().ctx); return 0; } diff --git a/ops/bench_matmul.cc b/ops/bench_matmul.cc index 1be2bed..221405c 100644 --- a/ops/bench_matmul.cc +++ b/ops/bench_matmul.cc @@ -30,7 +30,6 @@ #include "ops/matmul.h" #include "util/basics.h" #include "util/threading_context.h" -#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/nanobenchmark.h" #include "hwy/profiler.h" #include "hwy/timer.h" @@ -72,7 +71,6 @@ void PrintSpeed(const Extents2D& A_extents, const Extents2D& B_extents, // M = A rows, K = A cols, N = C cols. template void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) { - hwy::ThreadPool& pool = env.ctx.pools.Pool(0); if (env.print_config || env.print_measurement) { fprintf(stderr, "\n"); } @@ -91,15 +89,14 @@ void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) { MatStorageT add_storage("add", Extents2D(), env.ctx.allocator, MatPadding::kPacked); if (add) { - add_storage = GenerateMat(Extents2D(1, N), env.ctx.allocator, - MatPadding::kPacked, pool); + add_storage = + GenerateMat(Extents2D(1, N), MatPadding::kPacked, env.ctx); add_storage.SetScale(1.0f); } - MatStorageT a = - GenerateMat(A_extents, env.ctx.allocator, MatPadding::kOdd, pool); - MatStorageT b_trans = GenerateTransposedMat( - B_extents, env.ctx.allocator, MatPadding::kOdd, pool); + MatStorageT a = GenerateMat(A_extents, MatPadding::kOdd, env.ctx); + MatStorageT b_trans = + GenerateTransposedMat(B_extents, MatPadding::kOdd, env.ctx); const float* add_row = add ? add_storage.PackedScale1() : nullptr; diff --git a/ops/dot_test.cc b/ops/dot_test.cc index ed09429..827b6b4 100644 --- a/ops/dot_test.cc +++ b/ops/dot_test.cc @@ -31,7 +31,6 @@ #include "util/test_util.h" #include "util/threading_context.h" #include "hwy/base.h" -#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/profiler.h" #include "hwy/stats.h" #include "hwy/timer.h" @@ -922,9 +921,11 @@ void GenerateWellConditionedInputs(const size_t num, float* HWY_RESTRICT raw, (void)ScaleWeights(raw, num); } - hwy::ThreadPool pool(0); // num is too small for parallelization + ThreadingArgs threading_args; + threading_args.max_lps = 1; // num is too small for parallelization + ThreadingContext ctx(threading_args); const size_t packed_ofs = 0; - Compress(raw, num, work, packed, packed_ofs, pool); + Compress(raw, num, work, packed, packed_ofs, ctx); const hn::ScalableTag df; DecompressAndZeroPad(df, MakeConst(packed), packed_ofs, raw, num); @@ -1125,7 +1126,7 @@ void TestAllDot() { std::array all_stats; ParallelFor( - ParallelismStrategy::kWithinCluster, kReps, ctx, 0, + ParallelismStrategy::kWithinCluster, kReps, ctx, 0, Callers::kTest, [&](size_t rep, size_t thread) { float* HWY_RESTRICT pa = a.Row(thread); float* HWY_RESTRICT pb = b.Row(thread); diff --git a/ops/matmul-inl.h b/ops/matmul-inl.h index d72ac38..96cd4f1 100644 --- a/ops/matmul-inl.h +++ b/ops/matmul-inl.h @@ -291,7 +291,7 @@ class MMDecompress { const hn::ScalableTag dbf; const size_t NBF = hn::Lanes(dbf); - const auto zone = GetProfilerZone(Zones::kMMDecompressA); + const auto zone = env.ctx.profiler_zones.Get(Zones::kMMDecompressA); const auto do_range = [&](const IndexRange& range_M, const IndexRange& range_K, size_t worker) @@ -879,9 +879,8 @@ class MMLoops { static HWY_NOINLINE void Dispatch(const StridedViewBF A, const MatPtrT& B, const MatPtrT* B2, RowPtrs C, const MMArgs& args) { - PROFILER_ZONE3(args.env.ctx.profiler, - args.env.ctx.Worker(args.options.cluster_idx), - GetProfilerZone(Zones::kMMDispatch)); + GCPP_ZONE(args.env.ctx, args.env.ctx.Worker(args.options.cluster_idx), + Zones::kMMDispatch); DispatchParallelism( args.options.parallelism, [&](const auto& parallel) HWY_ATTR { @@ -904,7 +903,7 @@ class MMLoops { const StridedViewBF A, const MatPtrT& B, const MatPtrT* B2, RowPtrs C, const MMArgs& args) { - const auto zone = GetProfilerZone(Zones::kMMNT); + const auto zone = args.env.ctx.profiler_zones.Get(Zones::kMMNT); HWY_DASSERT(args.ranges_mc.NumTasks() == 1); HWY_DASSERT(args.ranges_kc.NumTasks() == 1); const IndexRange& range_mc = args.ranges_mc.Range(0); @@ -940,7 +939,7 @@ class MMLoops { const StridedViewBF A, const MatPtrT& B, const MatPtrT* B2, RowPtrs C, const MMArgs& args) { - const auto zone = GetProfilerZone(Zones::kMMNT_K); + const auto zone = args.env.ctx.profiler_zones.Get(Zones::kMMNT_K); HWY_DASSERT(args.ranges_mc.NumTasks() == 1); const IndexRange& range_mc = args.ranges_mc.Range(0); @@ -976,7 +975,7 @@ class MMLoops { const StridedViewBF A, const MatPtrT& B, const MatPtrT* B2, RowPtrs C, const MMArgs& args) { - const auto zone = GetProfilerZone(Zones::kMMNT_MT); + const auto zone = args.env.ctx.profiler_zones.Get(Zones::kMMNT_MT); HWY_DASSERT(args.ranges_kc.NumTasks() == 1); const IndexRange& range_kc = args.ranges_kc.Range(0); @@ -1010,7 +1009,7 @@ class MMLoops { const StridedViewBF A, const MatPtrT& B, const MatPtrT* B2, RowPtrs C, const MMArgs& args) { - const auto zone = GetProfilerZone(Zones::kMMNT_MT_K); + const auto zone = args.env.ctx.profiler_zones.Get(Zones::kMMNT_MT_K); parallel.ForRangesMC_NC( args.env.ctx, args.ranges_mc, args.ranges_nc, args.options.cluster_idx, @@ -1063,8 +1062,7 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT& A, const MatPtrT& B, MatPtrT& C, MMOptions options = MMOptions()) { const size_t cluster_idx = options.cluster_idx; HWY_DASSERT(cluster_idx < env.row_ptrs.size()); - PROFILER_ZONE3(env.ctx.profiler, env.ctx.Worker(cluster_idx), - GetProfilerZone(Zones::kMMMatMul)); + GCPP_ZONE(env.ctx, env.ctx.Worker(cluster_idx), Zones::kMMMatMul); RowPtrs C_rows = GetOrSetTempRowPtrs(C, env.row_ptrs[cluster_idx]); @@ -1124,8 +1122,7 @@ HWY_NOINLINE MMPerKey* TwoMatMul(const MatPtrT& A, const MatPtrT& B1, MatPtrT& C, MMOptions options) { const size_t cluster_idx = options.cluster_idx; HWY_DASSERT(cluster_idx < env.row_ptrs.size()); - PROFILER_ZONE3(env.ctx.profiler, env.ctx.Worker(cluster_idx), - GetProfilerZone(Zones::kMMTwoMatMul)); + GCPP_ZONE(env.ctx, env.ctx.Worker(cluster_idx), Zones::kMMTwoMatMul); HWY_DASSERT(options.func != nullptr); // no other way to get access to C2. diff --git a/ops/matmul.h b/ops/matmul.h index e16c0f2..ea7c090 100644 --- a/ops/matmul.h +++ b/ops/matmul.h @@ -111,6 +111,7 @@ struct MMParallelWithinCluster { const IndexRangePartition ranges_n = StaticPartition( range_n, cluster.NumWorkers() * inner_tasks, n_multiple); ParallelizeOneRange(ranges_n, cluster, + ctx.pool_callers.Get(Callers::kMMClusterForN), [&](const IndexRange& worker_range, size_t worker) { func(worker_range, base + worker); }); @@ -127,12 +128,14 @@ struct MMParallelWithinCluster { // Low-batch: avoid Divide/Remainder. if (HWY_UNLIKELY(ranges_mc.NumTasks() == 1)) { ParallelizeOneRange(ranges_nc, cluster, + ctx.pool_callers.Get(Callers::kMMClusterForMCNC), [&](const IndexRange& range_nc, size_t worker) { func(ranges_mc.Range(0), range_nc, base + worker); }); } else { ParallelizeTwoRanges( ranges_mc, ranges_nc, cluster, + ctx.pool_callers.Get(Callers::kMMClusterForMCNC), [&](const IndexRange& range_mc, const IndexRange& range_nc, size_t worker) { func(range_mc, range_nc, base + worker); }); } @@ -146,6 +149,7 @@ struct MMParallelWithinCluster { cluster.Run( range_mc.begin(), range_mc.end(), + ctx.pool_callers.Get(Callers::kMMClusterForMC), [&](uint64_t row_a, size_t worker) { func(row_a, base + worker); }); } }; @@ -159,6 +163,7 @@ struct MMParallelHierarchical { const Func& func) const { HWY_DASSERT(1 <= inner_tasks && inner_tasks <= 4); HWY_DASSERT(caller_cluster_idx == 0); + const hwy::pool::Caller caller = ctx.pool_callers.Get(Callers::kMMHierForN); // Single cluster: parallel-for over static partition of `range_n`. hwy::ThreadPool& all_clusters = ctx.pools.AllClusters(); @@ -169,7 +174,7 @@ struct MMParallelHierarchical { const IndexRangePartition ranges_n = StaticPartition( range_n, cluster.NumWorkers() * inner_tasks, n_multiple); return ParallelizeOneRange( - ranges_n, cluster, + ranges_n, cluster, caller, [&](const IndexRange& worker_range, size_t worker) { func(worker_range, worker); }); @@ -179,7 +184,7 @@ struct MMParallelHierarchical { const IndexRangePartition ranges_n = StaticPartition(range_n, num_clusters, n_multiple); ParallelizeOneRange( - ranges_n, all_clusters, + ranges_n, all_clusters, caller, [&](const IndexRange& n_range, const size_t cluster_idx) { hwy::ThreadPool& cluster = ctx.pools.Cluster(cluster_idx); const size_t cluster_base = ctx.Worker(cluster_idx); @@ -187,7 +192,7 @@ struct MMParallelHierarchical { const IndexRangePartition worker_ranges = StaticPartition( n_range, cluster.NumWorkers() * inner_tasks, n_multiple); ParallelizeOneRange( - worker_ranges, cluster, + worker_ranges, cluster, caller, [&](const IndexRange& worker_range, size_t worker) { func(worker_range, cluster_base + worker); }); @@ -203,6 +208,8 @@ struct MMParallelHierarchical { HWY_MAYBE_UNUSED size_t caller_cluster_idx, const Func& func) const { HWY_DASSERT(caller_cluster_idx == 0); + const hwy::pool::Caller caller = + ctx.pool_callers.Get(Callers::kMMHierForMCNC); hwy::ThreadPool& all_clusters = ctx.pools.AllClusters(); // `all_clusters` is a pool with one worker per cluster in a package. @@ -215,12 +222,13 @@ struct MMParallelHierarchical { // Low-batch: avoid Divide/Remainder. if (HWY_UNLIKELY(ranges_mc.NumTasks() == 1)) { return ParallelizeOneRange( - ranges_nc, cluster, [&](const IndexRange& range_nc, size_t worker) { + ranges_nc, cluster, caller, + [&](const IndexRange& range_nc, size_t worker) { func(ranges_mc.Range(0), range_nc, worker); }); } else { return ParallelizeTwoRanges( - ranges_mc, ranges_nc, cluster, + ranges_mc, ranges_nc, cluster, caller, [&](const IndexRange& range_mc, const IndexRange& range_nc, size_t worker) { func(range_mc, range_nc, worker); }); } @@ -229,11 +237,11 @@ struct MMParallelHierarchical { // Multiple clusters: N across clusters (both are usually the larger), and // M within each cluster. We assume auto-tuning finds small MC/NC tasks. ParallelizeOneRange( - ranges_nc, all_clusters, + ranges_nc, all_clusters, caller, [&](const IndexRange range_nc, size_t cluster_idx) { const size_t cluster_base = ctx.Worker(cluster_idx); hwy::ThreadPool& cluster = ctx.pools.Cluster(cluster_idx); - ParallelizeOneRange(ranges_mc, cluster, + ParallelizeOneRange(ranges_mc, cluster, caller, [&](const IndexRange& range_mc, size_t worker) { func(range_mc, range_nc, cluster_base + worker); }); @@ -244,7 +252,7 @@ struct MMParallelHierarchical { template void ForRangeMC(ThreadingContext& ctx, const IndexRange& range_mc, size_t caller_cluster_idx, const Func& func) const { - HierarchicalParallelFor(range_mc.Num(), ctx.pools, + HierarchicalParallelFor(range_mc.Num(), ctx, Callers::kMMHierForMC, [&](size_t task, size_t worker) { func(range_mc.begin() + task, worker); }); @@ -811,7 +819,7 @@ class MMZone { private: uint64_t data_ = 0; - uint64_t data2_ = 0; + HWY_MEMBER_VAR_MAYBE_UNUSED uint64_t data2_ = 0; }; #else struct MMZone { diff --git a/ops/matmul_test.cc b/ops/matmul_test.cc index 101707f..2aaf301 100644 --- a/ops/matmul_test.cc +++ b/ops/matmul_test.cc @@ -196,7 +196,7 @@ HWY_INLINE void MatMulSlow(const MatPtrT A, const MatPtrT B, const IndexRangePartition get_col_c = StaticPartition(all_cols_c, all_clusters.NumWorkers(), multiple); ParallelizeOneRange( - get_col_c, all_clusters, + get_col_c, all_clusters, env.ctx.pool_callers.Get(Callers::kTest), [&](const IndexRange& cols_c, size_t cluster_idx) HWY_ATTR { for (size_t r : all_rows_c) { TC* HWY_RESTRICT C_row = C.Row(r); @@ -221,7 +221,6 @@ void PrintSpeed(const char* algo, const Extents2D& A_extents, template void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add, MatMulEnv& env, int line) { - hwy::ThreadPool& pool = env.ctx.pools.Pool(); fprintf(stderr, "TestMatMul %zu, K=%zu, %zu, add=%d, TA=%s, TB=%s, TC=%s\n", rows_ac, cols_a_rows_b, cols_bc, add, TypeName(), TypeName(), TypeName()); @@ -233,11 +232,10 @@ void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add, const Extents2D B_extents(cols_bc, cols_a_rows_b); // already transposed const Extents2D C_extents(rows_ac, cols_bc); - MatStorageT A( - GenerateMat(A_extents, env.ctx.allocator, MatPadding::kOdd, pool)); + MatStorageT A(GenerateMat(A_extents, MatPadding::kOdd, env.ctx)); // Must be packed because we call Span() on it. - MatStorageT BT(GenerateTransposedMat(B_extents, env.ctx.allocator, - MatPadding::kPacked, pool)); + MatStorageT BT( + GenerateTransposedMat(B_extents, MatPadding::kPacked, env.ctx)); MatStorageT C_slow("C_slow", C_extents, env.ctx.allocator, MatPadding::kOdd); MatStorageT C("C", C_extents, env.ctx.allocator, MatPadding::kOdd); @@ -246,8 +244,8 @@ void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add, C2.AllocateAndAttachRowPtrs(env.row_ptrs); MatStorageT add_storage = - add ? GenerateMat(Extents2D(1, cols_bc), env.ctx.allocator, - MatPadding::kPacked, pool) + add ? GenerateMat(Extents2D(1, cols_bc), MatPadding::kPacked, + env.ctx) : MatStorageT("add", Extents2D(), env.ctx.allocator, MatPadding::kPacked); add_storage.SetScale(1.0f); diff --git a/ops/ops-inl.h b/ops/ops-inl.h index 4ff2c7d..f2933ca 100644 --- a/ops/ops-inl.h +++ b/ops/ops-inl.h @@ -205,9 +205,9 @@ namespace detail { // Shared by RMSNorm and RMSNormInplace. template -float RMSNormMul(const VT* HWY_RESTRICT x, const size_t size, hwy::Profiler& p, - const size_t worker) { - PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsRmsNormMul)); +float RMSNormMul(const VT* HWY_RESTRICT x, const size_t size, + ThreadingContext& ctx, const size_t worker) { + GCPP_ZONE(ctx, worker, Zones::kOpsRmsNormMul); const hn::ScalableTag d; const float l2 = DecompressAndCall(d, MakeSpan(x, size), DotKernelDefault()); @@ -218,19 +218,17 @@ float RMSNormMul(const VT* HWY_RESTRICT x, const size_t size, hwy::Profiler& p, } // namespace detail template -HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(const XT* HWY_RESTRICT x, - const WT* HWY_RESTRICT weight, - const size_t w_ofs, - OT* HWY_RESTRICT out, - const size_t size, hwy::Profiler& p, - const size_t worker) { - PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsRmsNorm)); +HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm( + const XT* HWY_RESTRICT x, const WT* HWY_RESTRICT weight, const size_t w_ofs, + OT* HWY_RESTRICT out, const size_t size, ThreadingContext& ctx, + const size_t worker) { + GCPP_ZONE(ctx, worker, Zones::kOpsRmsNorm); namespace hn = hwy::HWY_NAMESPACE; using DF = hn::ScalableTag; using VF = hn::Vec; - const VF mul = hn::Set(DF(), detail::RMSNormMul(x, size, p, worker)); + const VF mul = hn::Set(DF(), detail::RMSNormMul(x, size, ctx, worker)); const VF* HWY_RESTRICT pmul = &mul; Decompress2AndCompressTo(DF(), out, size, x, weight, w_ofs, @@ -245,13 +243,13 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(const XT* HWY_RESTRICT x, template HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace( const WT* HWY_RESTRICT weight, const size_t w_ofs, XT* HWY_RESTRICT inout, - const size_t size, hwy::Profiler& p, const size_t worker) { - PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsRmsNormInplace)); + const size_t size, ThreadingContext& ctx, const size_t worker) { + GCPP_ZONE(ctx, worker, Zones::kOpsRmsNormInplace); namespace hn = hwy::HWY_NAMESPACE; using DF = hn::ScalableTag; using VF = hn::Vec; - const VF mul = hn::Set(DF(), detail::RMSNormMul(inout, size, p, worker)); + const VF mul = hn::Set(DF(), detail::RMSNormMul(inout, size, ctx, worker)); const VF* HWY_RESTRICT pmul = &mul; Decompress1AndCompressInplace(DF(), inout, size, weight, w_ofs, @@ -359,9 +357,9 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void AddAbsolutePositionalEmbeddings( // This overload is called if `post_qk == PostQKType::HalfRope`. static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope( float* HWY_RESTRICT x, const size_t dim_qkv, - const float* HWY_RESTRICT inv_timescale, const int pos, hwy::Profiler& p, - const size_t worker) { - PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsRope)); + const float* HWY_RESTRICT inv_timescale, const int pos, + ThreadingContext& ctx, const size_t worker) { + GCPP_ZONE(ctx, worker, Zones::kOpsRope); HWY_DASSERT(dim_qkv % 2 == 0); const size_t half_dim_qkv = dim_qkv / 2; @@ -418,9 +416,9 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope( // `inv_timescale[dim_qkv / 2]` is precomputed in AttentionActivations. static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy( const float mul, float* HWY_RESTRICT x, const size_t dim_qkv, - const float* HWY_RESTRICT inv_timescale, const int pos, hwy::Profiler& p, - const size_t worker) { - PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsRopeAndMulBy)); + const float* HWY_RESTRICT inv_timescale, const int pos, + ThreadingContext& ctx, const size_t worker) { + GCPP_ZONE(ctx, worker, Zones::kOpsRopeAndMulBy); HWY_DASSERT(dim_qkv % 2 == 0); const size_t half_dim_qkv = dim_qkv / 2; @@ -480,9 +478,9 @@ template static HWY_NOINLINE HWY_MAYBE_UNUSED void AddFrom(const XT* HWY_RESTRICT x, float* HWY_RESTRICT out, const size_t size, - hwy::Profiler& p, + ThreadingContext& ctx, const size_t worker) { - PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsAddFrom)); + GCPP_ZONE(ctx, worker, Zones::kOpsAddFrom); namespace hn = hwy::HWY_NAMESPACE; using DF = hn::ScalableTag; @@ -503,10 +501,11 @@ void RMSNormBatched(const MatPtrT& activations, const MatPtr& weights, CallUpcasted(&weights, [&](const auto* weights_t) { ParallelFor(ParallelismStrategy::kFlat, activations.Rows(), ctx, - cluster_idx, [&](uint64_t token_idx, size_t worker) { + cluster_idx, Callers::kOpsRMSNormBatched, + [&](uint64_t token_idx, size_t worker) { RMSNorm(activations.Row(token_idx), weights_t->PackedScale1(), /*w_ofs=*/0, out.Row(token_idx), activations.Cols(), - ctx.profiler, worker); + ctx, worker); }); }); } @@ -519,10 +518,11 @@ void RMSNormInplaceBatched(const MatPtr& weights, MatPtrT& inout, CallUpcasted(&weights, [&](const auto* weights_t) { ParallelFor(ParallelismStrategy::kFlat, inout.Rows(), ctx, cluster_idx, + Callers::kOpsRMSNormInplaceBatched, [&](uint64_t token_idx, size_t worker) { RMSNormInplace(weights_t->PackedScale1(), /*w_ofs=*/0, - inout.Row(token_idx), inout.Cols(), - ctx.profiler, worker); + inout.Row(token_idx), inout.Cols(), ctx, + worker); }); }); } @@ -549,13 +549,14 @@ static HWY_INLINE void AddFromBatched(const MatPtrT& x, MatPtrT& out, ThreadingContext& ctx, size_t cluster_idx = 0) { HWY_DASSERT(out.SameShape(x)); - ParallelFor(ParallelismStrategy::kFlat, out.Rows(), ctx, cluster_idx, - [&](uint64_t token_idx, size_t worker) { - AddFrom(x.Row(token_idx), out.Row(token_idx), x.Cols(), - ctx.profiler, worker); - }); + ParallelFor( + ParallelismStrategy::kFlat, out.Rows(), ctx, cluster_idx, + Callers::kOpsAddFromBatched, [&](uint64_t token_idx, size_t worker) { + AddFrom(x.Row(token_idx), out.Row(token_idx), x.Cols(), ctx, worker); + }); } +// No profiler zone because this is short and frequently called. template HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConst(const float c, XT* HWY_RESTRICT x, const size_t size) { @@ -575,8 +576,8 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConst(const float c, XT* HWY_RESTRICT x, template HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstTo( const float c, const XT* HWY_RESTRICT x, OT* HWY_RESTRICT out, - const size_t size, hwy::Profiler& p, const size_t worker) { - PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsMulByConstTo)); + const size_t size, ThreadingContext& ctx, const size_t worker) { + GCPP_ZONE(ctx, worker, Zones::kOpsMulByConstTo); namespace hn = hwy::HWY_NAMESPACE; using DF = hn::ScalableTag; using VF = hn::Vec; @@ -1121,10 +1122,10 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddVector( // See below for a specialized version for top-1 sampling. // TODO: support bf16 logits using Decompress2. -static HWY_NOINLINE void Softmax(Logits logits, hwy::Profiler& p, +static HWY_NOINLINE void Softmax(Logits logits, ThreadingContext& ctx, const size_t worker, float temperature = 1.0f) { - PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsSoftmax)); + GCPP_ZONE(ctx, worker, Zones::kOpsSoftmax); HWY_DASSERT(logits.size() != 0); namespace hn = hwy::HWY_NAMESPACE; @@ -1256,8 +1257,9 @@ static HWY_MAYBE_UNUSED TokenAndProb Top1OfSoftmax(Logits logits) { } static HWY_NOINLINE void LogitsSoftCap(const float cap, Logits logits, - hwy::Profiler& p, const size_t worker) { - PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsLogitsSoftCap)); + ThreadingContext& ctx, + const size_t worker) { + GCPP_ZONE(ctx, worker, Zones::kOpsLogitsSoftCap); namespace hn = hwy::HWY_NAMESPACE; using DF = hn::ScalableTag; @@ -1277,9 +1279,10 @@ static HWY_NOINLINE void LogitsSoftCap(const float cap, Logits logits, // Calls LogitsSoftCap if cap != 0.0f. static HWY_INLINE HWY_MAYBE_UNUSED void MaybeLogitsSoftCap( - const float cap, Logits logits, hwy::Profiler& p, const size_t worker) { + const float cap, Logits logits, ThreadingContext& ctx, + const size_t worker) { if (cap != 0.0f) { - LogitsSoftCap(cap, logits, p, worker); + LogitsSoftCap(cap, logits, ctx, worker); } } @@ -1288,9 +1291,10 @@ static HWY_INLINE HWY_MAYBE_UNUSED void MaybeLogitsSoftCapBatched( ThreadingContext& ctx, size_t cluster_idx = 0) { if (cap == 0.0f) return; ParallelFor(ParallelismStrategy::kFlat, x.Rows(), ctx, cluster_idx, + Callers::kOpsMaybeLogitsSoftCapBatched, [&](uint64_t task, size_t worker) { if (non_eos.Get(task)) { - LogitsSoftCap(cap, x.RowSpan(task), ctx.profiler, worker); + LogitsSoftCap(cap, x.RowSpan(task), ctx, worker); } }); } @@ -1371,7 +1375,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED int SampleTopK(Logits logits, size_t k, template HWY_NOINLINE HWY_MAYBE_UNUSED TokenAndProb FusedSoftmaxAndSampleTopK( Logits logits, size_t k, RngStream& gen, float temperature, - TAcceptToken& accept_token, hwy::Profiler& p, size_t worker) { + TAcceptToken& accept_token, ThreadingContext& ctx, size_t worker) { // Softmax and sample top-K is equivalent to taking the top-K logits and // sampling from the softmax of the top-K logits. The latter is faster as it // avoids computing the softmax of all logits. @@ -1384,7 +1388,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED TokenAndProb FusedSoftmaxAndSampleTopK( } const size_t mask = token_logits.size(); - Softmax(Logits(topk_logits.data(), mask), p, worker, temperature); + Softmax(Logits(topk_logits.data(), mask), ctx, worker, temperature); auto distribution = std::discrete_distribution( std::begin(topk_logits), std::begin(topk_logits) + mask); int topk_sampled_index = distribution(gen); diff --git a/ops/ops_test.cc b/ops/ops_test.cc index d46bb5c..4d94b61 100644 --- a/ops/ops_test.cc +++ b/ops/ops_test.cc @@ -58,6 +58,11 @@ namespace HWY_NAMESPACE { namespace hn = hwy::HWY_NAMESPACE; +static ThreadingContext& Ctx() { + static ThreadingContext* ctx = new ThreadingContext(ThreadingArgs()); + return *ctx; +} + static RngStream MakeRng() { static AesCtrEngine engine(/*deterministic=*/true); static uint64_t stream = 0; @@ -133,8 +138,7 @@ class TestAddFrom { } SimpleAddFrom(o, e, count); - InitProfilerZones(hwy::Profiler::Get()); - AddFrom(o, x, count, hwy::Profiler::Get(), /*worker=*/0); + AddFrom(o, x, count, Ctx(), /*worker=*/0); hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__, __LINE__); @@ -182,7 +186,6 @@ class TestMulByConstAndAdd { T constant = Random(rng); SimpleMulByConstAndAdd(constant, o, e, count); - InitProfilerZones(hwy::Profiler::Get()); MulByConstAndAdd(constant, o, x, count); hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__, @@ -231,7 +234,6 @@ class TestMulByConst { T constant = Random(rng); SimpleMulByConst(constant, e, count); - InitProfilerZones(hwy::Profiler::Get()); MulByConst(constant, x, count); hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__, @@ -278,8 +280,7 @@ struct TestMulByConstTo { hwy::ConvertScalarTo(constant)); } - InitProfilerZones(hwy::Profiler::Get()); - MulByConstTo(constant, x, actual, count, hwy::Profiler::Get(), + MulByConstTo(constant, x, actual, count, Ctx(), /*worker=*/0); hwy::AssertArraySimilar(e, actual, count, hwy::TargetName(HWY_TARGET), @@ -315,8 +316,7 @@ class TestSoftmax { } SimpleSoftmax(e, count); - InitProfilerZones(hwy::Profiler::Get()); - Softmax(Logits(x, count), hwy::Profiler::Get(), /*worker=*/0); + Softmax(Logits(x, count), Ctx(), /*worker=*/0); T sum = 0.0f; for (size_t i = 0; i < count; ++i) { @@ -440,9 +440,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void ScalarRopeAndMulBy( } void TestRopeAndMulBy() { - ThreadingArgs threading_args; - ThreadingContext ctx(threading_args); - hwy::Profiler& p = ctx.profiler; + ThreadingContext& ctx = Ctx(); const size_t worker = 0; const ModelConfig config(Model::GEMMA2_9B, Type::kSFP, @@ -476,7 +474,7 @@ void TestRopeAndMulBy() { CopyMat(x, qactual); ScalarRopeAndMulBy(qmul, qexpected.Row(0), dim_qkv, inv_timescale.Row(0), pos); - RopeAndMulBy(qmul, qactual.Row(0), dim_qkv, inv_timescale.Row(0), pos, p, + RopeAndMulBy(qmul, qactual.Row(0), dim_qkv, inv_timescale.Row(0), pos, ctx, worker); for (size_t i = 0; i < dim_qkv; ++i) { EXPECT_NEAR(qexpected.Row(0)[i], qactual.Row(0)[i], 1e-4) << " " << i; @@ -487,7 +485,7 @@ void TestRopeAndMulBy() { CopyMat(x, qactual); ScalarRopeAndMulBy(1.0f, qexpected.Row(0), dim_qkv, inv_timescale.Row(0), pos); - Rope(qactual.Row(0), dim_qkv, inv_timescale.Row(0), pos, p, worker); + Rope(qactual.Row(0), dim_qkv, inv_timescale.Row(0), pos, ctx, worker); for (size_t i = 0; i < dim_qkv; ++i) { EXPECT_NEAR(qexpected.Row(0)[i], qactual.Row(0)[i], 1e-4) << " " << i; } @@ -498,10 +496,10 @@ void TestRopeAndMulBy() { CopyMat(x, kactual2); ScalarRopeAndMulBy(kmul, kexpected.Row(0), dim_qkv, inv_timescale.Row(0), pos); - RopeAndMulBy(kmul, kactual.Row(0), dim_qkv, inv_timescale.Row(0), pos, p, + RopeAndMulBy(kmul, kactual.Row(0), dim_qkv, inv_timescale.Row(0), pos, ctx, worker); static_assert(kmul == 1.0f, ""); - Rope(kactual2.Row(0), dim_qkv, inv_timescale.Row(0), pos, p, worker); + Rope(kactual2.Row(0), dim_qkv, inv_timescale.Row(0), pos, ctx, worker); for (size_t i = 0; i < dim_qkv; ++i) { EXPECT_NEAR(kexpected.Row(0)[i], kactual.Row(0)[i], 1e-4) << " " << i; @@ -557,8 +555,7 @@ struct TestRMSNorm { } ScalarRMSNorm(vec, weight, expected, kSize); - InitProfilerZones(hwy::Profiler::Get()); - RMSNorm(vec, weight, /*w_ofs=*/0, actual, kSize, hwy::Profiler::Get(), + RMSNorm(vec, weight, /*w_ofs=*/0, actual, kSize, Ctx(), /*worker=*/0); for (size_t i = 0; i < kSize; i++) { @@ -593,8 +590,7 @@ struct TestRMSNormInplace { } ScalarRMSNorm(expected, weight, expected, kSize); - InitProfilerZones(hwy::Profiler::Get()); - RMSNormInplace(weight, /*w_ofs=*/0, actual, kSize, hwy::Profiler::Get(), + RMSNormInplace(weight, /*w_ofs=*/0, actual, kSize, Ctx(), /*worker=*/0); for (size_t i = 0; i < kSize; i++) { @@ -715,15 +711,14 @@ void TestAllLayerNorm() { } void TestSampleTopK() { - hwy::Profiler& p = hwy::Profiler::Get(); - InitProfilerZones(p); + ThreadingContext& ctx = Ctx(); const size_t worker = 0; const size_t kSize = 52; std::vector logits_vec(kSize); Logits logits(logits_vec.data(), kSize); // Create a vector going from -100 to -100+51=49 and take Softmax. std::iota(logits.begin(), logits.end(), -100.0f); - Softmax(logits, p, worker); + Softmax(logits, ctx, worker); RngStream rng = MakeRng(); float temperature = 1.0f; // SampleTopK<1> should return the argmax. @@ -736,7 +731,7 @@ void TestSampleTopK() { EXPECT_EQ(sample, 50); // Last even index. // Reset the logits to a positive, increasing sequence and take Softmax. std::iota(logits.begin(), logits.end(), 1.0f); - Softmax(logits, p, worker); + Softmax(logits, ctx, worker); // Sample from the top 3, expect one of the top 3 even indices. for (int i = 0; i < 100; ++i) { sample = SampleTopK(logits, /*k=*/3, rng, temperature, accept_token); diff --git a/util/threading.cc b/util/threading.cc index 6d4a603..d476519 100644 --- a/util/threading.cc +++ b/util/threading.cc @@ -49,9 +49,10 @@ PinningPolicy::PinningPolicy(Tristate pin) { static void MaybePin(const BoundedTopology& topology, size_t cluster_idx, const BoundedTopology::Cluster& cluster, PinningPolicy& pinning, hwy::ThreadPool& pool) { + static hwy::pool::Caller caller = hwy::ThreadPool::AddCaller("MaybePin"); const std::vector lps = cluster.LPVector(); HWY_ASSERT(pool.NumWorkers() <= lps.size()); - pool.Run(0, pool.NumWorkers(), [&](uint64_t task, size_t thread) { + pool.Run(0, pool.NumWorkers(), caller, [&](uint64_t task, size_t thread) { HWY_ASSERT(task == thread); // each worker has one task char buf[16]; // Linux limitation @@ -141,17 +142,20 @@ NestedPools::NestedPools(const BoundedTopology& topology, // Parallel so we also pin the calling worker in `all_clusters` to // `cluster.lps`. - all_clusters_->Run(0, num_clusters, [&](size_t cluster_idx, size_t thread) { - HWY_ASSERT(cluster_idx == thread); // each thread has one task - const BoundedTopology::Cluster& tcluster = topology.GetCluster(cluster_idx); - clusters_[cluster_idx] = - MakePool(allocator, workers_per_cluster[cluster_idx], - hwy::PoolWorkerMapping(cluster_idx, max_workers_per_cluster_), - tcluster.Node()); - // Pin workers AND the calling thread from `all_clusters_`. - MaybePin(topology, cluster_idx, tcluster, pinning_, - *clusters_[cluster_idx]); - }); + static hwy::pool::Caller caller = hwy::ThreadPool::AddCaller("NestedPools"); + all_clusters_->Run( + 0, num_clusters, caller, [&](size_t cluster_idx, size_t thread) { + HWY_ASSERT(cluster_idx == thread); // each thread has one task + const BoundedTopology::Cluster& tcluster = + topology.GetCluster(cluster_idx); + clusters_[cluster_idx] = MakePool( + allocator, workers_per_cluster[cluster_idx], + hwy::PoolWorkerMapping(cluster_idx, max_workers_per_cluster_), + tcluster.Node()); + // Pin workers AND the calling thread from `all_clusters_`. + MaybePin(topology, cluster_idx, tcluster, pinning_, + *clusters_[cluster_idx]); + }); all_pinned_ = pinning_.AllPinned(&pin_string_); } diff --git a/util/threading.h b/util/threading.h index 35c6e22..dcdcf24 100644 --- a/util/threading.h +++ b/util/threading.h @@ -266,9 +266,9 @@ static inline IndexRangePartition StaticPartition(const IndexRange& range, // index to a range. template void ParallelizeOneRange(const IndexRangePartition& get1, hwy::ThreadPool& pool, - const Func& func) { + hwy::pool::Caller caller, const Func& func) { const size_t num_tasks = get1.NumTasks(); - pool.Run(0, num_tasks, [&](uint64_t task, size_t thread) { + pool.Run(0, num_tasks, caller, [&](uint64_t task, size_t thread) { const IndexRange range1 = get1.Range(task); func(range1, thread); }); @@ -282,11 +282,12 @@ void ParallelizeOneRange(const IndexRangePartition& get1, hwy::ThreadPool& pool, template void ParallelizeTwoRanges(const IndexRangePartition& get1, const IndexRangePartition& get2, - hwy::ThreadPool& pool, const Func& func) { + hwy::ThreadPool& pool, hwy::pool::Caller caller, + const Func& func) { const hwy::Divisor div1(static_cast(get1.NumTasks())); const size_t num_tasks = get1.NumTasks() * get2.NumTasks(); - pool.Run(0, num_tasks, [&](uint64_t task, size_t thread) { + pool.Run(0, num_tasks, caller, [&](uint64_t task, size_t thread) { HWY_DASSERT(task < (uint64_t{1} << 32)); const size_t idx2 = div1.Divide(static_cast(task)); const size_t idx1 = div1.Remainder(static_cast(task)); @@ -298,37 +299,6 @@ void ParallelizeTwoRanges(const IndexRangePartition& get1, }); } -// Calls `func(task, worker)` for each task in `[0, num_tasks)`. Parallelizes -// over clusters of ONE package, then within each cluster. -template -void HierarchicalParallelFor(size_t num_tasks, NestedPools& pools, - const Func& func) { - // If few tasks, run on a single cluster. Also avoids a bit of overhead if - // there is only one cluster. - hwy::ThreadPool& all_clusters = pools.AllClusters(); - const size_t num_clusters = all_clusters.NumWorkers(); - hwy::ThreadPool& cluster = pools.Cluster(0); - if (num_clusters == 1 || num_tasks <= cluster.NumWorkers()) { - return cluster.Run(0, num_tasks, [&](uint64_t task, size_t thread) { - func(task, thread); - }); - } - - // Assign each cluster a sub-range. - const IndexRangePartition ranges = - StaticPartition(IndexRange(0, num_tasks), num_clusters, 1); - ParallelizeOneRange( - ranges, all_clusters, - [&](const IndexRange& range, const size_t cluster_idx) { - hwy::ThreadPool& cluster = pools.Cluster(cluster_idx); - const size_t cluster_base = cluster_idx * pools.MaxWorkersPerCluster(); - cluster.Run(range.begin(), range.end(), - [&](uint64_t task, size_t thread) { - func(task, cluster_base + thread); - }); - }); -} - } // namespace gcpp #endif // THIRD_PARTY_GEMMA_CPP_UTIL_THREADING_H_ diff --git a/util/threading_context.cc b/util/threading_context.cc index 0a349fc..e725ce3 100644 --- a/util/threading_context.cc +++ b/util/threading_context.cc @@ -78,31 +78,33 @@ static void TunePool(hwy::PoolWaitMode wait_mode, hwy::ThreadPool& pool) { #endif } -static void TunePools(hwy::PoolWaitMode wait_mode, NestedPools& pools) { - hwy::ThreadPool& clusters = pools.AllClusters(); +static void TunePools(hwy::PoolWaitMode wait_mode, ThreadingContext& ctx) { + hwy::ThreadPool& clusters = ctx.pools.AllClusters(); TunePool(wait_mode, clusters); // Run in parallel because Turin CPUs have 16, and in real usage, we often // run all at the same time. clusters.Run(0, clusters.NumWorkers(), + ctx.pool_callers.Get(Callers::kTunePool), [&](uint64_t cluster_idx, size_t /*thread*/) { - TunePool(wait_mode, pools.Cluster(cluster_idx)); + TunePool(wait_mode, ctx.pools.Cluster(cluster_idx)); }); } ThreadingContext::ThreadingContext(const ThreadingArgs& args) : profiler(hwy::Profiler::Get()), + profiler_zones(profiler), + pool_callers(), topology(BoundedSlice(args.skip_packages, args.max_packages), BoundedSlice(args.skip_clusters, args.max_clusters), BoundedSlice(args.skip_lps, args.max_lps)), cache_info(topology), allocator(topology, cache_info, args.bind != Tristate::kFalse), pools(topology, allocator, args.max_threads, args.pin) { - InitProfilerZones(profiler); PROFILER_ZONE("Startup.ThreadingContext autotune"); - TunePools(hwy::PoolWaitMode::kSpin, pools); + TunePools(hwy::PoolWaitMode::kSpin, *this); // kBlock is the default, hence set/tune it last. - TunePools(hwy::PoolWaitMode::kBlock, pools); + TunePools(hwy::PoolWaitMode::kBlock, *this); } } // namespace gcpp diff --git a/util/threading_context.h b/util/threading_context.h index 5c55fc4..251888f 100644 --- a/util/threading_context.h +++ b/util/threading_context.h @@ -28,6 +28,7 @@ #include "util/basics.h" // Tristate #include "util/threading.h" #include "util/topology.h" +#include "util/zones.h" #include "hwy/profiler.h" // IWYU pragma: end_exports @@ -107,6 +108,9 @@ struct ThreadingContext { // Singleton; pass around a reference to reduce overhead. hwy::Profiler& profiler; + ProfilerZones profiler_zones; + PoolCallers pool_callers; + // Detects topology, subject to limits imposed by user-specified `args`. // For example, if `args.max_clusters` is 1, then `topology.NumClusters()` // will be 1 regardless of the actual system topology. @@ -122,6 +126,9 @@ struct ThreadingContext { NestedPools pools; }; +#define GCPP_ZONE(ctx, global_idx, zone_enum) \ + PROFILER_ZONE3(ctx.profiler, global_idx, ctx.profiler_zones.Get(zone_enum)) + // Describes the strategy for distributing parallel work across cores. enum class ParallelismStrategy : uint8_t { // Execute using a single-threaded loop on the calling thread. The `worker` @@ -147,18 +154,53 @@ enum class ParallelismStrategy : uint8_t { kHierarchical, }; +// Calls `func(task, worker)` for each task in `[0, num_tasks)`. Parallelizes +// over clusters of ONE package, then within each cluster. +template +void HierarchicalParallelFor(size_t num_tasks, ThreadingContext& ctx, + Callers callers, const Func& func) { + const hwy::pool::Caller caller = ctx.pool_callers.Get(callers); + // If few tasks, run on a single cluster. Also avoids a bit of overhead if + // there is only one cluster. + hwy::ThreadPool& all_clusters = ctx.pools.AllClusters(); + const size_t num_clusters = all_clusters.NumWorkers(); + hwy::ThreadPool& cluster = ctx.pools.Cluster(0); + if (num_clusters == 1 || num_tasks <= cluster.NumWorkers()) { + return cluster.Run(0, num_tasks, caller, [&](uint64_t task, size_t thread) { + func(task, thread); + }); + } + + // Assign each cluster a sub-range. + const IndexRangePartition ranges = + StaticPartition(IndexRange(0, num_tasks), num_clusters, 1); + ParallelizeOneRange(ranges, all_clusters, caller, + [&](const IndexRange& range, const size_t cluster_idx) { + hwy::ThreadPool& cluster = + ctx.pools.Cluster(cluster_idx); + const size_t cluster_base = + cluster_idx * ctx.pools.MaxWorkersPerCluster(); + cluster.Run(range.begin(), range.end(), caller, + [&](uint64_t task, size_t thread) { + func(task, cluster_base + thread); + }); + }); +} + // Calls `func(task, worker)` for each `task` in `[0, num_tasks)`, with the // number/type of workers determined by `parallelism`. `cluster_idx` is for // `parallelism == kWithinCluster`, and should be 0 if unknown. template void ParallelFor(ParallelismStrategy parallelism, size_t num_tasks, - ThreadingContext& ctx, size_t cluster_idx, const Func& func) { + ThreadingContext& ctx, size_t cluster_idx, Callers callers, + const Func& func) { HWY_DASSERT(cluster_idx < ctx.topology.NumClusters()); if (cluster_idx != 0) { // If already running across clusters, only use within-cluster modes. HWY_DASSERT(parallelism == ParallelismStrategy::kNone || parallelism == ParallelismStrategy::kWithinCluster); } + const hwy::pool::Caller caller = ctx.pool_callers.Get(callers); switch (parallelism) { case ParallelismStrategy::kNone: { @@ -171,7 +213,7 @@ void ParallelFor(ParallelismStrategy parallelism, size_t num_tasks, case ParallelismStrategy::kAcrossClusters: return ctx.pools.AllClusters().Run( - 0, num_tasks, + 0, num_tasks, caller, [&](uint64_t task, size_t cluster_idx) { func(task, cluster_idx); }); case ParallelismStrategy::kWithinCluster: { @@ -179,7 +221,7 @@ void ParallelFor(ParallelismStrategy parallelism, size_t num_tasks, // used for TLS indexing for example in profiler.h. const size_t base = ctx.Worker(cluster_idx); return ctx.pools.Cluster(cluster_idx) - .Run(0, num_tasks, [&](uint64_t task, size_t worker) { + .Run(0, num_tasks, caller, [&](uint64_t task, size_t worker) { func(task, base + worker); }); } @@ -191,19 +233,19 @@ void ParallelFor(ParallelismStrategy parallelism, size_t num_tasks, const size_t num_clusters = all_clusters.NumWorkers(); if (num_clusters == 1) { return ctx.pools.Cluster(cluster_idx) - .Run(0, num_tasks, + .Run(0, num_tasks, caller, [&](uint64_t task, size_t worker) { func(task, worker); }); } - return ctx.pools.AllClusters().Run( - 0, num_tasks, [&](uint64_t task, size_t cluster_idx) { - const size_t worker = ctx.Worker(cluster_idx); - func(task, worker); - }); + return all_clusters.Run(0, num_tasks, caller, + [&](uint64_t task, size_t cluster_idx) { + const size_t worker = ctx.Worker(cluster_idx); + func(task, worker); + }); } case ParallelismStrategy::kHierarchical: - return HierarchicalParallelFor(num_tasks, ctx.pools, func); + return HierarchicalParallelFor(num_tasks, ctx, callers, func); } } diff --git a/util/threading_test.cc b/util/threading_test.cc index 4cd8554..14ea1a0 100644 --- a/util/threading_test.cc +++ b/util/threading_test.cc @@ -37,6 +37,8 @@ namespace { using ::testing::ElementsAre; +static const hwy::pool::Caller kCaller = hwy::ThreadPool::AddCaller("Test"); + TEST(ThreadingTest, TestBoundedSlice) { const char* name = "test"; // No args = no limit. @@ -205,7 +207,7 @@ TEST(ThreadingTest, TestParallelizeOneRange) { const IndexRangePartition partition = StaticPartition(range, 2, 4); hwy::ThreadPool null_pool(0); size_t calls = 0; - ParallelizeOneRange(partition, null_pool, + ParallelizeOneRange(partition, null_pool, kCaller, [&](const IndexRange& range, size_t) { if (++calls == 1) { HWY_ASSERT(range.begin() == 0 && range.end() == 8); @@ -226,7 +228,7 @@ TEST(ThreadingTest, TestParallelizeTwoRanges) { { size_t calls = 0; ParallelizeTwoRanges( - partition1, partition2, null_pool, + partition1, partition2, null_pool, kCaller, [&](const IndexRange& range1, const IndexRange& range2, size_t) { ++calls; HWY_ASSERT(range1.begin() == 0 || range1.begin() == 8); @@ -240,7 +242,7 @@ TEST(ThreadingTest, TestParallelizeTwoRanges) { { size_t calls = 0; ParallelizeTwoRanges( - partition2, partition1, null_pool, + partition2, partition1, null_pool, kCaller, [&](const IndexRange& range2, const IndexRange& range1, size_t) { ++calls; HWY_ASSERT(range1.begin() == 0 || range1.begin() == 8); @@ -265,7 +267,7 @@ std::vector MeasureForkJoin(hwy::ThreadPool& pool) { const double t0 = hwy::platform::Now(); for (size_t reps = 0; reps < 1200; ++reps) { - pool.Run(0, pool.NumWorkers(), [&](uint64_t task, size_t thread) { + pool.Run(0, pool.NumWorkers(), kCaller, [&](uint64_t task, size_t thread) { outputs[thread * kU64PerThread] = base + thread; }); hwy::PreventElision(outputs[base]); @@ -305,18 +307,20 @@ std::vector MeasureForkJoin(hwy::ThreadPool& pool) { if (have_stop) { for (size_t rep = 0; rep < max_reps; ++rep) { const uint64_t t0 = hwy::timer::Start(); - pool.Run(0, pool.NumWorkers(), [&](uint64_t task, size_t thread) { - outputs[thread * kU64PerThread] = base + thread; - }); + pool.Run(0, pool.NumWorkers(), kCaller, + [&](uint64_t task, size_t thread) { + outputs[thread * kU64PerThread] = base + thread; + }); const uint64_t t1 = hwy::timer::Stop(); times.push_back(t1 - t0); } } else { for (size_t rep = 0; rep < max_reps; ++rep) { const uint64_t t0 = hwy::timer::Start(); - pool.Run(0, pool.NumWorkers(), [&](uint64_t task, size_t thread) { - outputs[thread * kU64PerThread] = base + thread; - }); + pool.Run(0, pool.NumWorkers(), kCaller, + [&](uint64_t task, size_t thread) { + outputs[thread * kU64PerThread] = base + thread; + }); const uint64_t t1 = hwy::timer::Start(); times.push_back(t1 - t0); } diff --git a/util/zones.cc b/util/zones.cc index abc9dc2..a474311 100644 --- a/util/zones.cc +++ b/util/zones.cc @@ -1,70 +1,201 @@ #include "util/zones.h" +#include + +#include "hwy/base.h" +#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/profiler.h" namespace gcpp { +namespace { -#if PROFILER_ENABLED -static constexpr size_t kNumZones = static_cast(Zones::kNumZones); - -static const char* kProfilerZoneNames[kNumZones] = { - // Keep in sync with Zones enum. - "Ops.RMSNormMul", - "Ops.RMSNorm", - "Ops.RMSNormInplace", - "Ops.Rope", - "Ops.RopeAndMulBy", - "Ops.AddFrom", - "Ops.MulByConst", - "Ops.MulByConstTo", - "Ops.MulByConstAndAdd", - "Ops.MulByConstAndAddTile", - "Ops.MulByConstAndAddTile4", - "Ops.MulByConstAndAddVector", - "Ops.Softmax", - "Ops.LogitsSoftCap", - "FlashAttention.TransposeQ", - "FlashAttention.RMSNormAndPositionalEncoding", - "FlashAttention.SingleFlashAttention", - "FlashAttention.TileFlashAttention", - "FlashAttention.TileFlashAttention4", - "FlashAttention.FlashAttention", - "Gen.Activation", - "Gen.ActivationFused", - "Gen.SampleTop1", - "Gen.SampleTopK", - "Gen.Attention.QDotK", - "Gen.Attention.DotSoftmaxWeightedSum.par", - "Startup.Weights.ReadAllToBF16", - "Startup.Weights.ReadBatches", - "MM.Dispatch", - "MM.MatMul", - "MM.TwoMatMul", - "MM.DecompressA", - "MM.NT", - "MM.NT_K", - "MM.NT_MT", - "MM.NT_MT_K", -}; - -static hwy::profiler::ZoneHandle profiler_zone_handles[kNumZones]; -#endif - -void InitProfilerZones(hwy::Profiler& profiler) { -#if PROFILER_ENABLED - // Initialize the zone handles. This is done once at startup. - for (size_t i = 0; i < kNumZones; ++i) { - profiler_zone_handles[i] = profiler.AddZone(kProfilerZoneNames[i]); +const char* ZoneName(Zones zone) { + switch (zone) { + case Zones::kFlashAttentionFlashAttention: + return "FlashAttention.FlashAttention"; + case Zones::kFlashAttentionInclusive: + return "FlashAttention.Inclusive"; + case Zones::kFlashAttentionRmsNormAndPositionalEncoding: + return "FlashAttention.RMSNormAndPositionalEncoding"; + case Zones::kFlashAttentionSingleFlashAttention: + return "FlashAttention.SingleFlashAttention"; + case Zones::kFlashAttentionTileFlashAttention: + return "FlashAttention.TileFlashAttention"; + case Zones::kFlashAttentionTileFlashAttention4: + return "FlashAttention.TileFlashAttention4"; + case Zones::kFlashAttentionTransposeQ: + return "FlashAttention.TransposeQ"; + case Zones::kGenActivation: + return "Gen.Activation"; + case Zones::kGenActivationFused: + return "Gen.ActivationFused"; + case Zones::kGenAttention: + return "Gen.Attention"; + case Zones::kGenAttentionComputeQKV: + return "Gen.Attention.ComputeQKV"; + case Zones::kGenAttentionDotSoftmaxWeightedSumInclusive: + return "Gen.Attention.DotSoftmaxWeightedSumInclusive"; + case Zones::kGenAttentionDotSoftmaxWeightedSumPar: + return "Gen.Attention.DotSoftmaxWeightedSum.par"; + case Zones::kGenAttentionQDotK: + return "Gen.Attention.QDotK"; + case Zones::kGenAttentionSumHeads: + return "Gen.Attention.SumHeads"; + case Zones::kGenEmbed: + return "Gen.Embed"; + case Zones::kGenEmbeddingMatmul: + return "Gen.EmbeddingMatmul"; + case Zones::kGenFFW: + return "Gen.FFW"; + case Zones::kGenSampleTop1: + return "Gen.SampleTop1"; + case Zones::kGenSampleTopK: + return "Gen.SampleTopK"; + case Zones::kMMDecompressA: + return "MM.DecompressA"; + case Zones::kMMDispatch: + return "MM.Dispatch"; + case Zones::kMMMatMul: + return "MM.MatMul"; + case Zones::kMMNT_K: + return "MM.NT_K"; + case Zones::kMMNT_MT_K: + return "MM.NT_MT_K"; + case Zones::kMMNT_MT: + return "MM.NT_MT"; + case Zones::kMMNT: + return "MM.NT"; + case Zones::kMMTwoMatMul: + return "MM.TwoMatMul"; + case Zones::kOpsAddFrom: + return "Ops.AddFrom"; + case Zones::kOpsLogitsSoftCap: + return "Ops.LogitsSoftCap"; + // case Zones::kOpsMulByConst: // removed due to overhead + // case Zones::kOpsMulByConstAndAdd: // removed due to overhead + // case Zones::kOpsMulByConstAndAddTile: // removed due to overhead + // case Zones::kOpsMulByConstAndAddTile4: // removed due to overhead + // case Zones::kOpsMulByConstAndAddVector: // removed due to overhead + case Zones::kOpsMulByConstTo: + return "Ops.MulByConstTo"; + case Zones::kOpsRmsNorm: + return "Ops.RMSNorm"; + case Zones::kOpsRmsNormInplace: + return "Ops.RMSNormInplace"; + case Zones::kOpsRmsNormMul: + return "Ops.RMSNormMul"; + case Zones::kOpsRope: + return "Ops.Rope"; + case Zones::kOpsRopeAndMulBy: + return "Ops.RopeAndMulBy"; + case Zones::kOpsSoftmax: + return "Ops.Softmax"; + case Zones::kStartupWeightsReadAllToBF16: + return "Startup.Weights.ReadAllToBF16"; + case Zones::kStartupWeightsReadBatches: + return "Startup.Weights.ReadBatches"; + default: + HWY_ABORT("Invalid zone %d.", static_cast(zone)); } -#endif } -hwy::profiler::ZoneHandle GetProfilerZone(Zones zone) { -#if PROFILER_ENABLED - return profiler_zone_handles[static_cast(zone)]; -#else - return hwy::profiler::ZoneHandle(); -#endif +hwy::ProfilerFlags ZoneFlags(Zones zone) { + switch (zone) { + case Zones::kFlashAttentionInclusive: + case Zones::kGenAttention: + case Zones::kGenAttentionComputeQKV: + case Zones::kGenAttentionDotSoftmaxWeightedSumInclusive: + case Zones::kGenAttentionSumHeads: + case Zones::kGenEmbed: + case Zones::kGenEmbeddingMatmul: + case Zones::kGenFFW: + return hwy::ProfilerFlags::kInclusive; + default: + return hwy::ProfilerFlags::kDefault; + } +} + +const char* CallerName(Callers caller) { + switch (caller) { + case Callers::kActivationBatched: + return "ActivationBatched"; + case Callers::kAllocateAndBindAll: + return "AllocateAndBindAll"; + case Callers::kAttComputeQKV: + return "Att.ComputeQKV"; + case Callers::kAttDotSoftmaxWeightedSum: + return "Att.DotSoftmaxWeightedSum"; + case Callers::kBlobWriter: + return "BlobWriter"; + case Callers::kCompress: + return "Compress"; + case Callers::kFixupWeights: + return "FixupWeights"; + case Callers::kFlashAttention: + return "FlashAttention"; + case Callers::kFlashRMSNormAndPositionalEncoding: + return "Flash.RMSNormAndPositionalEncoding"; + case Callers::kFlashTransposeQ: + return "Flash.TransposeQ"; + case Callers::kMMClusterForMC: + return "MM.ClusterForMC"; + case Callers::kMMClusterForMCNC: + return "MM.ClusterForMCNC"; + case Callers::kMMClusterForN: + return "MM.ClusterForN"; + case Callers::kMMHierForMC: + return "MM.HierForMC"; + case Callers::kMMHierForMCNC: + return "MM.HierForMCNC"; + case Callers::kMMHierForN: + return "MM.HierForN"; + case Callers::kOpsAddFromBatched: + return "Ops.AddFromBatched"; + case Callers::kOpsMaybeLogitsSoftCapBatched: + return "Ops.MaybeLogitsSoftCapBatched"; + case Callers::kOpsRMSNormBatched: + return "Ops.RMSNormBatched"; + case Callers::kOpsRMSNormInplaceBatched: + return "Ops.RMSNormInplaceBatched"; + case Callers::kReadAllToBF16: + return "ReadAllToBF16"; + case Callers::kReadBatches: + return "ReadBatches"; + case Callers::kSampleAndStream: + return "SampleAndStream"; + case Callers::kTest: // only for unit tests. + return "Test-only!"; + case Callers::kTunePool: + return "TunePool"; + case Callers::kVitDotSoftmax1: + return "Vit.DotSoftmax1"; + case Callers::kVitDotSoftmax2: + return "Vit.DotSoftmax2"; + case Callers::kVitDotSoftmax3: + return "Vit.DotSoftmax3"; + case Callers::kVitDotSoftmax4: + return "Vit.DotSoftmax4"; + default: + HWY_ABORT("Invalid caller %d.", static_cast(caller)); + } +} + +} // namespace + +ProfilerZones::ProfilerZones(hwy::Profiler& profiler) { + for (size_t i = 0;; ++i) { + const Zones zone = static_cast(i); + if (zone == Zones::kNumZones) break; + handles_[i] = profiler.AddZone(ZoneName(zone), ZoneFlags(zone)); + } +} + +PoolCallers::PoolCallers() { + for (size_t i = 0;; ++i) { + const Callers caller = static_cast(i); + if (caller == Callers::kNumCallers) break; + callers_[i] = hwy::ThreadPool::AddCaller(CallerName(caller)); + } } } // namespace gcpp diff --git a/util/zones.h b/util/zones.h index e78340a..5624e24 100644 --- a/util/zones.h +++ b/util/zones.h @@ -1,57 +1,123 @@ #ifndef THIRD_PARTY_GEMMA_CPP_UTIL_ZONES_H_ #define THIRD_PARTY_GEMMA_CPP_UTIL_ZONES_H_ +#include + +#include "hwy/base.h" +#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/profiler.h" namespace gcpp { // Zones for the profiler. -enum class Zones { - kOpsRmsNormMul, - kOpsRmsNorm, - kOpsRmsNormInplace, - kOpsRope, - kOpsRopeAndMulBy, - kOpsAddFrom, - kOpsMulByConst, - kOpsMulByConstTo, - kOpsMulByConstAndAdd, - kOpsMulByConstAndAddTile, - kOpsMulByConstAndAddTile4, - kOpsMulByConstAndAddVector, - kOpsSoftmax, - kOpsLogitsSoftCap, - kFlashAttentionTransposeQ, +enum class Zones { // Keep sorted + kFlashAttentionFlashAttention, + kFlashAttentionInclusive, kFlashAttentionRmsNormAndPositionalEncoding, kFlashAttentionSingleFlashAttention, kFlashAttentionTileFlashAttention, kFlashAttentionTileFlashAttention4, - kFlashAttentionFlashAttention, + kFlashAttentionTransposeQ, kGenActivation, kGenActivationFused, + kGenAttention, + kGenAttentionComputeQKV, + kGenAttentionDotSoftmaxWeightedSumInclusive, + kGenAttentionDotSoftmaxWeightedSumPar, + kGenAttentionQDotK, + kGenAttentionSumHeads, + kGenEmbed, + kGenEmbeddingMatmul, + kGenFFW, kGenSampleTop1, kGenSampleTopK, - kGenAttentionQDotK, - kGenAttentionDotSoftmaxWeightedSumPar, - kStartupWeightsReadAllToBF16, - kStartupWeightsReadBatches, + kMMDecompressA, kMMDispatch, kMMMatMul, - kMMTwoMatMul, - kMMDecompressA, - kMMNT, kMMNT_K, - kMMNT_MT, kMMNT_MT_K, - kNumZones + kMMNT_MT, + kMMNT, + kMMTwoMatMul, + kOpsAddFrom, + kOpsLogitsSoftCap, + // kOpsMulByConst, // removed due to overhead + // kOpsMulByConstAndAdd, // removed due to overhead + // kOpsMulByConstAndAddTile, // removed due to overhead + // kOpsMulByConstAndAddTile4, // removed due to overhead + // kOpsMulByConstAndAddVector, // removed due to overhead + kOpsMulByConstTo, + kOpsRmsNorm, + kOpsRmsNormInplace, + kOpsRmsNormMul, + kOpsRope, + kOpsRopeAndMulBy, + kOpsSoftmax, + kStartupWeightsReadAllToBF16, + kStartupWeightsReadBatches, + kNumZones // must be last }; -// Initializes the profiler zones. Must be called before any other profiler -// functions. -void InitProfilerZones(hwy::Profiler& profiler); +// Owned by ThreadingContext. +class ProfilerZones { + public: + ProfilerZones(hwy::Profiler& profiler); -// Returns the zone handle for the given zone enum value. -hwy::profiler::ZoneHandle GetProfilerZone(Zones zone); + hwy::profiler::ZoneHandle Get(Zones zone) { + HWY_DASSERT(zone != Zones::kNumZones); + return handles_[static_cast(zone)]; + } + + private: + hwy::profiler::ZoneHandle handles_[static_cast(Zones::kNumZones)]; +}; + +enum class Callers { // Keep sorted + kActivationBatched, + kAllocateAndBindAll, + kAttComputeQKV, + kAttDotSoftmaxWeightedSum, + kBlobWriter, + kCompress, + kFixupWeights, + kFlashAttention, + kFlashRMSNormAndPositionalEncoding, + kFlashTransposeQ, + kMMClusterForMC, + kMMClusterForMCNC, + kMMClusterForN, + kMMHierForMC, + kMMHierForMCNC, + kMMHierForN, + kOpsAddFromBatched, + kOpsMaybeLogitsSoftCapBatched, + kOpsRMSNormBatched, + kOpsRMSNormInplaceBatched, + kReadAllToBF16, + kReadBatches, + kSampleAndStream, + kTest, // only for unit tests. + kTunePool, + kVitDotSoftmax1, + kVitDotSoftmax2, + kVitDotSoftmax3, + kVitDotSoftmax4, + kNumCallers // must be last +}; + +// Owned by ThreadingContext. +class PoolCallers { + public: + PoolCallers(); + + hwy::pool::Caller Get(Callers caller) { + HWY_DASSERT(caller != Callers::kNumCallers); + return callers_[static_cast(caller)]; + } + + private: + hwy::pool::Caller callers_[static_cast(Callers::kNumCallers)]; +}; } // namespace gcpp