Gemma CPP: move PaliGemma tests' helper to a separate class

This helps to be able to use PaliGemma functionalities directly for inference by just providing tokenizer and weight paths.

Added @mukundagg to allowed authors list.

PiperOrigin-RevId: 772705238
This commit is contained in:
Mukund Aggarwal 2025-06-17 18:36:52 -07:00 committed by Copybara-Service
parent f2adbfbcab
commit 606e22155a
4 changed files with 118 additions and 57 deletions

View File

@ -29,6 +29,24 @@ cc_test(
], ],
) )
cc_library(
name = "paligemma_helper",
srcs = ["paligemma_helper.cc"],
hdrs = ["paligemma_helper.h"],
deps = [
":image",
"//:allocator",
"//:benchmark_helper",
"//:configs",
"//:gemma_args",
"//:gemma_lib",
"//compression:types",
"//io",
"@highway//:hwy",
"@highway//:profiler",
],
)
cc_test( cc_test(
name = "paligemma_test", name = "paligemma_test",
srcs = ["paligemma_test.cc"], srcs = ["paligemma_test.cc"],
@ -39,6 +57,8 @@ cc_test(
"no_tap", "no_tap",
], ],
deps = [ deps = [
":paligemma_helper",
"//devtools/build/runtime:get_runfiles_dir",
"@googletest//:gtest_main", # buildcleaner: keep "@googletest//:gtest_main", # buildcleaner: keep
"//:allocator", "//:allocator",
"//:benchmark_helper", "//:benchmark_helper",

View File

@ -0,0 +1,68 @@
#include "paligemma/paligemma_helper.h"
#include <cstddef>
#include <memory>
#include <string>
#include <vector>
#include "compression/types.h"
#include "evals/benchmark_helper.h"
#include "gemma/configs.h"
#include "gemma/gemma.h"
#include "util/allocator.h"
#include "hwy/base.h"
namespace gcpp {
void PaliGemmaHelper::InitVit(const std::string& path) {
HWY_ASSERT(env_->GetGemma() != nullptr);
const Gemma& gemma = *(env_->GetGemma());
const ModelConfig& config = gemma.GetModelConfig();
HWY_ASSERT(config.wrapping == PromptWrapping::PALIGEMMA);
image_tokens_ = std::make_unique<ImageTokens>(
"image", Extents2D(config.vit_config.seq_len, config.model_dim),
MatPadding::kPacked);
image_tokens_->AllocateAndAttachRowPtrs(env_->Env().row_ptrs);
Image image;
HWY_ASSERT(image.ReadPPM(path));
const size_t image_size = config.vit_config.image_size;
image.Resize(image_size, image_size);
RuntimeConfig runtime_config = {.gen = &env_->MutableGen(),
.verbosity = 0};
gemma.GenerateImageTokens(runtime_config, env_->MutableKVCache().SeqLen(),
image, *image_tokens_);
}
std::string PaliGemmaHelper::GemmaReply(const std::string& prompt_text) const {
const Gemma& model = *(env_->GetGemma());
env_->MutableGen().seed(0x12345678);
std::string response;
auto stream_token = [&](int token, float) {
std::string token_text;
HWY_ASSERT(
model.Tokenizer().Decode(std::vector<int>{token}, &token_text));
response += token_text;
return true;
};
std::string mutable_prompt = prompt_text;
std::vector<int> tokens = env_->WrapAndTokenize(mutable_prompt);
tokens.insert(tokens.begin(), image_tokens_->Rows(), 0);
RuntimeConfig runtime_config = {.max_generated_tokens = 512,
// PrefixLM sees/attends to all tokens.
.prefill_tbatch_size = tokens.size(),
.gen = &env_->MutableGen(),
.verbosity = 0,
.stream_token = stream_token,
.image_tokens = image_tokens_.get()};
const size_t prefix_end = tokens.size();
TimingInfo timing_info = {.verbosity = 0};
model.Generate(runtime_config, tokens, /*pos=*/0, prefix_end,
env_->MutableKVCache(), timing_info);
return response;
}
} // namespace gcpp

View File

@ -0,0 +1,25 @@
#ifndef THIRD_PARTY_GEMMA_CPP_PALIGEMMA_PALIGEMMA_HELPER_H_
#define THIRD_PARTY_GEMMA_CPP_PALIGEMMA_PALIGEMMA_HELPER_H_
#include <memory>
#include <string>
#include "evals/benchmark_helper.h"
#include "gemma/gemma_args.h"
namespace gcpp {
class PaliGemmaHelper {
public:
explicit PaliGemmaHelper(GemmaEnv* env) : env_(env) {};
void InitVit(const std::string& path);
std::string GemmaReply(const std::string& prompt_text) const;
private:
std::unique_ptr<ImageTokens> image_tokens_;
GemmaEnv* env_;
};
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_PALIGEMMA_PALIGEMMA_HELPER_H_

View File

@ -17,16 +17,14 @@
#include <memory> #include <memory>
#include <string> #include <string>
#include <vector>
#include "compression/types.h"
#include "evals/benchmark_helper.h" #include "evals/benchmark_helper.h"
#include "gemma/configs.h" #include "gemma/configs.h"
#include "gemma/gemma.h" #include "gemma/gemma.h"
#include "io/io.h" #include "io/io.h"
#include "util/allocator.h" #include "util/allocator.h"
#include "hwy/base.h"
#include "hwy/tests/hwy_gtest.h" #include "hwy/tests/hwy_gtest.h"
#include "paligemma/paligemma_helper.h"
// This test can be run manually with the downloaded PaliGemma weights. // This test can be run manually with the downloaded PaliGemma weights.
// It should pass for `paligemma-3b-mix-224` and `paligemma2-3b-pt-448`. // It should pass for `paligemma-3b-mix-224` and `paligemma2-3b-pt-448`.
@ -41,63 +39,13 @@ GemmaEnv* s_env = nullptr;
class PaliGemmaTest : public ::testing::Test { class PaliGemmaTest : public ::testing::Test {
protected: protected:
void InitVit(const std::string& path) {
ASSERT_NE(s_env->GetGemma(), nullptr);
const Gemma& gemma = *(s_env->GetGemma());
const ModelConfig& config = gemma.GetModelConfig();
HWY_ASSERT(config.wrapping == PromptWrapping::PALIGEMMA);
image_tokens_ = std::make_unique<ImageTokens>(
"image", Extents2D(config.vit_config.seq_len, config.model_dim),
MatPadding::kPacked);
image_tokens_->AllocateAndAttachRowPtrs(s_env->Env().row_ptrs);
Image image;
HWY_ASSERT(image.ReadPPM(path));
const size_t image_size = config.vit_config.image_size;
image.Resize(image_size, image_size);
RuntimeConfig runtime_config = {.gen = &s_env->MutableGen(),
.verbosity = 0};
gemma.GenerateImageTokens(runtime_config, s_env->MutableKVCache().SeqLen(),
image, *image_tokens_);
}
std::string GemmaReply(const std::string& prompt_text) const {
const Gemma& model = *(s_env->GetGemma());
s_env->MutableGen().seed(0x12345678);
std::string response;
auto stream_token = [&](int token, float) {
std::string token_text;
HWY_ASSERT(
model.Tokenizer().Decode(std::vector<int>{token}, &token_text));
response += token_text;
return true;
};
std::string mutable_prompt = prompt_text;
std::vector<int> tokens = s_env->WrapAndTokenize(mutable_prompt);
tokens.insert(tokens.begin(), image_tokens_->Rows(), 0);
RuntimeConfig runtime_config = {.max_generated_tokens = 512,
// PrefixLM sees/attends to all tokens.
.prefill_tbatch_size = tokens.size(),
.gen = &s_env->MutableGen(),
.verbosity = 0,
.stream_token = stream_token,
.image_tokens = image_tokens_.get()};
const size_t prefix_end = tokens.size();
TimingInfo timing_info = {.verbosity = 0};
model.Generate(runtime_config, tokens, /*pos=*/0, prefix_end,
s_env->MutableKVCache(), timing_info);
return response;
}
void TestQuestion(const char* question, const char* expected_substring) { void TestQuestion(const char* question, const char* expected_substring) {
ASSERT_NE(s_env->GetGemma(), nullptr); ASSERT_NE(s_env->GetGemma(), nullptr);
std::string path = "paligemma/testdata/image.ppm"; std::string path = "paligemma/testdata/image.ppm";
InitVit(path);
const std::string reply = GemmaReply(question); PaliGemmaHelper paligemma_helper(s_env);
paligemma_helper.InitVit(path);
const std::string reply = paligemma_helper.GemmaReply(question);
fprintf(stderr, "'%s'\n\n", reply.c_str()); fprintf(stderr, "'%s'\n\n", reply.c_str());
EXPECT_TRUE(reply.find(expected_substring) != std::string::npos); // NOLINT EXPECT_TRUE(reply.find(expected_substring) != std::string::npos); // NOLINT
} }