Major Prefill/Generate cleanup, 1.3x Prefill speedup

This fixes TTFT, which was not including prefill.

PiperOrigin-RevId: 653690626
This commit is contained in:
Jan Wassenberg 2024-07-18 11:15:57 -07:00 committed by Copybara-Service
parent 3fe79b3876
commit 12016d31c3
7 changed files with 387 additions and 277 deletions

View File

@ -149,13 +149,14 @@ cc_library(
":tokenizer",
":kv_cache",
":weights",
"//compression:compress",
"//compression:io",
"@hwy//:hwy",
"@hwy//:bit_set",
"@hwy//:matvec",
"@hwy//:nanobenchmark", # timer
"@hwy//:profiler",
"@hwy//:thread_pool",
"@hwy//:topology",
],
)

View File

@ -76,8 +76,8 @@ GemmaEnv::GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference,
fprintf(stderr, "Loading model...\n");
model_ = AllocateGemma(loader_, pool_);
kv_caches_.reserve(16);
for (int i = 0; i < 16; ++i) {
kv_caches_.reserve(kBatchedQueryBatchSize);
for (int i = 0; i < kBatchedQueryBatchSize; ++i) {
kv_caches_.push_back(new KVCache(KVCache::Create(model_->Info().model)));
}
}

View File

@ -36,11 +36,10 @@ ByteStorageT AllocateSizeof() {
return hwy::AllocateAligned<uint8_t>(sizeof(T));
}
constexpr size_t kPrefillBatchSize = 512;
constexpr size_t kDecodeBatchSize = 1;
// Relatively small so that we can also parallelize non-Matmul work. There is
// one outer thread per batch, each with --num_threads / batches inner threads.
constexpr size_t kPrefillBatchSize = 64;
constexpr size_t kBatchedQueryBatchSize = 16;
constexpr size_t kMinAdjustedPrefillBatchSize =
HWY_MAX((size_t)1, kPrefillBatchSize / kBatchedQueryBatchSize);
// Model variants: see configs.h for details. When adding a new one, also
// update GEMMA_FOREACH* and Call* below, and add instantiations/*.cc.

View File

@ -26,23 +26,26 @@
#include <stddef.h>
#include <stdio.h>
#include <string.h> // memcpy
#include <algorithm>
#include <algorithm> // std::min
#include <memory> // std::unique_ptr
#include <string>
#include <type_traits>
#include <vector>
#include "gemma/activations.h"
#include "gemma/common.h"
#include "gemma/configs.h"
#include "gemma/gemma.h"
#include "gemma/ops.h"
#include "gemma/weights.h"
// Placeholder for internal test4, do not remove
#include "hwy/aligned_allocator.h"
#include "hwy/base.h"
#include "hwy/bit_set.h"
#include "hwy/contrib/matvec/matvec-inl.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/contrib/thread_pool/topology.h"
#include "hwy/highway.h"
#include "hwy/profiler.h"
#include "hwy/timer.h"
@ -269,7 +272,7 @@ HWY_NOINLINE void Attention(size_t batch_and_query_start, size_t num_tokens,
const float* HWY_RESTRICT q =
activations.q.Batch(batch_and_query_idx) + head * kQStride;
// Skip past the Q part of `q`, and copy KV to `kv`.
memcpy(kv, q + kQKVDim, 2 * kQKVDim * sizeof(float));
hwy::CopyBytes(q + kQKVDim, kv, 2 * kQKVDim * sizeof(float));
}
PostQK<TConfig>(kv, pos, layer);
});
@ -414,6 +417,7 @@ HWY_NOINLINE void FFW(Activations& activations, size_t num_tokens,
}
// TODO: pass Activations.x instead of Activations.
// `pos` is for the entire batch and does not include `batch_idx`.
template <class TConfig>
HWY_NOINLINE void EmbedToken(int token, size_t batch_idx, size_t pos,
const CompressedWeights<TConfig>& weights,
@ -421,8 +425,10 @@ HWY_NOINLINE void EmbedToken(int token, size_t batch_idx, size_t pos,
constexpr size_t kModelDim = TConfig::kModelDim;
GEMMA_CONSTEXPR_EMBSCALING const float kEmbScaling =
EmbeddingScaling<TConfig>();
HWY_DASSERT(token >= 0);
HWY_DASSERT(token < TConfig::kVocabSize);
Decompress(weights.embedder_input_embedding, token * kModelDim,
activations.x.Batch(batch_idx), kModelDim);
MulByConst(kEmbScaling, activations.x.Batch(batch_idx), kModelDim);
@ -441,7 +447,7 @@ HWY_NOINLINE void ResidualConnection(
AddFromBatched(num_tokens_and_queries, other, x, kModelDim);
}
template <class TConfig, size_t kQueryBatchSize>
template <class TConfig>
HWY_NOINLINE void TransformerLayer(
size_t num_tokens, size_t num_queries, size_t pos, size_t layer,
const CompressedLayer<TConfig>* layer_weights, Activations& activations,
@ -458,13 +464,10 @@ HWY_NOINLINE void TransformerLayer(
Attention<TConfig>(pos, num_tokens, num_queries, layer_of_type, activations,
layer_weights, kv_caches, pool);
} else {
// This Griffin layers should never exist unless the model is a Griffin
// model. This conditional prevents the compiler from generating code for
// this branch when building a non-Griffin model, since we have static
// asserts about the query batch size for Griffin layers.
// Only reached if the model is Griffin. `if constexpr` prevents generating
// this code for non-Griffin models.
if constexpr (TConfig::kGriffinLayers > 0) {
static_assert(kQueryBatchSize == 1,
"Griffin does not support batched queries.");
HWY_ASSERT(num_queries == 1);
GriffinRecurrent<TConfig>(pos, num_tokens, num_queries, layer_of_type,
activations, layer_weights, kv_caches, pool);
}
@ -494,39 +497,171 @@ HWY_NOINLINE void TransformerLayer(
/*is_attention=*/false);
}
template <class TConfig, size_t kBatchSize, size_t kQueryBatchSize>
HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens,
size_t num_queries, size_t pos,
const CompressedWeights<TConfig>& weights,
Activations& activations,
const std::vector<KVCache*>& kv_caches,
hwy::ThreadPool& pool) {
PROFILER_ZONE("Gen.Prefill");
HWY_DASSERT(num_queries <= kQueryBatchSize);
const size_t minibatch_size = std::min(num_tokens, kBatchSize);
// TODO: hoist pool.Run out of the loop, change the unit of work to batches.
for (size_t i = 0; i < num_tokens; i += minibatch_size) {
const size_t offset = i * num_queries;
const size_t current_token_count = std::min(
minibatch_size, num_tokens - i);
pool.Run(0, current_token_count * num_queries,
[&](const uint64_t token_idx, size_t /*thread*/) HWY_ATTR {
EmbedToken<TConfig>(tokens[token_idx + offset], token_idx,
pos + offset, weights, activations);
// For prefill, we have two-level parallelism:
// - Outer: input tokens are split into batches, each of which is one task
// processed by a worker in `outer_pool_`, which includes the main thread
// because it is the one that calls `Prefill`.
// - Inner: each `outer` worker passes `inner_pools_[outer]` to
// `TransformerLayer` for tensor-level parallelism.
//
// This class holds the thread pools and activations, recreated for each query.
//
// It is safe to parallelize batches because we write to KVCache at
// `pos % kSeqLen`, which is far greater than the number of workers.
// Note however that this currently leads to nondeterministic results because
// the RNG is invoked in different order.
class PrefillState {
public:
explicit PrefillState(hwy::ThreadPool& main_pool) : main_pool_(&main_pool) {}
~PrefillState() { DeleteInnerPools(); }
// Called before each query. Recreates thread pools, which has the advantage
// of tailoring the parallelism to the prompt length.
template <class TConfig>
void Init(size_t prefill_size) {
// Would be zero for single-token prompts (prefill_size == num_tokens - 1).
num_batches_ =
HWY_MAX(size_t{1}, hwy::DivCeil(prefill_size, kPrefillBatchSize));
// More than num_batches_ would waste workers on idling in the outer Run;
// more than NumWorkers() would exceed the global --num_threads.
const size_t outer_workers =
HWY_MIN(num_batches_, main_pool_->NumWorkers());
HWY_ASSERT(outer_workers != 0); // Otherwise activations_ is empty.
// One activation per outer worker. Allocating in parallel saves 30 ms.
activations_.resize(outer_workers);
main_pool_->Run(0, outer_workers, [this](uint64_t task, size_t /*thread*/) {
activations_[task].Allocate<TConfig>(kPrefillBatchSize);
});
DeleteInnerPools();
// If we'd create just one inner pool with all the workers, skip the cost of
// thread creation and pinning (about 60 ms) by reusing the main pool.
if (outer_workers <= 1) {
// Still allocate a dummy pool to simplify Prefill().
outer_pool_ = std::make_unique<hwy::ThreadPool>(1);
inner_pools_.push_back(main_pool_);
return;
}
// Before creating new threads, stop the old ones from spinning. Caller is
// responsible for undoing this by calling `ResumeMainSpinning`.
main_pool_->SetWaitMode(hwy::PoolWaitMode::kBlock);
outer_pool_ = std::make_unique<hwy::ThreadPool>(outer_workers);
outer_pool_->SetWaitMode(hwy::PoolWaitMode::kSpin);
// Assign up to `max_workers` to inner pools. Each inner pool creates
// `workers_per_outer - 1` threads in addition to its 'main' thread, which
// is the one calling `inner_pools[outer]->Run`, i.e., `outer`. In total,
// `outer_workers * (max_workers / outer_workers)` workers are used.
const size_t workers_per_outer = main_pool_->NumWorkers() / outer_workers;
for (size_t outer = 0; outer < outer_workers; ++outer) {
inner_pools_.push_back(new hwy::ThreadPool(workers_per_outer));
inner_pools_.back()->SetWaitMode(hwy::PoolWaitMode::kSpin);
}
PinThreads(outer_workers, workers_per_outer);
}
// `tokens` are from interleaved queries. (See InterleaveQueries() below.)
template <class TConfig>
HWY_NOINLINE void Prefill(hwy::Span<const int> tokens, size_t num_queries,
size_t pos,
const CompressedWeights<TConfig>& weights,
const RuntimeConfig& runtime_config,
const std::vector<KVCache*>& kv_caches) {
PROFILER_ZONE("Gen.Prefill");
HWY_ASSERT(activations_.size() == outer_pool_->NumWorkers());
HWY_ASSERT(inner_pools_.size() == outer_pool_->NumWorkers());
outer_pool_->Run(
0, num_batches_, [&](const uint64_t batch_num, size_t thread) HWY_ATTR {
const size_t batch_start = batch_num * kPrefillBatchSize;
const size_t batch_size =
HWY_MIN(kPrefillBatchSize, tokens.size() - batch_start);
HWY_DASSERT(batch_start % num_queries == 0);
HWY_DASSERT(batch_size % num_queries == 0);
const size_t pos_per_query = pos + batch_start / num_queries;
const size_t num_tokens = batch_size / num_queries;
// Negligible time compared to TransformerLayer.
for (size_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) {
EmbedToken<TConfig>(tokens[batch_start + batch_idx], batch_idx,
pos_per_query, weights, activations_[thread]);
}
for (size_t layer = 0; layer < TConfig::kLayers; ++layer) {
const auto* layer_weights = weights.GetLayer(layer);
TransformerLayer<TConfig, kQueryBatchSize>(
current_token_count, num_queries, pos + offset, layer, layer_weights,
activations, kv_caches, pool);
}
}
TransformerLayer<TConfig>(
num_tokens, num_queries, pos_per_query, layer, layer_weights,
activations_[thread], kv_caches, *inner_pools_[thread]);
}
// Compute the transformer for a batch of input tokens. During generation,
// we usually have num_tokens == 1 (and also kBatchSize == 1).
template <class TConfig, size_t kQueryBatchSize>
// NOTE: we unconditionally call StreamToken, even if EOS.
for (size_t i = 0; i < batch_size; ++i) {
const size_t query_idx = i % num_queries;
const size_t batch_idx = i / num_queries;
runtime_config.StreamToken(query_idx, pos_per_query + batch_idx,
tokens[i], 0.0f);
}
});
}
// Stops spinning in our pools and resume spinning in main_pool_.
void ResumeMainSpinning() {
// If we didn't create a new inner pool, we didn't stop spinning on the
// main pool, so nothing to do here.
if (inner_pools_[0] == main_pool_) return;
for (hwy::ThreadPool* p : inner_pools_) {
p->SetWaitMode(hwy::PoolWaitMode::kBlock);
}
outer_pool_->SetWaitMode(hwy::PoolWaitMode::kBlock);
main_pool_->SetWaitMode(hwy::PoolWaitMode::kSpin);
}
private:
// Pins each outer thread after their inner threads so they are likely to
// run on the same socket.
void PinThreads(size_t outer_workers, size_t workers_per_outer) {
outer_pool_->Run(
0, outer_workers,
[this, workers_per_outer](uint64_t outer, size_t outer_thread) {
HWY_ASSERT(outer == outer_thread);
// Pins inner *and* `outer` - the latter is the calling thread.
inner_pools_[outer]->Run(
0, workers_per_outer,
[outer, workers_per_outer](uint64_t task, size_t thread) {
HWY_ASSERT(task == thread); // each worker has one task
const size_t lp = outer * workers_per_outer + task;
hwy::PinThreadToLogicalProcessor(lp);
});
});
}
void DeleteInnerPools() {
for (hwy::ThreadPool* p : inner_pools_) {
if (p != main_pool_) delete p;
}
inner_pools_.clear();
}
hwy::ThreadPool* main_pool_;
std::unique_ptr<hwy::ThreadPool> outer_pool_; // always allocated
std::vector<Activations> activations_; // size == outer_pool->NumWorkers()
// Either there is a single pointer equal to main_pool, or newly created pools
// that we own. The former case avoids thread creation overhead for prompts
// that fit in a single batch.
std::vector<hwy::ThreadPool*> inner_pools_;
size_t num_batches_ = 0;
};
// `tokens` is length `num_tokens * num_queries`. In autoregressive decode,
// `num_tokens == 1`.
template <class TConfig>
HWY_NOINLINE void Transformer(const int* tokens, size_t num_tokens,
size_t num_queries, size_t pos,
const CompressedWeights<TConfig>& weights,
@ -550,9 +685,8 @@ HWY_NOINLINE void Transformer(const int* tokens, size_t num_tokens,
for (size_t layer = 0; layer < TConfig::kLayers; ++layer) {
const CompressedLayer<TConfig>* layer_weights = weights.GetLayer(layer);
TransformerLayer<TConfig, kQueryBatchSize>(num_tokens, num_queries, pos,
layer, layer_weights,
activations, kv_caches, pool);
TransformerLayer<TConfig>(num_tokens, num_queries, pos, layer,
layer_weights, activations, kv_caches, pool);
if (layers_output) {
const std::string block_name = "blocks." + std::to_string(layer);
@ -610,42 +744,81 @@ void RangeChecks(size_t& max_tokens, size_t& max_generated_tokens,
// Placeholder for internal test3, do not remove
template <class TConfig, size_t kQueryBatchSize>
void GenerateT(const ByteStorageT& weights_u8, Activations& prefill,
Activations& activations, const RuntimeConfig& runtime_config,
const hwy::Span<const hwy::Span<int>>& prompts, size_t pos,
const size_t query_index_offset,
const std::vector<KVCache*>& kv_caches, hwy::ThreadPool& pool,
TimingInfo& timing_info) {
constexpr size_t kAdjustedPrefillBatchSize =
std::max((size_t)1, kPrefillBatchSize / kQueryBatchSize);
static_assert(kAdjustedPrefillBatchSize >= kMinAdjustedPrefillBatchSize);
const size_t num_queries = prompts.size();
HWY_DASSERT(num_queries <= kQueryBatchSize);
pos *= num_queries; // position in (num_queries) interleaved token sequence.
const CompressedWeights<TConfig>& weights =
*reinterpret_cast<const CompressedWeights<TConfig>*>(weights_u8.get());
size_t min_prompt_size = (size_t)-1;
size_t max_prompt_size = 0;
for (int i=0; i < prompts.size(); ++i) {
min_prompt_size = std::min(min_prompt_size, prompts[i].size());
max_prompt_size = std::max(max_prompt_size, prompts[i].size());
// Returns interleaved tokens: one from each query, followed by the second from
// all queries, with EOS padding.
static std::vector<int> InterleaveQueries(
const hwy::Span<const hwy::Span<int>>& queries,
const RuntimeConfig& runtime_config, size_t& min_prompt_size,
size_t& max_prompt_size) {
const size_t num_queries = queries.size();
min_prompt_size = hwy::LimitsMax<size_t>();
max_prompt_size = 0;
for (size_t i = 0; i < num_queries; ++i) {
min_prompt_size = std::min(min_prompt_size, queries[i].size());
max_prompt_size = std::max(max_prompt_size, queries[i].size());
}
std::vector<int> prompt;
prompt.reserve(max_prompt_size * prompts.size());
for (int i = 0; i < max_prompt_size; ++i) {
for (int j=0; j < prompts.size(); ++j) {
if (i < prompts[j].size()) {
prompt.push_back(prompts[j][i]);
prompt.reserve(max_prompt_size * num_queries);
for (size_t pos = 0; pos < max_prompt_size; ++pos) {
for (size_t q = 0; q < num_queries; ++q) {
if (pos < queries[q].size()) {
prompt.push_back(queries[q][pos]);
} else {
prompt.push_back(0);
prompt.push_back(runtime_config.eos_id);
}
}
}
return prompt;
}
// Holds "is at end of stream" state for each query.
class TokenStreamer {
public:
explicit TokenStreamer(const RuntimeConfig& runtime_config)
: runtime_config_(runtime_config) {}
// Returns whether the query was already at, or has just reached, the end of
// the stream: either via token == eos_id, or StreamToken returning false.
bool operator()(size_t query_idx, size_t pos, int token, float prob) {
if (HWY_UNLIKELY(is_eos_.Get(query_idx))) return true;
if (!runtime_config_.StreamToken(query_idx, pos, token, prob) ||
token == runtime_config_.eos_id) {
is_eos_.Set(query_idx);
return true;
}
return false;
}
private:
const RuntimeConfig& runtime_config_;
// BitSet4096 divides the arg by 64, so ensure it is at least 64.
hwy::BitSet4096<HWY_MAX(64, kBatchedQueryBatchSize)> is_eos_;
};
// Generates one token per query in the batch.
//
// pos indexes the KV cache. In the first turn of a chat, pos = 0, and it
// continues to increase by one for each prefilled/generated token per query.
// query_idx_start is the first query index in the batch.
template <class TConfig, size_t kQueryBatchSize>
void GenerateT(const ByteStorageT& weights_u8, Activations& activations,
const RuntimeConfig& runtime_config,
const hwy::Span<const hwy::Span<int>>& prompts, const size_t pos,
const size_t query_idx_start,
const std::vector<KVCache*>& kv_caches, hwy::ThreadPool& pool,
TimingInfo& timing_info) {
constexpr size_t kVocabSize = TConfig::kVocabSize;
const CompressedWeights<TConfig>& weights =
*reinterpret_cast<const CompressedWeights<TConfig>*>(weights_u8.get());
const size_t num_queries = prompts.size();
HWY_DASSERT(num_queries <= kQueryBatchSize);
size_t min_prompt_size, max_prompt_size;
const std::vector<int> prompt = InterleaveQueries(
prompts, runtime_config, min_prompt_size, max_prompt_size);
size_t max_tokens = runtime_config.max_tokens;
size_t max_generated_tokens = runtime_config.max_generated_tokens;
@ -666,171 +839,92 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& prefill,
runtime_config.accept_token);
};
std::vector<bool> reached_eos(num_queries);
std::fill(reached_eos.begin(), reached_eos.end(), false);
// pos indexes the KV cache. In the first turn of a chat, pos = 0.
//
// After the first turn, pos gets passed in with > 0 corresponding to the
// current token position in the KV cache.
//
// pos_offset keeps track of the relative position within the turn, starting
// at 0 each turn. During prefill, pos_offset corresponds to the index into
// the prompt vector.
//
// In single-turn (non-chat) usage, pos and pos_offset start at 0 and are
// always equal.
size_t pos_offset = 0; // offset relative to pos
// Used to keep track of how many tokens are processed per prompt,
// so that we know when to start generating tokens.
size_t single_prompt_pos_offset = 0;
// Prefill stops before min_prompt_size - 1 because the last prompt token is
// the first input token for generation.
const size_t prefill_per_query = min_prompt_size - 1;
const hwy::Span<const int> prefill_tokens(prompt.data(),
prefill_per_query * num_queries);
PrefillState prefill(pool);
prefill.Init<TConfig>(prefill_tokens.size());
const double prefill_start = hwy::platform::Now();
size_t interleaved_pos = pos * num_queries;
prefill.Prefill<TConfig>(prefill_tokens, num_queries, interleaved_pos,
weights, runtime_config, kv_caches);
interleaved_pos += prefill_tokens.size();
timing_info.NotifyPrefill(prefill_tokens.size(), prefill_start);
// Prefill stops before prompt_size - 1 since the last prompt token is the
// first input token for generation.
while (single_prompt_pos_offset < min_prompt_size - 1) {
const size_t batch_size = std::min(
kPrefillBatchSize, min_prompt_size - 1 - single_prompt_pos_offset);
const size_t batch_and_query_size = batch_size * num_queries;
HWY_DASSERT(batch_size <= kPrefillBatchSize);
HWY_DASSERT(single_prompt_pos_offset + batch_size <= min_prompt_size - 1);
HWY_DASSERT(pos_offset + batch_size <= (min_prompt_size - 1) * num_queries);
const int* batch_tokens = prompt.data() + pos_offset;
Prefill<TConfig, kAdjustedPrefillBatchSize, kQueryBatchSize>(
batch_tokens, batch_size, num_queries, pos, weights, prefill, kv_caches,
pool);
for (size_t idx = 0; idx < batch_size; ++idx) {
bool all_tokens_eos = true;
prefill.ResumeMainSpinning();
// Storage for the last generated token from each query, passed to the next
// Transformer() call.
std::vector<int> gen_tokens(num_queries);
// Stream the last prompt token from each query and fill gen_tokens.
hwy::CopyBytes(&prompt[prefill_tokens.size()], gen_tokens.data(),
num_queries * sizeof(prompt[0]));
TokenStreamer token_streamer(runtime_config);
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
if (reached_eos[query_idx]) continue;
if (runtime_config.StreamToken(
query_idx + query_index_offset, single_prompt_pos_offset,
batch_tokens[idx * num_queries + query_idx], 0.0f)) {
all_tokens_eos = false;
} else {
reached_eos[query_idx] = true;
}
}
if (all_tokens_eos) {
return;
}
}
pos += batch_and_query_size;
pos_offset += batch_and_query_size;
single_prompt_pos_offset += batch_size;
(void)token_streamer(query_idx_start + query_idx, prefill_per_query,
gen_tokens[query_idx], 0.0f);
}
timing_info.prefill_tok_sec =
static_cast<double>(pos_offset) / (hwy::platform::Now() - prefill_start);
// Start generation.
const double gen_start = hwy::platform::Now();
HWY_DASSERT(single_prompt_pos_offset == min_prompt_size - 1);
size_t pos_gen_start = pos_offset;
int token = prompt.at(pos_offset);
std::vector<int>::const_iterator first = prompt.begin() + pos_offset;
std::vector<int>::const_iterator last = first + num_queries;
std::vector<int> gen_tokens(first, last);
// The loop below is not yet prepared for decode batch size > 1.
HWY_ASSERT(kDecodeBatchSize == 1);
bool all_tokens_eos = true;
for (size_t i=0; i < num_queries; ++i) {
if (reached_eos[i]) continue;
if (runtime_config.StreamToken(i + query_index_offset,
single_prompt_pos_offset, gen_tokens[i],
0.0f)) {
all_tokens_eos = false;
} else {
reached_eos[i] = true;
}
}
if (all_tokens_eos) {
return;
}
for (size_t generate_pos = 0;
generate_pos < max_tokens && generate_pos < max_generated_tokens;
++single_prompt_pos_offset, ++generate_pos) {
Transformer<TConfig, kQueryBatchSize>(
gen_tokens.data(), kDecodeBatchSize, num_queries, pos, weights,
activations, kv_caches, pool, runtime_config.layers_output);
float token_logit = 0.0f;
// The condition below is always true if we are doing Prefill above.
// We keep it here for clarity so that the code is correct even if Prefill
// is disabled.
bool all_tokens_eos = true;
for (size_t i = 0; i < num_queries; ++i, ++pos, ++pos_offset) {
const float* HWY_RESTRICT x = activations.x.Batch(i);
float* HWY_RESTRICT logits = activations.logits.Batch(i);
const size_t prompt_size = prompts[i].size();
const bool is_generating_phase =
(single_prompt_pos_offset >= prompt_size - 1);
if (is_generating_phase) {
for (size_t gen_per_query = 0;
gen_per_query < HWY_MIN(max_tokens, max_generated_tokens);
++gen_per_query) {
// Decode: generate one token for each query.
Transformer<TConfig>(gen_tokens.data(), /*num_tokens=*/1, num_queries,
interleaved_pos, weights, activations, kv_caches, pool,
runtime_config.layers_output);
interleaved_pos += num_queries;
bool all_queries_eos = true;
PROFILER_ZONE("Gen.Embedding");
// Compute logits from last layer activations.
MatVec<kVocabSize, TConfig::kModelDim>(weights.embedder_input_embedding,
0, x, activations.even_odd.All(),
logits, pool);
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
float* HWY_RESTRICT logits = activations.logits.Batch(query_idx);
// Compute logits from last layer activations. TODO: MatMul
MatVec<kVocabSize, TConfig::kModelDim>(
weights.embedder_input_embedding, 0, activations.x.Batch(query_idx),
activations.even_odd.All(), logits, pool);
if constexpr (TConfig::kFinalCap > 0.0f) {
LogitsSoftCap(TConfig::kFinalCap, logits, kVocabSize);
}
// Barrier: must have all logits so we can subtract max.
Softmax(logits, kVocabSize);
token = sample_token(logits, kVocabSize);
token_logit = logits[token];
if (generate_pos == 0) {
timing_info.time_to_first_token = hwy::platform::Now() - gen_start;
const int token = sample_token(logits, kVocabSize);
timing_info.NotifyGenerated(prefill_start);
const bool is_eos = token_streamer(query_idx_start + query_idx,
prefill_per_query + 1 + gen_per_query,
token, logits[token]);
all_queries_eos &= is_eos;
gen_tokens[query_idx] = is_eos ? runtime_config.eos_id : token;
}
} else {
// We would take this branch if we were not doing Prefill but would
// process the tokens of the prompt one at a time.
token = prompt.at(pos_offset);
token_logit = 0.0f;
}
if (!reached_eos[i]) {
if (!runtime_config.StreamToken(i + query_index_offset,
single_prompt_pos_offset + 1, token,
token_logit)) {
token = runtime_config.eos_id;
}
if (token != runtime_config.eos_id) {
all_tokens_eos = false;
} else {
reached_eos[i] = true;
}
}
gen_tokens[i] = token;
}
if (all_tokens_eos) {
break;
}
}
timing_info.gen_tok_sec = static_cast<double>(pos_offset - pos_gen_start) /
(hwy::platform::Now() - gen_start);
if (all_queries_eos) break;
} // foreach token to generate
timing_info.NotifyGenerateDone(gen_start);
}
// TODO: prompt should also be span, not a vector.
template <class TConfig>
void GenerateSingleT(const ByteStorageT& weights_u8, Activations& prefill,
Activations& activations,
void GenerateSingleT(const ByteStorageT& weights_u8, Activations& activations,
const RuntimeConfig& runtime_config,
const std::vector<int>& prompt, size_t pos,
KVCache& kv_cache, hwy::ThreadPool& pool,
TimingInfo& timing_info) {
// TODO: the input should also be span, not a vector.
const hwy::Span<int> prompt_span(const_cast<int*>(prompt.data()),
prompt.size());
const hwy::Span<const hwy::Span<int>> prompts(&prompt_span, 1);
// TODO: also span of kv_cache.
// TODO: also span of kv_cache, or batching inside KVCache?
std::vector<KVCache*> kv_caches = {&kv_cache};
const size_t query_index_offset = 0;
const size_t query_idx_start = 0;
GenerateT<TConfig, /*kQueryBatchSize=*/1>(
weights_u8, prefill, activations, runtime_config, prompts, pos,
query_index_offset, kv_caches, pool, timing_info);
weights_u8, activations, runtime_config, prompts, pos, query_idx_start,
kv_caches, pool, timing_info);
}
template <class TConfig>
void GenerateBatchT(const ByteStorageT& weights_u8, Activations& prefill,
Activations& activations,
void GenerateBatchT(const ByteStorageT& weights_u8, Activations& activations,
const RuntimeConfig& runtime_config,
const hwy::Span<const hwy::Span<int>>& prompts, size_t pos,
const std::vector<KVCache*>& kv_caches,
@ -838,12 +932,14 @@ void GenerateBatchT(const ByteStorageT& weights_u8, Activations& prefill,
// Disable query batching for Griffin models.
constexpr size_t kQueryBatchSize =
(TConfig::kGriffinLayers > 0) ? 1 : kBatchedQueryBatchSize;
for (size_t i = 0; i < prompts.size(); i += kQueryBatchSize) {
const size_t num_queries = std::min(prompts.size() - i, kQueryBatchSize);
const hwy::Span<const hwy::Span<int>> current_prompts(
prompts.data() + i, num_queries);
GenerateT<TConfig, kQueryBatchSize>(weights_u8, prefill, activations,
runtime_config, current_prompts, pos, i,
for (size_t query_idx_start = 0; query_idx_start < prompts.size();
query_idx_start += kQueryBatchSize) {
const size_t num_queries =
std::min(prompts.size() - query_idx_start, kQueryBatchSize);
const hwy::Span<const hwy::Span<int>> query_batch(
prompts.data() + query_idx_start, num_queries);
GenerateT<TConfig, kQueryBatchSize>(weights_u8, activations, runtime_config,
query_batch, pos, query_idx_start,
kv_caches, pool, timing_info);
}
}
@ -855,24 +951,24 @@ void GenerateBatchT(const ByteStorageT& weights_u8, Activations& prefill,
// These are extern functions defined by instantiations/*.cc, which include this
// 'header' after defining GEMMA_CONFIG, which is for function overloading.
void GenerateSingle( // NOLINT(misc-definitions-in-headers)
GEMMA_CONFIG, const ByteStorageT& weights_u8, Activations& prefill,
Activations& activations, const RuntimeConfig& runtime_config,
const std::vector<int>& prompt, size_t pos, KVCache& kv_cache,
hwy::ThreadPool& pool, TimingInfo& timing_info) {
GEMMA_CONFIG, const ByteStorageT& weights_u8, Activations& activations,
const RuntimeConfig& runtime_config, const std::vector<int>& prompt,
size_t pos, KVCache& kv_cache, hwy::ThreadPool& pool,
TimingInfo& timing_info) {
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateSingleT<GEMMA_CONFIG>)
(weights_u8, prefill, activations, runtime_config, prompt, pos, kv_cache,
pool, timing_info);
(weights_u8, activations, runtime_config, prompt, pos, kv_cache, pool,
timing_info);
}
void GenerateBatch( // NOLINT(misc-definitions-in-headers)
GEMMA_CONFIG, const ByteStorageT& weights_u8, Activations& prefill,
Activations& activations, const RuntimeConfig& runtime_config,
GEMMA_CONFIG, const ByteStorageT& weights_u8, Activations& activations,
const RuntimeConfig& runtime_config,
const hwy::Span<const hwy::Span<int>>& prompts, size_t pos,
const std::vector<KVCache*>& kv_caches, hwy::ThreadPool& pool,
TimingInfo& timing_info) {
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateBatchT<GEMMA_CONFIG>)
(weights_u8, prefill, activations, runtime_config, prompts, pos, kv_caches,
pool, timing_info);
(weights_u8, activations, runtime_config, prompts, pos, kv_caches, pool,
timing_info);
}
#endif // HWY_ONCE

View File

@ -37,13 +37,11 @@
namespace gcpp {
template <typename TConfig>
struct AllocateState {
void operator()(Activations& prefill, Activations& decode) const {
// When batching queries, the prefill batch size is reduced by a factor
// of kBatchedQueryBatchSize
prefill.Allocate<TConfig>(kMinAdjustedPrefillBatchSize *
kBatchedQueryBatchSize);
decode.Allocate<TConfig>(kDecodeBatchSize * kBatchedQueryBatchSize);
struct AllocateActivations {
void operator()(Activations& decode) const {
// TODO: this is wasted if we only have single-batch queries. Instead
// re-allocate when the query batch size is actually > 1?
decode.Allocate<TConfig>(kBatchedQueryBatchSize);
}
};
@ -51,8 +49,7 @@ Gemma::Gemma(const Path& tokenizer_path, const Path& weights,
const ModelInfo& info, hwy::ThreadPool& pool)
: pool_(pool), tokenizer_(tokenizer_path), info_(info) {
weights_u8_ = LoadCompressedWeights(weights, info.model, info.weight, pool);
CallForModelAndWeight<AllocateState>(info.model, info.weight, prefill_,
decode_);
CallForModelAndWeight<AllocateActivations>(info.model, info.weight, decode_);
}
Gemma::Gemma(GemmaTokenizer&& tokenizer, const ModelInfo& info,
@ -61,8 +58,7 @@ Gemma::Gemma(GemmaTokenizer&& tokenizer, const ModelInfo& info,
HWY_ASSERT(info.weight == Type::kF32);
weights_u8_ =
CallForModel<float, AllocateCompressedWeights>(info.model, pool);
CallForModelAndWeight<AllocateState>(info.model, info.weight, prefill_,
decode_);
CallForModelAndWeight<AllocateActivations>(info.model, info.weight, decode_);
}
Gemma::~Gemma() {
@ -76,13 +72,13 @@ Gemma::~Gemma() {
// explicit instantiations are still too slow to compile.
#define GEMMA_DECLARE(CONFIGT, TWEIGHT) \
extern void GenerateSingle( \
CONFIGT<TWEIGHT>, const ByteStorageT& weights_u8, Activations& prefill, \
Activations& decode, const RuntimeConfig& runtime_config, \
const std::vector<int>& prompt, size_t pos, KVCache& kv_cache, \
hwy::ThreadPool& pool, TimingInfo& timing_info); \
CONFIGT<TWEIGHT>, const ByteStorageT& weights_u8, Activations& decode, \
const RuntimeConfig& runtime_config, const std::vector<int>& prompt, \
size_t pos, KVCache& kv_cache, hwy::ThreadPool& pool, \
TimingInfo& timing_info); \
extern void GenerateBatch( \
CONFIGT<TWEIGHT>, const ByteStorageT& weights_u8, Activations& prefill, \
Activations& decode, const RuntimeConfig& runtime_config, \
CONFIGT<TWEIGHT>, const ByteStorageT& weights_u8, Activations& decode, \
const RuntimeConfig& runtime_config, \
const hwy::Span<const hwy::Span<int>>& prompts, size_t pos, \
const std::vector<KVCache*>& kv_caches, hwy::ThreadPool& pool, \
TimingInfo& timing_info);
@ -92,24 +88,24 @@ GEMMA_FOREACH_CONFIG_AND_WEIGHT(GEMMA_DECLARE);
// TODO: gather all ByteStorageT into a type-erased model struct?
template <class TConfig>
struct GenerateSingleT {
void operator()(const ByteStorageT& weights_u8, Activations& prefill,
Activations& decode, const RuntimeConfig& runtime_config,
void operator()(const ByteStorageT& weights_u8, Activations& decode,
const RuntimeConfig& runtime_config,
const std::vector<int>& prompt, size_t pos, KVCache& kv_cache,
hwy::ThreadPool& pool, TimingInfo& timing_info) const {
GenerateSingle(TConfig(), weights_u8, prefill, decode, runtime_config,
prompt, pos, kv_cache, pool, timing_info);
GenerateSingle(TConfig(), weights_u8, decode, runtime_config, prompt, pos,
kv_cache, pool, timing_info);
}
};
template <class TConfig>
struct GenerateBatchT {
void operator()(const ByteStorageT& weights_u8, Activations& prefill,
Activations& decode, const RuntimeConfig& runtime_config,
void operator()(const ByteStorageT& weights_u8, Activations& decode,
const RuntimeConfig& runtime_config,
const hwy::Span<const hwy::Span<int>>& prompts, size_t pos,
const std::vector<KVCache*>& kv_caches, hwy::ThreadPool& pool,
TimingInfo& timing_info) const {
GenerateBatch(TConfig(), weights_u8, prefill, decode, runtime_config,
prompts, pos, kv_caches, pool, timing_info);
GenerateBatch(TConfig(), weights_u8, decode, runtime_config, prompts, pos,
kv_caches, pool, timing_info);
}
};
@ -119,8 +115,8 @@ void Gemma::Generate(const RuntimeConfig& runtime_config,
pool_.SetWaitMode(hwy::PoolWaitMode::kSpin);
CallForModelAndWeight<GenerateSingleT>(
info_.model, info_.weight, weights_u8_, prefill_, decode_, runtime_config,
prompt, start_pos, kv_cache, pool_, timing_info);
info_.model, info_.weight, weights_u8_, decode_, runtime_config, prompt,
start_pos, kv_cache, pool_, timing_info);
pool_.SetWaitMode(hwy::PoolWaitMode::kBlock);
}
@ -133,8 +129,8 @@ void Gemma::GenerateBatch(const RuntimeConfig& runtime_config,
pool_.SetWaitMode(hwy::PoolWaitMode::kSpin);
CallForModelAndWeight<GenerateBatchT>(
info_.model, info_.weight, weights_u8_, prefill_, decode_, runtime_config,
prompts, start_pos, kv_caches, pool_, timing_info);
info_.model, info_.weight, weights_u8_, decode_, runtime_config, prompts,
start_pos, kv_caches, pool_, timing_info);
pool_.SetWaitMode(hwy::PoolWaitMode::kBlock);
}

View File

@ -28,6 +28,7 @@
#include "gemma/kv_cache.h"
#include "gemma/tokenizer.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/timer.h"
// IWYU pragma: end_exports
#include "hwy/aligned_allocator.h"
#include "hwy/base.h" // hwy::bfloat16_t
@ -78,9 +79,30 @@ struct RuntimeConfig {
};
struct TimingInfo {
double prefill_tok_sec = 0.0;
double gen_tok_sec = 0.0;
double time_to_first_token = 0.0;
void NotifyPrefill(size_t tokens, double start) {
prefill_tok_sec =
static_cast<double>(tokens) / (hwy::platform::Now() - start);
gen_tok_sec = 0.0;
time_to_first_token = 0.0;
tokens_generated = 0;
}
void NotifyGenerated(double prefill_start) {
++tokens_generated;
if (HWY_UNLIKELY(tokens_generated == 1)) {
time_to_first_token = hwy::platform::Now() - prefill_start;
}
}
void NotifyGenerateDone(double gen_start) {
gen_tok_sec = static_cast<double>(tokens_generated) /
(hwy::platform::Now() - gen_start);
}
double prefill_tok_sec;
double gen_tok_sec;
double time_to_first_token;
size_t tokens_generated;
};
class Gemma {
@ -96,7 +118,6 @@ class Gemma {
const ModelInfo& Info() const { return info_; }
const GemmaTokenizer& Tokenizer() const { return tokenizer_; }
const ByteStorageT& Weights() const { return weights_u8_; }
const Activations& Prefill() const { return prefill_; }
const Activations& Decode() const { return decode_; }
void Generate(const RuntimeConfig& runtime_config,
@ -115,7 +136,6 @@ class Gemma {
// Type-erased so that this can be defined in the header, without requiring
// forwarding functions.
ByteStorageT weights_u8_;
Activations prefill_;
Activations decode_;
ModelInfo info_;
};

View File

@ -97,9 +97,7 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, hwy::ThreadPool& pool,
} else if (token == EOS_ID) {
if (!args.multiturn) {
abs_pos = 0;
if (args.deterministic) {
gen.seed(42);
}
InitGenerator(args, gen);
}
if (verbosity >= 2) {
std::cout << "\n[ End ]\n";