Fix testing::SrcDir() path resolution in wheat_from_chaff_test

Also use a list of acceptable substring matchers for each question instead of just one

PiperOrigin-RevId: 883198819
This commit is contained in:
Nikhil Dev Goyal 2026-03-13 09:17:14 -07:00 committed by Copybara-Service
parent 529c201eb6
commit 0110ddfee7
1 changed files with 28 additions and 11 deletions

View File

@ -46,11 +46,18 @@ static const char* kQuestions =
"Which people first proposed the quark model of hadrons, and when?";
// All phrases in kAnswers must appear in the response in the order given for
// the test to pass.
static const char* kAnswers[] = {
"a ship's anchor", "a dark forest", "an hour",
"enormous sand", "castles", "limpet shells",
"Murray Gell-Mann", "George Zweig", "1964"};
// the test to pass. Multiple acceptable answers can be provided for each
// expected phrase.
static const std::vector<std::vector<const char*>> kAnswers = {
{"rusty metal", "ship's anchor"},
{"dark forest"},
{"an hour"},
{"sand"},
{"castle"},
{"limpet shells"},
{"Murray Gell-Mann"},
{"George Zweig"},
{"1964"}};
std::string LoadPromptFile(const std::string& filename) {
// If the filename is empty, return an empty string.
@ -108,12 +115,22 @@ class GemmaTest : public ::testing::Test {
void TestExpectations(const std::string& response) {
fprintf(stderr, "Response: '%s'\n", response.c_str());
size_t pos = 0;
for (const char* answer : kAnswers) {
auto found = response.find(answer, pos);
EXPECT_NE(found, std::string::npos)
<< "Response does not contain " << answer;
if (found != std::string::npos) {
pos = found + strlen(answer);
for (const auto& answer_group : kAnswers) {
size_t earliest_pos = std::string::npos;
const char* matched_answer = nullptr;
for (const char* answer : answer_group) {
auto found = response.find(answer, pos);
if (found != std::string::npos &&
(earliest_pos == std::string::npos || found < earliest_pos)) {
earliest_pos = found;
matched_answer = answer;
}
}
EXPECT_NE(earliest_pos, std::string::npos)
<< "Response does not contain acceptable answers, e.g., "
<< answer_group[0];
if (earliest_pos != std::string::npos) {
pos = earliest_pos + strlen(matched_answer);
}
}
s_env->PrintProfileResults();