mirror of https://github.com/google/gemma.cpp.git
1.04x speedup: Parallelize SoftCap
Also require opt-in constexpr flag for observer callbacks, update zones PiperOrigin-RevId: 799655163
This commit is contained in:
parent
ed2f0bd1b0
commit
86afd53076
|
|
@ -61,6 +61,9 @@
|
|||
#include "hwy/base.h"
|
||||
#include "hwy/timer.h"
|
||||
|
||||
// Require opt-in to debug/introspection functions to eliminate their overhead.
|
||||
HWY_INLINE_VAR constexpr bool kObserver = false;
|
||||
|
||||
#endif // GEMMA_CC_ONCE
|
||||
|
||||
HWY_BEFORE_NAMESPACE();
|
||||
|
|
@ -143,6 +146,10 @@ EmbedMMToken(int token, size_t x_row, size_t pos, size_t pos_in_prompt,
|
|||
MatStorageT<float>& x, ThreadingContext& ctx,
|
||||
const ImageTokens* image_tokens = nullptr,
|
||||
size_t image_token_position = 0) {
|
||||
static const auto zone =
|
||||
ctx.profiler.AddZone("Gen.Embed", hwy::ProfilerFlags::kInclusive);
|
||||
PROFILER_ZONE3(ctx.profiler, hwy::Profiler::Thread(), zone);
|
||||
|
||||
// Image tokens just need to be copied.
|
||||
if (model_config.wrapping == PromptWrapping::GEMMA_VLM &&
|
||||
image_tokens != nullptr && token == -2 &&
|
||||
|
|
@ -295,6 +302,7 @@ static HWY_NOINLINE void Transformer(const ModelConfig& config,
|
|||
const WeightsPtrs& weights,
|
||||
Activations& activations, QBatch& qbatch,
|
||||
MatMulEnv& env) {
|
||||
if constexpr (kObserver) {
|
||||
if (HWY_UNLIKELY(runtime_config.layers_output)) {
|
||||
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
|
||||
const float token_f = qbatch.PrevToken(qi);
|
||||
|
|
@ -302,6 +310,7 @@ static HWY_NOINLINE void Transformer(const ModelConfig& config,
|
|||
"tokens", -1, &token_f, 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: parallelize?
|
||||
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
|
||||
|
|
@ -313,12 +322,14 @@ static HWY_NOINLINE void Transformer(const ModelConfig& config,
|
|||
TransformerLayer(/*num_tokens=*/1, layer_idx, *weights.GetLayer(layer_idx),
|
||||
activations, qbatch, env);
|
||||
|
||||
if constexpr (kObserver) {
|
||||
if (HWY_UNLIKELY(runtime_config.activations_observer)) {
|
||||
runtime_config.activations_observer(
|
||||
QueriesPos(&qbatch.MutablePos(0), qbatch.Size()), layer_idx,
|
||||
activations);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Populates KV cache for the batch queries, one token at a time.
|
||||
|
|
@ -403,23 +414,29 @@ static void SampleAndStream(
|
|||
|
||||
RMSNormInplaceBatched(weights.final_norm_scale, activations.x, env.ctx);
|
||||
|
||||
if constexpr (kObserver) {
|
||||
if (HWY_UNLIKELY(runtime_config.activations_observer)) {
|
||||
runtime_config.activations_observer(
|
||||
QueriesPos(&qbatch.MutablePos(0), qbatch.Size()), -1, activations);
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
PROFILER_ZONE("Gen.EmbeddingMatmul");
|
||||
static const auto zone = env.ctx.profiler.AddZone(
|
||||
"Gen.EmbeddingMatmul", hwy::ProfilerFlags::kInclusive);
|
||||
PROFILER_ZONE3(env.ctx.profiler, /*worker=*/0, zone);
|
||||
// Compute logits from last layer activations.
|
||||
CallMatMul(activations.x, weights.embedder_input_embedding,
|
||||
/*add=*/nullptr, env, activations.logits);
|
||||
}
|
||||
PROFILER_ZONE("Gen.Softcap+Sample+Stream");
|
||||
const size_t worker = 0; // TODO: parallelize
|
||||
|
||||
MaybeLogitsSoftCapBatched(config.final_cap, activations.logits, non_eos,
|
||||
env.ctx);
|
||||
|
||||
// TODO: parallelize
|
||||
non_eos.Foreach([&](size_t qi) {
|
||||
float* HWY_RESTRICT logits = activations.logits.Row(qi);
|
||||
MaybeLogitsSoftCap(config.final_cap, logits, config.vocab_size,
|
||||
env.ctx.profiler, worker);
|
||||
const TokenAndProb tp = sample_token(logits, config.vocab_size);
|
||||
timing_info.NotifyGenerated();
|
||||
|
||||
|
|
|
|||
|
|
@ -33,6 +33,7 @@
|
|||
#include "util/mat.h"
|
||||
#include "util/threading_context.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/bit_set.h"
|
||||
#include "hwy/contrib/sort/order.h"
|
||||
#include "hwy/contrib/sort/vqsort.h"
|
||||
#include "hwy/detect_targets.h"
|
||||
|
|
@ -932,6 +933,17 @@ static HWY_INLINE HWY_MAYBE_UNUSED void MaybeLogitsSoftCap(
|
|||
}
|
||||
}
|
||||
|
||||
static HWY_INLINE HWY_MAYBE_UNUSED void MaybeLogitsSoftCapBatched(
|
||||
const float cap, MatPtrT<float>& x, const hwy::BitSet4096<>& non_eos,
|
||||
ThreadingContext& ctx) {
|
||||
if (cap == 0.0f) return;
|
||||
SmallParallelFor(x.Rows(), ctx.pools, [&](uint64_t task, size_t worker) {
|
||||
if (non_eos.Get(task)) {
|
||||
LogitsSoftCap(cap, x.Row(task), x.Cols(), ctx.profiler, worker);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
static HWY_NOINLINE HWY_MAYBE_UNUSED size_t
|
||||
SampleArgmax(const float* probabilities, size_t vocab_size) {
|
||||
size_t max_index = 0;
|
||||
|
|
|
|||
Loading…
Reference in New Issue