mirror of https://github.com/google/gemma.cpp.git
1.04x speedup: Parallelize SoftCap
Also require opt-in constexpr flag for observer callbacks, update zones PiperOrigin-RevId: 799655163
This commit is contained in:
parent
ed2f0bd1b0
commit
86afd53076
|
|
@ -61,6 +61,9 @@
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
#include "hwy/timer.h"
|
#include "hwy/timer.h"
|
||||||
|
|
||||||
|
// Require opt-in to debug/introspection functions to eliminate their overhead.
|
||||||
|
HWY_INLINE_VAR constexpr bool kObserver = false;
|
||||||
|
|
||||||
#endif // GEMMA_CC_ONCE
|
#endif // GEMMA_CC_ONCE
|
||||||
|
|
||||||
HWY_BEFORE_NAMESPACE();
|
HWY_BEFORE_NAMESPACE();
|
||||||
|
|
@ -143,6 +146,10 @@ EmbedMMToken(int token, size_t x_row, size_t pos, size_t pos_in_prompt,
|
||||||
MatStorageT<float>& x, ThreadingContext& ctx,
|
MatStorageT<float>& x, ThreadingContext& ctx,
|
||||||
const ImageTokens* image_tokens = nullptr,
|
const ImageTokens* image_tokens = nullptr,
|
||||||
size_t image_token_position = 0) {
|
size_t image_token_position = 0) {
|
||||||
|
static const auto zone =
|
||||||
|
ctx.profiler.AddZone("Gen.Embed", hwy::ProfilerFlags::kInclusive);
|
||||||
|
PROFILER_ZONE3(ctx.profiler, hwy::Profiler::Thread(), zone);
|
||||||
|
|
||||||
// Image tokens just need to be copied.
|
// Image tokens just need to be copied.
|
||||||
if (model_config.wrapping == PromptWrapping::GEMMA_VLM &&
|
if (model_config.wrapping == PromptWrapping::GEMMA_VLM &&
|
||||||
image_tokens != nullptr && token == -2 &&
|
image_tokens != nullptr && token == -2 &&
|
||||||
|
|
@ -295,6 +302,7 @@ static HWY_NOINLINE void Transformer(const ModelConfig& config,
|
||||||
const WeightsPtrs& weights,
|
const WeightsPtrs& weights,
|
||||||
Activations& activations, QBatch& qbatch,
|
Activations& activations, QBatch& qbatch,
|
||||||
MatMulEnv& env) {
|
MatMulEnv& env) {
|
||||||
|
if constexpr (kObserver) {
|
||||||
if (HWY_UNLIKELY(runtime_config.layers_output)) {
|
if (HWY_UNLIKELY(runtime_config.layers_output)) {
|
||||||
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
|
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
|
||||||
const float token_f = qbatch.PrevToken(qi);
|
const float token_f = qbatch.PrevToken(qi);
|
||||||
|
|
@ -302,6 +310,7 @@ static HWY_NOINLINE void Transformer(const ModelConfig& config,
|
||||||
"tokens", -1, &token_f, 1);
|
"tokens", -1, &token_f, 1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// TODO: parallelize?
|
// TODO: parallelize?
|
||||||
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
|
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
|
||||||
|
|
@ -313,6 +322,7 @@ static HWY_NOINLINE void Transformer(const ModelConfig& config,
|
||||||
TransformerLayer(/*num_tokens=*/1, layer_idx, *weights.GetLayer(layer_idx),
|
TransformerLayer(/*num_tokens=*/1, layer_idx, *weights.GetLayer(layer_idx),
|
||||||
activations, qbatch, env);
|
activations, qbatch, env);
|
||||||
|
|
||||||
|
if constexpr (kObserver) {
|
||||||
if (HWY_UNLIKELY(runtime_config.activations_observer)) {
|
if (HWY_UNLIKELY(runtime_config.activations_observer)) {
|
||||||
runtime_config.activations_observer(
|
runtime_config.activations_observer(
|
||||||
QueriesPos(&qbatch.MutablePos(0), qbatch.Size()), layer_idx,
|
QueriesPos(&qbatch.MutablePos(0), qbatch.Size()), layer_idx,
|
||||||
|
|
@ -320,6 +330,7 @@ static HWY_NOINLINE void Transformer(const ModelConfig& config,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Populates KV cache for the batch queries, one token at a time.
|
// Populates KV cache for the batch queries, one token at a time.
|
||||||
static HWY_NOINLINE void PrefillQBatch(const size_t max_prompt_size,
|
static HWY_NOINLINE void PrefillQBatch(const size_t max_prompt_size,
|
||||||
|
|
@ -403,23 +414,29 @@ static void SampleAndStream(
|
||||||
|
|
||||||
RMSNormInplaceBatched(weights.final_norm_scale, activations.x, env.ctx);
|
RMSNormInplaceBatched(weights.final_norm_scale, activations.x, env.ctx);
|
||||||
|
|
||||||
|
if constexpr (kObserver) {
|
||||||
if (HWY_UNLIKELY(runtime_config.activations_observer)) {
|
if (HWY_UNLIKELY(runtime_config.activations_observer)) {
|
||||||
runtime_config.activations_observer(
|
runtime_config.activations_observer(
|
||||||
QueriesPos(&qbatch.MutablePos(0), qbatch.Size()), -1, activations);
|
QueriesPos(&qbatch.MutablePos(0), qbatch.Size()), -1, activations);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
PROFILER_ZONE("Gen.EmbeddingMatmul");
|
static const auto zone = env.ctx.profiler.AddZone(
|
||||||
|
"Gen.EmbeddingMatmul", hwy::ProfilerFlags::kInclusive);
|
||||||
|
PROFILER_ZONE3(env.ctx.profiler, /*worker=*/0, zone);
|
||||||
// Compute logits from last layer activations.
|
// Compute logits from last layer activations.
|
||||||
CallMatMul(activations.x, weights.embedder_input_embedding,
|
CallMatMul(activations.x, weights.embedder_input_embedding,
|
||||||
/*add=*/nullptr, env, activations.logits);
|
/*add=*/nullptr, env, activations.logits);
|
||||||
}
|
}
|
||||||
PROFILER_ZONE("Gen.Softcap+Sample+Stream");
|
PROFILER_ZONE("Gen.Softcap+Sample+Stream");
|
||||||
const size_t worker = 0; // TODO: parallelize
|
|
||||||
|
MaybeLogitsSoftCapBatched(config.final_cap, activations.logits, non_eos,
|
||||||
|
env.ctx);
|
||||||
|
|
||||||
|
// TODO: parallelize
|
||||||
non_eos.Foreach([&](size_t qi) {
|
non_eos.Foreach([&](size_t qi) {
|
||||||
float* HWY_RESTRICT logits = activations.logits.Row(qi);
|
float* HWY_RESTRICT logits = activations.logits.Row(qi);
|
||||||
MaybeLogitsSoftCap(config.final_cap, logits, config.vocab_size,
|
|
||||||
env.ctx.profiler, worker);
|
|
||||||
const TokenAndProb tp = sample_token(logits, config.vocab_size);
|
const TokenAndProb tp = sample_token(logits, config.vocab_size);
|
||||||
timing_info.NotifyGenerated();
|
timing_info.NotifyGenerated();
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -33,6 +33,7 @@
|
||||||
#include "util/mat.h"
|
#include "util/mat.h"
|
||||||
#include "util/threading_context.h"
|
#include "util/threading_context.h"
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
|
#include "hwy/bit_set.h"
|
||||||
#include "hwy/contrib/sort/order.h"
|
#include "hwy/contrib/sort/order.h"
|
||||||
#include "hwy/contrib/sort/vqsort.h"
|
#include "hwy/contrib/sort/vqsort.h"
|
||||||
#include "hwy/detect_targets.h"
|
#include "hwy/detect_targets.h"
|
||||||
|
|
@ -932,6 +933,17 @@ static HWY_INLINE HWY_MAYBE_UNUSED void MaybeLogitsSoftCap(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static HWY_INLINE HWY_MAYBE_UNUSED void MaybeLogitsSoftCapBatched(
|
||||||
|
const float cap, MatPtrT<float>& x, const hwy::BitSet4096<>& non_eos,
|
||||||
|
ThreadingContext& ctx) {
|
||||||
|
if (cap == 0.0f) return;
|
||||||
|
SmallParallelFor(x.Rows(), ctx.pools, [&](uint64_t task, size_t worker) {
|
||||||
|
if (non_eos.Get(task)) {
|
||||||
|
LogitsSoftCap(cap, x.Row(task), x.Cols(), ctx.profiler, worker);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
static HWY_NOINLINE HWY_MAYBE_UNUSED size_t
|
static HWY_NOINLINE HWY_MAYBE_UNUSED size_t
|
||||||
SampleArgmax(const float* probabilities, size_t vocab_size) {
|
SampleArgmax(const float* probabilities, size_t vocab_size) {
|
||||||
size_t max_index = 0;
|
size_t max_index = 0;
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue