Fix paligemma_test, refs #588

Detect PaliGemma models from layer names
Remove unused allocator arg from CreateInvTimescale
matmul: only warn once about dim divisibility
Print config also in tests if --verbosity 2
PiperOrigin-RevId: 766605131
This commit is contained in:
Jan Wassenberg 2025-06-03 04:44:50 -07:00 committed by Copybara-Service
parent 209009b57e
commit 839a642992
10 changed files with 91 additions and 71 deletions

View File

@ -47,13 +47,16 @@ void InitGenerator(const InferenceArgs& inference, std::mt19937& gen) {
}
}
GemmaEnv::GemmaEnv(const LoaderArgs& loader,
const ThreadingArgs& threading_args,
GemmaEnv::GemmaEnv(const LoaderArgs& loader, const ThreadingArgs& threading,
const InferenceArgs& inference)
: env_(MakeMatMulEnv(threading_args)), gemma_(loader, env_) {
: env_(MakeMatMulEnv(threading)), gemma_(loader, env_) {
const ModelConfig& config = gemma_.GetModelConfig();
// Only allocate one for starters because GenerateBatch might not be called.
kv_caches_.push_back(
KVCache(gemma_.GetModelConfig(), inference.prefill_tbatch_size));
kv_caches_.push_back(KVCache(config, inference.prefill_tbatch_size));
if (inference.verbosity >= 2) {
ShowConfig(loader, threading, inference, config);
}
InitGenerator(inference, gen_);

View File

@ -40,9 +40,12 @@ struct Activations {
x("x", Extents2D(batch_size, config.model_dim), pad_),
// `vocab_size == 0` means it is for Vit part, VitAttention is still MHA
// and does not use an external KV cache.
q("q", Extents2D(batch_size, config.vocab_size == 0 ?
layer_config.heads * 3 * layer_config.qkv_dim :
layer_config.heads * layer_config.qkv_dim), pad_),
q("q",
Extents2D(batch_size,
config.vocab_size == 0
? layer_config.heads * 3 * layer_config.qkv_dim
: layer_config.heads * layer_config.qkv_dim),
pad_),
logits("logits", Extents2D(batch_size, config.vocab_size), pad_),
pre_att_rms_out("pre_att_rms_out",
@ -74,12 +77,12 @@ struct Activations {
"griffin_mul",
is_griffin ? Extents2D(batch_size, config.model_dim) : none_, pad_),
inv_timescale(CreateInvTimescale(
ThreadingContext::Get().allocator, layer_config.qkv_dim,
inv_timescale(
CreateInvTimescale(layer_config.qkv_dim,
layer_config.post_qk == PostQKType::HalfRope)),
inv_timescale_global(CreateInvTimescale(
ThreadingContext::Get().allocator, layer_config.qkv_dim,
layer_config.post_qk == PostQKType::HalfRope, 1000000.0)),
layer_config.qkv_dim, layer_config.post_qk == PostQKType::HalfRope,
1000000.0)),
env(env) {
HWY_ASSERT(batch_size != 0);

View File

@ -720,9 +720,16 @@ Model DeduceModel(size_t layers, int layer_types) {
if (layer_types & kDeducedGriffin) return Model::GRIFFIN_2B;
if (layer_types & kDeducedViT) return Model::GEMMA3_1B;
return Model::GEMMA2_2B;
case 27:
return (layer_types & kDeduced448) ? Model::PALIGEMMA2_3B_448
: Model::PALIGEMMA2_3B_224;
case 34:
return Model::GEMMA3_4B;
case 42:
if (layer_types & kDeducedViT) {
return (layer_types & kDeduced448) ? Model::PALIGEMMA2_10B_448
: Model::PALIGEMMA2_10B_224;
}
return Model::GEMMA2_9B;
case 46:
return Model::GEMMA2_27B;
@ -735,12 +742,6 @@ Model DeduceModel(size_t layers, int layer_types) {
/*
return Model::GEMMA2_772M;
return Model::PALIGEMMA2_772M_224;
return Model::PALIGEMMA_224;
return Model::PALIGEMMA_448;
return Model::PALIGEMMA2_3B_224;
return Model::PALIGEMMA2_3B_448;
return Model::PALIGEMMA2_10B_224;
return Model::PALIGEMMA2_10B_448;
*/
default:
HWY_WARN("Failed to deduce model type from layer count %zu types %x.",

View File

@ -473,6 +473,7 @@ ModelConfig GetVitConfig(const ModelConfig& config);
enum DeducedLayerTypes {
kDeducedGriffin = 1,
kDeducedViT = 2,
kDeduced448 = 4, // For ViT, 448x448 resolution instead of 224x224.
};
// layer_types is one or more of `DeducedLayerTypes`.

View File

@ -152,7 +152,7 @@ class TypePrefix {
const uint64_t bytes = bytes_[type_idx];
if (bytes == 0) continue;
const double percent = 100.0 * bytes / total_bytes_;
fprintf(stderr, "%zu blob bytes (%.2f%%) of %s\n",
fprintf(stderr, "%12zu blob bytes (%5.2f%%) of %4s\n",
static_cast<size_t>(bytes), percent, TypeName(type));
}
}
@ -185,15 +185,22 @@ static size_t DeduceNumLayers(const KeyVec& keys) {
// Looks for known tensor names associated with model families.
// This works with or without type prefixes because it searches for substrings.
static int DeduceLayerTypes(const KeyVec& keys) {
static int DeduceLayerTypes(const BlobReader& reader) {
int layer_types = 0;
for (const std::string& key : keys) {
for (size_t key_idx = 0; key_idx < reader.Keys().size(); ++key_idx) {
const std::string& key = reader.Keys()[key_idx];
if (key.find("gr_conv_w") != std::string::npos) { // NOLINT
return kDeducedGriffin;
}
if (key.find("qkv_einsum_w") != std::string::npos) { // NOLINT
if (key.find("qkv_ein_w") != std::string::npos) { // NOLINT
layer_types |= kDeducedViT;
}
if (key.find("img_pos_emb") != std::string::npos) { // NOLINT
// About 5.88 elements per pixel; assume at least bf16.
if (reader.Range(key_idx).bytes > 448 * 448 * 5 * sizeof(BF16)) {
layer_types |= kDeduced448;
}
}
}
return layer_types;
}
@ -211,7 +218,7 @@ static ModelConfig ReadOrDeduceConfig(BlobReader& reader,
// Always deduce so we can verify it against the config we read.
const size_t layers = DeduceNumLayers(reader.Keys());
const int layer_types = DeduceLayerTypes(reader.Keys());
const int layer_types = DeduceLayerTypes(reader);
const Model deduced_model = DeduceModel(layers, layer_types);
ModelConfig config;

View File

@ -19,6 +19,10 @@
#include <vector>
#pragma push_macro("PROFILER_ENABLED")
#undef PROFILER_ENABLED
#define PROFILER_ENABLED 0
#include "compression/types.h"
#include "ops/matmul.h" // IWYU pragma: export
#include "util/allocator.h"
@ -1322,7 +1326,8 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
if constexpr (HWY_IS_DEBUG_BUILD) {
fprintf(stderr,
"MatMul perf warning: setting row pointers because "
"C.AttachRowPtrs() was not called.\n");
"%s.AttachRowPtrs() was not called.\n",
C.Name());
}
HWY_DASSERT(C.HasPtr());
for (size_t r = 0; r < C.Rows(); ++r) {
@ -1416,3 +1421,5 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
HWY_AFTER_NAMESPACE();
#endif // NOLINT
#pragma pop_macro("PROFILER_ENABLED")

View File

@ -401,10 +401,14 @@ static size_t NPMultiple(const Allocator& allocator, size_t N,
}
// This happens in tests with small N, hence do not assert.
if (N % (np_multiple * num_packages) && N >= 128) {
static bool warned = false;
if (!warned) {
warned = true;
HWY_WARN(
"NPMultiple: N=%zu still not divisible by np_multiple=%zu "
"NPMultiple: N=%zu still not divisible by np_multiple=%zu * "
"num_packages=%zu\n",
N, np_multiple, num_packages);
}
np_multiple = nr;
}
}

View File

@ -26,8 +26,7 @@
namespace gcpp {
static inline HWY_MAYBE_UNUSED MatStorageT<float> CreateInvTimescale(
const Allocator& allocator, size_t qkv_dim, bool half_rope,
double base_frequency = 10000.0) {
size_t qkv_dim, bool half_rope, double base_frequency = 10000.0) {
const size_t rope_dim = half_rope ? qkv_dim / 2 : qkv_dim;
MatStorageT<float> inv_timescale("inv_timescale", rope_dim / 2);
for (size_t dim = 0; dim < rope_dim / 2; ++dim) {

View File

@ -386,8 +386,6 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void ScalarRopeAndMulBy(
}
void TestRopeAndMulBy() {
const Allocator& allocator = ThreadingContext::Get().allocator;
ModelConfig config(Model::GEMMA2_9B, Type::kSFP,
ChooseWrapping(Model::GEMMA2_9B));
int dim_qkv = config.layer_configs[0].qkv_dim;
@ -410,7 +408,7 @@ void TestRopeAndMulBy() {
MatStorageT<float> kexpected("kexpected", dim_qkv);
MatStorageT<float> kactual("kactual", dim_qkv);
MatStorageT<float> inv_timescale = CreateInvTimescale(
allocator, config.layer_configs[0].qkv_dim,
config.layer_configs[0].qkv_dim,
config.layer_configs[0].post_qk == PostQKType::HalfRope);
// Assert VectorizedRope computation is same as regular rope at different pos.
for (int pos = 1; pos < 500; pos++) {

View File

@ -14,6 +14,7 @@
// limitations under the License.
#include <cstdio>
#include <memory>
#include <string>
#include <vector>
@ -40,25 +41,31 @@ class PaliGemmaTest : public ::testing::Test {
protected:
void InitVit(const std::string& path);
std::string GemmaReply(const std::string& prompt_text) const;
void TestQuestions(const char* kQA[][2], size_t num_questions);
void TestQuestion(const char* question, const char* expected_substring);
ImageTokens image_tokens_;
std::unique_ptr<ImageTokens> image_tokens_;
std::vector<uint8_t*> image_row_ptrs_;
};
void PaliGemmaTest::InitVit(const std::string& path) {
ASSERT_NE(s_env->GetGemma(), nullptr);
const Allocator& allocator = s_env->Env().ctx.allocator;
const Gemma& gemma = *(s_env->GetGemma());
image_tokens_ = ImageTokens(
allocator, Extents2D(gemma.GetModelConfig().vit_config.seq_len,
gemma.GetModelConfig().model_dim));
const ModelConfig& config = gemma.GetModelConfig();
image_tokens_ = std::make_unique<ImageTokens>(
"image", Extents2D(config.vit_config.seq_len, config.model_dim),
MatPadding::kPacked);
image_row_ptrs_.resize(image_tokens_->Rows());
for (size_t r = 0; r < image_tokens_->Rows(); ++r) {
image_row_ptrs_[r] = image_tokens_->RowBytes(r);
}
image_tokens_->AttachRowPtrs(image_row_ptrs_.data());
Image image;
HWY_ASSERT(gemma.GetModelConfig().wrapping == PromptWrapping::PALIGEMMA);
HWY_ASSERT(config.wrapping == PromptWrapping::PALIGEMMA);
HWY_ASSERT(image.ReadPPM(path));
const size_t image_size = gemma.GetModelConfig().vit_config.image_size;
const size_t image_size = config.vit_config.image_size;
image.Resize(image_size, image_size);
RuntimeConfig runtime_config = {.gen = &s_env->MutableGen(), .verbosity = 0};
gemma.GenerateImageTokens(runtime_config, image, image_tokens_);
gemma.GenerateImageTokens(runtime_config, image, *image_tokens_);
}
std::string PaliGemmaTest::GemmaReply(const std::string& prompt_text) const{
@ -67,7 +74,7 @@ std::string PaliGemmaTest::GemmaReply(const std::string& prompt_text) const{
RuntimeConfig runtime_config = {.max_generated_tokens = 512,
.gen = &s_env->MutableGen(),
.verbosity = 0};
runtime_config.image_tokens = &image_tokens_;
runtime_config.image_tokens = image_tokens_.get();
size_t abs_pos = 0;
std::string mutable_prompt = prompt_text;
std::vector<int> tokens = s_env->WrapAndTokenize(mutable_prompt);
@ -79,7 +86,7 @@ std::string PaliGemmaTest::GemmaReply(const std::string& prompt_text) const{
return true;
};
runtime_config.stream_token = stream_token,
tokens.insert(tokens.begin(), image_tokens_.BatchSize(), 0);
tokens.insert(tokens.begin(), image_tokens_->Rows(), 0);
size_t num_tokens = tokens.size();
size_t prefix_end = num_tokens;
runtime_config.prefill_tbatch_size = num_tokens;
@ -89,39 +96,29 @@ std::string PaliGemmaTest::GemmaReply(const std::string& prompt_text) const{
return response;
}
void PaliGemmaTest::TestQuestions(const char* kQA[][2], size_t num_questions) {
void PaliGemmaTest::TestQuestion(const char* question,
const char* expected_substring) {
ASSERT_NE(s_env->GetGemma(), nullptr);
std::string path = "paligemma/testdata/image.ppm";
InitVit(path);
for (size_t i = 0; i < num_questions; ++i) {
fprintf(stderr, "Question %zu\n\n", i + 1);
std::string response = GemmaReply(kQA[i][0]);
fprintf(stderr, "'%s'\n\n", response.c_str());
EXPECT_TRUE(response.find(kQA[i][1]) != std::string::npos); // NOLINT
}
const std::string reply = GemmaReply(question);
fprintf(stderr, "'%s'\n\n", reply.c_str());
EXPECT_TRUE(reply.find(expected_substring) != std::string::npos); // NOLINT
}
TEST_F(PaliGemmaTest, General) {
TEST_F(PaliGemmaTest, QueryObjects) {
ASSERT_NE(s_env->GetGemma(), nullptr);
static const char* kQA_2_3B_pt_448[][2] = {
{"describe this image", "The Grossmünster in Zürich"},
{"describe image briefly", "The Grossmünster"},
{"answer en What objects are in the image?", "Building, Tower"},
{"segment water", "<loc1023> water"},
};
const char* (*qa)[2];
size_t num;
switch (s_env->GetGemma()->GetModelConfig().model) {
case Model::PALIGEMMA2_3B_448:
qa = kQA_2_3B_pt_448;
num = sizeof(kQA_2_3B_pt_448) / sizeof(kQA_2_3B_pt_448[0]);
break;
default:
FAIL() << "Unsupported model: "
<< s_env->GetGemma()->GetModelConfig().display_name;
break;
const char* question = "answer en What objects are in the image?";
const char* expected_substring = "Building, Tower"; // 3B PT 224, 10B Mix 224
const Model model = s_env->GetGemma()->GetModelConfig().model;
if (model == Model::PALIGEMMA2_3B_448) {
expected_substring = "Lake.";
} else if (model == Model::PALIGEMMA2_3B_224) {
expected_substring = "Cloud, Water.";
} else if (model == Model::PALIGEMMA2_10B_224) {
expected_substring = "Building.";
}
TestQuestions(qa, num);
TestQuestion(question, expected_substring);
}
} // namespace