gemma.cpp/paligemma/paligemma_helper.cc

64 lines
2.2 KiB
C++

#include "paligemma/paligemma_helper.h"
#include <cstddef>
#include <memory>
#include <string>
#include <vector>
#include "compression/types.h"
#include "evals/benchmark_helper.h"
#include "gemma/configs.h"
#include "gemma/gemma.h"
#include "util/allocator.h"
#include "hwy/base.h"
namespace gcpp {
void PaliGemmaHelper::InitVit(const std::string& path) {
HWY_ASSERT(env_->GetGemma() != nullptr);
const Gemma& gemma = *(env_->GetGemma());
const ModelConfig& config = gemma.Config();
HWY_ASSERT(config.wrapping == PromptWrapping::PALIGEMMA);
image_tokens_ = std::make_unique<ImageTokens>(
"image", Extents2D(config.vit_config.seq_len, config.model_dim),
env_->Env().ctx.allocator, MatPadding::kPacked);
image_tokens_->AllocateAndAttachRowPtrs(env_->Env().row_ptrs);
Image image;
HWY_ASSERT(image.ReadPPM(path));
const size_t image_size = config.vit_config.image_size;
image.Resize(image_size, image_size);
RuntimeConfig runtime_config = {.verbosity = 0};
gemma.GenerateImageTokens(runtime_config, env_->MutableKVCache().SeqLen(),
image, *image_tokens_, env_->MutableEnv());
}
std::string PaliGemmaHelper::GemmaReply(const std::string& prompt_text) const {
const Gemma& model = *(env_->GetGemma());
std::string response;
auto stream_token = [&](int token, float) {
std::string token_text;
HWY_ASSERT(model.Tokenizer().Decode(std::vector<int>{token}, &token_text));
response += token_text;
return true;
};
std::vector<int> tokens = env_->WrapAndTokenize(prompt_text);
tokens.insert(tokens.begin(), image_tokens_->Rows(), 0);
RuntimeConfig runtime_config = {.max_generated_tokens = 512,
// PrefixLM sees/attends to all tokens.
.prefill_tbatch_size = tokens.size(),
.verbosity = 0,
.stream_token = stream_token,
.image_tokens = image_tokens_.get()};
const size_t prefix_end = tokens.size();
TimingInfo timing_info = {.verbosity = 0};
model.Generate(runtime_config, tokens, /*pos=*/0, prefix_end,
env_->MutableKVCache(), env_->MutableEnv(), timing_info);
return response;
}
} // namespace gcpp