From 9b6ed1a58f631c85693117f38b1fea36c7e82a2f Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Wed, 15 Oct 2025 15:45:27 -0700 Subject: [PATCH] gemma_batch_bench: generate more unique prompts PiperOrigin-RevId: 819944137 --- evals/gemma_batch_bench.cc | 99 +++++++++++++++++++++++++------------- 1 file changed, 65 insertions(+), 34 deletions(-) diff --git a/evals/gemma_batch_bench.cc b/evals/gemma_batch_bench.cc index ff81671..4a6f5ea 100644 --- a/evals/gemma_batch_bench.cc +++ b/evals/gemma_batch_bench.cc @@ -15,6 +15,7 @@ #include +#include #include #include @@ -48,48 +49,78 @@ class GemmaBatchBench : public ::testing::Test { }; TEST_F(GemmaBatchBench, RandomQuestionsBatched) { - const std::vector questions = { - {"Write me a poem about Australia?"}, - {"What's the history of Denmark?"}, - {"Write me a comedy story about the USA."}, - {"Teach me about GPU programming."}, - {"Write me a story about the moon."}, - {"Write me a story about the universe."}, - {"Write a poem about planet earth."}, - {"Tell me more about olympic sports."}, - {"How would you describe Washington State?"}, - {"Write me a story about Silicon Valley."}, - {"Write me about your best friend."}, - {"How would you describe a unicorn?"}, - {"Tell me about world war history."}, - {"Tell me about Google."}, + std::vector prompts = { + {"Describe dynamic programming."}, + {"Explain how electric cars work."}, {"Explain to me how to use Google Maps."}, - {"Explain to me how AI works."}, - {"Write me a poem about France."}, - {"What's the history of Great Britain?"}, - {"Write me a comedy story about Florida."}, - {"Teach me about dynamic programming."}, - {"Write me a story about Jupiter."}, - {"Write me a story about space ships."}, - {"Write a poem about some random planet."}, - {"Tell me more about team sports."}, - {"How would you describe Michigan State?"}, - {"Write me a story about Europe."}, - {"Write me about your best colleague."}, - {"How would you describe a horse?"}, - {"Tell me about World War 2."}, + {"How does AI work?"}, + {"How would you describe a unicorn?"}, {"Please share some good cooking tips."}, - {"Tell me about space travel."}, - {"Explain to me how electric cars work."}, + {"Teach me about GPU programming."}, + {"Tell me a fact about World War 2."}, + {"Tell me about Google."}, + {"Tell me more about olympic sports."}, + {"Tell me something about space travel."}, + {"What is a horse?"}, + {"What is Michigan State?"}, + {"What's the history of Denmark?"}, + {"Write a poem about planet earth."}, + {"Write a story about Jupiter."}, + {"Write about the moon."}, + {"Write me a comedy story about Florida."}, + {"Write me a poem about France."}, }; + const std::vector start = { + {"What is"}, {"When did"}, {"Where did"}, {"How did"}, {"Why did"}}; + const std::vector concepts = {"Socrates", + "Einstein", + "Leonardo", + "Cleopatra", + "Adele", + "Mars", + "Turing", + "Mozart", + "democracy", + "gravity", + "AI", + "evolution", + "physics", + "the internet", + "steam engine", + "inflation", + "electricity", + "the Sahara", + "NASA", + "Rome", + "the UN", + "Google", + "the Renaissance", + "Hamlet", + "poetry", + "Stoicism", + "geometry", + "DNA", + "Star Wars", + "1984"}; + const std::vector end = {"exist?", "work?", "happen?", + "lead to?", "believe?", "result in?"}; + for (const std::string& s : start) { + for (const std::string& c : concepts) { + for (const std::string& e : end) { + prompts.push_back(s + " " + c + " " + e); + } + } + } + AesCtrEngine engine(true); + std::shuffle(prompts.begin(), prompts.end(), RngStream(engine, 123)); - // Fills prompts round robin from `questions` until the desired batch size. + // Fills `inputs` by repeating from `prompts` until the desired batch size. std::vector inputs; inputs.reserve(s_env->MutableConfig().decode_qbatch_size); size_t qpos = 0; for (size_t i = 0; i < inputs.capacity(); ++i) { - inputs.push_back(questions[qpos++]); - if (qpos == questions.size()) qpos = 0; + inputs.push_back(prompts[qpos++]); + if (qpos == prompts.size()) qpos = 0; } s_env->SetMaxGeneratedTokens(24); std::vector responses = BatchGemmaReply(inputs);