From f2adbfbcab3f4c721fc4ffb5f3e9da9b32420725 Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Tue, 17 Jun 2025 07:09:00 -0700 Subject: [PATCH] Batch inference fixes: set pos during prefill, fix assert PiperOrigin-RevId: 772458760 --- evals/benchmark_helper.cc | 3 ++- examples/hello_world/run.cc | 2 +- examples/simplified_gemma/gemma.hpp | 2 +- gemma/bindings/context.cc | 2 +- gemma/configs.cc | 1 - gemma/gemma.cc | 33 ++++++++++++++++++++--------- gemma/gemma.h | 3 ++- gemma/run.cc | 2 +- util/threading_context.h | 8 +++---- 9 files changed, 35 insertions(+), 21 deletions(-) diff --git a/evals/benchmark_helper.cc b/evals/benchmark_helper.cc index 6184d97..902bc87 100644 --- a/evals/benchmark_helper.cc +++ b/evals/benchmark_helper.cc @@ -49,7 +49,8 @@ void InitGenerator(const InferenceArgs& inference, std::mt19937& gen) { GemmaEnv::GemmaEnv(const LoaderArgs& loader, const ThreadingArgs& threading, const InferenceArgs& inference) - : env_(MakeMatMulEnv(threading)), gemma_(loader, inference, env_) { + : env_(MakeMatMulEnv(threading, inference)), + gemma_(loader, inference, env_) { const ModelConfig& config = gemma_.GetModelConfig(); // Only allocate one for starters because GenerateBatch might not be called. kv_caches_.push_back(KVCache(config, inference)); diff --git a/examples/hello_world/run.cc b/examples/hello_world/run.cc index 931de8a..e5b57da 100644 --- a/examples/hello_world/run.cc +++ b/examples/hello_world/run.cc @@ -51,7 +51,7 @@ int main(int argc, char** argv) { } // Instantiate model and KV Cache - gcpp::MatMulEnv env(MakeMatMulEnv(threading)); + gcpp::MatMulEnv env(MakeMatMulEnv(threading, inference)); gcpp::Gemma gemma(loader, inference, env); gcpp::KVCache kv_cache(gemma.GetModelConfig(), inference); size_t generated = 0; diff --git a/examples/simplified_gemma/gemma.hpp b/examples/simplified_gemma/gemma.hpp index 551cab4..2372591 100644 --- a/examples/simplified_gemma/gemma.hpp +++ b/examples/simplified_gemma/gemma.hpp @@ -35,7 +35,7 @@ class SimplifiedGemma { SimplifiedGemma(const gcpp::LoaderArgs& loader, const gcpp::ThreadingArgs& threading = gcpp::ThreadingArgs(), const gcpp::InferenceArgs& inference = gcpp::InferenceArgs()) - : env_(MakeMatMulEnv(threading)), + : env_(MakeMatMulEnv(threading, inference)), gemma_(loader, inference, env_), kv_cache_(gemma_.GetModelConfig(), inference) { // Initialize random number generator diff --git a/gemma/bindings/context.cc b/gemma/bindings/context.cc index e7edeaf..89fb650 100644 --- a/gemma/bindings/context.cc +++ b/gemma/bindings/context.cc @@ -101,7 +101,7 @@ GemmaContext::GemmaContext(const LoaderArgs& loader, int max_generated_tokens) : inference_args(inference_args), threading_args(threading_args), - matmul_env(MakeMatMulEnv(threading_args)), + matmul_env(MakeMatMulEnv(threading_args, inference_args)), active_conversation_name("default"), model(loader, inference_args, matmul_env) { std::stringstream ss; diff --git a/gemma/configs.cc b/gemma/configs.cc index 3c95cf9..9d548ef 100644 --- a/gemma/configs.cc +++ b/gemma/configs.cc @@ -28,7 +28,6 @@ namespace gcpp { static constexpr size_t kVocabSize = 256000; -static constexpr size_t kMaxSeqLen = 4096; static ModelConfig ConfigNoSSM() { ModelConfig config; diff --git a/gemma/gemma.cc b/gemma/gemma.cc index b8a625d..a585532 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -344,8 +344,9 @@ static HWY_NOINLINE void PrefillQBatch(const size_t max_prompt_size, token = qbatch.Prompt(qi)[pos_in_prompt]; // Ignore StreamToken return value because requesting to stop does not // make sense during prefill. - (void)runtime_config.StreamToken(qbatch.QueryIdx(qi), qbatch.Pos(qi), + (void)runtime_config.StreamToken(qbatch.QueryIdx(qi), pos_in_prompt, token, 0.0f); + qbatch.MutablePos(qi) = pos_in_prompt; } qbatch.PrevToken(qi) = token; @@ -356,6 +357,10 @@ static HWY_NOINLINE void PrefillQBatch(const size_t max_prompt_size, // probabilities, which are not required for the prompt tokens. Transformer(config, runtime_config, weights, activations, qbatch, env); } + + for (size_t qi = 0; qi < qbatch.Size(); ++qi) { + qbatch.MutablePos(qi) = qbatch.Prompt(qi).size() - 1; + } } // Calls `StreamToken`, writes the token to `PrevToken` for use by subsequent @@ -457,7 +462,7 @@ static void GenerateT(const ModelConfig& config, size_t max_prompt_size = 0; bool all_prefix_end_are_zero = true; - size_t prefill_tokens = 0; // only for timing. + size_t total_prefill_tokens = 0; // only for throughput stats. const size_t seq_len = qbatch.KV(0).SeqLen(); for (size_t qi = 0; qi < qbatch.Size(); ++qi) { const PromptTokens& prompt = qbatch.Prompt(qi); @@ -465,7 +470,7 @@ static void GenerateT(const ModelConfig& config, // Prefill stops before size - 1 because the last prompt token is the // first input token for generation. - prefill_tokens += prompt.size() - 1; + total_prefill_tokens += prompt.size() - 1; // Sanity check: prompts should not be empty, nor start with EOS. HWY_ASSERT(prompt.size() != 0 && prompt[0] != config.eos_id); @@ -475,7 +480,7 @@ static void GenerateT(const ModelConfig& config, // We use a single divisor, so all sequence lengths must be the same. HWY_ASSERT(qbatch.KV(qi).SeqLen() == seq_len); } - HWY_ASSERT(prefill_tokens < seq_len); + HWY_ASSERT(max_prompt_size < seq_len); HWY_ASSERT(activations.attention.div_seq_len.GetDivisor() == seq_len); // Lacks a constructor to bulk-set, hence initialized by Prefill* which have @@ -494,7 +499,7 @@ static void GenerateT(const ModelConfig& config, activations.SetBatchSize(qbatch.Size()); // Restore after PrefillTBatch. } HWY_DASSERT(non_eos.Count() == qbatch.Size()); - timing_info.NotifyPrefill(prefill_tokens); + timing_info.NotifyPrefill(total_prefill_tokens); // queries_pos have been incremented by Prefill. // Stream the last prompt token from each query, fill activations.gen_tokens. @@ -505,10 +510,10 @@ static void GenerateT(const ModelConfig& config, } size_t max_gen_steps = runtime_config.max_generated_tokens; - if (prefill_tokens + max_gen_steps > seq_len) { + if (max_prompt_size + max_gen_steps > seq_len) { HWY_WARN("prefill %zu + max_gen_steps %zu > seq_len %zu, truncating.", - prefill_tokens, max_gen_steps, seq_len); - max_gen_steps = seq_len - prefill_tokens; + max_prompt_size, max_gen_steps, seq_len); + max_gen_steps = seq_len - max_prompt_size; } const SampleFunc sample_token = ChooseSampleFunc(runtime_config); @@ -588,8 +593,16 @@ HWY_EXPORT(GenerateSingleT); HWY_EXPORT(GenerateBatchT); HWY_EXPORT(GenerateImageTokensT); -MatMulEnv MakeMatMulEnv(const ThreadingArgs& threading_args) { - ThreadingContext::SetArgs(threading_args); +MatMulEnv MakeMatMulEnv(const ThreadingArgs& threading_args, + const InferenceArgs& inference_args) { + if (inference_args.decode_qbatch_size >= 256) { + ThreadingArgs copy = threading_args; + copy.max_packages = 1; + ThreadingContext::SetArgs(copy); + } else { + ThreadingContext::SetArgs(threading_args); + } + return MatMulEnv(ThreadingContext::Get()); } diff --git a/gemma/gemma.h b/gemma/gemma.h index 40cc82c..423133d 100644 --- a/gemma/gemma.h +++ b/gemma/gemma.h @@ -226,7 +226,8 @@ struct TimingInfo { }; // Returns the `MatMulEnv` after calling `SetArgs`. -MatMulEnv MakeMatMulEnv(const ThreadingArgs& threading_args); +MatMulEnv MakeMatMulEnv(const ThreadingArgs& threading_args, + const InferenceArgs& inference_args); class Gemma { public: diff --git a/gemma/run.cc b/gemma/run.cc index bf9b32b..071a5a5 100644 --- a/gemma/run.cc +++ b/gemma/run.cc @@ -254,7 +254,7 @@ void Run(const LoaderArgs& loader, const ThreadingArgs& threading, const InferenceArgs& inference) { PROFILER_ZONE("Run.misc"); - MatMulEnv env(MakeMatMulEnv(threading)); + MatMulEnv env(MakeMatMulEnv(threading, inference)); if (inference.verbosity >= 2) env.print_best = true; const Gemma gemma(loader, inference, env); KVCache kv_cache(gemma.GetModelConfig(), inference); diff --git a/util/threading_context.h b/util/threading_context.h index be9bf59..e35d368 100644 --- a/util/threading_context.h +++ b/util/threading_context.h @@ -61,20 +61,20 @@ class ThreadingArgs : public ArgsBase { visitor(skip_packages, "skip_packages", size_t{0}, "Index of the first socket to use; default 0 = unlimited.", 2); visitor(max_packages, "max_packages", size_t{0}, - "Maximum number of sockets to use; default 0 = unlimited.", 2); + "Max sockets to use; default 0 = all unless large batch size.", 2); visitor(skip_clusters, "skip_clusters", size_t{0}, "Index of the first CCX to use; default 0 = unlimited.", 2); visitor(max_clusters, "max_clusters", size_t{0}, - "Maximum number of CCXs to use; default 0 = unlimited.", 2); + "Max CCXs to use; default 0 = unlimited.", 2); // These are only used when CPU topology is unknown. visitor(skip_lps, "skip_lps", size_t{0}, "Index of the first LP to use; default 0 = unlimited.", 2); visitor(max_lps, "max_lps", size_t{0}, - "Maximum number of LPs to use; default 0 = unlimited.", 2); + "Max LPs to use; default 0 = unlimited.", 2); // The exact meaning is more subtle: see the comment at NestedPools ctor. visitor(max_threads, "num_threads", size_t{0}, - "Maximum number of threads to use; default 0 = unlimited.", 2); + "Max threads to use; default 0 = unlimited.", 2); visitor(pin, "pin", Tristate::kDefault, "Pin threads? -1 = auto, 0 = no, 1 = yes.", 2); visitor(spin, "spin", Tristate::kDefault,