mirror of https://github.com/google/gemma.cpp.git
Batch bench: 4 runs to give autotuning more time
Also print auto-tune info for verbosity 3. PiperOrigin-RevId: 823555008
This commit is contained in:
parent
1bdde1af3c
commit
8198e7104a
|
|
@ -47,6 +47,9 @@ GemmaEnv::GemmaEnv(const LoaderArgs& loader, const ThreadingArgs& threading,
|
|||
ShowConfig(loader, threading, inference, config, gemma_.WeightReadMode(),
|
||||
ctx_);
|
||||
}
|
||||
if (inference.verbosity >= 3) {
|
||||
env_.print_config = env_.print_best = true;
|
||||
}
|
||||
|
||||
runtime_config_ = {
|
||||
.max_generated_tokens = inference.max_generated_tokens,
|
||||
|
|
|
|||
|
|
@ -48,7 +48,7 @@ class GemmaBatchBench : public ::testing::Test {
|
|||
}
|
||||
};
|
||||
|
||||
TEST_F(GemmaBatchBench, RandomQuestionsBatched) {
|
||||
std::vector<std::string> GenerateInputs() {
|
||||
std::vector<std::string> prompts = {
|
||||
{"Describe dynamic programming."},
|
||||
{"Explain how electric cars work."},
|
||||
|
|
@ -122,22 +122,25 @@ TEST_F(GemmaBatchBench, RandomQuestionsBatched) {
|
|||
inputs.push_back(prompts[qpos++]);
|
||||
if (qpos == prompts.size()) qpos = 0;
|
||||
}
|
||||
s_env->SetMaxGeneratedTokens(24);
|
||||
return inputs;
|
||||
}
|
||||
|
||||
TEST_F(GemmaBatchBench, RandomQuestionsBatched) {
|
||||
s_env->SetMaxGeneratedTokens(12);
|
||||
const std::vector<std::string> inputs = GenerateInputs();
|
||||
|
||||
// Run multiple times so that auto-tuning is closer to complete.
|
||||
for (size_t rep = 0; rep < 4; ++rep) {
|
||||
std::vector<std::string> responses = BatchGemmaReply(inputs);
|
||||
for (size_t i = 0; i < HWY_MIN(hwy::Unpredictable1() * 3, responses.size());
|
||||
++i) {
|
||||
fprintf(stderr, "Batch answer %zu '%s'\n\n", i, responses[i].c_str());
|
||||
fprintf(stderr, "Rep %zu batch answer %zu '%s'\n\n", rep, i,
|
||||
responses[i].c_str());
|
||||
}
|
||||
|
||||
PROFILER_PRINT_RESULTS();
|
||||
|
||||
// Run again: prefill will be faster due to autotuning. Fewer decode steps
|
||||
// because those are already fast.
|
||||
s_env->SetMaxGeneratedTokens(2);
|
||||
responses = BatchGemmaReply(inputs);
|
||||
|
||||
PROFILER_PRINT_RESULTS();
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace gcpp
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue