mirror of https://github.com/google/gemma.cpp.git
Gemma CPP: move PaliGemma tests' helper to a separate class
This helps to be able to use PaliGemma functionalities directly for inference by just providing tokenizer and weight paths. Added @mukundagg to allowed authors list. PiperOrigin-RevId: 772705238
This commit is contained in:
parent
f2adbfbcab
commit
606e22155a
|
|
@ -29,6 +29,24 @@ cc_test(
|
|||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "paligemma_helper",
|
||||
srcs = ["paligemma_helper.cc"],
|
||||
hdrs = ["paligemma_helper.h"],
|
||||
deps = [
|
||||
":image",
|
||||
"//:allocator",
|
||||
"//:benchmark_helper",
|
||||
"//:configs",
|
||||
"//:gemma_args",
|
||||
"//:gemma_lib",
|
||||
"//compression:types",
|
||||
"//io",
|
||||
"@highway//:hwy",
|
||||
"@highway//:profiler",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "paligemma_test",
|
||||
srcs = ["paligemma_test.cc"],
|
||||
|
|
@ -39,6 +57,8 @@ cc_test(
|
|||
"no_tap",
|
||||
],
|
||||
deps = [
|
||||
":paligemma_helper",
|
||||
"//devtools/build/runtime:get_runfiles_dir",
|
||||
"@googletest//:gtest_main", # buildcleaner: keep
|
||||
"//:allocator",
|
||||
"//:benchmark_helper",
|
||||
|
|
|
|||
|
|
@ -0,0 +1,68 @@
|
|||
#include "paligemma/paligemma_helper.h"
|
||||
#include <cstddef>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "compression/types.h"
|
||||
#include "evals/benchmark_helper.h"
|
||||
#include "gemma/configs.h"
|
||||
#include "gemma/gemma.h"
|
||||
#include "util/allocator.h"
|
||||
#include "hwy/base.h"
|
||||
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
void PaliGemmaHelper::InitVit(const std::string& path) {
|
||||
HWY_ASSERT(env_->GetGemma() != nullptr);
|
||||
const Gemma& gemma = *(env_->GetGemma());
|
||||
const ModelConfig& config = gemma.GetModelConfig();
|
||||
HWY_ASSERT(config.wrapping == PromptWrapping::PALIGEMMA);
|
||||
|
||||
image_tokens_ = std::make_unique<ImageTokens>(
|
||||
"image", Extents2D(config.vit_config.seq_len, config.model_dim),
|
||||
MatPadding::kPacked);
|
||||
image_tokens_->AllocateAndAttachRowPtrs(env_->Env().row_ptrs);
|
||||
Image image;
|
||||
HWY_ASSERT(image.ReadPPM(path));
|
||||
const size_t image_size = config.vit_config.image_size;
|
||||
image.Resize(image_size, image_size);
|
||||
RuntimeConfig runtime_config = {.gen = &env_->MutableGen(),
|
||||
.verbosity = 0};
|
||||
gemma.GenerateImageTokens(runtime_config, env_->MutableKVCache().SeqLen(),
|
||||
image, *image_tokens_);
|
||||
}
|
||||
|
||||
std::string PaliGemmaHelper::GemmaReply(const std::string& prompt_text) const {
|
||||
const Gemma& model = *(env_->GetGemma());
|
||||
env_->MutableGen().seed(0x12345678);
|
||||
|
||||
std::string response;
|
||||
auto stream_token = [&](int token, float) {
|
||||
std::string token_text;
|
||||
HWY_ASSERT(
|
||||
model.Tokenizer().Decode(std::vector<int>{token}, &token_text));
|
||||
response += token_text;
|
||||
return true;
|
||||
};
|
||||
|
||||
std::string mutable_prompt = prompt_text;
|
||||
std::vector<int> tokens = env_->WrapAndTokenize(mutable_prompt);
|
||||
tokens.insert(tokens.begin(), image_tokens_->Rows(), 0);
|
||||
|
||||
RuntimeConfig runtime_config = {.max_generated_tokens = 512,
|
||||
// PrefixLM sees/attends to all tokens.
|
||||
.prefill_tbatch_size = tokens.size(),
|
||||
.gen = &env_->MutableGen(),
|
||||
.verbosity = 0,
|
||||
.stream_token = stream_token,
|
||||
.image_tokens = image_tokens_.get()};
|
||||
|
||||
const size_t prefix_end = tokens.size();
|
||||
TimingInfo timing_info = {.verbosity = 0};
|
||||
model.Generate(runtime_config, tokens, /*pos=*/0, prefix_end,
|
||||
env_->MutableKVCache(), timing_info);
|
||||
return response;
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
@ -0,0 +1,25 @@
|
|||
#ifndef THIRD_PARTY_GEMMA_CPP_PALIGEMMA_PALIGEMMA_HELPER_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_PALIGEMMA_PALIGEMMA_HELPER_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include "evals/benchmark_helper.h"
|
||||
#include "gemma/gemma_args.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
class PaliGemmaHelper {
|
||||
public:
|
||||
explicit PaliGemmaHelper(GemmaEnv* env) : env_(env) {};
|
||||
|
||||
void InitVit(const std::string& path);
|
||||
std::string GemmaReply(const std::string& prompt_text) const;
|
||||
|
||||
private:
|
||||
std::unique_ptr<ImageTokens> image_tokens_;
|
||||
GemmaEnv* env_;
|
||||
};
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_PALIGEMMA_PALIGEMMA_HELPER_H_
|
||||
|
|
@ -17,16 +17,14 @@
|
|||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "compression/types.h"
|
||||
#include "evals/benchmark_helper.h"
|
||||
#include "gemma/configs.h"
|
||||
#include "gemma/gemma.h"
|
||||
#include "io/io.h"
|
||||
#include "util/allocator.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/tests/hwy_gtest.h"
|
||||
#include "paligemma/paligemma_helper.h"
|
||||
|
||||
// This test can be run manually with the downloaded PaliGemma weights.
|
||||
// It should pass for `paligemma-3b-mix-224` and `paligemma2-3b-pt-448`.
|
||||
|
|
@ -41,63 +39,13 @@ GemmaEnv* s_env = nullptr;
|
|||
|
||||
class PaliGemmaTest : public ::testing::Test {
|
||||
protected:
|
||||
void InitVit(const std::string& path) {
|
||||
ASSERT_NE(s_env->GetGemma(), nullptr);
|
||||
const Gemma& gemma = *(s_env->GetGemma());
|
||||
const ModelConfig& config = gemma.GetModelConfig();
|
||||
HWY_ASSERT(config.wrapping == PromptWrapping::PALIGEMMA);
|
||||
|
||||
image_tokens_ = std::make_unique<ImageTokens>(
|
||||
"image", Extents2D(config.vit_config.seq_len, config.model_dim),
|
||||
MatPadding::kPacked);
|
||||
image_tokens_->AllocateAndAttachRowPtrs(s_env->Env().row_ptrs);
|
||||
Image image;
|
||||
HWY_ASSERT(image.ReadPPM(path));
|
||||
const size_t image_size = config.vit_config.image_size;
|
||||
image.Resize(image_size, image_size);
|
||||
RuntimeConfig runtime_config = {.gen = &s_env->MutableGen(),
|
||||
.verbosity = 0};
|
||||
gemma.GenerateImageTokens(runtime_config, s_env->MutableKVCache().SeqLen(),
|
||||
image, *image_tokens_);
|
||||
}
|
||||
|
||||
std::string GemmaReply(const std::string& prompt_text) const {
|
||||
const Gemma& model = *(s_env->GetGemma());
|
||||
s_env->MutableGen().seed(0x12345678);
|
||||
|
||||
std::string response;
|
||||
auto stream_token = [&](int token, float) {
|
||||
std::string token_text;
|
||||
HWY_ASSERT(
|
||||
model.Tokenizer().Decode(std::vector<int>{token}, &token_text));
|
||||
response += token_text;
|
||||
return true;
|
||||
};
|
||||
|
||||
std::string mutable_prompt = prompt_text;
|
||||
std::vector<int> tokens = s_env->WrapAndTokenize(mutable_prompt);
|
||||
tokens.insert(tokens.begin(), image_tokens_->Rows(), 0);
|
||||
|
||||
RuntimeConfig runtime_config = {.max_generated_tokens = 512,
|
||||
// PrefixLM sees/attends to all tokens.
|
||||
.prefill_tbatch_size = tokens.size(),
|
||||
.gen = &s_env->MutableGen(),
|
||||
.verbosity = 0,
|
||||
.stream_token = stream_token,
|
||||
.image_tokens = image_tokens_.get()};
|
||||
|
||||
const size_t prefix_end = tokens.size();
|
||||
TimingInfo timing_info = {.verbosity = 0};
|
||||
model.Generate(runtime_config, tokens, /*pos=*/0, prefix_end,
|
||||
s_env->MutableKVCache(), timing_info);
|
||||
return response;
|
||||
}
|
||||
|
||||
void TestQuestion(const char* question, const char* expected_substring) {
|
||||
ASSERT_NE(s_env->GetGemma(), nullptr);
|
||||
std::string path = "paligemma/testdata/image.ppm";
|
||||
InitVit(path);
|
||||
const std::string reply = GemmaReply(question);
|
||||
|
||||
PaliGemmaHelper paligemma_helper(s_env);
|
||||
paligemma_helper.InitVit(path);
|
||||
const std::string reply = paligemma_helper.GemmaReply(question);
|
||||
fprintf(stderr, "'%s'\n\n", reply.c_str());
|
||||
EXPECT_TRUE(reply.find(expected_substring) != std::string::npos); // NOLINT
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue