mirror of https://github.com/google/gemma.cpp.git
Removed the PROFILER_ZONE from the most highly called functions to reduce the overhead.
PiperOrigin-RevId: 819739402
This commit is contained in:
parent
e3e8511e79
commit
ee18916abf
|
|
@ -89,7 +89,7 @@ void PositionalEncodingQK(float* qk, const size_t layer_idx,
|
|||
// PostQKType::Rope
|
||||
if (post_qk == PostQKType::HalfRope) {
|
||||
Rope(qk, qkv_dim / 2, inv_timescale, pos, p, worker);
|
||||
if (mul != 1.0f) MulByConst(mul, qk, qkv_dim, p, worker);
|
||||
if (mul != 1.0f) MulByConst(mul, qk, qkv_dim);
|
||||
} else {
|
||||
RopeAndMulBy(mul, qk, qkv_dim, inv_timescale, pos, p, worker);
|
||||
}
|
||||
|
|
@ -113,7 +113,7 @@ static HWY_INLINE void WeightedSumV(const size_t start_pos,
|
|||
MulByConstTo(att[start_pos], v.Row(start_pos), att_out, v.Cols(), p,
|
||||
worker);
|
||||
for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) {
|
||||
MulByConstAndAdd(att[pos], v.Row(pos), att_out, v.Cols(), p, worker);
|
||||
MulByConstAndAdd(att[pos], v.Row(pos), att_out, v.Cols());
|
||||
}
|
||||
} else {
|
||||
{
|
||||
|
|
@ -122,8 +122,7 @@ static HWY_INLINE void WeightedSumV(const size_t start_pos,
|
|||
}
|
||||
for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) {
|
||||
const size_t pos_mod = div_seq_len.Remainder(pos);
|
||||
MulByConstAndAdd(att[pos_mod], v.Row(pos_mod), att_out, v.Cols(), p,
|
||||
worker);
|
||||
MulByConstAndAdd(att[pos_mod], v.Row(pos_mod), att_out, v.Cols());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -131,10 +131,11 @@ void RMSNormAndPositionalEncoding(const size_t num_tokens, const QBatch& qbatch,
|
|||
}
|
||||
|
||||
// Handles a single v row of flash attention for a single q.k dot product.
|
||||
void HWY_INLINE SingleFlashAttentionStep(
|
||||
float x, float cap, float& old_max, float& old_d,
|
||||
const float* HWY_RESTRICT v, const size_t v_cols,
|
||||
float* HWY_RESTRICT att_out, hwy::Profiler& p, const size_t worker) {
|
||||
void HWY_INLINE SingleFlashAttentionStep(float x, float cap, float& old_max,
|
||||
float& old_d,
|
||||
const float* HWY_RESTRICT v,
|
||||
const size_t v_cols,
|
||||
float* HWY_RESTRICT att_out) {
|
||||
if (cap > 0.0f) {
|
||||
// Compute tanh(x / cap) * cap, being LogitsSoftCap on the scalar x.
|
||||
x = cap * std::tanh(x / cap);
|
||||
|
|
@ -147,8 +148,8 @@ void HWY_INLINE SingleFlashAttentionStep(
|
|||
float one_over_d = 1.0f / old_d;
|
||||
scale *= one_over_d;
|
||||
x *= one_over_d;
|
||||
MulByConst(scale, att_out, v_cols, p, worker);
|
||||
MulByConstAndAdd(x, v, att_out, v_cols, p, worker);
|
||||
MulByConst(scale, att_out, v_cols);
|
||||
MulByConstAndAdd(x, v, att_out, v_cols);
|
||||
}
|
||||
|
||||
// Calculates the complete attention outputs for a single row of q.
|
||||
|
|
@ -174,7 +175,7 @@ void SingleFlashAttention(const size_t start_pos, const size_t last_pos,
|
|||
const size_t pos_mod = activations.div_seq_len.Remainder(pos);
|
||||
float x = Dot(q, k.Row(pos_mod), k.Cols());
|
||||
SingleFlashAttentionStep(x, activations.config.att_cap, m, d,
|
||||
v.Row(pos_mod), v.Cols(), att_out, p, worker);
|
||||
v.Row(pos_mod), v.Cols(), att_out);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -183,7 +184,7 @@ void SingleFlashAttention(const size_t start_pos, const size_t last_pos,
|
|||
template <class DF, class VF = hn::Vec<DF>>
|
||||
VF QDotKVector(DF df, const uint32_t* HWY_RESTRICT q_offsets,
|
||||
const size_t k_pos, const MatPtrT<KV_t>& q,
|
||||
const MatPtrT<KV_t>& k, hwy::Profiler& p, const size_t worker) {
|
||||
const MatPtrT<KV_t>& k) {
|
||||
hn::TFromD<DF> results[hn::MaxLanes(df)];
|
||||
for (size_t i = 0; i < hn::Lanes(df); ++i) {
|
||||
results[i] = Dot(q.Row(0) + q_offsets[i], k.Row(k_pos), k.Cols());
|
||||
|
|
@ -198,9 +199,8 @@ VF QDotKVector(DF df, const uint32_t* HWY_RESTRICT q_offsets,
|
|||
// consecutive elements, and other columns by adding q_stride.
|
||||
template <class DF, class VF = hn::Vec<DF>>
|
||||
void QDotKTileFloat(DF df, const float* HWY_RESTRICT q, const size_t q_stride,
|
||||
const MatPtrT<KV_t>& k, const size_t* k_pos,
|
||||
hwy::Profiler& p, const size_t worker, VF& sum0, VF& sum1,
|
||||
VF& sum2, VF& sum3, VF& sum4, VF& sum5, VF& sum6,
|
||||
const MatPtrT<KV_t>& k, const size_t* k_pos, VF& sum0,
|
||||
VF& sum1, VF& sum2, VF& sum3, VF& sum4, VF& sum5, VF& sum6,
|
||||
VF& sum7) {
|
||||
constexpr size_t kHTileSize = kNFx8HTileSize;
|
||||
sum0 = hn::Zero(df);
|
||||
|
|
@ -303,8 +303,8 @@ void TileFlashAttention(
|
|||
k_pos[i] = activations.div_seq_len.Remainder(position + i);
|
||||
}
|
||||
VF x0, x1, x2, x3, x4, x5, x6, x7;
|
||||
QDotKTileFloat(df, qT_row, qT_stride, k, k_pos, p, worker, x0, x1, x2, x3,
|
||||
x4, x5, x6, x7);
|
||||
QDotKTileFloat(df, qT_row, qT_stride, k, k_pos, x0, x1, x2, x3, x4, x5, x6,
|
||||
x7);
|
||||
if (activations.config.att_cap > 0.0f) {
|
||||
// Compute tanh(x / cap) * cap, being LogitsSoftCap on the tile.
|
||||
VF cap = hn::Set(df, activations.config.att_cap);
|
||||
|
|
@ -343,12 +343,12 @@ void TileFlashAttention(
|
|||
x6 = hn::Mul(x6, one_over_d);
|
||||
x7 = hn::Mul(x7, one_over_d);
|
||||
MulByConstAndAddTile(df, scale, x0, x1, x2, x3, x4, x5, x6, x7, v, k_pos,
|
||||
att_out.Row(0), out_offsets, v.Cols(), p, worker);
|
||||
att_out.Row(0), out_offsets, v.Cols());
|
||||
position += kHTileSize;
|
||||
}
|
||||
while (position <= max_last_pos) {
|
||||
size_t k_pos = activations.div_seq_len.Remainder(position);
|
||||
VF x0 = QDotKVector(df, q_offsets, k_pos, q, k, p, worker);
|
||||
VF x0 = QDotKVector(df, q_offsets, k_pos, q, k);
|
||||
if (activations.config.att_cap > 0.0f) {
|
||||
// Compute tanh(x / cap) * cap, being LogitsSoftCap on the vector.
|
||||
VF cap = hn::Set(df, activations.config.att_cap);
|
||||
|
|
@ -369,7 +369,7 @@ void TileFlashAttention(
|
|||
x0 = hn::Mul(x0, one_over_d);
|
||||
scale = hn::Mul(scale, one_over_d);
|
||||
MulByConstAndAddVector(df, scale, x0, v, k_pos, att_out.Row(0), out_offsets,
|
||||
v.Cols(), p, worker);
|
||||
v.Cols());
|
||||
++position;
|
||||
}
|
||||
}
|
||||
|
|
@ -380,8 +380,8 @@ void TileFlashAttention(
|
|||
template <class DF, class VF = hn::Vec<DF>>
|
||||
void QDotKTilex4(DF df, const float* HWY_RESTRICT q,
|
||||
const uint32_t* HWY_RESTRICT q_offsets, const MatPtrT<KV_t>& k,
|
||||
const int32_t* HWY_RESTRICT k_offsets, hwy::Profiler& p,
|
||||
const size_t worker, VF& sum0, VF& sum1, VF& sum2, VF& sum3) {
|
||||
const int32_t* HWY_RESTRICT k_offsets, VF& sum0, VF& sum1,
|
||||
VF& sum2, VF& sum3) {
|
||||
sum0 = hn::Zero(df);
|
||||
sum1 = hn::Zero(df);
|
||||
sum2 = hn::Zero(df);
|
||||
|
|
@ -462,8 +462,7 @@ void TileFlashAttention4(
|
|||
k_offsets[i] = k.Row(v_pos[i]) - k.Row(0);
|
||||
}
|
||||
VF x0, x1, x2, x3;
|
||||
QDotKTilex4(df, q.Row(0), q_offsets, k, k_offsets, p, worker, x0, x1, x2,
|
||||
x3);
|
||||
QDotKTilex4(df, q.Row(0), q_offsets, k, k_offsets, x0, x1, x2, x3);
|
||||
if (activations.config.att_cap > 0.0f) {
|
||||
// Compute tanh(x / cap) * cap, being LogitsSoftCap on the tile.
|
||||
VF cap = hn::Set(df, activations.config.att_cap);
|
||||
|
|
@ -478,7 +477,7 @@ void TileFlashAttention4(
|
|||
scales[2] = SingleFlashAttentionRowVector(df, x2, old_m2, old_d2);
|
||||
scales[3] = SingleFlashAttentionRowVector(df, x3, old_m3, old_d3);
|
||||
MulByConstAndAddTile4(df, scales, x0, x1, x2, x3, v, v_pos, att_out.Row(0),
|
||||
out_offsets, v.Cols(), p, worker);
|
||||
out_offsets, v.Cols());
|
||||
position += kHTileSize;
|
||||
}
|
||||
while (position <= max_last_pos) {
|
||||
|
|
@ -488,28 +487,28 @@ void TileFlashAttention4(
|
|||
float x0 = Dot(q.Row(0) + q_offsets[0], k.Row(k_pos), k.Cols());
|
||||
SingleFlashAttentionStep(x0, activations.config.att_cap, old_m0, old_d0,
|
||||
v.Row(k_pos), v.Cols(),
|
||||
att_out.Row(0) + out_offsets[0], p, worker);
|
||||
att_out.Row(0) + out_offsets[0]);
|
||||
}
|
||||
if (position <= last_pos[1]) {
|
||||
// Past the last position, x1 doesn't count.
|
||||
float x1 = Dot(q.Row(0) + q_offsets[1], k.Row(k_pos), k.Cols());
|
||||
SingleFlashAttentionStep(x1, activations.config.att_cap, old_m1, old_d1,
|
||||
v.Row(k_pos), v.Cols(),
|
||||
att_out.Row(0) + out_offsets[1], p, worker);
|
||||
att_out.Row(0) + out_offsets[1]);
|
||||
}
|
||||
if (position <= last_pos[2]) {
|
||||
// Past the last position, x2 doesn't count.
|
||||
float x2 = Dot(q.Row(0) + q_offsets[2], k.Row(k_pos), k.Cols());
|
||||
SingleFlashAttentionStep(x2, activations.config.att_cap, old_m2, old_d2,
|
||||
v.Row(k_pos), v.Cols(),
|
||||
att_out.Row(0) + out_offsets[2], p, worker);
|
||||
att_out.Row(0) + out_offsets[2]);
|
||||
}
|
||||
if (position <= last_pos[3]) {
|
||||
// Past the last position, x3 doesn't count.
|
||||
float x3 = Dot(q.Row(0) + q_offsets[3], k.Row(k_pos), k.Cols());
|
||||
SingleFlashAttentionStep(x3, activations.config.att_cap, old_m3, old_d3,
|
||||
v.Row(k_pos), v.Cols(),
|
||||
att_out.Row(0) + out_offsets[3], p, worker);
|
||||
att_out.Row(0) + out_offsets[3]);
|
||||
}
|
||||
++position;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -160,7 +160,6 @@ EmbedMMToken(int token, size_t x_row, size_t pos, size_t pos_in_prompt,
|
|||
|
||||
const size_t model_dim = model_config.model_dim;
|
||||
const float emb_scaling = EmbeddingScaling(model_dim);
|
||||
const size_t worker = 0; // Not yet parallelized.
|
||||
|
||||
HWY_DASSERT(token >= 0);
|
||||
HWY_DASSERT(token < static_cast<int>(model_config.vocab_size));
|
||||
|
|
@ -176,8 +175,7 @@ EmbedMMToken(int token, size_t x_row, size_t pos, size_t pos_in_prompt,
|
|||
const hn::ScalableTag<float> df;
|
||||
DecompressAndZeroPad(df, embedding_span, embedding_ofs, x.Row(x_row),
|
||||
model_dim);
|
||||
MulByConst(emb_scaling * weights_t->Scale(), x.Row(x_row), model_dim,
|
||||
ctx.profiler, worker);
|
||||
MulByConst(emb_scaling * weights_t->Scale(), x.Row(x_row), model_dim);
|
||||
});
|
||||
|
||||
if (model_config.absolute_pe) {
|
||||
|
|
|
|||
10
gemma/vit.cc
10
gemma/vit.cc
|
|
@ -95,7 +95,7 @@ class VitAttention {
|
|||
float* HWY_RESTRICT q =
|
||||
activations_.attention.q.Row(token) + head * 3 * qkv_dim;
|
||||
// TODO: shift to MatMul with A.scale once MatMul is confirmed working
|
||||
MulByConst(query_scale, q, qkv_dim, env_.ctx.profiler, worker);
|
||||
MulByConst(query_scale, q, qkv_dim);
|
||||
hwy::CopyBytes(q, Q.Row(token), qkv_dim * sizeof(float));
|
||||
});
|
||||
|
||||
|
|
@ -120,8 +120,7 @@ class VitAttention {
|
|||
for (size_t i = 0; i < seq_len; ++i) {
|
||||
float* HWY_RESTRICT v = activations_.attention.q.Row(i) +
|
||||
head * 3 * qkv_dim + 2 * qkv_dim;
|
||||
MulByConstAndAdd(C.Row(token)[i], v, att_out, qkv_dim,
|
||||
env_.ctx.profiler, worker);
|
||||
MulByConstAndAdd(C.Row(token)[i], v, att_out, qkv_dim);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
|
@ -144,7 +143,7 @@ class VitAttention {
|
|||
// Compute Q.K scores, which are "logits" stored in head_att.
|
||||
float* HWY_RESTRICT q =
|
||||
activations_.attention.q.Row(token) + head * 3 * qkv_dim;
|
||||
MulByConst(query_scale, q, qkv_dim, env_.ctx.profiler, worker);
|
||||
MulByConst(query_scale, q, qkv_dim);
|
||||
float* HWY_RESTRICT head_att =
|
||||
activations_.attention.att.Row(token) + head * seq_len;
|
||||
for (size_t i = 0; i < seq_len; ++i) {
|
||||
|
|
@ -161,8 +160,7 @@ class VitAttention {
|
|||
for (size_t i = 0; i < seq_len; ++i) {
|
||||
float* HWY_RESTRICT v = activations_.attention.q.Row(i) +
|
||||
head * 3 * qkv_dim + 2 * qkv_dim;
|
||||
MulByConstAndAdd(head_att[i], v, att_out, qkv_dim,
|
||||
env_.ctx.profiler, worker);
|
||||
MulByConstAndAdd(head_att[i], v, att_out, qkv_dim);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
|
|
|||
|
|
@ -560,10 +560,7 @@ static HWY_INLINE void AddFromBatched(const MatPtrT<XT>& x, MatPtrT<float>& out,
|
|||
|
||||
template <typename XT>
|
||||
HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConst(const float c, XT* HWY_RESTRICT x,
|
||||
const size_t size,
|
||||
hwy::Profiler& p,
|
||||
const size_t worker) {
|
||||
PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsMulByConst));
|
||||
const size_t size) {
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
using DF = hn::ScalableTag<float>;
|
||||
using VF = hn::Vec<DF>;
|
||||
|
|
@ -596,10 +593,10 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstTo(
|
|||
|
||||
// out[i] += x[i] * c.
|
||||
template <typename XT, typename OT>
|
||||
HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAdd(
|
||||
const float c, const XT* HWY_RESTRICT x, OT* HWY_RESTRICT out,
|
||||
const size_t size, hwy::Profiler& p, const size_t worker) {
|
||||
PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsMulByConstAndAdd));
|
||||
HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAdd(const float c,
|
||||
const XT* HWY_RESTRICT x,
|
||||
OT* HWY_RESTRICT out,
|
||||
const size_t size) {
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
using DF = hn::ScalableTag<float>;
|
||||
using VF = hn::Vec<DF>;
|
||||
|
|
@ -734,9 +731,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddTile(
|
|||
DF df, const VF scale, const VF c0, const VF c1, const VF c2, const VF c3,
|
||||
const VF c4, const VF c5, const VF c6, const VF c7, const MatPtrT<float>& v,
|
||||
const size_t* HWY_RESTRICT pos, float* HWY_RESTRICT out,
|
||||
const uint32_t* HWY_RESTRICT out_offsets, const size_t size,
|
||||
hwy::Profiler& p, const size_t worker) {
|
||||
PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsMulByConstAndAddTile));
|
||||
const uint32_t* HWY_RESTRICT out_offsets, const size_t size) {
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
HWY_LANES_CONSTEXPR size_t NF = hn::Lanes(df);
|
||||
|
||||
|
|
@ -996,9 +991,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddTile4(
|
|||
DF df, const float* HWY_RESTRICT scales, const VF c0, const VF c1,
|
||||
const VF c2, const VF c3, const MatPtrT<float>& v,
|
||||
const size_t* HWY_RESTRICT pos, float* HWY_RESTRICT out,
|
||||
const uint32_t* HWY_RESTRICT out_offsets, const size_t size,
|
||||
hwy::Profiler& p, const size_t worker) {
|
||||
PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsMulByConstAndAddTile4));
|
||||
const uint32_t* HWY_RESTRICT out_offsets, const size_t size) {
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
HWY_LANES_CONSTEXPR size_t NF = hn::Lanes(df);
|
||||
|
||||
|
|
@ -1037,9 +1030,7 @@ template <class DF, class VF = hn::Vec<DF>>
|
|||
HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddVector(
|
||||
DF df, const VF scale, const VF c0, const MatPtrT<float>& v,
|
||||
const size_t pos, float* HWY_RESTRICT out,
|
||||
const uint32_t* HWY_RESTRICT out_offsets, const size_t size,
|
||||
hwy::Profiler& p, const size_t worker) {
|
||||
PROFILER_ZONE3(p, worker, GetProfilerZone(Zones::kOpsMulByConstAndAddVector));
|
||||
const uint32_t* HWY_RESTRICT out_offsets, const size_t size) {
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
HWY_LANES_CONSTEXPR size_t NF = hn::Lanes(df);
|
||||
|
||||
|
|
@ -1177,7 +1168,7 @@ static HWY_NOINLINE void Softmax(Logits logits, hwy::Profiler& p,
|
|||
const float sum_exp = Sum(d, logits.data(), logits.size());
|
||||
// Double-precision reciprocal does not appear to affect the results.
|
||||
const float mul = 1.0f / sum_exp;
|
||||
MulByConst(mul, logits.data(), logits.size(), p, worker);
|
||||
MulByConst(mul, logits.data(), logits.size());
|
||||
}
|
||||
|
||||
// Note: https://arxiv.org/pdf/2001.04438 proposes to replace the three max /
|
||||
|
|
|
|||
|
|
@ -183,7 +183,7 @@ class TestMulByConstAndAdd {
|
|||
|
||||
SimpleMulByConstAndAdd(constant, o, e, count);
|
||||
InitProfilerZones(hwy::Profiler::Get());
|
||||
MulByConstAndAdd(constant, o, x, count, hwy::Profiler::Get(), /*worker=*/0);
|
||||
MulByConstAndAdd(constant, o, x, count);
|
||||
|
||||
hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__,
|
||||
__LINE__);
|
||||
|
|
@ -232,7 +232,7 @@ class TestMulByConst {
|
|||
|
||||
SimpleMulByConst(constant, e, count);
|
||||
InitProfilerZones(hwy::Profiler::Get());
|
||||
MulByConst(constant, x, count, hwy::Profiler::Get(), /*worker=*/0);
|
||||
MulByConst(constant, x, count);
|
||||
|
||||
hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__,
|
||||
__LINE__);
|
||||
|
|
@ -443,7 +443,6 @@ void TestRopeAndMulBy() {
|
|||
ThreadingArgs threading_args;
|
||||
ThreadingContext ctx(threading_args);
|
||||
hwy::Profiler& p = ctx.profiler;
|
||||
InitProfilerZones(p);
|
||||
const size_t worker = 0;
|
||||
|
||||
const ModelConfig config(Model::GEMMA2_9B, Type::kSFP,
|
||||
|
|
|
|||
Loading…
Reference in New Issue