Major revamp #2 of Prefill: fix token order, parallel for multi-query

- Allocate only the required KV caches and activation batch size
- Add flags for batch sizes
- Const-correct interface: Span of const int.
- Also clean up the KVCache arg to a span.
- Move kPrefillBatchSize into RuntimeConfig and remove related global constants.

PiperOrigin-RevId: 655893197
This commit is contained in:
Jan Wassenberg 2024-07-25 03:28:10 -07:00 committed by Copybara-Service
parent c1f243c351
commit aaf51898b6
15 changed files with 445 additions and 327 deletions

View File

@ -224,6 +224,7 @@ cc_library(
":common", ":common",
":cross_entropy", ":cross_entropy",
":gemma_lib", ":gemma_lib",
":kv_cache",
# Placeholder for internal dep, do not remove., # Placeholder for internal dep, do not remove.,
"@benchmark//:benchmark", "@benchmark//:benchmark",
"//compression:compress", "//compression:compress",

View File

@ -52,7 +52,7 @@ TEST(OptimizeTest, GradientDescent) {
CallForModelAndWeight<AllocateForwardPass>(info.model, info.weight); CallForModelAndWeight<AllocateForwardPass>(info.model, info.weight);
ByteStorageT backward = ByteStorageT backward =
CallForModelAndWeight<AllocateForwardPass>(info.model, info.weight); CallForModelAndWeight<AllocateForwardPass>(info.model, info.weight);
KVCache kv_cache = KVCache::Create(info.model); KVCache kv_cache = KVCache::Create(info.model, /*prefill_tbatch_size=*/16);
Gemma gemma(GemmaTokenizer(), info, pool); Gemma gemma(GemmaTokenizer(), info, pool);

View File

@ -128,7 +128,8 @@ int BenchmarkCrossEntropy(GemmaEnv& env, const Path& text,
size_t num_tokens = std::min<size_t>(prompt.size() - pos, batch_tokens); size_t num_tokens = std::min<size_t>(prompt.size() - pos, batch_tokens);
std::vector<int> prompt_slice(prompt.begin() + pos, std::vector<int> prompt_slice(prompt.begin() + pos,
prompt.begin() + pos + num_tokens); prompt.begin() + pos + num_tokens);
KVCache kv_cache = KVCache::Create(env.Info().model); KVCache kv_cache = KVCache::Create(
env.Info().model, env.MutableInferenceArgs().prefill_tbatch_size);
float entropy = ComputeCrossEntropy( float entropy = ComputeCrossEntropy(
*env.GetModel(), num_tokens, prompt_slice, kv_cache, env.Verbosity()); *env.GetModel(), num_tokens, prompt_slice, kv_cache, env.Verbosity());
total_entropy += entropy; total_entropy += entropy;

View File

@ -34,9 +34,9 @@
#include "evals/cross_entropy.h" #include "evals/cross_entropy.h"
#include "gemma/common.h" // StringFromType #include "gemma/common.h" // StringFromType
#include "gemma/gemma.h" #include "gemma/gemma.h"
#include "gemma/kv_cache.h"
#include "util/app.h" #include "util/app.h"
#include "util/args.h" #include "util/args.h"
#include "hwy/aligned_allocator.h"
#include "hwy/base.h" #include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/highway.h" #include "hwy/highway.h"
@ -76,10 +76,10 @@ GemmaEnv::GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference,
fprintf(stderr, "Loading model...\n"); fprintf(stderr, "Loading model...\n");
model_ = AllocateGemma(loader_, pool_); model_ = AllocateGemma(loader_, pool_);
kv_caches_.reserve(kBatchedQueryBatchSize); // Only allocate one for starters because GenerateBatch might not be called.
for (int i = 0; i < kBatchedQueryBatchSize; ++i) { kv_caches_.resize(1);
kv_caches_.push_back(new KVCache(KVCache::Create(model_->Info().model))); kv_caches_[0] =
} KVCache::Create(model_->Info().model, inference.prefill_tbatch_size);
} }
InitGenerator(inference_args_, gen_); InitGenerator(inference_args_, gen_);
@ -132,7 +132,7 @@ std::pair<std::string, size_t> GemmaEnv::QueryModel(
} }
gcpp::TimingInfo timing_info; gcpp::TimingInfo timing_info;
runtime_config_.batch_stream_token = batch_stream_token; runtime_config_.batch_stream_token = batch_stream_token;
model_->Generate(runtime_config_, tokens, /*start_pos=*/0, *kv_caches_[0], model_->Generate(runtime_config_, tokens, /*start_pos=*/0, kv_caches_[0],
timing_info); timing_info);
if (app_.verbosity >= 1) { if (app_.verbosity >= 1) {
LogSpeedStats(time_start, total_tokens); LogSpeedStats(time_start, total_tokens);
@ -141,8 +141,10 @@ std::pair<std::string, size_t> GemmaEnv::QueryModel(
} }
std::vector<std::pair<std::string, size_t>> GemmaEnv::BatchQueryModel2( std::vector<std::pair<std::string, size_t>> GemmaEnv::BatchQueryModel2(
const hwy::Span<const hwy::Span<int>>& prompts) { const MultiplePromptsTokens& prompts) {
std::vector<std::pair<std::string, size_t>> res(prompts.size()); const size_t num_queries = prompts.size();
HWY_ASSERT(num_queries != 0);
std::vector<std::pair<std::string, size_t>> res(num_queries);
std::fill(res.begin(), res.end(), std::make_pair("", 0)); std::fill(res.begin(), res.end(), std::make_pair("", 0));
size_t total_tokens = 0; size_t total_tokens = 0;
@ -162,14 +164,29 @@ std::vector<std::pair<std::string, size_t>> GemmaEnv::BatchQueryModel2(
return true; return true;
}; };
if (app_.verbosity >= 2) { if (app_.verbosity >= 2) {
std::cout << inference_args_.max_tokens << " " fprintf(stderr,
<< inference_args_.max_generated_tokens << " " "Max tok: %zu max gen: %zu temp: %f tbatch: %zu qbatch: %zu\n",
<< inference_args_.temperature; inference_args_.max_tokens, inference_args_.max_generated_tokens,
inference_args_.temperature, inference_args_.prefill_tbatch_size,
inference_args_.decode_qbatch_size);
} }
// Ensure we have one KVCache per query.
if (kv_caches_.size() < num_queries) {
kv_caches_.resize(num_queries);
}
for (size_t i = 1; i < num_queries; ++i) {
if (kv_caches_[i].seq_len == 0) {
kv_caches_[i] = KVCache::Create(model_->Info().model,
inference_args_.prefill_tbatch_size);
}
}
gcpp::TimingInfo timing_info; gcpp::TimingInfo timing_info;
runtime_config_.batch_stream_token = batch_stream_token; runtime_config_.batch_stream_token = batch_stream_token;
model_->GenerateBatch(runtime_config_, prompts, /*start_pos=*/0, kv_caches_, inference_args_.CopyTo(runtime_config_);
timing_info); model_->GenerateBatch(runtime_config_, prompts, /*start_pos=*/0,
KVCaches(&kv_caches_[0], num_queries), timing_info);
if (app_.verbosity >= 1) { if (app_.verbosity >= 1) {
LogSpeedStats(time_start, total_tokens); LogSpeedStats(time_start, total_tokens);
} }
@ -191,13 +208,12 @@ std::vector<std::pair<std::string, size_t>> GemmaEnv::BatchQueryModel(
prompts.push_back(WrapAndTokenize(model_->Tokenizer(), model_->Info(), prompts.push_back(WrapAndTokenize(model_->Tokenizer(), model_->Info(),
/*pos=*/0, mutable_prompt)); /*pos=*/0, mutable_prompt));
} }
std::vector<hwy::Span<int>> prompt_vector; std::vector<PromptTokens> prompt_vector;
prompt_vector.reserve(prompts.size()); prompt_vector.reserve(prompts.size());
for (auto& prompt : prompts) { for (auto& prompt : prompts) {
prompt_vector.push_back(hwy::Span<int>(prompt.data(), prompt.size())); prompt_vector.push_back(PromptTokens(prompt.data(), prompt.size()));
} }
hwy::Span<const hwy::Span<int>> prompt_span = hwy::Span<const hwy::Span<int>>( MultiplePromptsTokens prompt_span(prompt_vector.data(), prompt_vector.size());
prompt_vector.data(), prompt_vector.size());
return BatchQueryModel2(prompt_span); return BatchQueryModel2(prompt_span);
} }
@ -226,8 +242,8 @@ void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
if (app.verbosity >= 2) { if (app.verbosity >= 2) {
time_t now = time(nullptr); time_t now = time(nullptr);
char* dt = ctime(&now); // NOLINT char* dt = ctime(&now); // NOLINT
// TODO: replace hardware_concurrency with detected topology.
std::cout << "Date & Time : " << dt std::cout << "Date & Time : " << dt
<< "Prefill Token Batch Size : " << kPrefillBatchSize << "\n"
<< "Hardware concurrency : " << "Hardware concurrency : "
<< std::thread::hardware_concurrency() << "\n" << std::thread::hardware_concurrency() << "\n"
<< "Instruction set : " << "Instruction set : "

View File

@ -69,7 +69,7 @@ class GemmaEnv {
// the number of tokens that were generated. // the number of tokens that were generated.
std::pair<std::string, size_t> QueryModel(const std::vector<int>& tokens); std::pair<std::string, size_t> QueryModel(const std::vector<int>& tokens);
std::vector<std::pair<std::string, size_t>> BatchQueryModel2( std::vector<std::pair<std::string, size_t>> BatchQueryModel2(
const hwy::Span<const hwy::Span<int>>& prompts); const MultiplePromptsTokens& prompts);
// Adds turn structure to input, tokenizes and calls the above overload. // Adds turn structure to input, tokenizes and calls the above overload.
std::pair<std::string, size_t> QueryModel(std::string& input); std::pair<std::string, size_t> QueryModel(std::string& input);
std::vector<std::pair<std::string, size_t>> BatchQueryModel( std::vector<std::pair<std::string, size_t>> BatchQueryModel(
@ -88,7 +88,7 @@ class GemmaEnv {
const ModelInfo& Info() const { return loader_.Info(); } const ModelInfo& Info() const { return loader_.Info(); }
InferenceArgs& MutableInferenceArgs() { return inference_args_; } InferenceArgs& MutableInferenceArgs() { return inference_args_; }
std::mt19937& MutableGen() { return gen_; } std::mt19937& MutableGen() { return gen_; }
KVCache& MutableKVCache() { return *kv_caches_[0]; } KVCache& MutableKVCache() { return kv_caches_[0]; }
private: private:
// Arguments to the model loader: file locations, etc. // Arguments to the model loader: file locations, etc.
@ -103,8 +103,8 @@ class GemmaEnv {
std::mt19937 gen_; std::mt19937 gen_;
// The model to run inference on. // The model to run inference on.
std::unique_ptr<Gemma> model_; std::unique_ptr<Gemma> model_;
// The KV cache to use for inference. // KV caches, same number as query batch.
std::vector<KVCache*> kv_caches_; std::vector<KVCache> kv_caches_;
RuntimeConfig runtime_config_; RuntimeConfig runtime_config_;
}; };

View File

@ -17,14 +17,12 @@
#include <stdio.h> #include <stdio.h>
#include <memory>
#include <string> #include <string>
#include <vector> #include <vector>
#include "evals/benchmark_helper.h" #include "evals/benchmark_helper.h"
#include "gemma/common.h" #include "gemma/common.h"
#include "gemma/tokenizer.h" #include "hwy/base.h"
#include "hwy/aligned_allocator.h"
#include "hwy/tests/hwy_gtest.h" #include "hwy/tests/hwy_gtest.h"
// This test can be run manually with the downloaded gemma weights. // This test can be run manually with the downloaded gemma weights.
@ -75,21 +73,17 @@ class GemmaTest : public ::testing::Test {
replies.push_back(response); replies.push_back(response);
} }
} else { // Not Gemma-2 27B. Do not use turn structure. } else { // Not Gemma-2 27B. Do not use turn structure.
std::vector<std::unique_ptr<std::vector<int>>> prompts; std::vector<std::vector<int>> prompts_vector;
prompts.reserve(inputs.size()); prompts_vector.reserve(inputs.size());
for (auto input_string : inputs) { for (const auto& input_string : inputs) {
std::string mutable_input_string = input_string; prompts_vector.push_back(s_env->TokenizeAndPrependBOS(input_string));
prompts.push_back(std::make_unique<std::vector<int>>(
s_env->TokenizeAndPrependBOS(input_string)));
} }
std::vector<hwy::Span<int>> prompt_vector; std::vector<PromptTokens> prompt_spans;
for (auto& prompt : prompts) { for (const auto& prompt : prompts_vector) {
prompt_vector.push_back(hwy::Span<int>(prompt->data(), prompt->size())); prompt_spans.push_back(PromptTokens(prompt.data(), prompt.size()));
} }
hwy::Span<const hwy::Span<int>> prompt_span = MultiplePromptsTokens prompts(prompt_spans.data(), prompt_spans.size());
hwy::Span<const hwy::Span<int>>(prompt_vector.data(), for (auto [response, n] : s_env->BatchQueryModel2(prompts)) {
prompt_vector.size());
for (auto [response, n] : s_env->BatchQueryModel2(prompt_span)) {
replies.push_back(response); replies.push_back(response);
} }
} }
@ -121,18 +115,20 @@ class GemmaTest : public ::testing::Test {
} }
}; };
TEST_F(GemmaTest, Geography) { TEST_F(GemmaTest, GeographyBatched) {
s_env->MutableInferenceArgs().decode_qbatch_size = 3;
// 6 are enough to test batching and the loop.
static const char* kQA[][2] = { static const char* kQA[][2] = {
{"What is the capital of Hungary?", "Budapest"},
{"What is the capital of Australia?", "Canberra"}, {"What is the capital of Australia?", "Canberra"},
{"What is the capital of Denmark?", "Copenhagen"},
{"Ljubljana is the capital of which country?", "Slovenia"},
{"Is Chicago a country?", "not"},
{"How many states does the US have?", "50"}, {"How many states does the US have?", "50"},
{"What is the Pacific?", "ocean"},
}; };
static const size_t kNum = sizeof(kQA) / sizeof(kQA[0]); static const size_t kNum = sizeof(kQA) / sizeof(kQA[0]);
TestQuestions(kQA, kNum, /*batch=*/false); TestQuestions(kQA, HWY_MIN(kNum, 3), /*batch=*/false);
static const char* kQA_single_question[][2] = { TestQuestions(kQA, 1, /*batch=*/true);
{"What is the capital of Australia?", "Canberra"},
};
TestQuestions(kQA_single_question, 1, /*batch=*/true);
TestQuestions(kQA, kNum, /*batch=*/true); TestQuestions(kQA, kNum, /*batch=*/true);
} }

View File

@ -31,6 +31,7 @@
int main(int argc, char** argv) { int main(int argc, char** argv) {
gcpp::LoaderArgs loader(argc, argv); gcpp::LoaderArgs loader(argc, argv);
gcpp::InferenceArgs inference(argc, argv);
if (gcpp::HasHelp(argc, argv)) { if (gcpp::HasHelp(argc, argv)) {
loader.Help(); loader.Help();
return 0; return 0;
@ -42,7 +43,8 @@ int main(int argc, char** argv) {
// Instantiate model and KV Cache // Instantiate model and KV Cache
hwy::ThreadPool pool(gcpp::AppArgs::GetSupportedThreadCount()); hwy::ThreadPool pool(gcpp::AppArgs::GetSupportedThreadCount());
gcpp::Gemma model = gcpp::CreateGemma(loader, pool); gcpp::Gemma model = gcpp::CreateGemma(loader, pool);
gcpp::KVCache kv_cache = gcpp::KVCache::Create(loader.Info().model); gcpp::KVCache kv_cache =
gcpp::KVCache::Create(loader.Info().model, inference.prefill_tbatch_size);
size_t pos = 0; // KV Cache position size_t pos = 0; // KV Cache position
// Initialize random number generator // Initialize random number generator

View File

@ -36,11 +36,6 @@ ByteStorageT AllocateSizeof() {
return hwy::AllocateAligned<uint8_t>(sizeof(T)); return hwy::AllocateAligned<uint8_t>(sizeof(T));
} }
// Relatively small so that we can also parallelize non-Matmul work. There is
// one outer thread per batch, each with --num_threads / batches inner threads.
constexpr size_t kPrefillBatchSize = 64;
constexpr size_t kBatchedQueryBatchSize = 16;
// Model variants: see configs.h for details. When adding a new one, also // Model variants: see configs.h for details. When adding a new one, also
// update GEMMA_FOREACH* and Call* below, and add instantiations/*.cc. // update GEMMA_FOREACH* and Call* below, and add instantiations/*.cc.
enum class Model { enum class Model {

View File

@ -73,10 +73,10 @@ template <class TConfig>
HWY_NOINLINE void GriffinRecurrent( HWY_NOINLINE void GriffinRecurrent(
size_t batch_start, size_t num_tokens, size_t num_queries, size_t layer, size_t batch_start, size_t num_tokens, size_t num_queries, size_t layer,
Activations& activations, const CompressedLayer<TConfig>* layer_weights, Activations& activations, const CompressedLayer<TConfig>* layer_weights,
const std::vector<KVCache*>& kv_caches, hwy::ThreadPool& pool) { const KVCaches& kv_caches, hwy::ThreadPool& pool) {
PROFILER_ZONE("Gen.Griffin"); PROFILER_ZONE("Gen.Griffin");
HWY_ASSERT(num_queries == 1); // TODO: add batch query support for Griffin. HWY_ASSERT(num_queries == 1); // TODO: add batch query support for Griffin.
KVCache& kv_cache = *kv_caches[0]; KVCache& kv_cache = kv_caches[0];
namespace hn = hwy::HWY_NAMESPACE; namespace hn = hwy::HWY_NAMESPACE;
using D = hn::ScalableTag<float>; using D = hn::ScalableTag<float>;
static constexpr size_t kModelDim = TConfig::kModelDim; static constexpr size_t kModelDim = TConfig::kModelDim;
@ -208,7 +208,7 @@ HWY_NOINLINE void GemmaAttention(size_t interleaved_start, size_t num_tokens,
size_t num_queries, size_t layer, size_t num_queries, size_t layer,
Activations& activations, Activations& activations,
const CompressedLayer<TConfig>* layer_weights, const CompressedLayer<TConfig>* layer_weights,
const std::vector<KVCache*>& kv_caches, const KVCaches& kv_caches,
hwy::ThreadPool& pool) { hwy::ThreadPool& pool) {
PROFILER_ZONE("Gen.Attention"); PROFILER_ZONE("Gen.Attention");
HWY_DASSERT(interleaved_start % num_queries == 0); HWY_DASSERT(interleaved_start % num_queries == 0);
@ -221,6 +221,10 @@ HWY_NOINLINE void GemmaAttention(size_t interleaved_start, size_t num_tokens,
constexpr size_t kKVHeads = TConfig::kKVHeads; constexpr size_t kKVHeads = TConfig::kKVHeads;
constexpr size_t kSeqLen = TConfig::kSeqLen; constexpr size_t kSeqLen = TConfig::kSeqLen;
GEMMA_CONSTEXPR_SQRT float kQueryScale = ChooseQueryScale<TConfig>(); GEMMA_CONSTEXPR_SQRT float kQueryScale = ChooseQueryScale<TConfig>();
HWY_ASSERT(num_queries <= kv_caches.size());
const hwy::Divisor div_seq_len(static_cast<uint32_t>(kv_caches[0].seq_len));
// Multi-Head Attention a.k.a. "use_qkv_einsum". // Multi-Head Attention a.k.a. "use_qkv_einsum".
constexpr bool kIsMHA = Activations::IsMHA<TConfig>(); constexpr bool kIsMHA = Activations::IsMHA<TConfig>();
static_assert(!kIsMHA || TConfig::kInterleaveQKV); // MHA => interleaved static_assert(!kIsMHA || TConfig::kInterleaveQKV); // MHA => interleaved
@ -245,9 +249,9 @@ HWY_NOINLINE void GemmaAttention(size_t interleaved_start, size_t num_tokens,
const float* x = activations.pre_att_rms_out.Batch(interleaved_idx); const float* x = activations.pre_att_rms_out.Batch(interleaved_idx);
const size_t query_idx = interleaved_idx % num_queries; const size_t query_idx = interleaved_idx % num_queries;
const size_t batch_idx = interleaved_idx / num_queries; const size_t batch_idx = interleaved_idx / num_queries;
KVCache& kv_cache = *kv_caches[query_idx]; KVCache& kv_cache = kv_caches[query_idx];
const size_t pos = batch_start + batch_idx; const size_t pos = batch_start + batch_idx;
const size_t cache_pos = pos % (kSeqLen + kPrefillBatchSize); const size_t cache_pos = div_seq_len.Remainder(pos);
const size_t kv_offset = const size_t kv_offset =
cache_pos * kCachePosSize + layer * kCacheLayerSize; cache_pos * kCachePosSize + layer * kCacheLayerSize;
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset; float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;
@ -268,10 +272,10 @@ HWY_NOINLINE void GemmaAttention(size_t interleaved_start, size_t num_tokens,
const size_t query_idx = interleaved_idx % num_queries; const size_t query_idx = interleaved_idx % num_queries;
const size_t batch_idx = interleaved_idx / num_queries; const size_t batch_idx = interleaved_idx / num_queries;
const size_t pos = batch_start + batch_idx; const size_t pos = batch_start + batch_idx;
const size_t cache_pos = pos % (kSeqLen + kPrefillBatchSize); const size_t cache_pos = div_seq_len.Remainder(pos);
const size_t kv_offset = cache_pos * kCachePosSize + const size_t kv_offset = cache_pos * kCachePosSize +
layer * kCacheLayerSize + head * kQKVDim * 2; layer * kCacheLayerSize + head * kQKVDim * 2;
KVCache& kv_cache = *kv_caches[query_idx]; KVCache& kv_cache = kv_caches[query_idx];
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset; float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;
if constexpr (kIsMHA) { if constexpr (kIsMHA) {
// For MHA, copy KV into the KV cache from scratch space (see above). // For MHA, copy KV into the KV cache from scratch space (see above).
@ -297,7 +301,7 @@ HWY_NOINLINE void GemmaAttention(size_t interleaved_start, size_t num_tokens,
const size_t query_idx = interleaved_idx % num_queries; const size_t query_idx = interleaved_idx % num_queries;
const size_t batch_idx = interleaved_idx / num_queries; const size_t batch_idx = interleaved_idx / num_queries;
const size_t head_offset = (head / kHeadGroups) * kQKVDim * 2; const size_t head_offset = (head / kHeadGroups) * kQKVDim * 2;
KVCache& kv_cache = *kv_caches[query_idx]; KVCache& kv_cache = kv_caches[query_idx];
float* HWY_RESTRICT q = float* HWY_RESTRICT q =
activations.q.Batch(interleaved_idx) + head * kQStride; activations.q.Batch(interleaved_idx) + head * kQStride;
@ -314,10 +318,10 @@ HWY_NOINLINE void GemmaAttention(size_t interleaved_start, size_t num_tokens,
const size_t start_pos = const size_t start_pos =
pos - std::min(TConfig::kAttentionWindowSizes[layer] - 1, pos); pos - std::min(TConfig::kAttentionWindowSizes[layer] - 1, pos);
for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) { for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) {
const size_t cache_pos = pos2 % (kSeqLen + kPrefillBatchSize); const size_t cache_pos = div_seq_len.Remainder(pos2);
const size_t kv_offset = const size_t kv_offset =
cache_pos * kCachePosSize + layer * kCacheLayerSize + head_offset; cache_pos * kCachePosSize + layer * kCacheLayerSize + head_offset;
const float* HWY_RESTRICT k = kv_cache.kv_cache.get() + kv_offset; const float* HWY_RESTRICT k = &kv_cache.kv_cache[kv_offset];
const float score = Dot(q, k, kQKVDim); const float score = Dot(q, k, kQKVDim);
head_att[pos2 % kSeqLen] = score; head_att[pos2 % kSeqLen] = score;
} }
@ -337,7 +341,7 @@ HWY_NOINLINE void GemmaAttention(size_t interleaved_start, size_t num_tokens,
activations.att_out.Batch(interleaved_idx) + head * kQKVDim; activations.att_out.Batch(interleaved_idx) + head * kQKVDim;
hwy::ZeroBytes(att_out, kQKVDim * sizeof(*att_out)); hwy::ZeroBytes(att_out, kQKVDim * sizeof(*att_out));
for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) { for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) {
const size_t cache_pos = pos2 % (kSeqLen + kPrefillBatchSize); const size_t cache_pos = div_seq_len.Remainder(pos2);
const size_t kv_offset = const size_t kv_offset =
cache_pos * kCachePosSize + layer * kCacheLayerSize + head_offset; cache_pos * kCachePosSize + layer * kCacheLayerSize + head_offset;
float* HWY_RESTRICT v = float* HWY_RESTRICT v =
@ -383,8 +387,7 @@ HWY_NOINLINE void Attention(LayerAttentionType type, size_t interleaved_start,
size_t num_tokens, size_t num_queries, size_t layer, size_t num_tokens, size_t num_queries, size_t layer,
Activations& activations, Activations& activations,
const CompressedLayer<TConfig>* layer_weights, const CompressedLayer<TConfig>* layer_weights,
const std::vector<KVCache*>& kv_caches, const KVCaches& kv_caches, hwy::ThreadPool& pool) {
hwy::ThreadPool& pool) {
if (type == LayerAttentionType::kGemma) { if (type == LayerAttentionType::kGemma) {
GemmaAttention<TConfig>(interleaved_start, num_tokens, num_queries, layer, GemmaAttention<TConfig>(interleaved_start, num_tokens, num_queries, layer,
activations, layer_weights, kv_caches, pool); activations, layer_weights, kv_caches, pool);
@ -458,12 +461,13 @@ HWY_NOINLINE void FFW(Activations& activations, size_t num_interleaved,
output_bias, pool); output_bias, pool);
} }
// TODO: pass Activations.x instead of Activations. // `batch_idx` indicates which row of `x` to write to.
// `pos` is for the entire batch and does not include `batch_idx`. // `pos` is the *token*'s position, not the start of the batch, because this is
// called for batches of tokens in prefill, but batches of queries in decode.
template <class TConfig> template <class TConfig>
HWY_NOINLINE void EmbedToken(int token, size_t batch_idx, size_t pos, HWY_NOINLINE void EmbedToken(int token, size_t batch_idx, size_t pos,
const CompressedWeights<TConfig>& weights, const CompressedWeights<TConfig>& weights,
Activations& activations) { RowVectorBatch<float>& x) {
constexpr size_t kModelDim = TConfig::kModelDim; constexpr size_t kModelDim = TConfig::kModelDim;
GEMMA_CONSTEXPR_EMBSCALING const float kEmbScaling = GEMMA_CONSTEXPR_EMBSCALING const float kEmbScaling =
EmbeddingScaling<TConfig>(); EmbeddingScaling<TConfig>();
@ -472,11 +476,10 @@ HWY_NOINLINE void EmbedToken(int token, size_t batch_idx, size_t pos,
HWY_DASSERT(token < TConfig::kVocabSize); HWY_DASSERT(token < TConfig::kVocabSize);
Decompress(weights.embedder_input_embedding, token * kModelDim, Decompress(weights.embedder_input_embedding, token * kModelDim,
activations.x.Batch(batch_idx), kModelDim); x.Batch(batch_idx), kModelDim);
MulByConst(kEmbScaling, activations.x.Batch(batch_idx), kModelDim); MulByConst(kEmbScaling, x.Batch(batch_idx), kModelDim);
if constexpr (TConfig::kAbsolutePE) { if constexpr (TConfig::kAbsolutePE) {
AddAbsolutePositionalEmbeddings(activations.x.Batch(batch_idx), kModelDim, AddAbsolutePositionalEmbeddings(x.Batch(batch_idx), kModelDim, pos);
pos + batch_idx);
}; };
} }
@ -501,7 +504,7 @@ template <class TConfig>
HWY_NOINLINE void TransformerLayer( HWY_NOINLINE void TransformerLayer(
size_t num_tokens, size_t num_queries, size_t pos, size_t layer, size_t num_tokens, size_t num_queries, size_t pos, size_t layer,
const CompressedLayer<TConfig>* layer_weights, Activations& activations, const CompressedLayer<TConfig>* layer_weights, Activations& activations,
const std::vector<KVCache*>& kv_caches, hwy::ThreadPool& pool) { const KVCaches& kv_caches, hwy::ThreadPool& pool) {
constexpr size_t kModelDim = TConfig::kModelDim; constexpr size_t kModelDim = TConfig::kModelDim;
const size_t num_interleaved = num_tokens * num_queries; const size_t num_interleaved = num_tokens * num_queries;
auto type = TConfig::kLayerConfig[layer]; auto type = TConfig::kLayerConfig[layer];
@ -536,116 +539,220 @@ HWY_NOINLINE void TransformerLayer(
/*is_attention=*/false); /*is_attention=*/false);
} }
// For prefill, we have two-level parallelism: // Batches are important for amortizing loading weights over multiple tokens.
// - Outer: input tokens are split into batches, each of which is one task // This is possible in prefill because we know all tokens beforehand, whereas
// processed by a worker in `outer_pool_`, which includes the main thread // decode depends on the previous output token. However, each prefill batch of a
// because it is the one that calls `Prefill`. // query requires that preceding batches already wrote to the KV cache, hence we
// sequentially loop over token batches. We can reduce the number of iterations
// by increasing the batch size, but this also increases arithmetic intensity,
// and so we are eventually compute-limited. The tensor parallelism (number of
// threads collaborating on MatMul) is also limited by the CPU topology:
// fork/join barriers are slow(er) when some threads reside in a different NUMA
// node. To allow more threads to help, we also support parallelizing over
// queries in case GenerateBatch was called.
//
// Thus we have two-level parallelism:
// - Outer: handles one 'qbatch' of entire queries. The set of outer workers
// includes the main thread because it is the one that calls `Prefill`, and is
// determined by the number of 'clusters' (shared L3 caches or sockets).
// - Inner: each `outer` worker passes `inner_pools_[outer]` to // - Inner: each `outer` worker passes `inner_pools_[outer]` to
// `TransformerLayer` for tensor-level parallelism. // `TransformerLayer` for tensor-level parallelism, and processes
// `tbatch_size` tokens from a single query at a time.
// //
// This class holds the thread pools and activations, recreated for each query. // This class holds the thread pools and one activation per outer worker. It is
// // NOT reused across calls to GenerateSingle/GenerateBatch so that we can adapt
// It is safe to parallelize batches because we write to KVCache at // to their num_queries.
// `pos % kSeqLen`, which is far greater than the number of workers.
// Note however that this currently leads to nondeterministic results because
// the RNG is invoked in different order.
class PrefillState { class PrefillState {
public: // TODO: move helper functions, also those in app.h, to a threading header
explicit PrefillState(hwy::ThreadPool& main_pool) : main_pool_(&main_pool) {} using LPS = hwy::LogicalProcessorSet;
LPS Intersection(const LPS& big, const LPS& small) {
~PrefillState() { DeleteInnerPools(); } LPS both;
// Reduce expected work by iterating over the smaller set.
// Called before each query. Recreates thread pools, which has the advantage small.Foreach([big, &both](size_t idx) {
// of tailoring the parallelism to the prompt length. if (big.Get(idx)) both.Set(idx);
template <class TConfig>
void Init(size_t prefill_size) {
// Would be zero for single-token prompts (prefill_size == num_tokens - 1).
num_batches_ =
HWY_MAX(size_t{1}, hwy::DivCeil(prefill_size, kPrefillBatchSize));
// More than num_batches_ would waste workers on idling in the outer Run;
// more than NumWorkers() would exceed the global --num_threads.
const size_t outer_workers =
HWY_MIN(num_batches_, main_pool_->NumWorkers());
HWY_ASSERT(outer_workers != 0); // Otherwise activations_ is empty.
// One activation per outer worker. Allocating in parallel saves 30 ms.
activations_.resize(outer_workers);
main_pool_->Run(0, outer_workers, [this](uint64_t task, size_t /*thread*/) {
activations_[task].Allocate<TConfig>(kPrefillBatchSize);
}); });
return both;
}
DeleteInnerPools(); std::vector<size_t> CoresInLPS(const LPS& cluster) {
std::vector<size_t> cores;
cores.reserve(cluster.Count());
cluster.Foreach([&cores](size_t idx) { cores.push_back(idx); });
return cores;
}
// If we'd create just one inner pool with all the workers, skip the cost of // For each cluster (shared L3 cache), a bitset of cores.
// thread creation and pinning (about 60 ms) by reusing the main pool. using CoresPerCluster = std::vector<LPS>;
if (outer_workers <= 1) {
// Still allocate a dummy pool to simplify Prefill(). // Returns empty if detection failed.
outer_pool_ = std::make_unique<hwy::ThreadPool>(1); CoresPerCluster DetectClusters() {
inner_pools_.push_back(main_pool_); CoresPerCluster clusters;
return; // Which processors are not disabled via OS, taskset, or numactl.
LPS enabled;
// If we don't know, better to use just a single inner pool rather than risk
// oversubscribing to enabled cores.
if (!GetThreadAffinity(enabled)) return clusters;
hwy::Topology topology;
if (topology.packages.empty()) return clusters;
// For each cluster = outer, the cores that will be used for an inner pool.
CoresPerCluster inner_lps;
for (const hwy::Topology::Package& package : topology.packages) {
for (const hwy::Topology::Cluster& cluster : package.clusters) {
// Only use enabled cores, and only add if not empty.
const LPS lps = Intersection(enabled, cluster.lps);
if (lps.Any()) clusters.push_back(lps);
}
} }
// Sort by descending number of enabled cores, so that we preferentially
// use the largest clusters.
std::sort(clusters.begin(), clusters.end(),
[](const LPS& a, const LPS& b) { return a.Count() > b.Count(); });
return clusters;
}
// Returns false if the main pool should be reused instead.
bool AssignInnerPoolsToClusters(const size_t num_queries) {
HWY_ASSERT(num_queries != 0);
CoresPerCluster inner_lps = DetectClusters();
// If we have more outer workers than queries, discard the excess.
if (inner_lps.size() > num_queries) inner_lps.resize(num_queries);
// If we're not going to create multiple pools, avoid the overhead of
// re-pinning (60 ms) and reuse the main pool.
if (inner_lps.size() <= 1) return false;
// Before creating new threads, stop the old ones from spinning. Caller is // Before creating new threads, stop the old ones from spinning. Caller is
// responsible for undoing this by calling `ResumeMainSpinning`. // responsible for undoing this by calling `ResumeMainSpinning`.
main_pool_->SetWaitMode(hwy::PoolWaitMode::kBlock); main_pool_->SetWaitMode(hwy::PoolWaitMode::kBlock);
outer_pool_ = std::make_unique<hwy::ThreadPool>(outer_workers);
outer_pool_ = std::make_unique<hwy::ThreadPool>(inner_lps.size());
outer_pool_->SetWaitMode(hwy::PoolWaitMode::kSpin); outer_pool_->SetWaitMode(hwy::PoolWaitMode::kSpin);
// Assign up to `max_workers` to inner pools. Each inner pool creates HWY_ASSERT(inner_pools_.empty());
// `workers_per_outer - 1` threads in addition to its 'main' thread, which for (const LPS& inner : inner_lps) {
// is the one calling `inner_pools[outer]->Run`, i.e., `outer`. In total, inner_pools_.push_back(new hwy::ThreadPool(inner.Count()));
// `outer_workers * (max_workers / outer_workers)` workers are used.
const size_t workers_per_outer = main_pool_->NumWorkers() / outer_workers;
for (size_t outer = 0; outer < outer_workers; ++outer) {
inner_pools_.push_back(new hwy::ThreadPool(workers_per_outer));
inner_pools_.back()->SetWaitMode(hwy::PoolWaitMode::kSpin); inner_pools_.back()->SetWaitMode(hwy::PoolWaitMode::kSpin);
} }
PinThreads(outer_workers, workers_per_outer); // For each inner pool, pin their threads AND the associated outer thread
// to the enabled cores in the cluster.
outer_pool_->Run(
0, inner_lps.size(),
[this, &inner_lps](uint64_t outer, size_t outer_thread) {
HWY_ASSERT(outer == outer_thread); // each outer has one task
const std::vector<size_t> cores = CoresInLPS(inner_lps[outer]);
inner_pools_[outer]->Run(
0, cores.size(), [&cores](uint64_t task, size_t thread) {
HWY_ASSERT(task == thread); // each inner has one task
hwy::PinThreadToLogicalProcessor(cores[task]);
});
});
return true;
} }
// `tokens` are from interleaved queries. (See InterleaveQueries() below.) void ReuseMainPoolAsInner() {
// Still allocate an empty pool to simplify Prefill().
outer_pool_ = std::make_unique<hwy::ThreadPool>(1);
HWY_ASSERT(inner_pools_.empty());
inner_pools_.push_back(main_pool_);
}
public:
// Creates pools. AllocateActivations must still be called separately; it has
// a template argument.
PrefillState(hwy::ThreadPool& main_pool, size_t num_queries)
: main_pool_(&main_pool) {
PROFILER_ZONE("Init.Prefill.Ctor");
if (!AssignInnerPoolsToClusters(num_queries)) {
ReuseMainPoolAsInner();
}
}
~PrefillState() {
for (hwy::ThreadPool* p : inner_pools_) {
if (p != main_pool_) delete p;
}
}
// `tbatch_size` is the number of tokens from one query to prefill at a time.
template <class TConfig> template <class TConfig>
HWY_NOINLINE void Prefill(hwy::Span<const int> tokens, size_t num_queries, void AllocateActivations(size_t num_queries, size_t tbatch_size) {
size_t pos, PROFILER_ZONE("Init.Prefill.AllocateActivations");
const size_t outer_workers = outer_pool_->NumWorkers();
HWY_ASSERT(outer_workers != 0); // Otherwise activations_ is empty.
HWY_ASSERT(activations_.empty()); // only call once.
activations_.resize(outer_workers);
if (outer_workers == 1) {
activations_[0].Allocate<TConfig>(tbatch_size);
} else {
// Allocating in parallel can save 30 ms.
main_pool_->Run(0, outer_workers,
[this, tbatch_size](uint64_t task, size_t /*thread*/) {
activations_[task].Allocate<TConfig>(tbatch_size);
});
}
}
template <class TConfig>
HWY_NOINLINE void Prefill(const MultiplePromptsTokens& prompts,
const size_t prefill_per_query, const size_t pos,
const size_t query_idx_start,
const CompressedWeights<TConfig>& weights, const CompressedWeights<TConfig>& weights,
const RuntimeConfig& runtime_config, const RuntimeConfig& runtime_config,
const std::vector<KVCache*>& kv_caches) { const KVCaches& kv_caches) {
PROFILER_ZONE("Gen.Prefill"); PROFILER_ZONE("Gen.Prefill");
const size_t num_queries = prompts.size();
HWY_ASSERT(kv_caches.size() == num_queries);
const size_t max_tbatch_size = activations_[0].x.BatchSize();
HWY_ASSERT(activations_.size() == outer_pool_->NumWorkers()); // For each query (parallel): an outer worker processes all its tokens.
HWY_ASSERT(inner_pools_.size() == outer_pool_->NumWorkers()); // `qi` is relative to the batch, not the global query index.
outer_pool_->Run( outer_pool_->Run(
0, num_batches_, [&](const uint64_t batch_num, size_t thread) HWY_ATTR { 0, num_queries, [&](const uint64_t qi, size_t qthread) HWY_ATTR {
const size_t batch_start = batch_num * kPrefillBatchSize; Activations& activations = activations_[qthread];
const size_t batch_size = hwy::ThreadPool& inner_pool = *inner_pools_[qthread];
HWY_MIN(kPrefillBatchSize, tokens.size() - batch_start);
HWY_DASSERT(batch_start % num_queries == 0);
HWY_DASSERT(batch_size % num_queries == 0);
const size_t pos_per_query = pos + batch_start / num_queries;
const size_t num_tokens = batch_size / num_queries;
// Negligible time compared to TransformerLayer. // Single query at a time, so pass a slice of the KV cache because
for (size_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) { // GemmaAttention will only access the first.
EmbedToken<TConfig>(tokens[batch_start + batch_idx], batch_idx, const size_t kPrefillQueries = 1;
pos_per_query, weights, activations_[thread]); KVCaches prefill_kv_caches(&kv_caches[qi], kPrefillQueries);
}
for (size_t layer = 0; layer < TConfig::kLayers; ++layer) { // For each batch of tokens in the query:
const auto* layer_weights = weights.GetLayer(layer); for (size_t tbatch_start = 0; tbatch_start < prefill_per_query;
TransformerLayer<TConfig>( tbatch_start += max_tbatch_size) {
num_tokens, num_queries, pos_per_query, layer, layer_weights, // Fill activations.x (much faster than TransformerLayer).
activations_[thread], kv_caches, *inner_pools_[thread]); const size_t tbatch_size =
} HWY_MIN(max_tbatch_size, prefill_per_query - tbatch_start);
for (size_t ti = 0; ti < tbatch_size; ++ti) {
const int token = prompts[qi][tbatch_start + ti];
EmbedToken<TConfig>(token, ti, pos + ti, weights, activations.x);
}
// NOTE: we unconditionally call StreamToken, even if EOS. // Transformer with one batch of tokens from a single query.
for (size_t i = 0; i < batch_size; ++i) { for (size_t layer = 0; layer < TConfig::kLayers; ++layer) {
const size_t query_idx = i % num_queries; const auto* layer_weights = weights.GetLayer(layer);
const size_t batch_idx = i / num_queries; TransformerLayer<TConfig>(
runtime_config.StreamToken(query_idx, pos_per_query + batch_idx, tbatch_size, kPrefillQueries, pos + tbatch_start, layer,
tokens[i], 0.0f); layer_weights, activations, prefill_kv_caches, inner_pool);
} }
// NOTE: we unconditionally call StreamToken, even if EOS.
for (size_t ti = 0; ti < tbatch_size; ++ti) {
const int token = prompts[qi][tbatch_start + ti];
runtime_config.StreamToken(query_idx_start + qi,
pos + tbatch_start + ti, token, 0.0f);
}
} // for tbatch_start
}); });
} }
@ -663,39 +770,15 @@ class PrefillState {
} }
private: private:
// Pins each outer thread after their inner threads so they are likely to
// run on the same socket.
void PinThreads(size_t outer_workers, size_t workers_per_outer) {
outer_pool_->Run(
0, outer_workers,
[this, workers_per_outer](uint64_t outer, size_t outer_thread) {
HWY_ASSERT(outer == outer_thread);
// Pins inner *and* `outer` - the latter is the calling thread.
inner_pools_[outer]->Run(
0, workers_per_outer,
[outer, workers_per_outer](uint64_t task, size_t thread) {
HWY_ASSERT(task == thread); // each worker has one task
const size_t lp = outer * workers_per_outer + task;
hwy::PinThreadToLogicalProcessor(lp);
});
});
}
void DeleteInnerPools() {
for (hwy::ThreadPool* p : inner_pools_) {
if (p != main_pool_) delete p;
}
inner_pools_.clear();
}
hwy::ThreadPool* main_pool_; hwy::ThreadPool* main_pool_;
std::unique_ptr<hwy::ThreadPool> outer_pool_; // always allocated std::unique_ptr<hwy::ThreadPool> outer_pool_; // always allocated
std::vector<Activations> activations_; // size == outer_pool->NumWorkers() // Holds a single pointer equal to main_pool_, or new allocations; in either
// Either there is a single pointer equal to main_pool, or newly created pools // case, size() is equal to outer_pool_->NumWorkers(). The first case avoids
// that we own. The former case avoids thread creation overhead for prompts // allocation overhead for the common case of a single query.
// that fit in a single batch.
std::vector<hwy::ThreadPool*> inner_pools_; std::vector<hwy::ThreadPool*> inner_pools_;
size_t num_batches_ = 0;
// size() == outer_pool_->NumWorkers(); filled by AllocateActivations.
std::vector<Activations> activations_;
}; };
// `tokens` is length `num_tokens * num_queries`. In autoregressive decode, // `tokens` is length `num_tokens * num_queries`. In autoregressive decode,
@ -705,8 +788,7 @@ HWY_NOINLINE void Transformer(const int* tokens, size_t num_tokens,
size_t num_queries, size_t pos, size_t num_queries, size_t pos,
const CompressedWeights<TConfig>& weights, const CompressedWeights<TConfig>& weights,
Activations& activations, Activations& activations,
const std::vector<KVCache*>& kv_caches, const KVCaches& kv_caches, hwy::ThreadPool& pool,
hwy::ThreadPool& pool,
const LayersOutputFunc& layers_output) { const LayersOutputFunc& layers_output) {
const size_t num_interleaved = num_tokens * num_queries; const size_t num_interleaved = num_tokens * num_queries;
if (layers_output) { if (layers_output) {
@ -718,7 +800,7 @@ HWY_NOINLINE void Transformer(const int* tokens, size_t num_tokens,
constexpr size_t kModelDim = TConfig::kModelDim; constexpr size_t kModelDim = TConfig::kModelDim;
for (size_t token_idx = 0; token_idx < num_interleaved; ++token_idx) { for (size_t token_idx = 0; token_idx < num_interleaved; ++token_idx) {
EmbedToken<TConfig>(tokens[token_idx], token_idx, pos, weights, EmbedToken<TConfig>(tokens[token_idx], token_idx, pos, weights,
activations); activations.x);
} }
for (size_t layer = 0; layer < TConfig::kLayers; ++layer) { for (size_t layer = 0; layer < TConfig::kLayers; ++layer) {
@ -781,10 +863,10 @@ void RangeChecks(size_t& max_tokens, size_t& max_generated_tokens,
// Returns interleaved tokens: one from each query, followed by the second from // Returns interleaved tokens: one from each query, followed by the second from
// all queries, with EOS padding. // all queries, with EOS padding.
static std::vector<int> InterleaveQueries( static std::vector<int> InterleaveQueries(const MultiplePromptsTokens& queries,
const hwy::Span<const hwy::Span<int>>& queries, const RuntimeConfig& runtime_config,
const RuntimeConfig& runtime_config, size_t& min_prompt_size, size_t& min_prompt_size,
size_t& max_prompt_size) { size_t& max_prompt_size) {
const size_t num_queries = queries.size(); const size_t num_queries = queries.size();
min_prompt_size = hwy::LimitsMax<size_t>(); min_prompt_size = hwy::LimitsMax<size_t>();
max_prompt_size = 0; max_prompt_size = 0;
@ -829,28 +911,34 @@ class TokenStreamer {
private: private:
const RuntimeConfig& runtime_config_; const RuntimeConfig& runtime_config_;
// BitSet4096 divides the arg by 64, so ensure it is at least 64. hwy::BitSet4096<> is_eos_;
hwy::BitSet4096<HWY_MAX(64, kBatchedQueryBatchSize)> is_eos_;
}; };
// Generates one token per query in the batch. // Generates one token for each query in `prompts`, which is one qbatch whose
// size is at most the `batch_size` passed to `activations.Allocate`.
// //
// pos indexes the KV cache. In the first turn of a chat, pos = 0, and it // `pos` indexes the KV cache. In the first turn of a chat, pos = 0, and it
// continues to increase by one for each prefilled/generated token per query. // continues to increase by one for each prefilled/generated token per query.
// query_idx_start is the first query index in the batch. //
template <class TConfig, size_t kQueryBatchSize> // `query_idx_start` is the query_idx of the first query in the batch, so that
// `StreamFunc` gets the global query index, not relative to the batch.
//
// `kv_caches` is for the batch, size must match `prompts`.
template <class TConfig>
void GenerateT(const ByteStorageT& weights_u8, Activations& activations, void GenerateT(const ByteStorageT& weights_u8, Activations& activations,
const RuntimeConfig& runtime_config, const RuntimeConfig& runtime_config,
const hwy::Span<const hwy::Span<int>>& prompts, const size_t pos, const MultiplePromptsTokens& prompts, const size_t pos,
const size_t query_idx_start, const size_t query_idx_start, const KVCaches& kv_caches,
const std::vector<KVCache*>& kv_caches, hwy::ThreadPool& pool, hwy::ThreadPool& pool, TimingInfo& timing_info) {
TimingInfo& timing_info) {
constexpr size_t kVocabSize = TConfig::kVocabSize; constexpr size_t kVocabSize = TConfig::kVocabSize;
const CompressedWeights<TConfig>& weights = const CompressedWeights<TConfig>& weights =
*reinterpret_cast<const CompressedWeights<TConfig>*>(weights_u8.get()); *reinterpret_cast<const CompressedWeights<TConfig>*>(weights_u8.get());
const size_t num_queries = prompts.size(); const size_t num_queries = prompts.size();
HWY_DASSERT(num_queries <= kQueryBatchSize); HWY_ASSERT(num_queries <= 4096); // TokenStreamer uses BitSet4096.
HWY_ASSERT(num_queries <= activations.x.BatchSize());
HWY_ASSERT(kv_caches.size() == num_queries);
size_t min_prompt_size, max_prompt_size; size_t min_prompt_size, max_prompt_size;
const std::vector<int> prompt = InterleaveQueries( const std::vector<int> prompt = InterleaveQueries(
prompts, runtime_config, min_prompt_size, max_prompt_size); prompts, runtime_config, min_prompt_size, max_prompt_size);
@ -877,28 +965,28 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations,
// Prefill stops before min_prompt_size - 1 because the last prompt token is // Prefill stops before min_prompt_size - 1 because the last prompt token is
// the first input token for generation. // the first input token for generation.
const size_t prefill_per_query = min_prompt_size - 1; const size_t prefill_per_query = min_prompt_size - 1;
const hwy::Span<const int> prefill_tokens(prompt.data(), double prefill_start;
prefill_per_query * num_queries); {
PrefillState prefill(pool); PrefillState prefill(pool, num_queries);
prefill.Init<TConfig>(prefill_tokens.size()); prefill.AllocateActivations<TConfig>(num_queries,
const double prefill_start = hwy::platform::Now(); runtime_config.prefill_tbatch_size);
size_t interleaved_pos = pos * num_queries; prefill_start = hwy::platform::Now();
prefill.Prefill<TConfig>(prefill_tokens, num_queries, interleaved_pos, prefill.Prefill<TConfig>(prompts, prefill_per_query, pos, query_idx_start,
weights, runtime_config, kv_caches); weights, runtime_config, kv_caches);
interleaved_pos += prefill_tokens.size(); timing_info.NotifyPrefill(prefill_per_query * num_queries, prefill_start);
timing_info.NotifyPrefill(prefill_tokens.size(), prefill_start); prefill.ResumeMainSpinning();
}
prefill.ResumeMainSpinning(); size_t interleaved_pos = (pos + prefill_per_query) * num_queries;
// Storage for the last generated token from each query, passed to the next // Storage for the last generated token from each query, passed to the next
// Transformer() call. // Transformer() call.
std::vector<int> gen_tokens(num_queries); std::vector<int> gen_tokens(num_queries);
// Stream the last prompt token from each query and fill gen_tokens. // Stream the last prompt token from each query and fill gen_tokens.
hwy::CopyBytes(&prompt[prefill_tokens.size()], gen_tokens.data(),
num_queries * sizeof(prompt[0]));
TokenStreamer token_streamer(runtime_config); TokenStreamer token_streamer(runtime_config);
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) { for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
gen_tokens[query_idx] = prompts[query_idx][prefill_per_query];
(void)token_streamer(query_idx_start + query_idx, prefill_per_query, (void)token_streamer(query_idx_start + query_idx, prefill_per_query,
gen_tokens[query_idx], 0.0f); gen_tokens[query_idx], 0.0f);
} }
@ -940,42 +1028,49 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations,
timing_info.NotifyGenerateDone(gen_start); timing_info.NotifyGenerateDone(gen_start);
} }
// TODO: prompt should also be span, not a vector.
template <class TConfig> template <class TConfig>
void GenerateSingleT(const ByteStorageT& weights_u8, Activations& activations, void GenerateSingleT(const ByteStorageT& weights_u8,
const RuntimeConfig& runtime_config, const RuntimeConfig& runtime_config,
const std::vector<int>& prompt, size_t pos, const PromptTokens& prompt, size_t pos, KVCache& kv_cache,
KVCache& kv_cache, hwy::ThreadPool& pool, hwy::ThreadPool& pool, TimingInfo& timing_info) {
TimingInfo& timing_info) { const size_t num_queries = 1;
const hwy::Span<int> prompt_span(const_cast<int*>(prompt.data()), const size_t qbatch_start = 0;
prompt.size());
const hwy::Span<const hwy::Span<int>> prompts(&prompt_span, 1); Activations activations;
// TODO: also span of kv_cache, or batching inside KVCache? activations.Allocate<TConfig>(num_queries);
std::vector<KVCache*> kv_caches = {&kv_cache};
const size_t query_idx_start = 0; const MultiplePromptsTokens prompts(&prompt, num_queries);
GenerateT<TConfig, /*kQueryBatchSize=*/1>( const KVCaches kv_caches{&kv_cache, num_queries};
weights_u8, activations, runtime_config, prompts, pos, query_idx_start,
kv_caches, pool, timing_info); GenerateT<TConfig>(weights_u8, activations, runtime_config, prompts, pos,
qbatch_start, kv_caches, pool, timing_info);
} }
template <class TConfig> template <class TConfig>
void GenerateBatchT(const ByteStorageT& weights_u8, Activations& activations, void GenerateBatchT(const ByteStorageT& weights_u8,
const RuntimeConfig& runtime_config, const RuntimeConfig& runtime_config,
const hwy::Span<const hwy::Span<int>>& prompts, size_t pos, const MultiplePromptsTokens& prompts, size_t pos,
const std::vector<KVCache*>& kv_caches, const KVCaches& kv_caches, hwy::ThreadPool& pool,
hwy::ThreadPool& pool, TimingInfo& timing_info) { TimingInfo& timing_info) {
// Disable query batching for Griffin models. HWY_ASSERT(prompts.size() == kv_caches.size());
constexpr size_t kQueryBatchSize = // Griffin does not support query batching.
(TConfig::kGriffinLayers > 0) ? 1 : kBatchedQueryBatchSize; const size_t max_qbatch_size =
for (size_t query_idx_start = 0; query_idx_start < prompts.size(); (TConfig::kGriffinLayers > 0) ? 1 : runtime_config.decode_qbatch_size;
query_idx_start += kQueryBatchSize) {
const size_t num_queries = Activations activations;
std::min(prompts.size() - query_idx_start, kQueryBatchSize); activations.Allocate<TConfig>(max_qbatch_size);
const hwy::Span<const hwy::Span<int>> query_batch(
prompts.data() + query_idx_start, num_queries); const size_t num_queries = prompts.size();
GenerateT<TConfig, kQueryBatchSize>(weights_u8, activations, runtime_config, for (size_t qbatch_start = 0; qbatch_start < num_queries;
query_batch, pos, query_idx_start, qbatch_start += max_qbatch_size) {
kv_caches, pool, timing_info); // Generate one batch of tokens from `qbatch_size` queries.
const size_t qbatch_size =
HWY_MIN(num_queries - qbatch_start, max_qbatch_size);
const MultiplePromptsTokens qbatch_prompts(&prompts[qbatch_start],
qbatch_size);
const KVCaches qbatch_kv(&kv_caches[qbatch_start], qbatch_size);
GenerateT<TConfig>(weights_u8, activations, runtime_config, qbatch_prompts,
pos, qbatch_start, qbatch_kv, pool, timing_info);
} }
} }
@ -986,24 +1081,20 @@ void GenerateBatchT(const ByteStorageT& weights_u8, Activations& activations,
// These are extern functions defined by instantiations/*.cc, which include this // These are extern functions defined by instantiations/*.cc, which include this
// 'header' after defining GEMMA_CONFIG, which is for function overloading. // 'header' after defining GEMMA_CONFIG, which is for function overloading.
void GenerateSingle( // NOLINT(misc-definitions-in-headers) void GenerateSingle( // NOLINT(misc-definitions-in-headers)
GEMMA_CONFIG, const ByteStorageT& weights_u8, Activations& activations, GEMMA_CONFIG, const ByteStorageT& weights_u8,
const RuntimeConfig& runtime_config, const std::vector<int>& prompt, const RuntimeConfig& runtime_config, const PromptTokens& prompt, size_t pos,
size_t pos, KVCache& kv_cache, hwy::ThreadPool& pool, KVCache& kv_cache, hwy::ThreadPool& pool, TimingInfo& timing_info) {
TimingInfo& timing_info) {
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateSingleT<GEMMA_CONFIG>) HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateSingleT<GEMMA_CONFIG>)
(weights_u8, activations, runtime_config, prompt, pos, kv_cache, pool, (weights_u8, runtime_config, prompt, pos, kv_cache, pool, timing_info);
timing_info);
} }
void GenerateBatch( // NOLINT(misc-definitions-in-headers) void GenerateBatch( // NOLINT(misc-definitions-in-headers)
GEMMA_CONFIG, const ByteStorageT& weights_u8, Activations& activations, GEMMA_CONFIG, const ByteStorageT& weights_u8,
const RuntimeConfig& runtime_config, const RuntimeConfig& runtime_config, const MultiplePromptsTokens& prompts,
const hwy::Span<const hwy::Span<int>>& prompts, size_t pos, size_t pos, const KVCaches& kv_caches, hwy::ThreadPool& pool,
const std::vector<KVCache*>& kv_caches, hwy::ThreadPool& pool,
TimingInfo& timing_info) { TimingInfo& timing_info) {
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateBatchT<GEMMA_CONFIG>) HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateBatchT<GEMMA_CONFIG>)
(weights_u8, activations, runtime_config, prompts, pos, kv_caches, pool, (weights_u8, runtime_config, prompts, pos, kv_caches, pool, timing_info);
timing_info);
} }
#endif // HWY_ONCE #endif // HWY_ONCE

View File

@ -24,32 +24,19 @@
#include <string.h> #include <string.h>
#include <utility> // std::move #include <utility> // std::move
#include <vector>
#include "compression/io.h" // Path #include "compression/io.h" // Path
#include "gemma/activations.h"
#include "gemma/common.h" #include "gemma/common.h"
#include "gemma/weights.h" #include "gemma/weights.h"
#include "hwy/aligned_allocator.h" // Span
#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/highway.h" #include "hwy/highway.h"
namespace gcpp { namespace gcpp {
template <typename TConfig>
struct AllocateActivations {
void operator()(Activations& decode) const {
// TODO: this is wasted if we only have single-batch queries. Instead
// re-allocate when the query batch size is actually > 1?
decode.Allocate<TConfig>(kBatchedQueryBatchSize);
}
};
Gemma::Gemma(const Path& tokenizer_path, const Path& weights, Gemma::Gemma(const Path& tokenizer_path, const Path& weights,
const ModelInfo& info, hwy::ThreadPool& pool) const ModelInfo& info, hwy::ThreadPool& pool)
: pool_(pool), tokenizer_(tokenizer_path), info_(info) { : pool_(pool), tokenizer_(tokenizer_path), info_(info) {
weights_u8_ = LoadCompressedWeights(weights, info.model, info.weight, pool); weights_u8_ = LoadCompressedWeights(weights, info.model, info.weight, pool);
CallForModelAndWeight<AllocateActivations>(info.model, info.weight, decode_);
} }
Gemma::Gemma(GemmaTokenizer&& tokenizer, const ModelInfo& info, Gemma::Gemma(GemmaTokenizer&& tokenizer, const ModelInfo& info,
@ -58,7 +45,6 @@ Gemma::Gemma(GemmaTokenizer&& tokenizer, const ModelInfo& info,
HWY_ASSERT(info.weight == Type::kF32); HWY_ASSERT(info.weight == Type::kF32);
weights_u8_ = weights_u8_ =
CallForModel<float, AllocateCompressedWeights>(info.model, pool); CallForModel<float, AllocateCompressedWeights>(info.model, pool);
CallForModelAndWeight<AllocateActivations>(info.model, info.weight, decode_);
} }
Gemma::~Gemma() { Gemma::~Gemma() {
@ -70,67 +56,64 @@ Gemma::~Gemma() {
// we shard them across multiple translation units in instantiations/*.cc. // we shard them across multiple translation units in instantiations/*.cc.
// This declares the functions defined there. We use overloading because // This declares the functions defined there. We use overloading because
// explicit instantiations are still too slow to compile. // explicit instantiations are still too slow to compile.
#define GEMMA_DECLARE(CONFIGT, TWEIGHT) \ #define GEMMA_DECLARE(CONFIGT, TWEIGHT) \
extern void GenerateSingle( \ extern void GenerateSingle(CONFIGT<TWEIGHT>, const ByteStorageT& weights_u8, \
CONFIGT<TWEIGHT>, const ByteStorageT& weights_u8, Activations& decode, \ const RuntimeConfig& runtime_config, \
const RuntimeConfig& runtime_config, const std::vector<int>& prompt, \ const PromptTokens& prompt, size_t pos, \
size_t pos, KVCache& kv_cache, hwy::ThreadPool& pool, \ KVCache& kv_cache, hwy::ThreadPool& pool, \
TimingInfo& timing_info); \ TimingInfo& timing_info); \
extern void GenerateBatch( \ extern void GenerateBatch(CONFIGT<TWEIGHT>, const ByteStorageT& weights_u8, \
CONFIGT<TWEIGHT>, const ByteStorageT& weights_u8, Activations& decode, \ const RuntimeConfig& runtime_config, \
const RuntimeConfig& runtime_config, \ const MultiplePromptsTokens& prompts, size_t pos, \
const hwy::Span<const hwy::Span<int>>& prompts, size_t pos, \ const KVCaches& kv_caches, hwy::ThreadPool& pool, \
const std::vector<KVCache*>& kv_caches, hwy::ThreadPool& pool, \ TimingInfo& timing_info);
TimingInfo& timing_info);
GEMMA_FOREACH_CONFIG_AND_WEIGHT(GEMMA_DECLARE); GEMMA_FOREACH_CONFIG_AND_WEIGHT(GEMMA_DECLARE);
// Adapters to select from the above overloads via CallForModelAndWeight. // Adapters to select from the above overloads via CallForModelAndWeight.
// TODO: gather all ByteStorageT into a type-erased model struct?
template <class TConfig> template <class TConfig>
struct GenerateSingleT { struct GenerateSingleT {
void operator()(const ByteStorageT& weights_u8, Activations& decode, void operator()(const ByteStorageT& weights_u8,
const RuntimeConfig& runtime_config, const RuntimeConfig& runtime_config,
const std::vector<int>& prompt, size_t pos, KVCache& kv_cache, const PromptTokens& prompt, size_t pos, KVCache& kv_cache,
hwy::ThreadPool& pool, TimingInfo& timing_info) const { hwy::ThreadPool& pool, TimingInfo& timing_info) const {
GenerateSingle(TConfig(), weights_u8, decode, runtime_config, prompt, pos, GenerateSingle(TConfig(), weights_u8, runtime_config, prompt, pos, kv_cache,
kv_cache, pool, timing_info); pool, timing_info);
} }
}; };
template <class TConfig> template <class TConfig>
struct GenerateBatchT { struct GenerateBatchT {
void operator()(const ByteStorageT& weights_u8, Activations& decode, void operator()(const ByteStorageT& weights_u8,
const RuntimeConfig& runtime_config, const RuntimeConfig& runtime_config,
const hwy::Span<const hwy::Span<int>>& prompts, size_t pos, const MultiplePromptsTokens& prompts, size_t pos,
const std::vector<KVCache*>& kv_caches, hwy::ThreadPool& pool, const KVCaches& kv_caches, hwy::ThreadPool& pool,
TimingInfo& timing_info) const { TimingInfo& timing_info) const {
GenerateBatch(TConfig(), weights_u8, decode, runtime_config, prompts, pos, GenerateBatch(TConfig(), weights_u8, runtime_config, prompts, pos,
kv_caches, pool, timing_info); kv_caches, pool, timing_info);
} }
}; };
void Gemma::Generate(const RuntimeConfig& runtime_config, void Gemma::Generate(const RuntimeConfig& runtime_config,
const std::vector<int>& prompt, size_t start_pos, const PromptTokens& prompt, size_t start_pos,
KVCache& kv_cache, TimingInfo& timing_info) { KVCache& kv_cache, TimingInfo& timing_info) {
pool_.SetWaitMode(hwy::PoolWaitMode::kSpin); pool_.SetWaitMode(hwy::PoolWaitMode::kSpin);
CallForModelAndWeight<GenerateSingleT>( CallForModelAndWeight<GenerateSingleT>(info_.model, info_.weight, weights_u8_,
info_.model, info_.weight, weights_u8_, decode_, runtime_config, prompt, runtime_config, prompt, start_pos,
start_pos, kv_cache, pool_, timing_info); kv_cache, pool_, timing_info);
pool_.SetWaitMode(hwy::PoolWaitMode::kBlock); pool_.SetWaitMode(hwy::PoolWaitMode::kBlock);
} }
void Gemma::GenerateBatch(const RuntimeConfig& runtime_config, void Gemma::GenerateBatch(const RuntimeConfig& runtime_config,
const hwy::Span<const hwy::Span<int>>& prompts, const MultiplePromptsTokens& prompts,
size_t start_pos, size_t start_pos, const KVCaches& kv_caches,
const std::vector<KVCache*>& kv_caches,
TimingInfo& timing_info) { TimingInfo& timing_info) {
pool_.SetWaitMode(hwy::PoolWaitMode::kSpin); pool_.SetWaitMode(hwy::PoolWaitMode::kSpin);
CallForModelAndWeight<GenerateBatchT>( CallForModelAndWeight<GenerateBatchT>(info_.model, info_.weight, weights_u8_,
info_.model, info_.weight, weights_u8_, decode_, runtime_config, prompts, runtime_config, prompts, start_pos,
start_pos, kv_caches, pool_, timing_info); kv_caches, pool_, timing_info);
pool_.SetWaitMode(hwy::PoolWaitMode::kBlock); pool_.SetWaitMode(hwy::PoolWaitMode::kBlock);
} }

View File

@ -30,7 +30,7 @@
#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/timer.h" #include "hwy/timer.h"
// IWYU pragma: end_exports // IWYU pragma: end_exports
#include "hwy/aligned_allocator.h" #include "hwy/aligned_allocator.h" // Span
#include "hwy/base.h" // hwy::bfloat16_t #include "hwy/base.h" // hwy::bfloat16_t
namespace gcpp { namespace gcpp {
@ -67,6 +67,13 @@ struct RuntimeConfig {
size_t max_tokens; size_t max_tokens;
size_t max_generated_tokens; size_t max_generated_tokens;
// These defaults are overridden by InferenceArgs::CopyTo(*this):
// Max tokens per batch during prefill.
size_t prefill_tbatch_size = 32;
// Max queries per batch (one token from each) during decode.
size_t decode_qbatch_size = 16;
float temperature; float temperature;
int verbosity; int verbosity;
std::mt19937* gen; std::mt19937* gen;
@ -105,6 +112,10 @@ struct TimingInfo {
size_t tokens_generated; size_t tokens_generated;
}; };
using PromptTokens = hwy::Span<const int>;
using MultiplePromptsTokens = hwy::Span<const PromptTokens>;
using KVCaches = hwy::Span<KVCache>;
class Gemma { class Gemma {
public: public:
Gemma(const Path& tokenizer_path, const Path& weights, const ModelInfo& info, Gemma(const Path& tokenizer_path, const Path& weights, const ModelInfo& info,
@ -118,25 +129,20 @@ class Gemma {
const ModelInfo& Info() const { return info_; } const ModelInfo& Info() const { return info_; }
const GemmaTokenizer& Tokenizer() const { return tokenizer_; } const GemmaTokenizer& Tokenizer() const { return tokenizer_; }
const ByteStorageT& Weights() const { return weights_u8_; } const ByteStorageT& Weights() const { return weights_u8_; }
const Activations& Decode() const { return decode_; }
void Generate(const RuntimeConfig& runtime_config, void Generate(const RuntimeConfig& runtime_config, const PromptTokens& prompt,
const std::vector<int>& prompt, size_t start_pos, size_t start_pos, KVCache& kv_cache, TimingInfo& timing_info);
KVCache& kv_cache, TimingInfo& timing_info);
void GenerateBatch(const RuntimeConfig& runtime_config, void GenerateBatch(const RuntimeConfig& runtime_config,
const hwy::Span<const hwy::Span<int>>& prompts, const MultiplePromptsTokens& prompts, size_t start_pos,
size_t start_pos, const std::vector<KVCache*>& kv_caches, const KVCaches& kv_caches, TimingInfo& timing_info);
TimingInfo& timing_info);
private: private:
hwy::ThreadPool& pool_; hwy::ThreadPool& pool_;
GemmaTokenizer tokenizer_; GemmaTokenizer tokenizer_;
// Type-erased so that this can be defined in the header, without requiring // Type-erased so that this can be defined in the header.
// forwarding functions.
ByteStorageT weights_u8_; ByteStorageT weights_u8_;
Activations decode_;
ModelInfo info_; ModelInfo info_;
}; };

View File

@ -23,13 +23,16 @@ namespace gcpp {
namespace { namespace {
template <class TConfig> template <class TConfig>
struct CreateKVCache { struct CreateKVCache {
KVCache operator()() const { KVCache operator()(size_t prefill_tbatch_size) const {
KVCache kv_cache = {}; KVCache kv_cache = {};
const size_t size_cache_pos = CachePosSize<TConfig>()(); const size_t size_cache_pos = CachePosSize<TConfig>()();
if (size_cache_pos != 0) { if (size_cache_pos != 0) {
const size_t seq_len = (TConfig::kSeqLen + kPrefillBatchSize); // Allocate more so that prefill can always access one batch, even if
kv_cache.kv_cache = hwy::AllocateAligned<float>(seq_len * size_cache_pos); // near the end of the sequence.
kv_cache.seq_len = TConfig::kSeqLen + prefill_tbatch_size;
kv_cache.kv_cache =
hwy::AllocateAligned<float>(kv_cache.seq_len * size_cache_pos);
} }
// TODO(patrickms): Add query batching support for Griffin. // TODO(patrickms): Add query batching support for Griffin.
@ -58,10 +61,13 @@ struct CreateKVCache {
}; };
} // namespace } // namespace
KVCache KVCache::Create(Model model_type) { // prefill_tbatch_size is the maximum number of tokens from one query to
// prefill at a time.
KVCache KVCache::Create(Model model_type, size_t prefill_tbatch_size) {
// TWeight=float is a placeholder and unused because CreateKVCache does not // TWeight=float is a placeholder and unused because CreateKVCache does not
// use TConfig::Weight. // use TConfig::Weight.
return CallForModel</*TWeight=*/float, CreateKVCache>(model_type); return CallForModel</*TWeight=*/float, CreateKVCache>(model_type,
prefill_tbatch_size);
} }
} // namespace gcpp } // namespace gcpp

View File

@ -16,13 +16,17 @@
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_KV_CACHE_H_ #ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_KV_CACHE_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_KV_CACHE_H_ #define THIRD_PARTY_GEMMA_CPP_GEMMA_KV_CACHE_H_
#include <stddef.h>
#include "gemma/common.h" // Model #include "gemma/common.h" // Model
#include "hwy/aligned_allocator.h" #include "hwy/aligned_allocator.h"
namespace gcpp { namespace gcpp {
struct KVCache { struct KVCache {
// kSeqLen * kGemmaLayers * kKVHeads * kQKVDim * 2 size_t seq_len = 0; // = kSeqLen + prefill_tbatch_size
// seq_len * kGemmaLayers * kKVHeads * kQKVDim * 2
hwy::AlignedFreeUniquePtr<float[]> kv_cache; hwy::AlignedFreeUniquePtr<float[]> kv_cache;
// (kConv1dWidth - 1) * kModelDim * kGriffinLayers // (kConv1dWidth - 1) * kModelDim * kGriffinLayers
@ -31,7 +35,7 @@ struct KVCache {
// kModelDim * kGriffinLayers // kModelDim * kGriffinLayers
hwy::AlignedFreeUniquePtr<float[]> rglru_cache; hwy::AlignedFreeUniquePtr<float[]> rglru_cache;
static KVCache Create(Model type); static KVCache Create(Model type, size_t prefill_tbatch_size);
}; };
} // namespace gcpp } // namespace gcpp

View File

@ -145,14 +145,12 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, hwy::ThreadPool& pool,
TimingInfo timing_info; TimingInfo timing_info;
RuntimeConfig runtime_config = { RuntimeConfig runtime_config = {
.max_tokens = args.max_tokens,
.max_generated_tokens = args.max_generated_tokens,
.temperature = args.temperature,
.verbosity = verbosity, .verbosity = verbosity,
.gen = &gen, .gen = &gen,
.stream_token = stream_token, .stream_token = stream_token,
.accept_token = accept_token, .accept_token = accept_token,
}; };
args.CopyTo(runtime_config);
model.Generate(runtime_config, prompt, abs_pos, kv_cache, timing_info); model.Generate(runtime_config, prompt, abs_pos, kv_cache, timing_info);
if (verbosity >= 2) { if (verbosity >= 2) {
std::cout << current_pos << " tokens (" << abs_pos << " total tokens)" std::cout << current_pos << " tokens (" << abs_pos << " total tokens)"
@ -181,7 +179,8 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
} }
Gemma model = CreateGemma(loader, pool); Gemma model = CreateGemma(loader, pool);
KVCache kv_cache = KVCache::Create(model.Info().model); KVCache kv_cache =
KVCache::Create(model.Info().model, inference.prefill_tbatch_size);
if (app.verbosity >= 1) { if (app.verbosity >= 1) {
std::string instructions = std::string instructions =

View File

@ -248,6 +248,9 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
size_t max_tokens; size_t max_tokens;
size_t max_generated_tokens; size_t max_generated_tokens;
size_t prefill_tbatch_size;
size_t decode_qbatch_size;
float temperature; float temperature;
bool deterministic; bool deterministic;
bool multiturn; bool multiturn;
@ -272,6 +275,11 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
visitor(max_generated_tokens, "max_generated_tokens", size_t{2048}, visitor(max_generated_tokens, "max_generated_tokens", size_t{2048},
"Maximum number of tokens to generate."); "Maximum number of tokens to generate.");
visitor(prefill_tbatch_size, "prefill_tbatch", size_t{64},
"Prefill: max tokens per batch.");
visitor(decode_qbatch_size, "decode_qbatch", size_t{16},
"Decode: max queries per batch.");
visitor(temperature, "temperature", 1.0f, "Temperature for top-K", 2); visitor(temperature, "temperature", 1.0f, "Temperature for top-K", 2);
visitor(deterministic, "deterministic", false, visitor(deterministic, "deterministic", false,
"Make top-k sampling deterministic", 2); "Make top-k sampling deterministic", 2);
@ -281,6 +289,16 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
" Default : 0 (conversation " " Default : 0 (conversation "
"resets every turn)"); "resets every turn)");
} }
void CopyTo(RuntimeConfig& runtime_config) const {
runtime_config.max_tokens = max_tokens;
runtime_config.max_generated_tokens = max_generated_tokens;
runtime_config.prefill_tbatch_size = prefill_tbatch_size;
runtime_config.decode_qbatch_size = decode_qbatch_size;
runtime_config.temperature = temperature;
}
}; };
} // namespace gcpp } // namespace gcpp