gemma.cpp/evals/gemma_test.cc

187 lines
6.6 KiB
C++

// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "gemma/gemma.h"
#include <stdio.h>
#include <string>
#include <vector>
#include "evals/benchmark_helper.h"
#include "gemma/configs.h"
#include "io/io.h"
#include "hwy/base.h"
#include "hwy/tests/hwy_gtest.h"
// This test can be run manually with the downloaded gemma weights.
// To run the test, pass the following flags:
// --model <model> --tokenizer <tokenizer_path> --weights <weights_path>
// or just use the single-file weights file with --weights <weights_path>.
// It should pass for the following models:
// Gemma1: 2b-it (v1 and v1.1), 7b-it (v1 and v1.1), gr2b-it,
// Gemma2: gemma2-2b-it, 9b-it, 27b-it,
namespace gcpp {
namespace {
class GemmaTest : public ::testing::Test {
public:
// Requires argc/argv, hence do not use `SetUpTestSuite`.
static void InitEnv(int argc, char** argv) {
HWY_ASSERT(s_env == nullptr); // Should only be called once.
s_env = new GemmaEnv(argc, argv);
const gcpp::ModelConfig& config = s_env->GetGemma()->Config();
fprintf(stderr, "Using %s\n", config.Specifier().c_str());
}
static void DeleteEnv() { delete s_env; }
protected:
std::vector<std::string> BatchGemmaReply(
const std::vector<std::string>& inputs) {
HWY_ASSERT(s_env); // must have called InitEnv()
s_env->SetMaxGeneratedTokens(64);
s_env->MutableConfig().temperature = 0.0f; // deterministic
s_env->MutableConfig().verbosity = 0;
// Always use turn structure (WrapAndTokenize).
std::vector<std::string> replies;
for (QueryResult result : s_env->BatchQueryModel(inputs)) {
replies.push_back(result.response);
}
return replies;
}
// Shared state. Requires argc/argv, so construct in main via InitEnv.
// Note that the style guide forbids non-local static variables with dtors.
static GemmaEnv* s_env;
};
GemmaEnv* GemmaTest::s_env = nullptr;
TEST_F(GemmaTest, Batched) {
// Test remainder handling in MatMul (four rows per tile), but avoid a
// second batch in debug builds to speed up the test.
s_env->MutableConfig().decode_qbatch_size = HWY_IS_DEBUG_BUILD ? 6 : 3;
static const char* kQA[][2] = {
{"What is the capital of Australia?", "Canberra"},
{"How many states does the US have?", "50"},
{"What is the Pacific?", "ocean"},
{"When was the battle of Hastings?", "1066"},
{"what is 13 + 14?", "27"},
{"what is 7 * 8?", "56"},
};
const size_t kNum = sizeof(kQA) / sizeof(kQA[0]);
std::vector<std::string> inputs;
for (size_t i = 0; i < kNum; ++i) {
inputs.push_back(kQA[i][0]);
}
std::vector<std::string> responses = BatchGemmaReply(inputs);
HWY_ASSERT(responses.size() == kNum);
for (size_t i = 0; i < kNum; ++i) {
fprintf(stderr, "#%zu: '%s'\n\n", i, responses[i].c_str());
EXPECT_TRUE(responses[i].find(kQA[i][1]) != std::string::npos); // NOLINT
}
}
TEST_F(GemmaTest, Multiturn) {
const Gemma* model = s_env->GetGemma();
const ModelConfig& config = model->Config();
size_t abs_pos = 0;
std::string response;
auto stream_token = [&](size_t query_idx, size_t pos, int token, float) {
HWY_ASSERT(query_idx == 0);
HWY_ASSERT(pos == abs_pos);
++abs_pos;
if (config.IsEOS(token)) return true;
std::string token_text;
EXPECT_TRUE(
model->Tokenizer().Decode(std::vector<int>{token}, &token_text));
response += token_text;
return true;
};
RuntimeConfig runtime_config{
.max_generated_tokens = 64,
.temperature = 0.0f,
.verbosity = 2,
.batch_stream_token = stream_token,
};
TimingInfo timing_info{.verbosity = 0};
// First "say" something slightly unusual.
std::string mutable_prompt = "I have a car and its color is turquoise.";
std::vector<int> tokens =
WrapAndTokenize(model->Tokenizer(), model->ChatTemplate(),
config.wrapping, abs_pos, mutable_prompt);
model->Generate(runtime_config, tokens, abs_pos, s_env->MutableKVCache(),
s_env->MutableEnv(), timing_info);
// Note: we do not rewind any <end_of_turn> tokens here. If the model
// produced one and WrapAndTokenize() inserts another one, it will just be
// duplicated.
mutable_prompt = "Please repeat all prior statements.";
tokens = WrapAndTokenize(model->Tokenizer(), model->ChatTemplate(),
config.wrapping, abs_pos, mutable_prompt);
// Reset the `response` string here, then check that the model actually has
// access to the previous turn by asking to reproduce.
response.clear();
// -1 because our prefill does not generate KVs for the last token. Do not
// just pass abs_pos - 1 because our callback checks pos == abs_pos.
HWY_ASSERT(abs_pos > 0);
--abs_pos;
model->Generate(runtime_config, tokens, abs_pos, s_env->MutableKVCache(),
s_env->MutableEnv(), timing_info);
fprintf(stderr, "decoded: '%s'\n", response.c_str());
bool remembered_turquoise =
response.find("turquoise") != std::string::npos; // NOLINT
bool remembered_car = response.find("car") != std::string::npos; // NOLINT
EXPECT_TRUE(remembered_turquoise || remembered_car);
}
TEST_F(GemmaTest, CrossEntropySmall) {
HWY_ASSERT(s_env->GetGemma() != nullptr);
const ModelConfig& config = s_env->GetGemma()->Config();
static const char kSmall[] =
"The capital of Hungary is Budapest which is located in Europe.";
float entropy = s_env->CrossEntropy(kSmall);
fprintf(stderr, "per-token entropy: %f\n", entropy);
switch (config.model) {
case gcpp::Model::GEMMA2_2B:
EXPECT_NEAR(entropy, 1.14f, 0.02f);
break;
case gcpp::Model::GEMMA2_9B:
EXPECT_NEAR(entropy, 1.28f, 0.02f);
break;
case gcpp::Model::GEMMA2_27B:
EXPECT_NEAR(entropy, 1.30f, 0.02f);
break;
default:
FAIL() << "no entropy expectation for this model";
break;
}
}
} // namespace
} // namespace gcpp
int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv);
gcpp::InternalInit();
gcpp::GemmaTest::InitEnv(argc, argv);
int ret = RUN_ALL_TESTS();
gcpp::GemmaTest::DeleteEnv();
return ret;
}