Removed the PROFILER_ZONE from the most highly called functions to reduce the overhead.

PiperOrigin-RevId: 819739402
This commit is contained in:
Ray Smith 2025-10-15 07:09:32 -07:00 committed by Copybara-Service
parent e3e8511e79
commit ee18916abf
6 changed files with 43 additions and 59 deletions

View File

@ -89,7 +89,7 @@ void PositionalEncodingQK(float* qk, const size_t layer_idx,
// PostQKType::Rope // PostQKType::Rope
if (post_qk == PostQKType::HalfRope) { if (post_qk == PostQKType::HalfRope) {
Rope(qk, qkv_dim / 2, inv_timescale, pos, p, worker); Rope(qk, qkv_dim / 2, inv_timescale, pos, p, worker);
if (mul != 1.0f) MulByConst(mul, qk, qkv_dim, p, worker); if (mul != 1.0f) MulByConst(mul, qk, qkv_dim);
} else { } else {
RopeAndMulBy(mul, qk, qkv_dim, inv_timescale, pos, p, worker); RopeAndMulBy(mul, qk, qkv_dim, inv_timescale, pos, p, worker);
} }
@ -113,7 +113,7 @@ static HWY_INLINE void WeightedSumV(const size_t start_pos,
MulByConstTo(att[start_pos], v.Row(start_pos), att_out, v.Cols(), p, MulByConstTo(att[start_pos], v.Row(start_pos), att_out, v.Cols(), p,
worker); worker);
for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) { for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) {
MulByConstAndAdd(att[pos], v.Row(pos), att_out, v.Cols(), p, worker); MulByConstAndAdd(att[pos], v.Row(pos), att_out, v.Cols());
} }
} else { } else {
{ {
@ -122,8 +122,7 @@ static HWY_INLINE void WeightedSumV(const size_t start_pos,
} }
for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) { for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) {
const size_t pos_mod = div_seq_len.Remainder(pos); const size_t pos_mod = div_seq_len.Remainder(pos);
MulByConstAndAdd(att[pos_mod], v.Row(pos_mod), att_out, v.Cols(), p, MulByConstAndAdd(att[pos_mod], v.Row(pos_mod), att_out, v.Cols());
worker);
} }
} }
} }

View File

@ -131,10 +131,11 @@ void RMSNormAndPositionalEncoding(const size_t num_tokens, const QBatch& qbatch,
} }
// Handles a single v row of flash attention for a single q.k dot product. // Handles a single v row of flash attention for a single q.k dot product.
void HWY_INLINE SingleFlashAttentionStep( void HWY_INLINE SingleFlashAttentionStep(float x, float cap, float& old_max,
float x, float cap, float& old_max, float& old_d, float& old_d,
const float* HWY_RESTRICT v, const size_t v_cols, const float* HWY_RESTRICT v,
float* HWY_RESTRICT att_out, hwy::Profiler& p, const size_t worker) { const size_t v_cols,
float* HWY_RESTRICT att_out) {
if (cap > 0.0f) { if (cap > 0.0f) {
// Compute tanh(x / cap) * cap, being LogitsSoftCap on the scalar x. // Compute tanh(x / cap) * cap, being LogitsSoftCap on the scalar x.
x = cap * std::tanh(x / cap); x = cap * std::tanh(x / cap);
@ -147,8 +148,8 @@ void HWY_INLINE SingleFlashAttentionStep(
float one_over_d = 1.0f / old_d; float one_over_d = 1.0f / old_d;
scale *= one_over_d; scale *= one_over_d;
x *= one_over_d; x *= one_over_d;
MulByConst(scale, att_out, v_cols, p, worker); MulByConst(scale, att_out, v_cols);
MulByConstAndAdd(x, v, att_out, v_cols, p, worker); MulByConstAndAdd(x, v, att_out, v_cols);
} }
// Calculates the complete attention outputs for a single row of q. // Calculates the complete attention outputs for a single row of q.
@ -174,7 +175,7 @@ void SingleFlashAttention(const size_t start_pos, const size_t last_pos,
const size_t pos_mod = activations.div_seq_len.Remainder(pos); const size_t pos_mod = activations.div_seq_len.Remainder(pos);
float x = Dot(q, k.Row(pos_mod), k.Cols()); float x = Dot(q, k.Row(pos_mod), k.Cols());
SingleFlashAttentionStep(x, activations.config.att_cap, m, d, SingleFlashAttentionStep(x, activations.config.att_cap, m, d,
v.Row(pos_mod), v.Cols(), att_out, p, worker); v.Row(pos_mod), v.Cols(), att_out);
} }
} }
@ -183,7 +184,7 @@ void SingleFlashAttention(const size_t start_pos, const size_t last_pos,
template <class DF, class VF = hn::Vec<DF>> template <class DF, class VF = hn::Vec<DF>>
VF QDotKVector(DF df, const uint32_t* HWY_RESTRICT q_offsets, VF QDotKVector(DF df, const uint32_t* HWY_RESTRICT q_offsets,
const size_t k_pos, const MatPtrT<KV_t>& q, const size_t k_pos, const MatPtrT<KV_t>& q,
const MatPtrT<KV_t>& k, hwy::Profiler& p, const size_t worker) { const MatPtrT<KV_t>& k) {
hn::TFromD<DF> results[hn::MaxLanes(df)]; hn::TFromD<DF> results[hn::MaxLanes(df)];
for (size_t i = 0; i < hn::Lanes(df); ++i) { for (size_t i = 0; i < hn::Lanes(df); ++i) {
results[i] = Dot(q.Row(0) + q_offsets[i], k.Row(k_pos), k.Cols()); results[i] = Dot(q.Row(0) + q_offsets[i], k.Row(k_pos), k.Cols());
@ -198,9 +199,8 @@ VF QDotKVector(DF df, const uint32_t* HWY_RESTRICT q_offsets,
// consecutive elements, and other columns by adding q_stride. // consecutive elements, and other columns by adding q_stride.
template <class DF, class VF = hn::Vec<DF>> template <class DF, class VF = hn::Vec<DF>>
void QDotKTileFloat(DF df, const float* HWY_RESTRICT q, const size_t q_stride, void QDotKTileFloat(DF df, const float* HWY_RESTRICT q, const size_t q_stride,
const MatPtrT<KV_t>& k, const size_t* k_pos, const MatPtrT<KV_t>& k, const size_t* k_pos, VF& sum0,
hwy::Profiler& p, const size_t worker, VF& sum0, VF& sum1, VF& sum1, VF& sum2, VF& sum3, VF& sum4, VF& sum5, VF& sum6,
VF& sum2, VF& sum3, VF& sum4, VF& sum5, VF& sum6,
VF& sum7) { VF& sum7) {
constexpr size_t kHTileSize = kNFx8HTileSize; constexpr size_t kHTileSize = kNFx8HTileSize;
sum0 = hn::Zero(df); sum0 = hn::Zero(df);
@ -303,8 +303,8 @@ void TileFlashAttention(
k_pos[i] = activations.div_seq_len.Remainder(position + i); k_pos[i] = activations.div_seq_len.Remainder(position + i);
} }
VF x0, x1, x2, x3, x4, x5, x6, x7; VF x0, x1, x2, x3, x4, x5, x6, x7;
QDotKTileFloat(df, qT_row, qT_stride, k, k_pos, p, worker, x0, x1, x2, x3, QDotKTileFloat(df, qT_row, qT_stride, k, k_pos, x0, x1, x2, x3, x4, x5, x6,
x4, x5, x6, x7); x7);
if (activations.config.att_cap > 0.0f) { if (activations.config.att_cap > 0.0f) {
// Compute tanh(x / cap) * cap, being LogitsSoftCap on the tile. // Compute tanh(x / cap) * cap, being LogitsSoftCap on the tile.
VF cap = hn::Set(df, activations.config.att_cap); VF cap = hn::Set(df, activations.config.att_cap);
@ -343,12 +343,12 @@ void TileFlashAttention(
x6 = hn::Mul(x6, one_over_d); x6 = hn::Mul(x6, one_over_d);
x7 = hn::Mul(x7, one_over_d); x7 = hn::Mul(x7, one_over_d);
MulByConstAndAddTile(df, scale, x0, x1, x2, x3, x4, x5, x6, x7, v, k_pos, MulByConstAndAddTile(df, scale, x0, x1, x2, x3, x4, x5, x6, x7, v, k_pos,
att_out.Row(0), out_offsets, v.Cols(), p, worker); att_out.Row(0), out_offsets, v.Cols());
position += kHTileSize; position += kHTileSize;
} }
while (position <= max_last_pos) { while (position <= max_last_pos) {
size_t k_pos = activations.div_seq_len.Remainder(position); size_t k_pos = activations.div_seq_len.Remainder(position);
VF x0 = QDotKVector(df, q_offsets, k_pos, q, k, p, worker); VF x0 = QDotKVector(df, q_offsets, k_pos, q, k);
if (activations.config.att_cap > 0.0f) { if (activations.config.att_cap > 0.0f) {
// Compute tanh(x / cap) * cap, being LogitsSoftCap on the vector. // Compute tanh(x / cap) * cap, being LogitsSoftCap on the vector.
VF cap = hn::Set(df, activations.config.att_cap); VF cap = hn::Set(df, activations.config.att_cap);
@ -369,7 +369,7 @@ void TileFlashAttention(
x0 = hn::Mul(x0, one_over_d); x0 = hn::Mul(x0, one_over_d);
scale = hn::Mul(scale, one_over_d); scale = hn::Mul(scale, one_over_d);
MulByConstAndAddVector(df, scale, x0, v, k_pos, att_out.Row(0), out_offsets, MulByConstAndAddVector(df, scale, x0, v, k_pos, att_out.Row(0), out_offsets,
v.Cols(), p, worker); v.Cols());
++position; ++position;
} }
} }
@ -380,8 +380,8 @@ void TileFlashAttention(
template <class DF, class VF = hn::Vec<DF>> template <class DF, class VF = hn::Vec<DF>>
void QDotKTilex4(DF df, const float* HWY_RESTRICT q, void QDotKTilex4(DF df, const float* HWY_RESTRICT q,
const uint32_t* HWY_RESTRICT q_offsets, const MatPtrT<KV_t>& k, const uint32_t* HWY_RESTRICT q_offsets, const MatPtrT<KV_t>& k,
const int32_t* HWY_RESTRICT k_offsets, hwy::Profiler& p, const int32_t* HWY_RESTRICT k_offsets, VF& sum0, VF& sum1,
const size_t worker, VF& sum0, VF& sum1, VF& sum2, VF& sum3) { VF& sum2, VF& sum3) {
sum0 = hn::Zero(df); sum0 = hn::Zero(df);
sum1 = hn::Zero(df); sum1 = hn::Zero(df);
sum2 = hn::Zero(df); sum2 = hn::Zero(df);
@ -462,8 +462,7 @@ void TileFlashAttention4(
k_offsets[i] = k.Row(v_pos[i]) - k.Row(0); k_offsets[i] = k.Row(v_pos[i]) - k.Row(0);
} }
VF x0, x1, x2, x3; VF x0, x1, x2, x3;
QDotKTilex4(df, q.Row(0), q_offsets, k, k_offsets, p, worker, x0, x1, x2, QDotKTilex4(df, q.Row(0), q_offsets, k, k_offsets, x0, x1, x2, x3);
x3);
if (activations.config.att_cap > 0.0f) { if (activations.config.att_cap > 0.0f) {
// Compute tanh(x / cap) * cap, being LogitsSoftCap on the tile. // Compute tanh(x / cap) * cap, being LogitsSoftCap on the tile.
VF cap = hn::Set(df, activations.config.att_cap); VF cap = hn::Set(df, activations.config.att_cap);
@ -478,7 +477,7 @@ void TileFlashAttention4(
scales[2] = SingleFlashAttentionRowVector(df, x2, old_m2, old_d2); scales[2] = SingleFlashAttentionRowVector(df, x2, old_m2, old_d2);
scales[3] = SingleFlashAttentionRowVector(df, x3, old_m3, old_d3); scales[3] = SingleFlashAttentionRowVector(df, x3, old_m3, old_d3);
MulByConstAndAddTile4(df, scales, x0, x1, x2, x3, v, v_pos, att_out.Row(0), MulByConstAndAddTile4(df, scales, x0, x1, x2, x3, v, v_pos, att_out.Row(0),
out_offsets, v.Cols(), p, worker); out_offsets, v.Cols());
position += kHTileSize; position += kHTileSize;
} }
while (position <= max_last_pos) { while (position <= max_last_pos) {
@ -488,28 +487,28 @@ void TileFlashAttention4(
float x0 = Dot(q.Row(0) + q_offsets[0], k.Row(k_pos), k.Cols()); float x0 = Dot(q.Row(0) + q_offsets[0], k.Row(k_pos), k.Cols());
SingleFlashAttentionStep(x0, activations.config.att_cap, old_m0, old_d0, SingleFlashAttentionStep(x0, activations.config.att_cap, old_m0, old_d0,
v.Row(k_pos), v.Cols(), v.Row(k_pos), v.Cols(),
att_out.Row(0) + out_offsets[0], p, worker); att_out.Row(0) + out_offsets[0]);
} }
if (position <= last_pos[1]) { if (position <= last_pos[1]) {
// Past the last position, x1 doesn't count. // Past the last position, x1 doesn't count.
float x1 = Dot(q.Row(0) + q_offsets[1], k.Row(k_pos), k.Cols()); float x1 = Dot(q.Row(0) + q_offsets[1], k.Row(k_pos), k.Cols());
SingleFlashAttentionStep(x1, activations.config.att_cap, old_m1, old_d1, SingleFlashAttentionStep(x1, activations.config.att_cap, old_m1, old_d1,
v.Row(k_pos), v.Cols(), v.Row(k_pos), v.Cols(),
att_out.Row(0) + out_offsets[1], p, worker); att_out.Row(0) + out_offsets[1]);
} }
if (position <= last_pos[2]) { if (position <= last_pos[2]) {
// Past the last position, x2 doesn't count. // Past the last position, x2 doesn't count.
float x2 = Dot(q.Row(0) + q_offsets[2], k.Row(k_pos), k.Cols()); float x2 = Dot(q.Row(0) + q_offsets[2], k.Row(k_pos), k.Cols());
SingleFlashAttentionStep(x2, activations.config.att_cap, old_m2, old_d2, SingleFlashAttentionStep(x2, activations.config.att_cap, old_m2, old_d2,
v.Row(k_pos), v.Cols(), v.Row(k_pos), v.Cols(),
att_out.Row(0) + out_offsets[2], p, worker); att_out.Row(0) + out_offsets[2]);
} }
if (position <= last_pos[3]) { if (position <= last_pos[3]) {
// Past the last position, x3 doesn't count. // Past the last position, x3 doesn't count.
float x3 = Dot(q.Row(0) + q_offsets[3], k.Row(k_pos), k.Cols()); float x3 = Dot(q.Row(0) + q_offsets[3], k.Row(k_pos), k.Cols());
SingleFlashAttentionStep(x3, activations.config.att_cap, old_m3, old_d3, SingleFlashAttentionStep(x3, activations.config.att_cap, old_m3, old_d3,
v.Row(k_pos), v.Cols(), v.Row(k_pos), v.Cols(),
att_out.Row(0) + out_offsets[3], p, worker); att_out.Row(0) + out_offsets[3]);
} }
++position; ++position;
} }

View File

@ -160,7 +160,6 @@ EmbedMMToken(int token, size_t x_row, size_t pos, size_t pos_in_prompt,
const size_t model_dim = model_config.model_dim; const size_t model_dim = model_config.model_dim;
const float emb_scaling = EmbeddingScaling(model_dim); const float emb_scaling = EmbeddingScaling(model_dim);
const size_t worker = 0; // Not yet parallelized.
HWY_DASSERT(token >= 0); HWY_DASSERT(token >= 0);
HWY_DASSERT(token < static_cast<int>(model_config.vocab_size)); HWY_DASSERT(token < static_cast<int>(model_config.vocab_size));
@ -176,8 +175,7 @@ EmbedMMToken(int token, size_t x_row, size_t pos, size_t pos_in_prompt,
const hn::ScalableTag<float> df; const hn::ScalableTag<float> df;
DecompressAndZeroPad(df, embedding_span, embedding_ofs, x.Row(x_row), DecompressAndZeroPad(df, embedding_span, embedding_ofs, x.Row(x_row),
model_dim); model_dim);
MulByConst(emb_scaling * weights_t->Scale(), x.Row(x_row), model_dim, MulByConst(emb_scaling * weights_t->Scale(), x.Row(x_row), model_dim);
ctx.profiler, worker);
}); });
if (model_config.absolute_pe) { if (model_config.absolute_pe) {

View File

@ -95,7 +95,7 @@ class VitAttention {
float* HWY_RESTRICT q = float* HWY_RESTRICT q =
activations_.attention.q.Row(token) + head * 3 * qkv_dim; activations_.attention.q.Row(token) + head * 3 * qkv_dim;
// TODO: shift to MatMul with A.scale once MatMul is confirmed working // TODO: shift to MatMul with A.scale once MatMul is confirmed working
MulByConst(query_scale, q, qkv_dim, env_.ctx.profiler, worker); MulByConst(query_scale, q, qkv_dim);
hwy::CopyBytes(q, Q.Row(token), qkv_dim * sizeof(float)); hwy::CopyBytes(q, Q.Row(token), qkv_dim * sizeof(float));
}); });
@ -120,8 +120,7 @@ class VitAttention {
for (size_t i = 0; i < seq_len; ++i) { for (size_t i = 0; i < seq_len; ++i) {
float* HWY_RESTRICT v = activations_.attention.q.Row(i) + float* HWY_RESTRICT v = activations_.attention.q.Row(i) +
head * 3 * qkv_dim + 2 * qkv_dim; head * 3 * qkv_dim + 2 * qkv_dim;
MulByConstAndAdd(C.Row(token)[i], v, att_out, qkv_dim, MulByConstAndAdd(C.Row(token)[i], v, att_out, qkv_dim);
env_.ctx.profiler, worker);
} }
}); });
} }
@ -144,7 +143,7 @@ class VitAttention {
// Compute Q.K scores, which are "logits" stored in head_att. // Compute Q.K scores, which are "logits" stored in head_att.
float* HWY_RESTRICT q = float* HWY_RESTRICT q =
activations_.attention.q.Row(token) + head * 3 * qkv_dim; activations_.attention.q.Row(token) + head * 3 * qkv_dim;
MulByConst(query_scale, q, qkv_dim, env_.ctx.profiler, worker); MulByConst(query_scale, q, qkv_dim);
float* HWY_RESTRICT head_att = float* HWY_RESTRICT head_att =
activations_.attention.att.Row(token) + head * seq_len; activations_.attention.att.Row(token) + head * seq_len;
for (size_t i = 0; i < seq_len; ++i) { for (size_t i = 0; i < seq_len; ++i) {
@ -161,8 +160,7 @@ class VitAttention {
for (size_t i = 0; i < seq_len; ++i) { for (size_t i = 0; i < seq_len; ++i) {
float* HWY_RESTRICT v = activations_.attention.q.Row(i) + float* HWY_RESTRICT v = activations_.attention.q.Row(i) +
head * 3 * qkv_dim + 2 * qkv_dim; head * 3 * qkv_dim + 2 * qkv_dim;
MulByConstAndAdd(head_att[i], v, att_out, qkv_dim, MulByConstAndAdd(head_att[i], v, att_out, qkv_dim);
env_.ctx.profiler, worker);
} }
}); });
} }

View File

@ -560,10 +560,7 @@ static HWY_INLINE void AddFromBatched(const MatPtrT<XT>& x, MatPtrT<float>& out,
template <typename XT> template <typename XT>
HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConst(const float c, XT* HWY_RESTRICT x, HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConst(const float c, XT* HWY_RESTRICT x,
const size_t size, const size_t size) {
hwy::Profiler& p,
const size_t worker) {
PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsMulByConst));
namespace hn = hwy::HWY_NAMESPACE; namespace hn = hwy::HWY_NAMESPACE;
using DF = hn::ScalableTag<float>; using DF = hn::ScalableTag<float>;
using VF = hn::Vec<DF>; using VF = hn::Vec<DF>;
@ -596,10 +593,10 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstTo(
// out[i] += x[i] * c. // out[i] += x[i] * c.
template <typename XT, typename OT> template <typename XT, typename OT>
HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAdd( HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAdd(const float c,
const float c, const XT* HWY_RESTRICT x, OT* HWY_RESTRICT out, const XT* HWY_RESTRICT x,
const size_t size, hwy::Profiler& p, const size_t worker) { OT* HWY_RESTRICT out,
PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsMulByConstAndAdd)); const size_t size) {
namespace hn = hwy::HWY_NAMESPACE; namespace hn = hwy::HWY_NAMESPACE;
using DF = hn::ScalableTag<float>; using DF = hn::ScalableTag<float>;
using VF = hn::Vec<DF>; using VF = hn::Vec<DF>;
@ -734,9 +731,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddTile(
DF df, const VF scale, const VF c0, const VF c1, const VF c2, const VF c3, DF df, const VF scale, const VF c0, const VF c1, const VF c2, const VF c3,
const VF c4, const VF c5, const VF c6, const VF c7, const MatPtrT<float>& v, const VF c4, const VF c5, const VF c6, const VF c7, const MatPtrT<float>& v,
const size_t* HWY_RESTRICT pos, float* HWY_RESTRICT out, const size_t* HWY_RESTRICT pos, float* HWY_RESTRICT out,
const uint32_t* HWY_RESTRICT out_offsets, const size_t size, const uint32_t* HWY_RESTRICT out_offsets, const size_t size) {
hwy::Profiler& p, const size_t worker) {
PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsMulByConstAndAddTile));
namespace hn = hwy::HWY_NAMESPACE; namespace hn = hwy::HWY_NAMESPACE;
HWY_LANES_CONSTEXPR size_t NF = hn::Lanes(df); HWY_LANES_CONSTEXPR size_t NF = hn::Lanes(df);
@ -996,9 +991,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddTile4(
DF df, const float* HWY_RESTRICT scales, const VF c0, const VF c1, DF df, const float* HWY_RESTRICT scales, const VF c0, const VF c1,
const VF c2, const VF c3, const MatPtrT<float>& v, const VF c2, const VF c3, const MatPtrT<float>& v,
const size_t* HWY_RESTRICT pos, float* HWY_RESTRICT out, const size_t* HWY_RESTRICT pos, float* HWY_RESTRICT out,
const uint32_t* HWY_RESTRICT out_offsets, const size_t size, const uint32_t* HWY_RESTRICT out_offsets, const size_t size) {
hwy::Profiler& p, const size_t worker) {
PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsMulByConstAndAddTile4));
namespace hn = hwy::HWY_NAMESPACE; namespace hn = hwy::HWY_NAMESPACE;
HWY_LANES_CONSTEXPR size_t NF = hn::Lanes(df); HWY_LANES_CONSTEXPR size_t NF = hn::Lanes(df);
@ -1037,9 +1030,7 @@ template <class DF, class VF = hn::Vec<DF>>
HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddVector( HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddVector(
DF df, const VF scale, const VF c0, const MatPtrT<float>& v, DF df, const VF scale, const VF c0, const MatPtrT<float>& v,
const size_t pos, float* HWY_RESTRICT out, const size_t pos, float* HWY_RESTRICT out,
const uint32_t* HWY_RESTRICT out_offsets, const size_t size, const uint32_t* HWY_RESTRICT out_offsets, const size_t size) {
hwy::Profiler& p, const size_t worker) {
PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsMulByConstAndAddVector));
namespace hn = hwy::HWY_NAMESPACE; namespace hn = hwy::HWY_NAMESPACE;
HWY_LANES_CONSTEXPR size_t NF = hn::Lanes(df); HWY_LANES_CONSTEXPR size_t NF = hn::Lanes(df);
@ -1177,7 +1168,7 @@ static HWY_NOINLINE void Softmax(Logits logits, hwy::Profiler& p,
const float sum_exp = Sum(d, logits.data(), logits.size()); const float sum_exp = Sum(d, logits.data(), logits.size());
// Double-precision reciprocal does not appear to affect the results. // Double-precision reciprocal does not appear to affect the results.
const float mul = 1.0f / sum_exp; const float mul = 1.0f / sum_exp;
MulByConst(mul, logits.data(), logits.size(), p, worker); MulByConst(mul, logits.data(), logits.size());
} }
// Note: https://arxiv.org/pdf/2001.04438 proposes to replace the three max / // Note: https://arxiv.org/pdf/2001.04438 proposes to replace the three max /

