gemma_batch_bench: generate more unique prompts

PiperOrigin-RevId: 819944137
This commit is contained in:
Jan Wassenberg 2025-10-15 15:45:27 -07:00 committed by Copybara-Service
parent 503aaddd65
commit 9b6ed1a58f
1 changed files with 65 additions and 34 deletions

View File

@ -15,6 +15,7 @@
#include <stdio.h> #include <stdio.h>
#include <algorithm>
#include <string> #include <string>
#include <vector> #include <vector>
@ -48,48 +49,78 @@ class GemmaBatchBench : public ::testing::Test {
}; };
TEST_F(GemmaBatchBench, RandomQuestionsBatched) { TEST_F(GemmaBatchBench, RandomQuestionsBatched) {
const std::vector<std::string> questions = { std::vector<std::string> prompts = {
{"Write me a poem about Australia?"}, {"Describe dynamic programming."},
{"What's the history of Denmark?"}, {"Explain how electric cars work."},
{"Write me a comedy story about the USA."},
{"Teach me about GPU programming."},
{"Write me a story about the moon."},
{"Write me a story about the universe."},
{"Write a poem about planet earth."},
{"Tell me more about olympic sports."},
{"How would you describe Washington State?"},
{"Write me a story about Silicon Valley."},
{"Write me about your best friend."},
{"How would you describe a unicorn?"},
{"Tell me about world war history."},
{"Tell me about Google."},
{"Explain to me how to use Google Maps."}, {"Explain to me how to use Google Maps."},
{"Explain to me how AI works."}, {"How does AI work?"},
{"Write me a poem about France."}, {"How would you describe a unicorn?"},
{"What's the history of Great Britain?"},
{"Write me a comedy story about Florida."},
{"Teach me about dynamic programming."},
{"Write me a story about Jupiter."},
{"Write me a story about space ships."},
{"Write a poem about some random planet."},
{"Tell me more about team sports."},
{"How would you describe Michigan State?"},
{"Write me a story about Europe."},
{"Write me about your best colleague."},
{"How would you describe a horse?"},
{"Tell me about World War 2."},
{"Please share some good cooking tips."}, {"Please share some good cooking tips."},
{"Tell me about space travel."}, {"Teach me about GPU programming."},
{"Explain to me how electric cars work."}, {"Tell me a fact about World War 2."},
{"Tell me about Google."},
{"Tell me more about olympic sports."},
{"Tell me something about space travel."},
{"What is a horse?"},
{"What is Michigan State?"},
{"What's the history of Denmark?"},
{"Write a poem about planet earth."},
{"Write a story about Jupiter."},
{"Write about the moon."},
{"Write me a comedy story about Florida."},
{"Write me a poem about France."},
}; };
const std::vector<std::string> start = {
{"What is"}, {"When did"}, {"Where did"}, {"How did"}, {"Why did"}};
const std::vector<std::string> concepts = {"Socrates",
"Einstein",
"Leonardo",
"Cleopatra",
"Adele",
"Mars",
"Turing",
"Mozart",
"democracy",
"gravity",
"AI",
"evolution",
"physics",
"the internet",
"steam engine",
"inflation",
"electricity",
"the Sahara",
"NASA",
"Rome",
"the UN",
"Google",
"the Renaissance",
"Hamlet",
"poetry",
"Stoicism",
"geometry",
"DNA",
"Star Wars",
"1984"};
const std::vector<std::string> end = {"exist?", "work?", "happen?",
"lead to?", "believe?", "result in?"};
for (const std::string& s : start) {
for (const std::string& c : concepts) {
for (const std::string& e : end) {
prompts.push_back(s + " " + c + " " + e);
}
}
}
AesCtrEngine engine(true);
std::shuffle(prompts.begin(), prompts.end(), RngStream(engine, 123));
// Fills prompts round robin from `questions` until the desired batch size. // Fills `inputs` by repeating from `prompts` until the desired batch size.
std::vector<std::string> inputs; std::vector<std::string> inputs;
inputs.reserve(s_env->MutableConfig().decode_qbatch_size); inputs.reserve(s_env->MutableConfig().decode_qbatch_size);
size_t qpos = 0; size_t qpos = 0;
for (size_t i = 0; i < inputs.capacity(); ++i) { for (size_t i = 0; i < inputs.capacity(); ++i) {
inputs.push_back(questions[qpos++]); inputs.push_back(prompts[qpos++]);
if (qpos == questions.size()) qpos = 0; if (qpos == prompts.size()) qpos = 0;
} }
s_env->SetMaxGeneratedTokens(24); s_env->SetMaxGeneratedTokens(24);
std::vector<std::string> responses = BatchGemmaReply(inputs); std::vector<std::string> responses = BatchGemmaReply(inputs);