Removed the PROFILER_ZONE from the most highly called functions to reduce the overhead.

PiperOrigin-RevId: 819739402
This commit is contained in:
Ray Smith 2025-10-15 07:09:32 -07:00 committed by Copybara-Service
parent e3e8511e79
commit ee18916abf
6 changed files with 43 additions and 59 deletions

View File

@ -89,7 +89,7 @@ void PositionalEncodingQK(float* qk, const size_t layer_idx,
// PostQKType::Rope
if (post_qk == PostQKType::HalfRope) {
Rope(qk, qkv_dim / 2, inv_timescale, pos, p, worker);
if (mul != 1.0f) MulByConst(mul, qk, qkv_dim, p, worker);
if (mul != 1.0f) MulByConst(mul, qk, qkv_dim);
} else {
RopeAndMulBy(mul, qk, qkv_dim, inv_timescale, pos, p, worker);
}
@ -113,7 +113,7 @@ static HWY_INLINE void WeightedSumV(const size_t start_pos,
MulByConstTo(att[start_pos], v.Row(start_pos), att_out, v.Cols(), p,
worker);
for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) {
MulByConstAndAdd(att[pos], v.Row(pos), att_out, v.Cols(), p, worker);
MulByConstAndAdd(att[pos], v.Row(pos), att_out, v.Cols());
}
} else {
{
@ -122,8 +122,7 @@ static HWY_INLINE void WeightedSumV(const size_t start_pos,
}
for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) {
const size_t pos_mod = div_seq_len.Remainder(pos);
MulByConstAndAdd(att[pos_mod], v.Row(pos_mod), att_out, v.Cols(), p,
worker);
MulByConstAndAdd(att[pos_mod], v.Row(pos_mod), att_out, v.Cols());
}
}
}

View File

@ -131,10 +131,11 @@ void RMSNormAndPositionalEncoding(const size_t num_tokens, const QBatch& qbatch,
}
// Handles a single v row of flash attention for a single q.k dot product.
void HWY_INLINE SingleFlashAttentionStep(
float x, float cap, float& old_max, float& old_d,
const float* HWY_RESTRICT v, const size_t v_cols,
float* HWY_RESTRICT att_out, hwy::Profiler& p, const size_t worker) {
void HWY_INLINE SingleFlashAttentionStep(float x, float cap, float& old_max,
float& old_d,
const float* HWY_RESTRICT v,
const size_t v_cols,
float* HWY_RESTRICT att_out) {
if (cap > 0.0f) {
// Compute tanh(x / cap) * cap, being LogitsSoftCap on the scalar x.
x = cap * std::tanh(x / cap);
@ -147,8 +148,8 @@ void HWY_INLINE SingleFlashAttentionStep(
float one_over_d = 1.0f / old_d;
scale *= one_over_d;
x *= one_over_d;
MulByConst(scale, att_out, v_cols, p, worker);
MulByConstAndAdd(x, v, att_out, v_cols, p, worker);
MulByConst(scale, att_out, v_cols);
MulByConstAndAdd(x, v, att_out, v_cols);
}
// Calculates the complete attention outputs for a single row of q.
@ -174,7 +175,7 @@ void SingleFlashAttention(const size_t start_pos, const size_t last_pos,
const size_t pos_mod = activations.div_seq_len.Remainder(pos);
float x = Dot(q, k.Row(pos_mod), k.Cols());
SingleFlashAttentionStep(x, activations.config.att_cap, m, d,
v.Row(pos_mod), v.Cols(), att_out, p, worker);
v.Row(pos_mod), v.Cols(), att_out);
}
}
@ -183,7 +184,7 @@ void SingleFlashAttention(const size_t start_pos, const size_t last_pos,
template <class DF, class VF = hn::Vec<DF>>
VF QDotKVector(DF df, const uint32_t* HWY_RESTRICT q_offsets,
const size_t k_pos, const MatPtrT<KV_t>& q,
const MatPtrT<KV_t>& k, hwy::Profiler& p, const size_t worker) {
const MatPtrT<KV_t>& k) {
hn::TFromD<DF> results[hn::MaxLanes(df)];
for (size_t i = 0; i < hn::Lanes(df); ++i) {
results[i] = Dot(q.Row(0) + q_offsets[i], k.Row(k_pos), k.Cols());
@ -198,9 +199,8 @@ VF QDotKVector(DF df, const uint32_t* HWY_RESTRICT q_offsets,
// consecutive elements, and other columns by adding q_stride.
template <class DF, class VF = hn::Vec<DF>>
void QDotKTileFloat(DF df, const float* HWY_RESTRICT q, const size_t q_stride,
const MatPtrT<KV_t>& k, const size_t* k_pos,
hwy::Profiler& p, const size_t worker, VF& sum0, VF& sum1,
VF& sum2, VF& sum3, VF& sum4, VF& sum5, VF& sum6,
const MatPtrT<KV_t>& k, const size_t* k_pos, VF& sum0,
VF& sum1, VF& sum2, VF& sum3, VF& sum4, VF& sum5, VF& sum6,
VF& sum7) {
constexpr size_t kHTileSize = kNFx8HTileSize;
sum0 = hn::Zero(df);
@ -303,8 +303,8 @@ void TileFlashAttention(
k_pos[i] = activations.div_seq_len.Remainder(position + i);
}
VF x0, x1, x2, x3, x4, x5, x6, x7;
QDotKTileFloat(df, qT_row, qT_stride, k, k_pos, p, worker, x0, x1, x2, x3,
x4, x5, x6, x7);
QDotKTileFloat(df, qT_row, qT_stride, k, k_pos, x0, x1, x2, x3, x4, x5, x6,
x7);
if (activations.config.att_cap > 0.0f) {
// Compute tanh(x / cap) * cap, being LogitsSoftCap on the tile.
VF cap = hn::Set(df, activations.config.att_cap);
@ -343,12 +343,12 @@ void TileFlashAttention(
x6 = hn::Mul(x6, one_over_d);
x7 = hn::Mul(x7, one_over_d);
MulByConstAndAddTile(df, scale, x0, x1, x2, x3, x4, x5, x6, x7, v, k_pos,
att_out.Row(0), out_offsets, v.Cols(), p, worker);
att_out.Row(0), out_offsets, v.Cols());
position += kHTileSize;
}
while (position <= max_last_pos) {
size_t k_pos = activations.div_seq_len.Remainder(position);
VF x0 = QDotKVector(df, q_offsets, k_pos, q, k, p, worker);
VF x0 = QDotKVector(df, q_offsets, k_pos, q, k);
if (activations.config.att_cap > 0.0f) {
// Compute tanh(x / cap) * cap, being LogitsSoftCap on the vector.
VF cap = hn::Set(df, activations.config.att_cap);
@ -369,7 +369,7 @@ void TileFlashAttention(
x0 = hn::Mul(x0, one_over_d);
scale = hn::Mul(scale, one_over_d);
MulByConstAndAddVector(df, scale, x0, v, k_pos, att_out.Row(0), out_offsets,
v.Cols(), p, worker);
v.Cols());
++position;
}
}
@ -380,8 +380,8 @@ void TileFlashAttention(
template <class DF, class VF = hn::Vec<DF>>
void QDotKTilex4(DF df, const float* HWY_RESTRICT q,
const uint32_t* HWY_RESTRICT q_offsets, const MatPtrT<KV_t>& k,
const int32_t* HWY_RESTRICT k_offsets, hwy::Profiler& p,
const size_t worker, VF& sum0, VF& sum1, VF& sum2, VF& sum3) {
const int32_t* HWY_RESTRICT k_offsets, VF& sum0, VF& sum1,
VF& sum2, VF& sum3) {
sum0 = hn::Zero(df);
sum1 = hn::Zero(df);
sum2 = hn::Zero(df);
@ -462,8 +462,7 @@ void TileFlashAttention4(
k_offsets[i] = k.Row(v_pos[i]) - k.Row(0);
}
VF x0, x1, x2, x3;
QDotKTilex4(df, q.Row(0), q_offsets, k, k_offsets, p, worker, x0, x1, x2,
x3);
QDotKTilex4(df, q.Row(0), q_offsets, k, k_offsets, x0, x1, x2, x3);
if (activations.config.att_cap > 0.0f) {
// Compute tanh(x / cap) * cap, being LogitsSoftCap on the tile.
VF cap = hn::Set(df, activations.config.att_cap);
@ -478,7 +477,7 @@ void TileFlashAttention4(
scales[2] = SingleFlashAttentionRowVector(df, x2, old_m2, old_d2);
scales[3] = SingleFlashAttentionRowVector(df, x3, old_m3, old_d3);
MulByConstAndAddTile4(df, scales, x0, x1, x2, x3, v, v_pos, att_out.Row(0),
out_offsets, v.Cols(), p, worker);
out_offsets, v.Cols());
position += kHTileSize;
}
while (position <= max_last_pos) {
@ -488,28 +487,28 @@ void TileFlashAttention4(
float x0 = Dot(q.Row(0) + q_offsets[0], k.Row(k_pos), k.Cols());
SingleFlashAttentionStep(x0, activations.config.att_cap, old_m0, old_d0,
v.Row(k_pos), v.Cols(),
att_out.Row(0) + out_offsets[0], p, worker);
att_out.Row(0) + out_offsets[0]);
}
if (position <= last_pos[1]) {
// Past the last position, x1 doesn't count.
float x1 = Dot(q.Row(0) + q_offsets[1], k.Row(k_pos), k.Cols());
SingleFlashAttentionStep(x1, activations.config.att_cap, old_m1, old_d1,
v.Row(k_pos), v.Cols(),
att_out.Row(0) + out_offsets[1], p, worker);
att_out.Row(0) + out_offsets[1]);
}
if (position <= last_pos[2]) {
// Past the last position, x2 doesn't count.
float x2 = Dot(q.Row(0) + q_offsets[2], k.Row(k_pos), k.Cols());
SingleFlashAttentionStep(x2, activations.config.att_cap, old_m2, old_d2,
v.Row(k_pos), v.Cols(),
att_out.Row(0) + out_offsets[2], p, worker);
att_out.Row(0) + out_offsets[2]);
}
if (position <= last_pos[3]) {
// Past the last position, x3 doesn't count.
float x3 = Dot(q.Row(0) + q_offsets[3], k.Row(k_pos), k.Cols());
SingleFlashAttentionStep(x3, activations.config.att_cap, old_m3, old_d3,
v.Row(k_pos), v.Cols(),
att_out.Row(0) + out_offsets[3], p, worker);
att_out.Row(0) + out_offsets[3]);
}
++position;
}

View File

@ -160,7 +160,6 @@ EmbedMMToken(int token, size_t x_row, size_t pos, size_t pos_in_prompt,
const size_t model_dim = model_config.model_dim;
const float emb_scaling = EmbeddingScaling(model_dim);
const size_t worker = 0; // Not yet parallelized.
HWY_DASSERT(token >= 0);
HWY_DASSERT(token < static_cast<int>(model_config.vocab_size));
@ -176,8 +175,7 @@ EmbedMMToken(int token, size_t x_row, size_t pos, size_t pos_in_prompt,
const hn::ScalableTag<float> df;
DecompressAndZeroPad(df, embedding_span, embedding_ofs, x.Row(x_row),
model_dim);
MulByConst(emb_scaling * weights_t->Scale(), x.Row(x_row), model_dim,
ctx.profiler, worker);
MulByConst(emb_scaling * weights_t->Scale(), x.Row(x_row), model_dim);
});
if (model_config.absolute_pe) {

View File

@ -95,7 +95,7 @@ class VitAttention {
float* HWY_RESTRICT q =
activations_.attention.q.Row(token) + head * 3 * qkv_dim;
// TODO: shift to MatMul with A.scale once MatMul is confirmed working
MulByConst(query_scale, q, qkv_dim, env_.ctx.profiler, worker);
MulByConst(query_scale, q, qkv_dim);
hwy::CopyBytes(q, Q.Row(token), qkv_dim * sizeof(float));
});
@ -120,8 +120,7 @@ class VitAttention {
for (size_t i = 0; i < seq_len; ++i) {
float* HWY_RESTRICT v = activations_.attention.q.Row(i) +
head * 3 * qkv_dim + 2 * qkv_dim;
MulByConstAndAdd(C.Row(token)[i], v, att_out, qkv_dim,
env_.ctx.profiler, worker);
MulByConstAndAdd(C.Row(token)[i], v, att_out, qkv_dim);
}
});
}
@ -144,7 +143,7 @@ class VitAttention {
// Compute Q.K scores, which are "logits" stored in head_att.
float* HWY_RESTRICT q =
activations_.attention.q.Row(token) + head * 3 * qkv_dim;
MulByConst(query_scale, q, qkv_dim, env_.ctx.profiler, worker);
MulByConst(query_scale, q, qkv_dim);
float* HWY_RESTRICT head_att =
activations_.attention.att.Row(token) + head * seq_len;
for (size_t i = 0; i < seq_len; ++i) {
@ -161,8 +160,7 @@ class VitAttention {
for (size_t i = 0; i < seq_len; ++i) {
float* HWY_RESTRICT v = activations_.attention.q.Row(i) +
head * 3 * qkv_dim + 2 * qkv_dim;
MulByConstAndAdd(head_att[i], v, att_out, qkv_dim,
env_.ctx.profiler, worker);
MulByConstAndAdd(head_att[i], v, att_out, qkv_dim);
}
});
}

View File

@ -560,10 +560,7 @@ static HWY_INLINE void AddFromBatched(const MatPtrT<XT>& x, MatPtrT<float>& out,
template <typename XT>
HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConst(const float c, XT* HWY_RESTRICT x,
const size_t size,
hwy::Profiler& p,
const size_t worker) {
PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsMulByConst));
const size_t size) {
namespace hn = hwy::HWY_NAMESPACE;
using DF = hn::ScalableTag<float>;
using VF = hn::Vec<DF>;
@ -596,10 +593,10 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstTo(
// out[i] += x[i] * c.
template <typename XT, typename OT>
HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAdd(
const float c, const XT* HWY_RESTRICT x, OT* HWY_RESTRICT out,
const size_t size, hwy::Profiler& p, const size_t worker) {
PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsMulByConstAndAdd));
HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAdd(const float c,
const XT* HWY_RESTRICT x,
OT* HWY_RESTRICT out,
const size_t size) {
namespace hn = hwy::HWY_NAMESPACE;
using DF = hn::ScalableTag<float>;
using VF = hn::Vec<DF>;
@ -734,9 +731,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddTile(
DF df, const VF scale, const VF c0, const VF c1, const VF c2, const VF c3,
const VF c4, const VF c5, const VF c6, const VF c7, const MatPtrT<float>& v,
const size_t* HWY_RESTRICT pos, float* HWY_RESTRICT out,
const uint32_t* HWY_RESTRICT out_offsets, const size_t size,
hwy::Profiler& p, const size_t worker) {
PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsMulByConstAndAddTile));
const uint32_t* HWY_RESTRICT out_offsets, const size_t size) {
namespace hn = hwy::HWY_NAMESPACE;
HWY_LANES_CONSTEXPR size_t NF = hn::Lanes(df);
@ -996,9 +991,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddTile4(
DF df, const float* HWY_RESTRICT scales, const VF c0, const VF c1,
const VF c2, const VF c3, const MatPtrT<float>& v,
const size_t* HWY_RESTRICT pos, float* HWY_RESTRICT out,
const uint32_t* HWY_RESTRICT out_offsets, const size_t size,
hwy::Profiler& p, const size_t worker) {
PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsMulByConstAndAddTile4));
const uint32_t* HWY_RESTRICT out_offsets, const size_t size) {
namespace hn = hwy::HWY_NAMESPACE;
HWY_LANES_CONSTEXPR size_t NF = hn::Lanes(df);
@ -1037,9 +1030,7 @@ template <class DF, class VF = hn::Vec<DF>>
HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddVector(
DF df, const VF scale, const VF c0, const MatPtrT<float>& v,
const size_t pos, float* HWY_RESTRICT out,
const uint32_t* HWY_RESTRICT out_offsets, const size_t size,
hwy::Profiler& p, const size_t worker) {
PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsMulByConstAndAddVector));
const uint32_t* HWY_RESTRICT out_offsets, const size_t size) {
namespace hn = hwy::HWY_NAMESPACE;
HWY_LANES_CONSTEXPR size_t NF = hn::Lanes(df);
@ -1177,7 +1168,7 @@ static HWY_NOINLINE void Softmax(Logits logits, hwy::Profiler& p,
const float sum_exp = Sum(d, logits.data(), logits.size());
// Double-precision reciprocal does not appear to affect the results.
const float mul = 1.0f / sum_exp;
MulByConst(mul, logits.data(), logits.size(), p, worker);
MulByConst(mul, logits.data(), logits.size());
}
// Note: https://arxiv.org/pdf/2001.04438 proposes to replace the three max /

View File

@ -183,7 +183,7 @@ class TestMulByConstAndAdd {
SimpleMulByConstAndAdd(constant, o, e, count);
InitProfilerZones(hwy::Profiler::Get());
MulByConstAndAdd(constant, o, x, count, hwy::Profiler::Get(), /*worker=*/0);
MulByConstAndAdd(constant, o, x, count);
hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__,
__LINE__);
@ -232,7 +232,7 @@ class TestMulByConst {
SimpleMulByConst(constant, e, count);
InitProfilerZones(hwy::Profiler::Get());
MulByConst(constant, x, count, hwy::Profiler::Get(), /*worker=*/0);
MulByConst(constant, x, count);
hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__,
__LINE__);
@ -443,7 +443,6 @@ void TestRopeAndMulBy() {
ThreadingArgs threading_args;
ThreadingContext ctx(threading_args);
hwy::Profiler& p = ctx.profiler;
InitProfilerZones(p);
const size_t worker = 0;
const ModelConfig config(Model::GEMMA2_9B, Type::kSFP,