mirror of https://github.com/google/gemma.cpp.git
1.14x batch decode speedup: parallelize RMSNorm ops
Activations was over-parallelized, use single pool instead. Also improve profiler zone annotations, pass through worker args (for tracking concurrency), now non-optional. PiperOrigin-RevId: 788790976
This commit is contained in:
parent
ac0d751d20
commit
d1638587f0
|
|
@ -474,6 +474,7 @@ cc_library(
|
|||
":matmul",
|
||||
"//io",
|
||||
"@highway//:hwy",
|
||||
"@highway//:profiler",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -85,7 +85,7 @@ namespace gcpp {
|
|||
namespace HWY_NAMESPACE {
|
||||
|
||||
void CallSoftmax(float* HWY_RESTRICT logits, size_t vocab_size) {
|
||||
Softmax(logits, vocab_size);
|
||||
Softmax(logits, vocab_size, /*worker=*/0);
|
||||
}
|
||||
|
||||
} // namespace HWY_NAMESPACE
|
||||
|
|
|
|||
|
|
@ -29,7 +29,6 @@
|
|||
#include "util/allocator.h" // Allocator
|
||||
#include "util/basics.h" // BF16
|
||||
#include "util/mat.h" // MatStorageT
|
||||
#include "hwy/profiler.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
|
|
@ -182,8 +181,8 @@ struct Activations {
|
|||
// Note that BindC on any MatMul output considerably slows down Prefill.
|
||||
}
|
||||
|
||||
// Negligible CPU time.
|
||||
void SetBatchSize(size_t batch_size) {
|
||||
PROFILER_ZONE("SetBatchSize");
|
||||
x.OverrideRows(batch_size);
|
||||
logits.OverrideRows(batch_size);
|
||||
|
||||
|
|
|
|||
|
|
@ -150,7 +150,7 @@ void SingleDotSoftmaxWeightedSum(
|
|||
// SoftMax with optional SoftCap yields "probabilities" in att.
|
||||
const size_t att_len = HWY_MIN(last_pos + 1, seq_len);
|
||||
MaybeLogitsSoftCap(att_cap, att, att_len, worker);
|
||||
Softmax(att, att_len, /*temperature=*/1.0f, worker);
|
||||
Softmax(att, att_len, worker, /*temperature=*/1.0f);
|
||||
|
||||
WeightedSumV(start_pos, last_pos, activations.div_seq_len, att, v, att_out,
|
||||
worker);
|
||||
|
|
@ -168,7 +168,6 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
|
|||
const LayerWeightsPtrs& layer,
|
||||
AttentionActivations& activations, QBatch& qbatch,
|
||||
NestedPools& pools) {
|
||||
PROFILER_ZONE("Gen.Attention.DotSoftmax.misc");
|
||||
static const uint32_t HWY_MAYBE_UNUSED zone_id_par =
|
||||
PROFILER_ADD_ZONE("Gen.Attention.DotSoftmax.par");
|
||||
|
||||
|
|
@ -227,8 +226,13 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
|
|||
layer, activations, att, att_out, worker);
|
||||
};
|
||||
|
||||
ParallelFor(num_tokens * div_qbatch.GetDivisor() * layer_config.heads, pools,
|
||||
/*pkg_idx=*/0, func);
|
||||
{
|
||||
PROFILER_ZONE("Gen.Attention.DotSoftmax.ForkJoin");
|
||||
const size_t pkg_idx = 0;
|
||||
// Full parallelism is helpful, SmallParallelFor is insufficient.
|
||||
ParallelFor(num_tokens * div_qbatch.GetDivisor() * layer_config.heads,
|
||||
pools, pkg_idx, func);
|
||||
}
|
||||
}
|
||||
|
||||
// Different functions use different naming conventions for the number of
|
||||
|
|
|
|||
|
|
@ -66,11 +66,12 @@ void Activation(ActivationType activation, T* HWY_RESTRICT c1,
|
|||
template <class Mat>
|
||||
void ActivationBatched(ActivationType activation, Mat& c1, NestedPools& pools) {
|
||||
using T = typename Mat::T;
|
||||
ParallelFor(c1.Rows(), pools, /*pkg_idx=*/0,
|
||||
[&](uint64_t task, size_t worker) {
|
||||
const size_t pkg_idx = 0;
|
||||
SmallParallelFor(
|
||||
c1.Rows(), pools, pkg_idx, [&](uint64_t task, size_t worker) {
|
||||
// Cast to correct type so type deduction works.
|
||||
Activation(activation, c1.Row(task),
|
||||
static_cast<const T*>(nullptr), c1.Cols(), worker);
|
||||
Activation(activation, c1.Row(task), static_cast<const T*>(nullptr),
|
||||
c1.Cols(), worker);
|
||||
});
|
||||
}
|
||||
|
||||
|
|
@ -79,17 +80,18 @@ HWY_NOINLINE void ActivationBatched(ActivationType activation, Mat& c1,
|
|||
const Mat* c2, NestedPools& pools) {
|
||||
using T = typename Mat::T;
|
||||
HWY_DASSERT(c1.SameShape(*c2));
|
||||
const size_t pkg_idx = 0;
|
||||
if (c2 && c2->HasPtr()) {
|
||||
ParallelFor(c1.Rows(), pools, /*pkg_idx=*/0,
|
||||
SmallParallelFor(c1.Rows(), pools, pkg_idx,
|
||||
[&](uint64_t task, size_t worker) {
|
||||
Activation(activation, c1.Row(task), c2->Row(task), c1.Cols(),
|
||||
worker);
|
||||
Activation(activation, c1.Row(task), c2->Row(task),
|
||||
c1.Cols(), worker);
|
||||
});
|
||||
} else { // No multiplier
|
||||
ParallelFor(c1.Rows(), pools, /*pkg_idx=*/0,
|
||||
[&](uint64_t task, size_t worker) {
|
||||
Activation(activation, c1.Row(task),
|
||||
static_cast<const T*>(nullptr), c1.Cols(), worker);
|
||||
SmallParallelFor(
|
||||
c1.Rows(), pools, pkg_idx, [&](uint64_t task, size_t worker) {
|
||||
Activation(activation, c1.Row(task), static_cast<const T*>(nullptr),
|
||||
c1.Cols(), worker);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
|
@ -98,17 +100,17 @@ template <typename T2, class LayerWeights>
|
|||
HWY_NOINLINE void ResidualConnection(const MatPtrT<T2>& other,
|
||||
MatPtrT<float>& HWY_RESTRICT x,
|
||||
const LayerWeights& layer,
|
||||
bool is_attention) {
|
||||
bool is_attention, ThreadingContext& ctx) {
|
||||
// ResidualType::Add
|
||||
AddFromBatched(other, x);
|
||||
AddFromBatched(other, x, ctx);
|
||||
}
|
||||
|
||||
template <typename InOutT>
|
||||
void PostNorm(PostNormType post_norm, const MatPtr& weights,
|
||||
MatPtrT<InOutT>& inout) {
|
||||
MatPtrT<InOutT>& inout, ThreadingContext& ctx) {
|
||||
HWY_DASSERT(weights.Rows() == 1);
|
||||
if (post_norm == PostNormType::Scale) {
|
||||
RMSNormInplaceBatched(weights, inout);
|
||||
RMSNormInplaceBatched(weights, inout, ctx);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -92,19 +92,19 @@ static HWY_NOINLINE void TransformerLayer(const size_t num_tokens,
|
|||
const LayerConfig& layer_config = layer.layer_config;
|
||||
|
||||
RMSNormBatched(activations.x, layer.pre_attention_norm_scale,
|
||||
activations.attention.pre_att_rms_out);
|
||||
activations.attention.pre_att_rms_out, env.ctx);
|
||||
|
||||
Attention(layer_config.type, num_tokens, layer_idx, layer, activations,
|
||||
qbatch, env);
|
||||
|
||||
PostNorm(layer_config.post_norm, layer.post_attention_norm_scale,
|
||||
activations.attention.att_sums);
|
||||
activations.attention.att_sums, env.ctx);
|
||||
|
||||
ResidualConnection(activations.attention.att_sums, activations.x, layer,
|
||||
/*is_attention=*/true);
|
||||
/*is_attention=*/true, env.ctx);
|
||||
|
||||
RMSNormBatched(activations.x, layer.pre_ffw_norm_scale,
|
||||
activations.pre_ffw_rms_out);
|
||||
activations.pre_ffw_rms_out, env.ctx);
|
||||
|
||||
if (layer_config.type == LayerAttentionType::kVit) {
|
||||
FFWVit(layer, activations, env);
|
||||
|
|
@ -113,10 +113,10 @@ static HWY_NOINLINE void TransformerLayer(const size_t num_tokens,
|
|||
}
|
||||
|
||||
PostNorm(layer_config.post_norm, layer.post_ffw_norm_scale,
|
||||
activations.ffw_out);
|
||||
activations.ffw_out, env.ctx);
|
||||
|
||||
ResidualConnection(activations.ffw_out, activations.x, layer,
|
||||
/*is_attention=*/false);
|
||||
/*is_attention=*/false, env.ctx);
|
||||
}
|
||||
|
||||
// Returns the scale value to use for the embedding (basically sqrt model_dim).
|
||||
|
|
@ -158,6 +158,7 @@ EmbedMMToken(int token, size_t qi, size_t pos, size_t pos_in_prompt,
|
|||
|
||||
const size_t model_dim = model_config.model_dim;
|
||||
const float emb_scaling = EmbeddingScaling(model_dim);
|
||||
const size_t worker = 0; // Not yet parallelized.
|
||||
|
||||
HWY_DASSERT(token >= 0);
|
||||
HWY_DASSERT(token < static_cast<int>(model_config.vocab_size));
|
||||
|
|
@ -173,7 +174,7 @@ EmbedMMToken(int token, size_t qi, size_t pos, size_t pos_in_prompt,
|
|||
const hn::ScalableTag<float> df;
|
||||
DecompressAndZeroPad(df, embedding_span, embedding_ofs, x.Row(qi),
|
||||
model_dim);
|
||||
MulByConst(emb_scaling * weights_t->Scale(), x.Row(qi), model_dim);
|
||||
MulByConst(emb_scaling * weights_t->Scale(), x.Row(qi), model_dim, worker);
|
||||
});
|
||||
|
||||
if (model_config.absolute_pe) {
|
||||
|
|
@ -302,6 +303,7 @@ static HWY_NOINLINE void Transformer(const ModelConfig& config,
|
|||
}
|
||||
}
|
||||
|
||||
// TODO: parallelize?
|
||||
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
|
||||
EmbedMMToken(qbatch.PrevToken(qi), qi, qbatch.Pos(qi),
|
||||
/*pos_in_prompt=*/0, config, weights, activations.x);
|
||||
|
|
@ -328,7 +330,7 @@ static HWY_NOINLINE void PrefillQBatch(const size_t max_prompt_size,
|
|||
Activations& activations, QBatch& qbatch,
|
||||
MatMulEnv& env,
|
||||
hwy::BitSet4096<>& non_eos) {
|
||||
PROFILER_ZONE("Gen.Prefill");
|
||||
PROFILER_ZONE("Gen.PrefillQ");
|
||||
|
||||
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
|
||||
non_eos.Set(qi);
|
||||
|
|
@ -400,7 +402,7 @@ static void DecodeStepT(const ModelConfig& config,
|
|||
|
||||
Transformer(config, runtime_config, weights, activations, qbatch, env);
|
||||
|
||||
RMSNormInplaceBatched(weights.final_norm_scale, activations.x);
|
||||
RMSNormInplaceBatched(weights.final_norm_scale, activations.x, env.ctx);
|
||||
|
||||
if (HWY_UNLIKELY(runtime_config.activations_observer)) {
|
||||
runtime_config.activations_observer(
|
||||
|
|
@ -414,9 +416,10 @@ static void DecodeStepT(const ModelConfig& config,
|
|||
/*add=*/nullptr, env, activations.logits);
|
||||
}
|
||||
PROFILER_ZONE("Gen.Softcap+Sample+Stream");
|
||||
const size_t worker = 0; // TODO: parallelize
|
||||
non_eos.Foreach([&](size_t qi) {
|
||||
float* HWY_RESTRICT logits = activations.logits.Row(qi);
|
||||
MaybeLogitsSoftCap(config.final_cap, logits, config.vocab_size);
|
||||
MaybeLogitsSoftCap(config.final_cap, logits, config.vocab_size, worker);
|
||||
const TokenAndProb tp = sample_token(logits, config.vocab_size);
|
||||
timing_info.NotifyGenerated();
|
||||
|
||||
|
|
@ -430,10 +433,12 @@ ChooseSampleFunc(const RuntimeConfig& runtime_config) {
|
|||
// If user provided a sample_func, use it.
|
||||
if (runtime_config.sample_func) return runtime_config.sample_func;
|
||||
|
||||
const size_t worker = 0; // TODO: parallelize
|
||||
|
||||
// Fast path for top-1 with no accept_token.
|
||||
if (runtime_config.top_k == 1 && !runtime_config.accept_token) {
|
||||
return [](float* logits, size_t vocab_size) HWY_ATTR -> TokenAndProb {
|
||||
PROFILER_ZONE("Gen.Sample Top1");
|
||||
PROFILER_ZONE2(worker, "Gen.Sample Top1");
|
||||
return Top1OfSoftmax(logits, vocab_size);
|
||||
};
|
||||
}
|
||||
|
|
@ -444,7 +449,7 @@ ChooseSampleFunc(const RuntimeConfig& runtime_config) {
|
|||
PROFILER_ZONE("Gen.Sample general");
|
||||
return FusedSoftmaxAndSampleTopK(
|
||||
logits, runtime_config.top_k, vocab_size, *runtime_config.gen,
|
||||
runtime_config.temperature, runtime_config.accept_token);
|
||||
runtime_config.temperature, runtime_config.accept_token, worker);
|
||||
};
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -32,6 +32,7 @@
|
|||
#include "util/mat.h"
|
||||
#include "hwy/aligned_allocator.h" // Span
|
||||
#include "hwy/base.h" // HWY_ABORT
|
||||
#include "hwy/profiler.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
|
|
@ -116,6 +117,7 @@ struct RuntimeConfig {
|
|||
// If non-null, `batch_stream_token` is called for each token in the batch,
|
||||
// otherwise `stream_token`. `query_idx` is absolute, not batch-relative.
|
||||
bool StreamToken(size_t query_idx, size_t pos, int token, float prob) const {
|
||||
PROFILER_ZONE("Gen.StreamToken");
|
||||
if (batch_stream_token) {
|
||||
return batch_stream_token(query_idx, pos, token, prob);
|
||||
}
|
||||
|
|
|
|||
28
gemma/vit.cc
28
gemma/vit.cc
|
|
@ -90,12 +90,12 @@ class VitAttention {
|
|||
ZeroInit(activations_.attention.att_out);
|
||||
|
||||
for (size_t head = 0; head < heads; ++head) {
|
||||
pool_.Run(0, num_tokens_, [&](uint64_t task, size_t /*thread*/) HWY_ATTR {
|
||||
pool_.Run(0, num_tokens_, [&](uint64_t task, size_t worker) HWY_ATTR {
|
||||
const size_t token = task;
|
||||
float* HWY_RESTRICT q =
|
||||
activations_.attention.q.Row(token) + head * 3 * qkv_dim;
|
||||
// TODO: shift to MatMul with A.scale once MatMul is confirmed working
|
||||
MulByConst(query_scale, q, qkv_dim);
|
||||
MulByConst(query_scale, q, qkv_dim, worker);
|
||||
hwy::CopyBytes(q, Q.Row(token), qkv_dim * sizeof(float));
|
||||
});
|
||||
|
||||
|
|
@ -109,19 +109,19 @@ class VitAttention {
|
|||
// this produces C, a (num_tokens_, seq_len) matrix of dot products
|
||||
CallMatMul(Q, K, nullptr, env_, C);
|
||||
|
||||
pool_.Run(0, num_tokens_, [&](uint64_t task, size_t /*thread*/) HWY_ATTR {
|
||||
pool_.Run(0, num_tokens_, [&](uint64_t task, size_t worker) HWY_ATTR {
|
||||
float* HWY_RESTRICT c = C.Row(task);
|
||||
Softmax(c, C.Cols());
|
||||
Softmax(c, C.Cols(), worker);
|
||||
});
|
||||
|
||||
pool_.Run(0, num_tokens_, [&](uint64_t task, size_t /*thread*/) HWY_ATTR {
|
||||
pool_.Run(0, num_tokens_, [&](uint64_t task, size_t worker) HWY_ATTR {
|
||||
size_t token = task;
|
||||
float* HWY_RESTRICT att_out =
|
||||
activations_.attention.att_out.Row(token) + head * qkv_dim;
|
||||
for (size_t i = 0; i < seq_len; ++i) {
|
||||
float* HWY_RESTRICT v = activations_.attention.q.Row(i) +
|
||||
head * 3 * qkv_dim + 2 * qkv_dim;
|
||||
MulByConstAndAdd(C.Row(token)[i], v, att_out, qkv_dim);
|
||||
MulByConstAndAdd(C.Row(token)[i], v, att_out, qkv_dim, worker);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
|
@ -138,13 +138,13 @@ class VitAttention {
|
|||
|
||||
// Compute Q.K, softmax, and weighted V.
|
||||
pool_.Run(0, layer_config_.heads * num_tokens_,
|
||||
[&](uint64_t task, size_t /*thread*/) HWY_ATTR {
|
||||
[&](uint64_t task, size_t worker) HWY_ATTR {
|
||||
const size_t head = task % layer_config_.heads;
|
||||
const size_t token = task / layer_config_.heads;
|
||||
// Compute Q.K scores, which are "logits" stored in head_att.
|
||||
float* HWY_RESTRICT q =
|
||||
activations_.attention.q.Row(token) + head * 3 * qkv_dim;
|
||||
MulByConst(query_scale, q, qkv_dim);
|
||||
MulByConst(query_scale, q, qkv_dim, worker);
|
||||
float* HWY_RESTRICT head_att =
|
||||
activations_.attention.att.Row(token) + head * seq_len;
|
||||
for (size_t i = 0; i < seq_len; ++i) {
|
||||
|
|
@ -153,7 +153,7 @@ class VitAttention {
|
|||
head_att[i] = Dot(q, k, qkv_dim); // score = q.k
|
||||
}
|
||||
// SoftMax yields "probabilities" in head_att.
|
||||
Softmax(head_att, seq_len);
|
||||
Softmax(head_att, seq_len, worker);
|
||||
// Compute weighted sum of v into att_out.
|
||||
float* HWY_RESTRICT att_out =
|
||||
activations_.attention.att_out.Row(token) + head * qkv_dim;
|
||||
|
|
@ -161,7 +161,7 @@ class VitAttention {
|
|||
for (size_t i = 0; i < seq_len; ++i) {
|
||||
float* HWY_RESTRICT v = activations_.attention.q.Row(i) +
|
||||
head * 3 * qkv_dim + 2 * qkv_dim;
|
||||
MulByConstAndAdd(head_att[i], v, att_out, qkv_dim);
|
||||
MulByConstAndAdd(head_att[i], v, att_out, qkv_dim, worker);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
|
@ -259,7 +259,7 @@ void VitTransformerLayer(size_t num_tokens, const size_t layer_idx,
|
|||
VitAttention(num_tokens, layer_idx, activations, layer, env)();
|
||||
|
||||
// x = out["+sa"] = x + y
|
||||
AddFromBatched(activations.attention.att_sums, x);
|
||||
AddFromBatched(activations.attention.att_sums, x, env.ctx);
|
||||
|
||||
// y = nn.LayerNorm()(x)
|
||||
// y ~ pre_ffw_rms_out
|
||||
|
|
@ -271,7 +271,7 @@ void VitTransformerLayer(size_t num_tokens, const size_t layer_idx,
|
|||
FFWVit(layer, activations, env);
|
||||
|
||||
// x = out["+mlp"] = x + y
|
||||
AddFromBatched(activations.ffw_out, x);
|
||||
AddFromBatched(activations.ffw_out, x, env.ctx);
|
||||
}
|
||||
|
||||
// Gets the patches of the image and embeds them with the image embedding
|
||||
|
|
@ -303,7 +303,7 @@ static HWY_NOINLINE void EmbedImagePatches(const Image& image,
|
|||
// Add position embeddings.
|
||||
CallUpcastedActivation(&weights.vit_img_pos_embedding,
|
||||
[&](const auto* weights_t) {
|
||||
AddFromBatched(*weights_t, activations.x);
|
||||
AddFromBatched(*weights_t, activations.x, env.ctx);
|
||||
});
|
||||
}
|
||||
|
||||
|
|
@ -334,7 +334,7 @@ void PrefillVit(const ModelConfig& model_config, const WeightsPtrs& weights,
|
|||
// Apply soft embedding norm before input projection.
|
||||
CallUpcasted(&weights.mm_embed_norm, [&](const auto* weights_t) {
|
||||
RMSNormInplace(weights_t->PackedScale1(), 0, activations.x.Row(0),
|
||||
vit_model_dim);
|
||||
vit_model_dim, /*worker=*/0);
|
||||
});
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -875,9 +875,6 @@ class MMPerPackage {
|
|||
inner_tasks_(config.InnerTasks()),
|
||||
out_(config.Out()),
|
||||
line_bytes_(args.env->ctx.allocator.LineBytes()) {
|
||||
static const uint32_t zone_id = PROFILER_ADD_ZONE("MM.DecompressA");
|
||||
MMZone zone;
|
||||
zone.MaybeEnter(pkg_idx, zone_id, args_);
|
||||
A_ = DecompressA(A);
|
||||
}
|
||||
|
||||
|
|
@ -1119,8 +1116,14 @@ class MMPerPackage {
|
|||
const size_t NBF = hn::Lanes(dbf);
|
||||
static_assert(hwy::IsSameEither<TA, BF16, float>(), "Can seek");
|
||||
|
||||
static const uint32_t zone_id = PROFILER_ADD_ZONE("MM.DecompressA");
|
||||
|
||||
const auto do_range = [&](const IndexRange& range_M,
|
||||
const IndexRange& range_K) HWY_ATTR {
|
||||
const IndexRange& range_K,
|
||||
size_t worker) HWY_ATTR {
|
||||
MMZone zone;
|
||||
zone.MaybeEnter(worker, zone_id, args_);
|
||||
|
||||
const size_t col0 = range_K.begin();
|
||||
const size_t cols = range_K.Num();
|
||||
// Must be a vector multiple, or the last range before row padding,
|
||||
|
|
@ -1141,7 +1144,7 @@ class MMPerPackage {
|
|||
|
||||
switch (par_a) {
|
||||
case MMParA::kNone:
|
||||
do_range(all_M, all_K);
|
||||
do_range(all_M, all_K, /*worker=*/0);
|
||||
break;
|
||||
case MMParA::kK1:
|
||||
case MMParA::kK2:
|
||||
|
|
@ -1154,15 +1157,15 @@ class MMPerPackage {
|
|||
|
||||
args_.env->parallel.ForNP(
|
||||
all_K, multiple_K, inner_tasks, pkg_idx_,
|
||||
[&](const IndexRange& range_K, size_t /*worker*/) {
|
||||
do_range(all_M, range_K);
|
||||
[&](const IndexRange& range_K, size_t worker) {
|
||||
do_range(all_M, range_K, worker);
|
||||
});
|
||||
break;
|
||||
}
|
||||
case MMParA::kM:
|
||||
args_.env->parallel.ForRangeMC(
|
||||
all_M, pkg_idx_, [&](size_t row_a, size_t /*worker*/) {
|
||||
do_range(IndexRange(row_a, row_a + 1), all_K);
|
||||
all_M, pkg_idx_, [&](size_t row_a, size_t worker) {
|
||||
do_range(IndexRange(row_a, row_a + 1), all_K, worker);
|
||||
});
|
||||
break;
|
||||
}
|
||||
|
|
@ -1190,12 +1193,9 @@ class MMPerPackage {
|
|||
|
||||
// First call: generate candidates.
|
||||
if (HWY_UNLIKELY(!autotune.HasCandidates())) {
|
||||
std::vector<MMParA> candidates = {MMParA::kK1, MMParA::kK2, MMParA::kK4};
|
||||
if (A.Rows() == 1) {
|
||||
candidates.push_back(MMParA::kNone);
|
||||
} else {
|
||||
candidates.push_back(MMParA::kM);
|
||||
}
|
||||
const MMParA other = (A.Rows() == 1) ? MMParA::kNone : MMParA::kM;
|
||||
std::vector<MMParA> candidates = {MMParA::kK1, MMParA::kK2, MMParA::kK4,
|
||||
other};
|
||||
autotune.SetCandidates(candidates);
|
||||
}
|
||||
|
||||
|
|
@ -1279,7 +1279,8 @@ struct MMImpl {
|
|||
static HWY_NOINLINE void DoMatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
||||
RowPtrs<TC> C_rows, const MMArgs& args,
|
||||
const MMConfig& config) {
|
||||
static const uint32_t zone_id = PROFILER_ADD_ZONE("MM.DoMatMul");
|
||||
PROFILER_ZONE("MM.DoMatMul");
|
||||
static const uint32_t zone_id = PROFILER_ADD_ZONE("MM.DoMatMul.PerPkg");
|
||||
|
||||
// Outermost loop: static NUMA-aware partition of B rows across packages.
|
||||
args.env->parallel.ForPkg(
|
||||
|
|
@ -1353,7 +1354,7 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
|||
return &per_key;
|
||||
}
|
||||
|
||||
PROFILER_ZONE("Matmul.Autotune");
|
||||
// From here, CPU time is negligible except DoMatMul.
|
||||
|
||||
// First call: enumerate all feasible configs.
|
||||
if (HWY_UNLIKELY(!tuner.HasCandidates())) {
|
||||
|
|
@ -1364,7 +1365,6 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
|||
HWY_ASSERT(N <= MMStorage::kMaxN);
|
||||
HWY_ASSERT(N % kNR == 0);
|
||||
|
||||
// Negligible CPU time.
|
||||
tuner.SetCandidates(MMCandidates(allocator, M, K, N, sizeof(TC), kMaxMR,
|
||||
kNR, per_key.ranges_np, env.print_config));
|
||||
}
|
||||
|
|
|
|||
|
|
@ -208,7 +208,7 @@ template <typename XT, typename WT, typename OT>
|
|||
HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
|
||||
const XT* HWY_RESTRICT x, const WT* HWY_RESTRICT weight, size_t w_ofs,
|
||||
OT* HWY_RESTRICT out, const size_t size,
|
||||
const size_t HWY_MAYBE_UNUSED worker = 0) {
|
||||
const size_t HWY_MAYBE_UNUSED worker) {
|
||||
PROFILER_ZONE2(worker, "ops.RMSNorm");
|
||||
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
|
|
@ -240,7 +240,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
|
|||
template <typename WT, typename XT>
|
||||
HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace(
|
||||
const WT* HWY_RESTRICT weight, size_t w_ofs, XT* HWY_RESTRICT inout,
|
||||
const size_t size, const HWY_MAYBE_UNUSED size_t worker = 0) {
|
||||
const size_t size, const HWY_MAYBE_UNUSED size_t worker) {
|
||||
PROFILER_ZONE2(worker, "ops.RMSNormInplace");
|
||||
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
|
|
@ -527,7 +527,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy(
|
|||
template <typename XT>
|
||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void AddFrom(
|
||||
const XT* HWY_RESTRICT x, float* HWY_RESTRICT out, const size_t size,
|
||||
const HWY_MAYBE_UNUSED size_t worker = 0) {
|
||||
const HWY_MAYBE_UNUSED size_t worker) {
|
||||
PROFILER_ZONE2(worker, "ops.AddFrom");
|
||||
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
|
|
@ -570,29 +570,36 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void AddFrom(
|
|||
// Simple loops unless/until batch sizes are large enough to parallelize.
|
||||
template <typename XT, typename OT>
|
||||
void RMSNormBatched(const MatPtrT<XT>& activations, const MatPtr& weights,
|
||||
MatPtrT<OT>& out) {
|
||||
MatPtrT<OT>& out, ThreadingContext& ctx) {
|
||||
HWY_DASSERT(weights.Rows() == 1);
|
||||
HWY_DASSERT(weights.Cols() == activations.Cols());
|
||||
HWY_DASSERT(activations.SameShape(out));
|
||||
|
||||
CallUpcasted(&weights, [&](const auto* weights_t) {
|
||||
for (size_t token_idx = 0; token_idx < activations.Rows(); ++token_idx) {
|
||||
RMSNorm(activations.Row(token_idx), weights_t->PackedScale1(), 0,
|
||||
out.Row(token_idx), activations.Cols());
|
||||
}
|
||||
const size_t pkg_idx = 0;
|
||||
SmallParallelFor(activations.Rows(), ctx.pools, pkg_idx,
|
||||
[&](uint64_t token_idx, size_t worker) {
|
||||
RMSNorm(activations.Row(token_idx),
|
||||
weights_t->PackedScale1(), 0, out.Row(token_idx),
|
||||
activations.Cols(), worker);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
template <typename XT>
|
||||
void RMSNormInplaceBatched(const MatPtr& weights, MatPtrT<XT>& inout) {
|
||||
void RMSNormInplaceBatched(const MatPtr& weights, MatPtrT<XT>& inout,
|
||||
ThreadingContext& ctx) {
|
||||
HWY_DASSERT(weights.Rows() == 1);
|
||||
HWY_DASSERT(weights.Cols() == inout.Cols());
|
||||
|
||||
CallUpcasted(&weights, [&](const auto* weights_t) {
|
||||
for (size_t token_idx = 0; token_idx < inout.Rows(); ++token_idx) {
|
||||
RMSNormInplace(weights_t->PackedScale1(), 0, inout.Row(token_idx),
|
||||
inout.Cols());
|
||||
}
|
||||
const size_t pkg_idx = 0;
|
||||
SmallParallelFor(inout.Rows(), ctx.pools, pkg_idx,
|
||||
[&](uint64_t token_idx, size_t worker) {
|
||||
RMSNormInplace(weights_t->PackedScale1(), 0,
|
||||
inout.Row(token_idx), inout.Cols(),
|
||||
worker);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
|
|
@ -614,18 +621,20 @@ void LayerNormBatched(const MatPtrT<XT>& x, const MatPtr& weight,
|
|||
}
|
||||
|
||||
template <typename XT>
|
||||
static HWY_INLINE void AddFromBatched(const MatPtrT<XT>& x,
|
||||
MatPtrT<float>& out) {
|
||||
static HWY_INLINE void AddFromBatched(const MatPtrT<XT>& x, MatPtrT<float>& out,
|
||||
ThreadingContext& ctx) {
|
||||
HWY_DASSERT(out.SameShape(x));
|
||||
for (size_t token_idx = 0; token_idx < out.Rows(); ++token_idx) {
|
||||
AddFrom(x.Row(token_idx), out.Row(token_idx), x.Cols());
|
||||
}
|
||||
const size_t pkg_idx = 0;
|
||||
SmallParallelFor(
|
||||
out.Rows(), ctx.pools, pkg_idx, [&](uint64_t token_idx, size_t worker) {
|
||||
AddFrom(x.Row(token_idx), out.Row(token_idx), x.Cols(), worker);
|
||||
});
|
||||
}
|
||||
|
||||
template <typename XT>
|
||||
HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConst(
|
||||
const float c, XT* HWY_RESTRICT x, const size_t size,
|
||||
const HWY_MAYBE_UNUSED size_t worker = 0) {
|
||||
const HWY_MAYBE_UNUSED size_t worker) {
|
||||
PROFILER_ZONE2(worker, "ops.MulByConst");
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
const hn::ScalableTag<float> df;
|
||||
|
|
@ -666,7 +675,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConst(
|
|||
template <typename XT, typename OT>
|
||||
HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstTo(
|
||||
const float c, const XT* HWY_RESTRICT x, OT* HWY_RESTRICT out,
|
||||
const size_t size, const HWY_MAYBE_UNUSED size_t worker = 0) {
|
||||
const size_t size, const HWY_MAYBE_UNUSED size_t worker) {
|
||||
PROFILER_ZONE2(worker, "ops.MulByConstTo");
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
const hn::ScalableTag<float> df;
|
||||
|
|
@ -708,7 +717,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstTo(
|
|||
template <typename XT, typename OT>
|
||||
HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAdd(
|
||||
const float c, const XT* HWY_RESTRICT x, OT* HWY_RESTRICT out,
|
||||
const size_t size, const HWY_MAYBE_UNUSED size_t worker = 0) {
|
||||
const size_t size, const HWY_MAYBE_UNUSED size_t worker) {
|
||||
PROFILER_ZONE2(worker, "ops.MulByConstAndAdd");
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
const hn::ScalableTag<float> df;
|
||||
|
|
@ -754,8 +763,8 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAdd(
|
|||
|
||||
// See below for a specialized version for top-1 sampling.
|
||||
static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size,
|
||||
float temperature = 1.0f,
|
||||
const HWY_MAYBE_UNUSED size_t worker = 0) {
|
||||
const size_t worker,
|
||||
float temperature = 1.0f) {
|
||||
PROFILER_ZONE2(worker, "ops.Softmax");
|
||||
HWY_DASSERT(size != 0);
|
||||
|
||||
|
|
@ -797,7 +806,7 @@ static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size,
|
|||
const float sum_exp = Sum(d, x, size);
|
||||
// Double-precision reciprocal does not appear to affect the results.
|
||||
const float mul = 1.0f / sum_exp;
|
||||
MulByConst(mul, x, size);
|
||||
MulByConst(mul, x, size, worker);
|
||||
}
|
||||
|
||||
// Note: https://arxiv.org/pdf/2001.04438 proposes to replace the three max /
|
||||
|
|
@ -886,9 +895,9 @@ static HWY_MAYBE_UNUSED TokenAndProb Top1OfSoftmax(float* HWY_RESTRICT x,
|
|||
return TokenAndProb{.token = argmax.token, .prob = prob};
|
||||
}
|
||||
|
||||
static HWY_NOINLINE void LogitsSoftCap(
|
||||
const float cap, float* HWY_RESTRICT x, const size_t size,
|
||||
const HWY_MAYBE_UNUSED size_t worker = 0) {
|
||||
static HWY_NOINLINE void LogitsSoftCap(const float cap, float* HWY_RESTRICT x,
|
||||
const size_t size,
|
||||
const HWY_MAYBE_UNUSED size_t worker) {
|
||||
PROFILER_ZONE2(worker, "ops.LogitsSoftCap");
|
||||
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
|
|
@ -906,7 +915,7 @@ static HWY_NOINLINE void LogitsSoftCap(
|
|||
// Calls LogitsSoftCap if cap != 0.0f.
|
||||
static HWY_INLINE HWY_MAYBE_UNUSED void MaybeLogitsSoftCap(
|
||||
const float cap, float* HWY_RESTRICT x, const size_t size,
|
||||
const size_t worker = 0) {
|
||||
const size_t worker) {
|
||||
if (cap != 0.0f) {
|
||||
LogitsSoftCap(cap, x, size, worker);
|
||||
}
|
||||
|
|
@ -991,7 +1000,8 @@ HWY_NOINLINE HWY_MAYBE_UNUSED int SampleTopK(
|
|||
template <typename TAcceptToken>
|
||||
HWY_NOINLINE HWY_MAYBE_UNUSED TokenAndProb FusedSoftmaxAndSampleTopK(
|
||||
const float* HWY_RESTRICT logits, size_t k, size_t vocab_size,
|
||||
std::mt19937& gen, float temperature, TAcceptToken& accept_token) {
|
||||
std::mt19937& gen, float temperature, TAcceptToken& accept_token,
|
||||
size_t worker) {
|
||||
// Softmax and sample top-K is equivalent to taking the top-K logits and
|
||||
// sampling from the softmax of the top-K logits. The latter is faster as it
|
||||
// avoids computing the softmax of all logits.
|
||||
|
|
@ -1005,7 +1015,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED TokenAndProb FusedSoftmaxAndSampleTopK(
|
|||
}
|
||||
|
||||
size_t mask = token_logits.size();
|
||||
Softmax(topk_logits.data(), mask, temperature);
|
||||
Softmax(topk_logits.data(), mask, worker, temperature);
|
||||
auto distribution = std::discrete_distribution<int>(
|
||||
std::begin(topk_logits), std::begin(topk_logits) + mask);
|
||||
int topk_sampled_index = distribution(gen);
|
||||
|
|
|
|||
|
|
@ -166,7 +166,7 @@ struct TestAddFrom {
|
|||
}
|
||||
|
||||
SimpleAddFrom(o, e, count);
|
||||
AddFrom(o, x, count);
|
||||
AddFrom(o, x, count, /*worker=*/0);
|
||||
|
||||
hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__,
|
||||
__LINE__);
|
||||
|
|
@ -199,7 +199,7 @@ struct TestMulByConstAndAdd {
|
|||
T constant = Random<T>(rng);
|
||||
|
||||
SimpleMulByConstAndAdd(constant, o, e, count);
|
||||
MulByConstAndAdd(constant, o, x, count);
|
||||
MulByConstAndAdd(constant, o, x, count, /*worker=*/0);
|
||||
|
||||
hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__,
|
||||
__LINE__);
|
||||
|
|
@ -229,7 +229,7 @@ struct TestMulByConst {
|
|||
T constant = Random<T>(rng);
|
||||
|
||||
SimpleMulByConst(constant, e, count);
|
||||
MulByConst(constant, x, count);
|
||||
MulByConst(constant, x, count, /*worker=*/0);
|
||||
|
||||
hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__,
|
||||
__LINE__);
|
||||
|
|
@ -259,7 +259,7 @@ struct TestSoftmax {
|
|||
}
|
||||
|
||||
SimpleSoftmax(e, count);
|
||||
Softmax(x, count);
|
||||
Softmax(x, count, /*worker=*/0);
|
||||
|
||||
T sum = 0.0f;
|
||||
for (size_t i = 0; i < count; ++i) {
|
||||
|
|
@ -454,7 +454,7 @@ void TestRMSNorm(hwy::RandomState& rng) {
|
|||
}
|
||||
|
||||
ScalarRMSNorm(vec, weight, expected, kSize);
|
||||
RMSNorm(vec, weight, 0, actual, kSize);
|
||||
RMSNorm(vec, weight, 0, actual, kSize, /*worker=*/0);
|
||||
|
||||
for (size_t i = 0; i < kSize; i++) {
|
||||
const float e = hwy::ConvertScalarTo<float>(expected[i]);
|
||||
|
|
@ -584,7 +584,7 @@ void TestSampleTopK() {
|
|||
std::vector<float> logits(kSize);
|
||||
// Create a vector going from -100 to -100+51=49 and take Softmax.
|
||||
std::iota(logits.begin(), logits.end(), -100.0f);
|
||||
Softmax(logits.data(), kSize);
|
||||
Softmax(logits.data(), kSize, /*worker=*/0);
|
||||
std::mt19937 gen;
|
||||
gen.seed(0x12345678);
|
||||
float temperature = 1.0f;
|
||||
|
|
@ -600,7 +600,7 @@ void TestSampleTopK() {
|
|||
EXPECT_EQ(sample, 50); // Last even index.
|
||||
// Reset the logits to a positive, increasing sequence and take Softmax.
|
||||
std::iota(logits.begin(), logits.end(), 1.0f);
|
||||
Softmax(logits.data(), kSize);
|
||||
Softmax(logits.data(), kSize, /*worker=*/0);
|
||||
// Sample from the top 3, expect one of the top 3 even indices.
|
||||
for (int i = 0; i < 100; ++i) {
|
||||
sample = SampleTopK(logits.data(), /*k=*/3, kSize, gen, temperature,
|
||||
|
|
|
|||
|
|
@ -355,6 +355,17 @@ void ParallelFor(size_t num_tasks, NestedPools& pools, size_t pkg_idx,
|
|||
});
|
||||
}
|
||||
|
||||
// As above, but for lightweight tasks. Uses only one pool.
|
||||
template <class Func>
|
||||
void SmallParallelFor(size_t num_tasks, NestedPools& pools, size_t pkg_idx,
|
||||
const Func& func) {
|
||||
const size_t pkg_base = pkg_idx * pools.MaxWorkersPerPackage();
|
||||
|
||||
pools.Pool(pkg_idx).Run(0, num_tasks, [&](uint64_t task, size_t thread) {
|
||||
func(task, pkg_base + thread);
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_UTIL_THREADING_H_
|
||||
|
|
|
|||
Loading…
Reference in New Issue