diff --git a/gemma/gemma.cc b/gemma/gemma.cc index b506e75..bef3a70 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -61,6 +61,9 @@ #include "hwy/base.h" #include "hwy/timer.h" +// Require opt-in to debug/introspection functions to eliminate their overhead. +HWY_INLINE_VAR constexpr bool kObserver = false; + #endif // GEMMA_CC_ONCE HWY_BEFORE_NAMESPACE(); @@ -143,6 +146,10 @@ EmbedMMToken(int token, size_t x_row, size_t pos, size_t pos_in_prompt, MatStorageT& x, ThreadingContext& ctx, const ImageTokens* image_tokens = nullptr, size_t image_token_position = 0) { + static const auto zone = + ctx.profiler.AddZone("Gen.Embed", hwy::ProfilerFlags::kInclusive); + PROFILER_ZONE3(ctx.profiler, hwy::Profiler::Thread(), zone); + // Image tokens just need to be copied. if (model_config.wrapping == PromptWrapping::GEMMA_VLM && image_tokens != nullptr && token == -2 && @@ -295,11 +302,13 @@ static HWY_NOINLINE void Transformer(const ModelConfig& config, const WeightsPtrs& weights, Activations& activations, QBatch& qbatch, MatMulEnv& env) { - if (HWY_UNLIKELY(runtime_config.layers_output)) { - for (size_t qi = 0; qi < qbatch.Size(); ++qi) { - const float token_f = qbatch.PrevToken(qi); - runtime_config.layers_output(qbatch.QueryIdx(qi), qbatch.Pos(qi), - "tokens", -1, &token_f, 1); + if constexpr (kObserver) { + if (HWY_UNLIKELY(runtime_config.layers_output)) { + for (size_t qi = 0; qi < qbatch.Size(); ++qi) { + const float token_f = qbatch.PrevToken(qi); + runtime_config.layers_output(qbatch.QueryIdx(qi), qbatch.Pos(qi), + "tokens", -1, &token_f, 1); + } } } @@ -313,10 +322,12 @@ static HWY_NOINLINE void Transformer(const ModelConfig& config, TransformerLayer(/*num_tokens=*/1, layer_idx, *weights.GetLayer(layer_idx), activations, qbatch, env); - if (HWY_UNLIKELY(runtime_config.activations_observer)) { - runtime_config.activations_observer( - QueriesPos(&qbatch.MutablePos(0), qbatch.Size()), layer_idx, - activations); + if constexpr (kObserver) { + if (HWY_UNLIKELY(runtime_config.activations_observer)) { + runtime_config.activations_observer( + QueriesPos(&qbatch.MutablePos(0), qbatch.Size()), layer_idx, + activations); + } } } } @@ -403,23 +414,29 @@ static void SampleAndStream( RMSNormInplaceBatched(weights.final_norm_scale, activations.x, env.ctx); - if (HWY_UNLIKELY(runtime_config.activations_observer)) { - runtime_config.activations_observer( - QueriesPos(&qbatch.MutablePos(0), qbatch.Size()), -1, activations); + if constexpr (kObserver) { + if (HWY_UNLIKELY(runtime_config.activations_observer)) { + runtime_config.activations_observer( + QueriesPos(&qbatch.MutablePos(0), qbatch.Size()), -1, activations); + } } { - PROFILER_ZONE("Gen.EmbeddingMatmul"); + static const auto zone = env.ctx.profiler.AddZone( + "Gen.EmbeddingMatmul", hwy::ProfilerFlags::kInclusive); + PROFILER_ZONE3(env.ctx.profiler, /*worker=*/0, zone); // Compute logits from last layer activations. CallMatMul(activations.x, weights.embedder_input_embedding, /*add=*/nullptr, env, activations.logits); } PROFILER_ZONE("Gen.Softcap+Sample+Stream"); - const size_t worker = 0; // TODO: parallelize + + MaybeLogitsSoftCapBatched(config.final_cap, activations.logits, non_eos, + env.ctx); + + // TODO: parallelize non_eos.Foreach([&](size_t qi) { float* HWY_RESTRICT logits = activations.logits.Row(qi); - MaybeLogitsSoftCap(config.final_cap, logits, config.vocab_size, - env.ctx.profiler, worker); const TokenAndProb tp = sample_token(logits, config.vocab_size); timing_info.NotifyGenerated(); diff --git a/ops/ops-inl.h b/ops/ops-inl.h index 343600a..95228f4 100644 --- a/ops/ops-inl.h +++ b/ops/ops-inl.h @@ -33,6 +33,7 @@ #include "util/mat.h" #include "util/threading_context.h" #include "hwy/base.h" +#include "hwy/bit_set.h" #include "hwy/contrib/sort/order.h" #include "hwy/contrib/sort/vqsort.h" #include "hwy/detect_targets.h" @@ -932,6 +933,17 @@ static HWY_INLINE HWY_MAYBE_UNUSED void MaybeLogitsSoftCap( } } +static HWY_INLINE HWY_MAYBE_UNUSED void MaybeLogitsSoftCapBatched( + const float cap, MatPtrT& x, const hwy::BitSet4096<>& non_eos, + ThreadingContext& ctx) { + if (cap == 0.0f) return; + SmallParallelFor(x.Rows(), ctx.pools, [&](uint64_t task, size_t worker) { + if (non_eos.Get(task)) { + LogitsSoftCap(cap, x.Row(task), x.Cols(), ctx.profiler, worker); + } + }); +} + static HWY_NOINLINE HWY_MAYBE_UNUSED size_t SampleArgmax(const float* probabilities, size_t vocab_size) { size_t max_index = 0;