From 437e0eb9afab794a489460dd61002c284ef568ce Mon Sep 17 00:00:00 2001 From: Daniel Keysers Date: Tue, 3 Sep 2024 06:15:31 -0700 Subject: [PATCH] Internal change. Slight restructuring of gemma_test. PiperOrigin-RevId: 670529565 --- evals/gemma_test.cc | 34 ++++++++++++++++++---------------- 1 file changed, 18 insertions(+), 16 deletions(-) diff --git a/evals/gemma_test.cc b/evals/gemma_test.cc index 6b12f5c..9c56ed2 100644 --- a/evals/gemma_test.cc +++ b/evals/gemma_test.cc @@ -72,20 +72,21 @@ class GemmaTest : public ::testing::Test { for (auto [response, n] : s_env->BatchQueryModel(inputs)) { replies.push_back(response); } - } else { // Not Gemma-2 27B. Do not use turn structure. - std::vector> prompts_vector; - prompts_vector.reserve(inputs.size()); - for (const auto& input_string : inputs) { - prompts_vector.push_back(s_env->TokenizeAndPrependBOS(input_string)); - } - std::vector prompt_spans; - for (const auto& prompt : prompts_vector) { - prompt_spans.push_back(PromptTokens(prompt.data(), prompt.size())); - } - QueriesPromptTokens prompts(prompt_spans.data(), prompt_spans.size()); - for (auto [response, n] : s_env->BatchQueryModel(prompts)) { - replies.push_back(response); - } + return replies; + } + // Not Gemma-2 27B. Do not use turn structure. + std::vector> prompts_vector; + prompts_vector.reserve(inputs.size()); + for (const auto& input_string : inputs) { + prompts_vector.push_back(s_env->TokenizeAndPrependBOS(input_string)); + } + std::vector prompt_spans; + for (const auto& prompt : prompts_vector) { + prompt_spans.push_back(PromptTokens(prompt.data(), prompt.size())); + } + QueriesPromptTokens prompts(prompt_spans.data(), prompt_spans.size()); + for (auto [response, n] : s_env->BatchQueryModel(prompts)) { + replies.push_back(response); } return replies; } @@ -186,8 +187,9 @@ TEST_F(GemmaTest, Multiturn) { model->Generate(runtime_config, tokens, abs_pos, s_env->MutableKVCache(), timing_info); fprintf(stderr, "decoded: %s\n", dialog.c_str()); - bool remembered_turquoise = dialog.find("turquoise") != std::string::npos; - bool remembered_car = dialog.find("car") != std::string::npos; + bool remembered_turquoise = + dialog.find("turquoise") != std::string::npos; // NOLINT + bool remembered_car = dialog.find("car") != std::string::npos; // NOLINT EXPECT_TRUE(remembered_turquoise || remembered_car); }