From da7507e6f0f6038deb33455c74bdc8570c05f5df Mon Sep 17 00:00:00 2001 From: "The gemma.cpp Authors" Date: Mon, 1 Jul 2024 07:50:53 -0700 Subject: [PATCH] Add prompt batching to Gemma.cpp. This CL adds a new function to Gemma that allows for batching of multiple prompts. The function takes a vector of prompts and returns a vector of responses. The prompts are processed in parallel, and the responses are returned in the same order as the prompts. PiperOrigin-RevId: 648367559 --- BUILD.bazel | 13 + gemma/benchmark_helper.cc | 79 +++++- gemma/benchmark_helper.h | 8 +- gemma/gemma.cc | 517 +++++++++++++++++++++++++++----------- gemma/gemma.h | 14 ++ gemma/gemma_test.cc | 70 +++++- 6 files changed, 542 insertions(+), 159 deletions(-) diff --git a/BUILD.bazel b/BUILD.bazel index a4f7636..c744213 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -103,6 +103,10 @@ cc_library( "gemma/activations.h", "gemma/gemma.h", ], + exec_properties = { + # Avoid linker OOMs when building with sanitizer instrumentation. + "mem": "28g", + }, textual_hdrs = [ # Placeholder for internal file1, do not remove, # Placeholder for internal file2, do not remove, @@ -194,6 +198,7 @@ cc_test( ":ops", "@googletest//:gtest_main", "//compression:io", + "@hwy//:hwy", "@hwy//:hwy_test_util", "@hwy//:thread_pool", ], @@ -370,6 +375,10 @@ cc_test( "backprop/backward_test.cc", "backprop/test_util.h", ], + exec_properties = { + # Avoid linker OOMs when building with sanitizer instrumentation. + "mem": "28g", + }, deps = [ ":backprop", ":backprop_scalar", @@ -406,6 +415,10 @@ cc_test( srcs = [ "backprop/optimize_test.cc", ], + exec_properties = { + # Avoid linker OOMs when building with sanitizer instrumentation. + "mem": "28g", + }, deps = [ ":backprop", ":common", diff --git a/gemma/benchmark_helper.cc b/gemma/benchmark_helper.cc index c80b236..68dbbee 100644 --- a/gemma/benchmark_helper.cc +++ b/gemma/benchmark_helper.cc @@ -18,6 +18,8 @@ #include #include +#include +#include #include #include #include @@ -34,6 +36,7 @@ #include "gemma/gemma.h" #include "util/app.h" #include "util/args.h" +#include "hwy/aligned_allocator.h" #include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/highway.h" @@ -72,7 +75,11 @@ GemmaEnv::GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference, } else { fprintf(stderr, "Loading model...\n"); model_ = AllocateGemma(loader_, pool_); - kv_cache_ = KVCache::Create(loader_.ModelType()); + + kv_caches_.reserve(16); + for (int i = 0; i < 16; ++i) { + kv_caches_.push_back(new KVCache(KVCache::Create(loader_.ModelType()))); + } } InitGenerator(inference_args_, gen_); @@ -107,8 +114,9 @@ std::pair GemmaEnv::QueryModel( size_t total_tokens = 0; const double time_start = hwy::platform::Now(); - const StreamFunc stream_token = [&res, &total_tokens, &time_start, this]( - int token, float) { + const BatchStreamFunc batch_stream_token = + [&res, &total_tokens, &time_start, this]( + size_t query_index, size_t pos, int token, float) { ++total_tokens; res += StringFromTokens(std::vector{token}); if (app_.verbosity >= 1 && total_tokens % 128 == 0) { @@ -123,8 +131,8 @@ std::pair GemmaEnv::QueryModel( << "\ttemperature: " << inference_args_.temperature << "\n"; } gcpp::TimingInfo timing_info; - runtime_config_.stream_token = stream_token; - model_->Generate(runtime_config_, tokens, /*start_pos=*/0, kv_cache_, + runtime_config_.batch_stream_token = batch_stream_token; + model_->Generate(runtime_config_, tokens, /*start_pos=*/0, *kv_caches_[0], timing_info); if (app_.verbosity >= 1) { LogSpeedStats(time_start, total_tokens); @@ -132,12 +140,73 @@ std::pair GemmaEnv::QueryModel( return {res, total_tokens}; } +std::vector> GemmaEnv::BatchQueryModel2( + const hwy::Span>& prompts) { + std::vector> res(prompts.size()); + std::fill(res.begin(), res.end(), std::make_pair("", 0)); + size_t total_tokens = 0; + + const double time_start = hwy::platform::Now(); + const BatchStreamFunc batch_stream_token = + [&res, &total_tokens, &time_start, this]( + size_t query_index, size_t pos, int token, float) { + std::string token_text; + HWY_ASSERT( + model_->Tokenizer().Decode(std::vector{token}, &token_text)); + // fprintf(stderr, "Query %zu returned token \"%s\"\n\n", query_index, + // token_text.c_str()); + std::string single_res = res[query_index].first + token_text; + size_t current_token_count = res[query_index].second + 1; + res[query_index] = std::make_pair(single_res, current_token_count); + + ++total_tokens; + if (app_.verbosity >= 1 && total_tokens % 128 == 0) { + LogSpeedStats(time_start, total_tokens); + } + return true; + }; + if (app_.verbosity >= 2) { + std::cout << inference_args_.max_tokens << " " + << inference_args_.max_generated_tokens << " " + << inference_args_.temperature; + } + gcpp::TimingInfo timing_info; + runtime_config_.batch_stream_token = batch_stream_token; + model_->GenerateBatch(runtime_config_, prompts, /*start_pos=*/0, kv_caches_, + timing_info); + if (app_.verbosity >= 1) { + LogSpeedStats(time_start, total_tokens); + } + return res; +} + std::pair GemmaEnv::QueryModel(std::string& input) { const std::vector prompt = WrapAndTokenize(model_->Tokenizer(), loader_.ModelTrainingType(), /*pos=*/0, input); return QueryModel(prompt); } +std::vector> GemmaEnv::BatchQueryModel( + const std::vector& inputs) { + std::vector>> prompts; + prompts.reserve(inputs.size()); + for (auto& input : inputs) { + std::string mutable_prompt = input; + prompts.push_back(std::make_unique>( + WrapAndTokenize(model_->Tokenizer(), + loader_.ModelTrainingType(), + /*pos=*/0, mutable_prompt))); + } + std::vector> prompt_vector; + prompt_vector.reserve(prompts.size()); + for (auto& prompt : prompts) { + prompt_vector.push_back(hwy::Span( + prompt->data(), prompt->size())); + } + hwy::Span> prompt_span = hwy::Span>( + prompt_vector.data(), prompt_vector.size()); + return BatchQueryModel2(prompt_span); +} float GemmaEnv::CrossEntropy(const std::string& input) { std::vector prompt = Tokenize(input); diff --git a/gemma/benchmark_helper.h b/gemma/benchmark_helper.h index 3c5b63a..3909606 100644 --- a/gemma/benchmark_helper.h +++ b/gemma/benchmark_helper.h @@ -69,8 +69,12 @@ class GemmaEnv { // Runs inference on the given input and returns the top-1 result string and // the number of tokens that were generated. std::pair QueryModel(const std::vector& tokens); + std::vector> BatchQueryModel2( + const hwy::Span>& prompts); // Adds turn structure to input, tokenizes and calls the above overload. std::pair QueryModel(std::string& input); + std::vector> BatchQueryModel( + const std::vector& inputs); // Runs inference on the given input and returns the cross entropy, a measure // of how well the model predicts the correct output. It is the average @@ -87,7 +91,7 @@ class GemmaEnv { RuntimeConfig& MutableConfig() { return runtime_config_; } InferenceArgs& MutableInferenceArgs() { return inference_args_; } std::mt19937& MutableGen() { return gen_; } - KVCache& MutableKVCache() { return kv_cache_; } + KVCache& MutableKVCache() { return *kv_caches_[0]; } private: // Arguments to the model loader: file locations, etc. @@ -103,7 +107,7 @@ class GemmaEnv { // The model to run inference on. std::unique_ptr model_; // The KV cache to use for inference. - KVCache kv_cache_; + std::vector kv_caches_; RuntimeConfig runtime_config_; }; diff --git a/gemma/gemma.cc b/gemma/gemma.cc index c035947..09a5aab 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -75,23 +75,30 @@ struct Activations { static constexpr size_t kQStride = kQKVDim * (kIsMHA ? 3 : 1); std::array x; // input - std::array pre_att_rms_out; - std::array q; // query vector + std::array + q; // query vector std::array - att; // attention vector - std::array att_out; // attention output + att; // attention vector + std::array + att_out; // attention output std::array att_post1; // attention output after linear transformation, per head std::array att_post2; // accumulation of attention outputs over heads + std::array + bf_pre_ffw_rms_out; + std::array + ffw_hidden; - std::array bf_pre_ffw_rms_out; - std::array ffw_hidden; - std::array C1; // MatMul output + // For FFW MatMul. + std::array C1; std::array C2; - std::array ffw_out; + // bf_ version can't be used until GeluMulToBF16 issue in FFW() is resolved. + // std::array + // bf_ffw_hidden; + std::array ffw_out; std::array logits; // For bf16/f32 vectors * bf16 matrix: faster to unpack once beforehand, into @@ -104,7 +111,8 @@ struct Activations { std::array griffin_x; std::array griffin_y; std::array griffin_gate_x; - std::array griffin_multiplier; + std::array + griffin_multiplier; }; namespace { @@ -116,10 +124,12 @@ struct CreateKVCache { const size_t size_cache_pos = CachePosSize()(); if (size_cache_pos != 0) { - const size_t seq_len = TConfig::kSeqLen + kPrefillBatchSize; + const size_t seq_len = + (TConfig::kSeqLen + kPrefillBatchSize); kv_cache.kv_cache = hwy::AllocateAligned(seq_len * size_cache_pos); } + // TODO(patrickms): Add query batching support for Griffin. if (TConfig::kGriffinLayers) { constexpr size_t kConv1dWidth = TConfig::kConv1dWidth; const size_t conv1d_cache_size = @@ -226,19 +236,24 @@ namespace gcpp { namespace HWY_NAMESPACE { namespace { -template +template HWY_NOINLINE void GriffinRecurrent( - size_t batch_start, size_t num_tokens, size_t layer, - Activations& activations, - const CompressedLayer* layer_weights, KVCache& kv_cache, - hwy::ThreadPool& pool) { + size_t batch_start, size_t num_tokens, size_t num_queries, size_t layer, + Activations& activations, + const CompressedLayer* layer_weights, + const std::vector& kv_caches, hwy::ThreadPool& pool) { PROFILER_ZONE("Gen.Griffin"); + static_assert(kQueryBatchSize == 1, + "Griffin does not support batched queries."); + HWY_DASSERT(num_queries == 1); // TODO: add batch query support for Griffin. + KVCache& kv_cache = *kv_caches[0]; namespace hn = hwy::HWY_NAMESPACE; using D = hn::ScalableTag; HWY_DASSERT(num_tokens <= kBatchSize); - constexpr size_t kModelDim = Activations::kModelDim; - constexpr size_t kConv1dWidth = TConfig::kConv1dWidth; - constexpr size_t kHeads = TConfig::kHeads; + static constexpr size_t kModelDim = + gcpp::Activations::kModelDim; + static constexpr size_t kConv1dWidth = TConfig::kConv1dWidth; + static constexpr size_t kHeads = TConfig::kHeads; // X / Y linear layers. for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) { @@ -357,14 +372,19 @@ HWY_NOINLINE void GriffinRecurrent( } } -template -HWY_NOINLINE void Attention(size_t batch_start, size_t num_tokens, size_t layer, - Activations& activations, - const CompressedLayer* layer_weights, - KVCache& kv_cache, hwy::ThreadPool& pool) { +template +HWY_NOINLINE void Attention( + size_t batch_and_query_start, size_t num_tokens, size_t num_queries, + size_t layer, + Activations& activations, + const CompressedLayer* layer_weights, + const std::vector& kv_caches, + hwy::ThreadPool& pool) { PROFILER_ZONE("Gen.Attention"); HWY_DASSERT(num_tokens <= kBatchSize); - using TActivations = Activations; + HWY_DASSERT(num_queries <= kQueryBatchSize); + HWY_DASSERT(batch_and_query_start % num_queries == 0); + using TActivations = Activations; constexpr size_t kQKVDim = TActivations::kQKVDim; constexpr size_t kQStride = TActivations::kQStride; constexpr size_t kCachePosSize = CachePosSize()(); @@ -376,15 +396,22 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t num_tokens, size_t layer, GEMMA_CONSTEXPR_SQRT const float kQueryScale = 1.0f / Sqrt(static_cast(kQKVDim)); constexpr bool kIsMHA = TActivations::kIsMHA; // Multi-Head Attention + const size_t batch_start = batch_and_query_start / num_queries; + const size_t num_tokens_and_queries = num_tokens * num_queries; // If MHA, this also computes KV, which we copy to the KV cache below. static_assert(!kIsMHA || TConfig::kInterleaveQKV); // MHA => interleaved MatMul_4x4_Batch( - num_tokens, activations.pre_att_rms_out.data(), + num_tokens_and_queries, activations.pre_att_rms_out.data(), layer_weights->qkv_einsum_w.data(), activations.q.data(), pool); - for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) { - const float* x = activations.pre_att_rms_out.data() + batch_idx * kModelDim; + for (size_t batch_and_query_idx = 0; + batch_and_query_idx < num_tokens_and_queries; ++batch_and_query_idx) { + const float* x = activations.pre_att_rms_out.data() + batch_and_query_idx + * kModelDim; + const size_t query_idx = batch_and_query_idx % num_queries; + const size_t batch_idx = batch_and_query_idx / num_queries; + KVCache& kv_cache = *kv_caches[query_idx]; // QKV projections: if constexpr (!kIsMHA) { const size_t pos = batch_start + batch_idx; @@ -401,18 +428,23 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t num_tokens, size_t layer, // Positional encodings for kv: pool.Run( - 0, kKVHeads * num_tokens, [&](uint64_t task, size_t thread) HWY_ATTR { + 0, kKVHeads * num_tokens_and_queries, + [&](uint64_t task, size_t thread) HWY_ATTR { const size_t head = task % kKVHeads; - const size_t batch_idx = task / kKVHeads; + const size_t batch_and_query_idx = task / kKVHeads; + const size_t query_idx = batch_and_query_idx % num_queries; + const size_t batch_idx = batch_and_query_idx / num_queries; const size_t pos = batch_start + batch_idx; const size_t cache_pos = pos % (kSeqLen + kPrefillBatchSize); const size_t kv_offset = cache_pos * kCachePosSize + layer * kCacheLayerSize + head * kQKVDim * 2; + KVCache& kv_cache = *kv_caches[query_idx]; float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset; if constexpr (kIsMHA) { // For MHA, copy kv into the KV cache from scratch space (see above). const float* HWY_RESTRICT q = - activations.q.data() + (batch_idx * kHeads + head) * kQStride; + activations.q.data() + (batch_and_query_idx * kHeads + + head) * kQStride; // Skip past the Q part of `q`, and copy KV to `kv`. memcpy(kv, q + kQKVDim, 2 * kQKVDim * sizeof(float)); } @@ -422,17 +454,22 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t num_tokens, size_t layer, static_assert((kHeads % kKVHeads) == 0, "query heads must be a multiple of key-value heads"); constexpr size_t kGroupHeads = kHeads / kKVHeads; - pool.Run(0, kHeads * num_tokens, [&](uint64_t task, size_t thread) HWY_ATTR { + pool.Run(0, kHeads * num_tokens_and_queries, + [&](uint64_t task, size_t thread) HWY_ATTR { const size_t head = task % kHeads; - const size_t batch_idx = task / kHeads; + const size_t batch_and_query_idx = task / kHeads; + const size_t query_idx = batch_and_query_idx % num_queries; + const size_t batch_idx = batch_and_query_idx / num_queries; const size_t head_offset = (head / kGroupHeads) * kQKVDim * 2; + KVCache& kv_cache = *kv_caches[query_idx]; float* HWY_RESTRICT q = - activations.q.data() + (batch_idx * kHeads + head) * kQStride; + activations.q.data() + (batch_and_query_idx * kHeads + head) * kQStride; const size_t pos = batch_start + batch_idx; // Calculate scores float* HWY_RESTRICT head_att = - activations.att.data() + head * kSeqLen + batch_idx * kHeads * kSeqLen; + activations.att.data() + head * kSeqLen + + batch_and_query_idx * kHeads * kSeqLen; Rope(q, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos); MulByConst(kQueryScale, q, kQKVDim); @@ -451,7 +488,7 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t num_tokens, size_t layer, // Weighted summation float* HWY_RESTRICT att_out = activations.att_out.data() + head * kQKVDim + - batch_idx * kHeads * kQKVDim; + batch_and_query_idx * kHeads * kQKVDim; hwy::ZeroBytes(att_out, kQKVDim * sizeof(*att_out)); for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) { const size_t cache_pos = pos2 % (kSeqLen + kPrefillBatchSize); @@ -462,20 +499,23 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t num_tokens, size_t layer, } }); - for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) { + for (size_t batch_and_query_idx = 0; + batch_and_query_idx < num_tokens_and_queries; ++batch_and_query_idx) { // TODO(szabadka) Use a single MatVecAdd like in GriffinRecurrent() after // rearranging the weights. float* HWY_RESTRICT att_out = - activations.att_out.data() + batch_idx * kHeads * kQKVDim; + activations.att_out.data() + batch_and_query_idx * kHeads * kQKVDim; float* HWY_RESTRICT layer_out = - activations.att_post2.data() + batch_idx * kModelDim; + activations.att_post2.data() + batch_and_query_idx * kModelDim; MatVecT( layer_weights->attn_vec_einsum_w, 0, att_out, layer_weights->attention_output_biases.data(), activations.even_odd.data(), layer_out, pool); for (size_t head = 1; head < kHeads; ++head) { + // TODO(patrickms): Check this calculation float* HWY_RESTRICT head_out = - activations.att_post1.data() + head * kBatchSize * kModelDim; + activations.att_post1.data() + + head * kBatchSize * kQueryBatchSize * kModelDim; // TODO: requires MatMul support for offsets. MatVec( layer_weights->attn_vec_einsum_w, head * kModelDim * kQKVDim, @@ -583,103 +623,138 @@ HWY_NOINLINE void EmbedToken(int token, size_t token_idx, size_t pos, }; } -template +template HWY_NOINLINE void TransformerLayer( - size_t num_tokens, size_t pos, size_t layer, + size_t num_tokens, size_t num_queries, size_t pos, size_t layer, const CompressedLayer* layer_weights, - Activations& activations, KVCache& kv_cache, - hwy::ThreadPool& pool) { + Activations& activations, + const std::vector& kv_caches, hwy::ThreadPool& pool) { constexpr size_t kModelDim = TConfig::kModelDim; + const size_t num_tokens_and_queries = num_tokens * num_queries; auto type = TConfig::kLayerConfig[layer]; size_t layer_of_type = NumLayersOfTypeBefore(TConfig::kLayerConfig, type, layer); - RMSNormBatched(num_tokens, activations.x.data(), - layer_weights->pre_attention_norm_scale.data(), - activations.pre_att_rms_out.data(), kModelDim); + RMSNormBatched( + num_tokens_and_queries, activations.x.data(), + layer_weights->pre_attention_norm_scale.data(), + activations.pre_att_rms_out.data(), kModelDim); if (type == LayerAttentionType::kGemma) { - Attention(pos, num_tokens, layer_of_type, activations, layer_weights, - kv_cache, pool); + Attention( + pos, num_tokens, num_queries, layer_of_type, activations, + layer_weights, kv_caches, pool); } else { - GriffinRecurrent(pos, num_tokens, layer_of_type, activations, layer_weights, - kv_cache, pool); + // This Griffin layers should never exist unless the model is a Griffin + // model. This conditional prevents the compiler from generating code for + // this branch when building a non-Griffin model, since we have static + // asserts about the query batch size for Griffin layers. + if constexpr (TConfig::kGriffinLayers > 0) { + GriffinRecurrent( + pos, num_tokens, num_queries, layer_of_type, activations, + layer_weights, kv_caches, pool); + } } if (TConfig::kPostNormScale) { - RMSNormInplaceBatched( - num_tokens, layer_weights->post_attention_norm_scale.data(), + RMSNormInplaceBatched( + num_tokens_and_queries, + layer_weights->post_attention_norm_scale.data(), activations.att_post2.data(), kModelDim); } - AddFromBatched(num_tokens, activations.att_post2.data(), - activations.x.data(), kModelDim); - RMSNormBatched(num_tokens, activations.x.data(), - layer_weights->pre_ffw_norm_scale.data(), - activations.bf_pre_ffw_rms_out.data(), kModelDim); - FFW(activations, num_tokens, layer_weights, pool); + AddFromBatched(num_tokens_and_queries, + activations.att_post2.data(), + activations.x.data(), kModelDim); + RMSNormBatched( + num_tokens_and_queries, activations.x.data(), + layer_weights->pre_ffw_norm_scale.data(), + activations.bf_pre_ffw_rms_out.data(), kModelDim); + FFW( + activations, num_tokens_and_queries, layer_weights, pool); if (TConfig::kPostNormScale) { - RMSNormInplaceBatched(num_tokens, - layer_weights->post_ffw_norm_scale.data(), - activations.ffw_out.data(), kModelDim); + RMSNormInplaceBatched( + num_tokens_and_queries, layer_weights->post_ffw_norm_scale.data(), + activations.ffw_out.data(), kModelDim); } - AddFromBatched(num_tokens, activations.ffw_out.data(), - activations.x.data(), kModelDim); + AddFromBatched( + num_tokens_and_queries, activations.ffw_out.data(), + activations.x.data(), kModelDim); } -template -HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos, - const CompressedWeights& weights, - Activations& activations, - KVCache& kv_cache, hwy::ThreadPool& pool) { +template +HWY_NOINLINE void Prefill( + const int* tokens, size_t num_tokens, size_t num_queries, size_t pos, + const CompressedWeights& weights, + Activations& activations, + const std::vector& kv_caches, hwy::ThreadPool& pool) { + HWY_DASSERT(num_queries <= kQueryBatchSize); + const size_t minibatch_size = std::min(num_tokens, kBatchSize); PROFILER_ZONE("Gen.Prefill\\Att\\FFW"); + // TODO(patrickms): Try to hoist pool.Run out of the loop. + for (size_t i = 0; i < num_tokens; i += minibatch_size) { + const size_t offset = i * num_queries; + const size_t current_token_count = std::min( + minibatch_size, num_tokens - i); + pool.Run(0, current_token_count * num_queries, + [&](const uint64_t token_idx, size_t /*thread*/) HWY_ATTR { + EmbedToken( + tokens[token_idx + offset], token_idx, pos + offset, + weights, activations); + }); - pool.Run( - 0, num_tokens, [&](const uint64_t token_idx, size_t /*thread*/) HWY_ATTR { - EmbedToken(tokens[token_idx], token_idx, pos, weights, activations); - }); - - for (size_t layer = 0; layer < TConfig::kLayers; ++layer) { - const auto* layer_weights = weights.GetLayer(layer); - TransformerLayer(num_tokens, pos, layer, layer_weights, activations, - kv_cache, pool); + for (size_t layer = 0; layer < TConfig::kLayers; ++layer) { + const auto* layer_weights = weights.GetLayer(layer); + TransformerLayer( + current_token_count, num_queries, pos + offset , layer, layer_weights, + activations, kv_caches, pool); + } } } // Compute the transformer for a batch of input tokens. During generation, // we usually have num_tokens == 1 (and also kBatchSize == 1). -template -HWY_NOINLINE void Transformer(const int* tokens, size_t num_tokens, size_t pos, - const CompressedWeights& weights, - Activations& activations, - KVCache& kv_cache, hwy::ThreadPool& pool, - const LayersOutputFunc& layers_output) { +template +HWY_NOINLINE void Transformer( + const int* tokens, size_t num_tokens, size_t num_queries, size_t pos, + const CompressedWeights& weights, + Activations& activations, + const std::vector& kv_caches, + hwy::ThreadPool& pool, + const LayersOutputFunc& layers_output) { HWY_ASSERT(num_tokens <= kBatchSize); + const size_t num_tokens_and_queries = num_tokens * num_queries; if (layers_output) { - for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) { + for (size_t token_idx = 0; token_idx < num_tokens_and_queries; + ++token_idx) { float token_f = tokens[token_idx]; layers_output(pos + token_idx, "Tokens", &token_f, 1); } } constexpr size_t kModelDim = TConfig::kModelDim; - for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) { - EmbedToken(tokens[token_idx], token_idx, pos, weights, activations); + for (size_t token_idx = 0; token_idx < num_tokens_and_queries; ++token_idx) { + EmbedToken( + tokens[token_idx], token_idx, pos, weights, activations); } for (size_t layer = 0; layer < TConfig::kLayers; ++layer) { const CompressedLayer* layer_weights = weights.GetLayer(layer); - TransformerLayer(num_tokens, pos, layer, layer_weights, activations, - kv_cache, pool); + TransformerLayer( + num_tokens, num_queries, pos, layer, layer_weights, + activations, kv_caches, pool); if (layers_output) { const std::string block_name = "blocks." + std::to_string(layer); - for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) { + for (size_t token_idx = 0; token_idx < num_tokens_and_queries; + ++token_idx) { layers_output(pos + token_idx, block_name, activations.x.data() + token_idx * kModelDim, kModelDim); } } } - RMSNormInplaceBatched(num_tokens, weights.final_norm_scale.data(), - activations.x.data(), kModelDim); + RMSNormInplaceBatched( + num_tokens * num_queries, weights.final_norm_scale.data(), + activations.x.data(), kModelDim); if (layers_output) { - for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) { + for (size_t token_idx = 0; token_idx < num_tokens_and_queries; + ++token_idx) { layers_output(pos + token_idx, "final_norm", activations.x.data() + token_idx * kModelDim, kModelDim); } @@ -719,31 +794,69 @@ void RangeChecks(size_t& max_tokens, size_t& max_generated_tokens, } template -Activations& GetActivations(const ByteStorageT& state_u8) { - return *reinterpret_cast*>(state_u8.get()); +Activations& GetActivations( + const ByteStorageT& state_u8) { + return *reinterpret_cast*>( + state_u8.get()); } } // namespace // Placeholder for internal test3, do not remove -template +bool StreamToken(size_t query_idx, size_t pos, int token, float weight, + const RuntimeConfig& runtime_config) { + if (runtime_config.batch_stream_token) { + return runtime_config.batch_stream_token(query_idx, pos, token, weight); + } + return runtime_config.stream_token(token, weight); +} + +template void GenerateT(const ByteStorageT& weights_u8, const ByteStorageT& prefill_u8, const ByteStorageT& decode_u8, const RuntimeConfig& runtime_config, - const std::vector& prompt, size_t pos, KVCache& kv_cache, - hwy::ThreadPool& pool, TimingInfo& timing_info) { + const hwy::Span>& prompts, size_t pos, + const size_t query_index_offset, + const std::vector& kv_caches, hwy::ThreadPool& pool, + TimingInfo& timing_info) { + constexpr size_t kAdjustedPrefillBatchSize = + std::max((size_t)1, kPrefillBatchSize / kQueryBatchSize); + static_assert(kAdjustedPrefillBatchSize >= kMinAdjustedPrefillBatchSize); + const size_t num_queries = prompts.size(); + HWY_DASSERT(num_queries <= kQueryBatchSize); + pos *= num_queries; // position in (num_queries) interleaved token sequence. const CompressedWeights& weights = *reinterpret_cast*>(weights_u8.get()); auto& prefill_activations = - GetActivations(prefill_u8); - auto& activations = GetActivations(decode_u8); + GetActivations(prefill_u8); + auto& activations = GetActivations(decode_u8); + + size_t min_prompt_size = (size_t)-1; + size_t max_prompt_size = 0; + for (int i=0; i < prompts.size(); ++i) { + min_prompt_size = std::min(min_prompt_size, prompts[i].size()); + max_prompt_size = std::max(max_prompt_size, prompts[i].size()); + } + + std::vector prompt; + prompt.reserve(max_prompt_size * prompts.size()); + for (int i = 0; i < max_prompt_size; ++i) { + for (int j=0; j < prompts.size(); ++j) { + if (i < prompts[j].size()) { + prompt.push_back(prompts[j][i]); + } else { + prompt.push_back(0); + } + } + } constexpr size_t kVocabSize = TConfig::kVocabSize; - size_t prompt_size = prompt.size(); + size_t max_tokens = runtime_config.max_tokens; size_t max_generated_tokens = runtime_config.max_generated_tokens; - RangeChecks(max_tokens, max_generated_tokens, prompt_size); + RangeChecks(max_tokens, max_generated_tokens, max_prompt_size); if (pos >= max_tokens) { fprintf(stderr, "Warning: pos %zu >= max_tokens %zu, aborting.\n", pos, max_tokens); @@ -760,6 +873,9 @@ void GenerateT(const ByteStorageT& weights_u8, const ByteStorageT& prefill_u8, runtime_config.accept_token); }; + std::vector reached_eos(num_queries); + std::fill(reached_eos.begin(), reached_eos.end(), false); + // pos indexes the KV cache. In the first turn of a chat, pos = 0. // // After the first turn, pos gets passed in with > 0 corresponding to the @@ -772,23 +888,44 @@ void GenerateT(const ByteStorageT& weights_u8, const ByteStorageT& prefill_u8, // In single-turn (non-chat) usage, pos and pos_offset start at 0 and are // always equal. size_t pos_offset = 0; // offset relative to pos + // Used to keep track of how many tokens are processed per prompt, + // so that we know when to start generating tokens. + size_t single_prompt_pos_offset = 0; const double prefill_start = hwy::platform::Now(); // Prefill stops before prompt_size - 1 since the last prompt token is the // first input token for generation. - while (pos_offset < prompt_size - 1) { - const size_t batch_size = - std::min(kPrefillBatchSize, prompt_size - 1 - pos_offset); + while (single_prompt_pos_offset < min_prompt_size - 1) { + const size_t batch_size = std::min( + kPrefillBatchSize, min_prompt_size - 1 - single_prompt_pos_offset); + const size_t batch_and_query_size = batch_size * num_queries; HWY_DASSERT(batch_size <= kPrefillBatchSize); - HWY_DASSERT(pos_offset + batch_size <= prompt_size - 1); + HWY_DASSERT(single_prompt_pos_offset + batch_size <= min_prompt_size - 1); + HWY_DASSERT(pos_offset + batch_size <= (min_prompt_size - 1) * num_queries); const int* batch_tokens = prompt.data() + pos_offset; - Prefill(batch_tokens, batch_size, pos, weights, prefill_activations, - kv_cache, pool); + Prefill( + batch_tokens, batch_size, num_queries, pos, weights, + prefill_activations, kv_caches, pool); for (size_t idx = 0; idx < batch_size; ++idx) { - if (!runtime_config.stream_token(batch_tokens[idx], 0.0f)) return; + bool all_tokens_eos = true; + for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) { + if (reached_eos[query_idx]) continue; + if (StreamToken( + query_idx + query_index_offset, single_prompt_pos_offset, + batch_tokens[idx * num_queries + query_idx], 0.0f, + runtime_config)) { + all_tokens_eos = false; + } else { + reached_eos[query_idx] = true; + } + } + if (all_tokens_eos) { + return; + } } - pos += batch_size; - pos_offset += batch_size; + pos += batch_and_query_size; + pos_offset += batch_and_query_size; + single_prompt_pos_offset += batch_size; } timing_info.prefill_tok_sec = @@ -796,45 +933,81 @@ void GenerateT(const ByteStorageT& weights_u8, const ByteStorageT& prefill_u8, // Start generation. const double gen_start = hwy::platform::Now(); - HWY_DASSERT(pos_offset == prompt_size - 1); + HWY_DASSERT(single_prompt_pos_offset == min_prompt_size - 1); size_t pos_gen_start = pos_offset; int token = prompt.at(pos_offset); - // The loop below is not yet prepared for batch size > 1. + std::vector::const_iterator first = prompt.begin() + pos_offset; + std::vector::const_iterator last = first + num_queries; + std::vector gen_tokens(first, last); + // The loop below is not yet prepared for decode batch size > 1. HWY_ASSERT(kDecodeBatchSize == 1); - if (!runtime_config.stream_token(token, 0.0f)) return; + bool all_tokens_eos = true; + for (size_t i=0; i < num_queries; ++i) { + if (reached_eos[i]) continue; + if (StreamToken(i + query_index_offset, + single_prompt_pos_offset, gen_tokens[i], 0.0f, + runtime_config)) { + all_tokens_eos = false; + } else { + reached_eos[i] = true; + } + } + if (all_tokens_eos) { + return; + } for (size_t generate_pos = 0; - pos < max_tokens && generate_pos < max_generated_tokens; - ++pos, ++pos_offset, ++generate_pos) { - Transformer(&token, kDecodeBatchSize, pos, weights, activations, kv_cache, - pool, runtime_config.layers_output); + generate_pos < max_tokens && generate_pos < max_generated_tokens; + ++single_prompt_pos_offset, ++generate_pos) { + Transformer( + gen_tokens.data(), kDecodeBatchSize, num_queries, pos, weights, + activations, kv_caches, pool, runtime_config.layers_output); float token_logit = 0.0f; // The condition below is always true if we are doing Prefill above. // We keep it here for clarity so that the code is correct even if Prefill // is disabled. - const bool is_generating_phase = pos_offset >= prompt_size - 1; - if (is_generating_phase) { - PROFILER_ZONE("Gen.Embedding"); - // Compute logits from last layer activations. - MatVec( - weights.embedder_input_embedding, 0, activations.x.data(), - activations.even_odd.data(), activations.logits.data(), pool); - LogitsSoftCap(30.0f, activations.logits.data(), kVocabSize); - // Barrier: must have all logits so we can subtract max. - Softmax(activations.logits.data(), kVocabSize); - token = sample_token(activations.logits.data(), kVocabSize); - token_logit = activations.logits[token]; - if (generate_pos == 0) { - timing_info.time_to_first_token = hwy::platform::Now() - gen_start; + bool all_tokens_eos = true; + float* x = activations.x.data(); + float* logits = activations.logits.data(); + for (size_t i = 0; i < num_queries; ++i, ++pos, ++pos_offset, + x += TConfig::kModelDim, logits += kVocabSize) { + const size_t prompt_size = prompts[i].size(); + const bool is_generating_phase = + (single_prompt_pos_offset >= prompt_size - 1); + if (is_generating_phase) { + PROFILER_ZONE("Gen.Embedding"); + // Compute logits from last layer activations. + MatVec( + weights.embedder_input_embedding, 0, x, activations.even_odd.data(), + logits, pool); + LogitsSoftCap(30.0f, logits, kVocabSize); + // Barrier: must have all logits so we can subtract max. + Softmax(logits, kVocabSize); + token = sample_token(logits, kVocabSize); + token_logit = logits[token]; + if (generate_pos == 0) { + timing_info.time_to_first_token = hwy::platform::Now() - gen_start; + } + } else { + // We would take this branch if we were not doing Prefill but would + // process the tokens of the prompt one at a time. + token = prompt.at(pos_offset); + token_logit = 0.0f; } - } else { - // We would take this branch if we were not doing Prefill but would - // process the tokens of the prompt one at a time. - token = prompt.at(pos_offset + 1); + + if (!reached_eos[i]) { + if (!StreamToken(i + query_index_offset, single_prompt_pos_offset+1, + token, token_logit, runtime_config)) { + token = runtime_config.eos_id; + } + if (token != runtime_config.eos_id) { + all_tokens_eos = false; + } else { + reached_eos[i] = true; + } + } + gen_tokens[i] = token; } - if (!runtime_config.stream_token(token, token_logit)) { - token = runtime_config.eos_id; - } - if (token == runtime_config.eos_id) { + if (all_tokens_eos) { break; } } @@ -842,6 +1015,46 @@ void GenerateT(const ByteStorageT& weights_u8, const ByteStorageT& prefill_u8, (hwy::platform::Now() - gen_start); } +template +void GenerateOneQueryT(const ByteStorageT& weights_u8, + const ByteStorageT& prefill_u8, + const ByteStorageT& decode_u8, + const RuntimeConfig& runtime_config, + const std::vector& prompt, size_t pos, + KVCache& kv_cache, hwy::ThreadPool& pool, + TimingInfo& timing_info) { + std::vector> prompt_vector = { + hwy::Span(const_cast(prompt.data()), prompt.size())}; + const hwy::Span> prompts( + prompt_vector.data(), prompt_vector.size()); + std::vector kv_caches = {&kv_cache}; + GenerateT(weights_u8, prefill_u8, decode_u8, + runtime_config, prompts, pos, 0, + kv_caches, pool, timing_info); +} + +template +void GenerateBatchT(const ByteStorageT& weights_u8, + const ByteStorageT& prefill_u8, + const ByteStorageT& decode_u8, + const RuntimeConfig& runtime_config, + const hwy::Span>& prompts, + size_t pos, const std::vector& kv_caches, + hwy::ThreadPool& pool, + TimingInfo& timing_info) { + // Disable query batching for Griffin models. + constexpr size_t kQueryBatchSize = + (TConfig::kGriffinLayers > 0) ? 1 : kBatchedQueryBatchSize; + for (size_t i = 0; i < prompts.size(); i += kQueryBatchSize) { + const size_t num_queries = std::min(prompts.size() - i, kQueryBatchSize); + const hwy::Span> current_prompts( + prompts.data() + i, num_queries); + GenerateT(weights_u8, prefill_u8, decode_u8, + runtime_config, current_prompts, + pos, i, kv_caches, pool, timing_info); + } +} + } // namespace HWY_NAMESPACE } // namespace gcpp HWY_AFTER_NAMESPACE(); @@ -853,8 +1066,13 @@ namespace { template struct AllocateState { void operator()(ByteStorageT& prefill, ByteStorageT& decode) const { - prefill = AllocateSizeof>(); - decode = AllocateSizeof>(); + // When batching queries, the prefill batch size is reduced by a factor + // of kBatchedQueryBatchSize + prefill = AllocateSizeof< + Activations>(); + decode = AllocateSizeof< + Activations>(); } }; @@ -895,13 +1113,28 @@ void Gemma::Generate(const RuntimeConfig& runtime_config, pool_.SetWaitMode(hwy::PoolWaitMode::kSpin); GEMMA_EXPORT_AND_DISPATCH( - model_type_, weight_type_, GenerateT, + model_type_, weight_type_, GenerateOneQueryT, (weights_u8_, prefill_u8_, decode_u8_, runtime_config, prompt, start_pos, kv_cache, pool_, timing_info)); pool_.SetWaitMode(hwy::PoolWaitMode::kBlock); } +void Gemma::GenerateBatch(const RuntimeConfig& runtime_config, + const hwy::Span>& prompts, + size_t start_pos, + const std::vector& kv_caches, + TimingInfo& timing_info) { + pool_.SetWaitMode(hwy::PoolWaitMode::kSpin); + + GEMMA_EXPORT_AND_DISPATCH( + model_type_, weight_type_, GenerateBatchT, + (weights_u8_, prefill_u8_, decode_u8_, runtime_config, prompts, start_pos, + kv_caches, pool_, timing_info)); + + pool_.SetWaitMode(hwy::PoolWaitMode::kBlock); +} + std::vector WrapAndTokenize(const GemmaTokenizer& tokenizer, const ModelTraining training, size_t pos, std::string& prompt) { diff --git a/gemma/gemma.h b/gemma/gemma.h index 9a91870..280b451 100644 --- a/gemma/gemma.h +++ b/gemma/gemma.h @@ -32,6 +32,9 @@ namespace gcpp { constexpr size_t kPrefillBatchSize = 16; constexpr size_t kDecodeBatchSize = 1; +constexpr size_t kBatchedQueryBatchSize = 16; +constexpr size_t kMinAdjustedPrefillBatchSize = + HWY_MAX((size_t)1, kPrefillBatchSize / kBatchedQueryBatchSize); constexpr bool kSystemPrompt = false; struct KVCache { @@ -72,6 +75,11 @@ class GemmaTokenizer { // probability is 0.0f. StreamFunc should return false to stop generation and // true to continue generation. using StreamFunc = std::function; +// BatchStreamFunc is called with (query_idx, pos, token, probability). +// For prompt tokens, +// probability is 0.0f. StreamFunc should return false to stop generation and +// true to continue generation. +using BatchStreamFunc = std::function; // If not empty, AcceptFunc is called with token. It should return false for // tokens you don't want to generate and true for tokens you want to generate. using AcceptFunc = std::function; @@ -93,6 +101,7 @@ struct RuntimeConfig { int verbosity; std::mt19937* gen; StreamFunc stream_token; + BatchStreamFunc batch_stream_token; AcceptFunc accept_token; // if empty, accepts all tokens. SampleFunc sample_func; // if empty, uses SampleTopK. LayersOutputFunc layers_output; // if not empty, called after each layer. @@ -125,6 +134,11 @@ class Gemma { const std::vector& prompt, size_t start_pos, KVCache& kv_cache, TimingInfo& timing_info); + void GenerateBatch(const RuntimeConfig& runtime_config, + const hwy::Span>& prompts, + size_t start_pos, const std::vector& kv_caches, + TimingInfo& timing_info); + private: hwy::ThreadPool& pool_; diff --git a/gemma/gemma_test.cc b/gemma/gemma_test.cc index b885195..c50760b 100644 --- a/gemma/gemma_test.cc +++ b/gemma/gemma_test.cc @@ -17,9 +17,11 @@ #include +#include #include #include +#include "hwy/aligned_allocator.h" #include "gemma/benchmark_helper.h" #include "gemma/common.h" #include "hwy/tests/hwy_gtest.h" @@ -44,13 +46,55 @@ class GemmaTest : public ::testing::Test { return response; } - void TestQuestions(const char* kQA[][2], size_t num_questions) { + std::vector BatchGemmaReply( + const std::vector& inputs) { + s_env->SetMaxGeneratedTokens(64); + s_env->MutableConfig().temperature = 0.0f; // deterministic + s_env->MutableConfig().verbosity = 0; + // Using the turn structure worsens results. + std::vector>> prompts; + prompts.reserve(inputs.size()); + for (auto input_string : inputs) { + std::string mutable_input_string = input_string; + prompts.push_back(std::make_unique>( + s_env->TokenizeAndPrependBOS(input_string))); + } + std::vector> prompt_vector; + for (auto& prompt : prompts) { + prompt_vector.push_back(hwy::Span( + prompt->data(), prompt->size())); + } + hwy::Span> prompt_span = + hwy::Span>( + prompt_vector.data(), prompt_vector.size()); + std::vector replies; + for (auto [response, n] : s_env->BatchQueryModel2(prompt_span)) { + replies.push_back(response); + } + return replies; + } + + void TestQuestions(const char* kQA[][2], size_t num_questions, bool batch) { if (!s_env->GetModel()) return; - for (size_t i = 0; i < num_questions; ++i) { - fprintf(stderr, "Question %zu\n\n", i + 1); - std::string response = GemmaReply(kQA[i][0]); - fprintf(stderr, "'%s'\n\n", response.c_str()); - EXPECT_TRUE(response.find(kQA[i][1]) != std::string::npos); // NOLINT + if (batch) { + std::vector inputs; + for (size_t i = 0; i < num_questions; ++i) { + fprintf(stderr, "Batch Question %zu\n\n", i + 1); + inputs.push_back(kQA[i][0]); + } + std::vector responses = BatchGemmaReply(inputs); + for (size_t i = 0; i < num_questions; ++i) { + std::string response = responses.at(i); + fprintf(stderr, "Batch answer %zu '%s'\n\n", i + 1, response.c_str()); + EXPECT_TRUE(response.find(kQA[i][1]) != std::string::npos); // NOLINT + } + } else { + for (size_t i = 0; i < num_questions; ++i) { + fprintf(stderr, "Question %zu\n\n", i + 1); + std::string response = GemmaReply(kQA[i][0]); + fprintf(stderr, "'%s'\n\n", response.c_str()); + EXPECT_TRUE(response.find(kQA[i][1]) != std::string::npos); // NOLINT + } } } }; @@ -58,10 +102,16 @@ class GemmaTest : public ::testing::Test { TEST_F(GemmaTest, Geography) { static const char* kQA[][2] = { {"What is the capital of Hungary?", "Budapest"}, + {"What is the capital of Australia?", "Canberra"}, {"How many states does the US have?", "50"}, }; static const size_t kNum = sizeof(kQA) / sizeof(kQA[0]); - TestQuestions(kQA, kNum); + TestQuestions(kQA, kNum, /* batch= */ false); + static const char* kQA_single_question[][2] = { + {"What is the capital of Australia?", "Canberra"}, + }; + TestQuestions(kQA_single_question, 1, /* batch= */ true); + TestQuestions(kQA, kNum, /* batch= */ true); } TEST_F(GemmaTest, History) { @@ -69,7 +119,7 @@ TEST_F(GemmaTest, History) { {"When was the battle of Hastings?", "1066"}, }; static const size_t kNum = sizeof(kQA) / sizeof(kQA[0]); - TestQuestions(kQA, kNum); + TestQuestions(kQA, kNum, /* batch= */ false); } TEST_F(GemmaTest, Arithmetic) { @@ -78,7 +128,7 @@ TEST_F(GemmaTest, Arithmetic) { {"what is 7 * 8?", "56"}, }; static const size_t kNum = sizeof(kQA) / sizeof(kQA[0]); - TestQuestions(kQA, kNum); + TestQuestions(kQA, kNum, /* batch= */ false); } static const char kJingleBells[] = R"( @@ -152,4 +202,4 @@ int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); -} \ No newline at end of file +}