Cleanup gemma_batch_bench

PiperOrigin-RevId: 766177406
This commit is contained in:
Jan Wassenberg 2025-06-02 07:03:59 -07:00 committed by Copybara-Service
parent a3f7bf0991
commit c4a75abe43
2 changed files with 20 additions and 54 deletions

View File

@ -605,11 +605,11 @@ cc_test(
], ],
deps = [ deps = [
":benchmark_helper", ":benchmark_helper",
":configs",
":gemma_lib", ":gemma_lib",
"@googletest//:gtest_main", # buildcleaner: keep "@googletest//:gtest_main", # buildcleaner: keep
"@highway//:hwy", "@highway//:hwy",
"@highway//:hwy_test_util", "@highway//:hwy_test_util",
"@highway//:nanobenchmark",
"@highway//:profiler", "@highway//:profiler",
], ],
) )

View File

@ -19,18 +19,12 @@
#include <vector> #include <vector>
#include "evals/benchmark_helper.h" #include "evals/benchmark_helper.h"
#include "gemma/configs.h"
#include "gemma/gemma.h" #include "gemma/gemma.h"
#include "hwy/base.h" #include "hwy/base.h"
#include "hwy/nanobenchmark.h"
#include "hwy/profiler.h" #include "hwy/profiler.h"
#include "hwy/tests/hwy_gtest.h" #include "hwy/tests/hwy_gtest.h"
// This test can be run manually with the downloaded gemma weights.
// To run the test, pass the following flags:
// --tokenizer <tokenizer_path> --weights <weights_path>
// It should pass for the following models:
// Gemma2: gemma2-2b-it, 9b-it, 27b-it,
namespace gcpp { namespace gcpp {
namespace { namespace {
@ -43,59 +37,19 @@ class GemmaTest : public ::testing::Test {
protected: protected:
std::vector<std::string> BatchGemmaReply( std::vector<std::string> BatchGemmaReply(
const std::vector<std::string>& inputs) { const std::vector<std::string>& inputs) {
s_env->SetMaxGeneratedTokens(64); s_env->SetMaxGeneratedTokens(32);
s_env->MutableConfig().temperature = 0.0f; // deterministic s_env->MutableConfig().temperature = 0.0f; // deterministic
s_env->MutableConfig().verbosity = 5; s_env->MutableConfig().verbosity = 2;
const ModelConfig& config = s_env->GetGemma()->GetModelConfig();
std::vector<std::string> replies; std::vector<std::string> replies;
// Using the turn structure worsens results sometimes.
// However, some models need the turn structure to work.
// It would be good to make these tests more consistent.
if (config.model == Model::GEMMA2_27B ||
config.model == Model::GRIFFIN_2B) {
for (const QueryResult& result : s_env->BatchQueryModel(inputs)) { for (const QueryResult& result : s_env->BatchQueryModel(inputs)) {
replies.push_back(result.response); replies.push_back(result.response);
} }
return replies; return replies;
} }
// Otherwise, do not use turn structure.
std::vector<std::vector<int>> prompts_vector;
prompts_vector.reserve(inputs.size());
for (const auto& input_string : inputs) {
prompts_vector.push_back(s_env->TokenizeAndPrependBOS(input_string));
}
std::vector<PromptTokens> prompt_spans;
prompt_spans.reserve(prompts_vector.size());
for (const auto& prompt : prompts_vector) {
prompt_spans.push_back(PromptTokens(prompt.data(), prompt.size()));
}
QueriesPromptTokens prompts(prompt_spans.data(), prompt_spans.size());
for (const QueryResult& result : s_env->BatchQueryModel(prompts)) {
replies.push_back(result.response);
}
return replies;
}
void GenerateTokens(const std::vector<std::string>& questions) {
ASSERT_NE(s_env->GetGemma(), nullptr);
// Fills prompts round robin from `questions` until the desired batch size.
std::vector<std::string> inputs;
inputs.reserve(s_env->MutableConfig().decode_qbatch_size);
size_t qpos = 0;
for (size_t i = 0; i < inputs.capacity(); ++i) {
inputs.push_back(questions[qpos++]);
if (qpos == questions.size()) qpos = 0;
}
std::vector<std::string> responses = BatchGemmaReply(inputs);
for (size_t i = 0; i < inputs.size(); ++i) {
fprintf(stderr, "Batch answer %zu '%s'\n\n", i, responses[i].c_str());
}
}
}; };
TEST_F(GemmaTest, RandomQuestionsBatched) { TEST_F(GemmaTest, RandomQuestionsBatched) {
static std::vector<std::string> kQA = { const std::vector<std::string> questions = {
{"Write me a poem about Australia?"}, {"Write me a poem about Australia?"},
{"What's the history of Denmark?"}, {"What's the history of Denmark?"},
{"Write me a comedy story about the USA."}, {"Write me a comedy story about the USA."},
@ -129,8 +83,20 @@ TEST_F(GemmaTest, RandomQuestionsBatched) {
{"Tell me about space travel."}, {"Tell me about space travel."},
{"Explain to me how electric cars work."}, {"Explain to me how electric cars work."},
}; };
s_env->MutableConfig().verbosity = 5;
GenerateTokens(kQA); // Fills prompts round robin from `questions` until the desired batch size.
std::vector<std::string> inputs;
inputs.reserve(s_env->MutableConfig().decode_qbatch_size);
size_t qpos = 0;
for (size_t i = 0; i < inputs.capacity(); ++i) {
inputs.push_back(questions[qpos++]);
if (qpos == questions.size()) qpos = 0;
}
std::vector<std::string> responses = BatchGemmaReply(inputs);
for (size_t i = 0; i < hwy::Unpredictable1(); ++i) {
fprintf(stderr, "Batch answer %zu '%s'\n\n", i, responses[i].c_str());
}
PROFILER_PRINT_RESULTS(); PROFILER_PRINT_RESULTS();
} }
} // namespace } // namespace