From ee18916abffdcbc08fecb7c37a3d4bdc38a4bc80 Mon Sep 17 00:00:00 2001 From: Ray Smith Date: Wed, 15 Oct 2025 07:09:32 -0700 Subject: [PATCH] Removed the PROFILER_ZONE from the most highly called functions to reduce the overhead. PiperOrigin-RevId: 819739402 --- gemma/attention.cc | 7 +++--- gemma/flash_attention.cc | 49 ++++++++++++++++++++-------------------- gemma/gemma.cc | 4 +--- gemma/vit.cc | 10 ++++---- ops/ops-inl.h | 27 ++++++++-------------- ops/ops_test.cc | 5 ++-- 6 files changed, 43 insertions(+), 59 deletions(-) diff --git a/gemma/attention.cc b/gemma/attention.cc index 8950bc2..bf39702 100644 --- a/gemma/attention.cc +++ b/gemma/attention.cc @@ -89,7 +89,7 @@ void PositionalEncodingQK(float* qk, const size_t layer_idx, // PostQKType::Rope if (post_qk == PostQKType::HalfRope) { Rope(qk, qkv_dim / 2, inv_timescale, pos, p, worker); - if (mul != 1.0f) MulByConst(mul, qk, qkv_dim, p, worker); + if (mul != 1.0f) MulByConst(mul, qk, qkv_dim); } else { RopeAndMulBy(mul, qk, qkv_dim, inv_timescale, pos, p, worker); } @@ -113,7 +113,7 @@ static HWY_INLINE void WeightedSumV(const size_t start_pos, MulByConstTo(att[start_pos], v.Row(start_pos), att_out, v.Cols(), p, worker); for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) { - MulByConstAndAdd(att[pos], v.Row(pos), att_out, v.Cols(), p, worker); + MulByConstAndAdd(att[pos], v.Row(pos), att_out, v.Cols()); } } else { { @@ -122,8 +122,7 @@ static HWY_INLINE void WeightedSumV(const size_t start_pos, } for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) { const size_t pos_mod = div_seq_len.Remainder(pos); - MulByConstAndAdd(att[pos_mod], v.Row(pos_mod), att_out, v.Cols(), p, - worker); + MulByConstAndAdd(att[pos_mod], v.Row(pos_mod), att_out, v.Cols()); } } } diff --git a/gemma/flash_attention.cc b/gemma/flash_attention.cc index cfadf28..df6efd1 100644 --- a/gemma/flash_attention.cc +++ b/gemma/flash_attention.cc @@ -131,10 +131,11 @@ void RMSNormAndPositionalEncoding(const size_t num_tokens, const QBatch& qbatch, } // Handles a single v row of flash attention for a single q.k dot product. -void HWY_INLINE SingleFlashAttentionStep( - float x, float cap, float& old_max, float& old_d, - const float* HWY_RESTRICT v, const size_t v_cols, - float* HWY_RESTRICT att_out, hwy::Profiler& p, const size_t worker) { +void HWY_INLINE SingleFlashAttentionStep(float x, float cap, float& old_max, + float& old_d, + const float* HWY_RESTRICT v, + const size_t v_cols, + float* HWY_RESTRICT att_out) { if (cap > 0.0f) { // Compute tanh(x / cap) * cap, being LogitsSoftCap on the scalar x. x = cap * std::tanh(x / cap); @@ -147,8 +148,8 @@ void HWY_INLINE SingleFlashAttentionStep( float one_over_d = 1.0f / old_d; scale *= one_over_d; x *= one_over_d; - MulByConst(scale, att_out, v_cols, p, worker); - MulByConstAndAdd(x, v, att_out, v_cols, p, worker); + MulByConst(scale, att_out, v_cols); + MulByConstAndAdd(x, v, att_out, v_cols); } // Calculates the complete attention outputs for a single row of q. @@ -174,7 +175,7 @@ void SingleFlashAttention(const size_t start_pos, const size_t last_pos, const size_t pos_mod = activations.div_seq_len.Remainder(pos); float x = Dot(q, k.Row(pos_mod), k.Cols()); SingleFlashAttentionStep(x, activations.config.att_cap, m, d, - v.Row(pos_mod), v.Cols(), att_out, p, worker); + v.Row(pos_mod), v.Cols(), att_out); } } @@ -183,7 +184,7 @@ void SingleFlashAttention(const size_t start_pos, const size_t last_pos, template > VF QDotKVector(DF df, const uint32_t* HWY_RESTRICT q_offsets, const size_t k_pos, const MatPtrT& q, - const MatPtrT& k, hwy::Profiler& p, const size_t worker) { + const MatPtrT& k) { hn::TFromD results[hn::MaxLanes(df)]; for (size_t i = 0; i < hn::Lanes(df); ++i) { results[i] = Dot(q.Row(0) + q_offsets[i], k.Row(k_pos), k.Cols()); @@ -198,9 +199,8 @@ VF QDotKVector(DF df, const uint32_t* HWY_RESTRICT q_offsets, // consecutive elements, and other columns by adding q_stride. template > void QDotKTileFloat(DF df, const float* HWY_RESTRICT q, const size_t q_stride, - const MatPtrT& k, const size_t* k_pos, - hwy::Profiler& p, const size_t worker, VF& sum0, VF& sum1, - VF& sum2, VF& sum3, VF& sum4, VF& sum5, VF& sum6, + const MatPtrT& k, const size_t* k_pos, VF& sum0, + VF& sum1, VF& sum2, VF& sum3, VF& sum4, VF& sum5, VF& sum6, VF& sum7) { constexpr size_t kHTileSize = kNFx8HTileSize; sum0 = hn::Zero(df); @@ -303,8 +303,8 @@ void TileFlashAttention( k_pos[i] = activations.div_seq_len.Remainder(position + i); } VF x0, x1, x2, x3, x4, x5, x6, x7; - QDotKTileFloat(df, qT_row, qT_stride, k, k_pos, p, worker, x0, x1, x2, x3, - x4, x5, x6, x7); + QDotKTileFloat(df, qT_row, qT_stride, k, k_pos, x0, x1, x2, x3, x4, x5, x6, + x7); if (activations.config.att_cap > 0.0f) { // Compute tanh(x / cap) * cap, being LogitsSoftCap on the tile. VF cap = hn::Set(df, activations.config.att_cap); @@ -343,12 +343,12 @@ void TileFlashAttention( x6 = hn::Mul(x6, one_over_d); x7 = hn::Mul(x7, one_over_d); MulByConstAndAddTile(df, scale, x0, x1, x2, x3, x4, x5, x6, x7, v, k_pos, - att_out.Row(0), out_offsets, v.Cols(), p, worker); + att_out.Row(0), out_offsets, v.Cols()); position += kHTileSize; } while (position <= max_last_pos) { size_t k_pos = activations.div_seq_len.Remainder(position); - VF x0 = QDotKVector(df, q_offsets, k_pos, q, k, p, worker); + VF x0 = QDotKVector(df, q_offsets, k_pos, q, k); if (activations.config.att_cap > 0.0f) { // Compute tanh(x / cap) * cap, being LogitsSoftCap on the vector. VF cap = hn::Set(df, activations.config.att_cap); @@ -369,7 +369,7 @@ void TileFlashAttention( x0 = hn::Mul(x0, one_over_d); scale = hn::Mul(scale, one_over_d); MulByConstAndAddVector(df, scale, x0, v, k_pos, att_out.Row(0), out_offsets, - v.Cols(), p, worker); + v.Cols()); ++position; } } @@ -380,8 +380,8 @@ void TileFlashAttention( template > void QDotKTilex4(DF df, const float* HWY_RESTRICT q, const uint32_t* HWY_RESTRICT q_offsets, const MatPtrT& k, - const int32_t* HWY_RESTRICT k_offsets, hwy::Profiler& p, - const size_t worker, VF& sum0, VF& sum1, VF& sum2, VF& sum3) { + const int32_t* HWY_RESTRICT k_offsets, VF& sum0, VF& sum1, + VF& sum2, VF& sum3) { sum0 = hn::Zero(df); sum1 = hn::Zero(df); sum2 = hn::Zero(df); @@ -462,8 +462,7 @@ void TileFlashAttention4( k_offsets[i] = k.Row(v_pos[i]) - k.Row(0); } VF x0, x1, x2, x3; - QDotKTilex4(df, q.Row(0), q_offsets, k, k_offsets, p, worker, x0, x1, x2, - x3); + QDotKTilex4(df, q.Row(0), q_offsets, k, k_offsets, x0, x1, x2, x3); if (activations.config.att_cap > 0.0f) { // Compute tanh(x / cap) * cap, being LogitsSoftCap on the tile. VF cap = hn::Set(df, activations.config.att_cap); @@ -478,7 +477,7 @@ void TileFlashAttention4( scales[2] = SingleFlashAttentionRowVector(df, x2, old_m2, old_d2); scales[3] = SingleFlashAttentionRowVector(df, x3, old_m3, old_d3); MulByConstAndAddTile4(df, scales, x0, x1, x2, x3, v, v_pos, att_out.Row(0), - out_offsets, v.Cols(), p, worker); + out_offsets, v.Cols()); position += kHTileSize; } while (position <= max_last_pos) { @@ -488,28 +487,28 @@ void TileFlashAttention4( float x0 = Dot(q.Row(0) + q_offsets[0], k.Row(k_pos), k.Cols()); SingleFlashAttentionStep(x0, activations.config.att_cap, old_m0, old_d0, v.Row(k_pos), v.Cols(), - att_out.Row(0) + out_offsets[0], p, worker); + att_out.Row(0) + out_offsets[0]); } if (position <= last_pos[1]) { // Past the last position, x1 doesn't count. float x1 = Dot(q.Row(0) + q_offsets[1], k.Row(k_pos), k.Cols()); SingleFlashAttentionStep(x1, activations.config.att_cap, old_m1, old_d1, v.Row(k_pos), v.Cols(), - att_out.Row(0) + out_offsets[1], p, worker); + att_out.Row(0) + out_offsets[1]); } if (position <= last_pos[2]) { // Past the last position, x2 doesn't count. float x2 = Dot(q.Row(0) + q_offsets[2], k.Row(k_pos), k.Cols()); SingleFlashAttentionStep(x2, activations.config.att_cap, old_m2, old_d2, v.Row(k_pos), v.Cols(), - att_out.Row(0) + out_offsets[2], p, worker); + att_out.Row(0) + out_offsets[2]); } if (position <= last_pos[3]) { // Past the last position, x3 doesn't count. float x3 = Dot(q.Row(0) + q_offsets[3], k.Row(k_pos), k.Cols()); SingleFlashAttentionStep(x3, activations.config.att_cap, old_m3, old_d3, v.Row(k_pos), v.Cols(), - att_out.Row(0) + out_offsets[3], p, worker); + att_out.Row(0) + out_offsets[3]); } ++position; } diff --git a/gemma/gemma.cc b/gemma/gemma.cc index 78c9cc4..80bf9e2 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -160,7 +160,6 @@ EmbedMMToken(int token, size_t x_row, size_t pos, size_t pos_in_prompt, const size_t model_dim = model_config.model_dim; const float emb_scaling = EmbeddingScaling(model_dim); - const size_t worker = 0; // Not yet parallelized. HWY_DASSERT(token >= 0); HWY_DASSERT(token < static_cast(model_config.vocab_size)); @@ -176,8 +175,7 @@ EmbedMMToken(int token, size_t x_row, size_t pos, size_t pos_in_prompt, const hn::ScalableTag df; DecompressAndZeroPad(df, embedding_span, embedding_ofs, x.Row(x_row), model_dim); - MulByConst(emb_scaling * weights_t->Scale(), x.Row(x_row), model_dim, - ctx.profiler, worker); + MulByConst(emb_scaling * weights_t->Scale(), x.Row(x_row), model_dim); }); if (model_config.absolute_pe) { diff --git a/gemma/vit.cc b/gemma/vit.cc index 44b1bcb..d21be16 100644 --- a/gemma/vit.cc +++ b/gemma/vit.cc @@ -95,7 +95,7 @@ class VitAttention { float* HWY_RESTRICT q = activations_.attention.q.Row(token) + head * 3 * qkv_dim; // TODO: shift to MatMul with A.scale once MatMul is confirmed working - MulByConst(query_scale, q, qkv_dim, env_.ctx.profiler, worker); + MulByConst(query_scale, q, qkv_dim); hwy::CopyBytes(q, Q.Row(token), qkv_dim * sizeof(float)); }); @@ -120,8 +120,7 @@ class VitAttention { for (size_t i = 0; i < seq_len; ++i) { float* HWY_RESTRICT v = activations_.attention.q.Row(i) + head * 3 * qkv_dim + 2 * qkv_dim; - MulByConstAndAdd(C.Row(token)[i], v, att_out, qkv_dim, - env_.ctx.profiler, worker); + MulByConstAndAdd(C.Row(token)[i], v, att_out, qkv_dim); } }); } @@ -144,7 +143,7 @@ class VitAttention { // Compute Q.K scores, which are "logits" stored in head_att. float* HWY_RESTRICT q = activations_.attention.q.Row(token) + head * 3 * qkv_dim; - MulByConst(query_scale, q, qkv_dim, env_.ctx.profiler, worker); + MulByConst(query_scale, q, qkv_dim); float* HWY_RESTRICT head_att = activations_.attention.att.Row(token) + head * seq_len; for (size_t i = 0; i < seq_len; ++i) { @@ -161,8 +160,7 @@ class VitAttention { for (size_t i = 0; i < seq_len; ++i) { float* HWY_RESTRICT v = activations_.attention.q.Row(i) + head * 3 * qkv_dim + 2 * qkv_dim; - MulByConstAndAdd(head_att[i], v, att_out, qkv_dim, - env_.ctx.profiler, worker); + MulByConstAndAdd(head_att[i], v, att_out, qkv_dim); } }); } diff --git a/ops/ops-inl.h b/ops/ops-inl.h index 162b48a..c966a68 100644 --- a/ops/ops-inl.h +++ b/ops/ops-inl.h @@ -560,10 +560,7 @@ static HWY_INLINE void AddFromBatched(const MatPtrT& x, MatPtrT& out, template HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConst(const float c, XT* HWY_RESTRICT x, - const size_t size, - hwy::Profiler& p, - const size_t worker) { - PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsMulByConst)); + const size_t size) { namespace hn = hwy::HWY_NAMESPACE; using DF = hn::ScalableTag; using VF = hn::Vec; @@ -596,10 +593,10 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstTo( // out[i] += x[i] * c. template -HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAdd( - const float c, const XT* HWY_RESTRICT x, OT* HWY_RESTRICT out, - const size_t size, hwy::Profiler& p, const size_t worker) { - PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsMulByConstAndAdd)); +HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAdd(const float c, + const XT* HWY_RESTRICT x, + OT* HWY_RESTRICT out, + const size_t size) { namespace hn = hwy::HWY_NAMESPACE; using DF = hn::ScalableTag; using VF = hn::Vec; @@ -734,9 +731,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddTile( DF df, const VF scale, const VF c0, const VF c1, const VF c2, const VF c3, const VF c4, const VF c5, const VF c6, const VF c7, const MatPtrT& v, const size_t* HWY_RESTRICT pos, float* HWY_RESTRICT out, - const uint32_t* HWY_RESTRICT out_offsets, const size_t size, - hwy::Profiler& p, const size_t worker) { - PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsMulByConstAndAddTile)); + const uint32_t* HWY_RESTRICT out_offsets, const size_t size) { namespace hn = hwy::HWY_NAMESPACE; HWY_LANES_CONSTEXPR size_t NF = hn::Lanes(df); @@ -996,9 +991,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddTile4( DF df, const float* HWY_RESTRICT scales, const VF c0, const VF c1, const VF c2, const VF c3, const MatPtrT& v, const size_t* HWY_RESTRICT pos, float* HWY_RESTRICT out, - const uint32_t* HWY_RESTRICT out_offsets, const size_t size, - hwy::Profiler& p, const size_t worker) { - PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsMulByConstAndAddTile4)); + const uint32_t* HWY_RESTRICT out_offsets, const size_t size) { namespace hn = hwy::HWY_NAMESPACE; HWY_LANES_CONSTEXPR size_t NF = hn::Lanes(df); @@ -1037,9 +1030,7 @@ template > HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddVector( DF df, const VF scale, const VF c0, const MatPtrT& v, const size_t pos, float* HWY_RESTRICT out, - const uint32_t* HWY_RESTRICT out_offsets, const size_t size, - hwy::Profiler& p, const size_t worker) { - PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsMulByConstAndAddVector)); + const uint32_t* HWY_RESTRICT out_offsets, const size_t size) { namespace hn = hwy::HWY_NAMESPACE; HWY_LANES_CONSTEXPR size_t NF = hn::Lanes(df); @@ -1177,7 +1168,7 @@ static HWY_NOINLINE void Softmax(Logits logits, hwy::Profiler& p, const float sum_exp = Sum(d, logits.data(), logits.size()); // Double-precision reciprocal does not appear to affect the results. const float mul = 1.0f / sum_exp; - MulByConst(mul, logits.data(), logits.size(), p, worker); + MulByConst(mul, logits.data(), logits.size()); } // Note: https://arxiv.org/pdf/2001.04438 proposes to replace the three max / diff --git a/ops/ops_test.cc b/ops/ops_test.cc index 40f1002..dd8e4e8 100644 --- a/ops/ops_test.cc +++ b/ops/ops_test.cc @@ -183,7 +183,7 @@ class TestMulByConstAndAdd { SimpleMulByConstAndAdd(constant, o, e, count); InitProfilerZones(hwy::Profiler::Get()); - MulByConstAndAdd(constant, o, x, count, hwy::Profiler::Get(), /*worker=*/0); + MulByConstAndAdd(constant, o, x, count); hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__, __LINE__); @@ -232,7 +232,7 @@ class TestMulByConst { SimpleMulByConst(constant, e, count); InitProfilerZones(hwy::Profiler::Get()); - MulByConst(constant, x, count, hwy::Profiler::Get(), /*worker=*/0); + MulByConst(constant, x, count); hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__, __LINE__); @@ -443,7 +443,6 @@ void TestRopeAndMulBy() { ThreadingArgs threading_args; ThreadingContext ctx(threading_args); hwy::Profiler& p = ctx.profiler; - InitProfilerZones(p); const size_t worker = 0; const ModelConfig config(Model::GEMMA2_9B, Type::kSFP,