From 8198e7104a78643a58b807188eafbb40f94157c8 Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Fri, 24 Oct 2025 09:14:07 -0700 Subject: [PATCH] Batch bench: 4 runs to give autotuning more time Also print auto-tune info for verbosity 3. PiperOrigin-RevId: 823555008 --- evals/benchmark_helper.cc | 3 +++ evals/gemma_batch_bench.cc | 35 +++++++++++++++++++---------------- 2 files changed, 22 insertions(+), 16 deletions(-) diff --git a/evals/benchmark_helper.cc b/evals/benchmark_helper.cc index a495dea..9e4c1b6 100644 --- a/evals/benchmark_helper.cc +++ b/evals/benchmark_helper.cc @@ -47,6 +47,9 @@ GemmaEnv::GemmaEnv(const LoaderArgs& loader, const ThreadingArgs& threading, ShowConfig(loader, threading, inference, config, gemma_.WeightReadMode(), ctx_); } + if (inference.verbosity >= 3) { + env_.print_config = env_.print_best = true; + } runtime_config_ = { .max_generated_tokens = inference.max_generated_tokens, diff --git a/evals/gemma_batch_bench.cc b/evals/gemma_batch_bench.cc index 4a6f5ea..90e46d4 100644 --- a/evals/gemma_batch_bench.cc +++ b/evals/gemma_batch_bench.cc @@ -48,7 +48,7 @@ class GemmaBatchBench : public ::testing::Test { } }; -TEST_F(GemmaBatchBench, RandomQuestionsBatched) { +std::vector GenerateInputs() { std::vector prompts = { {"Describe dynamic programming."}, {"Explain how electric cars work."}, @@ -122,22 +122,25 @@ TEST_F(GemmaBatchBench, RandomQuestionsBatched) { inputs.push_back(prompts[qpos++]); if (qpos == prompts.size()) qpos = 0; } - s_env->SetMaxGeneratedTokens(24); - std::vector responses = BatchGemmaReply(inputs); - for (size_t i = 0; i < HWY_MIN(hwy::Unpredictable1() * 3, responses.size()); - ++i) { - fprintf(stderr, "Batch answer %zu '%s'\n\n", i, responses[i].c_str()); - } - - PROFILER_PRINT_RESULTS(); - - // Run again: prefill will be faster due to autotuning. Fewer decode steps - // because those are already fast. - s_env->SetMaxGeneratedTokens(2); - responses = BatchGemmaReply(inputs); - - PROFILER_PRINT_RESULTS(); + return inputs; } + +TEST_F(GemmaBatchBench, RandomQuestionsBatched) { + s_env->SetMaxGeneratedTokens(12); + const std::vector inputs = GenerateInputs(); + + // Run multiple times so that auto-tuning is closer to complete. + for (size_t rep = 0; rep < 4; ++rep) { + std::vector responses = BatchGemmaReply(inputs); + for (size_t i = 0; i < HWY_MIN(hwy::Unpredictable1() * 3, responses.size()); + ++i) { + fprintf(stderr, "Rep %zu batch answer %zu '%s'\n\n", rep, i, + responses[i].c_str()); + } + PROFILER_PRINT_RESULTS(); + } +} + } // namespace } // namespace gcpp