mirror of https://github.com/google/gemma.cpp.git
Factor out addition of ViTConfig to a ModelConfig.
Use ModelConfig values for ImageTokens. Output timing info for image token generation. Add a method to copy image data into Image class directly. Minor changes: pipe ModelTraining to more places. PiperOrigin-RevId: 690572283
This commit is contained in:
parent
19cfe14c76
commit
583bd93e9a
|
|
@ -233,6 +233,7 @@ cc_library(
|
|||
deps = [
|
||||
":common",
|
||||
"//compression:io",
|
||||
"//compression:sfp",
|
||||
"@highway//:hwy",
|
||||
"@highway//:profiler",
|
||||
"@com_google_sentencepiece//:sentencepiece_processor",
|
||||
|
|
@ -376,6 +377,7 @@ cc_binary(
|
|||
":gemma_lib",
|
||||
":threading",
|
||||
# Placeholder for internal dep, do not remove.,
|
||||
"//compression:sfp",
|
||||
"//paligemma:image",
|
||||
"@highway//:hwy",
|
||||
"@highway//:profiler",
|
||||
|
|
|
|||
|
|
@ -40,8 +40,9 @@
|
|||
#include <vector>
|
||||
|
||||
#include "compression/compress.h"
|
||||
#include "compression/io.h" // Path
|
||||
#include "gemma/common.h" // Model
|
||||
#include "compression/shared.h" // ModelTraining
|
||||
#include "compression/io.h" // Path
|
||||
#include "gemma/common.h" // Model
|
||||
#include "gemma/weights.h"
|
||||
#include "util/allocator.h"
|
||||
#include "util/args.h"
|
||||
|
|
@ -73,9 +74,8 @@ struct Args : public ArgsBase<Args> {
|
|||
|
||||
// Returns error string or nullptr if OK.
|
||||
const char* Validate() {
|
||||
ModelTraining model_training;
|
||||
if (const char* err = ParseModelTypeAndTraining(model_type_str, model_type_,
|
||||
model_training)) {
|
||||
model_training_)) {
|
||||
return err;
|
||||
}
|
||||
if (const char* err = ParseType(weight_type_str, weight_type_)) {
|
||||
|
|
@ -127,10 +127,12 @@ struct Args : public ArgsBase<Args> {
|
|||
|
||||
// Uninitialized before Validate, must call after that.
|
||||
gcpp::Model ModelType() const { return model_type_; }
|
||||
gcpp::ModelTraining ModelTrainingType() const { return model_training_; }
|
||||
gcpp::Type WeightType() const { return weight_type_; }
|
||||
|
||||
private:
|
||||
Model model_type_;
|
||||
ModelTraining model_training_;
|
||||
Type weight_type_;
|
||||
};
|
||||
|
||||
|
|
@ -210,10 +212,10 @@ namespace gcpp {
|
|||
|
||||
void Run(Args& args) {
|
||||
hwy::ThreadPool pool(args.num_threads);
|
||||
const Model model_type = args.ModelType();
|
||||
if (model_type == Model::PALIGEMMA_224) {
|
||||
if (args.ModelTrainingType() == ModelTraining::PALIGEMMA) {
|
||||
HWY_ABORT("PaliGemma is not supported in compress_weights.");
|
||||
}
|
||||
const Model model_type = args.ModelType();
|
||||
const Type weight_type = args.WeightType();
|
||||
switch (weight_type) {
|
||||
case Type::kF32:
|
||||
|
|
|
|||
|
|
@ -198,17 +198,15 @@ static ModelConfig ConfigGriffin2B() {
|
|||
return config;
|
||||
}
|
||||
|
||||
static ModelConfig ConfigPaliGemma_224() {
|
||||
ModelConfig config = ConfigGemma2B();
|
||||
config.model_name = "PaliGemma_224";
|
||||
config.model = Model::PALIGEMMA_224;
|
||||
// Adds a ViT config (SigLIP SoViT ViT, used in PaliGemma) to the model config.
|
||||
static void AddVitConfig(ModelConfig& config) {
|
||||
config.vit_model_dim = 1152;
|
||||
config.vocab_size = 256000 + 1024 + 128; // = 257152
|
||||
config.image_size = 224;
|
||||
config.patch_width = 14;
|
||||
const size_t num_patches = config.image_size / config.patch_width;
|
||||
config.vit_seq_len = num_patches * num_patches;
|
||||
LayerConfig layer_config = {
|
||||
LayerConfig vit_layer_config = {
|
||||
.model_dim = config.vit_model_dim,
|
||||
.ff_hidden_dim = 4304,
|
||||
.heads = 16,
|
||||
|
|
@ -217,8 +215,15 @@ static ModelConfig ConfigPaliGemma_224() {
|
|||
.ff_biases = true,
|
||||
.type = LayerAttentionType::kVit,
|
||||
};
|
||||
config.vit_layer_configs = {27, layer_config};
|
||||
config.vit_layer_configs = {27, vit_layer_config};
|
||||
config.num_vit_scales = 4 * config.vit_layer_configs.size();
|
||||
}
|
||||
|
||||
static ModelConfig ConfigPaliGemma_224() {
|
||||
ModelConfig config = ConfigGemma2B();
|
||||
config.model_name = "PaliGemma_224";
|
||||
config.model = Model::PALIGEMMA_224;
|
||||
AddVitConfig(config);
|
||||
return config;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@
|
|||
|
||||
#include <stddef.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <array>
|
||||
#include <string>
|
||||
#include <unordered_set>
|
||||
|
|
|
|||
21
gemma/run.cc
21
gemma/run.cc
|
|
@ -16,12 +16,14 @@
|
|||
// Command line text interface to gemma.
|
||||
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <random>
|
||||
#include <string>
|
||||
#include <string_view>
|
||||
#include <vector>
|
||||
|
||||
// Placeholder for internal header, do not modify.
|
||||
#include "compression/shared.h" // ModelTraining
|
||||
#include "evals/benchmark_helper.h"
|
||||
#include "gemma/common.h"
|
||||
#include "gemma/gemma.h" // Gemma
|
||||
|
|
@ -90,13 +92,22 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const InferenceArgs& args,
|
|||
|
||||
const bool have_image = !args.image_file.path.empty();
|
||||
Image image;
|
||||
ImageTokens image_tokens(256, 2048);
|
||||
std::unique_ptr<ImageTokens> image_tokens;
|
||||
if (have_image) {
|
||||
HWY_ASSERT(model.Info().model == Model::PALIGEMMA_224);
|
||||
image_tokens = std::make_unique<ImageTokens>(
|
||||
model.GetModelConfig().vit_seq_len, model.GetModelConfig().model_dim);
|
||||
HWY_ASSERT(model.Info().training == ModelTraining::PALIGEMMA);
|
||||
HWY_ASSERT(image.ReadPPM(args.image_file.path));
|
||||
image.Resize();
|
||||
RuntimeConfig runtime_config = {.verbosity = verbosity, .gen = &gen};
|
||||
model.GenerateImageTokens(runtime_config, image, image_tokens);
|
||||
double image_tokens_start = hwy::platform::Now();
|
||||
model.GenerateImageTokens(runtime_config, image, *image_tokens);
|
||||
if (verbosity >= 1) {
|
||||
double image_tokens_duration = hwy::platform::Now() - image_tokens_start;
|
||||
fprintf(stderr,
|
||||
"\n\n[ Timing info ] Image token generation took: %d ms\n",
|
||||
static_cast<int>(image_tokens_duration * 1000));
|
||||
}
|
||||
}
|
||||
|
||||
// callback function invoked for each generated token.
|
||||
|
|
@ -170,8 +181,8 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const InferenceArgs& args,
|
|||
args.CopyTo(runtime_config);
|
||||
size_t prefix_end = 0;
|
||||
if (have_image) {
|
||||
runtime_config.image_tokens = &image_tokens;
|
||||
prompt.insert(prompt.begin(), image_tokens.BatchSize(), 0);
|
||||
runtime_config.image_tokens = image_tokens.get();
|
||||
prompt.insert(prompt.begin(), image_tokens->BatchSize(), 0);
|
||||
prompt_size = prompt.size();
|
||||
// The end of the prefix for prefix-LM style attention in Paligemma.
|
||||
// See Figure 2 of https://arxiv.org/abs/2407.07726.
|
||||
|
|
|
|||
|
|
@ -21,9 +21,10 @@
|
|||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "compression/io.h" // Path
|
||||
#include "gemma/common.h" // Wrap
|
||||
#include "hwy/base.h" // HWY_ASSERT
|
||||
#include "compression/io.h" // Path
|
||||
#include "compression/shared.h" // ModelTraining
|
||||
#include "gemma/common.h" // Wrap
|
||||
#include "hwy/base.h" // HWY_ASSERT
|
||||
#include "hwy/profiler.h"
|
||||
// copybara:import_next_line:sentencepiece
|
||||
#include "src/sentencepiece_processor.h"
|
||||
|
|
@ -109,7 +110,7 @@ std::vector<int> WrapAndTokenize(const GemmaTokenizer& tokenizer,
|
|||
}
|
||||
|
||||
// PaliGemma separator. The SEP token "\n" is always tokenized separately.
|
||||
if (info.model == Model::PALIGEMMA_224) {
|
||||
if (info.training == ModelTraining::PALIGEMMA) {
|
||||
std::vector<int> sep_tokens;
|
||||
HWY_ASSERT(tokenizer.Encode("\n", &sep_tokens));
|
||||
tokens.insert(tokens.end(), sep_tokens.begin(), sep_tokens.end());
|
||||
|
|
|
|||
|
|
@ -38,11 +38,11 @@ cc_test(
|
|||
"no_tap",
|
||||
],
|
||||
deps = [
|
||||
"@googletest//:gtest_main",
|
||||
"@googletest//:gtest_main", # buildcleaner: keep
|
||||
"//:benchmark_helper",
|
||||
"//:common",
|
||||
"//:gemma_lib",
|
||||
"//:tokenizer",
|
||||
"//compression:sfp",
|
||||
"@highway//:hwy",
|
||||
"@highway//:hwy_test_util",
|
||||
],
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@
|
|||
#include <cstdio>
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <limits>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
|
@ -152,6 +153,27 @@ bool Image::ReadPPM(const hwy::Span<const char>& buf) {
|
|||
return true;
|
||||
}
|
||||
|
||||
void Image::Set(int width, int height, const float* data) {
|
||||
width_ = width;
|
||||
height_ = height;
|
||||
int num_elements = width * height * 3;
|
||||
data_.resize(num_elements);
|
||||
data_.assign(data, data + num_elements);
|
||||
float min_value = std::numeric_limits<float>::infinity();
|
||||
float max_value = -std::numeric_limits<float>::infinity();
|
||||
for (int i = 0; i < num_elements; ++i) {
|
||||
if (data_[i] < min_value) min_value = data_[i];
|
||||
if (data_[i] > max_value) max_value = data_[i];
|
||||
}
|
||||
// -> out_min + (value - in_min) * (out_max - out_min) / (in_max - in_min)
|
||||
float in_range = max_value - min_value;
|
||||
if (in_range == 0.0f) in_range = 1.0f;
|
||||
float scale = 2.0f / in_range;
|
||||
for (int i = 0; i < num_elements; ++i) {
|
||||
data_[i] = (data_[i] - min_value) * scale - 1.0f;
|
||||
}
|
||||
}
|
||||
|
||||
void Image::Resize() {
|
||||
int new_width = 224;
|
||||
int new_height = kImageSize;
|
||||
|
|
|
|||
|
|
@ -35,6 +35,9 @@ class Image {
|
|||
// Reads PPM format (P6, binary) data from a hwy::Span, normalizes to [-1, 1].
|
||||
// Returns true on success.
|
||||
bool ReadPPM(const hwy::Span<const char>& buf);
|
||||
// Sets the image content to the given data. The data is copied and normalized
|
||||
// to [-1, 1]. The data is expected to be of size width * height * 3.
|
||||
void Set(int width, int height, const float* data);
|
||||
// Resizes to 224x224 (nearest-neighbor for now, bilinear or antialias would
|
||||
// be better).
|
||||
void Resize();
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@
|
|||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "compression/shared.h"
|
||||
#include "evals/benchmark_helper.h"
|
||||
#include "gemma/common.h"
|
||||
#include "gemma/gemma.h"
|
||||
|
|
@ -50,9 +51,10 @@ class PaliGemmaTest : public ::testing::Test {
|
|||
void PaliGemmaTest::InitVit(const std::string& path) {
|
||||
ASSERT_NE(s_env->GetModel(), nullptr);
|
||||
Gemma& model = *(s_env->GetModel());
|
||||
image_tokens_ = std::make_unique<ImageTokens>(256, 2048);
|
||||
image_tokens_ = std::make_unique<ImageTokens>(
|
||||
model.GetModelConfig().vit_seq_len, model.GetModelConfig().model_dim);
|
||||
Image image;
|
||||
HWY_ASSERT(model.Info().model == Model::PALIGEMMA_224);
|
||||
HWY_ASSERT(model.Info().training == ModelTraining::PALIGEMMA);
|
||||
HWY_ASSERT(image.ReadPPM(path));
|
||||
image.Resize();
|
||||
RuntimeConfig runtime_config = {.verbosity = 0, .gen = &s_env->MutableGen()};
|
||||
|
|
|
|||
Loading…
Reference in New Issue