diff --git a/evals/gemma_test.cc b/evals/gemma_test.cc index 6f0514f..5d29f4f 100644 --- a/evals/gemma_test.cc +++ b/evals/gemma_test.cc @@ -43,92 +43,26 @@ class GemmaTest : public ::testing::Test { HWY_ASSERT(s_env == nullptr); // Should only be called once. s_env = new GemmaEnv(argc, argv); const gcpp::ModelConfig& config = s_env->GetGemma()->GetModelConfig(); - fprintf(stderr, "Using %s)\n", config.Specifier().c_str()); + fprintf(stderr, "Using %s\n", config.Specifier().c_str()); } static void DeleteEnv() { delete s_env; } protected: - std::string GemmaReply(const std::string& prompt) { - HWY_ASSERT(s_env); // must have called InitEnv() - s_env->SetMaxGeneratedTokens(2048); - s_env->MutableConfig().temperature = 0.0f; // deterministic - s_env->MutableConfig().verbosity = 0; - const ModelConfig& config = s_env->GetGemma()->GetModelConfig(); - // Using the turn structure worsens results sometimes. - // However, some models need the turn structure to work. - // It would be good to make these tests more consistent. - if (config.model == Model::GEMMA2_27B || - config.model == Model::GRIFFIN_2B) { - std::string mutable_prompt = prompt; - QueryResult result = s_env->QueryModel(mutable_prompt); // Uses turns. - return result.response; - } - // Otherwise, do not use turn structure. - const std::vector tokens = s_env->TokenizeAndPrependBOS(prompt); - QueryResult result = s_env->QueryModel(tokens); - return result.response; - } - std::vector BatchGemmaReply( const std::vector& inputs) { HWY_ASSERT(s_env); // must have called InitEnv() s_env->SetMaxGeneratedTokens(64); s_env->MutableConfig().temperature = 0.0f; // deterministic s_env->MutableConfig().verbosity = 0; - const ModelConfig& config = s_env->GetGemma()->GetModelConfig(); + // Always use turn structure (WrapAndTokenize). std::vector replies; - // Using the turn structure worsens results sometimes. - // However, some models need the turn structure to work. - // It would be good to make these tests more consistent. - if (config.model == Model::GEMMA2_27B || - config.model == Model::GRIFFIN_2B) { - for (QueryResult result : s_env->BatchQueryModel(inputs)) { - replies.push_back(result.response); - } - return replies; - } - // Otherwise, do not use turn structure. - std::vector> prompts_vector; - prompts_vector.reserve(inputs.size()); - for (const auto& input_string : inputs) { - prompts_vector.push_back(s_env->TokenizeAndPrependBOS(input_string)); - } - std::vector prompt_spans; - for (const auto& prompt : prompts_vector) { - prompt_spans.push_back(PromptTokens(prompt.data(), prompt.size())); - } - QueriesPromptTokens prompts(prompt_spans.data(), prompt_spans.size()); - for (const QueryResult& result : s_env->BatchQueryModel(prompts)) { + for (QueryResult result : s_env->BatchQueryModel(inputs)) { replies.push_back(result.response); } return replies; } - void TestQuestions(const char* kQA[][2], size_t num_questions, bool batch) { - HWY_ASSERT(s_env->GetGemma() != nullptr); - if (batch) { - std::vector inputs; - for (size_t i = 0; i < num_questions; ++i) { - fprintf(stderr, "Batch Question %zu\n\n", i + 1); - inputs.push_back(kQA[i][0]); - } - std::vector responses = BatchGemmaReply(inputs); - for (size_t i = 0; i < num_questions; ++i) { - std::string response = responses.at(i); - fprintf(stderr, "Batch answer %zu '%s'\n\n", i + 1, response.c_str()); - EXPECT_TRUE(response.find(kQA[i][1]) != std::string::npos); // NOLINT - } - } else { - for (size_t i = 0; i < num_questions; ++i) { - fprintf(stderr, "Question %zu\n\n", i + 1); - std::string response = GemmaReply(kQA[i][0]); - fprintf(stderr, "'%s'\n\n", response.c_str()); - EXPECT_TRUE(response.find(kQA[i][1]) != std::string::npos); // NOLINT - } - } - } - // Shared state. Requires argc/argv, so construct in main via InitEnv. // Note that the style guide forbids non-local static variables with dtors. static GemmaEnv* s_env; @@ -136,44 +70,34 @@ class GemmaTest : public ::testing::Test { GemmaEnv* GemmaTest::s_env = nullptr; -TEST_F(GemmaTest, GeographyBatched) { - s_env->MutableConfig().decode_qbatch_size = 3; - // 6 are enough to test batching and the loop. +TEST_F(GemmaTest, Batched) { + // Test remainder handling in MatMul (four rows per tile), but avoid a + // second batch in debug builds to speed up the test. + s_env->MutableConfig().decode_qbatch_size = HWY_IS_DEBUG_BUILD ? 6 : 3; static const char* kQA[][2] = { {"What is the capital of Australia?", "Canberra"}, - {"What is the capital of Denmark?", "Copenhagen"}, - {"Ljubljana is the capital of which country?", "Slovenia"}, - {"Is Chicago a country?", "city"}, {"How many states does the US have?", "50"}, {"What is the Pacific?", "ocean"}, - }; - static const size_t kNum = sizeof(kQA) / sizeof(kQA[0]); - TestQuestions(kQA, HWY_MIN(kNum, 3), /*batch=*/false); - TestQuestions(kQA, 1, /*batch=*/true); - TestQuestions(kQA, kNum, /*batch=*/true); -} - -TEST_F(GemmaTest, History) { - static const char* kQA[][2] = { {"When was the battle of Hastings?", "1066"}, - }; - static const size_t kNum = sizeof(kQA) / sizeof(kQA[0]); - TestQuestions(kQA, kNum, /*batch=*/false); -} - -TEST_F(GemmaTest, Arithmetic) { - static const char* kQA[][2] = { {"what is 13 + 14?", "27"}, {"what is 7 * 8?", "56"}, }; - static const size_t kNum = sizeof(kQA) / sizeof(kQA[0]); - TestQuestions(kQA, kNum, /*batch=*/false); + const size_t kNum = sizeof(kQA) / sizeof(kQA[0]); + std::vector inputs; + for (size_t i = 0; i < kNum; ++i) { + inputs.push_back(kQA[i][0]); + } + std::vector responses = BatchGemmaReply(inputs); + HWY_ASSERT(responses.size() == kNum); + for (size_t i = 0; i < kNum; ++i) { + fprintf(stderr, "#%zu: '%s'\n\n", i, responses[i].c_str()); + EXPECT_TRUE(responses[i].find(kQA[i][1]) != std::string::npos); // NOLINT + } } TEST_F(GemmaTest, Multiturn) { const Gemma* model = s_env->GetGemma(); const ModelConfig& config = model->GetModelConfig(); - HWY_ASSERT(model != nullptr); size_t abs_pos = 0; std::string response; auto stream_token = [&](int token, float) { @@ -220,41 +144,6 @@ TEST_F(GemmaTest, Multiturn) { EXPECT_TRUE(remembered_turquoise || remembered_car); } -static const char kJingleBells[] = R"( -Dashing through the snow -In a one-horse open sleigh -O'er the fields we go -Laughing all the way -Bells on bobtails ring -Making spirits bright -What fun it is to ride and sing -A sleighing song tonight -)"; - -// The "Hay Draft" of the Gettysburg Address. -static const char kGettysburg[] = { - "Four score and seven years ago our fathers brought forth, upon this " - "continent, a new nation, conceived in Liberty, and dedicated to the " - "proposition that all men are created equal.\n\nNow we are engaged in a " - "great civil war, testing whether that nation, or any nation, so " - "conceived, and so dedicated, can long endure. We are met here on a great " - "battlefield of that war. We have come to dedicate a portion of it as a " - "final resting place for those who here gave their lives that that nation " - "might live. It is altogether fitting and proper that we should do " - "this.\n\nBut in a larger sense we can not dedicate -- we can not " - "consecrate -- we can not hallow this ground. The brave men, living and " - "dead, who struggled, here, have consecrated it far above our poor power " - "to add or detract. The world will little note, nor long remember, what we " - "say here, but can never forget what they did here. It is for us, the " - "living, rather to be dedicated here to the unfinished work which they " - "have, thus far, so nobly carried on. It is rather for us to be here " - "dedicated to the great task remaining before us -- that from these " - "honored dead we take increased devotion to that cause for which they here " - "gave the last full measure of devotion -- that we here highly resolve " - "that these dead shall not have died in vain; that this nation shall have " - "a new birth of freedom; and that this government of the people, by the " - "people, for the people, shall not perish from the earth.\n"}; - TEST_F(GemmaTest, CrossEntropySmall) { HWY_ASSERT(s_env->GetGemma() != nullptr); const ModelConfig& config = s_env->GetGemma()->GetModelConfig(); @@ -281,54 +170,6 @@ TEST_F(GemmaTest, CrossEntropySmall) { } } -TEST_F(GemmaTest, CrossEntropyJingleBells) { - HWY_ASSERT(s_env->GetGemma() != nullptr); - const ModelConfig& config = s_env->GetGemma()->GetModelConfig(); - float entropy = s_env->CrossEntropy(kJingleBells); - fprintf(stderr, "per-token entropy: %f\n", entropy); - switch (config.model) { - case gcpp::Model::GRIFFIN_2B: - EXPECT_NEAR(entropy, 1.62f, 0.02f); - break; - case gcpp::Model::GEMMA2_2B: - EXPECT_NEAR(entropy, 0.49f, 0.02f); - break; - case gcpp::Model::GEMMA2_9B: - EXPECT_NEAR(entropy, 0.37f, 0.02f); - break; - case gcpp::Model::GEMMA2_27B: - EXPECT_NEAR(entropy, 0.33f, 0.02f); - break; - default: - FAIL() << "no entropy expectation for this model"; - break; - } -} - -TEST_F(GemmaTest, CrossEntropyGettysburg) { - HWY_ASSERT(s_env->GetGemma() != nullptr); - const ModelConfig& config = s_env->GetGemma()->GetModelConfig(); - float entropy = s_env->CrossEntropy(kGettysburg); - fprintf(stderr, "per-token entropy: %f\n", entropy); - switch (config.model) { - case gcpp::Model::GRIFFIN_2B: - EXPECT_NEAR(entropy, 0.71f, 0.02f); - break; - case gcpp::Model::GEMMA2_2B: - EXPECT_NEAR(entropy, 0.20f, 0.02f); - break; - case gcpp::Model::GEMMA2_9B: - EXPECT_NEAR(entropy, 0.15f, 0.02f); - break; - case gcpp::Model::GEMMA2_27B: - EXPECT_NEAR(entropy, 0.14f, 0.02f); - break; - default: - FAIL() << "no entropy expectation for this model"; - break; - } -} - } // namespace } // namespace gcpp