mirror of https://github.com/google/gemma.cpp.git
Update gemma_test with the expected entropy values for the IT models of size 2B/7B/9B/27B.
PiperOrigin-RevId: 649662047
This commit is contained in:
parent
438b1bace2
commit
cdebcc3533
|
|
@ -26,6 +26,11 @@
|
|||
#include "hwy/aligned_allocator.h"
|
||||
#include "hwy/tests/hwy_gtest.h"
|
||||
|
||||
// This test can be run manually with the downloaded gemma weights.
|
||||
// To run the test, pass the following flags:
|
||||
// --model <model> --tokenizer <tokenizer_path> --weights <weights_path>
|
||||
// It should pass for the following models: 2b-it, 7b-it, 9b-it, 27b-it
|
||||
|
||||
namespace gcpp {
|
||||
namespace {
|
||||
|
||||
|
|
@ -170,25 +175,76 @@ TEST_F(GemmaTest, CrossEntropySmall) {
|
|||
static const char kSmall[] =
|
||||
"The capital of Hungary is Budapest which is located in Europe.";
|
||||
float entropy = s_env->CrossEntropy(kSmall);
|
||||
fprintf(stderr, "per-byte entropy: %f\n", entropy);
|
||||
const bool is_7b = s_env->GetModel()->Info().model == gcpp::Model::GEMMA_7B;
|
||||
EXPECT_LT(entropy, is_7b ? 2.1f : 2.0f);
|
||||
fprintf(stderr, "per-token entropy: %f\n", entropy);
|
||||
float expected_entropy;
|
||||
switch (s_env->GetModel()->Info().model) {
|
||||
case gcpp::Model::GEMMA_2B:
|
||||
expected_entropy = 2.56f;
|
||||
break;
|
||||
case gcpp::Model::GEMMA_7B:
|
||||
expected_entropy = 2.91f;
|
||||
break;
|
||||
case gcpp::Model::GEMMA_9B:
|
||||
expected_entropy = 1.28f;
|
||||
break;
|
||||
case gcpp::Model::GEMMA_27B:
|
||||
expected_entropy = 1.30f;
|
||||
break;
|
||||
default:
|
||||
FAIL() << "no entropy expectation for this model";
|
||||
break;
|
||||
}
|
||||
EXPECT_NEAR(entropy, expected_entropy, 0.02f);
|
||||
}
|
||||
|
||||
TEST_F(GemmaTest, CrossEntropyJingleBells) {
|
||||
if (!s_env->GetModel()) return;
|
||||
float entropy = s_env->CrossEntropy(kJingleBells);
|
||||
fprintf(stderr, "per-byte entropy: %f\n", entropy);
|
||||
const bool is_7b = s_env->GetModel()->Info().model == gcpp::Model::GEMMA_7B;
|
||||
EXPECT_LT(entropy, is_7b ? 0.9f : 1.8f);
|
||||
fprintf(stderr, "per-token entropy: %f\n", entropy);
|
||||
float expected_entropy;
|
||||
switch (s_env->GetModel()->Info().model) {
|
||||
case gcpp::Model::GEMMA_2B:
|
||||
expected_entropy = 1.85f;
|
||||
break;
|
||||
case gcpp::Model::GEMMA_7B:
|
||||
expected_entropy = 1.06f;
|
||||
break;
|
||||
case gcpp::Model::GEMMA_9B:
|
||||
expected_entropy = 0.37f;
|
||||
break;
|
||||
case gcpp::Model::GEMMA_27B:
|
||||
expected_entropy = 0.33f;
|
||||
break;
|
||||
default:
|
||||
FAIL() << "no entropy expectation for this model";
|
||||
break;
|
||||
}
|
||||
EXPECT_NEAR(entropy, expected_entropy, 0.02f);
|
||||
}
|
||||
|
||||
TEST_F(GemmaTest, CrossEntropyGettysburg) {
|
||||
if (!s_env->GetModel()) return;
|
||||
float entropy = s_env->CrossEntropy(kGettysburg);
|
||||
fprintf(stderr, "per-byte entropy: %f\n", entropy);
|
||||
const bool is_7b = s_env->GetModel()->Info().model == gcpp::Model::GEMMA_7B;
|
||||
EXPECT_LT(entropy, is_7b ? 0.8f : 1.2f);
|
||||
fprintf(stderr, "per-token entropy: %f\n", entropy);
|
||||
float expected_entropy;
|
||||
switch (s_env->GetModel()->Info().model) {
|
||||
case gcpp::Model::GEMMA_2B:
|
||||
expected_entropy = 1.05f;
|
||||
break;
|
||||
case gcpp::Model::GEMMA_7B:
|
||||
expected_entropy = 0.83f;
|
||||
break;
|
||||
case gcpp::Model::GEMMA_9B:
|
||||
expected_entropy = 0.15f;
|
||||
break;
|
||||
case gcpp::Model::GEMMA_27B:
|
||||
expected_entropy = 0.14f;
|
||||
break;
|
||||
default:
|
||||
FAIL() << "no entropy expectation for this model";
|
||||
break;
|
||||
}
|
||||
EXPECT_NEAR(entropy, expected_entropy, 0.02f);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
|
|
|||
Loading…
Reference in New Issue