(Resubmit) Prepare profiler annotations for new API

Pass hwy::Profiler& to low-level functions.
Used ThreadingContext arg instead of NestedPools.
Use new PROFILER_ZONE3.

PiperOrigin-RevId: 794461159
This commit is contained in:
Jan Wassenberg 2025-08-13 01:37:53 -07:00 committed by Copybara-Service
parent a2d9133f7d
commit faa4102992
20 changed files with 220 additions and 168 deletions

View File

@ -80,6 +80,7 @@ cc_library(
":topology", ":topology",
# Placeholder for container detection, do not remove # Placeholder for container detection, do not remove
"@highway//:hwy", "@highway//:hwy",
"@highway//:profiler",
"@highway//:thread_pool", "@highway//:thread_pool",
"@highway//:topology", "@highway//:topology",
], ],
@ -379,6 +380,7 @@ cc_test(
"@highway//:hwy", "@highway//:hwy",
"@highway//:hwy_test_util", "@highway//:hwy_test_util",
"@highway//:nanobenchmark", #buildcleaner: keep "@highway//:nanobenchmark", #buildcleaner: keep
"@highway//:profiler",
], ],
) )

View File

@ -22,7 +22,7 @@ set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_EXPORT_COMPILE_COMMANDS ON) set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 9414b48aeec251b69e6cadbfa42bebb5ddae1c34 EXCLUDE_FROM_ALL) FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 92d327e841d78e11ae888757a3e16d291951cf64 EXCLUDE_FROM_ALL)
FetchContent_MakeAvailable(highway) FetchContent_MakeAvailable(highway)
## Note: absl needs to be installed by sentencepiece. This will only happen if ## Note: absl needs to be installed by sentencepiece. This will only happen if

View File

@ -18,7 +18,7 @@ bazel_dep(name = "google_benchmark", version = "1.8.5")
# Require a more recent version. # Require a more recent version.
git_override( git_override(
module_name = "highway", module_name = "highway",
commit = "9414b48aeec251b69e6cadbfa42bebb5ddae1c34", commit = "92d327e841d78e11ae888757a3e16d291951cf64",
remote = "https://github.com/google/highway", remote = "https://github.com/google/highway",
) )

View File

@ -469,7 +469,7 @@ FetchContent_MakeAvailable(sentencepiece)
FetchContent_Declare(gemma GIT_REPOSITORY https://github.com/google/gemma.cpp GIT_TAG origin/main) FetchContent_Declare(gemma GIT_REPOSITORY https://github.com/google/gemma.cpp GIT_TAG origin/main)
FetchContent_MakeAvailable(gemma) FetchContent_MakeAvailable(gemma)
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG da250571a45826b21eebbddc1e50d0c1137dee5f) FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 92d327e841d78e11ae888757a3e16d291951cf64)
FetchContent_MakeAvailable(highway) FetchContent_MakeAvailable(highway)
``` ```

View File

@ -84,8 +84,9 @@ HWY_BEFORE_NAMESPACE();
namespace gcpp { namespace gcpp {
namespace HWY_NAMESPACE { namespace HWY_NAMESPACE {
void CallSoftmax(float* HWY_RESTRICT logits, size_t vocab_size) { void CallSoftmax(float* HWY_RESTRICT logits, size_t vocab_size,
Softmax(logits, vocab_size, /*worker=*/0); hwy::Profiler& p) {
Softmax(logits, vocab_size, p, hwy::Profiler::Thread());
} }
} // namespace HWY_NAMESPACE } // namespace HWY_NAMESPACE
@ -109,7 +110,7 @@ float ComputeCrossEntropy(const Gemma& gemma, size_t max_generated_tokens,
const SampleFunc sample_token = [&](float* probs, const SampleFunc sample_token = [&](float* probs,
size_t vocab_size) -> TokenAndProb { size_t vocab_size) -> TokenAndProb {
// input is logits, not yet probabilities // input is logits, not yet probabilities
HWY_DYNAMIC_DISPATCH(CallSoftmax)(probs, vocab_size); HWY_DYNAMIC_DISPATCH(CallSoftmax)(probs, vocab_size, env.ctx.profiler);
// We are called for each token, but pos starts at 1. Clamping // We are called for each token, but pos starts at 1. Clamping
// max_generated_tokens to prompt.size() should prevent overrun. // max_generated_tokens to prompt.size() should prevent overrun.
HWY_ASSERT(pos < prompt.size()); HWY_ASSERT(pos < prompt.size());

View File

@ -18,7 +18,7 @@ set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_CXX_STANDARD_REQUIRED ON)
include(FetchContent) include(FetchContent)
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG bfc36a6e633af94e63ac4b91c687bf0354cb24e0) FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 92d327e841d78e11ae888757a3e16d291951cf64)
FetchContent_MakeAvailable(highway) FetchContent_MakeAvailable(highway)
FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 9045b2f60fa2b323dfac0eaef8fc17565036f9f9) FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 9045b2f60fa2b323dfac0eaef8fc17565036f9f9)
FetchContent_MakeAvailable(sentencepiece) FetchContent_MakeAvailable(sentencepiece)

View File

@ -18,7 +18,7 @@ set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_CXX_STANDARD_REQUIRED ON)
include(FetchContent) include(FetchContent)
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 9414b48aeec251b69e6cadbfa42bebb5ddae1c34) FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 92d327e841d78e11ae888757a3e16d291951cf64)
FetchContent_MakeAvailable(highway) FetchContent_MakeAvailable(highway)
FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 53de76561cfc149d3c01037f0595669ad32a5e7c) FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 53de76561cfc149d3c01037f0595669ad32a5e7c)
FetchContent_MakeAvailable(sentencepiece) FetchContent_MakeAvailable(sentencepiece)

View File

@ -19,6 +19,7 @@
#include <vector> #include <vector>
#include "compression/types.h" // GEMMA_DISABLED_TARGETS #include "compression/types.h" // GEMMA_DISABLED_TARGETS
#include "util/threading_context.h"
#ifndef HWY_DISABLED_TARGETS #ifndef HWY_DISABLED_TARGETS
#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS #define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS
#endif // HWY_DISABLED_TARGETS #endif // HWY_DISABLED_TARGETS
@ -53,8 +54,9 @@ static HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos,
const hwy::Divisor& div_seq_len, const hwy::Divisor& div_seq_len,
const float* HWY_RESTRICT q, const float* HWY_RESTRICT q,
const MatPtrT<KV_t>& k, float* HWY_RESTRICT att, const MatPtrT<KV_t>& k, float* HWY_RESTRICT att,
const size_t worker) { hwy::Profiler& p, const size_t worker) {
PROFILER_ZONE2(worker, "Gen.Attention.QDotK"); static const auto zone = p.AddZone("Gen.Attention.QDotK");
PROFILER_ZONE3(p, worker, zone);
if (HWY_LIKELY(last_pos < static_cast<size_t>(div_seq_len.GetDivisor()))) { if (HWY_LIKELY(last_pos < static_cast<size_t>(div_seq_len.GetDivisor()))) {
// Slightly faster: no wraparound. // Slightly faster: no wraparound.
for (size_t pos = start_pos; pos <= last_pos; ++pos) { for (size_t pos = start_pos; pos <= last_pos; ++pos) {
@ -73,8 +75,8 @@ static HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos,
static void PositionalEncodingQK(float* qk, const size_t layer_idx, static void PositionalEncodingQK(float* qk, const size_t layer_idx,
const LayerWeightsPtrs& layer, const LayerWeightsPtrs& layer,
const AttentionActivations& activations, const AttentionActivations& activations,
const size_t worker, const size_t pos, hwy::Profiler& p, const size_t worker,
const float mul = 1.0f) { const size_t pos, const float mul = 1.0f) {
const size_t qkv_dim = layer.layer_config.qkv_dim; const size_t qkv_dim = layer.layer_config.qkv_dim;
const PostQKType& post_qk = layer.layer_config.post_qk; const PostQKType& post_qk = layer.layer_config.post_qk;
// qk is either q or k, so qkv_dim is the length we operate on. // qk is either q or k, so qkv_dim is the length we operate on.
@ -86,10 +88,10 @@ static void PositionalEncodingQK(float* qk, const size_t layer_idx,
} }
// PostQKType::Rope // PostQKType::Rope
if (post_qk == PostQKType::HalfRope) { if (post_qk == PostQKType::HalfRope) {
Rope(qk, qkv_dim / 2, inv_timescale, pos, worker); Rope(qk, qkv_dim / 2, inv_timescale, pos, p, worker);
if (mul != 1.0f) MulByConst(mul, qk, qkv_dim, worker); if (mul != 1.0f) MulByConst(mul, qk, qkv_dim, p, worker);
} else { } else {
RopeAndMulBy(mul, qk, qkv_dim, inv_timescale, pos, worker); RopeAndMulBy(mul, qk, qkv_dim, inv_timescale, pos, p, worker);
} }
} }
@ -97,26 +99,31 @@ static void PositionalEncodingQK(float* qk, const size_t layer_idx,
// `att_out`. Equivalent in gemma/modules.py: // `att_out`. Equivalent in gemma/modules.py:
// encoded = jnp.einsum('BTNS,BSNH->BTNH', probs, value_proj) // encoded = jnp.einsum('BTNS,BSNH->BTNH', probs, value_proj)
// `v` is a strided view of the kv cache with dimensions [seq_len, qkv_dim]. // `v` is a strided view of the kv cache with dimensions [seq_len, qkv_dim].
static HWY_INLINE void WeightedSumV( static HWY_INLINE void WeightedSumV(const size_t start_pos,
const size_t start_pos, const size_t last_pos, const size_t last_pos,
const hwy::Divisor& div_seq_len, const float* HWY_RESTRICT att, const hwy::Divisor& div_seq_len,
const MatPtrT<KV_t>& v, float* HWY_RESTRICT att_out, const size_t worker) { const float* HWY_RESTRICT att,
const MatPtrT<KV_t>& v,
float* HWY_RESTRICT att_out,
hwy::Profiler& p, const size_t worker) {
if (HWY_LIKELY(last_pos < static_cast<size_t>(div_seq_len.GetDivisor()))) { if (HWY_LIKELY(last_pos < static_cast<size_t>(div_seq_len.GetDivisor()))) {
// Slightly faster: no wraparound. Could be replaced with MatMul(att, v) if // Slightly faster: no wraparound. Could be replaced with MatMul(att, v) if
// we supported non-transposed B. // we supported non-transposed B.
// TODO: 2..4x unroll // TODO: 2..4x unroll
MulByConstTo(att[start_pos], v.Row(start_pos), att_out, v.Cols(), worker); MulByConstTo(att[start_pos], v.Row(start_pos), att_out, v.Cols(), p,
worker);
for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) { for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) {
MulByConstAndAdd(att[pos], v.Row(pos), att_out, v.Cols(), worker); MulByConstAndAdd(att[pos], v.Row(pos), att_out, v.Cols(), p, worker);
} }
} else { } else {
{ {
const size_t pos_mod = div_seq_len.Remainder(start_pos); const size_t pos_mod = div_seq_len.Remainder(start_pos);
MulByConstTo(att[pos_mod], v.Row(pos_mod), att_out, v.Cols(), worker); MulByConstTo(att[pos_mod], v.Row(pos_mod), att_out, v.Cols(), p, worker);
} }
for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) { for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) {
const size_t pos_mod = div_seq_len.Remainder(pos); const size_t pos_mod = div_seq_len.Remainder(pos);
MulByConstAndAdd(att[pos_mod], v.Row(pos_mod), att_out, v.Cols(), worker); MulByConstAndAdd(att[pos_mod], v.Row(pos_mod), att_out, v.Cols(), p,
worker);
} }
} }
} }
@ -128,7 +135,7 @@ void SingleDotSoftmaxWeightedSum(
float* HWY_RESTRICT q, const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v, float* HWY_RESTRICT q, const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v,
const size_t layer_idx, const LayerWeightsPtrs& layer, const size_t layer_idx, const LayerWeightsPtrs& layer,
const AttentionActivations& activations, float* HWY_RESTRICT att, const AttentionActivations& activations, float* HWY_RESTRICT att,
float* HWY_RESTRICT att_out, const size_t worker) { float* HWY_RESTRICT att_out, hwy::Profiler& p, const size_t worker) {
const float att_cap = activations.config.att_cap; const float att_cap = activations.config.att_cap;
const float query_scale = activations.query_scale; const float query_scale = activations.query_scale;
const size_t seq_len = const size_t seq_len =
@ -138,21 +145,21 @@ void SingleDotSoftmaxWeightedSum(
if (layer.query_norm_scale.HasPtr()) { if (layer.query_norm_scale.HasPtr()) {
CallUpcasted(&layer.query_norm_scale, [&](const auto* weights_t) { CallUpcasted(&layer.query_norm_scale, [&](const auto* weights_t) {
RMSNormInplace(weights_t->PackedScale1(), 0, q, RMSNormInplace(weights_t->PackedScale1(), 0, q,
layer.layer_config.qkv_dim, worker); layer.layer_config.qkv_dim, p, worker);
}); });
} }
PositionalEncodingQK(q, layer_idx, layer, activations, worker, pos, PositionalEncodingQK(q, layer_idx, layer, activations, p, worker, pos,
query_scale); query_scale);
QDotK(start_pos, last_pos, activations.div_seq_len, q, k, att, worker); QDotK(start_pos, last_pos, activations.div_seq_len, q, k, att, p, worker);
// SoftMax with optional SoftCap yields "probabilities" in att. // SoftMax with optional SoftCap yields "probabilities" in att.
const size_t att_len = HWY_MIN(last_pos + 1, seq_len); const size_t att_len = HWY_MIN(last_pos + 1, seq_len);
MaybeLogitsSoftCap(att_cap, att, att_len, worker); MaybeLogitsSoftCap(att_cap, att, att_len, p, worker);
Softmax(att, att_len, worker, /*temperature=*/1.0f); Softmax(att, att_len, p, worker, /*temperature=*/1.0f);
WeightedSumV(start_pos, last_pos, activations.div_seq_len, att, v, att_out, WeightedSumV(start_pos, last_pos, activations.div_seq_len, att, v, att_out, p,
worker); worker);
} }
@ -167,9 +174,8 @@ static HWY_INLINE size_t StartPos(size_t pos, const ModelConfig& config,
void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx, void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
const LayerWeightsPtrs& layer, const LayerWeightsPtrs& layer,
AttentionActivations& activations, QBatch& qbatch, AttentionActivations& activations, QBatch& qbatch,
NestedPools& pools) { ThreadingContext& ctx) {
static const uint32_t HWY_MAYBE_UNUSED zone_id_par = static const auto zone = ctx.profiler.AddZone("Gen.Attention.DotSoftmax.par");
PROFILER_ADD_ZONE("Gen.Attention.DotSoftmax.par");
const hwy::Divisor div_qbatch(qbatch.Size()); const hwy::Divisor div_qbatch(qbatch.Size());
const LayerConfig& layer_config = layer.layer_config; const LayerConfig& layer_config = layer.layer_config;
@ -189,9 +195,7 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
const auto func = [&](const size_t task, size_t worker) HWY_ATTR { const auto func = [&](const size_t task, size_t worker) HWY_ATTR {
const size_t tq_idx = activations.div_heads.Divide(task); const size_t tq_idx = activations.div_heads.Divide(task);
const size_t head = activations.div_heads.Remainder(task); const size_t head = activations.div_heads.Remainder(task);
#if PROFILER_ENABLED PROFILER_ZONE3(ctx.profiler, worker, zone);
const hwy::Zone zone(worker, zone_id_par);
#endif
const size_t qi = div_qbatch.Remainder(tq_idx); const size_t qi = div_qbatch.Remainder(tq_idx);
const size_t batch_idx = div_qbatch.Divide(tq_idx); const size_t batch_idx = div_qbatch.Divide(tq_idx);
@ -223,14 +227,15 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
v.SetPtr(kv_cache.Row(0) + kv_head_offset + qkv_dim, kv_cache.Stride()); v.SetPtr(kv_cache.Row(0) + kv_head_offset + qkv_dim, kv_cache.Stride());
SingleDotSoftmaxWeightedSum(pos, start_pos, last_pos, q, k, v, layer_idx, SingleDotSoftmaxWeightedSum(pos, start_pos, last_pos, q, k, v, layer_idx,
layer, activations, att, att_out, worker); layer, activations, att, att_out, ctx.profiler,
worker);
}; };
{ {
PROFILER_ZONE("Gen.Attention.DotSoftmax.ForkJoin"); PROFILER_ZONE("Gen.Attention.DotSoftmax.ForkJoin");
// Full parallelism is helpful, SmallParallelFor is insufficient. // Full parallelism is helpful, SmallParallelFor is insufficient.
ParallelFor(num_tokens * div_qbatch.GetDivisor() * layer_config.heads, ParallelFor(num_tokens * div_qbatch.GetDivisor() * layer_config.heads,
pools, func); ctx.pools, func);
} }
} }
@ -303,12 +308,12 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
if (layer.key_norm_scale.HasPtr()) { if (layer.key_norm_scale.HasPtr()) {
CallUpcasted(&layer.key_norm_scale, [&](const auto* weights_t) { CallUpcasted(&layer.key_norm_scale, [&](const auto* weights_t) {
RMSNormInplace(weights_t->PackedScale1(), 0, kv_f32, qkv_dim, RMSNormInplace(weights_t->PackedScale1(), 0, kv_f32, qkv_dim,
thread); env.ctx.profiler, thread);
}); });
} }
PositionalEncodingQK(kv_f32, layer_idx, layer, activations, thread, PositionalEncodingQK(kv_f32, layer_idx, layer, activations,
pos); env.ctx.profiler, thread, pos);
CompressPerThread tls; CompressPerThread tls;
Compress(kv_f32, 2 * qkv_dim, tls, MakeSpan(kv, 2 * qkv_dim), 0); Compress(kv_f32, 2 * qkv_dim, tls, MakeSpan(kv, 2 * qkv_dim), 0);
}); });
@ -339,6 +344,10 @@ void GemmaAttention(size_t num_tokens, const size_t layer_idx,
const LayerWeightsPtrs& layer, const LayerWeightsPtrs& layer,
AttentionActivations& activations, QBatch& qbatch, AttentionActivations& activations, QBatch& qbatch,
MatMulEnv& env, int flags) { MatMulEnv& env, int flags) {
static const auto zone =
env.ctx.profiler.AddZone("Gen.Attention", ProfilerFlags::kInclusive);
PROFILER_ZONE3(env.ctx.profiler, hwy::Profiler::Thread(), zone);
const LayerConfig& layer_config = layer.layer_config; const LayerConfig& layer_config = layer.layer_config;
HWY_DASSERT(!layer_config.IsMHA()); // No longer supported. HWY_DASSERT(!layer_config.IsMHA()); // No longer supported.
HWY_DASSERT_M((layer_config.heads % layer_config.kv_heads) == 0, HWY_DASSERT_M((layer_config.heads % layer_config.kv_heads) == 0,
@ -347,7 +356,7 @@ void GemmaAttention(size_t num_tokens, const size_t layer_idx,
ComputeQKV(num_tokens, layer_idx, layer, activations, qbatch, flags, env); ComputeQKV(num_tokens, layer_idx, layer, activations, qbatch, flags, env);
DotSoftmaxWeightedSum(num_tokens, layer_idx, layer, activations, qbatch, DotSoftmaxWeightedSum(num_tokens, layer_idx, layer, activations, qbatch,
env.ctx.pools); env.ctx);
SumHeads(layer, activations, env); SumHeads(layer, activations, env);
} }

View File

@ -33,12 +33,12 @@ namespace gcpp {
float* HWY_RESTRICT q, const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v, \ float* HWY_RESTRICT q, const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v, \
size_t layer_idx, const LayerWeightsPtrs& layer, \ size_t layer_idx, const LayerWeightsPtrs& layer, \
const AttentionActivations& activations, float* HWY_RESTRICT att, \ const AttentionActivations& activations, float* HWY_RESTRICT att, \
float* HWY_RESTRICT att_out, size_t worker); \ float* HWY_RESTRICT att_out, hwy::Profiler& p, size_t worker); \
\ \
void DotSoftmaxWeightedSum(const size_t num_tokens, size_t layer_idx, \ void DotSoftmaxWeightedSum(const size_t num_tokens, size_t layer_idx, \
const LayerWeightsPtrs& layer, \ const LayerWeightsPtrs& layer, \
AttentionActivations& activations, \ AttentionActivations& activations, \
QBatch& qbatch, NestedPools& pools); \ QBatch& qbatch, ThreadingContext& ctx); \
\ \
void GemmaAttention(size_t num_tokens, const size_t layer_idx, \ void GemmaAttention(size_t num_tokens, const size_t layer_idx, \
const LayerWeightsPtrs& layer, \ const LayerWeightsPtrs& layer, \

View File

@ -45,9 +45,10 @@ namespace HWY_NAMESPACE {
template <typename T> template <typename T>
void Activation(ActivationType activation, T* HWY_RESTRICT c1, void Activation(ActivationType activation, T* HWY_RESTRICT c1,
const T* HWY_RESTRICT c2, const size_t count, const T* HWY_RESTRICT c2, const size_t count, hwy::Profiler& p,
const size_t worker) { const size_t worker) {
PROFILER_ZONE2(worker, "Gen.Activation"); static const auto zone = p.AddZone("Gen.Activation");
PROFILER_ZONE3(p, worker, zone);
namespace hn = hwy::HWY_NAMESPACE; namespace hn = hwy::HWY_NAMESPACE;
using DF = hn::ScalableTag<T>; using DF = hn::ScalableTag<T>;
using VF = hn::Vec<DF>; using VF = hn::Vec<DF>;
@ -64,28 +65,30 @@ void Activation(ActivationType activation, T* HWY_RESTRICT c1,
// No C2 multiplier. // No C2 multiplier.
template <class Mat> template <class Mat>
void ActivationBatched(ActivationType activation, Mat& c1, NestedPools& pools) { void ActivationBatched(ActivationType activation, Mat& c1,
ThreadingContext& ctx) {
using T = typename Mat::T; using T = typename Mat::T;
SmallParallelFor(c1.Rows(), pools, [&](uint64_t task, size_t worker) { SmallParallelFor(c1.Rows(), ctx.pools, [&](uint64_t task, size_t worker) {
// Cast to correct type so type deduction works. // Cast to correct type so type deduction works.
Activation(activation, c1.Row(task), static_cast<const T*>(nullptr), Activation(activation, c1.Row(task), static_cast<const T*>(nullptr),
c1.Cols(), worker); c1.Cols(), ctx.profiler, worker);
}); });
} }
template <class Mat> template <class Mat>
HWY_NOINLINE void ActivationBatched(ActivationType activation, Mat& c1, HWY_NOINLINE void ActivationBatched(ActivationType activation, Mat& c1,
const Mat* c2, NestedPools& pools) { const Mat* c2, ThreadingContext& ctx) {
using T = typename Mat::T; using T = typename Mat::T;
HWY_DASSERT(c1.SameShape(*c2)); HWY_DASSERT(c1.SameShape(*c2));
if (c2 && c2->HasPtr()) { if (c2 && c2->HasPtr()) {
SmallParallelFor(c1.Rows(), pools, [&](uint64_t task, size_t worker) { SmallParallelFor(c1.Rows(), ctx.pools, [&](uint64_t task, size_t worker) {
Activation(activation, c1.Row(task), c2->Row(task), c1.Cols(), worker); Activation(activation, c1.Row(task), c2->Row(task), c1.Cols(),
ctx.profiler, worker);
}); });
} else { // No multiplier } else { // No multiplier
SmallParallelFor(c1.Rows(), pools, [&](uint64_t task, size_t worker) { SmallParallelFor(c1.Rows(), ctx.pools, [&](uint64_t task, size_t worker) {
Activation(activation, c1.Row(task), static_cast<const T*>(nullptr), Activation(activation, c1.Row(task), static_cast<const T*>(nullptr),
c1.Cols(), worker); c1.Cols(), ctx.profiler, worker);
}); });
} }
} }
@ -110,7 +113,9 @@ void PostNorm(PostNormType post_norm, const MatPtr& weights,
static inline void FFWNoVit(const LayerWeightsPtrs& layer, static inline void FFWNoVit(const LayerWeightsPtrs& layer,
Activations& activations, MatMulEnv& env) { Activations& activations, MatMulEnv& env) {
PROFILER_ZONE("Gen.FFW"); static const auto zone =
env.ctx.profiler.AddZone("Gen.FFW", ProfilerFlags::kInclusive);
PROFILER_ZONE3(env.ctx.profiler, hwy::Profiler::Thread(), zone);
const LayerConfig& layer_config = layer.layer_config; const LayerConfig& layer_config = layer.layer_config;
const size_t ffh_hidden_dim = layer_config.ff_hidden_dim; const size_t ffh_hidden_dim = layer_config.ff_hidden_dim;
@ -129,7 +134,7 @@ static inline void FFWNoVit(const LayerWeightsPtrs& layer,
// Activation (Gelu) and maybe multiply by gate. Store activations in act. // Activation (Gelu) and maybe multiply by gate. Store activations in act.
ActivationBatched(layer_config.activation, activations.C1, &activations.C2, ActivationBatched(layer_config.activation, activations.C1, &activations.C2,
env.ctx.pools); env.ctx);
// Hidden layer -> output layer. // Hidden layer -> output layer.
CallMatMul(activations.C1, layer.linear_w, output_bias, env, CallMatMul(activations.C1, layer.linear_w, output_bias, env,

View File

@ -55,6 +55,7 @@
#include "io/io.h" // Path #include "io/io.h" // Path
#include "ops/matmul.h" #include "ops/matmul.h"
#include "paligemma/image.h" #include "paligemma/image.h"
#include "util/basics.h" // PROFILER_ZONE3
#include "util/threading_context.h" #include "util/threading_context.h"
#include "hwy/aligned_allocator.h" // Span #include "hwy/aligned_allocator.h" // Span
#include "hwy/base.h" #include "hwy/base.h"
@ -138,7 +139,8 @@ static float EmbeddingScaling(size_t model_dim) {
static HWY_NOINLINE size_t static HWY_NOINLINE size_t
EmbedMMToken(int token, size_t qi, size_t pos, size_t pos_in_prompt, EmbedMMToken(int token, size_t qi, size_t pos, size_t pos_in_prompt,
const ModelConfig& model_config, const WeightsPtrs& weights, const ModelConfig& model_config, const WeightsPtrs& weights,
MatStorageT<float>& x, const ImageTokens* image_tokens = nullptr, MatStorageT<float>& x, ThreadingContext& ctx,
const ImageTokens* image_tokens = nullptr,
size_t image_token_position = 0) { size_t image_token_position = 0) {
// Image tokens just need to be copied. // Image tokens just need to be copied.
if (model_config.wrapping == PromptWrapping::GEMMA_VLM && if (model_config.wrapping == PromptWrapping::GEMMA_VLM &&
@ -174,7 +176,8 @@ EmbedMMToken(int token, size_t qi, size_t pos, size_t pos_in_prompt,
const hn::ScalableTag<float> df; const hn::ScalableTag<float> df;
DecompressAndZeroPad(df, embedding_span, embedding_ofs, x.Row(qi), DecompressAndZeroPad(df, embedding_span, embedding_ofs, x.Row(qi),
model_dim); model_dim);
MulByConst(emb_scaling * weights_t->Scale(), x.Row(qi), model_dim, worker); MulByConst(emb_scaling * weights_t->Scale(), x.Row(qi), model_dim,
ctx.profiler, worker);
}); });
if (model_config.absolute_pe) { if (model_config.absolute_pe) {
@ -249,7 +252,7 @@ static HWY_NOINLINE void PrefillTBatch(const ModelConfig& config,
const int token = qbatch_1.Prompt(0)[pos_in_prompt]; const int token = qbatch_1.Prompt(0)[pos_in_prompt];
image_token_position = EmbedMMToken( image_token_position = EmbedMMToken(
token, ti, pos, pos_in_prompt, config, weights, activations.x, token, ti, pos, pos_in_prompt, config, weights, activations.x,
runtime_config.image_tokens, image_token_position); env.ctx, runtime_config.image_tokens, image_token_position);
} }
// Transformer with one batch of tokens from a single query. // Transformer with one batch of tokens from a single query.
@ -306,7 +309,7 @@ static HWY_NOINLINE void Transformer(const ModelConfig& config,
// TODO: parallelize? // TODO: parallelize?
for (size_t qi = 0; qi < qbatch.Size(); ++qi) { for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
EmbedMMToken(qbatch.PrevToken(qi), qi, qbatch.Pos(qi), EmbedMMToken(qbatch.PrevToken(qi), qi, qbatch.Pos(qi),
/*pos_in_prompt=*/0, config, weights, activations.x); /*pos_in_prompt=*/0, config, weights, activations.x, env.ctx);
} }
for (size_t layer_idx = 0; layer_idx < weights.c_layers.size(); ++layer_idx) { for (size_t layer_idx = 0; layer_idx < weights.c_layers.size(); ++layer_idx) {
@ -419,7 +422,8 @@ static void DecodeStepT(const ModelConfig& config,
const size_t worker = 0; // TODO: parallelize const size_t worker = 0; // TODO: parallelize
non_eos.Foreach([&](size_t qi) { non_eos.Foreach([&](size_t qi) {
float* HWY_RESTRICT logits = activations.logits.Row(qi); float* HWY_RESTRICT logits = activations.logits.Row(qi);
MaybeLogitsSoftCap(config.final_cap, logits, config.vocab_size, worker); MaybeLogitsSoftCap(config.final_cap, logits, config.vocab_size,
env.ctx.profiler, worker);
const TokenAndProb tp = sample_token(logits, config.vocab_size); const TokenAndProb tp = sample_token(logits, config.vocab_size);
timing_info.NotifyGenerated(); timing_info.NotifyGenerated();
@ -429,27 +433,28 @@ static void DecodeStepT(const ModelConfig& config,
} }
static HWY_INLINE SampleFunc static HWY_INLINE SampleFunc
ChooseSampleFunc(const RuntimeConfig& runtime_config) { ChooseSampleFunc(const RuntimeConfig& runtime_config, ThreadingContext& ctx) {
// If user provided a sample_func, use it. // If user provided a sample_func, use it.
if (runtime_config.sample_func) return runtime_config.sample_func; if (runtime_config.sample_func) return runtime_config.sample_func;
static const auto zone = ctx.profiler.AddZone("Gen.Sample Top1");
const size_t worker = 0; // TODO: parallelize const size_t worker = 0; // TODO: parallelize
// Fast path for top-1 with no accept_token. // Fast path for top-1 with no accept_token.
if (runtime_config.top_k == 1 && !runtime_config.accept_token) { if (runtime_config.top_k == 1 && !runtime_config.accept_token) {
return [](float* logits, size_t vocab_size) HWY_ATTR -> TokenAndProb { return [&](float* logits, size_t vocab_size) HWY_ATTR -> TokenAndProb {
PROFILER_ZONE2(worker, "Gen.Sample Top1"); PROFILER_ZONE3(ctx.profiler, worker, zone);
return Top1OfSoftmax(logits, vocab_size); return Top1OfSoftmax(logits, vocab_size);
}; };
} }
// General case: Softmax with top-k sampling. // General case: Softmax with top-k sampling.
return [&runtime_config](float* logits, return [&](float* logits, size_t vocab_size) HWY_ATTR -> TokenAndProb {
size_t vocab_size) HWY_ATTR -> TokenAndProb {
PROFILER_ZONE("Gen.Sample general"); PROFILER_ZONE("Gen.Sample general");
return FusedSoftmaxAndSampleTopK( return FusedSoftmaxAndSampleTopK(
logits, runtime_config.top_k, vocab_size, *runtime_config.gen, logits, runtime_config.top_k, vocab_size, *runtime_config.gen,
runtime_config.temperature, runtime_config.accept_token, worker); runtime_config.temperature, runtime_config.accept_token, ctx.profiler,
worker);
}; };
} }
@ -524,7 +529,7 @@ static void GenerateT(const ModelConfig& config,
max_gen_steps = seq_len - max_prompt_size; max_gen_steps = seq_len - max_prompt_size;
} }
const SampleFunc sample_token = ChooseSampleFunc(runtime_config); const SampleFunc sample_token = ChooseSampleFunc(runtime_config, env.ctx);
{ {
timing_info.generate_start = hwy::platform::Now(); timing_info.generate_start = hwy::platform::Now();

View File

@ -95,7 +95,7 @@ class VitAttention {
float* HWY_RESTRICT q = float* HWY_RESTRICT q =
activations_.attention.q.Row(token) + head * 3 * qkv_dim; activations_.attention.q.Row(token) + head * 3 * qkv_dim;
// TODO: shift to MatMul with A.scale once MatMul is confirmed working // TODO: shift to MatMul with A.scale once MatMul is confirmed working
MulByConst(query_scale, q, qkv_dim, worker); MulByConst(query_scale, q, qkv_dim, env_.ctx.profiler, worker);
hwy::CopyBytes(q, Q.Row(token), qkv_dim * sizeof(float)); hwy::CopyBytes(q, Q.Row(token), qkv_dim * sizeof(float));
}); });
@ -111,7 +111,7 @@ class VitAttention {
pool_.Run(0, num_tokens_, [&](uint64_t task, size_t worker) HWY_ATTR { pool_.Run(0, num_tokens_, [&](uint64_t task, size_t worker) HWY_ATTR {
float* HWY_RESTRICT c = C.Row(task); float* HWY_RESTRICT c = C.Row(task);
Softmax(c, C.Cols(), worker); Softmax(c, C.Cols(), env_.ctx.profiler, worker);
}); });
pool_.Run(0, num_tokens_, [&](uint64_t task, size_t worker) HWY_ATTR { pool_.Run(0, num_tokens_, [&](uint64_t task, size_t worker) HWY_ATTR {
@ -121,7 +121,8 @@ class VitAttention {
for (size_t i = 0; i < seq_len; ++i) { for (size_t i = 0; i < seq_len; ++i) {
float* HWY_RESTRICT v = activations_.attention.q.Row(i) + float* HWY_RESTRICT v = activations_.attention.q.Row(i) +
head * 3 * qkv_dim + 2 * qkv_dim; head * 3 * qkv_dim + 2 * qkv_dim;
MulByConstAndAdd(C.Row(token)[i], v, att_out, qkv_dim, worker); MulByConstAndAdd(C.Row(token)[i], v, att_out, qkv_dim,
env_.ctx.profiler, worker);
} }
}); });
} }
@ -144,7 +145,7 @@ class VitAttention {
// Compute Q.K scores, which are "logits" stored in head_att. // Compute Q.K scores, which are "logits" stored in head_att.
float* HWY_RESTRICT q = float* HWY_RESTRICT q =
activations_.attention.q.Row(token) + head * 3 * qkv_dim; activations_.attention.q.Row(token) + head * 3 * qkv_dim;
MulByConst(query_scale, q, qkv_dim, worker); MulByConst(query_scale, q, qkv_dim, env_.ctx.profiler, worker);
float* HWY_RESTRICT head_att = float* HWY_RESTRICT head_att =
activations_.attention.att.Row(token) + head * seq_len; activations_.attention.att.Row(token) + head * seq_len;
for (size_t i = 0; i < seq_len; ++i) { for (size_t i = 0; i < seq_len; ++i) {
@ -153,7 +154,7 @@ class VitAttention {
head_att[i] = Dot(q, k, qkv_dim); // score = q.k head_att[i] = Dot(q, k, qkv_dim); // score = q.k
} }
// SoftMax yields "probabilities" in head_att. // SoftMax yields "probabilities" in head_att.
Softmax(head_att, seq_len, worker); Softmax(head_att, seq_len, env_.ctx.profiler, worker);
// Compute weighted sum of v into att_out. // Compute weighted sum of v into att_out.
float* HWY_RESTRICT att_out = float* HWY_RESTRICT att_out =
activations_.attention.att_out.Row(token) + head * qkv_dim; activations_.attention.att_out.Row(token) + head * qkv_dim;
@ -161,7 +162,8 @@ class VitAttention {
for (size_t i = 0; i < seq_len; ++i) { for (size_t i = 0; i < seq_len; ++i) {
float* HWY_RESTRICT v = activations_.attention.q.Row(i) + float* HWY_RESTRICT v = activations_.attention.q.Row(i) +
head * 3 * qkv_dim + 2 * qkv_dim; head * 3 * qkv_dim + 2 * qkv_dim;
MulByConstAndAdd(head_att[i], v, att_out, qkv_dim, worker); MulByConstAndAdd(head_att[i], v, att_out, qkv_dim,
env_.ctx.profiler, worker);
} }
}); });
} }
@ -224,7 +226,7 @@ void FFWVit(const LayerWeightsPtrs& layer, Activations& activations,
activations.C1); activations.C1);
// Activation (Gelu), store in C1. // Activation (Gelu), store in C1.
ActivationBatched(layer_config.activation, activations.C1, env.ctx.pools); ActivationBatched(layer_config.activation, activations.C1, env.ctx);
// Hidden layer -> output layer. // Hidden layer -> output layer.
CallMatMul(activations.C1, layer.vit.linear_1_w, output_bias, env, CallMatMul(activations.C1, layer.vit.linear_1_w, output_bias, env,
@ -334,7 +336,7 @@ void PrefillVit(const ModelConfig& model_config, const WeightsPtrs& weights,
// Apply soft embedding norm before input projection. // Apply soft embedding norm before input projection.
CallUpcasted(&weights.mm_embed_norm, [&](const auto* weights_t) { CallUpcasted(&weights.mm_embed_norm, [&](const auto* weights_t) {
RMSNormInplace(weights_t->PackedScale1(), 0, activations.x.Row(0), RMSNormInplace(weights_t->PackedScale1(), 0, activations.x.Row(0),
vit_model_dim, /*worker=*/0); vit_model_dim, env.ctx.profiler, hwy::Profiler::Thread());
}); });
} }

View File

@ -381,9 +381,11 @@ static void DecompressToBF16(MatPtr& mat,
} }
static void ReadAllToBF16(const std::vector<TensorToRead>& tensors, static void ReadAllToBF16(const std::vector<TensorToRead>& tensors,
const BlobReader& reader, hwy::ThreadPool& pool) { const BlobReader& reader, ThreadingContext& ctx) {
pool.Run(0, tensors.size(), [&](uint64_t task, size_t thread) { static const auto zone =
PROFILER_ZONE2(thread, "Startup.Weights.ReadBF16"); ctx.profiler.AddZone("Startup.Weights.ReadAllToBF16");
ctx.pools.Pool().Run(0, tensors.size(), [&](uint64_t task, size_t thread) {
PROFILER_ZONE3(ctx.profiler, thread, zone);
const TensorToRead& tensor = tensors[task]; const TensorToRead& tensor = tensors[task];
MatPtr& mat = *tensor.mat; MatPtr& mat = *tensor.mat;
@ -460,10 +462,11 @@ static std::vector<IOBatch> MakeBatches(
// want to use the OS cache between consecutive runs. // want to use the OS cache between consecutive runs.
static void ReadBatches(const BlobReader& reader, static void ReadBatches(const BlobReader& reader,
const std::vector<IOBatch>& batches, const std::vector<IOBatch>& batches,
hwy::ThreadPool& pool) { ThreadingContext& ctx) {
static const auto zone = ctx.profiler.AddZone("Startup.Weights.ReadBatches");
// >5x speedup from parallel reads when cached. // >5x speedup from parallel reads when cached.
pool.Run(0, batches.size(), [&](uint64_t i, size_t thread) { ctx.pools.Pool().Run(0, batches.size(), [&](uint64_t i, size_t thread) {
PROFILER_ZONE2(thread, "Startup.Weights.Read"); PROFILER_ZONE3(ctx.profiler, thread, zone);
const IOBatch& batch = batches[i]; const IOBatch& batch = batches[i];
const std::string& key = reader.Keys()[batch.KeyIdx()]; const std::string& key = reader.Keys()[batch.KeyIdx()];
const uint64_t bytes_read = batch.Read(reader.file()); const uint64_t bytes_read = batch.Read(reader.file());
@ -500,16 +503,14 @@ static MapPtr MapOrReadAll(std::vector<TensorToRead>& tensors,
AllocateAndBindAll(tensors, *mode, mat_owners, ctx); AllocateAndBindAll(tensors, *mode, mat_owners, ctx);
} }
hwy::ThreadPool& pool = ctx.pools.Pool();
if (*mode == WeightsPtrs::Mode::kReadBF16) { if (*mode == WeightsPtrs::Mode::kReadBF16) {
ReadAllToBF16(tensors, reader, pool); ReadAllToBF16(tensors, reader, ctx);
return MapPtr(); return MapPtr();
} }
const std::vector<IOBatch> batches = const std::vector<IOBatch> batches =
MakeBatches(tensors, reader.file_bytes()); MakeBatches(tensors, reader.file_bytes());
ReadBatches(reader, batches, pool); ReadBatches(reader, batches, ctx);
return MapPtr(); return MapPtr();
} }
@ -519,7 +520,7 @@ WeightsPtrs::Mode WeightsPtrs::ReadFromBlobs(const ModelStore& model,
const InferenceArgs& inference, const InferenceArgs& inference,
std::vector<MatOwner>& mat_owners, std::vector<MatOwner>& mat_owners,
ThreadingContext& ctx) { ThreadingContext& ctx) {
PROFILER_ZONE("Startup.ReadFromBlobs"); PROFILER_ZONE("Startup.Weights.ReadFromBlobs");
// List of tensors to read/map, and where from. // List of tensors to read/map, and where from.
std::vector<TensorToRead> tensors; std::vector<TensorToRead> tensors;

View File

@ -720,32 +720,33 @@ struct MMArgs {
// Wrapper over hwy::Zone that is only enabled when autotuning finished. // Wrapper over hwy::Zone that is only enabled when autotuning finished.
#if PROFILER_ENABLED #if PROFILER_ENABLED
class MMZone { class MMZone {
using Zone = hwy::Zone; using Zone = hwy::profiler::Zone;
static_assert(alignof(Zone) <= 8 && sizeof(Zone) <= 8); static_assert(alignof(Zone) <= 8 && sizeof(Zone) <= 16);
public: public:
~MMZone() { ~MMZone() {
if (used_) { if (data_ != 0) {
Zone* zone = reinterpret_cast<Zone*>(&data_); Zone* zone = reinterpret_cast<Zone*>(&data_);
zone->~Zone(); zone->~Zone();
} }
} }
// `name` must be a string literal. // `name` must be a string literal.
void MaybeEnter(size_t thread_id, uint32_t zone_id, const MMArgs& args) { void MaybeEnter(size_t thread, hwy::profiler::ZoneHandle zone,
const MMArgs& args) {
if (args.per_key->WantProfile()) { if (args.per_key->WantProfile()) {
new (&data_) Zone(thread_id, zone_id); new (&data_) Zone(args.env->ctx.profiler, thread, zone);
used_ = true; HWY_DASSERT(data_ != 0);
} }
} }
private: private:
uint64_t data_ = 0; uint64_t data_ = 0;
bool used_ = false; uint64_t data2_ = 0;
}; };
#else #else
struct MMZone { struct MMZone {
void MaybeEnter(size_t, uint32_t, const MMArgs&) {} void MaybeEnter(size_t, hwy::profiler::ZoneHandle, const MMArgs&) {}
}; };
#endif // PROFILER_ENABLED #endif // PROFILER_ENABLED

View File

@ -125,9 +125,9 @@ HWY_INLINE hn::Vec<D> Gelu(D d, hn::Vec<D> v) {
return hn::Mul(v, cdf); return hn::Mul(v, cdf);
} }
// Activation already has a profiler zone.
static HWY_NOINLINE HWY_MAYBE_UNUSED void Gelu(float* HWY_RESTRICT x, static HWY_NOINLINE HWY_MAYBE_UNUSED void Gelu(float* HWY_RESTRICT x,
size_t size) { size_t size) {
PROFILER_ZONE("ops.Gelu");
namespace hn = hwy::HWY_NAMESPACE; namespace hn = hwy::HWY_NAMESPACE;
using D = hn::ScalableTag<float>; using D = hn::ScalableTag<float>;
hn::Transform(D(), x, size, hn::Transform(D(), x, size,
@ -191,9 +191,10 @@ namespace detail {
// Shared by RMSNorm and RMSNormInplace. // Shared by RMSNorm and RMSNormInplace.
template <typename VT> template <typename VT>
float RMSNormMul(const VT* HWY_RESTRICT x, const size_t size, float RMSNormMul(const VT* HWY_RESTRICT x, const size_t size, hwy::Profiler& p,
const HWY_MAYBE_UNUSED size_t worker) { const size_t worker) {
PROFILER_ZONE2(worker, "ops.RMSNormMul"); static const auto zone = p.AddZone("Ops.RMSNormMul");
PROFILER_ZONE3(p, worker, zone);
const hn::ScalableTag<float> d; const hn::ScalableTag<float> d;
const float l2 = DecompressAndCall(d, MakeSpan(x, size), DotKernelDefault()); const float l2 = DecompressAndCall(d, MakeSpan(x, size), DotKernelDefault());
@ -205,18 +206,20 @@ float RMSNormMul(const VT* HWY_RESTRICT x, const size_t size,
// `x_ofs` is the offset within `x`, required for NuqStream. // `x_ofs` is the offset within `x`, required for NuqStream.
template <typename XT, typename WT, typename OT> template <typename XT, typename WT, typename OT>
HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm( HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(const XT* HWY_RESTRICT x,
const XT* HWY_RESTRICT x, const WT* HWY_RESTRICT weight, size_t w_ofs, const WT* HWY_RESTRICT weight,
OT* HWY_RESTRICT out, const size_t size, size_t w_ofs, OT* HWY_RESTRICT out,
const size_t HWY_MAYBE_UNUSED worker) { const size_t size, hwy::Profiler& p,
PROFILER_ZONE2(worker, "ops.RMSNorm"); const size_t worker) {
static const auto zone = p.AddZone("Ops.RMSNorm");
PROFILER_ZONE3(p, worker, zone);
namespace hn = hwy::HWY_NAMESPACE; namespace hn = hwy::HWY_NAMESPACE;
const hn::ScalableTag<float> df; const hn::ScalableTag<float> df;
using VF = hn::Vec<decltype(df)>; using VF = hn::Vec<decltype(df)>;
const size_t NF = hn::Lanes(df); const size_t NF = hn::Lanes(df);
const VF mul = hn::Set(df, detail::RMSNormMul(x, size, worker)); const VF mul = hn::Set(df, detail::RMSNormMul(x, size, p, worker));
const auto packed_x = MakeSpan(x, size); const auto packed_x = MakeSpan(x, size);
const auto packed_w = MakeSpan(weight, w_ofs + size); const auto packed_w = MakeSpan(weight, w_ofs + size);
@ -240,15 +243,16 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
template <typename WT, typename XT> template <typename WT, typename XT>
HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace( HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace(
const WT* HWY_RESTRICT weight, size_t w_ofs, XT* HWY_RESTRICT inout, const WT* HWY_RESTRICT weight, size_t w_ofs, XT* HWY_RESTRICT inout,
const size_t size, const HWY_MAYBE_UNUSED size_t worker) { const size_t size, hwy::Profiler& p, const size_t worker) {
PROFILER_ZONE2(worker, "ops.RMSNormInplace"); static const auto zone = p.AddZone("Ops.RMSNormInplace");
PROFILER_ZONE3(p, worker, zone);
namespace hn = hwy::HWY_NAMESPACE; namespace hn = hwy::HWY_NAMESPACE;
const hn::ScalableTag<float> df; const hn::ScalableTag<float> df;
using VF = hn::Vec<decltype(df)>; using VF = hn::Vec<decltype(df)>;
const size_t NF = hn::Lanes(df); const size_t NF = hn::Lanes(df);
const VF mul = hn::Set(df, detail::RMSNormMul(inout, size, worker)); const VF mul = hn::Set(df, detail::RMSNormMul(inout, size, p, worker));
const auto packed_w = MakeSpan(weight, w_ofs + size); const auto packed_w = MakeSpan(weight, w_ofs + size);
const auto packed_x = MakeSpan(inout, size); const auto packed_x = MakeSpan(inout, size);
@ -407,9 +411,10 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void AddAbsolutePositionalEmbeddings(
// This overload is called if `post_qk == PostQKType::HalfRope`. // This overload is called if `post_qk == PostQKType::HalfRope`.
static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope( static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope(
float* HWY_RESTRICT x, const size_t dim_qkv, float* HWY_RESTRICT x, const size_t dim_qkv,
const float* HWY_RESTRICT inv_timescale, const int pos, const float* HWY_RESTRICT inv_timescale, const int pos, hwy::Profiler& p,
const size_t HWY_MAYBE_UNUSED worker = 0) { const size_t worker) {
PROFILER_ZONE2(worker, "ops.Rope"); static const auto zone = p.AddZone("Ops.Rope");
PROFILER_ZONE3(p, worker, zone);
HWY_DASSERT(dim_qkv % 2 == 0); HWY_DASSERT(dim_qkv % 2 == 0);
const size_t half_dim_qkv = dim_qkv / 2; const size_t half_dim_qkv = dim_qkv / 2;
@ -466,9 +471,10 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope(
// `inv_timescale[dim_qkv / 2]` is precomputed in AttentionActivations. // `inv_timescale[dim_qkv / 2]` is precomputed in AttentionActivations.
static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy( static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy(
const float mul, float* HWY_RESTRICT x, const size_t dim_qkv, const float mul, float* HWY_RESTRICT x, const size_t dim_qkv,
const float* HWY_RESTRICT inv_timescale, const int pos, const float* HWY_RESTRICT inv_timescale, const int pos, hwy::Profiler& p,
const size_t HWY_MAYBE_UNUSED worker = 0) { const size_t worker) {
PROFILER_ZONE2(worker, "ops.RopeAndMulBy"); static const auto zone = p.AddZone("Ops.RopeAndMulBy");
PROFILER_ZONE3(p, worker, zone);
HWY_DASSERT(dim_qkv % 2 == 0); HWY_DASSERT(dim_qkv % 2 == 0);
const size_t half_dim_qkv = dim_qkv / 2; const size_t half_dim_qkv = dim_qkv / 2;
@ -525,10 +531,13 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy(
} }
template <typename XT> template <typename XT>
static HWY_NOINLINE HWY_MAYBE_UNUSED void AddFrom( static HWY_NOINLINE HWY_MAYBE_UNUSED void AddFrom(const XT* HWY_RESTRICT x,
const XT* HWY_RESTRICT x, float* HWY_RESTRICT out, const size_t size, float* HWY_RESTRICT out,
const HWY_MAYBE_UNUSED size_t worker) { const size_t size,
PROFILER_ZONE2(worker, "ops.AddFrom"); hwy::Profiler& p,
const size_t worker) {
static const auto zone = p.AddZone("Ops.AddFrom");
PROFILER_ZONE3(p, worker, zone);
namespace hn = hwy::HWY_NAMESPACE; namespace hn = hwy::HWY_NAMESPACE;
const hn::ScalableTag<float> df; const hn::ScalableTag<float> df;
@ -576,11 +585,10 @@ void RMSNormBatched(const MatPtrT<XT>& activations, const MatPtr& weights,
HWY_DASSERT(activations.SameShape(out)); HWY_DASSERT(activations.SameShape(out));
CallUpcasted(&weights, [&](const auto* weights_t) { CallUpcasted(&weights, [&](const auto* weights_t) {
SmallParallelFor(activations.Rows(), ctx.pools, SmallParallelFor(
[&](uint64_t token_idx, size_t worker) { activations.Rows(), ctx.pools, [&](uint64_t token_idx, size_t worker) {
RMSNorm(activations.Row(token_idx), RMSNorm(activations.Row(token_idx), weights_t->PackedScale1(), 0,
weights_t->PackedScale1(), 0, out.Row(token_idx), out.Row(token_idx), activations.Cols(), ctx.profiler, worker);
activations.Cols(), worker);
}); });
}); });
} }
@ -592,11 +600,10 @@ void RMSNormInplaceBatched(const MatPtr& weights, MatPtrT<XT>& inout,
HWY_DASSERT(weights.Cols() == inout.Cols()); HWY_DASSERT(weights.Cols() == inout.Cols());
CallUpcasted(&weights, [&](const auto* weights_t) { CallUpcasted(&weights, [&](const auto* weights_t) {
SmallParallelFor(inout.Rows(), ctx.pools, SmallParallelFor(
[&](uint64_t token_idx, size_t worker) { inout.Rows(), ctx.pools, [&](uint64_t token_idx, size_t worker) {
RMSNormInplace(weights_t->PackedScale1(), 0, RMSNormInplace(weights_t->PackedScale1(), 0, inout.Row(token_idx),
inout.Row(token_idx), inout.Cols(), inout.Cols(), ctx.profiler, worker);
worker);
}); });
}); });
} }
@ -622,17 +629,20 @@ template <typename XT>
static HWY_INLINE void AddFromBatched(const MatPtrT<XT>& x, MatPtrT<float>& out, static HWY_INLINE void AddFromBatched(const MatPtrT<XT>& x, MatPtrT<float>& out,
ThreadingContext& ctx) { ThreadingContext& ctx) {
HWY_DASSERT(out.SameShape(x)); HWY_DASSERT(out.SameShape(x));
SmallParallelFor( SmallParallelFor(out.Rows(), ctx.pools,
out.Rows(), ctx.pools, [&](uint64_t token_idx, size_t worker) { [&](uint64_t token_idx, size_t worker) {
AddFrom(x.Row(token_idx), out.Row(token_idx), x.Cols(), worker); AddFrom(x.Row(token_idx), out.Row(token_idx), x.Cols(),
ctx.profiler, worker);
}); });
} }
template <typename XT> template <typename XT>
HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConst( HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConst(const float c, XT* HWY_RESTRICT x,
const float c, XT* HWY_RESTRICT x, const size_t size, const size_t size,
const HWY_MAYBE_UNUSED size_t worker) { hwy::Profiler& p,
PROFILER_ZONE2(worker, "ops.MulByConst"); const size_t worker) {
static const auto zone = p.AddZone("Ops.MulByConst");
PROFILER_ZONE3(p, worker, zone);
namespace hn = hwy::HWY_NAMESPACE; namespace hn = hwy::HWY_NAMESPACE;
const hn::ScalableTag<float> df; const hn::ScalableTag<float> df;
const size_t NF = hn::Lanes(df); const size_t NF = hn::Lanes(df);
@ -672,8 +682,9 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConst(
template <typename XT, typename OT> template <typename XT, typename OT>
HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstTo( HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstTo(
const float c, const XT* HWY_RESTRICT x, OT* HWY_RESTRICT out, const float c, const XT* HWY_RESTRICT x, OT* HWY_RESTRICT out,
const size_t size, const HWY_MAYBE_UNUSED size_t worker) { const size_t size, hwy::Profiler& p, const size_t worker) {
PROFILER_ZONE2(worker, "ops.MulByConstTo"); static const auto zone = p.AddZone("Ops.MulByConstTo");
PROFILER_ZONE3(p, worker, zone);
namespace hn = hwy::HWY_NAMESPACE; namespace hn = hwy::HWY_NAMESPACE;
const hn::ScalableTag<float> df; const hn::ScalableTag<float> df;
const size_t NF = hn::Lanes(df); const size_t NF = hn::Lanes(df);
@ -714,8 +725,9 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstTo(
template <typename XT, typename OT> template <typename XT, typename OT>
HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAdd( HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAdd(
const float c, const XT* HWY_RESTRICT x, OT* HWY_RESTRICT out, const float c, const XT* HWY_RESTRICT x, OT* HWY_RESTRICT out,
const size_t size, const HWY_MAYBE_UNUSED size_t worker) { const size_t size, hwy::Profiler& p, const size_t worker) {
PROFILER_ZONE2(worker, "ops.MulByConstAndAdd"); static const auto zone = p.AddZone("Ops.MulByConstAndAdd");
PROFILER_ZONE3(p, worker, zone);
namespace hn = hwy::HWY_NAMESPACE; namespace hn = hwy::HWY_NAMESPACE;
const hn::ScalableTag<float> df; const hn::ScalableTag<float> df;
const size_t NF = hn::Lanes(df); const size_t NF = hn::Lanes(df);
@ -760,9 +772,10 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAdd(
// See below for a specialized version for top-1 sampling. // See below for a specialized version for top-1 sampling.
static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size, static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size,
const size_t worker, hwy::Profiler& p, const size_t worker,
float temperature = 1.0f) { float temperature = 1.0f) {
PROFILER_ZONE2(worker, "ops.Softmax"); static const auto zone = p.AddZone("Ops.Softmax");
PROFILER_ZONE3(p, worker, zone);
HWY_DASSERT(size != 0); HWY_DASSERT(size != 0);
namespace hn = hwy::HWY_NAMESPACE; namespace hn = hwy::HWY_NAMESPACE;
@ -803,7 +816,7 @@ static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size,
const float sum_exp = Sum(d, x, size); const float sum_exp = Sum(d, x, size);
// Double-precision reciprocal does not appear to affect the results. // Double-precision reciprocal does not appear to affect the results.
const float mul = 1.0f / sum_exp; const float mul = 1.0f / sum_exp;
MulByConst(mul, x, size, worker); MulByConst(mul, x, size, p, worker);
} }
// Note: https://arxiv.org/pdf/2001.04438 proposes to replace the three max / // Note: https://arxiv.org/pdf/2001.04438 proposes to replace the three max /
@ -893,9 +906,10 @@ static HWY_MAYBE_UNUSED TokenAndProb Top1OfSoftmax(float* HWY_RESTRICT x,
} }
static HWY_NOINLINE void LogitsSoftCap(const float cap, float* HWY_RESTRICT x, static HWY_NOINLINE void LogitsSoftCap(const float cap, float* HWY_RESTRICT x,
const size_t size, const size_t size, hwy::Profiler& p,
const HWY_MAYBE_UNUSED size_t worker) { const size_t worker) {
PROFILER_ZONE2(worker, "ops.LogitsSoftCap"); static const auto zone = p.AddZone("Ops.LogitsSoftCap");
PROFILER_ZONE3(p, worker, zone);
namespace hn = hwy::HWY_NAMESPACE; namespace hn = hwy::HWY_NAMESPACE;
using D = hn::ScalableTag<float>; using D = hn::ScalableTag<float>;
@ -911,10 +925,10 @@ static HWY_NOINLINE void LogitsSoftCap(const float cap, float* HWY_RESTRICT x,
// Calls LogitsSoftCap if cap != 0.0f. // Calls LogitsSoftCap if cap != 0.0f.
static HWY_INLINE HWY_MAYBE_UNUSED void MaybeLogitsSoftCap( static HWY_INLINE HWY_MAYBE_UNUSED void MaybeLogitsSoftCap(
const float cap, float* HWY_RESTRICT x, const size_t size, const float cap, float* HWY_RESTRICT x, const size_t size, hwy::Profiler& p,
const size_t worker) { const size_t worker) {
if (cap != 0.0f) { if (cap != 0.0f) {
LogitsSoftCap(cap, x, size, worker); LogitsSoftCap(cap, x, size, p, worker);
} }
} }
@ -998,7 +1012,7 @@ template <typename TAcceptToken>
HWY_NOINLINE HWY_MAYBE_UNUSED TokenAndProb FusedSoftmaxAndSampleTopK( HWY_NOINLINE HWY_MAYBE_UNUSED TokenAndProb FusedSoftmaxAndSampleTopK(
const float* HWY_RESTRICT logits, size_t k, size_t vocab_size, const float* HWY_RESTRICT logits, size_t k, size_t vocab_size,
std::mt19937& gen, float temperature, TAcceptToken& accept_token, std::mt19937& gen, float temperature, TAcceptToken& accept_token,
size_t worker) { hwy::Profiler& p, size_t worker) {
// Softmax and sample top-K is equivalent to taking the top-K logits and // Softmax and sample top-K is equivalent to taking the top-K logits and
// sampling from the softmax of the top-K logits. The latter is faster as it // sampling from the softmax of the top-K logits. The latter is faster as it
// avoids computing the softmax of all logits. // avoids computing the softmax of all logits.
@ -1012,7 +1026,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED TokenAndProb FusedSoftmaxAndSampleTopK(
} }
size_t mask = token_logits.size(); size_t mask = token_logits.size();
Softmax(topk_logits.data(), mask, worker, temperature); Softmax(topk_logits.data(), mask, p, worker, temperature);
auto distribution = std::discrete_distribution<int>( auto distribution = std::discrete_distribution<int>(
std::begin(topk_logits), std::begin(topk_logits) + mask); std::begin(topk_logits), std::begin(topk_logits) + mask);
int topk_sampled_index = distribution(gen); int topk_sampled_index = distribution(gen);

View File

@ -18,8 +18,6 @@
#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS #define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS
#endif // HWY_DISABLED_TARGETS #endif // HWY_DISABLED_TARGETS
#include "ops/ops.h"
#include <stddef.h> #include <stddef.h>
#include <stdio.h> #include <stdio.h>
@ -33,11 +31,13 @@
#include "gemma/activations.h" // ChooseQueryScale #include "gemma/activations.h" // ChooseQueryScale
#include "gemma/configs.h" #include "gemma/configs.h"
#include "ops/ops.h"
#include "util/allocator.h" #include "util/allocator.h"
#include "util/basics.h" // BF16 #include "util/basics.h" // BF16
#include "util/mat.h" // MatStorageT #include "util/mat.h" // MatStorageT
#include "util/test_util.h" #include "util/test_util.h"
#include "util/threading_context.h" #include "util/threading_context.h"
#include "hwy/profiler.h"
#include "hwy/tests/hwy_gtest.h" #include "hwy/tests/hwy_gtest.h"
// clang-format off // clang-format off
@ -166,7 +166,7 @@ struct TestAddFrom {
} }
SimpleAddFrom(o, e, count); SimpleAddFrom(o, e, count);
AddFrom(o, x, count, /*worker=*/0); AddFrom(o, x, count, hwy::Profiler::Get(), /*worker=*/0);
hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__, hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__,
__LINE__); __LINE__);
@ -199,7 +199,7 @@ struct TestMulByConstAndAdd {
T constant = Random<T>(rng); T constant = Random<T>(rng);
SimpleMulByConstAndAdd(constant, o, e, count); SimpleMulByConstAndAdd(constant, o, e, count);
MulByConstAndAdd(constant, o, x, count, /*worker=*/0); MulByConstAndAdd(constant, o, x, count, hwy::Profiler::Get(), /*worker=*/0);
hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__, hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__,
__LINE__); __LINE__);
@ -229,7 +229,7 @@ struct TestMulByConst {
T constant = Random<T>(rng); T constant = Random<T>(rng);
SimpleMulByConst(constant, e, count); SimpleMulByConst(constant, e, count);
MulByConst(constant, x, count, /*worker=*/0); MulByConst(constant, x, count, hwy::Profiler::Get(), /*worker=*/0);
hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__, hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__,
__LINE__); __LINE__);
@ -259,7 +259,7 @@ struct TestSoftmax {
} }
SimpleSoftmax(e, count); SimpleSoftmax(e, count);
Softmax(x, count, /*worker=*/0); Softmax(x, count, hwy::Profiler::Get(), /*worker=*/0);
T sum = 0.0f; T sum = 0.0f;
for (size_t i = 0; i < count; ++i) { for (size_t i = 0; i < count; ++i) {
@ -349,6 +349,9 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void ScalarRopeAndMulBy(
void TestRopeAndMulBy() { void TestRopeAndMulBy() {
ThreadingArgs threading_args; ThreadingArgs threading_args;
ThreadingContext ctx(threading_args); ThreadingContext ctx(threading_args);
hwy::Profiler& p = ctx.profiler;
const size_t worker = 0;
const ModelConfig config(Model::GEMMA2_9B, Type::kSFP, const ModelConfig config(Model::GEMMA2_9B, Type::kSFP,
ChooseWrapping(Model::GEMMA2_9B)); ChooseWrapping(Model::GEMMA2_9B));
const size_t dim_qkv = config.layer_configs[0].qkv_dim; const size_t dim_qkv = config.layer_configs[0].qkv_dim;
@ -381,7 +384,8 @@ void TestRopeAndMulBy() {
CopyMat(x, qactual); CopyMat(x, qactual);
ScalarRopeAndMulBy(qmul, qexpected.Row(0), dim_qkv, inv_timescale.Row(0), ScalarRopeAndMulBy(qmul, qexpected.Row(0), dim_qkv, inv_timescale.Row(0),
pos); pos);
RopeAndMulBy(qmul, qactual.Row(0), dim_qkv, inv_timescale.Row(0), pos); RopeAndMulBy(qmul, qactual.Row(0), dim_qkv, inv_timescale.Row(0), pos, p,
worker);
for (size_t i = 0; i < dim_qkv; ++i) { for (size_t i = 0; i < dim_qkv; ++i) {
EXPECT_NEAR(qexpected.Row(0)[i], qactual.Row(0)[i], 1e-4) << " " << i; EXPECT_NEAR(qexpected.Row(0)[i], qactual.Row(0)[i], 1e-4) << " " << i;
} }
@ -391,7 +395,7 @@ void TestRopeAndMulBy() {
CopyMat(x, qactual); CopyMat(x, qactual);
ScalarRopeAndMulBy(1.0f, qexpected.Row(0), dim_qkv, inv_timescale.Row(0), ScalarRopeAndMulBy(1.0f, qexpected.Row(0), dim_qkv, inv_timescale.Row(0),
pos); pos);
Rope(qactual.Row(0), dim_qkv, inv_timescale.Row(0), pos); Rope(qactual.Row(0), dim_qkv, inv_timescale.Row(0), pos, p, worker);
for (size_t i = 0; i < dim_qkv; ++i) { for (size_t i = 0; i < dim_qkv; ++i) {
EXPECT_NEAR(qexpected.Row(0)[i], qactual.Row(0)[i], 1e-4) << " " << i; EXPECT_NEAR(qexpected.Row(0)[i], qactual.Row(0)[i], 1e-4) << " " << i;
} }
@ -402,9 +406,10 @@ void TestRopeAndMulBy() {
CopyMat(x, kactual2); CopyMat(x, kactual2);
ScalarRopeAndMulBy(kmul, kexpected.Row(0), dim_qkv, inv_timescale.Row(0), ScalarRopeAndMulBy(kmul, kexpected.Row(0), dim_qkv, inv_timescale.Row(0),
pos); pos);
RopeAndMulBy(kmul, kactual.Row(0), dim_qkv, inv_timescale.Row(0), pos); RopeAndMulBy(kmul, kactual.Row(0), dim_qkv, inv_timescale.Row(0), pos, p,
worker);
static_assert(kmul == 1.0f, ""); static_assert(kmul == 1.0f, "");
Rope(kactual2.Row(0), dim_qkv, inv_timescale.Row(0), pos); Rope(kactual2.Row(0), dim_qkv, inv_timescale.Row(0), pos, p, worker);
for (size_t i = 0; i < dim_qkv; ++i) { for (size_t i = 0; i < dim_qkv; ++i) {
EXPECT_NEAR(kexpected.Row(0)[i], kactual.Row(0)[i], 1e-4) << " " << i; EXPECT_NEAR(kexpected.Row(0)[i], kactual.Row(0)[i], 1e-4) << " " << i;
@ -454,7 +459,7 @@ void TestRMSNorm(hwy::RandomState& rng) {
} }
ScalarRMSNorm(vec, weight, expected, kSize); ScalarRMSNorm(vec, weight, expected, kSize);
RMSNorm(vec, weight, 0, actual, kSize, /*worker=*/0); RMSNorm(vec, weight, 0, actual, kSize, hwy::Profiler::Get(), /*worker=*/0);
for (size_t i = 0; i < kSize; i++) { for (size_t i = 0; i < kSize; i++) {
const float e = hwy::ConvertScalarTo<float>(expected[i]); const float e = hwy::ConvertScalarTo<float>(expected[i]);
@ -580,11 +585,13 @@ void TestAllLayerNorm() {
} }
void TestSampleTopK() { void TestSampleTopK() {
hwy::Profiler& p = hwy::Profiler::Get();
const size_t worker = 0;
const size_t kSize = 52; const size_t kSize = 52;
std::vector<float> logits(kSize); std::vector<float> logits(kSize);
// Create a vector going from -100 to -100+51=49 and take Softmax. // Create a vector going from -100 to -100+51=49 and take Softmax.
std::iota(logits.begin(), logits.end(), -100.0f); std::iota(logits.begin(), logits.end(), -100.0f);
Softmax(logits.data(), kSize, /*worker=*/0); Softmax(logits.data(), kSize, p, worker);
std::mt19937 gen; std::mt19937 gen;
gen.seed(0x12345678); gen.seed(0x12345678);
float temperature = 1.0f; float temperature = 1.0f;
@ -600,7 +607,7 @@ void TestSampleTopK() {
EXPECT_EQ(sample, 50); // Last even index. EXPECT_EQ(sample, 50); // Last even index.
// Reset the logits to a positive, increasing sequence and take Softmax. // Reset the logits to a positive, increasing sequence and take Softmax.
std::iota(logits.begin(), logits.end(), 1.0f); std::iota(logits.begin(), logits.end(), 1.0f);
Softmax(logits.data(), kSize, /*worker=*/0); Softmax(logits.data(), kSize, p, worker);
// Sample from the top 3, expect one of the top 3 even indices. // Sample from the top 3, expect one of the top 3 even indices.
for (int i = 0; i < 100; ++i) { for (int i = 0; i < 100; ++i) {
sample = SampleTopK(logits.data(), /*k=*/3, kSize, gen, temperature, sample = SampleTopK(logits.data(), /*k=*/3, kSize, gen, temperature,

View File

@ -29,6 +29,7 @@
#include "hwy/base.h" #include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/contrib/thread_pool/topology.h" #include "hwy/contrib/thread_pool/topology.h"
#include "hwy/profiler.h"
namespace gcpp { namespace gcpp {
@ -171,6 +172,8 @@ NestedPools::NestedPools(const BoundedTopology& topology,
HWY_ASSERT(max_clusters_per_package_ <= 64); HWY_ASSERT(max_clusters_per_package_ <= 64);
HWY_ASSERT(max_workers_per_cluster_ >= 1); HWY_ASSERT(max_workers_per_cluster_ >= 1);
HWY_ASSERT(max_workers_per_cluster_ <= 256); HWY_ASSERT(max_workers_per_cluster_ <= 256);
hwy::Profiler::Get().SetMaxThreads(MaxWorkers());
} }
// `max_or_zero` == 0 means no limit. // `max_or_zero` == 0 means no limit.

View File

@ -72,7 +72,8 @@ static void TunePool(hwy::ThreadPool& pool) {
} }
ThreadingContext::ThreadingContext(const ThreadingArgs& args) ThreadingContext::ThreadingContext(const ThreadingArgs& args)
: topology(BoundedSlice(args.skip_packages, args.max_packages), : profiler(hwy::Profiler::Get()),
topology(BoundedSlice(args.skip_packages, args.max_packages),
BoundedSlice(args.skip_clusters, args.max_clusters), BoundedSlice(args.skip_clusters, args.max_clusters),
BoundedSlice(args.skip_lps, args.max_lps)), BoundedSlice(args.skip_lps, args.max_lps)),
allocator(topology, args.bind != Tristate::kFalse), allocator(topology, args.bind != Tristate::kFalse),

View File

@ -90,6 +90,7 @@ struct ThreadingContext {
// Expected to be called early in the program, before threading starts. // Expected to be called early in the program, before threading starts.
explicit ThreadingContext(const ThreadingArgs& args); explicit ThreadingContext(const ThreadingArgs& args);
hwy::Profiler& profiler;
BoundedTopology topology; BoundedTopology topology;
Allocator allocator; Allocator allocator;
NestedPools pools; NestedPools pools;