View File

@ -183,7 +183,7 @@ class TestMulByConstAndAdd {
SimpleMulByConstAndAdd(constant, o, e, count); SimpleMulByConstAndAdd(constant, o, e, count);
InitProfilerZones(hwy::Profiler::Get()); InitProfilerZones(hwy::Profiler::Get());
MulByConstAndAdd(constant, o, x, count, hwy::Profiler::Get(), /*worker=*/0); MulByConstAndAdd(constant, o, x, count);
hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__, hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__,
__LINE__); __LINE__);
@ -232,7 +232,7 @@ class TestMulByConst {
SimpleMulByConst(constant, e, count); SimpleMulByConst(constant, e, count);
InitProfilerZones(hwy::Profiler::Get()); InitProfilerZones(hwy::Profiler::Get());
MulByConst(constant, x, count, hwy::Profiler::Get(), /*worker=*/0); MulByConst(constant, x, count);
hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__, hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__,
__LINE__); __LINE__);
@ -443,7 +443,6 @@ void TestRopeAndMulBy() {
ThreadingArgs threading_args; ThreadingArgs threading_args;
ThreadingContext ctx(threading_args); ThreadingContext ctx(threading_args);
hwy::Profiler& p = ctx.profiler; hwy::Profiler& p = ctx.profiler;
InitProfilerZones(p);
const size_t worker = 0; const size_t worker = 0;
const ModelConfig config(Model::GEMMA2_9B, Type::kSFP, const ModelConfig config(Model::GEMMA2_9B, Type::kSFP,