gemma.cpp/evals/wheat_from_chaff_test.cc

179 lines
6.4 KiB
C++

// Copyright 2026 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <stdio.h>
#include <cstring>
#include <string>
#include <vector>
#include "evals/benchmark_helper.h"
#include "gemma/configs.h"
#include "gemma/gemma.h"
#include "io/io.h"
#include "hwy/base.h"
#include "hwy/tests/hwy_gtest.h"
// This test can be run manually with the downloaded gemma weights.
// To run the test, pass the following flags:
// --tokenizer <tokenizer_path> --weights <weights_path>
// or just use the single-file weights file with --weights <weights_path>.
// It should pass for the following models:
// Gemma2: gemma2-2b-it
namespace gcpp {
namespace {
static const char* kQuestions =
"From the above information, please answer the following questions: "
"What did Marcia find in the sand? "
"What is Albert's preferred holiday activity? "
"How long did it take to dig out the object from the sand? "
"What is Marcia's preferred holiday activity? "
"What made the castle turrets look like daleks? "
"Which people first proposed the quark model of hadrons, and when?";
// All phrases in kAnswers must appear in the response in the order given for
// the test to pass.
static const char* kAnswers[] = {
"a ship's anchor", "a dark forest", "an hour",
"enormous sand", "castles", "limpet shells",
"Murray Gell-Mann", "George Zweig", "1964"};
std::string LoadPromptFile(const std::string& filename) {
// If the filename is empty, return an empty string.
if (filename.empty()) {
return "";
}
std::string path = testing::SrcDir() +
"evals/testdata/"
+ filename;
return ReadFileToString(Path(path));
}
std::string BuildPrompt(const std::vector<std::string>& files,
const std::string& suffix) {
std::string prompt;
for (const std::string& file : files) {
prompt += LoadPromptFile(file);
}
prompt += suffix;
return prompt;
}
class GemmaTest : public ::testing::Test {
public:
// Requires argc/argv, hence do not use `SetUpTestSuite`.
static void InitEnv(int argc, char** argv) {
HWY_ASSERT(s_env == nullptr); // Should only be called once.
ConsumedArgs consumed(argc, argv);
GemmaArgs args(argc, argv, consumed);
consumed.AbortIfUnconsumed();
s_env = new GemmaEnv(args);
const gcpp::ModelConfig& config = s_env->GetGemma()->Config();
fprintf(stderr, "Using %s\n", config.Specifier().c_str());
}
static void DeleteEnv() { delete s_env; }
protected:
std::string GemmaReply(const std::string& input,
AttentionImpl attention_mode) {
HWY_ASSERT(s_env); // must have called InitEnv()
s_env->SetMaxGeneratedTokens(256);
s_env->MutableConfig().attention_impl = attention_mode;
s_env->MutableConfig().temperature = 0.0f; // deterministic
s_env->MutableConfig().verbosity = 1;
// Always use turn structure (WrapAndTokenize).
auto response = s_env->QueryModel(input);
return response.response.substr(response.response_start_pos);
}
// Checks that the response contains the expected answer substrings int the
// expected order. Testing against a few keywords is more robust than checking
// the whole string.
void TestExpectations(const std::string& response) {
fprintf(stderr, "Response: '%s'\n", response.c_str());
size_t pos = 0;
for (const char* answer : kAnswers) {
auto found = response.find(answer, pos);
EXPECT_NE(found, std::string::npos)
<< "Response does not contain " << answer;
if (found != std::string::npos) {
pos = found + strlen(answer);
}
}
s_env->PrintProfileResults();
}
// Shared state. Requires argc/argv, so construct in main via InitEnv.
// Note that the style guide forbids non-local static variables with dtors.
static GemmaEnv* s_env;
};
GemmaEnv* GemmaTest::s_env = nullptr;
// Tests whether Gemma can find the right answer in varying levels of
// background information, ranging from the bare facts to outright distraction.
TEST_F(GemmaTest, WheatFromChaff) {
const AttentionImpl modes[] = {AttentionImpl::kOld, AttentionImpl::kFlash};
fprintf(stderr, "Warmup, mode %s\n", GetAttentionImplName(modes[0]).c_str());
auto prompt = BuildPrompt({"quark_1.txt", "holiday_story.txt"}, kQuestions);
auto response = GemmaReply(prompt, modes[0]);
TestExpectations(response);
for (const AttentionImpl mode : modes) {
const std::string mode_name = GetAttentionImplName(mode);
fprintf(stderr, "\nTesting quark_1 prompt, mode %s\n", mode_name.c_str());
prompt = BuildPrompt({"holiday_story.txt", "quark_1.txt"}, kQuestions);
response = GemmaReply(prompt, mode);
TestExpectations(response);
fprintf(stderr, "\nTesting quark_2 prompt, mode %s\n", mode_name.c_str());
prompt = BuildPrompt({"holiday_story.txt", "quark_2.txt"}, kQuestions);
response = GemmaReply(prompt, mode);
TestExpectations(response);
fprintf(stderr, "\nTesting standard_model prompt, mode %s\n",
mode_name.c_str());
prompt = BuildPrompt(
{"holiday_story.txt", "quark_2.txt", "standard_model.txt"}, kQuestions);
response = GemmaReply(prompt, mode);
TestExpectations(response);
if (s_env->MutableKVCache().SeqLen() > 38000) {
fprintf(stderr, "\nTesting special_relativity, mode %s\n",
mode_name.c_str());
prompt = BuildPrompt(
{"holiday_story.txt", "quark_2.txt", "special_relativity.txt"},
kQuestions);
} else {
fprintf(stderr, "\nSkipping special_relativity, mode %s\n",
mode_name.c_str());
prompt = BuildPrompt({"quark_1.txt", "holiday_story.txt"}, kQuestions);
}
response = GemmaReply(prompt, mode);
TestExpectations(response);
}
}
} // namespace
} // namespace gcpp
int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv);
gcpp::GemmaTest::InitEnv(argc, argv);
int ret = RUN_ALL_TESTS();
gcpp::GemmaTest::DeleteEnv();
return ret;
}