From faa41029926bc6f827b94937fd915a9f3810adb3 Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Wed, 13 Aug 2025 01:37:53 -0700 Subject: [PATCH] (Resubmit) Prepare profiler annotations for new API Pass hwy::Profiler& to low-level functions. Used ThreadingContext arg instead of NestedPools. Use new PROFILER_ZONE3. PiperOrigin-RevId: 794461159 --- BUILD.bazel | 2 + CMakeLists.txt | 2 +- MODULE.bazel | 2 +- README.md | 2 +- evals/cross_entropy.cc | 7 +- examples/hello_world/CMakeLists.txt | 2 +- examples/simplified_gemma/CMakeLists.txt | 2 +- gemma/attention.cc | 77 ++++++++------ gemma/attention.h | 4 +- gemma/gemma-inl.h | 29 ++--- gemma/gemma.cc | 29 ++--- gemma/tokenizer.cc | 4 +- gemma/vit.cc | 18 ++-- gemma/weights.cc | 23 ++-- ops/matmul.h | 17 +-- ops/ops-inl.h | 128 +++++++++++++---------- ops/ops_test.cc | 33 +++--- util/threading.cc | 3 + util/threading_context.cc | 3 +- util/threading_context.h | 1 + 20 files changed, 220 insertions(+), 168 deletions(-) diff --git a/BUILD.bazel b/BUILD.bazel index ef111d4..6429f6d 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -80,6 +80,7 @@ cc_library( ":topology", # Placeholder for container detection, do not remove "@highway//:hwy", + "@highway//:profiler", "@highway//:thread_pool", "@highway//:topology", ], @@ -379,6 +380,7 @@ cc_test( "@highway//:hwy", "@highway//:hwy_test_util", "@highway//:nanobenchmark", #buildcleaner: keep + "@highway//:profiler", ], ) diff --git a/CMakeLists.txt b/CMakeLists.txt index ef2f2c8..b29a379 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -22,7 +22,7 @@ set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_EXPORT_COMPILE_COMMANDS ON) -FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 9414b48aeec251b69e6cadbfa42bebb5ddae1c34 EXCLUDE_FROM_ALL) +FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 92d327e841d78e11ae888757a3e16d291951cf64 EXCLUDE_FROM_ALL) FetchContent_MakeAvailable(highway) ## Note: absl needs to be installed by sentencepiece. This will only happen if diff --git a/MODULE.bazel b/MODULE.bazel index 95fb5cc..73ae1ec 100644 --- a/MODULE.bazel +++ b/MODULE.bazel @@ -18,7 +18,7 @@ bazel_dep(name = "google_benchmark", version = "1.8.5") # Require a more recent version. git_override( module_name = "highway", - commit = "9414b48aeec251b69e6cadbfa42bebb5ddae1c34", + commit = "92d327e841d78e11ae888757a3e16d291951cf64", remote = "https://github.com/google/highway", ) diff --git a/README.md b/README.md index b389934..067051d 100644 --- a/README.md +++ b/README.md @@ -469,7 +469,7 @@ FetchContent_MakeAvailable(sentencepiece) FetchContent_Declare(gemma GIT_REPOSITORY https://github.com/google/gemma.cpp GIT_TAG origin/main) FetchContent_MakeAvailable(gemma) -FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG da250571a45826b21eebbddc1e50d0c1137dee5f) +FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 92d327e841d78e11ae888757a3e16d291951cf64) FetchContent_MakeAvailable(highway) ``` diff --git a/evals/cross_entropy.cc b/evals/cross_entropy.cc index 09c3a42..c150041 100644 --- a/evals/cross_entropy.cc +++ b/evals/cross_entropy.cc @@ -84,8 +84,9 @@ HWY_BEFORE_NAMESPACE(); namespace gcpp { namespace HWY_NAMESPACE { -void CallSoftmax(float* HWY_RESTRICT logits, size_t vocab_size) { - Softmax(logits, vocab_size, /*worker=*/0); +void CallSoftmax(float* HWY_RESTRICT logits, size_t vocab_size, + hwy::Profiler& p) { + Softmax(logits, vocab_size, p, hwy::Profiler::Thread()); } } // namespace HWY_NAMESPACE @@ -109,7 +110,7 @@ float ComputeCrossEntropy(const Gemma& gemma, size_t max_generated_tokens, const SampleFunc sample_token = [&](float* probs, size_t vocab_size) -> TokenAndProb { // input is logits, not yet probabilities - HWY_DYNAMIC_DISPATCH(CallSoftmax)(probs, vocab_size); + HWY_DYNAMIC_DISPATCH(CallSoftmax)(probs, vocab_size, env.ctx.profiler); // We are called for each token, but pos starts at 1. Clamping // max_generated_tokens to prompt.size() should prevent overrun. HWY_ASSERT(pos < prompt.size()); diff --git a/examples/hello_world/CMakeLists.txt b/examples/hello_world/CMakeLists.txt index 2f5d648..c466e1c 100644 --- a/examples/hello_world/CMakeLists.txt +++ b/examples/hello_world/CMakeLists.txt @@ -18,7 +18,7 @@ set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) include(FetchContent) -FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG bfc36a6e633af94e63ac4b91c687bf0354cb24e0) +FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 92d327e841d78e11ae888757a3e16d291951cf64) FetchContent_MakeAvailable(highway) FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 9045b2f60fa2b323dfac0eaef8fc17565036f9f9) FetchContent_MakeAvailable(sentencepiece) diff --git a/examples/simplified_gemma/CMakeLists.txt b/examples/simplified_gemma/CMakeLists.txt index 5595164..4723852 100644 --- a/examples/simplified_gemma/CMakeLists.txt +++ b/examples/simplified_gemma/CMakeLists.txt @@ -18,7 +18,7 @@ set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) include(FetchContent) -FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 9414b48aeec251b69e6cadbfa42bebb5ddae1c34) +FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 92d327e841d78e11ae888757a3e16d291951cf64) FetchContent_MakeAvailable(highway) FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 53de76561cfc149d3c01037f0595669ad32a5e7c) FetchContent_MakeAvailable(sentencepiece) diff --git a/gemma/attention.cc b/gemma/attention.cc index 7fc8e76..bd8917e 100644 --- a/gemma/attention.cc +++ b/gemma/attention.cc @@ -19,6 +19,7 @@ #include #include "compression/types.h" // GEMMA_DISABLED_TARGETS +#include "util/threading_context.h" #ifndef HWY_DISABLED_TARGETS #define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS #endif // HWY_DISABLED_TARGETS @@ -53,8 +54,9 @@ static HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos, const hwy::Divisor& div_seq_len, const float* HWY_RESTRICT q, const MatPtrT& k, float* HWY_RESTRICT att, - const size_t worker) { - PROFILER_ZONE2(worker, "Gen.Attention.QDotK"); + hwy::Profiler& p, const size_t worker) { + static const auto zone = p.AddZone("Gen.Attention.QDotK"); + PROFILER_ZONE3(p, worker, zone); if (HWY_LIKELY(last_pos < static_cast(div_seq_len.GetDivisor()))) { // Slightly faster: no wraparound. for (size_t pos = start_pos; pos <= last_pos; ++pos) { @@ -73,8 +75,8 @@ static HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos, static void PositionalEncodingQK(float* qk, const size_t layer_idx, const LayerWeightsPtrs& layer, const AttentionActivations& activations, - const size_t worker, const size_t pos, - const float mul = 1.0f) { + hwy::Profiler& p, const size_t worker, + const size_t pos, const float mul = 1.0f) { const size_t qkv_dim = layer.layer_config.qkv_dim; const PostQKType& post_qk = layer.layer_config.post_qk; // qk is either q or k, so qkv_dim is the length we operate on. @@ -86,10 +88,10 @@ static void PositionalEncodingQK(float* qk, const size_t layer_idx, } // PostQKType::Rope if (post_qk == PostQKType::HalfRope) { - Rope(qk, qkv_dim / 2, inv_timescale, pos, worker); - if (mul != 1.0f) MulByConst(mul, qk, qkv_dim, worker); + Rope(qk, qkv_dim / 2, inv_timescale, pos, p, worker); + if (mul != 1.0f) MulByConst(mul, qk, qkv_dim, p, worker); } else { - RopeAndMulBy(mul, qk, qkv_dim, inv_timescale, pos, worker); + RopeAndMulBy(mul, qk, qkv_dim, inv_timescale, pos, p, worker); } } @@ -97,26 +99,31 @@ static void PositionalEncodingQK(float* qk, const size_t layer_idx, // `att_out`. Equivalent in gemma/modules.py: // encoded = jnp.einsum('BTNS,BSNH->BTNH', probs, value_proj) // `v` is a strided view of the kv cache with dimensions [seq_len, qkv_dim]. -static HWY_INLINE void WeightedSumV( - const size_t start_pos, const size_t last_pos, - const hwy::Divisor& div_seq_len, const float* HWY_RESTRICT att, - const MatPtrT& v, float* HWY_RESTRICT att_out, const size_t worker) { +static HWY_INLINE void WeightedSumV(const size_t start_pos, + const size_t last_pos, + const hwy::Divisor& div_seq_len, + const float* HWY_RESTRICT att, + const MatPtrT& v, + float* HWY_RESTRICT att_out, + hwy::Profiler& p, const size_t worker) { if (HWY_LIKELY(last_pos < static_cast(div_seq_len.GetDivisor()))) { // Slightly faster: no wraparound. Could be replaced with MatMul(att, v) if // we supported non-transposed B. // TODO: 2..4x unroll - MulByConstTo(att[start_pos], v.Row(start_pos), att_out, v.Cols(), worker); + MulByConstTo(att[start_pos], v.Row(start_pos), att_out, v.Cols(), p, + worker); for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) { - MulByConstAndAdd(att[pos], v.Row(pos), att_out, v.Cols(), worker); + MulByConstAndAdd(att[pos], v.Row(pos), att_out, v.Cols(), p, worker); } } else { { const size_t pos_mod = div_seq_len.Remainder(start_pos); - MulByConstTo(att[pos_mod], v.Row(pos_mod), att_out, v.Cols(), worker); + MulByConstTo(att[pos_mod], v.Row(pos_mod), att_out, v.Cols(), p, worker); } for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) { const size_t pos_mod = div_seq_len.Remainder(pos); - MulByConstAndAdd(att[pos_mod], v.Row(pos_mod), att_out, v.Cols(), worker); + MulByConstAndAdd(att[pos_mod], v.Row(pos_mod), att_out, v.Cols(), p, + worker); } } } @@ -128,7 +135,7 @@ void SingleDotSoftmaxWeightedSum( float* HWY_RESTRICT q, const MatPtrT& k, const MatPtrT& v, const size_t layer_idx, const LayerWeightsPtrs& layer, const AttentionActivations& activations, float* HWY_RESTRICT att, - float* HWY_RESTRICT att_out, const size_t worker) { + float* HWY_RESTRICT att_out, hwy::Profiler& p, const size_t worker) { const float att_cap = activations.config.att_cap; const float query_scale = activations.query_scale; const size_t seq_len = @@ -138,21 +145,21 @@ void SingleDotSoftmaxWeightedSum( if (layer.query_norm_scale.HasPtr()) { CallUpcasted(&layer.query_norm_scale, [&](const auto* weights_t) { RMSNormInplace(weights_t->PackedScale1(), 0, q, - layer.layer_config.qkv_dim, worker); + layer.layer_config.qkv_dim, p, worker); }); } - PositionalEncodingQK(q, layer_idx, layer, activations, worker, pos, + PositionalEncodingQK(q, layer_idx, layer, activations, p, worker, pos, query_scale); - QDotK(start_pos, last_pos, activations.div_seq_len, q, k, att, worker); + QDotK(start_pos, last_pos, activations.div_seq_len, q, k, att, p, worker); // SoftMax with optional SoftCap yields "probabilities" in att. const size_t att_len = HWY_MIN(last_pos + 1, seq_len); - MaybeLogitsSoftCap(att_cap, att, att_len, worker); - Softmax(att, att_len, worker, /*temperature=*/1.0f); + MaybeLogitsSoftCap(att_cap, att, att_len, p, worker); + Softmax(att, att_len, p, worker, /*temperature=*/1.0f); - WeightedSumV(start_pos, last_pos, activations.div_seq_len, att, v, att_out, + WeightedSumV(start_pos, last_pos, activations.div_seq_len, att, v, att_out, p, worker); } @@ -167,9 +174,8 @@ static HWY_INLINE size_t StartPos(size_t pos, const ModelConfig& config, void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx, const LayerWeightsPtrs& layer, AttentionActivations& activations, QBatch& qbatch, - NestedPools& pools) { - static const uint32_t HWY_MAYBE_UNUSED zone_id_par = - PROFILER_ADD_ZONE("Gen.Attention.DotSoftmax.par"); + ThreadingContext& ctx) { + static const auto zone = ctx.profiler.AddZone("Gen.Attention.DotSoftmax.par"); const hwy::Divisor div_qbatch(qbatch.Size()); const LayerConfig& layer_config = layer.layer_config; @@ -189,9 +195,7 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx, const auto func = [&](const size_t task, size_t worker) HWY_ATTR { const size_t tq_idx = activations.div_heads.Divide(task); const size_t head = activations.div_heads.Remainder(task); -#if PROFILER_ENABLED - const hwy::Zone zone(worker, zone_id_par); -#endif + PROFILER_ZONE3(ctx.profiler, worker, zone); const size_t qi = div_qbatch.Remainder(tq_idx); const size_t batch_idx = div_qbatch.Divide(tq_idx); @@ -223,14 +227,15 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx, v.SetPtr(kv_cache.Row(0) + kv_head_offset + qkv_dim, kv_cache.Stride()); SingleDotSoftmaxWeightedSum(pos, start_pos, last_pos, q, k, v, layer_idx, - layer, activations, att, att_out, worker); + layer, activations, att, att_out, ctx.profiler, + worker); }; { PROFILER_ZONE("Gen.Attention.DotSoftmax.ForkJoin"); // Full parallelism is helpful, SmallParallelFor is insufficient. ParallelFor(num_tokens * div_qbatch.GetDivisor() * layer_config.heads, - pools, func); + ctx.pools, func); } } @@ -303,12 +308,12 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx, if (layer.key_norm_scale.HasPtr()) { CallUpcasted(&layer.key_norm_scale, [&](const auto* weights_t) { RMSNormInplace(weights_t->PackedScale1(), 0, kv_f32, qkv_dim, - thread); + env.ctx.profiler, thread); }); } - PositionalEncodingQK(kv_f32, layer_idx, layer, activations, thread, - pos); + PositionalEncodingQK(kv_f32, layer_idx, layer, activations, + env.ctx.profiler, thread, pos); CompressPerThread tls; Compress(kv_f32, 2 * qkv_dim, tls, MakeSpan(kv, 2 * qkv_dim), 0); }); @@ -339,6 +344,10 @@ void GemmaAttention(size_t num_tokens, const size_t layer_idx, const LayerWeightsPtrs& layer, AttentionActivations& activations, QBatch& qbatch, MatMulEnv& env, int flags) { + static const auto zone = + env.ctx.profiler.AddZone("Gen.Attention", ProfilerFlags::kInclusive); + PROFILER_ZONE3(env.ctx.profiler, hwy::Profiler::Thread(), zone); + const LayerConfig& layer_config = layer.layer_config; HWY_DASSERT(!layer_config.IsMHA()); // No longer supported. HWY_DASSERT_M((layer_config.heads % layer_config.kv_heads) == 0, @@ -347,7 +356,7 @@ void GemmaAttention(size_t num_tokens, const size_t layer_idx, ComputeQKV(num_tokens, layer_idx, layer, activations, qbatch, flags, env); DotSoftmaxWeightedSum(num_tokens, layer_idx, layer, activations, qbatch, - env.ctx.pools); + env.ctx); SumHeads(layer, activations, env); } diff --git a/gemma/attention.h b/gemma/attention.h index 42b2be1..c69cc8f 100644 --- a/gemma/attention.h +++ b/gemma/attention.h @@ -33,12 +33,12 @@ namespace gcpp { float* HWY_RESTRICT q, const MatPtrT& k, const MatPtrT& v, \ size_t layer_idx, const LayerWeightsPtrs& layer, \ const AttentionActivations& activations, float* HWY_RESTRICT att, \ - float* HWY_RESTRICT att_out, size_t worker); \ + float* HWY_RESTRICT att_out, hwy::Profiler& p, size_t worker); \ \ void DotSoftmaxWeightedSum(const size_t num_tokens, size_t layer_idx, \ const LayerWeightsPtrs& layer, \ AttentionActivations& activations, \ - QBatch& qbatch, NestedPools& pools); \ + QBatch& qbatch, ThreadingContext& ctx); \ \ void GemmaAttention(size_t num_tokens, const size_t layer_idx, \ const LayerWeightsPtrs& layer, \ diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index 6c75cd2..9bda90f 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -45,9 +45,10 @@ namespace HWY_NAMESPACE { template void Activation(ActivationType activation, T* HWY_RESTRICT c1, - const T* HWY_RESTRICT c2, const size_t count, + const T* HWY_RESTRICT c2, const size_t count, hwy::Profiler& p, const size_t worker) { - PROFILER_ZONE2(worker, "Gen.Activation"); + static const auto zone = p.AddZone("Gen.Activation"); + PROFILER_ZONE3(p, worker, zone); namespace hn = hwy::HWY_NAMESPACE; using DF = hn::ScalableTag; using VF = hn::Vec; @@ -64,28 +65,30 @@ void Activation(ActivationType activation, T* HWY_RESTRICT c1, // No C2 multiplier. template -void ActivationBatched(ActivationType activation, Mat& c1, NestedPools& pools) { +void ActivationBatched(ActivationType activation, Mat& c1, + ThreadingContext& ctx) { using T = typename Mat::T; - SmallParallelFor(c1.Rows(), pools, [&](uint64_t task, size_t worker) { + SmallParallelFor(c1.Rows(), ctx.pools, [&](uint64_t task, size_t worker) { // Cast to correct type so type deduction works. Activation(activation, c1.Row(task), static_cast(nullptr), - c1.Cols(), worker); + c1.Cols(), ctx.profiler, worker); }); } template HWY_NOINLINE void ActivationBatched(ActivationType activation, Mat& c1, - const Mat* c2, NestedPools& pools) { + const Mat* c2, ThreadingContext& ctx) { using T = typename Mat::T; HWY_DASSERT(c1.SameShape(*c2)); if (c2 && c2->HasPtr()) { - SmallParallelFor(c1.Rows(), pools, [&](uint64_t task, size_t worker) { - Activation(activation, c1.Row(task), c2->Row(task), c1.Cols(), worker); + SmallParallelFor(c1.Rows(), ctx.pools, [&](uint64_t task, size_t worker) { + Activation(activation, c1.Row(task), c2->Row(task), c1.Cols(), + ctx.profiler, worker); }); } else { // No multiplier - SmallParallelFor(c1.Rows(), pools, [&](uint64_t task, size_t worker) { + SmallParallelFor(c1.Rows(), ctx.pools, [&](uint64_t task, size_t worker) { Activation(activation, c1.Row(task), static_cast(nullptr), - c1.Cols(), worker); + c1.Cols(), ctx.profiler, worker); }); } } @@ -110,7 +113,9 @@ void PostNorm(PostNormType post_norm, const MatPtr& weights, static inline void FFWNoVit(const LayerWeightsPtrs& layer, Activations& activations, MatMulEnv& env) { - PROFILER_ZONE("Gen.FFW"); + static const auto zone = + env.ctx.profiler.AddZone("Gen.FFW", ProfilerFlags::kInclusive); + PROFILER_ZONE3(env.ctx.profiler, hwy::Profiler::Thread(), zone); const LayerConfig& layer_config = layer.layer_config; const size_t ffh_hidden_dim = layer_config.ff_hidden_dim; @@ -129,7 +134,7 @@ static inline void FFWNoVit(const LayerWeightsPtrs& layer, // Activation (Gelu) and maybe multiply by gate. Store activations in act. ActivationBatched(layer_config.activation, activations.C1, &activations.C2, - env.ctx.pools); + env.ctx); // Hidden layer -> output layer. CallMatMul(activations.C1, layer.linear_w, output_bias, env, diff --git a/gemma/gemma.cc b/gemma/gemma.cc index 8c05306..a7b8423 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -55,6 +55,7 @@ #include "io/io.h" // Path #include "ops/matmul.h" #include "paligemma/image.h" +#include "util/basics.h" // PROFILER_ZONE3 #include "util/threading_context.h" #include "hwy/aligned_allocator.h" // Span #include "hwy/base.h" @@ -138,7 +139,8 @@ static float EmbeddingScaling(size_t model_dim) { static HWY_NOINLINE size_t EmbedMMToken(int token, size_t qi, size_t pos, size_t pos_in_prompt, const ModelConfig& model_config, const WeightsPtrs& weights, - MatStorageT& x, const ImageTokens* image_tokens = nullptr, + MatStorageT& x, ThreadingContext& ctx, + const ImageTokens* image_tokens = nullptr, size_t image_token_position = 0) { // Image tokens just need to be copied. if (model_config.wrapping == PromptWrapping::GEMMA_VLM && @@ -174,7 +176,8 @@ EmbedMMToken(int token, size_t qi, size_t pos, size_t pos_in_prompt, const hn::ScalableTag df; DecompressAndZeroPad(df, embedding_span, embedding_ofs, x.Row(qi), model_dim); - MulByConst(emb_scaling * weights_t->Scale(), x.Row(qi), model_dim, worker); + MulByConst(emb_scaling * weights_t->Scale(), x.Row(qi), model_dim, + ctx.profiler, worker); }); if (model_config.absolute_pe) { @@ -249,7 +252,7 @@ static HWY_NOINLINE void PrefillTBatch(const ModelConfig& config, const int token = qbatch_1.Prompt(0)[pos_in_prompt]; image_token_position = EmbedMMToken( token, ti, pos, pos_in_prompt, config, weights, activations.x, - runtime_config.image_tokens, image_token_position); + env.ctx, runtime_config.image_tokens, image_token_position); } // Transformer with one batch of tokens from a single query. @@ -306,7 +309,7 @@ static HWY_NOINLINE void Transformer(const ModelConfig& config, // TODO: parallelize? for (size_t qi = 0; qi < qbatch.Size(); ++qi) { EmbedMMToken(qbatch.PrevToken(qi), qi, qbatch.Pos(qi), - /*pos_in_prompt=*/0, config, weights, activations.x); + /*pos_in_prompt=*/0, config, weights, activations.x, env.ctx); } for (size_t layer_idx = 0; layer_idx < weights.c_layers.size(); ++layer_idx) { @@ -419,7 +422,8 @@ static void DecodeStepT(const ModelConfig& config, const size_t worker = 0; // TODO: parallelize non_eos.Foreach([&](size_t qi) { float* HWY_RESTRICT logits = activations.logits.Row(qi); - MaybeLogitsSoftCap(config.final_cap, logits, config.vocab_size, worker); + MaybeLogitsSoftCap(config.final_cap, logits, config.vocab_size, + env.ctx.profiler, worker); const TokenAndProb tp = sample_token(logits, config.vocab_size); timing_info.NotifyGenerated(); @@ -429,27 +433,28 @@ static void DecodeStepT(const ModelConfig& config, } static HWY_INLINE SampleFunc -ChooseSampleFunc(const RuntimeConfig& runtime_config) { +ChooseSampleFunc(const RuntimeConfig& runtime_config, ThreadingContext& ctx) { // If user provided a sample_func, use it. if (runtime_config.sample_func) return runtime_config.sample_func; + static const auto zone = ctx.profiler.AddZone("Gen.Sample Top1"); const size_t worker = 0; // TODO: parallelize // Fast path for top-1 with no accept_token. if (runtime_config.top_k == 1 && !runtime_config.accept_token) { - return [](float* logits, size_t vocab_size) HWY_ATTR -> TokenAndProb { - PROFILER_ZONE2(worker, "Gen.Sample Top1"); + return [&](float* logits, size_t vocab_size) HWY_ATTR -> TokenAndProb { + PROFILER_ZONE3(ctx.profiler, worker, zone); return Top1OfSoftmax(logits, vocab_size); }; } // General case: Softmax with top-k sampling. - return [&runtime_config](float* logits, - size_t vocab_size) HWY_ATTR -> TokenAndProb { + return [&](float* logits, size_t vocab_size) HWY_ATTR -> TokenAndProb { PROFILER_ZONE("Gen.Sample general"); return FusedSoftmaxAndSampleTopK( logits, runtime_config.top_k, vocab_size, *runtime_config.gen, - runtime_config.temperature, runtime_config.accept_token, worker); + runtime_config.temperature, runtime_config.accept_token, ctx.profiler, + worker); }; } @@ -524,7 +529,7 @@ static void GenerateT(const ModelConfig& config, max_gen_steps = seq_len - max_prompt_size; } - const SampleFunc sample_token = ChooseSampleFunc(runtime_config); + const SampleFunc sample_token = ChooseSampleFunc(runtime_config, env.ctx); { timing_info.generate_start = hwy::platform::Now(); diff --git a/gemma/tokenizer.cc b/gemma/tokenizer.cc index e0e071c..6e6d7d1 100644 --- a/gemma/tokenizer.cc +++ b/gemma/tokenizer.cc @@ -21,8 +21,8 @@ #include #include -#include "gemma/configs.h" // PromptWrapping -#include "hwy/base.h" // HWY_ASSERT +#include "gemma/configs.h" // PromptWrapping +#include "hwy/base.h" // HWY_ASSERT #include "hwy/profiler.h" // copybara:import_next_line:sentencepiece #include "src/sentencepiece_processor.h" diff --git a/gemma/vit.cc b/gemma/vit.cc index 3549f85..a694187 100644 --- a/gemma/vit.cc +++ b/gemma/vit.cc @@ -95,7 +95,7 @@ class VitAttention { float* HWY_RESTRICT q = activations_.attention.q.Row(token) + head * 3 * qkv_dim; // TODO: shift to MatMul with A.scale once MatMul is confirmed working - MulByConst(query_scale, q, qkv_dim, worker); + MulByConst(query_scale, q, qkv_dim, env_.ctx.profiler, worker); hwy::CopyBytes(q, Q.Row(token), qkv_dim * sizeof(float)); }); @@ -111,7 +111,7 @@ class VitAttention { pool_.Run(0, num_tokens_, [&](uint64_t task, size_t worker) HWY_ATTR { float* HWY_RESTRICT c = C.Row(task); - Softmax(c, C.Cols(), worker); + Softmax(c, C.Cols(), env_.ctx.profiler, worker); }); pool_.Run(0, num_tokens_, [&](uint64_t task, size_t worker) HWY_ATTR { @@ -121,7 +121,8 @@ class VitAttention { for (size_t i = 0; i < seq_len; ++i) { float* HWY_RESTRICT v = activations_.attention.q.Row(i) + head * 3 * qkv_dim + 2 * qkv_dim; - MulByConstAndAdd(C.Row(token)[i], v, att_out, qkv_dim, worker); + MulByConstAndAdd(C.Row(token)[i], v, att_out, qkv_dim, + env_.ctx.profiler, worker); } }); } @@ -144,7 +145,7 @@ class VitAttention { // Compute Q.K scores, which are "logits" stored in head_att. float* HWY_RESTRICT q = activations_.attention.q.Row(token) + head * 3 * qkv_dim; - MulByConst(query_scale, q, qkv_dim, worker); + MulByConst(query_scale, q, qkv_dim, env_.ctx.profiler, worker); float* HWY_RESTRICT head_att = activations_.attention.att.Row(token) + head * seq_len; for (size_t i = 0; i < seq_len; ++i) { @@ -153,7 +154,7 @@ class VitAttention { head_att[i] = Dot(q, k, qkv_dim); // score = q.k } // SoftMax yields "probabilities" in head_att. - Softmax(head_att, seq_len, worker); + Softmax(head_att, seq_len, env_.ctx.profiler, worker); // Compute weighted sum of v into att_out. float* HWY_RESTRICT att_out = activations_.attention.att_out.Row(token) + head * qkv_dim; @@ -161,7 +162,8 @@ class VitAttention { for (size_t i = 0; i < seq_len; ++i) { float* HWY_RESTRICT v = activations_.attention.q.Row(i) + head * 3 * qkv_dim + 2 * qkv_dim; - MulByConstAndAdd(head_att[i], v, att_out, qkv_dim, worker); + MulByConstAndAdd(head_att[i], v, att_out, qkv_dim, + env_.ctx.profiler, worker); } }); } @@ -224,7 +226,7 @@ void FFWVit(const LayerWeightsPtrs& layer, Activations& activations, activations.C1); // Activation (Gelu), store in C1. - ActivationBatched(layer_config.activation, activations.C1, env.ctx.pools); + ActivationBatched(layer_config.activation, activations.C1, env.ctx); // Hidden layer -> output layer. CallMatMul(activations.C1, layer.vit.linear_1_w, output_bias, env, @@ -334,7 +336,7 @@ void PrefillVit(const ModelConfig& model_config, const WeightsPtrs& weights, // Apply soft embedding norm before input projection. CallUpcasted(&weights.mm_embed_norm, [&](const auto* weights_t) { RMSNormInplace(weights_t->PackedScale1(), 0, activations.x.Row(0), - vit_model_dim, /*worker=*/0); + vit_model_dim, env.ctx.profiler, hwy::Profiler::Thread()); }); } diff --git a/gemma/weights.cc b/gemma/weights.cc index 721cfb6..4124247 100644 --- a/gemma/weights.cc +++ b/gemma/weights.cc @@ -381,9 +381,11 @@ static void DecompressToBF16(MatPtr& mat, } static void ReadAllToBF16(const std::vector& tensors, - const BlobReader& reader, hwy::ThreadPool& pool) { - pool.Run(0, tensors.size(), [&](uint64_t task, size_t thread) { - PROFILER_ZONE2(thread, "Startup.Weights.ReadBF16"); + const BlobReader& reader, ThreadingContext& ctx) { + static const auto zone = + ctx.profiler.AddZone("Startup.Weights.ReadAllToBF16"); + ctx.pools.Pool().Run(0, tensors.size(), [&](uint64_t task, size_t thread) { + PROFILER_ZONE3(ctx.profiler, thread, zone); const TensorToRead& tensor = tensors[task]; MatPtr& mat = *tensor.mat; @@ -460,10 +462,11 @@ static std::vector MakeBatches( // want to use the OS cache between consecutive runs. static void ReadBatches(const BlobReader& reader, const std::vector& batches, - hwy::ThreadPool& pool) { + ThreadingContext& ctx) { + static const auto zone = ctx.profiler.AddZone("Startup.Weights.ReadBatches"); // >5x speedup from parallel reads when cached. - pool.Run(0, batches.size(), [&](uint64_t i, size_t thread) { - PROFILER_ZONE2(thread, "Startup.Weights.Read"); + ctx.pools.Pool().Run(0, batches.size(), [&](uint64_t i, size_t thread) { + PROFILER_ZONE3(ctx.profiler, thread, zone); const IOBatch& batch = batches[i]; const std::string& key = reader.Keys()[batch.KeyIdx()]; const uint64_t bytes_read = batch.Read(reader.file()); @@ -500,16 +503,14 @@ static MapPtr MapOrReadAll(std::vector& tensors, AllocateAndBindAll(tensors, *mode, mat_owners, ctx); } - hwy::ThreadPool& pool = ctx.pools.Pool(); - if (*mode == WeightsPtrs::Mode::kReadBF16) { - ReadAllToBF16(tensors, reader, pool); + ReadAllToBF16(tensors, reader, ctx); return MapPtr(); } const std::vector batches = MakeBatches(tensors, reader.file_bytes()); - ReadBatches(reader, batches, pool); + ReadBatches(reader, batches, ctx); return MapPtr(); } @@ -519,7 +520,7 @@ WeightsPtrs::Mode WeightsPtrs::ReadFromBlobs(const ModelStore& model, const InferenceArgs& inference, std::vector& mat_owners, ThreadingContext& ctx) { - PROFILER_ZONE("Startup.ReadFromBlobs"); + PROFILER_ZONE("Startup.Weights.ReadFromBlobs"); // List of tensors to read/map, and where from. std::vector tensors; diff --git a/ops/matmul.h b/ops/matmul.h index 69e2256..de8ef8c 100644 --- a/ops/matmul.h +++ b/ops/matmul.h @@ -720,32 +720,33 @@ struct MMArgs { // Wrapper over hwy::Zone that is only enabled when autotuning finished. #if PROFILER_ENABLED class MMZone { - using Zone = hwy::Zone; - static_assert(alignof(Zone) <= 8 && sizeof(Zone) <= 8); + using Zone = hwy::profiler::Zone; + static_assert(alignof(Zone) <= 8 && sizeof(Zone) <= 16); public: ~MMZone() { - if (used_) { + if (data_ != 0) { Zone* zone = reinterpret_cast(&data_); zone->~Zone(); } } // `name` must be a string literal. - void MaybeEnter(size_t thread_id, uint32_t zone_id, const MMArgs& args) { + void MaybeEnter(size_t thread, hwy::profiler::ZoneHandle zone, + const MMArgs& args) { if (args.per_key->WantProfile()) { - new (&data_) Zone(thread_id, zone_id); - used_ = true; + new (&data_) Zone(args.env->ctx.profiler, thread, zone); + HWY_DASSERT(data_ != 0); } } private: uint64_t data_ = 0; - bool used_ = false; + uint64_t data2_ = 0; }; #else struct MMZone { - void MaybeEnter(size_t, uint32_t, const MMArgs&) {} + void MaybeEnter(size_t, hwy::profiler::ZoneHandle, const MMArgs&) {} }; #endif // PROFILER_ENABLED diff --git a/ops/ops-inl.h b/ops/ops-inl.h index 688450b..343600a 100644 --- a/ops/ops-inl.h +++ b/ops/ops-inl.h @@ -125,9 +125,9 @@ HWY_INLINE hn::Vec Gelu(D d, hn::Vec v) { return hn::Mul(v, cdf); } +// Activation already has a profiler zone. static HWY_NOINLINE HWY_MAYBE_UNUSED void Gelu(float* HWY_RESTRICT x, size_t size) { - PROFILER_ZONE("ops.Gelu"); namespace hn = hwy::HWY_NAMESPACE; using D = hn::ScalableTag; hn::Transform(D(), x, size, @@ -191,9 +191,10 @@ namespace detail { // Shared by RMSNorm and RMSNormInplace. template -float RMSNormMul(const VT* HWY_RESTRICT x, const size_t size, - const HWY_MAYBE_UNUSED size_t worker) { - PROFILER_ZONE2(worker, "ops.RMSNormMul"); +float RMSNormMul(const VT* HWY_RESTRICT x, const size_t size, hwy::Profiler& p, + const size_t worker) { + static const auto zone = p.AddZone("Ops.RMSNormMul"); + PROFILER_ZONE3(p, worker, zone); const hn::ScalableTag d; const float l2 = DecompressAndCall(d, MakeSpan(x, size), DotKernelDefault()); @@ -205,18 +206,20 @@ float RMSNormMul(const VT* HWY_RESTRICT x, const size_t size, // `x_ofs` is the offset within `x`, required for NuqStream. template -HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm( - const XT* HWY_RESTRICT x, const WT* HWY_RESTRICT weight, size_t w_ofs, - OT* HWY_RESTRICT out, const size_t size, - const size_t HWY_MAYBE_UNUSED worker) { - PROFILER_ZONE2(worker, "ops.RMSNorm"); +HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(const XT* HWY_RESTRICT x, + const WT* HWY_RESTRICT weight, + size_t w_ofs, OT* HWY_RESTRICT out, + const size_t size, hwy::Profiler& p, + const size_t worker) { + static const auto zone = p.AddZone("Ops.RMSNorm"); + PROFILER_ZONE3(p, worker, zone); namespace hn = hwy::HWY_NAMESPACE; const hn::ScalableTag df; using VF = hn::Vec; const size_t NF = hn::Lanes(df); - const VF mul = hn::Set(df, detail::RMSNormMul(x, size, worker)); + const VF mul = hn::Set(df, detail::RMSNormMul(x, size, p, worker)); const auto packed_x = MakeSpan(x, size); const auto packed_w = MakeSpan(weight, w_ofs + size); @@ -240,15 +243,16 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm( template HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace( const WT* HWY_RESTRICT weight, size_t w_ofs, XT* HWY_RESTRICT inout, - const size_t size, const HWY_MAYBE_UNUSED size_t worker) { - PROFILER_ZONE2(worker, "ops.RMSNormInplace"); + const size_t size, hwy::Profiler& p, const size_t worker) { + static const auto zone = p.AddZone("Ops.RMSNormInplace"); + PROFILER_ZONE3(p, worker, zone); namespace hn = hwy::HWY_NAMESPACE; const hn::ScalableTag df; using VF = hn::Vec; const size_t NF = hn::Lanes(df); - const VF mul = hn::Set(df, detail::RMSNormMul(inout, size, worker)); + const VF mul = hn::Set(df, detail::RMSNormMul(inout, size, p, worker)); const auto packed_w = MakeSpan(weight, w_ofs + size); const auto packed_x = MakeSpan(inout, size); @@ -407,9 +411,10 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void AddAbsolutePositionalEmbeddings( // This overload is called if `post_qk == PostQKType::HalfRope`. static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope( float* HWY_RESTRICT x, const size_t dim_qkv, - const float* HWY_RESTRICT inv_timescale, const int pos, - const size_t HWY_MAYBE_UNUSED worker = 0) { - PROFILER_ZONE2(worker, "ops.Rope"); + const float* HWY_RESTRICT inv_timescale, const int pos, hwy::Profiler& p, + const size_t worker) { + static const auto zone = p.AddZone("Ops.Rope"); + PROFILER_ZONE3(p, worker, zone); HWY_DASSERT(dim_qkv % 2 == 0); const size_t half_dim_qkv = dim_qkv / 2; @@ -466,9 +471,10 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope( // `inv_timescale[dim_qkv / 2]` is precomputed in AttentionActivations. static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy( const float mul, float* HWY_RESTRICT x, const size_t dim_qkv, - const float* HWY_RESTRICT inv_timescale, const int pos, - const size_t HWY_MAYBE_UNUSED worker = 0) { - PROFILER_ZONE2(worker, "ops.RopeAndMulBy"); + const float* HWY_RESTRICT inv_timescale, const int pos, hwy::Profiler& p, + const size_t worker) { + static const auto zone = p.AddZone("Ops.RopeAndMulBy"); + PROFILER_ZONE3(p, worker, zone); HWY_DASSERT(dim_qkv % 2 == 0); const size_t half_dim_qkv = dim_qkv / 2; @@ -525,10 +531,13 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy( } template -static HWY_NOINLINE HWY_MAYBE_UNUSED void AddFrom( - const XT* HWY_RESTRICT x, float* HWY_RESTRICT out, const size_t size, - const HWY_MAYBE_UNUSED size_t worker) { - PROFILER_ZONE2(worker, "ops.AddFrom"); +static HWY_NOINLINE HWY_MAYBE_UNUSED void AddFrom(const XT* HWY_RESTRICT x, + float* HWY_RESTRICT out, + const size_t size, + hwy::Profiler& p, + const size_t worker) { + static const auto zone = p.AddZone("Ops.AddFrom"); + PROFILER_ZONE3(p, worker, zone); namespace hn = hwy::HWY_NAMESPACE; const hn::ScalableTag df; @@ -576,12 +585,11 @@ void RMSNormBatched(const MatPtrT& activations, const MatPtr& weights, HWY_DASSERT(activations.SameShape(out)); CallUpcasted(&weights, [&](const auto* weights_t) { - SmallParallelFor(activations.Rows(), ctx.pools, - [&](uint64_t token_idx, size_t worker) { - RMSNorm(activations.Row(token_idx), - weights_t->PackedScale1(), 0, out.Row(token_idx), - activations.Cols(), worker); - }); + SmallParallelFor( + activations.Rows(), ctx.pools, [&](uint64_t token_idx, size_t worker) { + RMSNorm(activations.Row(token_idx), weights_t->PackedScale1(), 0, + out.Row(token_idx), activations.Cols(), ctx.profiler, worker); + }); }); } @@ -592,12 +600,11 @@ void RMSNormInplaceBatched(const MatPtr& weights, MatPtrT& inout, HWY_DASSERT(weights.Cols() == inout.Cols()); CallUpcasted(&weights, [&](const auto* weights_t) { - SmallParallelFor(inout.Rows(), ctx.pools, - [&](uint64_t token_idx, size_t worker) { - RMSNormInplace(weights_t->PackedScale1(), 0, - inout.Row(token_idx), inout.Cols(), - worker); - }); + SmallParallelFor( + inout.Rows(), ctx.pools, [&](uint64_t token_idx, size_t worker) { + RMSNormInplace(weights_t->PackedScale1(), 0, inout.Row(token_idx), + inout.Cols(), ctx.profiler, worker); + }); }); } @@ -622,17 +629,20 @@ template static HWY_INLINE void AddFromBatched(const MatPtrT& x, MatPtrT& out, ThreadingContext& ctx) { HWY_DASSERT(out.SameShape(x)); - SmallParallelFor( - out.Rows(), ctx.pools, [&](uint64_t token_idx, size_t worker) { - AddFrom(x.Row(token_idx), out.Row(token_idx), x.Cols(), worker); - }); + SmallParallelFor(out.Rows(), ctx.pools, + [&](uint64_t token_idx, size_t worker) { + AddFrom(x.Row(token_idx), out.Row(token_idx), x.Cols(), + ctx.profiler, worker); + }); } template -HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConst( - const float c, XT* HWY_RESTRICT x, const size_t size, - const HWY_MAYBE_UNUSED size_t worker) { - PROFILER_ZONE2(worker, "ops.MulByConst"); +HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConst(const float c, XT* HWY_RESTRICT x, + const size_t size, + hwy::Profiler& p, + const size_t worker) { + static const auto zone = p.AddZone("Ops.MulByConst"); + PROFILER_ZONE3(p, worker, zone); namespace hn = hwy::HWY_NAMESPACE; const hn::ScalableTag df; const size_t NF = hn::Lanes(df); @@ -672,8 +682,9 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConst( template HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstTo( const float c, const XT* HWY_RESTRICT x, OT* HWY_RESTRICT out, - const size_t size, const HWY_MAYBE_UNUSED size_t worker) { - PROFILER_ZONE2(worker, "ops.MulByConstTo"); + const size_t size, hwy::Profiler& p, const size_t worker) { + static const auto zone = p.AddZone("Ops.MulByConstTo"); + PROFILER_ZONE3(p, worker, zone); namespace hn = hwy::HWY_NAMESPACE; const hn::ScalableTag df; const size_t NF = hn::Lanes(df); @@ -714,8 +725,9 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstTo( template HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAdd( const float c, const XT* HWY_RESTRICT x, OT* HWY_RESTRICT out, - const size_t size, const HWY_MAYBE_UNUSED size_t worker) { - PROFILER_ZONE2(worker, "ops.MulByConstAndAdd"); + const size_t size, hwy::Profiler& p, const size_t worker) { + static const auto zone = p.AddZone("Ops.MulByConstAndAdd"); + PROFILER_ZONE3(p, worker, zone); namespace hn = hwy::HWY_NAMESPACE; const hn::ScalableTag df; const size_t NF = hn::Lanes(df); @@ -760,9 +772,10 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAdd( // See below for a specialized version for top-1 sampling. static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size, - const size_t worker, + hwy::Profiler& p, const size_t worker, float temperature = 1.0f) { - PROFILER_ZONE2(worker, "ops.Softmax"); + static const auto zone = p.AddZone("Ops.Softmax"); + PROFILER_ZONE3(p, worker, zone); HWY_DASSERT(size != 0); namespace hn = hwy::HWY_NAMESPACE; @@ -803,7 +816,7 @@ static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size, const float sum_exp = Sum(d, x, size); // Double-precision reciprocal does not appear to affect the results. const float mul = 1.0f / sum_exp; - MulByConst(mul, x, size, worker); + MulByConst(mul, x, size, p, worker); } // Note: https://arxiv.org/pdf/2001.04438 proposes to replace the three max / @@ -893,9 +906,10 @@ static HWY_MAYBE_UNUSED TokenAndProb Top1OfSoftmax(float* HWY_RESTRICT x, } static HWY_NOINLINE void LogitsSoftCap(const float cap, float* HWY_RESTRICT x, - const size_t size, - const HWY_MAYBE_UNUSED size_t worker) { - PROFILER_ZONE2(worker, "ops.LogitsSoftCap"); + const size_t size, hwy::Profiler& p, + const size_t worker) { + static const auto zone = p.AddZone("Ops.LogitsSoftCap"); + PROFILER_ZONE3(p, worker, zone); namespace hn = hwy::HWY_NAMESPACE; using D = hn::ScalableTag; @@ -911,10 +925,10 @@ static HWY_NOINLINE void LogitsSoftCap(const float cap, float* HWY_RESTRICT x, // Calls LogitsSoftCap if cap != 0.0f. static HWY_INLINE HWY_MAYBE_UNUSED void MaybeLogitsSoftCap( - const float cap, float* HWY_RESTRICT x, const size_t size, + const float cap, float* HWY_RESTRICT x, const size_t size, hwy::Profiler& p, const size_t worker) { if (cap != 0.0f) { - LogitsSoftCap(cap, x, size, worker); + LogitsSoftCap(cap, x, size, p, worker); } } @@ -998,7 +1012,7 @@ template HWY_NOINLINE HWY_MAYBE_UNUSED TokenAndProb FusedSoftmaxAndSampleTopK( const float* HWY_RESTRICT logits, size_t k, size_t vocab_size, std::mt19937& gen, float temperature, TAcceptToken& accept_token, - size_t worker) { + hwy::Profiler& p, size_t worker) { // Softmax and sample top-K is equivalent to taking the top-K logits and // sampling from the softmax of the top-K logits. The latter is faster as it // avoids computing the softmax of all logits. @@ -1012,7 +1026,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED TokenAndProb FusedSoftmaxAndSampleTopK( } size_t mask = token_logits.size(); - Softmax(topk_logits.data(), mask, worker, temperature); + Softmax(topk_logits.data(), mask, p, worker, temperature); auto distribution = std::discrete_distribution( std::begin(topk_logits), std::begin(topk_logits) + mask); int topk_sampled_index = distribution(gen); diff --git a/ops/ops_test.cc b/ops/ops_test.cc index 2a51839..e935ccf 100644 --- a/ops/ops_test.cc +++ b/ops/ops_test.cc @@ -18,8 +18,6 @@ #define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS #endif // HWY_DISABLED_TARGETS -#include "ops/ops.h" - #include #include @@ -33,11 +31,13 @@ #include "gemma/activations.h" // ChooseQueryScale #include "gemma/configs.h" +#include "ops/ops.h" #include "util/allocator.h" #include "util/basics.h" // BF16 #include "util/mat.h" // MatStorageT #include "util/test_util.h" #include "util/threading_context.h" +#include "hwy/profiler.h" #include "hwy/tests/hwy_gtest.h" // clang-format off @@ -166,7 +166,7 @@ struct TestAddFrom { } SimpleAddFrom(o, e, count); - AddFrom(o, x, count, /*worker=*/0); + AddFrom(o, x, count, hwy::Profiler::Get(), /*worker=*/0); hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__, __LINE__); @@ -199,7 +199,7 @@ struct TestMulByConstAndAdd { T constant = Random(rng); SimpleMulByConstAndAdd(constant, o, e, count); - MulByConstAndAdd(constant, o, x, count, /*worker=*/0); + MulByConstAndAdd(constant, o, x, count, hwy::Profiler::Get(), /*worker=*/0); hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__, __LINE__); @@ -229,7 +229,7 @@ struct TestMulByConst { T constant = Random(rng); SimpleMulByConst(constant, e, count); - MulByConst(constant, x, count, /*worker=*/0); + MulByConst(constant, x, count, hwy::Profiler::Get(), /*worker=*/0); hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__, __LINE__); @@ -259,7 +259,7 @@ struct TestSoftmax { } SimpleSoftmax(e, count); - Softmax(x, count, /*worker=*/0); + Softmax(x, count, hwy::Profiler::Get(), /*worker=*/0); T sum = 0.0f; for (size_t i = 0; i < count; ++i) { @@ -349,6 +349,9 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void ScalarRopeAndMulBy( void TestRopeAndMulBy() { ThreadingArgs threading_args; ThreadingContext ctx(threading_args); + hwy::Profiler& p = ctx.profiler; + const size_t worker = 0; + const ModelConfig config(Model::GEMMA2_9B, Type::kSFP, ChooseWrapping(Model::GEMMA2_9B)); const size_t dim_qkv = config.layer_configs[0].qkv_dim; @@ -381,7 +384,8 @@ void TestRopeAndMulBy() { CopyMat(x, qactual); ScalarRopeAndMulBy(qmul, qexpected.Row(0), dim_qkv, inv_timescale.Row(0), pos); - RopeAndMulBy(qmul, qactual.Row(0), dim_qkv, inv_timescale.Row(0), pos); + RopeAndMulBy(qmul, qactual.Row(0), dim_qkv, inv_timescale.Row(0), pos, p, + worker); for (size_t i = 0; i < dim_qkv; ++i) { EXPECT_NEAR(qexpected.Row(0)[i], qactual.Row(0)[i], 1e-4) << " " << i; } @@ -391,7 +395,7 @@ void TestRopeAndMulBy() { CopyMat(x, qactual); ScalarRopeAndMulBy(1.0f, qexpected.Row(0), dim_qkv, inv_timescale.Row(0), pos); - Rope(qactual.Row(0), dim_qkv, inv_timescale.Row(0), pos); + Rope(qactual.Row(0), dim_qkv, inv_timescale.Row(0), pos, p, worker); for (size_t i = 0; i < dim_qkv; ++i) { EXPECT_NEAR(qexpected.Row(0)[i], qactual.Row(0)[i], 1e-4) << " " << i; } @@ -402,9 +406,10 @@ void TestRopeAndMulBy() { CopyMat(x, kactual2); ScalarRopeAndMulBy(kmul, kexpected.Row(0), dim_qkv, inv_timescale.Row(0), pos); - RopeAndMulBy(kmul, kactual.Row(0), dim_qkv, inv_timescale.Row(0), pos); + RopeAndMulBy(kmul, kactual.Row(0), dim_qkv, inv_timescale.Row(0), pos, p, + worker); static_assert(kmul == 1.0f, ""); - Rope(kactual2.Row(0), dim_qkv, inv_timescale.Row(0), pos); + Rope(kactual2.Row(0), dim_qkv, inv_timescale.Row(0), pos, p, worker); for (size_t i = 0; i < dim_qkv; ++i) { EXPECT_NEAR(kexpected.Row(0)[i], kactual.Row(0)[i], 1e-4) << " " << i; @@ -454,7 +459,7 @@ void TestRMSNorm(hwy::RandomState& rng) { } ScalarRMSNorm(vec, weight, expected, kSize); - RMSNorm(vec, weight, 0, actual, kSize, /*worker=*/0); + RMSNorm(vec, weight, 0, actual, kSize, hwy::Profiler::Get(), /*worker=*/0); for (size_t i = 0; i < kSize; i++) { const float e = hwy::ConvertScalarTo(expected[i]); @@ -580,11 +585,13 @@ void TestAllLayerNorm() { } void TestSampleTopK() { + hwy::Profiler& p = hwy::Profiler::Get(); + const size_t worker = 0; const size_t kSize = 52; std::vector logits(kSize); // Create a vector going from -100 to -100+51=49 and take Softmax. std::iota(logits.begin(), logits.end(), -100.0f); - Softmax(logits.data(), kSize, /*worker=*/0); + Softmax(logits.data(), kSize, p, worker); std::mt19937 gen; gen.seed(0x12345678); float temperature = 1.0f; @@ -600,7 +607,7 @@ void TestSampleTopK() { EXPECT_EQ(sample, 50); // Last even index. // Reset the logits to a positive, increasing sequence and take Softmax. std::iota(logits.begin(), logits.end(), 1.0f); - Softmax(logits.data(), kSize, /*worker=*/0); + Softmax(logits.data(), kSize, p, worker); // Sample from the top 3, expect one of the top 3 even indices. for (int i = 0; i < 100; ++i) { sample = SampleTopK(logits.data(), /*k=*/3, kSize, gen, temperature, diff --git a/util/threading.cc b/util/threading.cc index 6930612..1001f05 100644 --- a/util/threading.cc +++ b/util/threading.cc @@ -29,6 +29,7 @@ #include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/contrib/thread_pool/topology.h" +#include "hwy/profiler.h" namespace gcpp { @@ -171,6 +172,8 @@ NestedPools::NestedPools(const BoundedTopology& topology, HWY_ASSERT(max_clusters_per_package_ <= 64); HWY_ASSERT(max_workers_per_cluster_ >= 1); HWY_ASSERT(max_workers_per_cluster_ <= 256); + + hwy::Profiler::Get().SetMaxThreads(MaxWorkers()); } // `max_or_zero` == 0 means no limit. diff --git a/util/threading_context.cc b/util/threading_context.cc index 3bb1080..81155c5 100644 --- a/util/threading_context.cc +++ b/util/threading_context.cc @@ -72,7 +72,8 @@ static void TunePool(hwy::ThreadPool& pool) { } ThreadingContext::ThreadingContext(const ThreadingArgs& args) - : topology(BoundedSlice(args.skip_packages, args.max_packages), + : profiler(hwy::Profiler::Get()), + topology(BoundedSlice(args.skip_packages, args.max_packages), BoundedSlice(args.skip_clusters, args.max_clusters), BoundedSlice(args.skip_lps, args.max_lps)), allocator(topology, args.bind != Tristate::kFalse), diff --git a/util/threading_context.h b/util/threading_context.h index 8d14fdf..08387d0 100644 --- a/util/threading_context.h +++ b/util/threading_context.h @@ -90,6 +90,7 @@ struct ThreadingContext { // Expected to be called early in the program, before threading starts. explicit ThreadingContext(const ThreadingArgs& args); + hwy::Profiler& profiler; BoundedTopology topology; Allocator allocator; NestedPools pools;