diff --git a/paligemma/BUILD.bazel b/paligemma/BUILD.bazel index d60c745..27bf062 100644 --- a/paligemma/BUILD.bazel +++ b/paligemma/BUILD.bazel @@ -29,6 +29,24 @@ cc_test( ], ) +cc_library( + name = "paligemma_helper", + srcs = ["paligemma_helper.cc"], + hdrs = ["paligemma_helper.h"], + deps = [ + ":image", + "//:allocator", + "//:benchmark_helper", + "//:configs", + "//:gemma_args", + "//:gemma_lib", + "//compression:types", + "//io", + "@highway//:hwy", + "@highway//:profiler", + ], +) + cc_test( name = "paligemma_test", srcs = ["paligemma_test.cc"], @@ -39,6 +57,8 @@ cc_test( "no_tap", ], deps = [ + ":paligemma_helper", + "//devtools/build/runtime:get_runfiles_dir", "@googletest//:gtest_main", # buildcleaner: keep "//:allocator", "//:benchmark_helper", diff --git a/paligemma/paligemma_helper.cc b/paligemma/paligemma_helper.cc new file mode 100644 index 0000000..d5255c1 --- /dev/null +++ b/paligemma/paligemma_helper.cc @@ -0,0 +1,68 @@ +#include "paligemma/paligemma_helper.h" +#include +#include +#include +#include +#include "compression/types.h" +#include "evals/benchmark_helper.h" +#include "gemma/configs.h" +#include "gemma/gemma.h" +#include "util/allocator.h" +#include "hwy/base.h" + + +namespace gcpp { + +void PaliGemmaHelper::InitVit(const std::string& path) { + HWY_ASSERT(env_->GetGemma() != nullptr); + const Gemma& gemma = *(env_->GetGemma()); + const ModelConfig& config = gemma.GetModelConfig(); + HWY_ASSERT(config.wrapping == PromptWrapping::PALIGEMMA); + + image_tokens_ = std::make_unique( + "image", Extents2D(config.vit_config.seq_len, config.model_dim), + MatPadding::kPacked); + image_tokens_->AllocateAndAttachRowPtrs(env_->Env().row_ptrs); + Image image; + HWY_ASSERT(image.ReadPPM(path)); + const size_t image_size = config.vit_config.image_size; + image.Resize(image_size, image_size); + RuntimeConfig runtime_config = {.gen = &env_->MutableGen(), + .verbosity = 0}; + gemma.GenerateImageTokens(runtime_config, env_->MutableKVCache().SeqLen(), + image, *image_tokens_); +} + +std::string PaliGemmaHelper::GemmaReply(const std::string& prompt_text) const { + const Gemma& model = *(env_->GetGemma()); + env_->MutableGen().seed(0x12345678); + + std::string response; + auto stream_token = [&](int token, float) { + std::string token_text; + HWY_ASSERT( + model.Tokenizer().Decode(std::vector{token}, &token_text)); + response += token_text; + return true; + }; + + std::string mutable_prompt = prompt_text; + std::vector tokens = env_->WrapAndTokenize(mutable_prompt); + tokens.insert(tokens.begin(), image_tokens_->Rows(), 0); + + RuntimeConfig runtime_config = {.max_generated_tokens = 512, + // PrefixLM sees/attends to all tokens. + .prefill_tbatch_size = tokens.size(), + .gen = &env_->MutableGen(), + .verbosity = 0, + .stream_token = stream_token, + .image_tokens = image_tokens_.get()}; + + const size_t prefix_end = tokens.size(); + TimingInfo timing_info = {.verbosity = 0}; + model.Generate(runtime_config, tokens, /*pos=*/0, prefix_end, + env_->MutableKVCache(), timing_info); + return response; +} + +} // namespace gcpp diff --git a/paligemma/paligemma_helper.h b/paligemma/paligemma_helper.h new file mode 100644 index 0000000..4994c43 --- /dev/null +++ b/paligemma/paligemma_helper.h @@ -0,0 +1,25 @@ +#ifndef THIRD_PARTY_GEMMA_CPP_PALIGEMMA_PALIGEMMA_HELPER_H_ +#define THIRD_PARTY_GEMMA_CPP_PALIGEMMA_PALIGEMMA_HELPER_H_ + +#include +#include +#include "evals/benchmark_helper.h" +#include "gemma/gemma_args.h" + +namespace gcpp { + +class PaliGemmaHelper { + public: + explicit PaliGemmaHelper(GemmaEnv* env) : env_(env) {}; + + void InitVit(const std::string& path); + std::string GemmaReply(const std::string& prompt_text) const; + + private: + std::unique_ptr image_tokens_; + GemmaEnv* env_; +}; + +} // namespace gcpp + +#endif // THIRD_PARTY_GEMMA_CPP_PALIGEMMA_PALIGEMMA_HELPER_H_ diff --git a/paligemma/paligemma_test.cc b/paligemma/paligemma_test.cc index e883379..56f618f 100644 --- a/paligemma/paligemma_test.cc +++ b/paligemma/paligemma_test.cc @@ -17,16 +17,14 @@ #include #include -#include -#include "compression/types.h" #include "evals/benchmark_helper.h" #include "gemma/configs.h" #include "gemma/gemma.h" #include "io/io.h" #include "util/allocator.h" -#include "hwy/base.h" #include "hwy/tests/hwy_gtest.h" +#include "paligemma/paligemma_helper.h" // This test can be run manually with the downloaded PaliGemma weights. // It should pass for `paligemma-3b-mix-224` and `paligemma2-3b-pt-448`. @@ -41,63 +39,13 @@ GemmaEnv* s_env = nullptr; class PaliGemmaTest : public ::testing::Test { protected: - void InitVit(const std::string& path) { - ASSERT_NE(s_env->GetGemma(), nullptr); - const Gemma& gemma = *(s_env->GetGemma()); - const ModelConfig& config = gemma.GetModelConfig(); - HWY_ASSERT(config.wrapping == PromptWrapping::PALIGEMMA); - - image_tokens_ = std::make_unique( - "image", Extents2D(config.vit_config.seq_len, config.model_dim), - MatPadding::kPacked); - image_tokens_->AllocateAndAttachRowPtrs(s_env->Env().row_ptrs); - Image image; - HWY_ASSERT(image.ReadPPM(path)); - const size_t image_size = config.vit_config.image_size; - image.Resize(image_size, image_size); - RuntimeConfig runtime_config = {.gen = &s_env->MutableGen(), - .verbosity = 0}; - gemma.GenerateImageTokens(runtime_config, s_env->MutableKVCache().SeqLen(), - image, *image_tokens_); - } - - std::string GemmaReply(const std::string& prompt_text) const { - const Gemma& model = *(s_env->GetGemma()); - s_env->MutableGen().seed(0x12345678); - - std::string response; - auto stream_token = [&](int token, float) { - std::string token_text; - HWY_ASSERT( - model.Tokenizer().Decode(std::vector{token}, &token_text)); - response += token_text; - return true; - }; - - std::string mutable_prompt = prompt_text; - std::vector tokens = s_env->WrapAndTokenize(mutable_prompt); - tokens.insert(tokens.begin(), image_tokens_->Rows(), 0); - - RuntimeConfig runtime_config = {.max_generated_tokens = 512, - // PrefixLM sees/attends to all tokens. - .prefill_tbatch_size = tokens.size(), - .gen = &s_env->MutableGen(), - .verbosity = 0, - .stream_token = stream_token, - .image_tokens = image_tokens_.get()}; - - const size_t prefix_end = tokens.size(); - TimingInfo timing_info = {.verbosity = 0}; - model.Generate(runtime_config, tokens, /*pos=*/0, prefix_end, - s_env->MutableKVCache(), timing_info); - return response; - } - void TestQuestion(const char* question, const char* expected_substring) { ASSERT_NE(s_env->GetGemma(), nullptr); std::string path = "paligemma/testdata/image.ppm"; - InitVit(path); - const std::string reply = GemmaReply(question); + + PaliGemmaHelper paligemma_helper(s_env); + paligemma_helper.InitVit(path); + const std::string reply = paligemma_helper.GemmaReply(question); fprintf(stderr, "'%s'\n\n", reply.c_str()); EXPECT_TRUE(reply.find(expected_substring) != std::string::npos); // NOLINT }