diff --git a/BUILD.bazel b/BUILD.bazel index b4f1148..467bcf1 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -213,19 +213,20 @@ cc_library( cc_test( name = "gemma_test", srcs = ["evals/gemma_test.cc"], + # Requires model files + tags = [ + "local", + "manual", + "no_tap", + ], deps = [ - ":app", - ":args", ":benchmark_helper", ":common", - ":cross_entropy", ":gemma_lib", - ":ops", + ":tokenizer", "@googletest//:gtest_main", - "//compression:io", "@hwy//:hwy", "@hwy//:hwy_test_util", - "@hwy//:thread_pool", ], ) diff --git a/evals/gemma_test.cc b/evals/gemma_test.cc index d24425f..6436908 100644 --- a/evals/gemma_test.cc +++ b/evals/gemma_test.cc @@ -23,13 +23,15 @@ #include "evals/benchmark_helper.h" #include "gemma/common.h" +#include "gemma/tokenizer.h" #include "hwy/aligned_allocator.h" #include "hwy/tests/hwy_gtest.h" // This test can be run manually with the downloaded gemma weights. // To run the test, pass the following flags: // --model --tokenizer --weights -// It should pass for the following models: 2b-it, 7b-it, 9b-it, 27b-it +// It should pass for the following models: +// 2b-it (v1 and v1.1), 7b-it (v1 and v1.1), 9b-it, 27b-it namespace gcpp { namespace { @@ -45,7 +47,15 @@ class GemmaTest : public ::testing::Test { s_env->SetMaxGeneratedTokens(2048); s_env->MutableConfig().temperature = 0.0f; // deterministic s_env->MutableConfig().verbosity = 0; - // Using the turn structure worsens results. + // Using the turn structure worsens results sometimes. + // However, gemma-2 27B seems to need the turn structure to work. + // It would be good to make these tests more consistent. + if (s_env->GetModel()->Info().model == Model::GEMMA_27B) { + std::string mutable_prompt = prompt; + auto [response, n] = s_env->QueryModel(mutable_prompt); // Uses turns. + return response; + } + // Otherwise, don't use turn structure. const std::vector tokens = s_env->TokenizeAndPrependBOS(prompt); auto [response, n] = s_env->QueryModel(tokens); return response; @@ -56,30 +66,38 @@ class GemmaTest : public ::testing::Test { s_env->SetMaxGeneratedTokens(64); s_env->MutableConfig().temperature = 0.0f; // deterministic s_env->MutableConfig().verbosity = 0; - // Using the turn structure worsens results. - std::vector>> prompts; - prompts.reserve(inputs.size()); - for (auto input_string : inputs) { - std::string mutable_input_string = input_string; - prompts.push_back(std::make_unique>( - s_env->TokenizeAndPrependBOS(input_string))); - } - std::vector> prompt_vector; - for (auto& prompt : prompts) { - prompt_vector.push_back(hwy::Span(prompt->data(), prompt->size())); - } - hwy::Span> prompt_span = - hwy::Span>(prompt_vector.data(), - prompt_vector.size()); std::vector replies; - for (auto [response, n] : s_env->BatchQueryModel2(prompt_span)) { - replies.push_back(response); + // Using the turn structure worsens results sometimes. + // However, gemma-2 27B seems to need the turn structure to work. + // It would be good to make these tests more consistent. + if (s_env->GetModel()->Info().model == Model::GEMMA_27B) { + for (auto [response, n] : s_env->BatchQueryModel(inputs)) { + replies.push_back(response); + } + } else { // Not Gemma-2 27B. Do not use turn structure. + std::vector>> prompts; + prompts.reserve(inputs.size()); + for (auto input_string : inputs) { + std::string mutable_input_string = input_string; + prompts.push_back(std::make_unique>( + s_env->TokenizeAndPrependBOS(input_string))); + } + std::vector> prompt_vector; + for (auto& prompt : prompts) { + prompt_vector.push_back(hwy::Span(prompt->data(), prompt->size())); + } + hwy::Span> prompt_span = + hwy::Span>(prompt_vector.data(), + prompt_vector.size()); + for (auto [response, n] : s_env->BatchQueryModel2(prompt_span)) { + replies.push_back(response); + } } return replies; } void TestQuestions(const char* kQA[][2], size_t num_questions, bool batch) { - if (!s_env->GetModel()) return; + ASSERT_NE(s_env->GetModel(), nullptr); if (batch) { std::vector inputs; for (size_t i = 0; i < num_questions; ++i) { @@ -171,80 +189,80 @@ static const char kGettysburg[] = { "people, for the people, shall not perish from the earth.\n"}; TEST_F(GemmaTest, CrossEntropySmall) { - if (!s_env->GetModel()) return; + ASSERT_NE(s_env->GetModel(), nullptr); static const char kSmall[] = "The capital of Hungary is Budapest which is located in Europe."; float entropy = s_env->CrossEntropy(kSmall); fprintf(stderr, "per-token entropy: %f\n", entropy); - float expected_entropy; switch (s_env->GetModel()->Info().model) { case gcpp::Model::GEMMA_2B: - expected_entropy = 2.56f; + // 2B v.1 and v.1.1 produce slightly different results. + EXPECT_NEAR(entropy, 2.6f, 0.2f); break; case gcpp::Model::GEMMA_7B: - expected_entropy = 2.91f; + // 7B v.1 and v.1.1 produce slightly different results. + EXPECT_NEAR(entropy, 2.8f, 0.2f); break; case gcpp::Model::GEMMA_9B: - expected_entropy = 1.28f; + EXPECT_NEAR(entropy, 1.28f, 0.02f); break; case gcpp::Model::GEMMA_27B: - expected_entropy = 1.30f; + EXPECT_NEAR(entropy, 1.30f, 0.02f); break; default: FAIL() << "no entropy expectation for this model"; break; } - EXPECT_NEAR(entropy, expected_entropy, 0.02f); } TEST_F(GemmaTest, CrossEntropyJingleBells) { - if (!s_env->GetModel()) return; + ASSERT_NE(s_env->GetModel(), nullptr); float entropy = s_env->CrossEntropy(kJingleBells); fprintf(stderr, "per-token entropy: %f\n", entropy); - float expected_entropy; switch (s_env->GetModel()->Info().model) { case gcpp::Model::GEMMA_2B: - expected_entropy = 1.85f; + // 2B v.1 and v.1.1 produce slightly different results. + EXPECT_NEAR(entropy, 1.9f, 0.2f); break; case gcpp::Model::GEMMA_7B: - expected_entropy = 1.06f; + // 7B v.1 and v.1.1 produce slightly different results. + EXPECT_NEAR(entropy, 1.07f, 0.05f); break; case gcpp::Model::GEMMA_9B: - expected_entropy = 0.37f; + EXPECT_NEAR(entropy, 0.37f, 0.02f); break; case gcpp::Model::GEMMA_27B: - expected_entropy = 0.33f; + EXPECT_NEAR(entropy, 0.33f, 0.02f); break; default: FAIL() << "no entropy expectation for this model"; break; } - EXPECT_NEAR(entropy, expected_entropy, 0.02f); } TEST_F(GemmaTest, CrossEntropyGettysburg) { - if (!s_env->GetModel()) return; + ASSERT_NE(s_env->GetModel(), nullptr); float entropy = s_env->CrossEntropy(kGettysburg); fprintf(stderr, "per-token entropy: %f\n", entropy); - float expected_entropy; switch (s_env->GetModel()->Info().model) { case gcpp::Model::GEMMA_2B: - expected_entropy = 1.05f; + // 2B v.1 and v.1.1 produce slightly different results. + EXPECT_NEAR(entropy, 1.1f, 0.1f); break; case gcpp::Model::GEMMA_7B: - expected_entropy = 0.83f; + // 7B v.1 and v.1.1 produce slightly different results. + EXPECT_NEAR(entropy, 0.75f, 0.1f); break; case gcpp::Model::GEMMA_9B: - expected_entropy = 0.15f; + EXPECT_NEAR(entropy, 0.15f, 0.02f); break; case gcpp::Model::GEMMA_27B: - expected_entropy = 0.14f; + EXPECT_NEAR(entropy, 0.14f, 0.02f); break; default: FAIL() << "no entropy expectation for this model"; break; } - EXPECT_NEAR(entropy, expected_entropy, 0.02f); } } // namespace