mirror of https://github.com/google/gemma.cpp.git
Batch bench: 4 runs to give autotuning more time
Also print auto-tune info for verbosity 3. PiperOrigin-RevId: 823555008
This commit is contained in:
parent
1bdde1af3c
commit
8198e7104a
|
|
@ -47,6 +47,9 @@ GemmaEnv::GemmaEnv(const LoaderArgs& loader, const ThreadingArgs& threading,
|
||||||
ShowConfig(loader, threading, inference, config, gemma_.WeightReadMode(),
|
ShowConfig(loader, threading, inference, config, gemma_.WeightReadMode(),
|
||||||
ctx_);
|
ctx_);
|
||||||
}
|
}
|
||||||
|
if (inference.verbosity >= 3) {
|
||||||
|
env_.print_config = env_.print_best = true;
|
||||||
|
}
|
||||||
|
|
||||||
runtime_config_ = {
|
runtime_config_ = {
|
||||||
.max_generated_tokens = inference.max_generated_tokens,
|
.max_generated_tokens = inference.max_generated_tokens,
|
||||||
|
|
|
||||||
|
|
@ -48,7 +48,7 @@ class GemmaBatchBench : public ::testing::Test {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
TEST_F(GemmaBatchBench, RandomQuestionsBatched) {
|
std::vector<std::string> GenerateInputs() {
|
||||||
std::vector<std::string> prompts = {
|
std::vector<std::string> prompts = {
|
||||||
{"Describe dynamic programming."},
|
{"Describe dynamic programming."},
|
||||||
{"Explain how electric cars work."},
|
{"Explain how electric cars work."},
|
||||||
|
|
@ -122,22 +122,25 @@ TEST_F(GemmaBatchBench, RandomQuestionsBatched) {
|
||||||
inputs.push_back(prompts[qpos++]);
|
inputs.push_back(prompts[qpos++]);
|
||||||
if (qpos == prompts.size()) qpos = 0;
|
if (qpos == prompts.size()) qpos = 0;
|
||||||
}
|
}
|
||||||
s_env->SetMaxGeneratedTokens(24);
|
return inputs;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(GemmaBatchBench, RandomQuestionsBatched) {
|
||||||
|
s_env->SetMaxGeneratedTokens(12);
|
||||||
|
const std::vector<std::string> inputs = GenerateInputs();
|
||||||
|
|
||||||
|
// Run multiple times so that auto-tuning is closer to complete.
|
||||||
|
for (size_t rep = 0; rep < 4; ++rep) {
|
||||||
std::vector<std::string> responses = BatchGemmaReply(inputs);
|
std::vector<std::string> responses = BatchGemmaReply(inputs);
|
||||||
for (size_t i = 0; i < HWY_MIN(hwy::Unpredictable1() * 3, responses.size());
|
for (size_t i = 0; i < HWY_MIN(hwy::Unpredictable1() * 3, responses.size());
|
||||||
++i) {
|
++i) {
|
||||||
fprintf(stderr, "Batch answer %zu '%s'\n\n", i, responses[i].c_str());
|
fprintf(stderr, "Rep %zu batch answer %zu '%s'\n\n", rep, i,
|
||||||
|
responses[i].c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
PROFILER_PRINT_RESULTS();
|
|
||||||
|
|
||||||
// Run again: prefill will be faster due to autotuning. Fewer decode steps
|
|
||||||
// because those are already fast.
|
|
||||||
s_env->SetMaxGeneratedTokens(2);
|
|
||||||
responses = BatchGemmaReply(inputs);
|
|
||||||
|
|
||||||
PROFILER_PRINT_RESULTS();
|
PROFILER_PRINT_RESULTS();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue