diff --git a/BUILD.bazel b/BUILD.bazel index 6ee5ec6..15ce06b 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -224,6 +224,7 @@ cc_library( ":common", ":cross_entropy", ":gemma_lib", + ":kv_cache", # Placeholder for internal dep, do not remove., "@benchmark//:benchmark", "//compression:compress", diff --git a/backprop/optimize_test.cc b/backprop/optimize_test.cc index 2e031d0..8df94e9 100644 --- a/backprop/optimize_test.cc +++ b/backprop/optimize_test.cc @@ -52,7 +52,7 @@ TEST(OptimizeTest, GradientDescent) { CallForModelAndWeight(info.model, info.weight); ByteStorageT backward = CallForModelAndWeight(info.model, info.weight); - KVCache kv_cache = KVCache::Create(info.model); + KVCache kv_cache = KVCache::Create(info.model, /*prefill_tbatch_size=*/16); Gemma gemma(GemmaTokenizer(), info, pool); diff --git a/evals/benchmark.cc b/evals/benchmark.cc index 9b0d609..bbabc7e 100644 --- a/evals/benchmark.cc +++ b/evals/benchmark.cc @@ -128,7 +128,8 @@ int BenchmarkCrossEntropy(GemmaEnv& env, const Path& text, size_t num_tokens = std::min(prompt.size() - pos, batch_tokens); std::vector prompt_slice(prompt.begin() + pos, prompt.begin() + pos + num_tokens); - KVCache kv_cache = KVCache::Create(env.Info().model); + KVCache kv_cache = KVCache::Create( + env.Info().model, env.MutableInferenceArgs().prefill_tbatch_size); float entropy = ComputeCrossEntropy( *env.GetModel(), num_tokens, prompt_slice, kv_cache, env.Verbosity()); total_entropy += entropy; diff --git a/evals/benchmark_helper.cc b/evals/benchmark_helper.cc index caa489e..4111bd9 100644 --- a/evals/benchmark_helper.cc +++ b/evals/benchmark_helper.cc @@ -34,9 +34,9 @@ #include "evals/cross_entropy.h" #include "gemma/common.h" // StringFromType #include "gemma/gemma.h" +#include "gemma/kv_cache.h" #include "util/app.h" #include "util/args.h" -#include "hwy/aligned_allocator.h" #include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/highway.h" @@ -76,10 +76,10 @@ GemmaEnv::GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference, fprintf(stderr, "Loading model...\n"); model_ = AllocateGemma(loader_, pool_); - kv_caches_.reserve(kBatchedQueryBatchSize); - for (int i = 0; i < kBatchedQueryBatchSize; ++i) { - kv_caches_.push_back(new KVCache(KVCache::Create(model_->Info().model))); - } + // Only allocate one for starters because GenerateBatch might not be called. + kv_caches_.resize(1); + kv_caches_[0] = + KVCache::Create(model_->Info().model, inference.prefill_tbatch_size); } InitGenerator(inference_args_, gen_); @@ -132,7 +132,7 @@ std::pair GemmaEnv::QueryModel( } gcpp::TimingInfo timing_info; runtime_config_.batch_stream_token = batch_stream_token; - model_->Generate(runtime_config_, tokens, /*start_pos=*/0, *kv_caches_[0], + model_->Generate(runtime_config_, tokens, /*start_pos=*/0, kv_caches_[0], timing_info); if (app_.verbosity >= 1) { LogSpeedStats(time_start, total_tokens); @@ -141,8 +141,10 @@ std::pair GemmaEnv::QueryModel( } std::vector> GemmaEnv::BatchQueryModel2( - const hwy::Span>& prompts) { - std::vector> res(prompts.size()); + const MultiplePromptsTokens& prompts) { + const size_t num_queries = prompts.size(); + HWY_ASSERT(num_queries != 0); + std::vector> res(num_queries); std::fill(res.begin(), res.end(), std::make_pair("", 0)); size_t total_tokens = 0; @@ -162,14 +164,29 @@ std::vector> GemmaEnv::BatchQueryModel2( return true; }; if (app_.verbosity >= 2) { - std::cout << inference_args_.max_tokens << " " - << inference_args_.max_generated_tokens << " " - << inference_args_.temperature; + fprintf(stderr, + "Max tok: %zu max gen: %zu temp: %f tbatch: %zu qbatch: %zu\n", + inference_args_.max_tokens, inference_args_.max_generated_tokens, + inference_args_.temperature, inference_args_.prefill_tbatch_size, + inference_args_.decode_qbatch_size); } + + // Ensure we have one KVCache per query. + if (kv_caches_.size() < num_queries) { + kv_caches_.resize(num_queries); + } + for (size_t i = 1; i < num_queries; ++i) { + if (kv_caches_[i].seq_len == 0) { + kv_caches_[i] = KVCache::Create(model_->Info().model, + inference_args_.prefill_tbatch_size); + } + } + gcpp::TimingInfo timing_info; runtime_config_.batch_stream_token = batch_stream_token; - model_->GenerateBatch(runtime_config_, prompts, /*start_pos=*/0, kv_caches_, - timing_info); + inference_args_.CopyTo(runtime_config_); + model_->GenerateBatch(runtime_config_, prompts, /*start_pos=*/0, + KVCaches(&kv_caches_[0], num_queries), timing_info); if (app_.verbosity >= 1) { LogSpeedStats(time_start, total_tokens); } @@ -191,13 +208,12 @@ std::vector> GemmaEnv::BatchQueryModel( prompts.push_back(WrapAndTokenize(model_->Tokenizer(), model_->Info(), /*pos=*/0, mutable_prompt)); } - std::vector> prompt_vector; + std::vector prompt_vector; prompt_vector.reserve(prompts.size()); for (auto& prompt : prompts) { - prompt_vector.push_back(hwy::Span(prompt.data(), prompt.size())); + prompt_vector.push_back(PromptTokens(prompt.data(), prompt.size())); } - hwy::Span> prompt_span = hwy::Span>( - prompt_vector.data(), prompt_vector.size()); + MultiplePromptsTokens prompt_span(prompt_vector.data(), prompt_vector.size()); return BatchQueryModel2(prompt_span); } @@ -226,8 +242,8 @@ void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) { if (app.verbosity >= 2) { time_t now = time(nullptr); char* dt = ctime(&now); // NOLINT + // TODO: replace hardware_concurrency with detected topology. std::cout << "Date & Time : " << dt - << "Prefill Token Batch Size : " << kPrefillBatchSize << "\n" << "Hardware concurrency : " << std::thread::hardware_concurrency() << "\n" << "Instruction set : " diff --git a/evals/benchmark_helper.h b/evals/benchmark_helper.h index 1de6d70..21b0d2f 100644 --- a/evals/benchmark_helper.h +++ b/evals/benchmark_helper.h @@ -69,7 +69,7 @@ class GemmaEnv { // the number of tokens that were generated. std::pair QueryModel(const std::vector& tokens); std::vector> BatchQueryModel2( - const hwy::Span>& prompts); + const MultiplePromptsTokens& prompts); // Adds turn structure to input, tokenizes and calls the above overload. std::pair QueryModel(std::string& input); std::vector> BatchQueryModel( @@ -88,7 +88,7 @@ class GemmaEnv { const ModelInfo& Info() const { return loader_.Info(); } InferenceArgs& MutableInferenceArgs() { return inference_args_; } std::mt19937& MutableGen() { return gen_; } - KVCache& MutableKVCache() { return *kv_caches_[0]; } + KVCache& MutableKVCache() { return kv_caches_[0]; } private: // Arguments to the model loader: file locations, etc. @@ -103,8 +103,8 @@ class GemmaEnv { std::mt19937 gen_; // The model to run inference on. std::unique_ptr model_; - // The KV cache to use for inference. - std::vector kv_caches_; + // KV caches, same number as query batch. + std::vector kv_caches_; RuntimeConfig runtime_config_; }; diff --git a/evals/gemma_test.cc b/evals/gemma_test.cc index 6436908..29c189e 100644 --- a/evals/gemma_test.cc +++ b/evals/gemma_test.cc @@ -17,14 +17,12 @@ #include -#include #include #include #include "evals/benchmark_helper.h" #include "gemma/common.h" -#include "gemma/tokenizer.h" -#include "hwy/aligned_allocator.h" +#include "hwy/base.h" #include "hwy/tests/hwy_gtest.h" // This test can be run manually with the downloaded gemma weights. @@ -75,21 +73,17 @@ class GemmaTest : public ::testing::Test { replies.push_back(response); } } else { // Not Gemma-2 27B. Do not use turn structure. - std::vector>> prompts; - prompts.reserve(inputs.size()); - for (auto input_string : inputs) { - std::string mutable_input_string = input_string; - prompts.push_back(std::make_unique>( - s_env->TokenizeAndPrependBOS(input_string))); + std::vector> prompts_vector; + prompts_vector.reserve(inputs.size()); + for (const auto& input_string : inputs) { + prompts_vector.push_back(s_env->TokenizeAndPrependBOS(input_string)); } - std::vector> prompt_vector; - for (auto& prompt : prompts) { - prompt_vector.push_back(hwy::Span(prompt->data(), prompt->size())); + std::vector prompt_spans; + for (const auto& prompt : prompts_vector) { + prompt_spans.push_back(PromptTokens(prompt.data(), prompt.size())); } - hwy::Span> prompt_span = - hwy::Span>(prompt_vector.data(), - prompt_vector.size()); - for (auto [response, n] : s_env->BatchQueryModel2(prompt_span)) { + MultiplePromptsTokens prompts(prompt_spans.data(), prompt_spans.size()); + for (auto [response, n] : s_env->BatchQueryModel2(prompts)) { replies.push_back(response); } } @@ -121,18 +115,20 @@ class GemmaTest : public ::testing::Test { } }; -TEST_F(GemmaTest, Geography) { +TEST_F(GemmaTest, GeographyBatched) { + s_env->MutableInferenceArgs().decode_qbatch_size = 3; + // 6 are enough to test batching and the loop. static const char* kQA[][2] = { - {"What is the capital of Hungary?", "Budapest"}, {"What is the capital of Australia?", "Canberra"}, + {"What is the capital of Denmark?", "Copenhagen"}, + {"Ljubljana is the capital of which country?", "Slovenia"}, + {"Is Chicago a country?", "not"}, {"How many states does the US have?", "50"}, + {"What is the Pacific?", "ocean"}, }; static const size_t kNum = sizeof(kQA) / sizeof(kQA[0]); - TestQuestions(kQA, kNum, /*batch=*/false); - static const char* kQA_single_question[][2] = { - {"What is the capital of Australia?", "Canberra"}, - }; - TestQuestions(kQA_single_question, 1, /*batch=*/true); + TestQuestions(kQA, HWY_MIN(kNum, 3), /*batch=*/false); + TestQuestions(kQA, 1, /*batch=*/true); TestQuestions(kQA, kNum, /*batch=*/true); } diff --git a/examples/hello_world/run.cc b/examples/hello_world/run.cc index f3a331e..bfef76f 100644 --- a/examples/hello_world/run.cc +++ b/examples/hello_world/run.cc @@ -31,6 +31,7 @@ int main(int argc, char** argv) { gcpp::LoaderArgs loader(argc, argv); + gcpp::InferenceArgs inference(argc, argv); if (gcpp::HasHelp(argc, argv)) { loader.Help(); return 0; @@ -42,7 +43,8 @@ int main(int argc, char** argv) { // Instantiate model and KV Cache hwy::ThreadPool pool(gcpp::AppArgs::GetSupportedThreadCount()); gcpp::Gemma model = gcpp::CreateGemma(loader, pool); - gcpp::KVCache kv_cache = gcpp::KVCache::Create(loader.Info().model); + gcpp::KVCache kv_cache = + gcpp::KVCache::Create(loader.Info().model, inference.prefill_tbatch_size); size_t pos = 0; // KV Cache position // Initialize random number generator diff --git a/gemma/common.h b/gemma/common.h index 099314f..7471ceb 100644 --- a/gemma/common.h +++ b/gemma/common.h @@ -36,11 +36,6 @@ ByteStorageT AllocateSizeof() { return hwy::AllocateAligned(sizeof(T)); } -// Relatively small so that we can also parallelize non-Matmul work. There is -// one outer thread per batch, each with --num_threads / batches inner threads. -constexpr size_t kPrefillBatchSize = 64; -constexpr size_t kBatchedQueryBatchSize = 16; - // Model variants: see configs.h for details. When adding a new one, also // update GEMMA_FOREACH* and Call* below, and add instantiations/*.cc. enum class Model { diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index 62ebcb0..d4b6745 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -73,10 +73,10 @@ template HWY_NOINLINE void GriffinRecurrent( size_t batch_start, size_t num_tokens, size_t num_queries, size_t layer, Activations& activations, const CompressedLayer* layer_weights, - const std::vector& kv_caches, hwy::ThreadPool& pool) { + const KVCaches& kv_caches, hwy::ThreadPool& pool) { PROFILER_ZONE("Gen.Griffin"); HWY_ASSERT(num_queries == 1); // TODO: add batch query support for Griffin. - KVCache& kv_cache = *kv_caches[0]; + KVCache& kv_cache = kv_caches[0]; namespace hn = hwy::HWY_NAMESPACE; using D = hn::ScalableTag; static constexpr size_t kModelDim = TConfig::kModelDim; @@ -208,7 +208,7 @@ HWY_NOINLINE void GemmaAttention(size_t interleaved_start, size_t num_tokens, size_t num_queries, size_t layer, Activations& activations, const CompressedLayer* layer_weights, - const std::vector& kv_caches, + const KVCaches& kv_caches, hwy::ThreadPool& pool) { PROFILER_ZONE("Gen.Attention"); HWY_DASSERT(interleaved_start % num_queries == 0); @@ -221,6 +221,10 @@ HWY_NOINLINE void GemmaAttention(size_t interleaved_start, size_t num_tokens, constexpr size_t kKVHeads = TConfig::kKVHeads; constexpr size_t kSeqLen = TConfig::kSeqLen; GEMMA_CONSTEXPR_SQRT float kQueryScale = ChooseQueryScale(); + + HWY_ASSERT(num_queries <= kv_caches.size()); + const hwy::Divisor div_seq_len(static_cast(kv_caches[0].seq_len)); + // Multi-Head Attention a.k.a. "use_qkv_einsum". constexpr bool kIsMHA = Activations::IsMHA(); static_assert(!kIsMHA || TConfig::kInterleaveQKV); // MHA => interleaved @@ -245,9 +249,9 @@ HWY_NOINLINE void GemmaAttention(size_t interleaved_start, size_t num_tokens, const float* x = activations.pre_att_rms_out.Batch(interleaved_idx); const size_t query_idx = interleaved_idx % num_queries; const size_t batch_idx = interleaved_idx / num_queries; - KVCache& kv_cache = *kv_caches[query_idx]; + KVCache& kv_cache = kv_caches[query_idx]; const size_t pos = batch_start + batch_idx; - const size_t cache_pos = pos % (kSeqLen + kPrefillBatchSize); + const size_t cache_pos = div_seq_len.Remainder(pos); const size_t kv_offset = cache_pos * kCachePosSize + layer * kCacheLayerSize; float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset; @@ -268,10 +272,10 @@ HWY_NOINLINE void GemmaAttention(size_t interleaved_start, size_t num_tokens, const size_t query_idx = interleaved_idx % num_queries; const size_t batch_idx = interleaved_idx / num_queries; const size_t pos = batch_start + batch_idx; - const size_t cache_pos = pos % (kSeqLen + kPrefillBatchSize); + const size_t cache_pos = div_seq_len.Remainder(pos); const size_t kv_offset = cache_pos * kCachePosSize + layer * kCacheLayerSize + head * kQKVDim * 2; - KVCache& kv_cache = *kv_caches[query_idx]; + KVCache& kv_cache = kv_caches[query_idx]; float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset; if constexpr (kIsMHA) { // For MHA, copy KV into the KV cache from scratch space (see above). @@ -297,7 +301,7 @@ HWY_NOINLINE void GemmaAttention(size_t interleaved_start, size_t num_tokens, const size_t query_idx = interleaved_idx % num_queries; const size_t batch_idx = interleaved_idx / num_queries; const size_t head_offset = (head / kHeadGroups) * kQKVDim * 2; - KVCache& kv_cache = *kv_caches[query_idx]; + KVCache& kv_cache = kv_caches[query_idx]; float* HWY_RESTRICT q = activations.q.Batch(interleaved_idx) + head * kQStride; @@ -314,10 +318,10 @@ HWY_NOINLINE void GemmaAttention(size_t interleaved_start, size_t num_tokens, const size_t start_pos = pos - std::min(TConfig::kAttentionWindowSizes[layer] - 1, pos); for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) { - const size_t cache_pos = pos2 % (kSeqLen + kPrefillBatchSize); + const size_t cache_pos = div_seq_len.Remainder(pos2); const size_t kv_offset = cache_pos * kCachePosSize + layer * kCacheLayerSize + head_offset; - const float* HWY_RESTRICT k = kv_cache.kv_cache.get() + kv_offset; + const float* HWY_RESTRICT k = &kv_cache.kv_cache[kv_offset]; const float score = Dot(q, k, kQKVDim); head_att[pos2 % kSeqLen] = score; } @@ -337,7 +341,7 @@ HWY_NOINLINE void GemmaAttention(size_t interleaved_start, size_t num_tokens, activations.att_out.Batch(interleaved_idx) + head * kQKVDim; hwy::ZeroBytes(att_out, kQKVDim * sizeof(*att_out)); for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) { - const size_t cache_pos = pos2 % (kSeqLen + kPrefillBatchSize); + const size_t cache_pos = div_seq_len.Remainder(pos2); const size_t kv_offset = cache_pos * kCachePosSize + layer * kCacheLayerSize + head_offset; float* HWY_RESTRICT v = @@ -383,8 +387,7 @@ HWY_NOINLINE void Attention(LayerAttentionType type, size_t interleaved_start, size_t num_tokens, size_t num_queries, size_t layer, Activations& activations, const CompressedLayer* layer_weights, - const std::vector& kv_caches, - hwy::ThreadPool& pool) { + const KVCaches& kv_caches, hwy::ThreadPool& pool) { if (type == LayerAttentionType::kGemma) { GemmaAttention(interleaved_start, num_tokens, num_queries, layer, activations, layer_weights, kv_caches, pool); @@ -458,12 +461,13 @@ HWY_NOINLINE void FFW(Activations& activations, size_t num_interleaved, output_bias, pool); } -// TODO: pass Activations.x instead of Activations. -// `pos` is for the entire batch and does not include `batch_idx`. +// `batch_idx` indicates which row of `x` to write to. +// `pos` is the *token*'s position, not the start of the batch, because this is +// called for batches of tokens in prefill, but batches of queries in decode. template HWY_NOINLINE void EmbedToken(int token, size_t batch_idx, size_t pos, const CompressedWeights& weights, - Activations& activations) { + RowVectorBatch& x) { constexpr size_t kModelDim = TConfig::kModelDim; GEMMA_CONSTEXPR_EMBSCALING const float kEmbScaling = EmbeddingScaling(); @@ -472,11 +476,10 @@ HWY_NOINLINE void EmbedToken(int token, size_t batch_idx, size_t pos, HWY_DASSERT(token < TConfig::kVocabSize); Decompress(weights.embedder_input_embedding, token * kModelDim, - activations.x.Batch(batch_idx), kModelDim); - MulByConst(kEmbScaling, activations.x.Batch(batch_idx), kModelDim); + x.Batch(batch_idx), kModelDim); + MulByConst(kEmbScaling, x.Batch(batch_idx), kModelDim); if constexpr (TConfig::kAbsolutePE) { - AddAbsolutePositionalEmbeddings(activations.x.Batch(batch_idx), kModelDim, - pos + batch_idx); + AddAbsolutePositionalEmbeddings(x.Batch(batch_idx), kModelDim, pos); }; } @@ -501,7 +504,7 @@ template HWY_NOINLINE void TransformerLayer( size_t num_tokens, size_t num_queries, size_t pos, size_t layer, const CompressedLayer* layer_weights, Activations& activations, - const std::vector& kv_caches, hwy::ThreadPool& pool) { + const KVCaches& kv_caches, hwy::ThreadPool& pool) { constexpr size_t kModelDim = TConfig::kModelDim; const size_t num_interleaved = num_tokens * num_queries; auto type = TConfig::kLayerConfig[layer]; @@ -536,116 +539,220 @@ HWY_NOINLINE void TransformerLayer( /*is_attention=*/false); } -// For prefill, we have two-level parallelism: -// - Outer: input tokens are split into batches, each of which is one task -// processed by a worker in `outer_pool_`, which includes the main thread -// because it is the one that calls `Prefill`. +// Batches are important for amortizing loading weights over multiple tokens. +// This is possible in prefill because we know all tokens beforehand, whereas +// decode depends on the previous output token. However, each prefill batch of a +// query requires that preceding batches already wrote to the KV cache, hence we +// sequentially loop over token batches. We can reduce the number of iterations +// by increasing the batch size, but this also increases arithmetic intensity, +// and so we are eventually compute-limited. The tensor parallelism (number of +// threads collaborating on MatMul) is also limited by the CPU topology: +// fork/join barriers are slow(er) when some threads reside in a different NUMA +// node. To allow more threads to help, we also support parallelizing over +// queries in case GenerateBatch was called. +// +// Thus we have two-level parallelism: +// - Outer: handles one 'qbatch' of entire queries. The set of outer workers +// includes the main thread because it is the one that calls `Prefill`, and is +// determined by the number of 'clusters' (shared L3 caches or sockets). // - Inner: each `outer` worker passes `inner_pools_[outer]` to -// `TransformerLayer` for tensor-level parallelism. +// `TransformerLayer` for tensor-level parallelism, and processes +// `tbatch_size` tokens from a single query at a time. // -// This class holds the thread pools and activations, recreated for each query. -// -// It is safe to parallelize batches because we write to KVCache at -// `pos % kSeqLen`, which is far greater than the number of workers. -// Note however that this currently leads to nondeterministic results because -// the RNG is invoked in different order. +// This class holds the thread pools and one activation per outer worker. It is +// NOT reused across calls to GenerateSingle/GenerateBatch so that we can adapt +// to their num_queries. class PrefillState { - public: - explicit PrefillState(hwy::ThreadPool& main_pool) : main_pool_(&main_pool) {} - - ~PrefillState() { DeleteInnerPools(); } - - // Called before each query. Recreates thread pools, which has the advantage - // of tailoring the parallelism to the prompt length. - template - void Init(size_t prefill_size) { - // Would be zero for single-token prompts (prefill_size == num_tokens - 1). - num_batches_ = - HWY_MAX(size_t{1}, hwy::DivCeil(prefill_size, kPrefillBatchSize)); - // More than num_batches_ would waste workers on idling in the outer Run; - // more than NumWorkers() would exceed the global --num_threads. - const size_t outer_workers = - HWY_MIN(num_batches_, main_pool_->NumWorkers()); - HWY_ASSERT(outer_workers != 0); // Otherwise activations_ is empty. - - // One activation per outer worker. Allocating in parallel saves 30 ms. - activations_.resize(outer_workers); - main_pool_->Run(0, outer_workers, [this](uint64_t task, size_t /*thread*/) { - activations_[task].Allocate(kPrefillBatchSize); + // TODO: move helper functions, also those in app.h, to a threading header + using LPS = hwy::LogicalProcessorSet; + LPS Intersection(const LPS& big, const LPS& small) { + LPS both; + // Reduce expected work by iterating over the smaller set. + small.Foreach([big, &both](size_t idx) { + if (big.Get(idx)) both.Set(idx); }); + return both; + } - DeleteInnerPools(); + std::vector CoresInLPS(const LPS& cluster) { + std::vector cores; + cores.reserve(cluster.Count()); + cluster.Foreach([&cores](size_t idx) { cores.push_back(idx); }); + return cores; + } - // If we'd create just one inner pool with all the workers, skip the cost of - // thread creation and pinning (about 60 ms) by reusing the main pool. - if (outer_workers <= 1) { - // Still allocate a dummy pool to simplify Prefill(). - outer_pool_ = std::make_unique(1); - inner_pools_.push_back(main_pool_); - return; + // For each cluster (shared L3 cache), a bitset of cores. + using CoresPerCluster = std::vector; + + // Returns empty if detection failed. + CoresPerCluster DetectClusters() { + CoresPerCluster clusters; + // Which processors are not disabled via OS, taskset, or numactl. + LPS enabled; + // If we don't know, better to use just a single inner pool rather than risk + // oversubscribing to enabled cores. + if (!GetThreadAffinity(enabled)) return clusters; + + hwy::Topology topology; + if (topology.packages.empty()) return clusters; + + // For each cluster = outer, the cores that will be used for an inner pool. + CoresPerCluster inner_lps; + for (const hwy::Topology::Package& package : topology.packages) { + for (const hwy::Topology::Cluster& cluster : package.clusters) { + // Only use enabled cores, and only add if not empty. + const LPS lps = Intersection(enabled, cluster.lps); + if (lps.Any()) clusters.push_back(lps); + } } + // Sort by descending number of enabled cores, so that we preferentially + // use the largest clusters. + std::sort(clusters.begin(), clusters.end(), + [](const LPS& a, const LPS& b) { return a.Count() > b.Count(); }); + + return clusters; + } + + // Returns false if the main pool should be reused instead. + bool AssignInnerPoolsToClusters(const size_t num_queries) { + HWY_ASSERT(num_queries != 0); + + CoresPerCluster inner_lps = DetectClusters(); + // If we have more outer workers than queries, discard the excess. + if (inner_lps.size() > num_queries) inner_lps.resize(num_queries); + // If we're not going to create multiple pools, avoid the overhead of + // re-pinning (60 ms) and reuse the main pool. + if (inner_lps.size() <= 1) return false; + // Before creating new threads, stop the old ones from spinning. Caller is // responsible for undoing this by calling `ResumeMainSpinning`. main_pool_->SetWaitMode(hwy::PoolWaitMode::kBlock); - outer_pool_ = std::make_unique(outer_workers); + + outer_pool_ = std::make_unique(inner_lps.size()); outer_pool_->SetWaitMode(hwy::PoolWaitMode::kSpin); - // Assign up to `max_workers` to inner pools. Each inner pool creates - // `workers_per_outer - 1` threads in addition to its 'main' thread, which - // is the one calling `inner_pools[outer]->Run`, i.e., `outer`. In total, - // `outer_workers * (max_workers / outer_workers)` workers are used. - const size_t workers_per_outer = main_pool_->NumWorkers() / outer_workers; - for (size_t outer = 0; outer < outer_workers; ++outer) { - inner_pools_.push_back(new hwy::ThreadPool(workers_per_outer)); + HWY_ASSERT(inner_pools_.empty()); + for (const LPS& inner : inner_lps) { + inner_pools_.push_back(new hwy::ThreadPool(inner.Count())); inner_pools_.back()->SetWaitMode(hwy::PoolWaitMode::kSpin); } - PinThreads(outer_workers, workers_per_outer); + // For each inner pool, pin their threads AND the associated outer thread + // to the enabled cores in the cluster. + outer_pool_->Run( + 0, inner_lps.size(), + [this, &inner_lps](uint64_t outer, size_t outer_thread) { + HWY_ASSERT(outer == outer_thread); // each outer has one task + const std::vector cores = CoresInLPS(inner_lps[outer]); + + inner_pools_[outer]->Run( + 0, cores.size(), [&cores](uint64_t task, size_t thread) { + HWY_ASSERT(task == thread); // each inner has one task + hwy::PinThreadToLogicalProcessor(cores[task]); + }); + }); + + return true; } - // `tokens` are from interleaved queries. (See InterleaveQueries() below.) + void ReuseMainPoolAsInner() { + // Still allocate an empty pool to simplify Prefill(). + outer_pool_ = std::make_unique(1); + + HWY_ASSERT(inner_pools_.empty()); + inner_pools_.push_back(main_pool_); + } + + public: + // Creates pools. AllocateActivations must still be called separately; it has + // a template argument. + PrefillState(hwy::ThreadPool& main_pool, size_t num_queries) + : main_pool_(&main_pool) { + PROFILER_ZONE("Init.Prefill.Ctor"); + if (!AssignInnerPoolsToClusters(num_queries)) { + ReuseMainPoolAsInner(); + } + } + + ~PrefillState() { + for (hwy::ThreadPool* p : inner_pools_) { + if (p != main_pool_) delete p; + } + } + + // `tbatch_size` is the number of tokens from one query to prefill at a time. template - HWY_NOINLINE void Prefill(hwy::Span tokens, size_t num_queries, - size_t pos, + void AllocateActivations(size_t num_queries, size_t tbatch_size) { + PROFILER_ZONE("Init.Prefill.AllocateActivations"); + + const size_t outer_workers = outer_pool_->NumWorkers(); + HWY_ASSERT(outer_workers != 0); // Otherwise activations_ is empty. + + HWY_ASSERT(activations_.empty()); // only call once. + activations_.resize(outer_workers); + + if (outer_workers == 1) { + activations_[0].Allocate(tbatch_size); + } else { + // Allocating in parallel can save 30 ms. + main_pool_->Run(0, outer_workers, + [this, tbatch_size](uint64_t task, size_t /*thread*/) { + activations_[task].Allocate(tbatch_size); + }); + } + } + + template + HWY_NOINLINE void Prefill(const MultiplePromptsTokens& prompts, + const size_t prefill_per_query, const size_t pos, + const size_t query_idx_start, const CompressedWeights& weights, const RuntimeConfig& runtime_config, - const std::vector& kv_caches) { + const KVCaches& kv_caches) { PROFILER_ZONE("Gen.Prefill"); + const size_t num_queries = prompts.size(); + HWY_ASSERT(kv_caches.size() == num_queries); + const size_t max_tbatch_size = activations_[0].x.BatchSize(); - HWY_ASSERT(activations_.size() == outer_pool_->NumWorkers()); - HWY_ASSERT(inner_pools_.size() == outer_pool_->NumWorkers()); - + // For each query (parallel): an outer worker processes all its tokens. + // `qi` is relative to the batch, not the global query index. outer_pool_->Run( - 0, num_batches_, [&](const uint64_t batch_num, size_t thread) HWY_ATTR { - const size_t batch_start = batch_num * kPrefillBatchSize; - const size_t batch_size = - HWY_MIN(kPrefillBatchSize, tokens.size() - batch_start); - HWY_DASSERT(batch_start % num_queries == 0); - HWY_DASSERT(batch_size % num_queries == 0); - const size_t pos_per_query = pos + batch_start / num_queries; - const size_t num_tokens = batch_size / num_queries; + 0, num_queries, [&](const uint64_t qi, size_t qthread) HWY_ATTR { + Activations& activations = activations_[qthread]; + hwy::ThreadPool& inner_pool = *inner_pools_[qthread]; - // Negligible time compared to TransformerLayer. - for (size_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) { - EmbedToken(tokens[batch_start + batch_idx], batch_idx, - pos_per_query, weights, activations_[thread]); - } + // Single query at a time, so pass a slice of the KV cache because + // GemmaAttention will only access the first. + const size_t kPrefillQueries = 1; + KVCaches prefill_kv_caches(&kv_caches[qi], kPrefillQueries); - for (size_t layer = 0; layer < TConfig::kLayers; ++layer) { - const auto* layer_weights = weights.GetLayer(layer); - TransformerLayer( - num_tokens, num_queries, pos_per_query, layer, layer_weights, - activations_[thread], kv_caches, *inner_pools_[thread]); - } + // For each batch of tokens in the query: + for (size_t tbatch_start = 0; tbatch_start < prefill_per_query; + tbatch_start += max_tbatch_size) { + // Fill activations.x (much faster than TransformerLayer). + const size_t tbatch_size = + HWY_MIN(max_tbatch_size, prefill_per_query - tbatch_start); + for (size_t ti = 0; ti < tbatch_size; ++ti) { + const int token = prompts[qi][tbatch_start + ti]; + EmbedToken(token, ti, pos + ti, weights, activations.x); + } - // NOTE: we unconditionally call StreamToken, even if EOS. - for (size_t i = 0; i < batch_size; ++i) { - const size_t query_idx = i % num_queries; - const size_t batch_idx = i / num_queries; - runtime_config.StreamToken(query_idx, pos_per_query + batch_idx, - tokens[i], 0.0f); - } + // Transformer with one batch of tokens from a single query. + for (size_t layer = 0; layer < TConfig::kLayers; ++layer) { + const auto* layer_weights = weights.GetLayer(layer); + TransformerLayer( + tbatch_size, kPrefillQueries, pos + tbatch_start, layer, + layer_weights, activations, prefill_kv_caches, inner_pool); + } + + // NOTE: we unconditionally call StreamToken, even if EOS. + for (size_t ti = 0; ti < tbatch_size; ++ti) { + const int token = prompts[qi][tbatch_start + ti]; + runtime_config.StreamToken(query_idx_start + qi, + pos + tbatch_start + ti, token, 0.0f); + } + } // for tbatch_start }); } @@ -663,39 +770,15 @@ class PrefillState { } private: - // Pins each outer thread after their inner threads so they are likely to - // run on the same socket. - void PinThreads(size_t outer_workers, size_t workers_per_outer) { - outer_pool_->Run( - 0, outer_workers, - [this, workers_per_outer](uint64_t outer, size_t outer_thread) { - HWY_ASSERT(outer == outer_thread); - // Pins inner *and* `outer` - the latter is the calling thread. - inner_pools_[outer]->Run( - 0, workers_per_outer, - [outer, workers_per_outer](uint64_t task, size_t thread) { - HWY_ASSERT(task == thread); // each worker has one task - const size_t lp = outer * workers_per_outer + task; - hwy::PinThreadToLogicalProcessor(lp); - }); - }); - } - - void DeleteInnerPools() { - for (hwy::ThreadPool* p : inner_pools_) { - if (p != main_pool_) delete p; - } - inner_pools_.clear(); - } - hwy::ThreadPool* main_pool_; std::unique_ptr outer_pool_; // always allocated - std::vector activations_; // size == outer_pool->NumWorkers() - // Either there is a single pointer equal to main_pool, or newly created pools - // that we own. The former case avoids thread creation overhead for prompts - // that fit in a single batch. + // Holds a single pointer equal to main_pool_, or new allocations; in either + // case, size() is equal to outer_pool_->NumWorkers(). The first case avoids + // allocation overhead for the common case of a single query. std::vector inner_pools_; - size_t num_batches_ = 0; + + // size() == outer_pool_->NumWorkers(); filled by AllocateActivations. + std::vector activations_; }; // `tokens` is length `num_tokens * num_queries`. In autoregressive decode, @@ -705,8 +788,7 @@ HWY_NOINLINE void Transformer(const int* tokens, size_t num_tokens, size_t num_queries, size_t pos, const CompressedWeights& weights, Activations& activations, - const std::vector& kv_caches, - hwy::ThreadPool& pool, + const KVCaches& kv_caches, hwy::ThreadPool& pool, const LayersOutputFunc& layers_output) { const size_t num_interleaved = num_tokens * num_queries; if (layers_output) { @@ -718,7 +800,7 @@ HWY_NOINLINE void Transformer(const int* tokens, size_t num_tokens, constexpr size_t kModelDim = TConfig::kModelDim; for (size_t token_idx = 0; token_idx < num_interleaved; ++token_idx) { EmbedToken(tokens[token_idx], token_idx, pos, weights, - activations); + activations.x); } for (size_t layer = 0; layer < TConfig::kLayers; ++layer) { @@ -781,10 +863,10 @@ void RangeChecks(size_t& max_tokens, size_t& max_generated_tokens, // Returns interleaved tokens: one from each query, followed by the second from // all queries, with EOS padding. -static std::vector InterleaveQueries( - const hwy::Span>& queries, - const RuntimeConfig& runtime_config, size_t& min_prompt_size, - size_t& max_prompt_size) { +static std::vector InterleaveQueries(const MultiplePromptsTokens& queries, + const RuntimeConfig& runtime_config, + size_t& min_prompt_size, + size_t& max_prompt_size) { const size_t num_queries = queries.size(); min_prompt_size = hwy::LimitsMax(); max_prompt_size = 0; @@ -829,28 +911,34 @@ class TokenStreamer { private: const RuntimeConfig& runtime_config_; - // BitSet4096 divides the arg by 64, so ensure it is at least 64. - hwy::BitSet4096 is_eos_; + hwy::BitSet4096<> is_eos_; }; -// Generates one token per query in the batch. +// Generates one token for each query in `prompts`, which is one qbatch whose +// size is at most the `batch_size` passed to `activations.Allocate`. // -// pos indexes the KV cache. In the first turn of a chat, pos = 0, and it +// `pos` indexes the KV cache. In the first turn of a chat, pos = 0, and it // continues to increase by one for each prefilled/generated token per query. -// query_idx_start is the first query index in the batch. -template +// +// `query_idx_start` is the query_idx of the first query in the batch, so that +// `StreamFunc` gets the global query index, not relative to the batch. +// +// `kv_caches` is for the batch, size must match `prompts`. +template void GenerateT(const ByteStorageT& weights_u8, Activations& activations, const RuntimeConfig& runtime_config, - const hwy::Span>& prompts, const size_t pos, - const size_t query_idx_start, - const std::vector& kv_caches, hwy::ThreadPool& pool, - TimingInfo& timing_info) { + const MultiplePromptsTokens& prompts, const size_t pos, + const size_t query_idx_start, const KVCaches& kv_caches, + hwy::ThreadPool& pool, TimingInfo& timing_info) { constexpr size_t kVocabSize = TConfig::kVocabSize; const CompressedWeights& weights = *reinterpret_cast*>(weights_u8.get()); const size_t num_queries = prompts.size(); - HWY_DASSERT(num_queries <= kQueryBatchSize); + HWY_ASSERT(num_queries <= 4096); // TokenStreamer uses BitSet4096. + HWY_ASSERT(num_queries <= activations.x.BatchSize()); + HWY_ASSERT(kv_caches.size() == num_queries); + size_t min_prompt_size, max_prompt_size; const std::vector prompt = InterleaveQueries( prompts, runtime_config, min_prompt_size, max_prompt_size); @@ -877,28 +965,28 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations, // Prefill stops before min_prompt_size - 1 because the last prompt token is // the first input token for generation. const size_t prefill_per_query = min_prompt_size - 1; - const hwy::Span prefill_tokens(prompt.data(), - prefill_per_query * num_queries); - PrefillState prefill(pool); - prefill.Init(prefill_tokens.size()); - const double prefill_start = hwy::platform::Now(); - size_t interleaved_pos = pos * num_queries; - prefill.Prefill(prefill_tokens, num_queries, interleaved_pos, - weights, runtime_config, kv_caches); - interleaved_pos += prefill_tokens.size(); - timing_info.NotifyPrefill(prefill_tokens.size(), prefill_start); + double prefill_start; + { + PrefillState prefill(pool, num_queries); + prefill.AllocateActivations(num_queries, + runtime_config.prefill_tbatch_size); + prefill_start = hwy::platform::Now(); + prefill.Prefill(prompts, prefill_per_query, pos, query_idx_start, + weights, runtime_config, kv_caches); + timing_info.NotifyPrefill(prefill_per_query * num_queries, prefill_start); + prefill.ResumeMainSpinning(); + } - prefill.ResumeMainSpinning(); + size_t interleaved_pos = (pos + prefill_per_query) * num_queries; // Storage for the last generated token from each query, passed to the next // Transformer() call. std::vector gen_tokens(num_queries); // Stream the last prompt token from each query and fill gen_tokens. - hwy::CopyBytes(&prompt[prefill_tokens.size()], gen_tokens.data(), - num_queries * sizeof(prompt[0])); TokenStreamer token_streamer(runtime_config); for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) { + gen_tokens[query_idx] = prompts[query_idx][prefill_per_query]; (void)token_streamer(query_idx_start + query_idx, prefill_per_query, gen_tokens[query_idx], 0.0f); } @@ -940,42 +1028,49 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations, timing_info.NotifyGenerateDone(gen_start); } -// TODO: prompt should also be span, not a vector. template -void GenerateSingleT(const ByteStorageT& weights_u8, Activations& activations, +void GenerateSingleT(const ByteStorageT& weights_u8, const RuntimeConfig& runtime_config, - const std::vector& prompt, size_t pos, - KVCache& kv_cache, hwy::ThreadPool& pool, - TimingInfo& timing_info) { - const hwy::Span prompt_span(const_cast(prompt.data()), - prompt.size()); - const hwy::Span> prompts(&prompt_span, 1); - // TODO: also span of kv_cache, or batching inside KVCache? - std::vector kv_caches = {&kv_cache}; - const size_t query_idx_start = 0; - GenerateT( - weights_u8, activations, runtime_config, prompts, pos, query_idx_start, - kv_caches, pool, timing_info); + const PromptTokens& prompt, size_t pos, KVCache& kv_cache, + hwy::ThreadPool& pool, TimingInfo& timing_info) { + const size_t num_queries = 1; + const size_t qbatch_start = 0; + + Activations activations; + activations.Allocate(num_queries); + + const MultiplePromptsTokens prompts(&prompt, num_queries); + const KVCaches kv_caches{&kv_cache, num_queries}; + + GenerateT(weights_u8, activations, runtime_config, prompts, pos, + qbatch_start, kv_caches, pool, timing_info); } template -void GenerateBatchT(const ByteStorageT& weights_u8, Activations& activations, +void GenerateBatchT(const ByteStorageT& weights_u8, const RuntimeConfig& runtime_config, - const hwy::Span>& prompts, size_t pos, - const std::vector& kv_caches, - hwy::ThreadPool& pool, TimingInfo& timing_info) { - // Disable query batching for Griffin models. - constexpr size_t kQueryBatchSize = - (TConfig::kGriffinLayers > 0) ? 1 : kBatchedQueryBatchSize; - for (size_t query_idx_start = 0; query_idx_start < prompts.size(); - query_idx_start += kQueryBatchSize) { - const size_t num_queries = - std::min(prompts.size() - query_idx_start, kQueryBatchSize); - const hwy::Span> query_batch( - prompts.data() + query_idx_start, num_queries); - GenerateT(weights_u8, activations, runtime_config, - query_batch, pos, query_idx_start, - kv_caches, pool, timing_info); + const MultiplePromptsTokens& prompts, size_t pos, + const KVCaches& kv_caches, hwy::ThreadPool& pool, + TimingInfo& timing_info) { + HWY_ASSERT(prompts.size() == kv_caches.size()); + // Griffin does not support query batching. + const size_t max_qbatch_size = + (TConfig::kGriffinLayers > 0) ? 1 : runtime_config.decode_qbatch_size; + + Activations activations; + activations.Allocate(max_qbatch_size); + + const size_t num_queries = prompts.size(); + for (size_t qbatch_start = 0; qbatch_start < num_queries; + qbatch_start += max_qbatch_size) { + // Generate one batch of tokens from `qbatch_size` queries. + const size_t qbatch_size = + HWY_MIN(num_queries - qbatch_start, max_qbatch_size); + const MultiplePromptsTokens qbatch_prompts(&prompts[qbatch_start], + qbatch_size); + const KVCaches qbatch_kv(&kv_caches[qbatch_start], qbatch_size); + GenerateT(weights_u8, activations, runtime_config, qbatch_prompts, + pos, qbatch_start, qbatch_kv, pool, timing_info); } } @@ -986,24 +1081,20 @@ void GenerateBatchT(const ByteStorageT& weights_u8, Activations& activations, // These are extern functions defined by instantiations/*.cc, which include this // 'header' after defining GEMMA_CONFIG, which is for function overloading. void GenerateSingle( // NOLINT(misc-definitions-in-headers) - GEMMA_CONFIG, const ByteStorageT& weights_u8, Activations& activations, - const RuntimeConfig& runtime_config, const std::vector& prompt, - size_t pos, KVCache& kv_cache, hwy::ThreadPool& pool, - TimingInfo& timing_info) { + GEMMA_CONFIG, const ByteStorageT& weights_u8, + const RuntimeConfig& runtime_config, const PromptTokens& prompt, size_t pos, + KVCache& kv_cache, hwy::ThreadPool& pool, TimingInfo& timing_info) { HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateSingleT) - (weights_u8, activations, runtime_config, prompt, pos, kv_cache, pool, - timing_info); + (weights_u8, runtime_config, prompt, pos, kv_cache, pool, timing_info); } void GenerateBatch( // NOLINT(misc-definitions-in-headers) - GEMMA_CONFIG, const ByteStorageT& weights_u8, Activations& activations, - const RuntimeConfig& runtime_config, - const hwy::Span>& prompts, size_t pos, - const std::vector& kv_caches, hwy::ThreadPool& pool, + GEMMA_CONFIG, const ByteStorageT& weights_u8, + const RuntimeConfig& runtime_config, const MultiplePromptsTokens& prompts, + size_t pos, const KVCaches& kv_caches, hwy::ThreadPool& pool, TimingInfo& timing_info) { HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateBatchT) - (weights_u8, activations, runtime_config, prompts, pos, kv_caches, pool, - timing_info); + (weights_u8, runtime_config, prompts, pos, kv_caches, pool, timing_info); } #endif // HWY_ONCE diff --git a/gemma/gemma.cc b/gemma/gemma.cc index df6474f..d3b9b6a 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -24,32 +24,19 @@ #include #include // std::move -#include #include "compression/io.h" // Path -#include "gemma/activations.h" #include "gemma/common.h" #include "gemma/weights.h" -#include "hwy/aligned_allocator.h" // Span #include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/highway.h" namespace gcpp { -template -struct AllocateActivations { - void operator()(Activations& decode) const { - // TODO: this is wasted if we only have single-batch queries. Instead - // re-allocate when the query batch size is actually > 1? - decode.Allocate(kBatchedQueryBatchSize); - } -}; - Gemma::Gemma(const Path& tokenizer_path, const Path& weights, const ModelInfo& info, hwy::ThreadPool& pool) : pool_(pool), tokenizer_(tokenizer_path), info_(info) { weights_u8_ = LoadCompressedWeights(weights, info.model, info.weight, pool); - CallForModelAndWeight(info.model, info.weight, decode_); } Gemma::Gemma(GemmaTokenizer&& tokenizer, const ModelInfo& info, @@ -58,7 +45,6 @@ Gemma::Gemma(GemmaTokenizer&& tokenizer, const ModelInfo& info, HWY_ASSERT(info.weight == Type::kF32); weights_u8_ = CallForModel(info.model, pool); - CallForModelAndWeight(info.model, info.weight, decode_); } Gemma::~Gemma() { @@ -70,67 +56,64 @@ Gemma::~Gemma() { // we shard them across multiple translation units in instantiations/*.cc. // This declares the functions defined there. We use overloading because // explicit instantiations are still too slow to compile. -#define GEMMA_DECLARE(CONFIGT, TWEIGHT) \ - extern void GenerateSingle( \ - CONFIGT, const ByteStorageT& weights_u8, Activations& decode, \ - const RuntimeConfig& runtime_config, const std::vector& prompt, \ - size_t pos, KVCache& kv_cache, hwy::ThreadPool& pool, \ - TimingInfo& timing_info); \ - extern void GenerateBatch( \ - CONFIGT, const ByteStorageT& weights_u8, Activations& decode, \ - const RuntimeConfig& runtime_config, \ - const hwy::Span>& prompts, size_t pos, \ - const std::vector& kv_caches, hwy::ThreadPool& pool, \ - TimingInfo& timing_info); +#define GEMMA_DECLARE(CONFIGT, TWEIGHT) \ + extern void GenerateSingle(CONFIGT, const ByteStorageT& weights_u8, \ + const RuntimeConfig& runtime_config, \ + const PromptTokens& prompt, size_t pos, \ + KVCache& kv_cache, hwy::ThreadPool& pool, \ + TimingInfo& timing_info); \ + extern void GenerateBatch(CONFIGT, const ByteStorageT& weights_u8, \ + const RuntimeConfig& runtime_config, \ + const MultiplePromptsTokens& prompts, size_t pos, \ + const KVCaches& kv_caches, hwy::ThreadPool& pool, \ + TimingInfo& timing_info); GEMMA_FOREACH_CONFIG_AND_WEIGHT(GEMMA_DECLARE); // Adapters to select from the above overloads via CallForModelAndWeight. -// TODO: gather all ByteStorageT into a type-erased model struct? template struct GenerateSingleT { - void operator()(const ByteStorageT& weights_u8, Activations& decode, + void operator()(const ByteStorageT& weights_u8, const RuntimeConfig& runtime_config, - const std::vector& prompt, size_t pos, KVCache& kv_cache, + const PromptTokens& prompt, size_t pos, KVCache& kv_cache, hwy::ThreadPool& pool, TimingInfo& timing_info) const { - GenerateSingle(TConfig(), weights_u8, decode, runtime_config, prompt, pos, - kv_cache, pool, timing_info); + GenerateSingle(TConfig(), weights_u8, runtime_config, prompt, pos, kv_cache, + pool, timing_info); } }; template struct GenerateBatchT { - void operator()(const ByteStorageT& weights_u8, Activations& decode, + void operator()(const ByteStorageT& weights_u8, const RuntimeConfig& runtime_config, - const hwy::Span>& prompts, size_t pos, - const std::vector& kv_caches, hwy::ThreadPool& pool, + const MultiplePromptsTokens& prompts, size_t pos, + const KVCaches& kv_caches, hwy::ThreadPool& pool, TimingInfo& timing_info) const { - GenerateBatch(TConfig(), weights_u8, decode, runtime_config, prompts, pos, + GenerateBatch(TConfig(), weights_u8, runtime_config, prompts, pos, kv_caches, pool, timing_info); } }; void Gemma::Generate(const RuntimeConfig& runtime_config, - const std::vector& prompt, size_t start_pos, + const PromptTokens& prompt, size_t start_pos, KVCache& kv_cache, TimingInfo& timing_info) { pool_.SetWaitMode(hwy::PoolWaitMode::kSpin); - CallForModelAndWeight( - info_.model, info_.weight, weights_u8_, decode_, runtime_config, prompt, - start_pos, kv_cache, pool_, timing_info); + CallForModelAndWeight(info_.model, info_.weight, weights_u8_, + runtime_config, prompt, start_pos, + kv_cache, pool_, timing_info); pool_.SetWaitMode(hwy::PoolWaitMode::kBlock); } void Gemma::GenerateBatch(const RuntimeConfig& runtime_config, - const hwy::Span>& prompts, - size_t start_pos, - const std::vector& kv_caches, + const MultiplePromptsTokens& prompts, + size_t start_pos, const KVCaches& kv_caches, TimingInfo& timing_info) { pool_.SetWaitMode(hwy::PoolWaitMode::kSpin); - CallForModelAndWeight( - info_.model, info_.weight, weights_u8_, decode_, runtime_config, prompts, - start_pos, kv_caches, pool_, timing_info); + CallForModelAndWeight(info_.model, info_.weight, weights_u8_, + runtime_config, prompts, start_pos, + kv_caches, pool_, timing_info); pool_.SetWaitMode(hwy::PoolWaitMode::kBlock); } diff --git a/gemma/gemma.h b/gemma/gemma.h index 3caeab5..e734456 100644 --- a/gemma/gemma.h +++ b/gemma/gemma.h @@ -30,7 +30,7 @@ #include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/timer.h" // IWYU pragma: end_exports -#include "hwy/aligned_allocator.h" +#include "hwy/aligned_allocator.h" // Span #include "hwy/base.h" // hwy::bfloat16_t namespace gcpp { @@ -67,6 +67,13 @@ struct RuntimeConfig { size_t max_tokens; size_t max_generated_tokens; + + // These defaults are overridden by InferenceArgs::CopyTo(*this): + // Max tokens per batch during prefill. + size_t prefill_tbatch_size = 32; + // Max queries per batch (one token from each) during decode. + size_t decode_qbatch_size = 16; + float temperature; int verbosity; std::mt19937* gen; @@ -105,6 +112,10 @@ struct TimingInfo { size_t tokens_generated; }; +using PromptTokens = hwy::Span; +using MultiplePromptsTokens = hwy::Span; +using KVCaches = hwy::Span; + class Gemma { public: Gemma(const Path& tokenizer_path, const Path& weights, const ModelInfo& info, @@ -118,25 +129,20 @@ class Gemma { const ModelInfo& Info() const { return info_; } const GemmaTokenizer& Tokenizer() const { return tokenizer_; } const ByteStorageT& Weights() const { return weights_u8_; } - const Activations& Decode() const { return decode_; } - void Generate(const RuntimeConfig& runtime_config, - const std::vector& prompt, size_t start_pos, - KVCache& kv_cache, TimingInfo& timing_info); + void Generate(const RuntimeConfig& runtime_config, const PromptTokens& prompt, + size_t start_pos, KVCache& kv_cache, TimingInfo& timing_info); void GenerateBatch(const RuntimeConfig& runtime_config, - const hwy::Span>& prompts, - size_t start_pos, const std::vector& kv_caches, - TimingInfo& timing_info); + const MultiplePromptsTokens& prompts, size_t start_pos, + const KVCaches& kv_caches, TimingInfo& timing_info); private: hwy::ThreadPool& pool_; GemmaTokenizer tokenizer_; - // Type-erased so that this can be defined in the header, without requiring - // forwarding functions. + // Type-erased so that this can be defined in the header. ByteStorageT weights_u8_; - Activations decode_; ModelInfo info_; }; diff --git a/gemma/kv_cache.cc b/gemma/kv_cache.cc index 10e76e7..d12c9af 100644 --- a/gemma/kv_cache.cc +++ b/gemma/kv_cache.cc @@ -23,13 +23,16 @@ namespace gcpp { namespace { template struct CreateKVCache { - KVCache operator()() const { + KVCache operator()(size_t prefill_tbatch_size) const { KVCache kv_cache = {}; const size_t size_cache_pos = CachePosSize()(); if (size_cache_pos != 0) { - const size_t seq_len = (TConfig::kSeqLen + kPrefillBatchSize); - kv_cache.kv_cache = hwy::AllocateAligned(seq_len * size_cache_pos); + // Allocate more so that prefill can always access one batch, even if + // near the end of the sequence. + kv_cache.seq_len = TConfig::kSeqLen + prefill_tbatch_size; + kv_cache.kv_cache = + hwy::AllocateAligned(kv_cache.seq_len * size_cache_pos); } // TODO(patrickms): Add query batching support for Griffin. @@ -58,10 +61,13 @@ struct CreateKVCache { }; } // namespace -KVCache KVCache::Create(Model model_type) { +// prefill_tbatch_size is the maximum number of tokens from one query to +// prefill at a time. +KVCache KVCache::Create(Model model_type, size_t prefill_tbatch_size) { // TWeight=float is a placeholder and unused because CreateKVCache does not // use TConfig::Weight. - return CallForModel(model_type); + return CallForModel(model_type, + prefill_tbatch_size); } } // namespace gcpp diff --git a/gemma/kv_cache.h b/gemma/kv_cache.h index 1c92b40..65b40c1 100644 --- a/gemma/kv_cache.h +++ b/gemma/kv_cache.h @@ -16,13 +16,17 @@ #ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_KV_CACHE_H_ #define THIRD_PARTY_GEMMA_CPP_GEMMA_KV_CACHE_H_ +#include + #include "gemma/common.h" // Model #include "hwy/aligned_allocator.h" namespace gcpp { struct KVCache { - // kSeqLen * kGemmaLayers * kKVHeads * kQKVDim * 2 + size_t seq_len = 0; // = kSeqLen + prefill_tbatch_size + + // seq_len * kGemmaLayers * kKVHeads * kQKVDim * 2 hwy::AlignedFreeUniquePtr kv_cache; // (kConv1dWidth - 1) * kModelDim * kGriffinLayers @@ -31,7 +35,7 @@ struct KVCache { // kModelDim * kGriffinLayers hwy::AlignedFreeUniquePtr rglru_cache; - static KVCache Create(Model type); + static KVCache Create(Model type, size_t prefill_tbatch_size); }; } // namespace gcpp diff --git a/gemma/run.cc b/gemma/run.cc index 2d33827..ce4902f 100644 --- a/gemma/run.cc +++ b/gemma/run.cc @@ -145,14 +145,12 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, hwy::ThreadPool& pool, TimingInfo timing_info; RuntimeConfig runtime_config = { - .max_tokens = args.max_tokens, - .max_generated_tokens = args.max_generated_tokens, - .temperature = args.temperature, .verbosity = verbosity, .gen = &gen, .stream_token = stream_token, .accept_token = accept_token, }; + args.CopyTo(runtime_config); model.Generate(runtime_config, prompt, abs_pos, kv_cache, timing_info); if (verbosity >= 2) { std::cout << current_pos << " tokens (" << abs_pos << " total tokens)" @@ -181,7 +179,8 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) { } Gemma model = CreateGemma(loader, pool); - KVCache kv_cache = KVCache::Create(model.Info().model); + KVCache kv_cache = + KVCache::Create(model.Info().model, inference.prefill_tbatch_size); if (app.verbosity >= 1) { std::string instructions = diff --git a/util/app.h b/util/app.h index bbd2047..160d120 100644 --- a/util/app.h +++ b/util/app.h @@ -248,6 +248,9 @@ struct InferenceArgs : public ArgsBase { size_t max_tokens; size_t max_generated_tokens; + size_t prefill_tbatch_size; + size_t decode_qbatch_size; + float temperature; bool deterministic; bool multiturn; @@ -272,6 +275,11 @@ struct InferenceArgs : public ArgsBase { visitor(max_generated_tokens, "max_generated_tokens", size_t{2048}, "Maximum number of tokens to generate."); + visitor(prefill_tbatch_size, "prefill_tbatch", size_t{64}, + "Prefill: max tokens per batch."); + visitor(decode_qbatch_size, "decode_qbatch", size_t{16}, + "Decode: max queries per batch."); + visitor(temperature, "temperature", 1.0f, "Temperature for top-K", 2); visitor(deterministic, "deterministic", false, "Make top-k sampling deterministic", 2); @@ -281,6 +289,16 @@ struct InferenceArgs : public ArgsBase { " Default : 0 (conversation " "resets every turn)"); } + + void CopyTo(RuntimeConfig& runtime_config) const { + runtime_config.max_tokens = max_tokens; + runtime_config.max_generated_tokens = max_generated_tokens; + + runtime_config.prefill_tbatch_size = prefill_tbatch_size; + runtime_config.decode_qbatch_size = decode_qbatch_size; + + runtime_config.temperature = temperature; + } }; } // namespace gcpp