Factor out addition of ViTConfig to a ModelConfig.

Use ModelConfig values for ImageTokens.
Output timing info for image token generation.
Add a method to copy image data into Image class directly.
Minor changes: pipe ModelTraining to more places.

PiperOrigin-RevId: 690572283
This commit is contained in:
Daniel Keysers 2024-10-28 05:28:42 -07:00 committed by Copybara-Service
parent 19cfe14c76
commit 583bd93e9a
10 changed files with 74 additions and 25 deletions

View File

@ -233,6 +233,7 @@ cc_library(
deps = [ deps = [
":common", ":common",
"//compression:io", "//compression:io",
"//compression:sfp",
"@highway//:hwy", "@highway//:hwy",
"@highway//:profiler", "@highway//:profiler",
"@com_google_sentencepiece//:sentencepiece_processor", "@com_google_sentencepiece//:sentencepiece_processor",
@ -376,6 +377,7 @@ cc_binary(
":gemma_lib", ":gemma_lib",
":threading", ":threading",
# Placeholder for internal dep, do not remove., # Placeholder for internal dep, do not remove.,
"//compression:sfp",
"//paligemma:image", "//paligemma:image",
"@highway//:hwy", "@highway//:hwy",
"@highway//:profiler", "@highway//:profiler",

View File

@ -40,8 +40,9 @@
#include <vector> #include <vector>
#include "compression/compress.h" #include "compression/compress.h"
#include "compression/io.h" // Path #include "compression/shared.h" // ModelTraining
#include "gemma/common.h" // Model #include "compression/io.h" // Path
#include "gemma/common.h" // Model
#include "gemma/weights.h" #include "gemma/weights.h"
#include "util/allocator.h" #include "util/allocator.h"
#include "util/args.h" #include "util/args.h"
@ -73,9 +74,8 @@ struct Args : public ArgsBase<Args> {
// Returns error string or nullptr if OK. // Returns error string or nullptr if OK.
const char* Validate() { const char* Validate() {
ModelTraining model_training;
if (const char* err = ParseModelTypeAndTraining(model_type_str, model_type_, if (const char* err = ParseModelTypeAndTraining(model_type_str, model_type_,
model_training)) { model_training_)) {
return err; return err;
} }
if (const char* err = ParseType(weight_type_str, weight_type_)) { if (const char* err = ParseType(weight_type_str, weight_type_)) {
@ -127,10 +127,12 @@ struct Args : public ArgsBase<Args> {
// Uninitialized before Validate, must call after that. // Uninitialized before Validate, must call after that.
gcpp::Model ModelType() const { return model_type_; } gcpp::Model ModelType() const { return model_type_; }
gcpp::ModelTraining ModelTrainingType() const { return model_training_; }
gcpp::Type WeightType() const { return weight_type_; } gcpp::Type WeightType() const { return weight_type_; }
private: private:
Model model_type_; Model model_type_;
ModelTraining model_training_;
Type weight_type_; Type weight_type_;
}; };
@ -210,10 +212,10 @@ namespace gcpp {
void Run(Args& args) { void Run(Args& args) {
hwy::ThreadPool pool(args.num_threads); hwy::ThreadPool pool(args.num_threads);
const Model model_type = args.ModelType(); if (args.ModelTrainingType() == ModelTraining::PALIGEMMA) {
if (model_type == Model::PALIGEMMA_224) {
HWY_ABORT("PaliGemma is not supported in compress_weights."); HWY_ABORT("PaliGemma is not supported in compress_weights.");
} }
const Model model_type = args.ModelType();
const Type weight_type = args.WeightType(); const Type weight_type = args.WeightType();
switch (weight_type) { switch (weight_type) {
case Type::kF32: case Type::kF32:

View File

@ -198,17 +198,15 @@ static ModelConfig ConfigGriffin2B() {
return config; return config;
} }
static ModelConfig ConfigPaliGemma_224() { // Adds a ViT config (SigLIP SoViT ViT, used in PaliGemma) to the model config.
ModelConfig config = ConfigGemma2B(); static void AddVitConfig(ModelConfig& config) {
config.model_name = "PaliGemma_224";
config.model = Model::PALIGEMMA_224;
config.vit_model_dim = 1152; config.vit_model_dim = 1152;
config.vocab_size = 256000 + 1024 + 128; // = 257152 config.vocab_size = 256000 + 1024 + 128; // = 257152
config.image_size = 224; config.image_size = 224;
config.patch_width = 14; config.patch_width = 14;
const size_t num_patches = config.image_size / config.patch_width; const size_t num_patches = config.image_size / config.patch_width;
config.vit_seq_len = num_patches * num_patches; config.vit_seq_len = num_patches * num_patches;
LayerConfig layer_config = { LayerConfig vit_layer_config = {
.model_dim = config.vit_model_dim, .model_dim = config.vit_model_dim,
.ff_hidden_dim = 4304, .ff_hidden_dim = 4304,
.heads = 16, .heads = 16,
@ -217,8 +215,15 @@ static ModelConfig ConfigPaliGemma_224() {
.ff_biases = true, .ff_biases = true,
.type = LayerAttentionType::kVit, .type = LayerAttentionType::kVit,
}; };
config.vit_layer_configs = {27, layer_config}; config.vit_layer_configs = {27, vit_layer_config};
config.num_vit_scales = 4 * config.vit_layer_configs.size(); config.num_vit_scales = 4 * config.vit_layer_configs.size();
}
static ModelConfig ConfigPaliGemma_224() {
ModelConfig config = ConfigGemma2B();
config.model_name = "PaliGemma_224";
config.model = Model::PALIGEMMA_224;
AddVitConfig(config);
return config; return config;
} }

View File

@ -20,6 +20,7 @@
#include <stddef.h> #include <stddef.h>
#include <algorithm>
#include <array> #include <array>
#include <string> #include <string>
#include <unordered_set> #include <unordered_set>

View File

@ -16,12 +16,14 @@
// Command line text interface to gemma. // Command line text interface to gemma.
#include <iostream> #include <iostream>
#include <memory>
#include <random> #include <random>
#include <string> #include <string>
#include <string_view> #include <string_view>
#include <vector> #include <vector>
// Placeholder for internal header, do not modify. // Placeholder for internal header, do not modify.
#include "compression/shared.h" // ModelTraining
#include "evals/benchmark_helper.h" #include "evals/benchmark_helper.h"
#include "gemma/common.h" #include "gemma/common.h"
#include "gemma/gemma.h" // Gemma #include "gemma/gemma.h" // Gemma
@ -90,13 +92,22 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const InferenceArgs& args,
const bool have_image = !args.image_file.path.empty(); const bool have_image = !args.image_file.path.empty();
Image image; Image image;
ImageTokens image_tokens(256, 2048); std::unique_ptr<ImageTokens> image_tokens;
if (have_image) { if (have_image) {
HWY_ASSERT(model.Info().model == Model::PALIGEMMA_224); image_tokens = std::make_unique<ImageTokens>(
model.GetModelConfig().vit_seq_len, model.GetModelConfig().model_dim);
HWY_ASSERT(model.Info().training == ModelTraining::PALIGEMMA);
HWY_ASSERT(image.ReadPPM(args.image_file.path)); HWY_ASSERT(image.ReadPPM(args.image_file.path));
image.Resize(); image.Resize();
RuntimeConfig runtime_config = {.verbosity = verbosity, .gen = &gen}; RuntimeConfig runtime_config = {.verbosity = verbosity, .gen = &gen};
model.GenerateImageTokens(runtime_config, image, image_tokens); double image_tokens_start = hwy::platform::Now();
model.GenerateImageTokens(runtime_config, image, *image_tokens);
if (verbosity >= 1) {
double image_tokens_duration = hwy::platform::Now() - image_tokens_start;
fprintf(stderr,
"\n\n[ Timing info ] Image token generation took: %d ms\n",
static_cast<int>(image_tokens_duration * 1000));
}
} }
// callback function invoked for each generated token. // callback function invoked for each generated token.
@ -170,8 +181,8 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const InferenceArgs& args,
args.CopyTo(runtime_config); args.CopyTo(runtime_config);
size_t prefix_end = 0; size_t prefix_end = 0;
if (have_image) { if (have_image) {
runtime_config.image_tokens = &image_tokens; runtime_config.image_tokens = image_tokens.get();
prompt.insert(prompt.begin(), image_tokens.BatchSize(), 0); prompt.insert(prompt.begin(), image_tokens->BatchSize(), 0);
prompt_size = prompt.size(); prompt_size = prompt.size();
// The end of the prefix for prefix-LM style attention in Paligemma. // The end of the prefix for prefix-LM style attention in Paligemma.
// See Figure 2 of https://arxiv.org/abs/2407.07726. // See Figure 2 of https://arxiv.org/abs/2407.07726.

View File

@ -21,9 +21,10 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "compression/io.h" // Path #include "compression/io.h" // Path
#include "gemma/common.h" // Wrap #include "compression/shared.h" // ModelTraining
#include "hwy/base.h" // HWY_ASSERT #include "gemma/common.h" // Wrap
#include "hwy/base.h" // HWY_ASSERT
#include "hwy/profiler.h" #include "hwy/profiler.h"
// copybara:import_next_line:sentencepiece // copybara:import_next_line:sentencepiece
#include "src/sentencepiece_processor.h" #include "src/sentencepiece_processor.h"
@ -109,7 +110,7 @@ std::vector<int> WrapAndTokenize(const GemmaTokenizer& tokenizer,
} }
// PaliGemma separator. The SEP token "\n" is always tokenized separately. // PaliGemma separator. The SEP token "\n" is always tokenized separately.
if (info.model == Model::PALIGEMMA_224) { if (info.training == ModelTraining::PALIGEMMA) {
std::vector<int> sep_tokens; std::vector<int> sep_tokens;
HWY_ASSERT(tokenizer.Encode("\n", &sep_tokens)); HWY_ASSERT(tokenizer.Encode("\n", &sep_tokens));
tokens.insert(tokens.end(), sep_tokens.begin(), sep_tokens.end()); tokens.insert(tokens.end(), sep_tokens.begin(), sep_tokens.end());

View File

@ -38,11 +38,11 @@ cc_test(
"no_tap", "no_tap",
], ],
deps = [ deps = [
"@googletest//:gtest_main", "@googletest//:gtest_main", # buildcleaner: keep
"//:benchmark_helper", "//:benchmark_helper",
"//:common", "//:common",
"//:gemma_lib", "//:gemma_lib",
"//:tokenizer", "//compression:sfp",
"@highway//:hwy", "@highway//:hwy",
"@highway//:hwy_test_util", "@highway//:hwy_test_util",
], ],

View File

@ -24,6 +24,7 @@
#include <cstdio> #include <cstdio>
#include <fstream> #include <fstream>
#include <iostream> #include <iostream>
#include <limits>
#include <string> #include <string>
#include <utility> #include <utility>
#include <vector> #include <vector>
@ -152,6 +153,27 @@ bool Image::ReadPPM(const hwy::Span<const char>& buf) {
return true; return true;
} }
void Image::Set(int width, int height, const float* data) {
width_ = width;
height_ = height;
int num_elements = width * height * 3;
data_.resize(num_elements);
data_.assign(data, data + num_elements);
float min_value = std::numeric_limits<float>::infinity();
float max_value = -std::numeric_limits<float>::infinity();
for (int i = 0; i < num_elements; ++i) {
if (data_[i] < min_value) min_value = data_[i];
if (data_[i] > max_value) max_value = data_[i];
}
// -> out_min + (value - in_min) * (out_max - out_min) / (in_max - in_min)
float in_range = max_value - min_value;
if (in_range == 0.0f) in_range = 1.0f;
float scale = 2.0f / in_range;
for (int i = 0; i < num_elements; ++i) {
data_[i] = (data_[i] - min_value) * scale - 1.0f;
}
}
void Image::Resize() { void Image::Resize() {
int new_width = 224; int new_width = 224;
int new_height = kImageSize; int new_height = kImageSize;

View File

@ -35,6 +35,9 @@ class Image {
// Reads PPM format (P6, binary) data from a hwy::Span, normalizes to [-1, 1]. // Reads PPM format (P6, binary) data from a hwy::Span, normalizes to [-1, 1].
// Returns true on success. // Returns true on success.
bool ReadPPM(const hwy::Span<const char>& buf); bool ReadPPM(const hwy::Span<const char>& buf);
// Sets the image content to the given data. The data is copied and normalized
// to [-1, 1]. The data is expected to be of size width * height * 3.
void Set(int width, int height, const float* data);
// Resizes to 224x224 (nearest-neighbor for now, bilinear or antialias would // Resizes to 224x224 (nearest-neighbor for now, bilinear or antialias would
// be better). // be better).
void Resize(); void Resize();

View File

@ -18,6 +18,7 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "compression/shared.h"
#include "evals/benchmark_helper.h" #include "evals/benchmark_helper.h"
#include "gemma/common.h" #include "gemma/common.h"
#include "gemma/gemma.h" #include "gemma/gemma.h"
@ -50,9 +51,10 @@ class PaliGemmaTest : public ::testing::Test {
void PaliGemmaTest::InitVit(const std::string& path) { void PaliGemmaTest::InitVit(const std::string& path) {
ASSERT_NE(s_env->GetModel(), nullptr); ASSERT_NE(s_env->GetModel(), nullptr);
Gemma& model = *(s_env->GetModel()); Gemma& model = *(s_env->GetModel());
image_tokens_ = std::make_unique<ImageTokens>(256, 2048); image_tokens_ = std::make_unique<ImageTokens>(
model.GetModelConfig().vit_seq_len, model.GetModelConfig().model_dim);
Image image; Image image;
HWY_ASSERT(model.Info().model == Model::PALIGEMMA_224); HWY_ASSERT(model.Info().training == ModelTraining::PALIGEMMA);
HWY_ASSERT(image.ReadPPM(path)); HWY_ASSERT(image.ReadPPM(path));
image.Resize(); image.Resize();
RuntimeConfig runtime_config = {.verbosity = 0, .gen = &s_env->MutableGen()}; RuntimeConfig runtime_config = {.verbosity = 0, .gen = &s_env->MutableGen()};