Update gemma_test to also pass for the v1.1. models.

Make it an error if the model cannot be loaded.

PiperOrigin-RevId: 650232602
This commit is contained in:
Daniel Keysers 2024-07-08 06:44:56 -07:00 committed by Copybara-Service
parent 6a3f7cf3ea
commit cf76f0a401
2 changed files with 66 additions and 47 deletions

View File

@ -213,19 +213,20 @@ cc_library(
cc_test( cc_test(
name = "gemma_test", name = "gemma_test",
srcs = ["evals/gemma_test.cc"], srcs = ["evals/gemma_test.cc"],
# Requires model files
tags = [
"local",
"manual",
"no_tap",
],
deps = [ deps = [
":app",
":args",
":benchmark_helper", ":benchmark_helper",
":common", ":common",
":cross_entropy",
":gemma_lib", ":gemma_lib",
":ops", ":tokenizer",
"@googletest//:gtest_main", "@googletest//:gtest_main",
"//compression:io",
"@hwy//:hwy", "@hwy//:hwy",
"@hwy//:hwy_test_util", "@hwy//:hwy_test_util",
"@hwy//:thread_pool",
], ],
) )

View File

@ -23,13 +23,15 @@
#include "evals/benchmark_helper.h" #include "evals/benchmark_helper.h"
#include "gemma/common.h" #include "gemma/common.h"
#include "gemma/tokenizer.h"
#include "hwy/aligned_allocator.h" #include "hwy/aligned_allocator.h"
#include "hwy/tests/hwy_gtest.h" #include "hwy/tests/hwy_gtest.h"
// This test can be run manually with the downloaded gemma weights. // This test can be run manually with the downloaded gemma weights.
// To run the test, pass the following flags: // To run the test, pass the following flags:
// --model <model> --tokenizer <tokenizer_path> --weights <weights_path> // --model <model> --tokenizer <tokenizer_path> --weights <weights_path>
// It should pass for the following models: 2b-it, 7b-it, 9b-it, 27b-it // It should pass for the following models:
// 2b-it (v1 and v1.1), 7b-it (v1 and v1.1), 9b-it, 27b-it
namespace gcpp { namespace gcpp {
namespace { namespace {
@ -45,7 +47,15 @@ class GemmaTest : public ::testing::Test {
s_env->SetMaxGeneratedTokens(2048); s_env->SetMaxGeneratedTokens(2048);
s_env->MutableConfig().temperature = 0.0f; // deterministic s_env->MutableConfig().temperature = 0.0f; // deterministic
s_env->MutableConfig().verbosity = 0; s_env->MutableConfig().verbosity = 0;
// Using the turn structure worsens results. // Using the turn structure worsens results sometimes.
// However, gemma-2 27B seems to need the turn structure to work.
// It would be good to make these tests more consistent.
if (s_env->GetModel()->Info().model == Model::GEMMA_27B) {
std::string mutable_prompt = prompt;
auto [response, n] = s_env->QueryModel(mutable_prompt); // Uses turns.
return response;
}
// Otherwise, don't use turn structure.
const std::vector<int> tokens = s_env->TokenizeAndPrependBOS(prompt); const std::vector<int> tokens = s_env->TokenizeAndPrependBOS(prompt);
auto [response, n] = s_env->QueryModel(tokens); auto [response, n] = s_env->QueryModel(tokens);
return response; return response;
@ -56,30 +66,38 @@ class GemmaTest : public ::testing::Test {
s_env->SetMaxGeneratedTokens(64); s_env->SetMaxGeneratedTokens(64);
s_env->MutableConfig().temperature = 0.0f; // deterministic s_env->MutableConfig().temperature = 0.0f; // deterministic
s_env->MutableConfig().verbosity = 0; s_env->MutableConfig().verbosity = 0;
// Using the turn structure worsens results.
std::vector<std::unique_ptr<std::vector<int>>> prompts;
prompts.reserve(inputs.size());
for (auto input_string : inputs) {
std::string mutable_input_string = input_string;
prompts.push_back(std::make_unique<std::vector<int>>(
s_env->TokenizeAndPrependBOS(input_string)));
}
std::vector<hwy::Span<int>> prompt_vector;
for (auto& prompt : prompts) {
prompt_vector.push_back(hwy::Span<int>(prompt->data(), prompt->size()));
}
hwy::Span<const hwy::Span<int>> prompt_span =
hwy::Span<const hwy::Span<int>>(prompt_vector.data(),
prompt_vector.size());
std::vector<std::string> replies; std::vector<std::string> replies;
for (auto [response, n] : s_env->BatchQueryModel2(prompt_span)) { // Using the turn structure worsens results sometimes.
replies.push_back(response); // However, gemma-2 27B seems to need the turn structure to work.
// It would be good to make these tests more consistent.
if (s_env->GetModel()->Info().model == Model::GEMMA_27B) {
for (auto [response, n] : s_env->BatchQueryModel(inputs)) {
replies.push_back(response);
}
} else { // Not Gemma-2 27B. Do not use turn structure.
std::vector<std::unique_ptr<std::vector<int>>> prompts;
prompts.reserve(inputs.size());
for (auto input_string : inputs) {
std::string mutable_input_string = input_string;
prompts.push_back(std::make_unique<std::vector<int>>(
s_env->TokenizeAndPrependBOS(input_string)));
}
std::vector<hwy::Span<int>> prompt_vector;
for (auto& prompt : prompts) {
prompt_vector.push_back(hwy::Span<int>(prompt->data(), prompt->size()));
}
hwy::Span<const hwy::Span<int>> prompt_span =
hwy::Span<const hwy::Span<int>>(prompt_vector.data(),
prompt_vector.size());
for (auto [response, n] : s_env->BatchQueryModel2(prompt_span)) {
replies.push_back(response);
}
} }
return replies; return replies;
} }
void TestQuestions(const char* kQA[][2], size_t num_questions, bool batch) { void TestQuestions(const char* kQA[][2], size_t num_questions, bool batch) {
if (!s_env->GetModel()) return; ASSERT_NE(s_env->GetModel(), nullptr);
if (batch) { if (batch) {
std::vector<std::string> inputs; std::vector<std::string> inputs;
for (size_t i = 0; i < num_questions; ++i) { for (size_t i = 0; i < num_questions; ++i) {
@ -171,80 +189,80 @@ static const char kGettysburg[] = {
"people, for the people, shall not perish from the earth.\n"}; "people, for the people, shall not perish from the earth.\n"};
TEST_F(GemmaTest, CrossEntropySmall) { TEST_F(GemmaTest, CrossEntropySmall) {
if (!s_env->GetModel()) return; ASSERT_NE(s_env->GetModel(), nullptr);
static const char kSmall[] = static const char kSmall[] =
"The capital of Hungary is Budapest which is located in Europe."; "The capital of Hungary is Budapest which is located in Europe.";
float entropy = s_env->CrossEntropy(kSmall); float entropy = s_env->CrossEntropy(kSmall);
fprintf(stderr, "per-token entropy: %f\n", entropy); fprintf(stderr, "per-token entropy: %f\n", entropy);
float expected_entropy;
switch (s_env->GetModel()->Info().model) { switch (s_env->GetModel()->Info().model) {
case gcpp::Model::GEMMA_2B: case gcpp::Model::GEMMA_2B:
expected_entropy = 2.56f; // 2B v.1 and v.1.1 produce slightly different results.
EXPECT_NEAR(entropy, 2.6f, 0.2f);
break; break;
case gcpp::Model::GEMMA_7B: case gcpp::Model::GEMMA_7B:
expected_entropy = 2.91f; // 7B v.1 and v.1.1 produce slightly different results.
EXPECT_NEAR(entropy, 2.8f, 0.2f);
break; break;
case gcpp::Model::GEMMA_9B: case gcpp::Model::GEMMA_9B:
expected_entropy = 1.28f; EXPECT_NEAR(entropy, 1.28f, 0.02f);
break; break;
case gcpp::Model::GEMMA_27B: case gcpp::Model::GEMMA_27B:
expected_entropy = 1.30f; EXPECT_NEAR(entropy, 1.30f, 0.02f);
break; break;
default: default:
FAIL() << "no entropy expectation for this model"; FAIL() << "no entropy expectation for this model";
break; break;
} }
EXPECT_NEAR(entropy, expected_entropy, 0.02f);
} }
TEST_F(GemmaTest, CrossEntropyJingleBells) { TEST_F(GemmaTest, CrossEntropyJingleBells) {
if (!s_env->GetModel()) return; ASSERT_NE(s_env->GetModel(), nullptr);
float entropy = s_env->CrossEntropy(kJingleBells); float entropy = s_env->CrossEntropy(kJingleBells);
fprintf(stderr, "per-token entropy: %f\n", entropy); fprintf(stderr, "per-token entropy: %f\n", entropy);
float expected_entropy;
switch (s_env->GetModel()->Info().model) { switch (s_env->GetModel()->Info().model) {
case gcpp::Model::GEMMA_2B: case gcpp::Model::GEMMA_2B:
expected_entropy = 1.85f; // 2B v.1 and v.1.1 produce slightly different results.
EXPECT_NEAR(entropy, 1.9f, 0.2f);
break; break;
case gcpp::Model::GEMMA_7B: case gcpp::Model::GEMMA_7B:
expected_entropy = 1.06f; // 7B v.1 and v.1.1 produce slightly different results.
EXPECT_NEAR(entropy, 1.07f, 0.05f);
break; break;
case gcpp::Model::GEMMA_9B: case gcpp::Model::GEMMA_9B:
expected_entropy = 0.37f; EXPECT_NEAR(entropy, 0.37f, 0.02f);
break; break;
case gcpp::Model::GEMMA_27B: case gcpp::Model::GEMMA_27B:
expected_entropy = 0.33f; EXPECT_NEAR(entropy, 0.33f, 0.02f);
break; break;
default: default:
FAIL() << "no entropy expectation for this model"; FAIL() << "no entropy expectation for this model";
break; break;
} }
EXPECT_NEAR(entropy, expected_entropy, 0.02f);
} }
TEST_F(GemmaTest, CrossEntropyGettysburg) { TEST_F(GemmaTest, CrossEntropyGettysburg) {
if (!s_env->GetModel()) return; ASSERT_NE(s_env->GetModel(), nullptr);
float entropy = s_env->CrossEntropy(kGettysburg); float entropy = s_env->CrossEntropy(kGettysburg);
fprintf(stderr, "per-token entropy: %f\n", entropy); fprintf(stderr, "per-token entropy: %f\n", entropy);
float expected_entropy;
switch (s_env->GetModel()->Info().model) { switch (s_env->GetModel()->Info().model) {
case gcpp::Model::GEMMA_2B: case gcpp::Model::GEMMA_2B:
expected_entropy = 1.05f; // 2B v.1 and v.1.1 produce slightly different results.
EXPECT_NEAR(entropy, 1.1f, 0.1f);
break; break;
case gcpp::Model::GEMMA_7B: case gcpp::Model::GEMMA_7B:
expected_entropy = 0.83f; // 7B v.1 and v.1.1 produce slightly different results.
EXPECT_NEAR(entropy, 0.75f, 0.1f);
break; break;
case gcpp::Model::GEMMA_9B: case gcpp::Model::GEMMA_9B:
expected_entropy = 0.15f; EXPECT_NEAR(entropy, 0.15f, 0.02f);
break; break;
case gcpp::Model::GEMMA_27B: case gcpp::Model::GEMMA_27B:
expected_entropy = 0.14f; EXPECT_NEAR(entropy, 0.14f, 0.02f);
break; break;
default: default:
FAIL() << "no entropy expectation for this model"; FAIL() << "no entropy expectation for this model";
break; break;
} }
EXPECT_NEAR(entropy, expected_entropy, 0.02f);
} }
} // namespace } // namespace