Fix paligemma_test, refs #588

Detect PaliGemma models from layer names
Remove unused allocator arg from CreateInvTimescale
matmul: only warn once about dim divisibility
Print config also in tests if --verbosity 2
PiperOrigin-RevId: 766605131
This commit is contained in:
Jan Wassenberg 2025-06-03 04:44:50 -07:00 committed by Copybara-Service
parent 209009b57e
commit 839a642992
10 changed files with 91 additions and 71 deletions

View File

@ -47,13 +47,16 @@ void InitGenerator(const InferenceArgs& inference, std::mt19937& gen) {
} }
} }
GemmaEnv::GemmaEnv(const LoaderArgs& loader, GemmaEnv::GemmaEnv(const LoaderArgs& loader, const ThreadingArgs& threading,
const ThreadingArgs& threading_args,
const InferenceArgs& inference) const InferenceArgs& inference)
: env_(MakeMatMulEnv(threading_args)), gemma_(loader, env_) { : env_(MakeMatMulEnv(threading)), gemma_(loader, env_) {
const ModelConfig& config = gemma_.GetModelConfig();
// Only allocate one for starters because GenerateBatch might not be called. // Only allocate one for starters because GenerateBatch might not be called.
kv_caches_.push_back( kv_caches_.push_back(KVCache(config, inference.prefill_tbatch_size));
KVCache(gemma_.GetModelConfig(), inference.prefill_tbatch_size));
if (inference.verbosity >= 2) {
ShowConfig(loader, threading, inference, config);
}
InitGenerator(inference, gen_); InitGenerator(inference, gen_);

View File

@ -40,9 +40,12 @@ struct Activations {
x("x", Extents2D(batch_size, config.model_dim), pad_), x("x", Extents2D(batch_size, config.model_dim), pad_),
// `vocab_size == 0` means it is for Vit part, VitAttention is still MHA // `vocab_size == 0` means it is for Vit part, VitAttention is still MHA
// and does not use an external KV cache. // and does not use an external KV cache.
q("q", Extents2D(batch_size, config.vocab_size == 0 ? q("q",
layer_config.heads * 3 * layer_config.qkv_dim : Extents2D(batch_size,
layer_config.heads * layer_config.qkv_dim), pad_), config.vocab_size == 0
? layer_config.heads * 3 * layer_config.qkv_dim
: layer_config.heads * layer_config.qkv_dim),
pad_),
logits("logits", Extents2D(batch_size, config.vocab_size), pad_), logits("logits", Extents2D(batch_size, config.vocab_size), pad_),
pre_att_rms_out("pre_att_rms_out", pre_att_rms_out("pre_att_rms_out",
@ -74,12 +77,12 @@ struct Activations {
"griffin_mul", "griffin_mul",
is_griffin ? Extents2D(batch_size, config.model_dim) : none_, pad_), is_griffin ? Extents2D(batch_size, config.model_dim) : none_, pad_),
inv_timescale(CreateInvTimescale( inv_timescale(
ThreadingContext::Get().allocator, layer_config.qkv_dim, CreateInvTimescale(layer_config.qkv_dim,
layer_config.post_qk == PostQKType::HalfRope)), layer_config.post_qk == PostQKType::HalfRope)),
inv_timescale_global(CreateInvTimescale( inv_timescale_global(CreateInvTimescale(
ThreadingContext::Get().allocator, layer_config.qkv_dim, layer_config.qkv_dim, layer_config.post_qk == PostQKType::HalfRope,
layer_config.post_qk == PostQKType::HalfRope, 1000000.0)), 1000000.0)),
env(env) { env(env) {
HWY_ASSERT(batch_size != 0); HWY_ASSERT(batch_size != 0);

View File

@ -720,9 +720,16 @@ Model DeduceModel(size_t layers, int layer_types) {
if (layer_types & kDeducedGriffin) return Model::GRIFFIN_2B; if (layer_types & kDeducedGriffin) return Model::GRIFFIN_2B;
if (layer_types & kDeducedViT) return Model::GEMMA3_1B; if (layer_types & kDeducedViT) return Model::GEMMA3_1B;
return Model::GEMMA2_2B; return Model::GEMMA2_2B;
case 27:
return (layer_types & kDeduced448) ? Model::PALIGEMMA2_3B_448
: Model::PALIGEMMA2_3B_224;
case 34: case 34:
return Model::GEMMA3_4B; return Model::GEMMA3_4B;
case 42: case 42:
if (layer_types & kDeducedViT) {
return (layer_types & kDeduced448) ? Model::PALIGEMMA2_10B_448
: Model::PALIGEMMA2_10B_224;
}
return Model::GEMMA2_9B; return Model::GEMMA2_9B;
case 46: case 46:
return Model::GEMMA2_27B; return Model::GEMMA2_27B;
@ -735,12 +742,6 @@ Model DeduceModel(size_t layers, int layer_types) {
/* /*
return Model::GEMMA2_772M; return Model::GEMMA2_772M;
return Model::PALIGEMMA2_772M_224; return Model::PALIGEMMA2_772M_224;
return Model::PALIGEMMA_224;
return Model::PALIGEMMA_448;
return Model::PALIGEMMA2_3B_224;
return Model::PALIGEMMA2_3B_448;
return Model::PALIGEMMA2_10B_224;
return Model::PALIGEMMA2_10B_448;
*/ */
default: default:
HWY_WARN("Failed to deduce model type from layer count %zu types %x.", HWY_WARN("Failed to deduce model type from layer count %zu types %x.",

View File

@ -473,6 +473,7 @@ ModelConfig GetVitConfig(const ModelConfig& config);
enum DeducedLayerTypes { enum DeducedLayerTypes {
kDeducedGriffin = 1, kDeducedGriffin = 1,
kDeducedViT = 2, kDeducedViT = 2,
kDeduced448 = 4, // For ViT, 448x448 resolution instead of 224x224.
}; };
// layer_types is one or more of `DeducedLayerTypes`. // layer_types is one or more of `DeducedLayerTypes`.

View File

@ -152,7 +152,7 @@ class TypePrefix {
const uint64_t bytes = bytes_[type_idx]; const uint64_t bytes = bytes_[type_idx];
if (bytes == 0) continue; if (bytes == 0) continue;
const double percent = 100.0 * bytes / total_bytes_; const double percent = 100.0 * bytes / total_bytes_;
fprintf(stderr, "%zu blob bytes (%.2f%%) of %s\n", fprintf(stderr, "%12zu blob bytes (%5.2f%%) of %4s\n",
static_cast<size_t>(bytes), percent, TypeName(type)); static_cast<size_t>(bytes), percent, TypeName(type));
} }
} }
@ -185,15 +185,22 @@ static size_t DeduceNumLayers(const KeyVec& keys) {
// Looks for known tensor names associated with model families. // Looks for known tensor names associated with model families.
// This works with or without type prefixes because it searches for substrings. // This works with or without type prefixes because it searches for substrings.
static int DeduceLayerTypes(const KeyVec& keys) { static int DeduceLayerTypes(const BlobReader& reader) {
int layer_types = 0; int layer_types = 0;
for (const std::string& key : keys) { for (size_t key_idx = 0; key_idx < reader.Keys().size(); ++key_idx) {
const std::string& key = reader.Keys()[key_idx];
if (key.find("gr_conv_w") != std::string::npos) { // NOLINT if (key.find("gr_conv_w") != std::string::npos) { // NOLINT
return kDeducedGriffin; return kDeducedGriffin;
} }
if (key.find("qkv_einsum_w") != std::string::npos) { // NOLINT if (key.find("qkv_ein_w") != std::string::npos) { // NOLINT
layer_types |= kDeducedViT; layer_types |= kDeducedViT;
} }
if (key.find("img_pos_emb") != std::string::npos) { // NOLINT
// About 5.88 elements per pixel; assume at least bf16.
if (reader.Range(key_idx).bytes > 448 * 448 * 5 * sizeof(BF16)) {
layer_types |= kDeduced448;
}
}
} }
return layer_types; return layer_types;
} }
@ -211,7 +218,7 @@ static ModelConfig ReadOrDeduceConfig(BlobReader& reader,
// Always deduce so we can verify it against the config we read. // Always deduce so we can verify it against the config we read.
const size_t layers = DeduceNumLayers(reader.Keys()); const size_t layers = DeduceNumLayers(reader.Keys());
const int layer_types = DeduceLayerTypes(reader.Keys()); const int layer_types = DeduceLayerTypes(reader);
const Model deduced_model = DeduceModel(layers, layer_types); const Model deduced_model = DeduceModel(layers, layer_types);
ModelConfig config; ModelConfig config;

View File

@ -19,6 +19,10 @@
#include <vector> #include <vector>
#pragma push_macro("PROFILER_ENABLED")
#undef PROFILER_ENABLED
#define PROFILER_ENABLED 0
#include "compression/types.h" #include "compression/types.h"
#include "ops/matmul.h" // IWYU pragma: export #include "ops/matmul.h" // IWYU pragma: export
#include "util/allocator.h" #include "util/allocator.h"
@ -1322,7 +1326,8 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
if constexpr (HWY_IS_DEBUG_BUILD) { if constexpr (HWY_IS_DEBUG_BUILD) {
fprintf(stderr, fprintf(stderr,
"MatMul perf warning: setting row pointers because " "MatMul perf warning: setting row pointers because "
"C.AttachRowPtrs() was not called.\n"); "%s.AttachRowPtrs() was not called.\n",
C.Name());
} }
HWY_DASSERT(C.HasPtr()); HWY_DASSERT(C.HasPtr());
for (size_t r = 0; r < C.Rows(); ++r) { for (size_t r = 0; r < C.Rows(); ++r) {
@ -1416,3 +1421,5 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
HWY_AFTER_NAMESPACE(); HWY_AFTER_NAMESPACE();
#endif // NOLINT #endif // NOLINT
#pragma pop_macro("PROFILER_ENABLED")

View File

@ -401,10 +401,14 @@ static size_t NPMultiple(const Allocator& allocator, size_t N,
} }
// This happens in tests with small N, hence do not assert. // This happens in tests with small N, hence do not assert.
if (N % (np_multiple * num_packages) && N >= 128) { if (N % (np_multiple * num_packages) && N >= 128) {
static bool warned = false;
if (!warned) {
warned = true;
HWY_WARN( HWY_WARN(
"NPMultiple: N=%zu still not divisible by np_multiple=%zu " "NPMultiple: N=%zu still not divisible by np_multiple=%zu * "
"num_packages=%zu\n", "num_packages=%zu\n",
N, np_multiple, num_packages); N, np_multiple, num_packages);
}
np_multiple = nr; np_multiple = nr;
} }
} }

View File

@ -26,8 +26,7 @@
namespace gcpp { namespace gcpp {
static inline HWY_MAYBE_UNUSED MatStorageT<float> CreateInvTimescale( static inline HWY_MAYBE_UNUSED MatStorageT<float> CreateInvTimescale(
const Allocator& allocator, size_t qkv_dim, bool half_rope, size_t qkv_dim, bool half_rope, double base_frequency = 10000.0) {
double base_frequency = 10000.0) {
const size_t rope_dim = half_rope ? qkv_dim / 2 : qkv_dim; const size_t rope_dim = half_rope ? qkv_dim / 2 : qkv_dim;
MatStorageT<float> inv_timescale("inv_timescale", rope_dim / 2); MatStorageT<float> inv_timescale("inv_timescale", rope_dim / 2);
for (size_t dim = 0; dim < rope_dim / 2; ++dim) { for (size_t dim = 0; dim < rope_dim / 2; ++dim) {

View File

@ -386,8 +386,6 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void ScalarRopeAndMulBy(
} }
void TestRopeAndMulBy() { void TestRopeAndMulBy() {
const Allocator& allocator = ThreadingContext::Get().allocator;
ModelConfig config(Model::GEMMA2_9B, Type::kSFP, ModelConfig config(Model::GEMMA2_9B, Type::kSFP,
ChooseWrapping(Model::GEMMA2_9B)); ChooseWrapping(Model::GEMMA2_9B));
int dim_qkv = config.layer_configs[0].qkv_dim; int dim_qkv = config.layer_configs[0].qkv_dim;
@ -410,7 +408,7 @@ void TestRopeAndMulBy() {
MatStorageT<float> kexpected("kexpected", dim_qkv); MatStorageT<float> kexpected("kexpected", dim_qkv);
MatStorageT<float> kactual("kactual", dim_qkv); MatStorageT<float> kactual("kactual", dim_qkv);
MatStorageT<float> inv_timescale = CreateInvTimescale( MatStorageT<float> inv_timescale = CreateInvTimescale(
allocator, config.layer_configs[0].qkv_dim, config.layer_configs[0].qkv_dim,
config.layer_configs[0].post_qk == PostQKType::HalfRope); config.layer_configs[0].post_qk == PostQKType::HalfRope);
// Assert VectorizedRope computation is same as regular rope at different pos. // Assert VectorizedRope computation is same as regular rope at different pos.
for (int pos = 1; pos < 500; pos++) { for (int pos = 1; pos < 500; pos++) {

View File

@ -14,6 +14,7 @@
// limitations under the License. // limitations under the License.
#include <cstdio> #include <cstdio>
#include <memory>
#include <string> #include <string>
#include <vector> #include <vector>
@ -40,25 +41,31 @@ class PaliGemmaTest : public ::testing::Test {
protected: protected:
void InitVit(const std::string& path); void InitVit(const std::string& path);
std::string GemmaReply(const std::string& prompt_text) const; std::string GemmaReply(const std::string& prompt_text) const;
void TestQuestions(const char* kQA[][2], size_t num_questions); void TestQuestion(const char* question, const char* expected_substring);
ImageTokens image_tokens_; std::unique_ptr<ImageTokens> image_tokens_;
std::vector<uint8_t*> image_row_ptrs_;
}; };
void PaliGemmaTest::InitVit(const std::string& path) { void PaliGemmaTest::InitVit(const std::string& path) {
ASSERT_NE(s_env->GetGemma(), nullptr); ASSERT_NE(s_env->GetGemma(), nullptr);
const Allocator& allocator = s_env->Env().ctx.allocator;
const Gemma& gemma = *(s_env->GetGemma()); const Gemma& gemma = *(s_env->GetGemma());
image_tokens_ = ImageTokens( const ModelConfig& config = gemma.GetModelConfig();
allocator, Extents2D(gemma.GetModelConfig().vit_config.seq_len, image_tokens_ = std::make_unique<ImageTokens>(
gemma.GetModelConfig().model_dim)); "image", Extents2D(config.vit_config.seq_len, config.model_dim),
MatPadding::kPacked);
image_row_ptrs_.resize(image_tokens_->Rows());
for (size_t r = 0; r < image_tokens_->Rows(); ++r) {
image_row_ptrs_[r] = image_tokens_->RowBytes(r);
}
image_tokens_->AttachRowPtrs(image_row_ptrs_.data());
Image image; Image image;
HWY_ASSERT(gemma.GetModelConfig().wrapping == PromptWrapping::PALIGEMMA); HWY_ASSERT(config.wrapping == PromptWrapping::PALIGEMMA);
HWY_ASSERT(image.ReadPPM(path)); HWY_ASSERT(image.ReadPPM(path));
const size_t image_size = gemma.GetModelConfig().vit_config.image_size; const size_t image_size = config.vit_config.image_size;
image.Resize(image_size, image_size); image.Resize(image_size, image_size);
RuntimeConfig runtime_config = {.gen = &s_env->MutableGen(), .verbosity = 0}; RuntimeConfig runtime_config = {.gen = &s_env->MutableGen(), .verbosity = 0};
gemma.GenerateImageTokens(runtime_config, image, image_tokens_); gemma.GenerateImageTokens(runtime_config, image, *image_tokens_);
} }
std::string PaliGemmaTest::GemmaReply(const std::string& prompt_text) const{ std::string PaliGemmaTest::GemmaReply(const std::string& prompt_text) const{
@ -67,7 +74,7 @@ std::string PaliGemmaTest::GemmaReply(const std::string& prompt_text) const{
RuntimeConfig runtime_config = {.max_generated_tokens = 512, RuntimeConfig runtime_config = {.max_generated_tokens = 512,
.gen = &s_env->MutableGen(), .gen = &s_env->MutableGen(),
.verbosity = 0}; .verbosity = 0};
runtime_config.image_tokens = &image_tokens_; runtime_config.image_tokens = image_tokens_.get();
size_t abs_pos = 0; size_t abs_pos = 0;
std::string mutable_prompt = prompt_text; std::string mutable_prompt = prompt_text;
std::vector<int> tokens = s_env->WrapAndTokenize(mutable_prompt); std::vector<int> tokens = s_env->WrapAndTokenize(mutable_prompt);
@ -79,7 +86,7 @@ std::string PaliGemmaTest::GemmaReply(const std::string& prompt_text) const{
return true; return true;
}; };
runtime_config.stream_token = stream_token, runtime_config.stream_token = stream_token,
tokens.insert(tokens.begin(), image_tokens_.BatchSize(), 0); tokens.insert(tokens.begin(), image_tokens_->Rows(), 0);
size_t num_tokens = tokens.size(); size_t num_tokens = tokens.size();
size_t prefix_end = num_tokens; size_t prefix_end = num_tokens;
runtime_config.prefill_tbatch_size = num_tokens; runtime_config.prefill_tbatch_size = num_tokens;
@ -89,39 +96,29 @@ std::string PaliGemmaTest::GemmaReply(const std::string& prompt_text) const{
return response; return response;
} }
void PaliGemmaTest::TestQuestions(const char* kQA[][2], size_t num_questions) { void PaliGemmaTest::TestQuestion(const char* question,
const char* expected_substring) {
ASSERT_NE(s_env->GetGemma(), nullptr); ASSERT_NE(s_env->GetGemma(), nullptr);
std::string path = "paligemma/testdata/image.ppm"; std::string path = "paligemma/testdata/image.ppm";
InitVit(path); InitVit(path);
for (size_t i = 0; i < num_questions; ++i) { const std::string reply = GemmaReply(question);
fprintf(stderr, "Question %zu\n\n", i + 1); fprintf(stderr, "'%s'\n\n", reply.c_str());
std::string response = GemmaReply(kQA[i][0]); EXPECT_TRUE(reply.find(expected_substring) != std::string::npos); // NOLINT
fprintf(stderr, "'%s'\n\n", response.c_str());
EXPECT_TRUE(response.find(kQA[i][1]) != std::string::npos); // NOLINT
}
} }
TEST_F(PaliGemmaTest, General) { TEST_F(PaliGemmaTest, QueryObjects) {
ASSERT_NE(s_env->GetGemma(), nullptr); ASSERT_NE(s_env->GetGemma(), nullptr);
static const char* kQA_2_3B_pt_448[][2] = { const char* question = "answer en What objects are in the image?";
{"describe this image", "The Grossmünster in Zürich"}, const char* expected_substring = "Building, Tower"; // 3B PT 224, 10B Mix 224
{"describe image briefly", "The Grossmünster"}, const Model model = s_env->GetGemma()->GetModelConfig().model;
{"answer en What objects are in the image?", "Building, Tower"}, if (model == Model::PALIGEMMA2_3B_448) {
{"segment water", "<loc1023> water"}, expected_substring = "Lake.";
}; } else if (model == Model::PALIGEMMA2_3B_224) {
const char* (*qa)[2]; expected_substring = "Cloud, Water.";
size_t num; } else if (model == Model::PALIGEMMA2_10B_224) {
switch (s_env->GetGemma()->GetModelConfig().model) { expected_substring = "Building.";
case Model::PALIGEMMA2_3B_448:
qa = kQA_2_3B_pt_448;
num = sizeof(kQA_2_3B_pt_448) / sizeof(kQA_2_3B_pt_448[0]);
break;
default:
FAIL() << "Unsupported model: "
<< s_env->GetGemma()->GetModelConfig().display_name;
break;
} }
TestQuestions(qa, num); TestQuestion(question, expected_substring);
} }
} // namespace } // namespace