Factor out addition of ViTConfig to a ModelConfig.

Use ModelConfig values for ImageTokens.
Output timing info for image token generation.
Add a method to copy image data into Image class directly.
Minor changes: pipe ModelTraining to more places.

PiperOrigin-RevId: 690572283
This commit is contained in:
Daniel Keysers 2024-10-28 05:28:42 -07:00 committed by Copybara-Service
parent 19cfe14c76
commit 583bd93e9a
10 changed files with 74 additions and 25 deletions

View File

@ -233,6 +233,7 @@ cc_library(
deps = [
":common",
"//compression:io",
"//compression:sfp",
"@highway//:hwy",
"@highway//:profiler",
"@com_google_sentencepiece//:sentencepiece_processor",
@ -376,6 +377,7 @@ cc_binary(
":gemma_lib",
":threading",
# Placeholder for internal dep, do not remove.,
"//compression:sfp",
"//paligemma:image",
"@highway//:hwy",
"@highway//:profiler",

View File

@ -40,8 +40,9 @@
#include <vector>
#include "compression/compress.h"
#include "compression/io.h" // Path
#include "gemma/common.h" // Model
#include "compression/shared.h" // ModelTraining
#include "compression/io.h" // Path
#include "gemma/common.h" // Model
#include "gemma/weights.h"
#include "util/allocator.h"
#include "util/args.h"
@ -73,9 +74,8 @@ struct Args : public ArgsBase<Args> {
// Returns error string or nullptr if OK.
const char* Validate() {
ModelTraining model_training;
if (const char* err = ParseModelTypeAndTraining(model_type_str, model_type_,
model_training)) {
model_training_)) {
return err;
}
if (const char* err = ParseType(weight_type_str, weight_type_)) {
@ -127,10 +127,12 @@ struct Args : public ArgsBase<Args> {
// Uninitialized before Validate, must call after that.
gcpp::Model ModelType() const { return model_type_; }
gcpp::ModelTraining ModelTrainingType() const { return model_training_; }
gcpp::Type WeightType() const { return weight_type_; }
private:
Model model_type_;
ModelTraining model_training_;
Type weight_type_;
};
@ -210,10 +212,10 @@ namespace gcpp {
void Run(Args& args) {
hwy::ThreadPool pool(args.num_threads);
const Model model_type = args.ModelType();
if (model_type == Model::PALIGEMMA_224) {
if (args.ModelTrainingType() == ModelTraining::PALIGEMMA) {
HWY_ABORT("PaliGemma is not supported in compress_weights.");
}
const Model model_type = args.ModelType();
const Type weight_type = args.WeightType();
switch (weight_type) {
case Type::kF32:

View File

@ -198,17 +198,15 @@ static ModelConfig ConfigGriffin2B() {
return config;
}
static ModelConfig ConfigPaliGemma_224() {
ModelConfig config = ConfigGemma2B();
config.model_name = "PaliGemma_224";
config.model = Model::PALIGEMMA_224;
// Adds a ViT config (SigLIP SoViT ViT, used in PaliGemma) to the model config.
static void AddVitConfig(ModelConfig& config) {
config.vit_model_dim = 1152;
config.vocab_size = 256000 + 1024 + 128; // = 257152
config.image_size = 224;
config.patch_width = 14;
const size_t num_patches = config.image_size / config.patch_width;
config.vit_seq_len = num_patches * num_patches;
LayerConfig layer_config = {
LayerConfig vit_layer_config = {
.model_dim = config.vit_model_dim,
.ff_hidden_dim = 4304,
.heads = 16,
@ -217,8 +215,15 @@ static ModelConfig ConfigPaliGemma_224() {
.ff_biases = true,
.type = LayerAttentionType::kVit,
};
config.vit_layer_configs = {27, layer_config};
config.vit_layer_configs = {27, vit_layer_config};
config.num_vit_scales = 4 * config.vit_layer_configs.size();
}
static ModelConfig ConfigPaliGemma_224() {
ModelConfig config = ConfigGemma2B();
config.model_name = "PaliGemma_224";
config.model = Model::PALIGEMMA_224;
AddVitConfig(config);
return config;
}

View File

@ -20,6 +20,7 @@
#include <stddef.h>
#include <algorithm>
#include <array>
#include <string>
#include <unordered_set>

View File

@ -16,12 +16,14 @@
// Command line text interface to gemma.
#include <iostream>
#include <memory>
#include <random>
#include <string>
#include <string_view>
#include <vector>
// Placeholder for internal header, do not modify.
#include "compression/shared.h" // ModelTraining
#include "evals/benchmark_helper.h"
#include "gemma/common.h"
#include "gemma/gemma.h" // Gemma
@ -90,13 +92,22 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const InferenceArgs& args,
const bool have_image = !args.image_file.path.empty();
Image image;
ImageTokens image_tokens(256, 2048);
std::unique_ptr<ImageTokens> image_tokens;
if (have_image) {
HWY_ASSERT(model.Info().model == Model::PALIGEMMA_224);
image_tokens = std::make_unique<ImageTokens>(
model.GetModelConfig().vit_seq_len, model.GetModelConfig().model_dim);
HWY_ASSERT(model.Info().training == ModelTraining::PALIGEMMA);
HWY_ASSERT(image.ReadPPM(args.image_file.path));
image.Resize();
RuntimeConfig runtime_config = {.verbosity = verbosity, .gen = &gen};
model.GenerateImageTokens(runtime_config, image, image_tokens);
double image_tokens_start = hwy::platform::Now();
model.GenerateImageTokens(runtime_config, image, *image_tokens);
if (verbosity >= 1) {
double image_tokens_duration = hwy::platform::Now() - image_tokens_start;
fprintf(stderr,
"\n\n[ Timing info ] Image token generation took: %d ms\n",
static_cast<int>(image_tokens_duration * 1000));
}
}
// callback function invoked for each generated token.
@ -170,8 +181,8 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const InferenceArgs& args,
args.CopyTo(runtime_config);
size_t prefix_end = 0;
if (have_image) {
runtime_config.image_tokens = &image_tokens;
prompt.insert(prompt.begin(), image_tokens.BatchSize(), 0);
runtime_config.image_tokens = image_tokens.get();
prompt.insert(prompt.begin(), image_tokens->BatchSize(), 0);
prompt_size = prompt.size();
// The end of the prefix for prefix-LM style attention in Paligemma.
// See Figure 2 of https://arxiv.org/abs/2407.07726.

View File

@ -21,9 +21,10 @@
#include <string>
#include <vector>
#include "compression/io.h" // Path
#include "gemma/common.h" // Wrap
#include "hwy/base.h" // HWY_ASSERT
#include "compression/io.h" // Path
#include "compression/shared.h" // ModelTraining
#include "gemma/common.h" // Wrap
#include "hwy/base.h" // HWY_ASSERT
#include "hwy/profiler.h"
// copybara:import_next_line:sentencepiece
#include "src/sentencepiece_processor.h"
@ -109,7 +110,7 @@ std::vector<int> WrapAndTokenize(const GemmaTokenizer& tokenizer,
}
// PaliGemma separator. The SEP token "\n" is always tokenized separately.
if (info.model == Model::PALIGEMMA_224) {
if (info.training == ModelTraining::PALIGEMMA) {
std::vector<int> sep_tokens;
HWY_ASSERT(tokenizer.Encode("\n", &sep_tokens));
tokens.insert(tokens.end(), sep_tokens.begin(), sep_tokens.end());

View File

@ -38,11 +38,11 @@ cc_test(
"no_tap",
],
deps = [
"@googletest//:gtest_main",
"@googletest//:gtest_main", # buildcleaner: keep
"//:benchmark_helper",
"//:common",
"//:gemma_lib",
"//:tokenizer",
"//compression:sfp",
"@highway//:hwy",
"@highway//:hwy_test_util",
],

View File

@ -24,6 +24,7 @@
#include <cstdio>
#include <fstream>
#include <iostream>
#include <limits>
#include <string>
#include <utility>
#include <vector>
@ -152,6 +153,27 @@ bool Image::ReadPPM(const hwy::Span<const char>& buf) {
return true;
}
void Image::Set(int width, int height, const float* data) {
width_ = width;
height_ = height;
int num_elements = width * height * 3;
data_.resize(num_elements);
data_.assign(data, data + num_elements);
float min_value = std::numeric_limits<float>::infinity();
float max_value = -std::numeric_limits<float>::infinity();
for (int i = 0; i < num_elements; ++i) {
if (data_[i] < min_value) min_value = data_[i];
if (data_[i] > max_value) max_value = data_[i];
}
// -> out_min + (value - in_min) * (out_max - out_min) / (in_max - in_min)
float in_range = max_value - min_value;
if (in_range == 0.0f) in_range = 1.0f;
float scale = 2.0f / in_range;
for (int i = 0; i < num_elements; ++i) {
data_[i] = (data_[i] - min_value) * scale - 1.0f;
}
}
void Image::Resize() {
int new_width = 224;
int new_height = kImageSize;

View File

@ -35,6 +35,9 @@ class Image {
// Reads PPM format (P6, binary) data from a hwy::Span, normalizes to [-1, 1].
// Returns true on success.
bool ReadPPM(const hwy::Span<const char>& buf);
// Sets the image content to the given data. The data is copied and normalized
// to [-1, 1]. The data is expected to be of size width * height * 3.
void Set(int width, int height, const float* data);
// Resizes to 224x224 (nearest-neighbor for now, bilinear or antialias would
// be better).
void Resize();

View File

@ -18,6 +18,7 @@
#include <string>
#include <vector>
#include "compression/shared.h"
#include "evals/benchmark_helper.h"
#include "gemma/common.h"
#include "gemma/gemma.h"
@ -50,9 +51,10 @@ class PaliGemmaTest : public ::testing::Test {
void PaliGemmaTest::InitVit(const std::string& path) {
ASSERT_NE(s_env->GetModel(), nullptr);
Gemma& model = *(s_env->GetModel());
image_tokens_ = std::make_unique<ImageTokens>(256, 2048);
image_tokens_ = std::make_unique<ImageTokens>(
model.GetModelConfig().vit_seq_len, model.GetModelConfig().model_dim);
Image image;
HWY_ASSERT(model.Info().model == Model::PALIGEMMA_224);
HWY_ASSERT(model.Info().training == ModelTraining::PALIGEMMA);
HWY_ASSERT(image.ReadPPM(path));
image.Resize();
RuntimeConfig runtime_config = {.verbosity = 0, .gen = &s_env->MutableGen()};