diff --git a/evals/gemma_test.cc b/evals/gemma_test.cc index 9c56ed2..d70788c 100644 --- a/evals/gemma_test.cc +++ b/evals/gemma_test.cc @@ -29,7 +29,8 @@ // To run the test, pass the following flags: // --model --tokenizer --weights // It should pass for the following models: -// 2b-it (v1 and v1.1), 7b-it (v1 and v1.1), gemma2-2b-it, 9b-it, 27b-it +// Gemma1: 2b-it (v1 and v1.1), 7b-it (v1 and v1.1), gr2b-it, +// Gemma2: gemma2-2b-it, 9b-it, 27b-it, namespace gcpp { namespace { @@ -46,14 +47,15 @@ class GemmaTest : public ::testing::Test { s_env->MutableConfig().temperature = 0.0f; // deterministic s_env->MutableConfig().verbosity = 0; // Using the turn structure worsens results sometimes. - // However, gemma-2 27B seems to need the turn structure to work. + // However, some models need the turn structure to work. // It would be good to make these tests more consistent. - if (s_env->GetModel()->Info().model == Model::GEMMA2_27B) { + if (s_env->GetModel()->Info().model == Model::GEMMA2_27B || + s_env->GetModel()->Info().model == Model::GRIFFIN_2B) { std::string mutable_prompt = prompt; auto [response, n] = s_env->QueryModel(mutable_prompt); // Uses turns. return response; } - // Otherwise, don't use turn structure. + // Otherwise, do not use turn structure. const std::vector tokens = s_env->TokenizeAndPrependBOS(prompt); auto [response, n] = s_env->QueryModel(tokens); return response; @@ -66,15 +68,16 @@ class GemmaTest : public ::testing::Test { s_env->MutableConfig().verbosity = 0; std::vector replies; // Using the turn structure worsens results sometimes. - // However, gemma-2 27B seems to need the turn structure to work. + // However, some models need the turn structure to work. // It would be good to make these tests more consistent. - if (s_env->GetModel()->Info().model == Model::GEMMA2_27B) { + if (s_env->GetModel()->Info().model == Model::GEMMA2_27B || + s_env->GetModel()->Info().model == Model::GRIFFIN_2B) { for (auto [response, n] : s_env->BatchQueryModel(inputs)) { replies.push_back(response); } return replies; } - // Not Gemma-2 27B. Do not use turn structure. + // Otherwise, do not use turn structure. std::vector> prompts_vector; prompts_vector.reserve(inputs.size()); for (const auto& input_string : inputs) { @@ -243,6 +246,9 @@ TEST_F(GemmaTest, CrossEntropySmall) { // 7B v.1 and v.1.1 produce slightly different results. EXPECT_NEAR(entropy, 2.8f, 0.2f); break; + case gcpp::Model::GRIFFIN_2B: + EXPECT_NEAR(entropy, 2.25f, 0.02f); + break; case gcpp::Model::GEMMA2_2B: EXPECT_NEAR(entropy, 1.14f, 0.02f); break; @@ -271,6 +277,9 @@ TEST_F(GemmaTest, CrossEntropyJingleBells) { // 7B v.1 and v.1.1 produce slightly different results. EXPECT_NEAR(entropy, 1.07f, 0.05f); break; + case gcpp::Model::GRIFFIN_2B: + EXPECT_NEAR(entropy, 1.95f, 0.02f); + break; case gcpp::Model::GEMMA2_2B: EXPECT_NEAR(entropy, 0.49f, 0.02f); break; @@ -299,6 +308,9 @@ TEST_F(GemmaTest, CrossEntropyGettysburg) { // 7B v.1 and v.1.1 produce slightly different results. EXPECT_NEAR(entropy, 0.75f, 0.1f); break; + case gcpp::Model::GRIFFIN_2B: + EXPECT_NEAR(entropy, 0.82f, 0.02f); + break; case gcpp::Model::GEMMA2_2B: EXPECT_NEAR(entropy, 0.20f, 0.02f); break;