diff --git a/BUILD.bazel b/BUILD.bazel index 12abd4c..e652d72 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -149,13 +149,14 @@ cc_library( ":tokenizer", ":kv_cache", ":weights", - "//compression:compress", "//compression:io", "@hwy//:hwy", + "@hwy//:bit_set", "@hwy//:matvec", "@hwy//:nanobenchmark", # timer "@hwy//:profiler", "@hwy//:thread_pool", + "@hwy//:topology", ], ) diff --git a/evals/benchmark_helper.cc b/evals/benchmark_helper.cc index fd3593d..caa489e 100644 --- a/evals/benchmark_helper.cc +++ b/evals/benchmark_helper.cc @@ -76,8 +76,8 @@ GemmaEnv::GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference, fprintf(stderr, "Loading model...\n"); model_ = AllocateGemma(loader_, pool_); - kv_caches_.reserve(16); - for (int i = 0; i < 16; ++i) { + kv_caches_.reserve(kBatchedQueryBatchSize); + for (int i = 0; i < kBatchedQueryBatchSize; ++i) { kv_caches_.push_back(new KVCache(KVCache::Create(model_->Info().model))); } } diff --git a/gemma/common.h b/gemma/common.h index f8da552..099314f 100644 --- a/gemma/common.h +++ b/gemma/common.h @@ -36,11 +36,10 @@ ByteStorageT AllocateSizeof() { return hwy::AllocateAligned(sizeof(T)); } -constexpr size_t kPrefillBatchSize = 512; -constexpr size_t kDecodeBatchSize = 1; +// Relatively small so that we can also parallelize non-Matmul work. There is +// one outer thread per batch, each with --num_threads / batches inner threads. +constexpr size_t kPrefillBatchSize = 64; constexpr size_t kBatchedQueryBatchSize = 16; -constexpr size_t kMinAdjustedPrefillBatchSize = - HWY_MAX((size_t)1, kPrefillBatchSize / kBatchedQueryBatchSize); // Model variants: see configs.h for details. When adding a new one, also // update GEMMA_FOREACH* and Call* below, and add instantiations/*.cc. diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index 66edc75..0b58c39 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -26,23 +26,26 @@ #include #include -#include // memcpy -#include +#include // std::min +#include // std::unique_ptr #include #include #include #include "gemma/activations.h" #include "gemma/common.h" +#include "gemma/configs.h" #include "gemma/gemma.h" #include "gemma/ops.h" #include "gemma/weights.h" // Placeholder for internal test4, do not remove #include "hwy/aligned_allocator.h" #include "hwy/base.h" +#include "hwy/bit_set.h" #include "hwy/contrib/matvec/matvec-inl.h" #include "hwy/contrib/thread_pool/thread_pool.h" +#include "hwy/contrib/thread_pool/topology.h" #include "hwy/highway.h" #include "hwy/profiler.h" #include "hwy/timer.h" @@ -269,7 +272,7 @@ HWY_NOINLINE void Attention(size_t batch_and_query_start, size_t num_tokens, const float* HWY_RESTRICT q = activations.q.Batch(batch_and_query_idx) + head * kQStride; // Skip past the Q part of `q`, and copy KV to `kv`. - memcpy(kv, q + kQKVDim, 2 * kQKVDim * sizeof(float)); + hwy::CopyBytes(q + kQKVDim, kv, 2 * kQKVDim * sizeof(float)); } PostQK(kv, pos, layer); }); @@ -414,6 +417,7 @@ HWY_NOINLINE void FFW(Activations& activations, size_t num_tokens, } // TODO: pass Activations.x instead of Activations. +// `pos` is for the entire batch and does not include `batch_idx`. template HWY_NOINLINE void EmbedToken(int token, size_t batch_idx, size_t pos, const CompressedWeights& weights, @@ -421,8 +425,10 @@ HWY_NOINLINE void EmbedToken(int token, size_t batch_idx, size_t pos, constexpr size_t kModelDim = TConfig::kModelDim; GEMMA_CONSTEXPR_EMBSCALING const float kEmbScaling = EmbeddingScaling(); + HWY_DASSERT(token >= 0); HWY_DASSERT(token < TConfig::kVocabSize); + Decompress(weights.embedder_input_embedding, token * kModelDim, activations.x.Batch(batch_idx), kModelDim); MulByConst(kEmbScaling, activations.x.Batch(batch_idx), kModelDim); @@ -441,7 +447,7 @@ HWY_NOINLINE void ResidualConnection( AddFromBatched(num_tokens_and_queries, other, x, kModelDim); } -template +template HWY_NOINLINE void TransformerLayer( size_t num_tokens, size_t num_queries, size_t pos, size_t layer, const CompressedLayer* layer_weights, Activations& activations, @@ -458,13 +464,10 @@ HWY_NOINLINE void TransformerLayer( Attention(pos, num_tokens, num_queries, layer_of_type, activations, layer_weights, kv_caches, pool); } else { - // This Griffin layers should never exist unless the model is a Griffin - // model. This conditional prevents the compiler from generating code for - // this branch when building a non-Griffin model, since we have static - // asserts about the query batch size for Griffin layers. + // Only reached if the model is Griffin. `if constexpr` prevents generating + // this code for non-Griffin models. if constexpr (TConfig::kGriffinLayers > 0) { - static_assert(kQueryBatchSize == 1, - "Griffin does not support batched queries."); + HWY_ASSERT(num_queries == 1); GriffinRecurrent(pos, num_tokens, num_queries, layer_of_type, activations, layer_weights, kv_caches, pool); } @@ -494,39 +497,171 @@ HWY_NOINLINE void TransformerLayer( /*is_attention=*/false); } -template -HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, - size_t num_queries, size_t pos, - const CompressedWeights& weights, - Activations& activations, - const std::vector& kv_caches, - hwy::ThreadPool& pool) { - PROFILER_ZONE("Gen.Prefill"); - HWY_DASSERT(num_queries <= kQueryBatchSize); - const size_t minibatch_size = std::min(num_tokens, kBatchSize); - // TODO: hoist pool.Run out of the loop, change the unit of work to batches. - for (size_t i = 0; i < num_tokens; i += minibatch_size) { - const size_t offset = i * num_queries; - const size_t current_token_count = std::min( - minibatch_size, num_tokens - i); - pool.Run(0, current_token_count * num_queries, - [&](const uint64_t token_idx, size_t /*thread*/) HWY_ATTR { - EmbedToken(tokens[token_idx + offset], token_idx, - pos + offset, weights, activations); - }); +// For prefill, we have two-level parallelism: +// - Outer: input tokens are split into batches, each of which is one task +// processed by a worker in `outer_pool_`, which includes the main thread +// because it is the one that calls `Prefill`. +// - Inner: each `outer` worker passes `inner_pools_[outer]` to +// `TransformerLayer` for tensor-level parallelism. +// +// This class holds the thread pools and activations, recreated for each query. +// +// It is safe to parallelize batches because we write to KVCache at +// `pos % kSeqLen`, which is far greater than the number of workers. +// Note however that this currently leads to nondeterministic results because +// the RNG is invoked in different order. +class PrefillState { + public: + explicit PrefillState(hwy::ThreadPool& main_pool) : main_pool_(&main_pool) {} - for (size_t layer = 0; layer < TConfig::kLayers; ++layer) { - const auto* layer_weights = weights.GetLayer(layer); - TransformerLayer( - current_token_count, num_queries, pos + offset, layer, layer_weights, - activations, kv_caches, pool); + ~PrefillState() { DeleteInnerPools(); } + + // Called before each query. Recreates thread pools, which has the advantage + // of tailoring the parallelism to the prompt length. + template + void Init(size_t prefill_size) { + // Would be zero for single-token prompts (prefill_size == num_tokens - 1). + num_batches_ = + HWY_MAX(size_t{1}, hwy::DivCeil(prefill_size, kPrefillBatchSize)); + // More than num_batches_ would waste workers on idling in the outer Run; + // more than NumWorkers() would exceed the global --num_threads. + const size_t outer_workers = + HWY_MIN(num_batches_, main_pool_->NumWorkers()); + HWY_ASSERT(outer_workers != 0); // Otherwise activations_ is empty. + + // One activation per outer worker. Allocating in parallel saves 30 ms. + activations_.resize(outer_workers); + main_pool_->Run(0, outer_workers, [this](uint64_t task, size_t /*thread*/) { + activations_[task].Allocate(kPrefillBatchSize); + }); + + DeleteInnerPools(); + + // If we'd create just one inner pool with all the workers, skip the cost of + // thread creation and pinning (about 60 ms) by reusing the main pool. + if (outer_workers <= 1) { + // Still allocate a dummy pool to simplify Prefill(). + outer_pool_ = std::make_unique(1); + inner_pools_.push_back(main_pool_); + return; } - } -} -// Compute the transformer for a batch of input tokens. During generation, -// we usually have num_tokens == 1 (and also kBatchSize == 1). -template + // Before creating new threads, stop the old ones from spinning. Caller is + // responsible for undoing this by calling `ResumeMainSpinning`. + main_pool_->SetWaitMode(hwy::PoolWaitMode::kBlock); + outer_pool_ = std::make_unique(outer_workers); + outer_pool_->SetWaitMode(hwy::PoolWaitMode::kSpin); + + // Assign up to `max_workers` to inner pools. Each inner pool creates + // `workers_per_outer - 1` threads in addition to its 'main' thread, which + // is the one calling `inner_pools[outer]->Run`, i.e., `outer`. In total, + // `outer_workers * (max_workers / outer_workers)` workers are used. + const size_t workers_per_outer = main_pool_->NumWorkers() / outer_workers; + for (size_t outer = 0; outer < outer_workers; ++outer) { + inner_pools_.push_back(new hwy::ThreadPool(workers_per_outer)); + inner_pools_.back()->SetWaitMode(hwy::PoolWaitMode::kSpin); + } + + PinThreads(outer_workers, workers_per_outer); + } + + // `tokens` are from interleaved queries. (See InterleaveQueries() below.) + template + HWY_NOINLINE void Prefill(hwy::Span tokens, size_t num_queries, + size_t pos, + const CompressedWeights& weights, + const RuntimeConfig& runtime_config, + const std::vector& kv_caches) { + PROFILER_ZONE("Gen.Prefill"); + + HWY_ASSERT(activations_.size() == outer_pool_->NumWorkers()); + HWY_ASSERT(inner_pools_.size() == outer_pool_->NumWorkers()); + + outer_pool_->Run( + 0, num_batches_, [&](const uint64_t batch_num, size_t thread) HWY_ATTR { + const size_t batch_start = batch_num * kPrefillBatchSize; + const size_t batch_size = + HWY_MIN(kPrefillBatchSize, tokens.size() - batch_start); + HWY_DASSERT(batch_start % num_queries == 0); + HWY_DASSERT(batch_size % num_queries == 0); + const size_t pos_per_query = pos + batch_start / num_queries; + const size_t num_tokens = batch_size / num_queries; + + // Negligible time compared to TransformerLayer. + for (size_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) { + EmbedToken(tokens[batch_start + batch_idx], batch_idx, + pos_per_query, weights, activations_[thread]); + } + + for (size_t layer = 0; layer < TConfig::kLayers; ++layer) { + const auto* layer_weights = weights.GetLayer(layer); + TransformerLayer( + num_tokens, num_queries, pos_per_query, layer, layer_weights, + activations_[thread], kv_caches, *inner_pools_[thread]); + } + + // NOTE: we unconditionally call StreamToken, even if EOS. + for (size_t i = 0; i < batch_size; ++i) { + const size_t query_idx = i % num_queries; + const size_t batch_idx = i / num_queries; + runtime_config.StreamToken(query_idx, pos_per_query + batch_idx, + tokens[i], 0.0f); + } + }); + } + + // Stops spinning in our pools and resume spinning in main_pool_. + void ResumeMainSpinning() { + // If we didn't create a new inner pool, we didn't stop spinning on the + // main pool, so nothing to do here. + if (inner_pools_[0] == main_pool_) return; + + for (hwy::ThreadPool* p : inner_pools_) { + p->SetWaitMode(hwy::PoolWaitMode::kBlock); + } + outer_pool_->SetWaitMode(hwy::PoolWaitMode::kBlock); + main_pool_->SetWaitMode(hwy::PoolWaitMode::kSpin); + } + + private: + // Pins each outer thread after their inner threads so they are likely to + // run on the same socket. + void PinThreads(size_t outer_workers, size_t workers_per_outer) { + outer_pool_->Run( + 0, outer_workers, + [this, workers_per_outer](uint64_t outer, size_t outer_thread) { + HWY_ASSERT(outer == outer_thread); + // Pins inner *and* `outer` - the latter is the calling thread. + inner_pools_[outer]->Run( + 0, workers_per_outer, + [outer, workers_per_outer](uint64_t task, size_t thread) { + HWY_ASSERT(task == thread); // each worker has one task + const size_t lp = outer * workers_per_outer + task; + hwy::PinThreadToLogicalProcessor(lp); + }); + }); + } + + void DeleteInnerPools() { + for (hwy::ThreadPool* p : inner_pools_) { + if (p != main_pool_) delete p; + } + inner_pools_.clear(); + } + + hwy::ThreadPool* main_pool_; + std::unique_ptr outer_pool_; // always allocated + std::vector activations_; // size == outer_pool->NumWorkers() + // Either there is a single pointer equal to main_pool, or newly created pools + // that we own. The former case avoids thread creation overhead for prompts + // that fit in a single batch. + std::vector inner_pools_; + size_t num_batches_ = 0; +}; + +// `tokens` is length `num_tokens * num_queries`. In autoregressive decode, +// `num_tokens == 1`. +template HWY_NOINLINE void Transformer(const int* tokens, size_t num_tokens, size_t num_queries, size_t pos, const CompressedWeights& weights, @@ -550,9 +685,8 @@ HWY_NOINLINE void Transformer(const int* tokens, size_t num_tokens, for (size_t layer = 0; layer < TConfig::kLayers; ++layer) { const CompressedLayer* layer_weights = weights.GetLayer(layer); - TransformerLayer(num_tokens, num_queries, pos, - layer, layer_weights, - activations, kv_caches, pool); + TransformerLayer(num_tokens, num_queries, pos, layer, + layer_weights, activations, kv_caches, pool); if (layers_output) { const std::string block_name = "blocks." + std::to_string(layer); @@ -610,42 +744,81 @@ void RangeChecks(size_t& max_tokens, size_t& max_generated_tokens, // Placeholder for internal test3, do not remove -template -void GenerateT(const ByteStorageT& weights_u8, Activations& prefill, - Activations& activations, const RuntimeConfig& runtime_config, - const hwy::Span>& prompts, size_t pos, - const size_t query_index_offset, - const std::vector& kv_caches, hwy::ThreadPool& pool, - TimingInfo& timing_info) { - constexpr size_t kAdjustedPrefillBatchSize = - std::max((size_t)1, kPrefillBatchSize / kQueryBatchSize); - static_assert(kAdjustedPrefillBatchSize >= kMinAdjustedPrefillBatchSize); - const size_t num_queries = prompts.size(); - HWY_DASSERT(num_queries <= kQueryBatchSize); - pos *= num_queries; // position in (num_queries) interleaved token sequence. - const CompressedWeights& weights = - *reinterpret_cast*>(weights_u8.get()); - - size_t min_prompt_size = (size_t)-1; - size_t max_prompt_size = 0; - for (int i=0; i < prompts.size(); ++i) { - min_prompt_size = std::min(min_prompt_size, prompts[i].size()); - max_prompt_size = std::max(max_prompt_size, prompts[i].size()); +// Returns interleaved tokens: one from each query, followed by the second from +// all queries, with EOS padding. +static std::vector InterleaveQueries( + const hwy::Span>& queries, + const RuntimeConfig& runtime_config, size_t& min_prompt_size, + size_t& max_prompt_size) { + const size_t num_queries = queries.size(); + min_prompt_size = hwy::LimitsMax(); + max_prompt_size = 0; + for (size_t i = 0; i < num_queries; ++i) { + min_prompt_size = std::min(min_prompt_size, queries[i].size()); + max_prompt_size = std::max(max_prompt_size, queries[i].size()); } std::vector prompt; - prompt.reserve(max_prompt_size * prompts.size()); - for (int i = 0; i < max_prompt_size; ++i) { - for (int j=0; j < prompts.size(); ++j) { - if (i < prompts[j].size()) { - prompt.push_back(prompts[j][i]); + prompt.reserve(max_prompt_size * num_queries); + for (size_t pos = 0; pos < max_prompt_size; ++pos) { + for (size_t q = 0; q < num_queries; ++q) { + if (pos < queries[q].size()) { + prompt.push_back(queries[q][pos]); } else { - prompt.push_back(0); + prompt.push_back(runtime_config.eos_id); } } } + return prompt; +} +// Holds "is at end of stream" state for each query. +class TokenStreamer { + public: + explicit TokenStreamer(const RuntimeConfig& runtime_config) + : runtime_config_(runtime_config) {} + + // Returns whether the query was already at, or has just reached, the end of + // the stream: either via token == eos_id, or StreamToken returning false. + bool operator()(size_t query_idx, size_t pos, int token, float prob) { + if (HWY_UNLIKELY(is_eos_.Get(query_idx))) return true; + + if (!runtime_config_.StreamToken(query_idx, pos, token, prob) || + token == runtime_config_.eos_id) { + is_eos_.Set(query_idx); + return true; + } + + return false; + } + + private: + const RuntimeConfig& runtime_config_; + // BitSet4096 divides the arg by 64, so ensure it is at least 64. + hwy::BitSet4096 is_eos_; +}; + +// Generates one token per query in the batch. +// +// pos indexes the KV cache. In the first turn of a chat, pos = 0, and it +// continues to increase by one for each prefilled/generated token per query. +// query_idx_start is the first query index in the batch. +template +void GenerateT(const ByteStorageT& weights_u8, Activations& activations, + const RuntimeConfig& runtime_config, + const hwy::Span>& prompts, const size_t pos, + const size_t query_idx_start, + const std::vector& kv_caches, hwy::ThreadPool& pool, + TimingInfo& timing_info) { constexpr size_t kVocabSize = TConfig::kVocabSize; + const CompressedWeights& weights = + *reinterpret_cast*>(weights_u8.get()); + + const size_t num_queries = prompts.size(); + HWY_DASSERT(num_queries <= kQueryBatchSize); + size_t min_prompt_size, max_prompt_size; + const std::vector prompt = InterleaveQueries( + prompts, runtime_config, min_prompt_size, max_prompt_size); size_t max_tokens = runtime_config.max_tokens; size_t max_generated_tokens = runtime_config.max_generated_tokens; @@ -666,171 +839,92 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& prefill, runtime_config.accept_token); }; - std::vector reached_eos(num_queries); - std::fill(reached_eos.begin(), reached_eos.end(), false); - - // pos indexes the KV cache. In the first turn of a chat, pos = 0. - // - // After the first turn, pos gets passed in with > 0 corresponding to the - // current token position in the KV cache. - // - // pos_offset keeps track of the relative position within the turn, starting - // at 0 each turn. During prefill, pos_offset corresponds to the index into - // the prompt vector. - // - // In single-turn (non-chat) usage, pos and pos_offset start at 0 and are - // always equal. - size_t pos_offset = 0; // offset relative to pos - // Used to keep track of how many tokens are processed per prompt, - // so that we know when to start generating tokens. - size_t single_prompt_pos_offset = 0; + // Prefill stops before min_prompt_size - 1 because the last prompt token is + // the first input token for generation. + const size_t prefill_per_query = min_prompt_size - 1; + const hwy::Span prefill_tokens(prompt.data(), + prefill_per_query * num_queries); + PrefillState prefill(pool); + prefill.Init(prefill_tokens.size()); const double prefill_start = hwy::platform::Now(); + size_t interleaved_pos = pos * num_queries; + prefill.Prefill(prefill_tokens, num_queries, interleaved_pos, + weights, runtime_config, kv_caches); + interleaved_pos += prefill_tokens.size(); + timing_info.NotifyPrefill(prefill_tokens.size(), prefill_start); - // Prefill stops before prompt_size - 1 since the last prompt token is the - // first input token for generation. - while (single_prompt_pos_offset < min_prompt_size - 1) { - const size_t batch_size = std::min( - kPrefillBatchSize, min_prompt_size - 1 - single_prompt_pos_offset); - const size_t batch_and_query_size = batch_size * num_queries; - HWY_DASSERT(batch_size <= kPrefillBatchSize); - HWY_DASSERT(single_prompt_pos_offset + batch_size <= min_prompt_size - 1); - HWY_DASSERT(pos_offset + batch_size <= (min_prompt_size - 1) * num_queries); - const int* batch_tokens = prompt.data() + pos_offset; - Prefill( - batch_tokens, batch_size, num_queries, pos, weights, prefill, kv_caches, - pool); - for (size_t idx = 0; idx < batch_size; ++idx) { - bool all_tokens_eos = true; - for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) { - if (reached_eos[query_idx]) continue; - if (runtime_config.StreamToken( - query_idx + query_index_offset, single_prompt_pos_offset, - batch_tokens[idx * num_queries + query_idx], 0.0f)) { - all_tokens_eos = false; - } else { - reached_eos[query_idx] = true; - } - } - if (all_tokens_eos) { - return; - } - } - pos += batch_and_query_size; - pos_offset += batch_and_query_size; - single_prompt_pos_offset += batch_size; + prefill.ResumeMainSpinning(); + + // Storage for the last generated token from each query, passed to the next + // Transformer() call. + std::vector gen_tokens(num_queries); + + // Stream the last prompt token from each query and fill gen_tokens. + hwy::CopyBytes(&prompt[prefill_tokens.size()], gen_tokens.data(), + num_queries * sizeof(prompt[0])); + TokenStreamer token_streamer(runtime_config); + for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) { + (void)token_streamer(query_idx_start + query_idx, prefill_per_query, + gen_tokens[query_idx], 0.0f); } - timing_info.prefill_tok_sec = - static_cast(pos_offset) / (hwy::platform::Now() - prefill_start); - - // Start generation. const double gen_start = hwy::platform::Now(); - HWY_DASSERT(single_prompt_pos_offset == min_prompt_size - 1); - size_t pos_gen_start = pos_offset; - int token = prompt.at(pos_offset); - std::vector::const_iterator first = prompt.begin() + pos_offset; - std::vector::const_iterator last = first + num_queries; - std::vector gen_tokens(first, last); - // The loop below is not yet prepared for decode batch size > 1. - HWY_ASSERT(kDecodeBatchSize == 1); - bool all_tokens_eos = true; - for (size_t i=0; i < num_queries; ++i) { - if (reached_eos[i]) continue; - if (runtime_config.StreamToken(i + query_index_offset, - single_prompt_pos_offset, gen_tokens[i], - 0.0f)) { - all_tokens_eos = false; - } else { - reached_eos[i] = true; - } - } - if (all_tokens_eos) { - return; - } - for (size_t generate_pos = 0; - generate_pos < max_tokens && generate_pos < max_generated_tokens; - ++single_prompt_pos_offset, ++generate_pos) { - Transformer( - gen_tokens.data(), kDecodeBatchSize, num_queries, pos, weights, - activations, kv_caches, pool, runtime_config.layers_output); - float token_logit = 0.0f; - // The condition below is always true if we are doing Prefill above. - // We keep it here for clarity so that the code is correct even if Prefill - // is disabled. - bool all_tokens_eos = true; - for (size_t i = 0; i < num_queries; ++i, ++pos, ++pos_offset) { - const float* HWY_RESTRICT x = activations.x.Batch(i); - float* HWY_RESTRICT logits = activations.logits.Batch(i); - const size_t prompt_size = prompts[i].size(); - const bool is_generating_phase = - (single_prompt_pos_offset >= prompt_size - 1); - if (is_generating_phase) { - PROFILER_ZONE("Gen.Embedding"); - // Compute logits from last layer activations. - MatVec(weights.embedder_input_embedding, - 0, x, activations.even_odd.All(), - logits, pool); - if constexpr (TConfig::kFinalCap > 0.0f) { - LogitsSoftCap(TConfig::kFinalCap, logits, kVocabSize); - } - // Barrier: must have all logits so we can subtract max. - Softmax(logits, kVocabSize); - token = sample_token(logits, kVocabSize); - token_logit = logits[token]; - if (generate_pos == 0) { - timing_info.time_to_first_token = hwy::platform::Now() - gen_start; - } - } else { - // We would take this branch if we were not doing Prefill but would - // process the tokens of the prompt one at a time. - token = prompt.at(pos_offset); - token_logit = 0.0f; - } + for (size_t gen_per_query = 0; + gen_per_query < HWY_MIN(max_tokens, max_generated_tokens); + ++gen_per_query) { + // Decode: generate one token for each query. + Transformer(gen_tokens.data(), /*num_tokens=*/1, num_queries, + interleaved_pos, weights, activations, kv_caches, pool, + runtime_config.layers_output); + interleaved_pos += num_queries; - if (!reached_eos[i]) { - if (!runtime_config.StreamToken(i + query_index_offset, - single_prompt_pos_offset + 1, token, - token_logit)) { - token = runtime_config.eos_id; - } - if (token != runtime_config.eos_id) { - all_tokens_eos = false; - } else { - reached_eos[i] = true; - } + bool all_queries_eos = true; + PROFILER_ZONE("Gen.Embedding"); + for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) { + float* HWY_RESTRICT logits = activations.logits.Batch(query_idx); + // Compute logits from last layer activations. TODO: MatMul + MatVec( + weights.embedder_input_embedding, 0, activations.x.Batch(query_idx), + activations.even_odd.All(), logits, pool); + if constexpr (TConfig::kFinalCap > 0.0f) { + LogitsSoftCap(TConfig::kFinalCap, logits, kVocabSize); } - gen_tokens[i] = token; + Softmax(logits, kVocabSize); + const int token = sample_token(logits, kVocabSize); + timing_info.NotifyGenerated(prefill_start); + + const bool is_eos = token_streamer(query_idx_start + query_idx, + prefill_per_query + 1 + gen_per_query, + token, logits[token]); + all_queries_eos &= is_eos; + gen_tokens[query_idx] = is_eos ? runtime_config.eos_id : token; } - if (all_tokens_eos) { - break; - } - } - timing_info.gen_tok_sec = static_cast(pos_offset - pos_gen_start) / - (hwy::platform::Now() - gen_start); + if (all_queries_eos) break; + } // foreach token to generate + + timing_info.NotifyGenerateDone(gen_start); } +// TODO: prompt should also be span, not a vector. template -void GenerateSingleT(const ByteStorageT& weights_u8, Activations& prefill, - Activations& activations, +void GenerateSingleT(const ByteStorageT& weights_u8, Activations& activations, const RuntimeConfig& runtime_config, const std::vector& prompt, size_t pos, KVCache& kv_cache, hwy::ThreadPool& pool, TimingInfo& timing_info) { - // TODO: the input should also be span, not a vector. const hwy::Span prompt_span(const_cast(prompt.data()), prompt.size()); const hwy::Span> prompts(&prompt_span, 1); - // TODO: also span of kv_cache. + // TODO: also span of kv_cache, or batching inside KVCache? std::vector kv_caches = {&kv_cache}; - const size_t query_index_offset = 0; + const size_t query_idx_start = 0; GenerateT( - weights_u8, prefill, activations, runtime_config, prompts, pos, - query_index_offset, kv_caches, pool, timing_info); + weights_u8, activations, runtime_config, prompts, pos, query_idx_start, + kv_caches, pool, timing_info); } template -void GenerateBatchT(const ByteStorageT& weights_u8, Activations& prefill, - Activations& activations, +void GenerateBatchT(const ByteStorageT& weights_u8, Activations& activations, const RuntimeConfig& runtime_config, const hwy::Span>& prompts, size_t pos, const std::vector& kv_caches, @@ -838,12 +932,14 @@ void GenerateBatchT(const ByteStorageT& weights_u8, Activations& prefill, // Disable query batching for Griffin models. constexpr size_t kQueryBatchSize = (TConfig::kGriffinLayers > 0) ? 1 : kBatchedQueryBatchSize; - for (size_t i = 0; i < prompts.size(); i += kQueryBatchSize) { - const size_t num_queries = std::min(prompts.size() - i, kQueryBatchSize); - const hwy::Span> current_prompts( - prompts.data() + i, num_queries); - GenerateT(weights_u8, prefill, activations, - runtime_config, current_prompts, pos, i, + for (size_t query_idx_start = 0; query_idx_start < prompts.size(); + query_idx_start += kQueryBatchSize) { + const size_t num_queries = + std::min(prompts.size() - query_idx_start, kQueryBatchSize); + const hwy::Span> query_batch( + prompts.data() + query_idx_start, num_queries); + GenerateT(weights_u8, activations, runtime_config, + query_batch, pos, query_idx_start, kv_caches, pool, timing_info); } } @@ -855,24 +951,24 @@ void GenerateBatchT(const ByteStorageT& weights_u8, Activations& prefill, // These are extern functions defined by instantiations/*.cc, which include this // 'header' after defining GEMMA_CONFIG, which is for function overloading. void GenerateSingle( // NOLINT(misc-definitions-in-headers) - GEMMA_CONFIG, const ByteStorageT& weights_u8, Activations& prefill, - Activations& activations, const RuntimeConfig& runtime_config, - const std::vector& prompt, size_t pos, KVCache& kv_cache, - hwy::ThreadPool& pool, TimingInfo& timing_info) { + GEMMA_CONFIG, const ByteStorageT& weights_u8, Activations& activations, + const RuntimeConfig& runtime_config, const std::vector& prompt, + size_t pos, KVCache& kv_cache, hwy::ThreadPool& pool, + TimingInfo& timing_info) { HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateSingleT) - (weights_u8, prefill, activations, runtime_config, prompt, pos, kv_cache, - pool, timing_info); + (weights_u8, activations, runtime_config, prompt, pos, kv_cache, pool, + timing_info); } void GenerateBatch( // NOLINT(misc-definitions-in-headers) - GEMMA_CONFIG, const ByteStorageT& weights_u8, Activations& prefill, - Activations& activations, const RuntimeConfig& runtime_config, + GEMMA_CONFIG, const ByteStorageT& weights_u8, Activations& activations, + const RuntimeConfig& runtime_config, const hwy::Span>& prompts, size_t pos, const std::vector& kv_caches, hwy::ThreadPool& pool, TimingInfo& timing_info) { HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateBatchT) - (weights_u8, prefill, activations, runtime_config, prompts, pos, kv_caches, - pool, timing_info); + (weights_u8, activations, runtime_config, prompts, pos, kv_caches, pool, + timing_info); } #endif // HWY_ONCE diff --git a/gemma/gemma.cc b/gemma/gemma.cc index 6df0a15..df6474f 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -37,13 +37,11 @@ namespace gcpp { template -struct AllocateState { - void operator()(Activations& prefill, Activations& decode) const { - // When batching queries, the prefill batch size is reduced by a factor - // of kBatchedQueryBatchSize - prefill.Allocate(kMinAdjustedPrefillBatchSize * - kBatchedQueryBatchSize); - decode.Allocate(kDecodeBatchSize * kBatchedQueryBatchSize); +struct AllocateActivations { + void operator()(Activations& decode) const { + // TODO: this is wasted if we only have single-batch queries. Instead + // re-allocate when the query batch size is actually > 1? + decode.Allocate(kBatchedQueryBatchSize); } }; @@ -51,8 +49,7 @@ Gemma::Gemma(const Path& tokenizer_path, const Path& weights, const ModelInfo& info, hwy::ThreadPool& pool) : pool_(pool), tokenizer_(tokenizer_path), info_(info) { weights_u8_ = LoadCompressedWeights(weights, info.model, info.weight, pool); - CallForModelAndWeight(info.model, info.weight, prefill_, - decode_); + CallForModelAndWeight(info.model, info.weight, decode_); } Gemma::Gemma(GemmaTokenizer&& tokenizer, const ModelInfo& info, @@ -61,8 +58,7 @@ Gemma::Gemma(GemmaTokenizer&& tokenizer, const ModelInfo& info, HWY_ASSERT(info.weight == Type::kF32); weights_u8_ = CallForModel(info.model, pool); - CallForModelAndWeight(info.model, info.weight, prefill_, - decode_); + CallForModelAndWeight(info.model, info.weight, decode_); } Gemma::~Gemma() { @@ -74,17 +70,17 @@ Gemma::~Gemma() { // we shard them across multiple translation units in instantiations/*.cc. // This declares the functions defined there. We use overloading because // explicit instantiations are still too slow to compile. -#define GEMMA_DECLARE(CONFIGT, TWEIGHT) \ - extern void GenerateSingle( \ - CONFIGT, const ByteStorageT& weights_u8, Activations& prefill, \ - Activations& decode, const RuntimeConfig& runtime_config, \ - const std::vector& prompt, size_t pos, KVCache& kv_cache, \ - hwy::ThreadPool& pool, TimingInfo& timing_info); \ - extern void GenerateBatch( \ - CONFIGT, const ByteStorageT& weights_u8, Activations& prefill, \ - Activations& decode, const RuntimeConfig& runtime_config, \ - const hwy::Span>& prompts, size_t pos, \ - const std::vector& kv_caches, hwy::ThreadPool& pool, \ +#define GEMMA_DECLARE(CONFIGT, TWEIGHT) \ + extern void GenerateSingle( \ + CONFIGT, const ByteStorageT& weights_u8, Activations& decode, \ + const RuntimeConfig& runtime_config, const std::vector& prompt, \ + size_t pos, KVCache& kv_cache, hwy::ThreadPool& pool, \ + TimingInfo& timing_info); \ + extern void GenerateBatch( \ + CONFIGT, const ByteStorageT& weights_u8, Activations& decode, \ + const RuntimeConfig& runtime_config, \ + const hwy::Span>& prompts, size_t pos, \ + const std::vector& kv_caches, hwy::ThreadPool& pool, \ TimingInfo& timing_info); GEMMA_FOREACH_CONFIG_AND_WEIGHT(GEMMA_DECLARE); @@ -92,24 +88,24 @@ GEMMA_FOREACH_CONFIG_AND_WEIGHT(GEMMA_DECLARE); // TODO: gather all ByteStorageT into a type-erased model struct? template struct GenerateSingleT { - void operator()(const ByteStorageT& weights_u8, Activations& prefill, - Activations& decode, const RuntimeConfig& runtime_config, + void operator()(const ByteStorageT& weights_u8, Activations& decode, + const RuntimeConfig& runtime_config, const std::vector& prompt, size_t pos, KVCache& kv_cache, hwy::ThreadPool& pool, TimingInfo& timing_info) const { - GenerateSingle(TConfig(), weights_u8, prefill, decode, runtime_config, - prompt, pos, kv_cache, pool, timing_info); + GenerateSingle(TConfig(), weights_u8, decode, runtime_config, prompt, pos, + kv_cache, pool, timing_info); } }; template struct GenerateBatchT { - void operator()(const ByteStorageT& weights_u8, Activations& prefill, - Activations& decode, const RuntimeConfig& runtime_config, + void operator()(const ByteStorageT& weights_u8, Activations& decode, + const RuntimeConfig& runtime_config, const hwy::Span>& prompts, size_t pos, const std::vector& kv_caches, hwy::ThreadPool& pool, TimingInfo& timing_info) const { - GenerateBatch(TConfig(), weights_u8, prefill, decode, runtime_config, - prompts, pos, kv_caches, pool, timing_info); + GenerateBatch(TConfig(), weights_u8, decode, runtime_config, prompts, pos, + kv_caches, pool, timing_info); } }; @@ -119,8 +115,8 @@ void Gemma::Generate(const RuntimeConfig& runtime_config, pool_.SetWaitMode(hwy::PoolWaitMode::kSpin); CallForModelAndWeight( - info_.model, info_.weight, weights_u8_, prefill_, decode_, runtime_config, - prompt, start_pos, kv_cache, pool_, timing_info); + info_.model, info_.weight, weights_u8_, decode_, runtime_config, prompt, + start_pos, kv_cache, pool_, timing_info); pool_.SetWaitMode(hwy::PoolWaitMode::kBlock); } @@ -133,8 +129,8 @@ void Gemma::GenerateBatch(const RuntimeConfig& runtime_config, pool_.SetWaitMode(hwy::PoolWaitMode::kSpin); CallForModelAndWeight( - info_.model, info_.weight, weights_u8_, prefill_, decode_, runtime_config, - prompts, start_pos, kv_caches, pool_, timing_info); + info_.model, info_.weight, weights_u8_, decode_, runtime_config, prompts, + start_pos, kv_caches, pool_, timing_info); pool_.SetWaitMode(hwy::PoolWaitMode::kBlock); } diff --git a/gemma/gemma.h b/gemma/gemma.h index 477bcf8..3caeab5 100644 --- a/gemma/gemma.h +++ b/gemma/gemma.h @@ -28,6 +28,7 @@ #include "gemma/kv_cache.h" #include "gemma/tokenizer.h" #include "hwy/contrib/thread_pool/thread_pool.h" +#include "hwy/timer.h" // IWYU pragma: end_exports #include "hwy/aligned_allocator.h" #include "hwy/base.h" // hwy::bfloat16_t @@ -78,9 +79,30 @@ struct RuntimeConfig { }; struct TimingInfo { - double prefill_tok_sec = 0.0; - double gen_tok_sec = 0.0; - double time_to_first_token = 0.0; + void NotifyPrefill(size_t tokens, double start) { + prefill_tok_sec = + static_cast(tokens) / (hwy::platform::Now() - start); + gen_tok_sec = 0.0; + time_to_first_token = 0.0; + tokens_generated = 0; + } + + void NotifyGenerated(double prefill_start) { + ++tokens_generated; + if (HWY_UNLIKELY(tokens_generated == 1)) { + time_to_first_token = hwy::platform::Now() - prefill_start; + } + } + + void NotifyGenerateDone(double gen_start) { + gen_tok_sec = static_cast(tokens_generated) / + (hwy::platform::Now() - gen_start); + } + + double prefill_tok_sec; + double gen_tok_sec; + double time_to_first_token; + size_t tokens_generated; }; class Gemma { @@ -96,7 +118,6 @@ class Gemma { const ModelInfo& Info() const { return info_; } const GemmaTokenizer& Tokenizer() const { return tokenizer_; } const ByteStorageT& Weights() const { return weights_u8_; } - const Activations& Prefill() const { return prefill_; } const Activations& Decode() const { return decode_; } void Generate(const RuntimeConfig& runtime_config, @@ -115,7 +136,6 @@ class Gemma { // Type-erased so that this can be defined in the header, without requiring // forwarding functions. ByteStorageT weights_u8_; - Activations prefill_; Activations decode_; ModelInfo info_; }; diff --git a/gemma/run.cc b/gemma/run.cc index ea2a136..2d33827 100644 --- a/gemma/run.cc +++ b/gemma/run.cc @@ -97,9 +97,7 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, hwy::ThreadPool& pool, } else if (token == EOS_ID) { if (!args.multiturn) { abs_pos = 0; - if (args.deterministic) { - gen.seed(42); - } + InitGenerator(args, gen); } if (verbosity >= 2) { std::cout << "\n[ End ]\n";