mirror of https://github.com/google/gemma.cpp.git
Factor out addition of ViTConfig to a ModelConfig.
Use ModelConfig values for ImageTokens. Output timing info for image token generation. Add a method to copy image data into Image class directly. Minor changes: pipe ModelTraining to more places. PiperOrigin-RevId: 690572283
This commit is contained in:
parent
19cfe14c76
commit
583bd93e9a
|
|
@ -233,6 +233,7 @@ cc_library(
|
||||||
deps = [
|
deps = [
|
||||||
":common",
|
":common",
|
||||||
"//compression:io",
|
"//compression:io",
|
||||||
|
"//compression:sfp",
|
||||||
"@highway//:hwy",
|
"@highway//:hwy",
|
||||||
"@highway//:profiler",
|
"@highway//:profiler",
|
||||||
"@com_google_sentencepiece//:sentencepiece_processor",
|
"@com_google_sentencepiece//:sentencepiece_processor",
|
||||||
|
|
@ -376,6 +377,7 @@ cc_binary(
|
||||||
":gemma_lib",
|
":gemma_lib",
|
||||||
":threading",
|
":threading",
|
||||||
# Placeholder for internal dep, do not remove.,
|
# Placeholder for internal dep, do not remove.,
|
||||||
|
"//compression:sfp",
|
||||||
"//paligemma:image",
|
"//paligemma:image",
|
||||||
"@highway//:hwy",
|
"@highway//:hwy",
|
||||||
"@highway//:profiler",
|
"@highway//:profiler",
|
||||||
|
|
|
||||||
|
|
@ -40,6 +40,7 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "compression/compress.h"
|
#include "compression/compress.h"
|
||||||
|
#include "compression/shared.h" // ModelTraining
|
||||||
#include "compression/io.h" // Path
|
#include "compression/io.h" // Path
|
||||||
#include "gemma/common.h" // Model
|
#include "gemma/common.h" // Model
|
||||||
#include "gemma/weights.h"
|
#include "gemma/weights.h"
|
||||||
|
|
@ -73,9 +74,8 @@ struct Args : public ArgsBase<Args> {
|
||||||
|
|
||||||
// Returns error string or nullptr if OK.
|
// Returns error string or nullptr if OK.
|
||||||
const char* Validate() {
|
const char* Validate() {
|
||||||
ModelTraining model_training;
|
|
||||||
if (const char* err = ParseModelTypeAndTraining(model_type_str, model_type_,
|
if (const char* err = ParseModelTypeAndTraining(model_type_str, model_type_,
|
||||||
model_training)) {
|
model_training_)) {
|
||||||
return err;
|
return err;
|
||||||
}
|
}
|
||||||
if (const char* err = ParseType(weight_type_str, weight_type_)) {
|
if (const char* err = ParseType(weight_type_str, weight_type_)) {
|
||||||
|
|
@ -127,10 +127,12 @@ struct Args : public ArgsBase<Args> {
|
||||||
|
|
||||||
// Uninitialized before Validate, must call after that.
|
// Uninitialized before Validate, must call after that.
|
||||||
gcpp::Model ModelType() const { return model_type_; }
|
gcpp::Model ModelType() const { return model_type_; }
|
||||||
|
gcpp::ModelTraining ModelTrainingType() const { return model_training_; }
|
||||||
gcpp::Type WeightType() const { return weight_type_; }
|
gcpp::Type WeightType() const { return weight_type_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Model model_type_;
|
Model model_type_;
|
||||||
|
ModelTraining model_training_;
|
||||||
Type weight_type_;
|
Type weight_type_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -210,10 +212,10 @@ namespace gcpp {
|
||||||
|
|
||||||
void Run(Args& args) {
|
void Run(Args& args) {
|
||||||
hwy::ThreadPool pool(args.num_threads);
|
hwy::ThreadPool pool(args.num_threads);
|
||||||
const Model model_type = args.ModelType();
|
if (args.ModelTrainingType() == ModelTraining::PALIGEMMA) {
|
||||||
if (model_type == Model::PALIGEMMA_224) {
|
|
||||||
HWY_ABORT("PaliGemma is not supported in compress_weights.");
|
HWY_ABORT("PaliGemma is not supported in compress_weights.");
|
||||||
}
|
}
|
||||||
|
const Model model_type = args.ModelType();
|
||||||
const Type weight_type = args.WeightType();
|
const Type weight_type = args.WeightType();
|
||||||
switch (weight_type) {
|
switch (weight_type) {
|
||||||
case Type::kF32:
|
case Type::kF32:
|
||||||
|
|
|
||||||
|
|
@ -198,17 +198,15 @@ static ModelConfig ConfigGriffin2B() {
|
||||||
return config;
|
return config;
|
||||||
}
|
}
|
||||||
|
|
||||||
static ModelConfig ConfigPaliGemma_224() {
|
// Adds a ViT config (SigLIP SoViT ViT, used in PaliGemma) to the model config.
|
||||||
ModelConfig config = ConfigGemma2B();
|
static void AddVitConfig(ModelConfig& config) {
|
||||||
config.model_name = "PaliGemma_224";
|
|
||||||
config.model = Model::PALIGEMMA_224;
|
|
||||||
config.vit_model_dim = 1152;
|
config.vit_model_dim = 1152;
|
||||||
config.vocab_size = 256000 + 1024 + 128; // = 257152
|
config.vocab_size = 256000 + 1024 + 128; // = 257152
|
||||||
config.image_size = 224;
|
config.image_size = 224;
|
||||||
config.patch_width = 14;
|
config.patch_width = 14;
|
||||||
const size_t num_patches = config.image_size / config.patch_width;
|
const size_t num_patches = config.image_size / config.patch_width;
|
||||||
config.vit_seq_len = num_patches * num_patches;
|
config.vit_seq_len = num_patches * num_patches;
|
||||||
LayerConfig layer_config = {
|
LayerConfig vit_layer_config = {
|
||||||
.model_dim = config.vit_model_dim,
|
.model_dim = config.vit_model_dim,
|
||||||
.ff_hidden_dim = 4304,
|
.ff_hidden_dim = 4304,
|
||||||
.heads = 16,
|
.heads = 16,
|
||||||
|
|
@ -217,8 +215,15 @@ static ModelConfig ConfigPaliGemma_224() {
|
||||||
.ff_biases = true,
|
.ff_biases = true,
|
||||||
.type = LayerAttentionType::kVit,
|
.type = LayerAttentionType::kVit,
|
||||||
};
|
};
|
||||||
config.vit_layer_configs = {27, layer_config};
|
config.vit_layer_configs = {27, vit_layer_config};
|
||||||
config.num_vit_scales = 4 * config.vit_layer_configs.size();
|
config.num_vit_scales = 4 * config.vit_layer_configs.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
static ModelConfig ConfigPaliGemma_224() {
|
||||||
|
ModelConfig config = ConfigGemma2B();
|
||||||
|
config.model_name = "PaliGemma_224";
|
||||||
|
config.model = Model::PALIGEMMA_224;
|
||||||
|
AddVitConfig(config);
|
||||||
return config;
|
return config;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -20,6 +20,7 @@
|
||||||
|
|
||||||
#include <stddef.h>
|
#include <stddef.h>
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
#include <array>
|
#include <array>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <unordered_set>
|
#include <unordered_set>
|
||||||
|
|
|
||||||
21
gemma/run.cc
21
gemma/run.cc
|
|
@ -16,12 +16,14 @@
|
||||||
// Command line text interface to gemma.
|
// Command line text interface to gemma.
|
||||||
|
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
|
#include <memory>
|
||||||
#include <random>
|
#include <random>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <string_view>
|
#include <string_view>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
// Placeholder for internal header, do not modify.
|
// Placeholder for internal header, do not modify.
|
||||||
|
#include "compression/shared.h" // ModelTraining
|
||||||
#include "evals/benchmark_helper.h"
|
#include "evals/benchmark_helper.h"
|
||||||
#include "gemma/common.h"
|
#include "gemma/common.h"
|
||||||
#include "gemma/gemma.h" // Gemma
|
#include "gemma/gemma.h" // Gemma
|
||||||
|
|
@ -90,13 +92,22 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const InferenceArgs& args,
|
||||||
|
|
||||||
const bool have_image = !args.image_file.path.empty();
|
const bool have_image = !args.image_file.path.empty();
|
||||||
Image image;
|
Image image;
|
||||||
ImageTokens image_tokens(256, 2048);
|
std::unique_ptr<ImageTokens> image_tokens;
|
||||||
if (have_image) {
|
if (have_image) {
|
||||||
HWY_ASSERT(model.Info().model == Model::PALIGEMMA_224);
|
image_tokens = std::make_unique<ImageTokens>(
|
||||||
|
model.GetModelConfig().vit_seq_len, model.GetModelConfig().model_dim);
|
||||||
|
HWY_ASSERT(model.Info().training == ModelTraining::PALIGEMMA);
|
||||||
HWY_ASSERT(image.ReadPPM(args.image_file.path));
|
HWY_ASSERT(image.ReadPPM(args.image_file.path));
|
||||||
image.Resize();
|
image.Resize();
|
||||||
RuntimeConfig runtime_config = {.verbosity = verbosity, .gen = &gen};
|
RuntimeConfig runtime_config = {.verbosity = verbosity, .gen = &gen};
|
||||||
model.GenerateImageTokens(runtime_config, image, image_tokens);
|
double image_tokens_start = hwy::platform::Now();
|
||||||
|
model.GenerateImageTokens(runtime_config, image, *image_tokens);
|
||||||
|
if (verbosity >= 1) {
|
||||||
|
double image_tokens_duration = hwy::platform::Now() - image_tokens_start;
|
||||||
|
fprintf(stderr,
|
||||||
|
"\n\n[ Timing info ] Image token generation took: %d ms\n",
|
||||||
|
static_cast<int>(image_tokens_duration * 1000));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// callback function invoked for each generated token.
|
// callback function invoked for each generated token.
|
||||||
|
|
@ -170,8 +181,8 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const InferenceArgs& args,
|
||||||
args.CopyTo(runtime_config);
|
args.CopyTo(runtime_config);
|
||||||
size_t prefix_end = 0;
|
size_t prefix_end = 0;
|
||||||
if (have_image) {
|
if (have_image) {
|
||||||
runtime_config.image_tokens = &image_tokens;
|
runtime_config.image_tokens = image_tokens.get();
|
||||||
prompt.insert(prompt.begin(), image_tokens.BatchSize(), 0);
|
prompt.insert(prompt.begin(), image_tokens->BatchSize(), 0);
|
||||||
prompt_size = prompt.size();
|
prompt_size = prompt.size();
|
||||||
// The end of the prefix for prefix-LM style attention in Paligemma.
|
// The end of the prefix for prefix-LM style attention in Paligemma.
|
||||||
// See Figure 2 of https://arxiv.org/abs/2407.07726.
|
// See Figure 2 of https://arxiv.org/abs/2407.07726.
|
||||||
|
|
|
||||||
|
|
@ -22,6 +22,7 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "compression/io.h" // Path
|
#include "compression/io.h" // Path
|
||||||
|
#include "compression/shared.h" // ModelTraining
|
||||||
#include "gemma/common.h" // Wrap
|
#include "gemma/common.h" // Wrap
|
||||||
#include "hwy/base.h" // HWY_ASSERT
|
#include "hwy/base.h" // HWY_ASSERT
|
||||||
#include "hwy/profiler.h"
|
#include "hwy/profiler.h"
|
||||||
|
|
@ -109,7 +110,7 @@ std::vector<int> WrapAndTokenize(const GemmaTokenizer& tokenizer,
|
||||||
}
|
}
|
||||||
|
|
||||||
// PaliGemma separator. The SEP token "\n" is always tokenized separately.
|
// PaliGemma separator. The SEP token "\n" is always tokenized separately.
|
||||||
if (info.model == Model::PALIGEMMA_224) {
|
if (info.training == ModelTraining::PALIGEMMA) {
|
||||||
std::vector<int> sep_tokens;
|
std::vector<int> sep_tokens;
|
||||||
HWY_ASSERT(tokenizer.Encode("\n", &sep_tokens));
|
HWY_ASSERT(tokenizer.Encode("\n", &sep_tokens));
|
||||||
tokens.insert(tokens.end(), sep_tokens.begin(), sep_tokens.end());
|
tokens.insert(tokens.end(), sep_tokens.begin(), sep_tokens.end());
|
||||||
|
|
|
||||||
|
|
@ -38,11 +38,11 @@ cc_test(
|
||||||
"no_tap",
|
"no_tap",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
"@googletest//:gtest_main",
|
"@googletest//:gtest_main", # buildcleaner: keep
|
||||||
"//:benchmark_helper",
|
"//:benchmark_helper",
|
||||||
"//:common",
|
"//:common",
|
||||||
"//:gemma_lib",
|
"//:gemma_lib",
|
||||||
"//:tokenizer",
|
"//compression:sfp",
|
||||||
"@highway//:hwy",
|
"@highway//:hwy",
|
||||||
"@highway//:hwy_test_util",
|
"@highway//:hwy_test_util",
|
||||||
],
|
],
|
||||||
|
|
|
||||||
|
|
@ -24,6 +24,7 @@
|
||||||
#include <cstdio>
|
#include <cstdio>
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
|
#include <limits>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
@ -152,6 +153,27 @@ bool Image::ReadPPM(const hwy::Span<const char>& buf) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void Image::Set(int width, int height, const float* data) {
|
||||||
|
width_ = width;
|
||||||
|
height_ = height;
|
||||||
|
int num_elements = width * height * 3;
|
||||||
|
data_.resize(num_elements);
|
||||||
|
data_.assign(data, data + num_elements);
|
||||||
|
float min_value = std::numeric_limits<float>::infinity();
|
||||||
|
float max_value = -std::numeric_limits<float>::infinity();
|
||||||
|
for (int i = 0; i < num_elements; ++i) {
|
||||||
|
if (data_[i] < min_value) min_value = data_[i];
|
||||||
|
if (data_[i] > max_value) max_value = data_[i];
|
||||||
|
}
|
||||||
|
// -> out_min + (value - in_min) * (out_max - out_min) / (in_max - in_min)
|
||||||
|
float in_range = max_value - min_value;
|
||||||
|
if (in_range == 0.0f) in_range = 1.0f;
|
||||||
|
float scale = 2.0f / in_range;
|
||||||
|
for (int i = 0; i < num_elements; ++i) {
|
||||||
|
data_[i] = (data_[i] - min_value) * scale - 1.0f;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void Image::Resize() {
|
void Image::Resize() {
|
||||||
int new_width = 224;
|
int new_width = 224;
|
||||||
int new_height = kImageSize;
|
int new_height = kImageSize;
|
||||||
|
|
|
||||||
|
|
@ -35,6 +35,9 @@ class Image {
|
||||||
// Reads PPM format (P6, binary) data from a hwy::Span, normalizes to [-1, 1].
|
// Reads PPM format (P6, binary) data from a hwy::Span, normalizes to [-1, 1].
|
||||||
// Returns true on success.
|
// Returns true on success.
|
||||||
bool ReadPPM(const hwy::Span<const char>& buf);
|
bool ReadPPM(const hwy::Span<const char>& buf);
|
||||||
|
// Sets the image content to the given data. The data is copied and normalized
|
||||||
|
// to [-1, 1]. The data is expected to be of size width * height * 3.
|
||||||
|
void Set(int width, int height, const float* data);
|
||||||
// Resizes to 224x224 (nearest-neighbor for now, bilinear or antialias would
|
// Resizes to 224x224 (nearest-neighbor for now, bilinear or antialias would
|
||||||
// be better).
|
// be better).
|
||||||
void Resize();
|
void Resize();
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,7 @@
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "compression/shared.h"
|
||||||
#include "evals/benchmark_helper.h"
|
#include "evals/benchmark_helper.h"
|
||||||
#include "gemma/common.h"
|
#include "gemma/common.h"
|
||||||
#include "gemma/gemma.h"
|
#include "gemma/gemma.h"
|
||||||
|
|
@ -50,9 +51,10 @@ class PaliGemmaTest : public ::testing::Test {
|
||||||
void PaliGemmaTest::InitVit(const std::string& path) {
|
void PaliGemmaTest::InitVit(const std::string& path) {
|
||||||
ASSERT_NE(s_env->GetModel(), nullptr);
|
ASSERT_NE(s_env->GetModel(), nullptr);
|
||||||
Gemma& model = *(s_env->GetModel());
|
Gemma& model = *(s_env->GetModel());
|
||||||
image_tokens_ = std::make_unique<ImageTokens>(256, 2048);
|
image_tokens_ = std::make_unique<ImageTokens>(
|
||||||
|
model.GetModelConfig().vit_seq_len, model.GetModelConfig().model_dim);
|
||||||
Image image;
|
Image image;
|
||||||
HWY_ASSERT(model.Info().model == Model::PALIGEMMA_224);
|
HWY_ASSERT(model.Info().training == ModelTraining::PALIGEMMA);
|
||||||
HWY_ASSERT(image.ReadPPM(path));
|
HWY_ASSERT(image.ReadPPM(path));
|
||||||
image.Resize();
|
image.Resize();
|
||||||
RuntimeConfig runtime_config = {.verbosity = 0, .gen = &s_env->MutableGen()};
|
RuntimeConfig runtime_config = {.verbosity = 0, .gen = &s_env->MutableGen()};
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue