From c4a75abe43e5cd68f6bfc840b0d5dbfb9bac791f Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Mon, 2 Jun 2025 07:03:59 -0700 Subject: [PATCH] Cleanup gemma_batch_bench PiperOrigin-RevId: 766177406 --- BUILD.bazel | 2 +- evals/gemma_batch_bench.cc | 72 ++++++++++---------------------------- 2 files changed, 20 insertions(+), 54 deletions(-) diff --git a/BUILD.bazel b/BUILD.bazel index fc5c0c5..8b8eb94 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -605,11 +605,11 @@ cc_test( ], deps = [ ":benchmark_helper", - ":configs", ":gemma_lib", "@googletest//:gtest_main", # buildcleaner: keep "@highway//:hwy", "@highway//:hwy_test_util", + "@highway//:nanobenchmark", "@highway//:profiler", ], ) diff --git a/evals/gemma_batch_bench.cc b/evals/gemma_batch_bench.cc index b691706..5411d95 100644 --- a/evals/gemma_batch_bench.cc +++ b/evals/gemma_batch_bench.cc @@ -19,18 +19,12 @@ #include #include "evals/benchmark_helper.h" -#include "gemma/configs.h" #include "gemma/gemma.h" #include "hwy/base.h" +#include "hwy/nanobenchmark.h" #include "hwy/profiler.h" #include "hwy/tests/hwy_gtest.h" -// This test can be run manually with the downloaded gemma weights. -// To run the test, pass the following flags: -// --tokenizer --weights -// It should pass for the following models: -// Gemma2: gemma2-2b-it, 9b-it, 27b-it, - namespace gcpp { namespace { @@ -43,59 +37,19 @@ class GemmaTest : public ::testing::Test { protected: std::vector BatchGemmaReply( const std::vector& inputs) { - s_env->SetMaxGeneratedTokens(64); + s_env->SetMaxGeneratedTokens(32); s_env->MutableConfig().temperature = 0.0f; // deterministic - s_env->MutableConfig().verbosity = 5; - const ModelConfig& config = s_env->GetGemma()->GetModelConfig(); + s_env->MutableConfig().verbosity = 2; std::vector replies; - // Using the turn structure worsens results sometimes. - // However, some models need the turn structure to work. - // It would be good to make these tests more consistent. - if (config.model == Model::GEMMA2_27B || - config.model == Model::GRIFFIN_2B) { - for (const QueryResult& result : s_env->BatchQueryModel(inputs)) { - replies.push_back(result.response); - } - return replies; - } - // Otherwise, do not use turn structure. - std::vector> prompts_vector; - prompts_vector.reserve(inputs.size()); - for (const auto& input_string : inputs) { - prompts_vector.push_back(s_env->TokenizeAndPrependBOS(input_string)); - } - std::vector prompt_spans; - prompt_spans.reserve(prompts_vector.size()); - for (const auto& prompt : prompts_vector) { - prompt_spans.push_back(PromptTokens(prompt.data(), prompt.size())); - } - QueriesPromptTokens prompts(prompt_spans.data(), prompt_spans.size()); - for (const QueryResult& result : s_env->BatchQueryModel(prompts)) { + for (const QueryResult& result : s_env->BatchQueryModel(inputs)) { replies.push_back(result.response); } return replies; } - - void GenerateTokens(const std::vector& questions) { - ASSERT_NE(s_env->GetGemma(), nullptr); - - // Fills prompts round robin from `questions` until the desired batch size. - std::vector inputs; - inputs.reserve(s_env->MutableConfig().decode_qbatch_size); - size_t qpos = 0; - for (size_t i = 0; i < inputs.capacity(); ++i) { - inputs.push_back(questions[qpos++]); - if (qpos == questions.size()) qpos = 0; - } - std::vector responses = BatchGemmaReply(inputs); - for (size_t i = 0; i < inputs.size(); ++i) { - fprintf(stderr, "Batch answer %zu '%s'\n\n", i, responses[i].c_str()); - } - } }; TEST_F(GemmaTest, RandomQuestionsBatched) { - static std::vector kQA = { + const std::vector questions = { {"Write me a poem about Australia?"}, {"What's the history of Denmark?"}, {"Write me a comedy story about the USA."}, @@ -129,8 +83,20 @@ TEST_F(GemmaTest, RandomQuestionsBatched) { {"Tell me about space travel."}, {"Explain to me how electric cars work."}, }; - s_env->MutableConfig().verbosity = 5; - GenerateTokens(kQA); + + // Fills prompts round robin from `questions` until the desired batch size. + std::vector inputs; + inputs.reserve(s_env->MutableConfig().decode_qbatch_size); + size_t qpos = 0; + for (size_t i = 0; i < inputs.capacity(); ++i) { + inputs.push_back(questions[qpos++]); + if (qpos == questions.size()) qpos = 0; + } + std::vector responses = BatchGemmaReply(inputs); + for (size_t i = 0; i < hwy::Unpredictable1(); ++i) { + fprintf(stderr, "Batch answer %zu '%s'\n\n", i, responses[i].c_str()); + } + PROFILER_PRINT_RESULTS(); } } // namespace