mirror of https://github.com/google/gemma.cpp.git
Update gemma_test to also pass for the v1.1. models.
Make it an error if the model cannot be loaded. PiperOrigin-RevId: 650232602
This commit is contained in:
parent
6a3f7cf3ea
commit
cf76f0a401
13
BUILD.bazel
13
BUILD.bazel
|
|
@ -213,19 +213,20 @@ cc_library(
|
||||||
cc_test(
|
cc_test(
|
||||||
name = "gemma_test",
|
name = "gemma_test",
|
||||||
srcs = ["evals/gemma_test.cc"],
|
srcs = ["evals/gemma_test.cc"],
|
||||||
|
# Requires model files
|
||||||
|
tags = [
|
||||||
|
"local",
|
||||||
|
"manual",
|
||||||
|
"no_tap",
|
||||||
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":app",
|
|
||||||
":args",
|
|
||||||
":benchmark_helper",
|
":benchmark_helper",
|
||||||
":common",
|
":common",
|
||||||
":cross_entropy",
|
|
||||||
":gemma_lib",
|
":gemma_lib",
|
||||||
":ops",
|
":tokenizer",
|
||||||
"@googletest//:gtest_main",
|
"@googletest//:gtest_main",
|
||||||
"//compression:io",
|
|
||||||
"@hwy//:hwy",
|
"@hwy//:hwy",
|
||||||
"@hwy//:hwy_test_util",
|
"@hwy//:hwy_test_util",
|
||||||
"@hwy//:thread_pool",
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -23,13 +23,15 @@
|
||||||
|
|
||||||
#include "evals/benchmark_helper.h"
|
#include "evals/benchmark_helper.h"
|
||||||
#include "gemma/common.h"
|
#include "gemma/common.h"
|
||||||
|
#include "gemma/tokenizer.h"
|
||||||
#include "hwy/aligned_allocator.h"
|
#include "hwy/aligned_allocator.h"
|
||||||
#include "hwy/tests/hwy_gtest.h"
|
#include "hwy/tests/hwy_gtest.h"
|
||||||
|
|
||||||
// This test can be run manually with the downloaded gemma weights.
|
// This test can be run manually with the downloaded gemma weights.
|
||||||
// To run the test, pass the following flags:
|
// To run the test, pass the following flags:
|
||||||
// --model <model> --tokenizer <tokenizer_path> --weights <weights_path>
|
// --model <model> --tokenizer <tokenizer_path> --weights <weights_path>
|
||||||
// It should pass for the following models: 2b-it, 7b-it, 9b-it, 27b-it
|
// It should pass for the following models:
|
||||||
|
// 2b-it (v1 and v1.1), 7b-it (v1 and v1.1), 9b-it, 27b-it
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
@ -45,7 +47,15 @@ class GemmaTest : public ::testing::Test {
|
||||||
s_env->SetMaxGeneratedTokens(2048);
|
s_env->SetMaxGeneratedTokens(2048);
|
||||||
s_env->MutableConfig().temperature = 0.0f; // deterministic
|
s_env->MutableConfig().temperature = 0.0f; // deterministic
|
||||||
s_env->MutableConfig().verbosity = 0;
|
s_env->MutableConfig().verbosity = 0;
|
||||||
// Using the turn structure worsens results.
|
// Using the turn structure worsens results sometimes.
|
||||||
|
// However, gemma-2 27B seems to need the turn structure to work.
|
||||||
|
// It would be good to make these tests more consistent.
|
||||||
|
if (s_env->GetModel()->Info().model == Model::GEMMA_27B) {
|
||||||
|
std::string mutable_prompt = prompt;
|
||||||
|
auto [response, n] = s_env->QueryModel(mutable_prompt); // Uses turns.
|
||||||
|
return response;
|
||||||
|
}
|
||||||
|
// Otherwise, don't use turn structure.
|
||||||
const std::vector<int> tokens = s_env->TokenizeAndPrependBOS(prompt);
|
const std::vector<int> tokens = s_env->TokenizeAndPrependBOS(prompt);
|
||||||
auto [response, n] = s_env->QueryModel(tokens);
|
auto [response, n] = s_env->QueryModel(tokens);
|
||||||
return response;
|
return response;
|
||||||
|
|
@ -56,7 +66,15 @@ class GemmaTest : public ::testing::Test {
|
||||||
s_env->SetMaxGeneratedTokens(64);
|
s_env->SetMaxGeneratedTokens(64);
|
||||||
s_env->MutableConfig().temperature = 0.0f; // deterministic
|
s_env->MutableConfig().temperature = 0.0f; // deterministic
|
||||||
s_env->MutableConfig().verbosity = 0;
|
s_env->MutableConfig().verbosity = 0;
|
||||||
// Using the turn structure worsens results.
|
std::vector<std::string> replies;
|
||||||
|
// Using the turn structure worsens results sometimes.
|
||||||
|
// However, gemma-2 27B seems to need the turn structure to work.
|
||||||
|
// It would be good to make these tests more consistent.
|
||||||
|
if (s_env->GetModel()->Info().model == Model::GEMMA_27B) {
|
||||||
|
for (auto [response, n] : s_env->BatchQueryModel(inputs)) {
|
||||||
|
replies.push_back(response);
|
||||||
|
}
|
||||||
|
} else { // Not Gemma-2 27B. Do not use turn structure.
|
||||||
std::vector<std::unique_ptr<std::vector<int>>> prompts;
|
std::vector<std::unique_ptr<std::vector<int>>> prompts;
|
||||||
prompts.reserve(inputs.size());
|
prompts.reserve(inputs.size());
|
||||||
for (auto input_string : inputs) {
|
for (auto input_string : inputs) {
|
||||||
|
|
@ -71,15 +89,15 @@ class GemmaTest : public ::testing::Test {
|
||||||
hwy::Span<const hwy::Span<int>> prompt_span =
|
hwy::Span<const hwy::Span<int>> prompt_span =
|
||||||
hwy::Span<const hwy::Span<int>>(prompt_vector.data(),
|
hwy::Span<const hwy::Span<int>>(prompt_vector.data(),
|
||||||
prompt_vector.size());
|
prompt_vector.size());
|
||||||
std::vector<std::string> replies;
|
|
||||||
for (auto [response, n] : s_env->BatchQueryModel2(prompt_span)) {
|
for (auto [response, n] : s_env->BatchQueryModel2(prompt_span)) {
|
||||||
replies.push_back(response);
|
replies.push_back(response);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
return replies;
|
return replies;
|
||||||
}
|
}
|
||||||
|
|
||||||
void TestQuestions(const char* kQA[][2], size_t num_questions, bool batch) {
|
void TestQuestions(const char* kQA[][2], size_t num_questions, bool batch) {
|
||||||
if (!s_env->GetModel()) return;
|
ASSERT_NE(s_env->GetModel(), nullptr);
|
||||||
if (batch) {
|
if (batch) {
|
||||||
std::vector<std::string> inputs;
|
std::vector<std::string> inputs;
|
||||||
for (size_t i = 0; i < num_questions; ++i) {
|
for (size_t i = 0; i < num_questions; ++i) {
|
||||||
|
|
@ -171,80 +189,80 @@ static const char kGettysburg[] = {
|
||||||
"people, for the people, shall not perish from the earth.\n"};
|
"people, for the people, shall not perish from the earth.\n"};
|
||||||
|
|
||||||
TEST_F(GemmaTest, CrossEntropySmall) {
|
TEST_F(GemmaTest, CrossEntropySmall) {
|
||||||
if (!s_env->GetModel()) return;
|
ASSERT_NE(s_env->GetModel(), nullptr);
|
||||||
static const char kSmall[] =
|
static const char kSmall[] =
|
||||||
"The capital of Hungary is Budapest which is located in Europe.";
|
"The capital of Hungary is Budapest which is located in Europe.";
|
||||||
float entropy = s_env->CrossEntropy(kSmall);
|
float entropy = s_env->CrossEntropy(kSmall);
|
||||||
fprintf(stderr, "per-token entropy: %f\n", entropy);
|
fprintf(stderr, "per-token entropy: %f\n", entropy);
|
||||||
float expected_entropy;
|
|
||||||
switch (s_env->GetModel()->Info().model) {
|
switch (s_env->GetModel()->Info().model) {
|
||||||
case gcpp::Model::GEMMA_2B:
|
case gcpp::Model::GEMMA_2B:
|
||||||
expected_entropy = 2.56f;
|
// 2B v.1 and v.1.1 produce slightly different results.
|
||||||
|
EXPECT_NEAR(entropy, 2.6f, 0.2f);
|
||||||
break;
|
break;
|
||||||
case gcpp::Model::GEMMA_7B:
|
case gcpp::Model::GEMMA_7B:
|
||||||
expected_entropy = 2.91f;
|
// 7B v.1 and v.1.1 produce slightly different results.
|
||||||
|
EXPECT_NEAR(entropy, 2.8f, 0.2f);
|
||||||
break;
|
break;
|
||||||
case gcpp::Model::GEMMA_9B:
|
case gcpp::Model::GEMMA_9B:
|
||||||
expected_entropy = 1.28f;
|
EXPECT_NEAR(entropy, 1.28f, 0.02f);
|
||||||
break;
|
break;
|
||||||
case gcpp::Model::GEMMA_27B:
|
case gcpp::Model::GEMMA_27B:
|
||||||
expected_entropy = 1.30f;
|
EXPECT_NEAR(entropy, 1.30f, 0.02f);
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
FAIL() << "no entropy expectation for this model";
|
FAIL() << "no entropy expectation for this model";
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
EXPECT_NEAR(entropy, expected_entropy, 0.02f);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(GemmaTest, CrossEntropyJingleBells) {
|
TEST_F(GemmaTest, CrossEntropyJingleBells) {
|
||||||
if (!s_env->GetModel()) return;
|
ASSERT_NE(s_env->GetModel(), nullptr);
|
||||||
float entropy = s_env->CrossEntropy(kJingleBells);
|
float entropy = s_env->CrossEntropy(kJingleBells);
|
||||||
fprintf(stderr, "per-token entropy: %f\n", entropy);
|
fprintf(stderr, "per-token entropy: %f\n", entropy);
|
||||||
float expected_entropy;
|
|
||||||
switch (s_env->GetModel()->Info().model) {
|
switch (s_env->GetModel()->Info().model) {
|
||||||
case gcpp::Model::GEMMA_2B:
|
case gcpp::Model::GEMMA_2B:
|
||||||
expected_entropy = 1.85f;
|
// 2B v.1 and v.1.1 produce slightly different results.
|
||||||
|
EXPECT_NEAR(entropy, 1.9f, 0.2f);
|
||||||
break;
|
break;
|
||||||
case gcpp::Model::GEMMA_7B:
|
case gcpp::Model::GEMMA_7B:
|
||||||
expected_entropy = 1.06f;
|
// 7B v.1 and v.1.1 produce slightly different results.
|
||||||
|
EXPECT_NEAR(entropy, 1.07f, 0.05f);
|
||||||
break;
|
break;
|
||||||
case gcpp::Model::GEMMA_9B:
|
case gcpp::Model::GEMMA_9B:
|
||||||
expected_entropy = 0.37f;
|
EXPECT_NEAR(entropy, 0.37f, 0.02f);
|
||||||
break;
|
break;
|
||||||
case gcpp::Model::GEMMA_27B:
|
case gcpp::Model::GEMMA_27B:
|
||||||
expected_entropy = 0.33f;
|
EXPECT_NEAR(entropy, 0.33f, 0.02f);
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
FAIL() << "no entropy expectation for this model";
|
FAIL() << "no entropy expectation for this model";
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
EXPECT_NEAR(entropy, expected_entropy, 0.02f);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(GemmaTest, CrossEntropyGettysburg) {
|
TEST_F(GemmaTest, CrossEntropyGettysburg) {
|
||||||
if (!s_env->GetModel()) return;
|
ASSERT_NE(s_env->GetModel(), nullptr);
|
||||||
float entropy = s_env->CrossEntropy(kGettysburg);
|
float entropy = s_env->CrossEntropy(kGettysburg);
|
||||||
fprintf(stderr, "per-token entropy: %f\n", entropy);
|
fprintf(stderr, "per-token entropy: %f\n", entropy);
|
||||||
float expected_entropy;
|
|
||||||
switch (s_env->GetModel()->Info().model) {
|
switch (s_env->GetModel()->Info().model) {
|
||||||
case gcpp::Model::GEMMA_2B:
|
case gcpp::Model::GEMMA_2B:
|
||||||
expected_entropy = 1.05f;
|
// 2B v.1 and v.1.1 produce slightly different results.
|
||||||
|
EXPECT_NEAR(entropy, 1.1f, 0.1f);
|
||||||
break;
|
break;
|
||||||
case gcpp::Model::GEMMA_7B:
|
case gcpp::Model::GEMMA_7B:
|
||||||
expected_entropy = 0.83f;
|
// 7B v.1 and v.1.1 produce slightly different results.
|
||||||
|
EXPECT_NEAR(entropy, 0.75f, 0.1f);
|
||||||
break;
|
break;
|
||||||
case gcpp::Model::GEMMA_9B:
|
case gcpp::Model::GEMMA_9B:
|
||||||
expected_entropy = 0.15f;
|
EXPECT_NEAR(entropy, 0.15f, 0.02f);
|
||||||
break;
|
break;
|
||||||
case gcpp::Model::GEMMA_27B:
|
case gcpp::Model::GEMMA_27B:
|
||||||
expected_entropy = 0.14f;
|
EXPECT_NEAR(entropy, 0.14f, 0.02f);
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
FAIL() << "no entropy expectation for this model";
|
FAIL() << "no entropy expectation for this model";
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
EXPECT_NEAR(entropy, expected_entropy, 0.02f);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue