Added a global (to gemma) zones list to enable most call sites to PROFILER_ZONE3 to avoid the sychronization required for the static const initialization of the zone handle.

Improved flash_attention to enable profiling using the new zones.

PiperOrigin-RevId: 819235421
This commit is contained in:
Ray Smith 2025-10-14 08:30:23 -07:00 committed by Copybara-Service
parent 035273c184
commit fb6fa793f4
14 changed files with 247 additions and 77 deletions

View File

@ -111,6 +111,7 @@ cc_library(
":basics", ":basics",
":threading", ":threading",
":topology", ":topology",
":zones",
"@highway//:hwy", "@highway//:hwy",
"@highway//:hwy_test_util", "@highway//:hwy_test_util",
"@highway//:profiler", "@highway//:profiler",
@ -118,6 +119,15 @@ cc_library(
], ],
) )
cc_library(
name = "zones",
srcs = ["util/zones.cc"],
hdrs = ["util/zones.h"],
deps = [
"@highway//:profiler",
],
)
cc_test( cc_test(
name = "flash_attention_test", name = "flash_attention_test",
srcs = ["gemma/flash_attention_test.cc"], srcs = ["gemma/flash_attention_test.cc"],
@ -263,6 +273,7 @@ cc_library(
":model_store", ":model_store",
":tensor_info", ":tensor_info",
":threading_context", ":threading_context",
":zones",
"//compression:compress", "//compression:compress",
"//io:blob_store", "//io:blob_store",
"@highway//:hwy", "@highway//:hwy",
@ -321,6 +332,7 @@ cc_library(
":matmul_env", ":matmul_env",
":threading", ":threading",
":threading_context", ":threading_context",
":zones",
"//compression:compress", "//compression:compress",
"@highway//:bit_set", "@highway//:bit_set",
"@highway//:hwy", "@highway//:hwy",
@ -352,6 +364,7 @@ cc_library(
":matmul", ":matmul",
":matmul_env", ":matmul_env",
":threading_context", ":threading_context",
":zones",
"//compression:compress", "//compression:compress",
"//compression:types", "//compression:types",
"@highway//:hwy", "@highway//:hwy",
@ -376,6 +389,7 @@ cc_library(
":matmul_env", # MMOptions ":matmul_env", # MMOptions
":matmul_static", ":matmul_static",
":threading_context", ":threading_context",
":zones",
"//compression:compress", "//compression:compress",
"@highway//:algo", "@highway//:algo",
"@highway//:bit_set", "@highway//:bit_set",
@ -431,6 +445,7 @@ cc_test(
":ops", ":ops",
":test_util", ":test_util",
":threading_context", ":threading_context",
":zones",
"@googletest//:gtest_main", # buildcleaner: keep "@googletest//:gtest_main", # buildcleaner: keep
"//compression:test_util", "//compression:test_util",
"//compression:types", "//compression:types",
@ -556,6 +571,7 @@ cc_library(
":threading", ":threading",
":threading_context", ":threading_context",
":weights", ":weights",
":zones",
"//compression:compress", "//compression:compress",
"//compression:types", "//compression:types",
"//io", "//io",

View File

@ -130,6 +130,8 @@ set(SOURCES
util/threading.h util/threading.h
util/topology.cc util/topology.cc
util/topology.h util/topology.h
util/zones.cc
util/zones.h
) )
# Add C API sources only when building DLL # Add C API sources only when building DLL

View File

@ -19,6 +19,7 @@
#include <vector> #include <vector>
#include "compression/types.h" // GEMMA_DISABLED_TARGETS #include "compression/types.h" // GEMMA_DISABLED_TARGETS
#include "util/zones.h"
#ifndef HWY_DISABLED_TARGETS #ifndef HWY_DISABLED_TARGETS
#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS #define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS
#endif // HWY_DISABLED_TARGETS #endif // HWY_DISABLED_TARGETS
@ -55,8 +56,7 @@ static HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos,
const float* HWY_RESTRICT q, const float* HWY_RESTRICT q,
const MatPtrT<KV_t>& k, float* HWY_RESTRICT att, const MatPtrT<KV_t>& k, float* HWY_RESTRICT att,
hwy::Profiler& p, const size_t worker) { hwy::Profiler& p, const size_t worker) {
static const auto zone = p.AddZone("Gen.Attention.QDotK"); PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kGenAttentionQDotK));
PROFILER_ZONE3(p, worker, zone);
if (HWY_LIKELY(last_pos < static_cast<size_t>(div_seq_len.GetDivisor()))) { if (HWY_LIKELY(last_pos < static_cast<size_t>(div_seq_len.GetDivisor()))) {
// Slightly faster: no wraparound. // Slightly faster: no wraparound.
for (size_t pos = start_pos; pos <= last_pos; ++pos) { for (size_t pos = start_pos; pos <= last_pos; ++pos) {
@ -175,7 +175,12 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
const LayerWeightsPtrs& layer, const LayerWeightsPtrs& layer,
AttentionActivations& activations, QBatch& qbatch, AttentionActivations& activations, QBatch& qbatch,
ThreadingContext& ctx) { ThreadingContext& ctx) {
static const auto zone = ctx.profiler.AddZone("Gen.Attention.DotSoftmax.par"); static const auto root_zone =
ctx.profiler.AddZone("Gen.Attention.DotSoftmaxWeightedSumInclusive",
hwy::ProfilerFlags::kInclusive);
PROFILER_ZONE3(ctx.profiler, 0, root_zone);
const auto zone =
GetProfilerZone(Zones::kGenAttentionDotSoftmaxWeightedSumPar);
const hwy::Divisor div_qbatch(qbatch.Size()); const hwy::Divisor div_qbatch(qbatch.Size());
const LayerConfig& layer_config = layer.layer_config; const LayerConfig& layer_config = layer.layer_config;

View File

@ -22,6 +22,7 @@
#include "compression/types.h" // GEMMA_DISABLED_TARGETS #include "compression/types.h" // GEMMA_DISABLED_TARGETS
#include "util/threading_context.h" #include "util/threading_context.h"
#include "util/zones.h"
#ifndef HWY_DISABLED_TARGETS #ifndef HWY_DISABLED_TARGETS
#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS #define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS
#endif // HWY_DISABLED_TARGETS #endif // HWY_DISABLED_TARGETS
@ -60,7 +61,7 @@ static constexpr size_t kNFx8HTileSize = 8;
// possible consecutive elements have the same KV. // possible consecutive elements have the same KV.
static void TransposeQ(const MatPtrT<float>& q, MatPtrT<float>& q_t, static void TransposeQ(const MatPtrT<float>& q, MatPtrT<float>& q_t,
const size_t qbatch_size, ThreadingContext& ctx) { const size_t qbatch_size, ThreadingContext& ctx) {
static const auto zone = ctx.profiler.AddZone("Gen.Attention.TransposeQ"); const auto zone = GetProfilerZone(Zones::kFlashAttentionTransposeQ);
// Group floats by the number of floats in a cache line. // Group floats by the number of floats in a cache line.
const size_t kNF = ctx.cache_info.LineBytes() / sizeof(float); const size_t kNF = ctx.cache_info.LineBytes() / sizeof(float);
const size_t num_heads = q.Cols() / q_t.Rows(); const size_t num_heads = q.Cols() / q_t.Rows();
@ -95,8 +96,8 @@ void RMSNormAndPositionalEncoding(const size_t num_tokens, const QBatch& qbatch,
const LayerWeightsPtrs& layer, const LayerWeightsPtrs& layer,
const AttentionActivations& activations, const AttentionActivations& activations,
ThreadingContext& ctx) { ThreadingContext& ctx) {
static const auto zone = const auto zone =
ctx.profiler.AddZone("Gen.Attention.RMSNormAndPositionalEncoding"); GetProfilerZone(Zones::kFlashAttentionRmsNormAndPositionalEncoding);
const float query_scale = activations.query_scale; const float query_scale = activations.query_scale;
const hwy::Divisor div_qbatch(qbatch.Size()); const hwy::Divisor div_qbatch(qbatch.Size());
const auto func = [&](const size_t task, size_t worker) HWY_ATTR { const auto func = [&](const size_t task, size_t worker) HWY_ATTR {
@ -158,8 +159,8 @@ void SingleFlashAttention(const size_t start_pos, const size_t last_pos,
const AttentionActivations& activations, const AttentionActivations& activations,
float* HWY_RESTRICT att_out, hwy::Profiler& p, float* HWY_RESTRICT att_out, hwy::Profiler& p,
const size_t worker) { const size_t worker) {
static const auto zone = p.AddZone("Gen.Attention.SingleFlashAttention"); PROFILER_ZONE3(p, worker,
PROFILER_ZONE3(p, worker, zone); GetProfilerZone(Zones::kFlashAttentionSingleFlashAttention));
const size_t pos_mod = activations.div_seq_len.Remainder(start_pos); const size_t pos_mod = activations.div_seq_len.Remainder(start_pos);
float m = Dot(q, k.Row(pos_mod), k.Cols()); float m = Dot(q, k.Row(pos_mod), k.Cols());
if (float cap = activations.config.att_cap; cap > 0.0f) { if (float cap = activations.config.att_cap; cap > 0.0f) {
@ -276,8 +277,8 @@ void TileFlashAttention(
const LayerWeightsPtrs& layer, const AttentionActivations& activations, const LayerWeightsPtrs& layer, const AttentionActivations& activations,
MatPtrT<float>& att_out, const uint32_t* HWY_RESTRICT out_offsets, MatPtrT<float>& att_out, const uint32_t* HWY_RESTRICT out_offsets,
hwy::Profiler& p, const size_t worker) { hwy::Profiler& p, const size_t worker) {
static const auto zone = p.AddZone("Gen.Attention.TileFlashAttention"); PROFILER_ZONE3(p, worker,
PROFILER_ZONE3(p, worker, zone); GetProfilerZone(Zones::kFlashAttentionTileFlashAttention));
constexpr int kHTileSize = kNFx8HTileSize; constexpr int kHTileSize = kNFx8HTileSize;
using DF = hn::ScalableTag<float>; using DF = hn::ScalableTag<float>;
const DF df; const DF df;
@ -430,8 +431,8 @@ void TileFlashAttention4(
const LayerWeightsPtrs& layer, const AttentionActivations& activations, const LayerWeightsPtrs& layer, const AttentionActivations& activations,
MatPtrT<float>& att_out, const uint32_t* HWY_RESTRICT out_offsets, MatPtrT<float>& att_out, const uint32_t* HWY_RESTRICT out_offsets,
hwy::Profiler& p, const size_t worker) { hwy::Profiler& p, const size_t worker) {
static const auto zone = p.AddZone("Gen.Attention.TileFlashAttention4"); PROFILER_ZONE3(p, worker,
PROFILER_ZONE3(p, worker, zone); GetProfilerZone(Zones::kFlashAttentionTileFlashAttention4));
using DF = hn::ScalableTag<float>; using DF = hn::ScalableTag<float>;
const DF df; const DF df;
using VF = hn::Vec<DF>; using VF = hn::Vec<DF>;
@ -524,6 +525,21 @@ static size_t RoundToSuitablePowerOf2(size_t n) {
return 32; return 32;
} }
// The vertical tile size is determined by the ability to use tiling and the
// target_parallelism. In practice the possible tile sizes in order of
// preference for efficiency are kNF, 4, 1, where kNF is likely to be 4 8 or
// 16. The final tile size is chosen to be the largest possible that allows
// for target_parallelism parallel tasks.
size_t GetVTileSize(size_t kNF, size_t num_head_groups, size_t num_tokens,
size_t total_tasks, size_t target_parallelism) {
const size_t kMaxEqualK =
RoundToSuitablePowerOf2(num_head_groups * num_tokens);
const size_t kMinTileSize = (total_tasks / 4 >= target_parallelism) ? 4 : 1;
return (kNF <= kMaxEqualK && total_tasks / kNF >= target_parallelism)
? kNF
: std::min(kMinTileSize, kMaxEqualK);
}
// The nominal aim of attention is to combine 3 inputs Q[L,D], K[L,D], V[L,D] // The nominal aim of attention is to combine 3 inputs Q[L,D], K[L,D], V[L,D]
// into a single output O[L,D]. // into a single output O[L,D].
// Conventional attention first computes A[L,L] = Q . KT // Conventional attention first computes A[L,L] = Q . KT
@ -582,7 +598,10 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
const size_t layer_idx, const LayerWeightsPtrs& layer, const size_t layer_idx, const LayerWeightsPtrs& layer,
AttentionActivations& activations, QBatch& qbatch, AttentionActivations& activations, QBatch& qbatch,
ThreadingContext& ctx) { ThreadingContext& ctx) {
static const auto zone = ctx.profiler.AddZone("Gen.Attention.FlashAttention"); static const auto root_zone = ctx.profiler.AddZone(
"FlashAttention.Inclusive", hwy::ProfilerFlags::kInclusive);
PROFILER_ZONE3(ctx.profiler, 0, root_zone);
const auto zone = GetProfilerZone(Zones::kFlashAttentionFlashAttention);
RMSNormAndPositionalEncoding(num_tokens, qbatch, activations.q, layer_idx, RMSNormAndPositionalEncoding(num_tokens, qbatch, activations.q, layer_idx,
layer, activations, ctx); layer, activations, ctx);
const hwy::Divisor div_qbatch(qbatch.Size()); const hwy::Divisor div_qbatch(qbatch.Size());
@ -603,17 +622,8 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
const size_t kNF = hn::Lanes(df); const size_t kNF = hn::Lanes(df);
constexpr size_t kMaxNF = hn::MaxLanes(df); constexpr size_t kMaxNF = hn::MaxLanes(df);
HWY_DASSERT(kNF <= kMaxNF); HWY_DASSERT(kNF <= kMaxNF);
// The vertical tile size is determined by the ability to use tiling and the const size_t kVTileSize = GetVTileSize(kNF, kHeadGroups, num_tokens,
// target_parallelism. In practice the possible tile sizes in order of total_tasks, target_parallelism);
// preference for efficiency are kNF, 4, 1, where kNF is likely to be 4 8 or
// 16. The final tile size is chosen to be the largest possible that allows
// for target_parallelism parallel tasks.
const size_t kMaxEqualK = RoundToSuitablePowerOf2(kHeadGroups * num_tokens);
const size_t kMinTileSize = (total_tasks / 4 >= target_parallelism) ? 4 : 1;
const size_t kVTileSize =
(kNF <= kMaxEqualK && total_tasks / kNF >= target_parallelism)
? kNF
: std::min(kMinTileSize, kMaxEqualK);
// Only transpose Q if we are using tiling. // Only transpose Q if we are using tiling.
if (kVTileSize == kNF) { if (kVTileSize == kNF) {
size_t max_last = 0, min_start = std::numeric_limits<size_t>::max(); size_t max_last = 0, min_start = std::numeric_limits<size_t>::max();

View File

@ -42,6 +42,9 @@ namespace gcpp {
float* HWY_RESTRICT att_out, hwy::Profiler& p, \ float* HWY_RESTRICT att_out, hwy::Profiler& p, \
size_t worker); \ size_t worker); \
\ \
size_t GetVTileSize(size_t kNF, size_t num_head_groups, size_t num_tokens, \
size_t total_tasks, size_t target_parallelism); \
\
void FlashAttention(size_t num_tokens, size_t target_parallelism, \ void FlashAttention(size_t num_tokens, size_t target_parallelism, \
size_t layer_idx, const LayerWeightsPtrs& layer, \ size_t layer_idx, const LayerWeightsPtrs& layer, \
AttentionActivations& activations, QBatch& qbatch, \ AttentionActivations& activations, QBatch& qbatch, \

View File

@ -101,7 +101,6 @@ void AssertClose(const MatPtrT<float>& a, const MatPtrT<float>& b) {
void TestFlashAttention(size_t target_parallelism) { void TestFlashAttention(size_t target_parallelism) {
ThreadingArgs threading_args; ThreadingArgs threading_args;
ThreadingContext ctx(threading_args); ThreadingContext ctx(threading_args);
// hwy::ThreadPool& pool = ctx.pools.Pool();
constexpr size_t kOuter = 1024; constexpr size_t kOuter = 1024;
constexpr size_t kInner = 256; constexpr size_t kInner = 256;
ModelConfig config(Model::GEMMA2_2B, Type::kF32, PromptWrapping::GEMMA_PT); ModelConfig config(Model::GEMMA2_2B, Type::kF32, PromptWrapping::GEMMA_PT);
@ -150,9 +149,19 @@ void TestFlashAttention(size_t target_parallelism) {
// Copy the output to saved_att to allow for comparison. // Copy the output to saved_att to allow for comparison.
auto saved_att = MakeCopyOfMat(attention.att_out, ctx.allocator); auto saved_att = MakeCopyOfMat(attention.att_out, ctx.allocator);
SetMat(1, attention.q); SetMat(1, attention.q);
using DF = hn::ScalableTag<float>;
const DF df;
const size_t kNF = hn::Lanes(df);
const size_t total_tasks =
tokens.size() * div_qbatch.GetDivisor() * layer_config.heads;
const size_t kVTileSize = GetVTileSize(kNF, kHeadGroups, tokens.size(),
total_tasks, target_parallelism);
printf("FlashAttention: target_parallelism=%zu, kNF=%zu, kVTileSize=%zu\n",
target_parallelism, kNF, kVTileSize);
FlashAttention(tokens.size(), target_parallelism, 0, layers, attention, FlashAttention(tokens.size(), target_parallelism, 0, layers, attention,
qbatch, ctx); qbatch, ctx);
AssertClose(attention.att_out, *saved_att); AssertClose(attention.att_out, *saved_att);
ctx.profiler.PrintResults();
} }
void TestAttention() { void TestAttention() {

View File

@ -24,6 +24,7 @@
#include "ops/matmul.h" #include "ops/matmul.h"
#include "util/mat.h" #include "util/mat.h"
#include "util/threading.h" #include "util/threading.h"
#include "util/zones.h"
#include "hwy/profiler.h" #include "hwy/profiler.h"
// Include guard (still compiled once per target) // Include guard (still compiled once per target)
@ -48,8 +49,7 @@ template <typename T1, typename T2>
void Activation(ActivationType activation, T1* HWY_RESTRICT c1, void Activation(ActivationType activation, T1* HWY_RESTRICT c1,
const T2* HWY_RESTRICT c2, const size_t count, hwy::Profiler& p, const T2* HWY_RESTRICT c2, const size_t count, hwy::Profiler& p,
const size_t worker) { const size_t worker) {
static const auto zone = p.AddZone("Gen.Activation"); PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kGenActivation));
PROFILER_ZONE3(p, worker, zone);
namespace hn = hwy::HWY_NAMESPACE; namespace hn = hwy::HWY_NAMESPACE;
using DF = hn::ScalableTag<float>; using DF = hn::ScalableTag<float>;
using VF = hn::Vec<DF>; using VF = hn::Vec<DF>;
@ -88,8 +88,7 @@ static inline void Activation(ActivationType activation, const RowPtrsBF C1,
const IndexRange range_r, const IndexRange range_r,
const IndexRange range_c, const StridedViewBF C2, const IndexRange range_c, const StridedViewBF C2,
hwy::Profiler& p, const size_t worker) { hwy::Profiler& p, const size_t worker) {
static const auto zone = p.AddZone("Gen.ActivationFused"); PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kGenActivationFused));
PROFILER_ZONE3(p, worker, zone);
const size_t cols = range_c.Num(); const size_t cols = range_c.Num();
HWY_DASSERT(C2.Cols() == cols); HWY_DASSERT(C2.Cols() == cols);

View File

@ -19,6 +19,7 @@
#include "gemma/gemma.h" #include "gemma/gemma.h"
#include "compression/types.h" // GEMMA_DISABLED_TARGETS #include "compression/types.h" // GEMMA_DISABLED_TARGETS
#include "util/zones.h"
#ifndef HWY_DISABLED_TARGETS #ifndef HWY_DISABLED_TARGETS
#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS #define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS
#endif // HWY_DISABLED_TARGETS #endif // HWY_DISABLED_TARGETS
@ -466,14 +467,12 @@ ChooseSampleFunc(const RuntimeConfig& runtime_config,
// If user provided a sample_func, use it. // If user provided a sample_func, use it.
if (runtime_config.sample_func) return runtime_config.sample_func; if (runtime_config.sample_func) return runtime_config.sample_func;
static const auto zone_top1 = ctx.profiler.AddZone("Gen.Sample Top1");
static const auto zone_topK = ctx.profiler.AddZone("Gen.Sample general");
// Fast path for top-1 with no accept_token. // Fast path for top-1 with no accept_token.
if (runtime_config.top_k == 1 && !runtime_config.accept_token) { if (runtime_config.top_k == 1 && !runtime_config.accept_token) {
return [&](size_t /*qi*/, size_t /*pos*/, Logits logits, size_t worker) return [&](size_t /*qi*/, size_t /*pos*/, Logits logits, size_t worker)
HWY_ATTR -> TokenAndProb { HWY_ATTR -> TokenAndProb {
PROFILER_ZONE3(ctx.profiler, worker, zone_top1); PROFILER_ZONE3(ctx.profiler, worker,
GetProfilerZone(Zones::kGenSampleTop1));
return Top1OfSoftmax(logits); return Top1OfSoftmax(logits);
}; };
} }
@ -481,7 +480,8 @@ ChooseSampleFunc(const RuntimeConfig& runtime_config,
// General case: Softmax with top-k sampling. // General case: Softmax with top-k sampling.
return [&](size_t qi, size_t pos, Logits logits, return [&](size_t qi, size_t pos, Logits logits,
size_t worker) HWY_ATTR -> TokenAndProb { size_t worker) HWY_ATTR -> TokenAndProb {
PROFILER_ZONE3(ctx.profiler, worker, zone_topK); PROFILER_ZONE3(ctx.profiler, worker,
GetProfilerZone(Zones::kGenSampleTopK));
// We want a different sequence for each batch element and position. // We want a different sequence for each batch element and position.
const uint64_t stream = (static_cast<uint64_t>(qi) << 32) | pos; const uint64_t stream = (static_cast<uint64_t>(qi) << 32) | pos;
RngStream gen(engine, stream); RngStream gen(engine, stream);

View File

@ -32,6 +32,7 @@
#include "io/blob_store.h" #include "io/blob_store.h"
#include "util/mat.h" #include "util/mat.h"
#include "util/threading_context.h" #include "util/threading_context.h"
#include "util/zones.h"
#include "hwy/base.h" #include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/highway.h" #include "hwy/highway.h"
@ -379,8 +380,7 @@ static void DecompressToBF16(MatPtr& mat,
static void ReadAllToBF16(const std::vector<TensorToRead>& tensors, static void ReadAllToBF16(const std::vector<TensorToRead>& tensors,
const BlobReader& reader, ThreadingContext& ctx) { const BlobReader& reader, ThreadingContext& ctx) {
static const auto zone = const auto zone = GetProfilerZone(Zones::kStartupWeightsReadAllToBF16);
ctx.profiler.AddZone("Startup.Weights.ReadAllToBF16");
// Especially TSAN is slow enough to warrant hierarchical parallelism. // Especially TSAN is slow enough to warrant hierarchical parallelism.
const ParallelismStrategy strategy = HWY_IS_DEBUG_BUILD const ParallelismStrategy strategy = HWY_IS_DEBUG_BUILD
? ParallelismStrategy::kHierarchical ? ParallelismStrategy::kHierarchical
@ -463,7 +463,7 @@ static std::vector<IOBatch> MakeBatches(
static void ReadBatches(const BlobReader& reader, static void ReadBatches(const BlobReader& reader,
const std::vector<IOBatch>& batches, const std::vector<IOBatch>& batches,
ThreadingContext& ctx) { ThreadingContext& ctx) {
static const auto zone = ctx.profiler.AddZone("Startup.Weights.ReadBatches"); const auto zone = GetProfilerZone(Zones::kStartupWeightsReadBatches);
// >5x speedup from parallel reads when cached. // >5x speedup from parallel reads when cached.
ParallelFor(ParallelismStrategy::kHierarchical, ParallelFor(ParallelismStrategy::kHierarchical,
batches.size(), ctx, /*cluster_idx=*/0, batches.size(), ctx, /*cluster_idx=*/0,

View File

@ -25,6 +25,7 @@
#include "util/basics.h" #include "util/basics.h"
#include "util/mat.h" #include "util/mat.h"
#include "util/threading_context.h" #include "util/threading_context.h"
#include "util/zones.h"
#include "hwy/base.h" #include "hwy/base.h"
#include "hwy/profiler.h" #include "hwy/profiler.h"
#include "hwy/timer.h" #include "hwy/timer.h"
@ -290,7 +291,7 @@ class MMDecompress {
const hn::ScalableTag<BF16> dbf; const hn::ScalableTag<BF16> dbf;
const size_t NBF = hn::Lanes(dbf); const size_t NBF = hn::Lanes(dbf);
static const auto zone = env.ctx.profiler.AddZone("MM.DecompressA"); const auto zone = GetProfilerZone(Zones::kMMDecompressA);
const auto do_range = const auto do_range =
[&](const IndexRange& range_M, const IndexRange& range_K, size_t worker) [&](const IndexRange& range_M, const IndexRange& range_K, size_t worker)
@ -878,9 +879,9 @@ class MMLoops {
static HWY_NOINLINE void Dispatch(const StridedViewBF A, const MatPtrT<TB>& B, static HWY_NOINLINE void Dispatch(const StridedViewBF A, const MatPtrT<TB>& B,
const MatPtrT<TB>* B2, RowPtrs<TC> C, const MatPtrT<TB>* B2, RowPtrs<TC> C,
const MMArgs& args) { const MMArgs& args) {
static const auto zone = args.env.ctx.profiler.AddZone("MM.Dispatch");
PROFILER_ZONE3(args.env.ctx.profiler, PROFILER_ZONE3(args.env.ctx.profiler,
args.env.ctx.Worker(args.options.cluster_idx), zone); args.env.ctx.Worker(args.options.cluster_idx),
GetProfilerZone(Zones::kMMDispatch));
DispatchParallelism( DispatchParallelism(
args.options.parallelism, [&](const auto& parallel) HWY_ATTR { args.options.parallelism, [&](const auto& parallel) HWY_ATTR {
@ -903,7 +904,7 @@ class MMLoops {
const StridedViewBF A, const MatPtrT<TB>& B, const StridedViewBF A, const MatPtrT<TB>& B,
const MatPtrT<TB>* B2, RowPtrs<TC> C, const MatPtrT<TB>* B2, RowPtrs<TC> C,
const MMArgs& args) { const MMArgs& args) {
static const auto zone = args.env.ctx.profiler.AddZone("MM.NT"); const auto zone = GetProfilerZone(Zones::kMMNT);
HWY_DASSERT(args.ranges_mc.NumTasks() == 1); HWY_DASSERT(args.ranges_mc.NumTasks() == 1);
HWY_DASSERT(args.ranges_kc.NumTasks() == 1); HWY_DASSERT(args.ranges_kc.NumTasks() == 1);
const IndexRange& range_mc = args.ranges_mc.Range(0); const IndexRange& range_mc = args.ranges_mc.Range(0);
@ -939,7 +940,7 @@ class MMLoops {
const StridedViewBF A, const MatPtrT<TB>& B, const StridedViewBF A, const MatPtrT<TB>& B,
const MatPtrT<TB>* B2, RowPtrs<TC> C, const MatPtrT<TB>* B2, RowPtrs<TC> C,
const MMArgs& args) { const MMArgs& args) {
static const auto zone = args.env.ctx.profiler.AddZone("MM.NT_K"); const auto zone = GetProfilerZone(Zones::kMMNT_K);
HWY_DASSERT(args.ranges_mc.NumTasks() == 1); HWY_DASSERT(args.ranges_mc.NumTasks() == 1);
const IndexRange& range_mc = args.ranges_mc.Range(0); const IndexRange& range_mc = args.ranges_mc.Range(0);
@ -975,7 +976,7 @@ class MMLoops {
const StridedViewBF A, const MatPtrT<TB>& B, const StridedViewBF A, const MatPtrT<TB>& B,
const MatPtrT<TB>* B2, RowPtrs<TC> C, const MatPtrT<TB>* B2, RowPtrs<TC> C,
const MMArgs& args) { const MMArgs& args) {
static const auto zone = args.env.ctx.profiler.AddZone("MM.NT_MT"); const auto zone = GetProfilerZone(Zones::kMMNT_MT);
HWY_DASSERT(args.ranges_kc.NumTasks() == 1); HWY_DASSERT(args.ranges_kc.NumTasks() == 1);
const IndexRange& range_kc = args.ranges_kc.Range(0); const IndexRange& range_kc = args.ranges_kc.Range(0);
@ -1009,7 +1010,7 @@ class MMLoops {
const StridedViewBF A, const MatPtrT<TB>& B, const StridedViewBF A, const MatPtrT<TB>& B,
const MatPtrT<TB>* B2, RowPtrs<TC> C, const MatPtrT<TB>* B2, RowPtrs<TC> C,
const MMArgs& args) { const MMArgs& args) {
static const auto zone = args.env.ctx.profiler.AddZone("MM.NT_MT_K"); const auto zone = GetProfilerZone(Zones::kMMNT_MT_K);
parallel.ForRangesMC_NC( parallel.ForRangesMC_NC(
args.env.ctx, args.ranges_mc, args.ranges_nc, args.options.cluster_idx, args.env.ctx, args.ranges_mc, args.ranges_nc, args.options.cluster_idx,
@ -1060,10 +1061,10 @@ template <typename TA, typename TB, typename TC>
HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B, HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
const float* HWY_RESTRICT add, MatMulEnv& env, const float* HWY_RESTRICT add, MatMulEnv& env,
MatPtrT<TC>& C, MMOptions options = MMOptions()) { MatPtrT<TC>& C, MMOptions options = MMOptions()) {
static const auto zone = env.ctx.profiler.AddZone("MM.MatMul");
const size_t cluster_idx = options.cluster_idx; const size_t cluster_idx = options.cluster_idx;
HWY_DASSERT(cluster_idx < env.row_ptrs.size()); HWY_DASSERT(cluster_idx < env.row_ptrs.size());
PROFILER_ZONE3(env.ctx.profiler, env.ctx.Worker(cluster_idx), zone); PROFILER_ZONE3(env.ctx.profiler, env.ctx.Worker(cluster_idx),
GetProfilerZone(Zones::kMMMatMul));
RowPtrs<TC> C_rows = GetOrSetTempRowPtrs(C, env.row_ptrs[cluster_idx]); RowPtrs<TC> C_rows = GetOrSetTempRowPtrs(C, env.row_ptrs[cluster_idx]);
@ -1121,10 +1122,10 @@ template <typename TB>
HWY_NOINLINE MMPerKey* TwoMatMul(const MatPtrT<BF16>& A, const MatPtrT<TB>& B1, HWY_NOINLINE MMPerKey* TwoMatMul(const MatPtrT<BF16>& A, const MatPtrT<TB>& B1,
const MatPtrT<TB>& B2, MatMulEnv& env, const MatPtrT<TB>& B2, MatMulEnv& env,
MatPtrT<BF16>& C, MMOptions options) { MatPtrT<BF16>& C, MMOptions options) {
static const auto zone = env.ctx.profiler.AddZone("MM.TwoMatMul");
const size_t cluster_idx = options.cluster_idx; const size_t cluster_idx = options.cluster_idx;
HWY_DASSERT(cluster_idx < env.row_ptrs.size()); HWY_DASSERT(cluster_idx < env.row_ptrs.size());
PROFILER_ZONE3(env.ctx.profiler, env.ctx.Worker(cluster_idx), zone); PROFILER_ZONE3(env.ctx.profiler, env.ctx.Worker(cluster_idx),
GetProfilerZone(Zones::kMMTwoMatMul));
HWY_DASSERT(options.func != nullptr); // no other way to get access to C2. HWY_DASSERT(options.func != nullptr); // no other way to get access to C2.

View File

@ -32,6 +32,7 @@
#include "util/basics.h" // TokenAndProb, RngStream #include "util/basics.h" // TokenAndProb, RngStream
#include "util/mat.h" #include "util/mat.h"
#include "util/threading_context.h" #include "util/threading_context.h"
#include "util/zones.h"
#include "hwy/base.h" #include "hwy/base.h"
#include "hwy/bit_set.h" #include "hwy/bit_set.h"
#include "hwy/contrib/sort/order.h" #include "hwy/contrib/sort/order.h"
@ -206,8 +207,7 @@ namespace detail {
template <typename VT> template <typename VT>
float RMSNormMul(const VT* HWY_RESTRICT x, const size_t size, hwy::Profiler& p, float RMSNormMul(const VT* HWY_RESTRICT x, const size_t size, hwy::Profiler& p,
const size_t worker) { const size_t worker) {
static const auto zone = p.AddZone("Ops.RMSNormMul"); PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsRmsNormMul));
PROFILER_ZONE3(p, worker, zone);
const hn::ScalableTag<float> d; const hn::ScalableTag<float> d;
const float l2 = DecompressAndCall(d, MakeSpan(x, size), DotKernelDefault()); const float l2 = DecompressAndCall(d, MakeSpan(x, size), DotKernelDefault());
@ -223,8 +223,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(const XT* HWY_RESTRICT x,
OT* HWY_RESTRICT out, OT* HWY_RESTRICT out,
const size_t size, hwy::Profiler& p, const size_t size, hwy::Profiler& p,
const size_t worker) { const size_t worker) {
static const auto zone = p.AddZone("Ops.RMSNorm"); PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsRmsNorm));
PROFILER_ZONE3(p, worker, zone);
namespace hn = hwy::HWY_NAMESPACE; namespace hn = hwy::HWY_NAMESPACE;
using DF = hn::ScalableTag<float>; using DF = hn::ScalableTag<float>;
@ -248,8 +247,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace(const WT* HWY_RESTRICT weight,
const size_t size, const size_t size,
hwy::Profiler& p, hwy::Profiler& p,
const size_t worker) { const size_t worker) {
static const auto zone = p.AddZone("Ops.RMSNormInplace"); PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsRmsNormInplace));
PROFILER_ZONE3(p, worker, zone);
namespace hn = hwy::HWY_NAMESPACE; namespace hn = hwy::HWY_NAMESPACE;
using DF = hn::ScalableTag<float>; using DF = hn::ScalableTag<float>;
@ -365,8 +363,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope(
float* HWY_RESTRICT x, const size_t dim_qkv, float* HWY_RESTRICT x, const size_t dim_qkv,
const float* HWY_RESTRICT inv_timescale, const int pos, hwy::Profiler& p, const float* HWY_RESTRICT inv_timescale, const int pos, hwy::Profiler& p,
const size_t worker) { const size_t worker) {
static const auto zone = p.AddZone("Ops.Rope"); PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsRope));
PROFILER_ZONE3(p, worker, zone);
HWY_DASSERT(dim_qkv % 2 == 0); HWY_DASSERT(dim_qkv % 2 == 0);
const size_t half_dim_qkv = dim_qkv / 2; const size_t half_dim_qkv = dim_qkv / 2;
@ -425,8 +422,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy(
const float mul, float* HWY_RESTRICT x, const size_t dim_qkv, const float mul, float* HWY_RESTRICT x, const size_t dim_qkv,
const float* HWY_RESTRICT inv_timescale, const int pos, hwy::Profiler& p, const float* HWY_RESTRICT inv_timescale, const int pos, hwy::Profiler& p,
const size_t worker) { const size_t worker) {
static const auto zone = p.AddZone("Ops.RopeAndMulBy"); PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsRopeAndMulBy));
PROFILER_ZONE3(p, worker, zone);
HWY_DASSERT(dim_qkv % 2 == 0); HWY_DASSERT(dim_qkv % 2 == 0);
const size_t half_dim_qkv = dim_qkv / 2; const size_t half_dim_qkv = dim_qkv / 2;
@ -488,8 +484,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void AddFrom(const XT* HWY_RESTRICT x,
const size_t size, const size_t size,
hwy::Profiler& p, hwy::Profiler& p,
const size_t worker) { const size_t worker) {
static const auto zone = p.AddZone("Ops.AddFrom"); PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsAddFrom));
PROFILER_ZONE3(p, worker, zone);
namespace hn = hwy::HWY_NAMESPACE; namespace hn = hwy::HWY_NAMESPACE;
using DF = hn::ScalableTag<float>; using DF = hn::ScalableTag<float>;
@ -568,8 +563,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConst(const float c, XT* HWY_RESTRICT x,
const size_t size, const size_t size,
hwy::Profiler& p, hwy::Profiler& p,
const size_t worker) { const size_t worker) {
static const auto zone = p.AddZone("Ops.MulByConst"); PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsMulByConst));
PROFILER_ZONE3(p, worker, zone);
namespace hn = hwy::HWY_NAMESPACE; namespace hn = hwy::HWY_NAMESPACE;
using DF = hn::ScalableTag<float>; using DF = hn::ScalableTag<float>;
using VF = hn::Vec<DF>; using VF = hn::Vec<DF>;
@ -587,8 +581,7 @@ template <typename XT, typename OT>
HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstTo( HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstTo(
const float c, const XT* HWY_RESTRICT x, OT* HWY_RESTRICT out, const float c, const XT* HWY_RESTRICT x, OT* HWY_RESTRICT out,
const size_t size, hwy::Profiler& p, const size_t worker) { const size_t size, hwy::Profiler& p, const size_t worker) {
static const auto zone = p.AddZone("Ops.MulByConstTo"); PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsMulByConstTo));
PROFILER_ZONE3(p, worker, zone);
namespace hn = hwy::HWY_NAMESPACE; namespace hn = hwy::HWY_NAMESPACE;
using DF = hn::ScalableTag<float>; using DF = hn::ScalableTag<float>;
using VF = hn::Vec<DF>; using VF = hn::Vec<DF>;
@ -606,8 +599,7 @@ template <typename XT, typename OT>
HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAdd( HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAdd(
const float c, const XT* HWY_RESTRICT x, OT* HWY_RESTRICT out, const float c, const XT* HWY_RESTRICT x, OT* HWY_RESTRICT out,
const size_t size, hwy::Profiler& p, const size_t worker) { const size_t size, hwy::Profiler& p, const size_t worker) {
static const auto zone = p.AddZone("Ops.MulByConstAndAdd"); PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsMulByConstAndAdd));
PROFILER_ZONE3(p, worker, zone);
namespace hn = hwy::HWY_NAMESPACE; namespace hn = hwy::HWY_NAMESPACE;
using DF = hn::ScalableTag<float>; using DF = hn::ScalableTag<float>;
using VF = hn::Vec<DF>; using VF = hn::Vec<DF>;
@ -744,8 +736,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddTile(
const size_t* HWY_RESTRICT pos, float* HWY_RESTRICT out, const size_t* HWY_RESTRICT pos, float* HWY_RESTRICT out,
const uint32_t* HWY_RESTRICT out_offsets, const size_t size, const uint32_t* HWY_RESTRICT out_offsets, const size_t size,
hwy::Profiler& p, const size_t worker) { hwy::Profiler& p, const size_t worker) {
static const auto zone = p.AddZone("Ops.MulByConstAndAdd"); PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsMulByConstAndAddTile));
PROFILER_ZONE3(p, worker, zone);
namespace hn = hwy::HWY_NAMESPACE; namespace hn = hwy::HWY_NAMESPACE;
HWY_LANES_CONSTEXPR size_t NF = hn::Lanes(df); HWY_LANES_CONSTEXPR size_t NF = hn::Lanes(df);
@ -1007,8 +998,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddTile4(
const size_t* HWY_RESTRICT pos, float* HWY_RESTRICT out, const size_t* HWY_RESTRICT pos, float* HWY_RESTRICT out,
const uint32_t* HWY_RESTRICT out_offsets, const size_t size, const uint32_t* HWY_RESTRICT out_offsets, const size_t size,
hwy::Profiler& p, const size_t worker) { hwy::Profiler& p, const size_t worker) {
static const auto zone = p.AddZone("Ops.MulByConstAndAddTile4"); PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsMulByConstAndAddTile4));
PROFILER_ZONE3(p, worker, zone);
namespace hn = hwy::HWY_NAMESPACE; namespace hn = hwy::HWY_NAMESPACE;
HWY_LANES_CONSTEXPR size_t NF = hn::Lanes(df); HWY_LANES_CONSTEXPR size_t NF = hn::Lanes(df);
@ -1049,8 +1039,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddVector(
const size_t pos, float* HWY_RESTRICT out, const size_t pos, float* HWY_RESTRICT out,
const uint32_t* HWY_RESTRICT out_offsets, const size_t size, const uint32_t* HWY_RESTRICT out_offsets, const size_t size,
hwy::Profiler& p, const size_t worker) { hwy::Profiler& p, const size_t worker) {
static const auto zone = p.AddZone("Ops.MulByConstAndAdd"); PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsMulByConstAndAddVector));
PROFILER_ZONE3(p, worker, zone);
namespace hn = hwy::HWY_NAMESPACE; namespace hn = hwy::HWY_NAMESPACE;
HWY_LANES_CONSTEXPR size_t NF = hn::Lanes(df); HWY_LANES_CONSTEXPR size_t NF = hn::Lanes(df);
@ -1146,8 +1135,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddVector(
static HWY_NOINLINE void Softmax(Logits logits, hwy::Profiler& p, static HWY_NOINLINE void Softmax(Logits logits, hwy::Profiler& p,
const size_t worker, const size_t worker,
float temperature = 1.0f) { float temperature = 1.0f) {
static const auto zone = p.AddZone("Ops.Softmax"); PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsSoftmax));
PROFILER_ZONE3(p, worker, zone);
HWY_DASSERT(logits.size() != 0); HWY_DASSERT(logits.size() != 0);
namespace hn = hwy::HWY_NAMESPACE; namespace hn = hwy::HWY_NAMESPACE;
@ -1280,8 +1268,7 @@ static HWY_MAYBE_UNUSED TokenAndProb Top1OfSoftmax(Logits logits) {
static HWY_NOINLINE void LogitsSoftCap(const float cap, Logits logits, static HWY_NOINLINE void LogitsSoftCap(const float cap, Logits logits,
hwy::Profiler& p, const size_t worker) { hwy::Profiler& p, const size_t worker) {
static const auto zone = p.AddZone("Ops.LogitsSoftCap"); PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsLogitsSoftCap));
PROFILER_ZONE3(p, worker, zone);
namespace hn = hwy::HWY_NAMESPACE; namespace hn = hwy::HWY_NAMESPACE;
using DF = hn::ScalableTag<float>; using DF = hn::ScalableTag<float>;

View File

@ -14,6 +14,7 @@
// limitations under the License. // limitations under the License.
#include "compression/types.h" #include "compression/types.h"
#include "util/zones.h"
#ifndef HWY_DISABLED_TARGETS #ifndef HWY_DISABLED_TARGETS
#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS #define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS
#endif // HWY_DISABLED_TARGETS #endif // HWY_DISABLED_TARGETS
@ -132,6 +133,7 @@ class TestAddFrom {
} }
SimpleAddFrom(o, e, count); SimpleAddFrom(o, e, count);
InitProfilerZones(hwy::Profiler::Get());
AddFrom(o, x, count, hwy::Profiler::Get(), /*worker=*/0); AddFrom(o, x, count, hwy::Profiler::Get(), /*worker=*/0);
hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__, hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__,
@ -180,6 +182,7 @@ class TestMulByConstAndAdd {
T constant = Random<T>(rng); T constant = Random<T>(rng);
SimpleMulByConstAndAdd(constant, o, e, count); SimpleMulByConstAndAdd(constant, o, e, count);
InitProfilerZones(hwy::Profiler::Get());
MulByConstAndAdd(constant, o, x, count, hwy::Profiler::Get(), /*worker=*/0); MulByConstAndAdd(constant, o, x, count, hwy::Profiler::Get(), /*worker=*/0);
hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__, hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__,
@ -228,6 +231,7 @@ class TestMulByConst {
T constant = Random<T>(rng); T constant = Random<T>(rng);
SimpleMulByConst(constant, e, count); SimpleMulByConst(constant, e, count);
InitProfilerZones(hwy::Profiler::Get());
MulByConst(constant, x, count, hwy::Profiler::Get(), /*worker=*/0); MulByConst(constant, x, count, hwy::Profiler::Get(), /*worker=*/0);
hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__, hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__,
@ -274,6 +278,7 @@ struct TestMulByConstTo {
hwy::ConvertScalarTo<float>(constant)); hwy::ConvertScalarTo<float>(constant));
} }
InitProfilerZones(hwy::Profiler::Get());
MulByConstTo(constant, x, actual, count, hwy::Profiler::Get(), MulByConstTo(constant, x, actual, count, hwy::Profiler::Get(),
/*worker=*/0); /*worker=*/0);
@ -310,6 +315,7 @@ class TestSoftmax {
} }
SimpleSoftmax(e, count); SimpleSoftmax(e, count);
InitProfilerZones(hwy::Profiler::Get());
Softmax(Logits(x, count), hwy::Profiler::Get(), /*worker=*/0); Softmax(Logits(x, count), hwy::Profiler::Get(), /*worker=*/0);
T sum = 0.0f; T sum = 0.0f;
@ -437,6 +443,7 @@ void TestRopeAndMulBy() {
ThreadingArgs threading_args; ThreadingArgs threading_args;
ThreadingContext ctx(threading_args); ThreadingContext ctx(threading_args);
hwy::Profiler& p = ctx.profiler; hwy::Profiler& p = ctx.profiler;
InitProfilerZones(p);
const size_t worker = 0; const size_t worker = 0;
const ModelConfig config(Model::GEMMA2_9B, Type::kSFP, const ModelConfig config(Model::GEMMA2_9B, Type::kSFP,
@ -551,6 +558,7 @@ struct TestRMSNorm {
} }
ScalarRMSNorm(vec, weight, expected, kSize); ScalarRMSNorm(vec, weight, expected, kSize);
InitProfilerZones(hwy::Profiler::Get());
RMSNorm(vec, weight, actual, kSize, hwy::Profiler::Get(), /*worker=*/0); RMSNorm(vec, weight, actual, kSize, hwy::Profiler::Get(), /*worker=*/0);
for (size_t i = 0; i < kSize; i++) { for (size_t i = 0; i < kSize; i++) {
@ -585,6 +593,7 @@ struct TestRMSNormInplace {
} }
ScalarRMSNorm(expected, weight, expected, kSize); ScalarRMSNorm(expected, weight, expected, kSize);
InitProfilerZones(hwy::Profiler::Get());
RMSNormInplace(weight, actual, kSize, hwy::Profiler::Get(), RMSNormInplace(weight, actual, kSize, hwy::Profiler::Get(),
/*worker=*/0); /*worker=*/0);
@ -707,6 +716,7 @@ void TestAllLayerNorm() {
void TestSampleTopK() { void TestSampleTopK() {
hwy::Profiler& p = hwy::Profiler::Get(); hwy::Profiler& p = hwy::Profiler::Get();
InitProfilerZones(p);
const size_t worker = 0; const size_t worker = 0;
const size_t kSize = 52; const size_t kSize = 52;
std::vector<float> logits_vec(kSize); std::vector<float> logits_vec(kSize);

70
util/zones.cc Normal file
View File

@ -0,0 +1,70 @@
#include "util/zones.h"
#include "hwy/profiler.h"
namespace gcpp {
#if PROFILER_ENABLED
static constexpr size_t kNumZones = static_cast<size_t>(Zones::kNumZones);
static const char* kProfilerZoneNames[kNumZones] = {
// Keep in sync with Zones enum.
"Ops.RMSNormMul",
"Ops.RMSNorm",
"Ops.RMSNormInplace",
"Ops.Rope",
"Ops.RopeAndMulBy",
"Ops.AddFrom",
"Ops.MulByConst",
"Ops.MulByConstTo",
"Ops.MulByConstAndAdd",
"Ops.MulByConstAndAddTile",
"Ops.MulByConstAndAddTile4",
"Ops.MulByConstAndAddVector",
"Ops.Softmax",
"Ops.LogitsSoftCap",
"FlashAttention.TransposeQ",
"FlashAttention.RMSNormAndPositionalEncoding",
"FlashAttention.SingleFlashAttention",
"FlashAttention.TileFlashAttention",
"FlashAttention.TileFlashAttention4",
"FlashAttention.FlashAttention",
"Gen.Activation",
"Gen.ActivationFused",
"Gen.SampleTop1",
"Gen.SampleTopK",
"Gen.Attention.QDotK",
"Gen.Attention.DotSoftmaxWeightedSum.par",
"Startup.Weights.ReadAllToBF16",
"Startup.Weights.ReadBatches",
"MM.Dispatch",
"MM.MatMul",
"MM.TwoMatMul",
"MM.DecompressA",
"MM.NT",
"MM.NT_K",
"MM.NT_MT",
"MM.NT_MT_K",
};
static hwy::profiler::ZoneHandle profiler_zone_handles[kNumZones];
#endif
void InitProfilerZones(hwy::Profiler& profiler) {
#if PROFILER_ENABLED
// Initialize the zone handles. This is done once at startup.
for (size_t i = 0; i < kNumZones; ++i) {
profiler_zone_handles[i] = profiler.AddZone(kProfilerZoneNames[i]);
}
#endif
}
hwy::profiler::ZoneHandle GetProfilerZone(Zones zone) {
#if PROFILER_ENABLED
return profiler_zone_handles[static_cast<size_t>(zone)];
#else
return hwy::profiler::ZoneHandle();
#endif
}
} // namespace gcpp

58
util/zones.h Normal file
View File

@ -0,0 +1,58 @@
#ifndef THIRD_PARTY_GEMMA_CPP_UTIL_ZONES_H_
#define THIRD_PARTY_GEMMA_CPP_UTIL_ZONES_H_
#include "hwy/profiler.h"
namespace gcpp {
// Zones for the profiler.
enum class Zones {
kOpsRmsNormMul,
kOpsRmsNorm,
kOpsRmsNormInplace,
kOpsRope,
kOpsRopeAndMulBy,
kOpsAddFrom,
kOpsMulByConst,
kOpsMulByConstTo,
kOpsMulByConstAndAdd,
kOpsMulByConstAndAddTile,
kOpsMulByConstAndAddTile4,
kOpsMulByConstAndAddVector,
kOpsSoftmax,
kOpsLogitsSoftCap,
kFlashAttentionTransposeQ,
kFlashAttentionRmsNormAndPositionalEncoding,
kFlashAttentionSingleFlashAttention,
kFlashAttentionTileFlashAttention,
kFlashAttentionTileFlashAttention4,
kFlashAttentionFlashAttention,
kGenActivation,
kGenActivationFused,
kGenSampleTop1,
kGenSampleTopK,
kGenAttentionQDotK,
kGenAttentionDotSoftmaxWeightedSumPar,
kStartupWeightsReadAllToBF16,
kStartupWeightsReadBatches,
kMMDispatch,
kMMMatMul,
kMMTwoMatMul,
kMMDecompressA,
kMMNT,
kMMNT_K,
kMMNT_MT,
kMMNT_MT_K,
kNumZones
};
// Initializes the profiler zones. Must be called before any other profiler
// functions.
void InitProfilerZones(hwy::Profiler& profiler);
// Returns the zone handle for the given zone enum value.
hwy::profiler::ZoneHandle GetProfilerZone(Zones zone);
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_UTIL_ZONES_H_