mirror of https://github.com/google/gemma.cpp.git
Add entropy expectations for Griffin-2b model in gemma_test and make sure it passes.
PiperOrigin-RevId: 675564389
This commit is contained in:
parent
e4ba93412a
commit
760a69449e
|
|
@ -29,7 +29,8 @@
|
||||||
// To run the test, pass the following flags:
|
// To run the test, pass the following flags:
|
||||||
// --model <model> --tokenizer <tokenizer_path> --weights <weights_path>
|
// --model <model> --tokenizer <tokenizer_path> --weights <weights_path>
|
||||||
// It should pass for the following models:
|
// It should pass for the following models:
|
||||||
// 2b-it (v1 and v1.1), 7b-it (v1 and v1.1), gemma2-2b-it, 9b-it, 27b-it
|
// Gemma1: 2b-it (v1 and v1.1), 7b-it (v1 and v1.1), gr2b-it,
|
||||||
|
// Gemma2: gemma2-2b-it, 9b-it, 27b-it,
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
@ -46,14 +47,15 @@ class GemmaTest : public ::testing::Test {
|
||||||
s_env->MutableConfig().temperature = 0.0f; // deterministic
|
s_env->MutableConfig().temperature = 0.0f; // deterministic
|
||||||
s_env->MutableConfig().verbosity = 0;
|
s_env->MutableConfig().verbosity = 0;
|
||||||
// Using the turn structure worsens results sometimes.
|
// Using the turn structure worsens results sometimes.
|
||||||
// However, gemma-2 27B seems to need the turn structure to work.
|
// However, some models need the turn structure to work.
|
||||||
// It would be good to make these tests more consistent.
|
// It would be good to make these tests more consistent.
|
||||||
if (s_env->GetModel()->Info().model == Model::GEMMA2_27B) {
|
if (s_env->GetModel()->Info().model == Model::GEMMA2_27B ||
|
||||||
|
s_env->GetModel()->Info().model == Model::GRIFFIN_2B) {
|
||||||
std::string mutable_prompt = prompt;
|
std::string mutable_prompt = prompt;
|
||||||
auto [response, n] = s_env->QueryModel(mutable_prompt); // Uses turns.
|
auto [response, n] = s_env->QueryModel(mutable_prompt); // Uses turns.
|
||||||
return response;
|
return response;
|
||||||
}
|
}
|
||||||
// Otherwise, don't use turn structure.
|
// Otherwise, do not use turn structure.
|
||||||
const std::vector<int> tokens = s_env->TokenizeAndPrependBOS(prompt);
|
const std::vector<int> tokens = s_env->TokenizeAndPrependBOS(prompt);
|
||||||
auto [response, n] = s_env->QueryModel(tokens);
|
auto [response, n] = s_env->QueryModel(tokens);
|
||||||
return response;
|
return response;
|
||||||
|
|
@ -66,15 +68,16 @@ class GemmaTest : public ::testing::Test {
|
||||||
s_env->MutableConfig().verbosity = 0;
|
s_env->MutableConfig().verbosity = 0;
|
||||||
std::vector<std::string> replies;
|
std::vector<std::string> replies;
|
||||||
// Using the turn structure worsens results sometimes.
|
// Using the turn structure worsens results sometimes.
|
||||||
// However, gemma-2 27B seems to need the turn structure to work.
|
// However, some models need the turn structure to work.
|
||||||
// It would be good to make these tests more consistent.
|
// It would be good to make these tests more consistent.
|
||||||
if (s_env->GetModel()->Info().model == Model::GEMMA2_27B) {
|
if (s_env->GetModel()->Info().model == Model::GEMMA2_27B ||
|
||||||
|
s_env->GetModel()->Info().model == Model::GRIFFIN_2B) {
|
||||||
for (auto [response, n] : s_env->BatchQueryModel(inputs)) {
|
for (auto [response, n] : s_env->BatchQueryModel(inputs)) {
|
||||||
replies.push_back(response);
|
replies.push_back(response);
|
||||||
}
|
}
|
||||||
return replies;
|
return replies;
|
||||||
}
|
}
|
||||||
// Not Gemma-2 27B. Do not use turn structure.
|
// Otherwise, do not use turn structure.
|
||||||
std::vector<std::vector<int>> prompts_vector;
|
std::vector<std::vector<int>> prompts_vector;
|
||||||
prompts_vector.reserve(inputs.size());
|
prompts_vector.reserve(inputs.size());
|
||||||
for (const auto& input_string : inputs) {
|
for (const auto& input_string : inputs) {
|
||||||
|
|
@ -243,6 +246,9 @@ TEST_F(GemmaTest, CrossEntropySmall) {
|
||||||
// 7B v.1 and v.1.1 produce slightly different results.
|
// 7B v.1 and v.1.1 produce slightly different results.
|
||||||
EXPECT_NEAR(entropy, 2.8f, 0.2f);
|
EXPECT_NEAR(entropy, 2.8f, 0.2f);
|
||||||
break;
|
break;
|
||||||
|
case gcpp::Model::GRIFFIN_2B:
|
||||||
|
EXPECT_NEAR(entropy, 2.25f, 0.02f);
|
||||||
|
break;
|
||||||
case gcpp::Model::GEMMA2_2B:
|
case gcpp::Model::GEMMA2_2B:
|
||||||
EXPECT_NEAR(entropy, 1.14f, 0.02f);
|
EXPECT_NEAR(entropy, 1.14f, 0.02f);
|
||||||
break;
|
break;
|
||||||
|
|
@ -271,6 +277,9 @@ TEST_F(GemmaTest, CrossEntropyJingleBells) {
|
||||||
// 7B v.1 and v.1.1 produce slightly different results.
|
// 7B v.1 and v.1.1 produce slightly different results.
|
||||||
EXPECT_NEAR(entropy, 1.07f, 0.05f);
|
EXPECT_NEAR(entropy, 1.07f, 0.05f);
|
||||||
break;
|
break;
|
||||||
|
case gcpp::Model::GRIFFIN_2B:
|
||||||
|
EXPECT_NEAR(entropy, 1.95f, 0.02f);
|
||||||
|
break;
|
||||||
case gcpp::Model::GEMMA2_2B:
|
case gcpp::Model::GEMMA2_2B:
|
||||||
EXPECT_NEAR(entropy, 0.49f, 0.02f);
|
EXPECT_NEAR(entropy, 0.49f, 0.02f);
|
||||||
break;
|
break;
|
||||||
|
|
@ -299,6 +308,9 @@ TEST_F(GemmaTest, CrossEntropyGettysburg) {
|
||||||
// 7B v.1 and v.1.1 produce slightly different results.
|
// 7B v.1 and v.1.1 produce slightly different results.
|
||||||
EXPECT_NEAR(entropy, 0.75f, 0.1f);
|
EXPECT_NEAR(entropy, 0.75f, 0.1f);
|
||||||
break;
|
break;
|
||||||
|
case gcpp::Model::GRIFFIN_2B:
|
||||||
|
EXPECT_NEAR(entropy, 0.82f, 0.02f);
|
||||||
|
break;
|
||||||
case gcpp::Model::GEMMA2_2B:
|
case gcpp::Model::GEMMA2_2B:
|
||||||
EXPECT_NEAR(entropy, 0.20f, 0.02f);
|
EXPECT_NEAR(entropy, 0.20f, 0.02f);
|
||||||
break;
|
break;
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue