Internal change. Slight restructuring of gemma_test.

PiperOrigin-RevId: 670529565
This commit is contained in:
Daniel Keysers 2024-09-03 06:15:31 -07:00 committed by Copybara-Service
parent a8e08778d4
commit 437e0eb9af
1 changed files with 18 additions and 16 deletions

View File

@ -72,20 +72,21 @@ class GemmaTest : public ::testing::Test {
for (auto [response, n] : s_env->BatchQueryModel(inputs)) {
replies.push_back(response);
}
} else { // Not Gemma-2 27B. Do not use turn structure.
std::vector<std::vector<int>> prompts_vector;
prompts_vector.reserve(inputs.size());
for (const auto& input_string : inputs) {
prompts_vector.push_back(s_env->TokenizeAndPrependBOS(input_string));
}
std::vector<PromptTokens> prompt_spans;
for (const auto& prompt : prompts_vector) {
prompt_spans.push_back(PromptTokens(prompt.data(), prompt.size()));
}
QueriesPromptTokens prompts(prompt_spans.data(), prompt_spans.size());
for (auto [response, n] : s_env->BatchQueryModel(prompts)) {
replies.push_back(response);
}
return replies;
}
// Not Gemma-2 27B. Do not use turn structure.
std::vector<std::vector<int>> prompts_vector;
prompts_vector.reserve(inputs.size());
for (const auto& input_string : inputs) {
prompts_vector.push_back(s_env->TokenizeAndPrependBOS(input_string));
}
std::vector<PromptTokens> prompt_spans;
for (const auto& prompt : prompts_vector) {
prompt_spans.push_back(PromptTokens(prompt.data(), prompt.size()));
}
QueriesPromptTokens prompts(prompt_spans.data(), prompt_spans.size());
for (auto [response, n] : s_env->BatchQueryModel(prompts)) {
replies.push_back(response);
}
return replies;
}
@ -186,8 +187,9 @@ TEST_F(GemmaTest, Multiturn) {
model->Generate(runtime_config, tokens, abs_pos, s_env->MutableKVCache(),
timing_info);
fprintf(stderr, "decoded: %s\n", dialog.c_str());
bool remembered_turquoise = dialog.find("turquoise") != std::string::npos;
bool remembered_car = dialog.find("car") != std::string::npos;
bool remembered_turquoise =
dialog.find("turquoise") != std::string::npos; // NOLINT
bool remembered_car = dialog.find("car") != std::string::npos; // NOLINT
EXPECT_TRUE(remembered_turquoise || remembered_car);
}