mirror of https://github.com/google/gemma.cpp.git
Shorten gemma_test so we can run it for more models.
PiperOrigin-RevId: 759685282
This commit is contained in:
parent
e890d46f30
commit
d6cfabc2c1
|
|
@ -43,92 +43,26 @@ class GemmaTest : public ::testing::Test {
|
|||
HWY_ASSERT(s_env == nullptr); // Should only be called once.
|
||||
s_env = new GemmaEnv(argc, argv);
|
||||
const gcpp::ModelConfig& config = s_env->GetGemma()->GetModelConfig();
|
||||
fprintf(stderr, "Using %s)\n", config.Specifier().c_str());
|
||||
fprintf(stderr, "Using %s\n", config.Specifier().c_str());
|
||||
}
|
||||
|
||||
static void DeleteEnv() { delete s_env; }
|
||||
|
||||
protected:
|
||||
std::string GemmaReply(const std::string& prompt) {
|
||||
HWY_ASSERT(s_env); // must have called InitEnv()
|
||||
s_env->SetMaxGeneratedTokens(2048);
|
||||
s_env->MutableConfig().temperature = 0.0f; // deterministic
|
||||
s_env->MutableConfig().verbosity = 0;
|
||||
const ModelConfig& config = s_env->GetGemma()->GetModelConfig();
|
||||
// Using the turn structure worsens results sometimes.
|
||||
// However, some models need the turn structure to work.
|
||||
// It would be good to make these tests more consistent.
|
||||
if (config.model == Model::GEMMA2_27B ||
|
||||
config.model == Model::GRIFFIN_2B) {
|
||||
std::string mutable_prompt = prompt;
|
||||
QueryResult result = s_env->QueryModel(mutable_prompt); // Uses turns.
|
||||
return result.response;
|
||||
}
|
||||
// Otherwise, do not use turn structure.
|
||||
const std::vector<int> tokens = s_env->TokenizeAndPrependBOS(prompt);
|
||||
QueryResult result = s_env->QueryModel(tokens);
|
||||
return result.response;
|
||||
}
|
||||
|
||||
std::vector<std::string> BatchGemmaReply(
|
||||
const std::vector<std::string>& inputs) {
|
||||
HWY_ASSERT(s_env); // must have called InitEnv()
|
||||
s_env->SetMaxGeneratedTokens(64);
|
||||
s_env->MutableConfig().temperature = 0.0f; // deterministic
|
||||
s_env->MutableConfig().verbosity = 0;
|
||||
const ModelConfig& config = s_env->GetGemma()->GetModelConfig();
|
||||
// Always use turn structure (WrapAndTokenize).
|
||||
std::vector<std::string> replies;
|
||||
// Using the turn structure worsens results sometimes.
|
||||
// However, some models need the turn structure to work.
|
||||
// It would be good to make these tests more consistent.
|
||||
if (config.model == Model::GEMMA2_27B ||
|
||||
config.model == Model::GRIFFIN_2B) {
|
||||
for (QueryResult result : s_env->BatchQueryModel(inputs)) {
|
||||
replies.push_back(result.response);
|
||||
}
|
||||
return replies;
|
||||
}
|
||||
// Otherwise, do not use turn structure.
|
||||
std::vector<std::vector<int>> prompts_vector;
|
||||
prompts_vector.reserve(inputs.size());
|
||||
for (const auto& input_string : inputs) {
|
||||
prompts_vector.push_back(s_env->TokenizeAndPrependBOS(input_string));
|
||||
}
|
||||
std::vector<PromptTokens> prompt_spans;
|
||||
for (const auto& prompt : prompts_vector) {
|
||||
prompt_spans.push_back(PromptTokens(prompt.data(), prompt.size()));
|
||||
}
|
||||
QueriesPromptTokens prompts(prompt_spans.data(), prompt_spans.size());
|
||||
for (const QueryResult& result : s_env->BatchQueryModel(prompts)) {
|
||||
for (QueryResult result : s_env->BatchQueryModel(inputs)) {
|
||||
replies.push_back(result.response);
|
||||
}
|
||||
return replies;
|
||||
}
|
||||
|
||||
void TestQuestions(const char* kQA[][2], size_t num_questions, bool batch) {
|
||||
HWY_ASSERT(s_env->GetGemma() != nullptr);
|
||||
if (batch) {
|
||||
std::vector<std::string> inputs;
|
||||
for (size_t i = 0; i < num_questions; ++i) {
|
||||
fprintf(stderr, "Batch Question %zu\n\n", i + 1);
|
||||
inputs.push_back(kQA[i][0]);
|
||||
}
|
||||
std::vector<std::string> responses = BatchGemmaReply(inputs);
|
||||
for (size_t i = 0; i < num_questions; ++i) {
|
||||
std::string response = responses.at(i);
|
||||
fprintf(stderr, "Batch answer %zu '%s'\n\n", i + 1, response.c_str());
|
||||
EXPECT_TRUE(response.find(kQA[i][1]) != std::string::npos); // NOLINT
|
||||
}
|
||||
} else {
|
||||
for (size_t i = 0; i < num_questions; ++i) {
|
||||
fprintf(stderr, "Question %zu\n\n", i + 1);
|
||||
std::string response = GemmaReply(kQA[i][0]);
|
||||
fprintf(stderr, "'%s'\n\n", response.c_str());
|
||||
EXPECT_TRUE(response.find(kQA[i][1]) != std::string::npos); // NOLINT
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Shared state. Requires argc/argv, so construct in main via InitEnv.
|
||||
// Note that the style guide forbids non-local static variables with dtors.
|
||||
static GemmaEnv* s_env;
|
||||
|
|
@ -136,44 +70,34 @@ class GemmaTest : public ::testing::Test {
|
|||
|
||||
GemmaEnv* GemmaTest::s_env = nullptr;
|
||||
|
||||
TEST_F(GemmaTest, GeographyBatched) {
|
||||
s_env->MutableConfig().decode_qbatch_size = 3;
|
||||
// 6 are enough to test batching and the loop.
|
||||
TEST_F(GemmaTest, Batched) {
|
||||
// Test remainder handling in MatMul (four rows per tile), but avoid a
|
||||
// second batch in debug builds to speed up the test.
|
||||
s_env->MutableConfig().decode_qbatch_size = HWY_IS_DEBUG_BUILD ? 6 : 3;
|
||||
static const char* kQA[][2] = {
|
||||
{"What is the capital of Australia?", "Canberra"},
|
||||
{"What is the capital of Denmark?", "Copenhagen"},
|
||||
{"Ljubljana is the capital of which country?", "Slovenia"},
|
||||
{"Is Chicago a country?", "city"},
|
||||
{"How many states does the US have?", "50"},
|
||||
{"What is the Pacific?", "ocean"},
|
||||
};
|
||||
static const size_t kNum = sizeof(kQA) / sizeof(kQA[0]);
|
||||
TestQuestions(kQA, HWY_MIN(kNum, 3), /*batch=*/false);
|
||||
TestQuestions(kQA, 1, /*batch=*/true);
|
||||
TestQuestions(kQA, kNum, /*batch=*/true);
|
||||
}
|
||||
|
||||
TEST_F(GemmaTest, History) {
|
||||
static const char* kQA[][2] = {
|
||||
{"When was the battle of Hastings?", "1066"},
|
||||
};
|
||||
static const size_t kNum = sizeof(kQA) / sizeof(kQA[0]);
|
||||
TestQuestions(kQA, kNum, /*batch=*/false);
|
||||
}
|
||||
|
||||
TEST_F(GemmaTest, Arithmetic) {
|
||||
static const char* kQA[][2] = {
|
||||
{"what is 13 + 14?", "27"},
|
||||
{"what is 7 * 8?", "56"},
|
||||
};
|
||||
static const size_t kNum = sizeof(kQA) / sizeof(kQA[0]);
|
||||
TestQuestions(kQA, kNum, /*batch=*/false);
|
||||
const size_t kNum = sizeof(kQA) / sizeof(kQA[0]);
|
||||
std::vector<std::string> inputs;
|
||||
for (size_t i = 0; i < kNum; ++i) {
|
||||
inputs.push_back(kQA[i][0]);
|
||||
}
|
||||
std::vector<std::string> responses = BatchGemmaReply(inputs);
|
||||
HWY_ASSERT(responses.size() == kNum);
|
||||
for (size_t i = 0; i < kNum; ++i) {
|
||||
fprintf(stderr, "#%zu: '%s'\n\n", i, responses[i].c_str());
|
||||
EXPECT_TRUE(responses[i].find(kQA[i][1]) != std::string::npos); // NOLINT
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(GemmaTest, Multiturn) {
|
||||
const Gemma* model = s_env->GetGemma();
|
||||
const ModelConfig& config = model->GetModelConfig();
|
||||
HWY_ASSERT(model != nullptr);
|
||||
size_t abs_pos = 0;
|
||||
std::string response;
|
||||
auto stream_token = [&](int token, float) {
|
||||
|
|
@ -220,41 +144,6 @@ TEST_F(GemmaTest, Multiturn) {
|
|||
EXPECT_TRUE(remembered_turquoise || remembered_car);
|
||||
}
|
||||
|
||||
static const char kJingleBells[] = R"(
|
||||
Dashing through the snow
|
||||
In a one-horse open sleigh
|
||||
O'er the fields we go
|
||||
Laughing all the way
|
||||
Bells on bobtails ring
|
||||
Making spirits bright
|
||||
What fun it is to ride and sing
|
||||
A sleighing song tonight
|
||||
)";
|
||||
|
||||
// The "Hay Draft" of the Gettysburg Address.
|
||||
static const char kGettysburg[] = {
|
||||
"Four score and seven years ago our fathers brought forth, upon this "
|
||||
"continent, a new nation, conceived in Liberty, and dedicated to the "
|
||||
"proposition that all men are created equal.\n\nNow we are engaged in a "
|
||||
"great civil war, testing whether that nation, or any nation, so "
|
||||
"conceived, and so dedicated, can long endure. We are met here on a great "
|
||||
"battlefield of that war. We have come to dedicate a portion of it as a "
|
||||
"final resting place for those who here gave their lives that that nation "
|
||||
"might live. It is altogether fitting and proper that we should do "
|
||||
"this.\n\nBut in a larger sense we can not dedicate -- we can not "
|
||||
"consecrate -- we can not hallow this ground. The brave men, living and "
|
||||
"dead, who struggled, here, have consecrated it far above our poor power "
|
||||
"to add or detract. The world will little note, nor long remember, what we "
|
||||
"say here, but can never forget what they did here. It is for us, the "
|
||||
"living, rather to be dedicated here to the unfinished work which they "
|
||||
"have, thus far, so nobly carried on. It is rather for us to be here "
|
||||
"dedicated to the great task remaining before us -- that from these "
|
||||
"honored dead we take increased devotion to that cause for which they here "
|
||||
"gave the last full measure of devotion -- that we here highly resolve "
|
||||
"that these dead shall not have died in vain; that this nation shall have "
|
||||
"a new birth of freedom; and that this government of the people, by the "
|
||||
"people, for the people, shall not perish from the earth.\n"};
|
||||
|
||||
TEST_F(GemmaTest, CrossEntropySmall) {
|
||||
HWY_ASSERT(s_env->GetGemma() != nullptr);
|
||||
const ModelConfig& config = s_env->GetGemma()->GetModelConfig();
|
||||
|
|
@ -281,54 +170,6 @@ TEST_F(GemmaTest, CrossEntropySmall) {
|
|||
}
|
||||
}
|
||||
|
||||
TEST_F(GemmaTest, CrossEntropyJingleBells) {
|
||||
HWY_ASSERT(s_env->GetGemma() != nullptr);
|
||||
const ModelConfig& config = s_env->GetGemma()->GetModelConfig();
|
||||
float entropy = s_env->CrossEntropy(kJingleBells);
|
||||
fprintf(stderr, "per-token entropy: %f\n", entropy);
|
||||
switch (config.model) {
|
||||
case gcpp::Model::GRIFFIN_2B:
|
||||
EXPECT_NEAR(entropy, 1.62f, 0.02f);
|
||||
break;
|
||||
case gcpp::Model::GEMMA2_2B:
|
||||
EXPECT_NEAR(entropy, 0.49f, 0.02f);
|
||||
break;
|
||||
case gcpp::Model::GEMMA2_9B:
|
||||
EXPECT_NEAR(entropy, 0.37f, 0.02f);
|
||||
break;
|
||||
case gcpp::Model::GEMMA2_27B:
|
||||
EXPECT_NEAR(entropy, 0.33f, 0.02f);
|
||||
break;
|
||||
default:
|
||||
FAIL() << "no entropy expectation for this model";
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(GemmaTest, CrossEntropyGettysburg) {
|
||||
HWY_ASSERT(s_env->GetGemma() != nullptr);
|
||||
const ModelConfig& config = s_env->GetGemma()->GetModelConfig();
|
||||
float entropy = s_env->CrossEntropy(kGettysburg);
|
||||
fprintf(stderr, "per-token entropy: %f\n", entropy);
|
||||
switch (config.model) {
|
||||
case gcpp::Model::GRIFFIN_2B:
|
||||
EXPECT_NEAR(entropy, 0.71f, 0.02f);
|
||||
break;
|
||||
case gcpp::Model::GEMMA2_2B:
|
||||
EXPECT_NEAR(entropy, 0.20f, 0.02f);
|
||||
break;
|
||||
case gcpp::Model::GEMMA2_9B:
|
||||
EXPECT_NEAR(entropy, 0.15f, 0.02f);
|
||||
break;
|
||||
case gcpp::Model::GEMMA2_27B:
|
||||
EXPECT_NEAR(entropy, 0.14f, 0.02f);
|
||||
break;
|
||||
default:
|
||||
FAIL() << "no entropy expectation for this model";
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace gcpp
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue