diff --git a/evals/gemma_test.cc b/evals/gemma_test.cc index ac9d8ba..97d1af3 100644 --- a/evals/gemma_test.cc +++ b/evals/gemma_test.cc @@ -26,6 +26,11 @@ #include "hwy/aligned_allocator.h" #include "hwy/tests/hwy_gtest.h" +// This test can be run manually with the downloaded gemma weights. +// To run the test, pass the following flags: +// --model --tokenizer --weights +// It should pass for the following models: 2b-it, 7b-it, 9b-it, 27b-it + namespace gcpp { namespace { @@ -170,25 +175,76 @@ TEST_F(GemmaTest, CrossEntropySmall) { static const char kSmall[] = "The capital of Hungary is Budapest which is located in Europe."; float entropy = s_env->CrossEntropy(kSmall); - fprintf(stderr, "per-byte entropy: %f\n", entropy); - const bool is_7b = s_env->GetModel()->Info().model == gcpp::Model::GEMMA_7B; - EXPECT_LT(entropy, is_7b ? 2.1f : 2.0f); + fprintf(stderr, "per-token entropy: %f\n", entropy); + float expected_entropy; + switch (s_env->GetModel()->Info().model) { + case gcpp::Model::GEMMA_2B: + expected_entropy = 2.56f; + break; + case gcpp::Model::GEMMA_7B: + expected_entropy = 2.91f; + break; + case gcpp::Model::GEMMA_9B: + expected_entropy = 1.28f; + break; + case gcpp::Model::GEMMA_27B: + expected_entropy = 1.30f; + break; + default: + FAIL() << "no entropy expectation for this model"; + break; + } + EXPECT_NEAR(entropy, expected_entropy, 0.02f); } TEST_F(GemmaTest, CrossEntropyJingleBells) { if (!s_env->GetModel()) return; float entropy = s_env->CrossEntropy(kJingleBells); - fprintf(stderr, "per-byte entropy: %f\n", entropy); - const bool is_7b = s_env->GetModel()->Info().model == gcpp::Model::GEMMA_7B; - EXPECT_LT(entropy, is_7b ? 0.9f : 1.8f); + fprintf(stderr, "per-token entropy: %f\n", entropy); + float expected_entropy; + switch (s_env->GetModel()->Info().model) { + case gcpp::Model::GEMMA_2B: + expected_entropy = 1.85f; + break; + case gcpp::Model::GEMMA_7B: + expected_entropy = 1.06f; + break; + case gcpp::Model::GEMMA_9B: + expected_entropy = 0.37f; + break; + case gcpp::Model::GEMMA_27B: + expected_entropy = 0.33f; + break; + default: + FAIL() << "no entropy expectation for this model"; + break; + } + EXPECT_NEAR(entropy, expected_entropy, 0.02f); } TEST_F(GemmaTest, CrossEntropyGettysburg) { if (!s_env->GetModel()) return; float entropy = s_env->CrossEntropy(kGettysburg); - fprintf(stderr, "per-byte entropy: %f\n", entropy); - const bool is_7b = s_env->GetModel()->Info().model == gcpp::Model::GEMMA_7B; - EXPECT_LT(entropy, is_7b ? 0.8f : 1.2f); + fprintf(stderr, "per-token entropy: %f\n", entropy); + float expected_entropy; + switch (s_env->GetModel()->Info().model) { + case gcpp::Model::GEMMA_2B: + expected_entropy = 1.05f; + break; + case gcpp::Model::GEMMA_7B: + expected_entropy = 0.83f; + break; + case gcpp::Model::GEMMA_9B: + expected_entropy = 0.15f; + break; + case gcpp::Model::GEMMA_27B: + expected_entropy = 0.14f; + break; + default: + FAIL() << "no entropy expectation for this model"; + break; + } + EXPECT_NEAR(entropy, expected_entropy, 0.02f); } } // namespace