mirror of https://github.com/google/gemma.cpp.git
Moved the vit config fields to their own config struct
PiperOrigin-RevId: 715692800
This commit is contained in:
parent
9d40f0117e
commit
b93231a47d
|
|
@ -253,18 +253,19 @@ static LayerConfig LayerConfigVit(size_t model_dim) {
|
||||||
|
|
||||||
// Adds a ViT config (SigLIP SoViT ViT, used in PaliGemma) to the model config.
|
// Adds a ViT config (SigLIP SoViT ViT, used in PaliGemma) to the model config.
|
||||||
static void AddVitConfig(ModelConfig& config, size_t image_size = 224) {
|
static void AddVitConfig(ModelConfig& config, size_t image_size = 224) {
|
||||||
config.vit_model_dim = 1152;
|
config.vit_config.model_dim = 1152;
|
||||||
config.vocab_size = 256000 + 1024 + 128; // = 257152
|
config.vocab_size = 256000 + 1024 + 128; // = 257152
|
||||||
config.image_size = image_size;
|
config.vit_config.image_size = image_size;
|
||||||
config.patch_width = 14;
|
config.vit_config.patch_width = 14;
|
||||||
|
const size_t num_patches =
|
||||||
|
config.vit_config.image_size / config.vit_config.patch_width;
|
||||||
|
config.vit_config.seq_len = num_patches * num_patches;
|
||||||
for (auto& layer_config : config.layer_configs) {
|
for (auto& layer_config : config.layer_configs) {
|
||||||
layer_config.optimized_gating = false;
|
layer_config.optimized_gating = false;
|
||||||
}
|
}
|
||||||
const size_t num_patches = config.image_size / config.patch_width;
|
LayerConfig vit_layer_config = LayerConfigVit(config.vit_config.model_dim);
|
||||||
config.vit_seq_len = num_patches * num_patches;
|
config.vit_config.layer_configs = {27, vit_layer_config};
|
||||||
LayerConfig vit_layer_config = LayerConfigVit(config.vit_model_dim);
|
config.vit_config.num_scales = 4 * config.vit_config.layer_configs.size();
|
||||||
config.vit_layer_configs = {27, vit_layer_config};
|
|
||||||
config.num_vit_scales = 4 * config.vit_layer_configs.size();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static ModelConfig ConfigPaliGemma_224() {
|
static ModelConfig ConfigPaliGemma_224() {
|
||||||
|
|
@ -283,11 +284,11 @@ static ModelConfig ConfigPaliGemma_448() {
|
||||||
return config;
|
return config;
|
||||||
}
|
}
|
||||||
|
|
||||||
ModelConfig VitConfig(const ModelConfig& config) {
|
ModelConfig GetVitConfig(const ModelConfig& config) {
|
||||||
ModelConfig vit_config = ConfigNoSSM();
|
ModelConfig vit_config = ConfigNoSSM();
|
||||||
vit_config.model_dim = config.vit_model_dim;
|
vit_config.model_dim = config.vit_config.model_dim;
|
||||||
vit_config.seq_len = config.vit_seq_len;
|
vit_config.seq_len = config.vit_config.seq_len;
|
||||||
vit_config.layer_configs = config.vit_layer_configs;
|
vit_config.layer_configs = config.vit_config.layer_configs;
|
||||||
// The Vit part does not have a vocabulary, the image patches are embedded.
|
// The Vit part does not have a vocabulary, the image patches are embedded.
|
||||||
vit_config.vocab_size = 0;
|
vit_config.vocab_size = 0;
|
||||||
return vit_config;
|
return vit_config;
|
||||||
|
|
@ -402,9 +403,28 @@ bool LayerConfig::TestEqual(const LayerConfig& other, bool partial,
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool VitConfig::TestEqual(const VitConfig& other, bool partial,
|
||||||
|
bool debug) const {
|
||||||
|
bool result = true;
|
||||||
|
TEST_EQUAL(model_dim, other.model_dim);
|
||||||
|
TEST_EQUAL(seq_len, other.seq_len);
|
||||||
|
if (!partial) {
|
||||||
|
TEST_EQUAL(num_scales, other.num_scales);
|
||||||
|
}
|
||||||
|
TEST_EQUAL(patch_width, other.patch_width);
|
||||||
|
TEST_EQUAL(image_size, other.image_size);
|
||||||
|
RETURN_IF_NOT_EQUAL(layer_configs.size(), other.layer_configs.size());
|
||||||
|
for (size_t i = 0; i < layer_configs.size(); ++i) {
|
||||||
|
result &=
|
||||||
|
layer_configs[i].TestEqual(other.layer_configs[i], partial, debug);
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
bool ModelConfig::TestEqual(const ModelConfig& other, bool partial,
|
bool ModelConfig::TestEqual(const ModelConfig& other, bool partial,
|
||||||
bool debug) const {
|
bool debug) const {
|
||||||
bool result = true;
|
bool result = true;
|
||||||
|
TEST_EQUAL(model_family_version, other.model_family_version);
|
||||||
// We don't care about model_name, model, wrapping, or weight being different,
|
// We don't care about model_name, model, wrapping, or weight being different,
|
||||||
// but will output in debug mode if they are.
|
// but will output in debug mode if they are.
|
||||||
if (debug) {
|
if (debug) {
|
||||||
|
|
@ -415,13 +435,10 @@ bool ModelConfig::TestEqual(const ModelConfig& other, bool partial,
|
||||||
WARN_IF_NOT_EQUAL(static_cast<int>(weight), static_cast<int>(other.weight));
|
WARN_IF_NOT_EQUAL(static_cast<int>(weight), static_cast<int>(other.weight));
|
||||||
}
|
}
|
||||||
TEST_EQUAL(model_dim, other.model_dim);
|
TEST_EQUAL(model_dim, other.model_dim);
|
||||||
TEST_EQUAL(vit_model_dim, other.vit_model_dim);
|
|
||||||
TEST_EQUAL(vocab_size, other.vocab_size);
|
TEST_EQUAL(vocab_size, other.vocab_size);
|
||||||
TEST_EQUAL(seq_len, other.seq_len);
|
TEST_EQUAL(seq_len, other.seq_len);
|
||||||
TEST_EQUAL(vit_seq_len, other.vit_seq_len);
|
|
||||||
if (!partial) {
|
if (!partial) {
|
||||||
TEST_EQUAL(num_tensor_scales, other.num_tensor_scales);
|
TEST_EQUAL(num_tensor_scales, other.num_tensor_scales);
|
||||||
TEST_EQUAL(num_vit_scales, other.num_vit_scales);
|
|
||||||
}
|
}
|
||||||
TEST_EQUAL(att_cap, other.att_cap);
|
TEST_EQUAL(att_cap, other.att_cap);
|
||||||
TEST_EQUAL(final_cap, other.final_cap);
|
TEST_EQUAL(final_cap, other.final_cap);
|
||||||
|
|
@ -439,11 +456,6 @@ bool ModelConfig::TestEqual(const ModelConfig& other, bool partial,
|
||||||
for (size_t i = 0; i < attention_window_sizes.size(); ++i) {
|
for (size_t i = 0; i < attention_window_sizes.size(); ++i) {
|
||||||
TEST_EQUAL(attention_window_sizes[i], other.attention_window_sizes[i]);
|
TEST_EQUAL(attention_window_sizes[i], other.attention_window_sizes[i]);
|
||||||
}
|
}
|
||||||
RETURN_IF_NOT_EQUAL(vit_layer_configs.size(), other.vit_layer_configs.size());
|
|
||||||
for (size_t i = 0; i < vit_layer_configs.size(); ++i) {
|
|
||||||
result &= vit_layer_configs[i].TestEqual(other.vit_layer_configs[i],
|
|
||||||
partial, debug);
|
|
||||||
}
|
|
||||||
if (!partial) {
|
if (!partial) {
|
||||||
if (scale_names != other.scale_names) {
|
if (scale_names != other.scale_names) {
|
||||||
result = false;
|
result = false;
|
||||||
|
|
@ -453,9 +465,7 @@ bool ModelConfig::TestEqual(const ModelConfig& other, bool partial,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
TEST_EQUAL(norm_num_groups, other.norm_num_groups);
|
TEST_EQUAL(norm_num_groups, other.norm_num_groups);
|
||||||
TEST_EQUAL(model_family_version, other.model_family_version);
|
result &= vit_config.TestEqual(other.vit_config, partial, debug);
|
||||||
TEST_EQUAL(patch_width, other.patch_width);
|
|
||||||
TEST_EQUAL(image_size, other.image_size);
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -220,6 +220,33 @@ struct LayerConfig : public IFields {
|
||||||
PostQKType post_qk = PostQKType::Rope;
|
PostQKType post_qk = PostQKType::Rope;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Dimensions related to image processing.
|
||||||
|
struct VitConfig : public IFields {
|
||||||
|
// Returns true if *this and other are equal.
|
||||||
|
// If partial is true, then we don't check for items that are only set after
|
||||||
|
// the tensors are loaded from the checkpoint.
|
||||||
|
// If debug is true, then we output the mismatched fields to stderr.
|
||||||
|
bool TestEqual(const VitConfig& other, bool partial, bool debug) const;
|
||||||
|
|
||||||
|
const char* Name() const override { return "VitConfig"; }
|
||||||
|
|
||||||
|
void VisitFields(IFieldsVisitor& visitor) override {
|
||||||
|
visitor(model_dim);
|
||||||
|
visitor(seq_len);
|
||||||
|
visitor(num_scales);
|
||||||
|
visitor(patch_width);
|
||||||
|
visitor(image_size);
|
||||||
|
visitor(layer_configs);
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t model_dim = 0;
|
||||||
|
uint32_t seq_len = 0;
|
||||||
|
uint32_t num_scales = 0;
|
||||||
|
uint32_t patch_width = 14;
|
||||||
|
uint32_t image_size = 224;
|
||||||
|
std::vector<LayerConfig> layer_configs;
|
||||||
|
};
|
||||||
|
|
||||||
struct ModelConfig : public IFields {
|
struct ModelConfig : public IFields {
|
||||||
// Returns true if *this and other are equal.
|
// Returns true if *this and other are equal.
|
||||||
// If partial is true, then we don't check for items that are only set after
|
// If partial is true, then we don't check for items that are only set after
|
||||||
|
|
@ -277,26 +304,21 @@ struct ModelConfig : public IFields {
|
||||||
visitor(layer_configs);
|
visitor(layer_configs);
|
||||||
visitor(attention_window_sizes);
|
visitor(attention_window_sizes);
|
||||||
visitor(norm_num_groups);
|
visitor(norm_num_groups);
|
||||||
visitor(vit_model_dim);
|
visitor(vit_config);
|
||||||
visitor(vit_seq_len);
|
|
||||||
visitor(num_vit_scales);
|
|
||||||
visitor(vit_layer_configs);
|
|
||||||
visitor(patch_width);
|
|
||||||
visitor(image_size);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Major version of the model family. It is used as a fallback to distinguish
|
||||||
|
// between model types when there is no explicit information in the config.
|
||||||
|
uint32_t model_family_version = 1;
|
||||||
std::string model_name;
|
std::string model_name;
|
||||||
Model model = Model::UNKNOWN;
|
Model model = Model::UNKNOWN;
|
||||||
PromptWrapping wrapping = PromptWrapping::GEMMA_PT;
|
PromptWrapping wrapping = PromptWrapping::GEMMA_PT;
|
||||||
Type weight = Type::kUnknown;
|
Type weight = Type::kUnknown;
|
||||||
uint32_t num_layers = 0;
|
uint32_t num_layers = 0;
|
||||||
uint32_t model_dim = 0;
|
uint32_t model_dim = 0;
|
||||||
uint32_t vit_model_dim = 0;
|
|
||||||
uint32_t vocab_size = 0;
|
uint32_t vocab_size = 0;
|
||||||
uint32_t seq_len = 0;
|
uint32_t seq_len = 0;
|
||||||
uint32_t vit_seq_len = 0;
|
|
||||||
uint32_t num_tensor_scales = 0;
|
uint32_t num_tensor_scales = 0;
|
||||||
uint32_t num_vit_scales = 0;
|
|
||||||
float att_cap = 0.0f;
|
float att_cap = 0.0f;
|
||||||
float final_cap = 0.0f;
|
float final_cap = 0.0f;
|
||||||
bool absolute_pe = false;
|
bool absolute_pe = false;
|
||||||
|
|
@ -304,13 +326,10 @@ struct ModelConfig : public IFields {
|
||||||
QueryScaleType query_scale = QueryScaleType::SqrtKeySize;
|
QueryScaleType query_scale = QueryScaleType::SqrtKeySize;
|
||||||
std::vector<LayerConfig> layer_configs;
|
std::vector<LayerConfig> layer_configs;
|
||||||
std::vector<uint32_t> attention_window_sizes;
|
std::vector<uint32_t> attention_window_sizes;
|
||||||
std::vector<LayerConfig> vit_layer_configs;
|
|
||||||
std::unordered_set<std::string> scale_names;
|
std::unordered_set<std::string> scale_names;
|
||||||
uint32_t norm_num_groups = 1;
|
uint32_t norm_num_groups = 1;
|
||||||
uint32_t model_family_version = 1;
|
|
||||||
// Dimensions related to image processing.
|
// Dimensions related to image processing.
|
||||||
uint32_t patch_width = 14;
|
VitConfig vit_config;
|
||||||
uint32_t image_size = 224;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// Returns the config for the given model.
|
// Returns the config for the given model.
|
||||||
|
|
@ -320,7 +339,7 @@ ModelConfig ConfigFromModel(Model model);
|
||||||
Model ModelFromConfig(const ModelConfig& config);
|
Model ModelFromConfig(const ModelConfig& config);
|
||||||
|
|
||||||
// Returns the sub-config for the ViT model of the PaliGemma model.
|
// Returns the sub-config for the ViT model of the PaliGemma model.
|
||||||
ModelConfig VitConfig(const ModelConfig& config);
|
ModelConfig GetVitConfig(const ModelConfig& config);
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -367,12 +367,13 @@ template <class TConfig>
|
||||||
void AssertMatch(const ModelConfig& config) {
|
void AssertMatch(const ModelConfig& config) {
|
||||||
ASSERT_EQ(TConfig::kModelDim, config.model_dim);
|
ASSERT_EQ(TConfig::kModelDim, config.model_dim);
|
||||||
if constexpr (TConfig::VitConfig::kModelDim != 0) {
|
if constexpr (TConfig::VitConfig::kModelDim != 0) {
|
||||||
ASSERT_EQ(TConfig::VitConfig::kModelDim, config.vit_model_dim);
|
ASSERT_EQ(TConfig::VitConfig::kModelDim, config.vit_config.model_dim);
|
||||||
ASSERT_EQ(TConfig::VitConfig::kSeqLen, config.vit_seq_len);
|
ASSERT_EQ(TConfig::VitConfig::kSeqLen, config.vit_config.seq_len);
|
||||||
ASSERT_EQ(TConfig::VitConfig::kNumTensorScales, config.num_vit_scales);
|
ASSERT_EQ(TConfig::VitConfig::kNumTensorScales,
|
||||||
for (size_t i = 0; i < config.vit_layer_configs.size(); ++i) {
|
config.vit_config.num_scales);
|
||||||
|
for (size_t i = 0; i < config.vit_config.layer_configs.size(); ++i) {
|
||||||
ASSERT_EQ(TConfig::VitConfig::kLayerConfig[i],
|
ASSERT_EQ(TConfig::VitConfig::kLayerConfig[i],
|
||||||
config.vit_layer_configs[i].type);
|
config.vit_config.layer_configs[i].type);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ASSERT_EQ(TConfig::kVocabSize, config.vocab_size);
|
ASSERT_EQ(TConfig::kVocabSize, config.vocab_size);
|
||||||
|
|
|
||||||
|
|
@ -1042,9 +1042,9 @@ template <typename T>
|
||||||
HWY_NOINLINE void EmbedImagePatches(const Image& image,
|
HWY_NOINLINE void EmbedImagePatches(const Image& image,
|
||||||
const ModelWeightsPtrs<T>& weights,
|
const ModelWeightsPtrs<T>& weights,
|
||||||
Activations& activations) {
|
Activations& activations) {
|
||||||
const size_t model_dim = weights.weights_config.vit_model_dim;
|
const size_t model_dim = weights.weights_config.vit_config.model_dim;
|
||||||
const size_t patch_width = weights.weights_config.patch_width;
|
const size_t patch_width = weights.weights_config.vit_config.patch_width;
|
||||||
const size_t seq_len = weights.weights_config.vit_seq_len;
|
const size_t seq_len = weights.weights_config.vit_config.seq_len;
|
||||||
const size_t patch_size = patch_width * patch_width * 3;
|
const size_t patch_size = patch_width * patch_width * 3;
|
||||||
HWY_DASSERT(weights.vit_img_embedding_kernel.NumElements() ==
|
HWY_DASSERT(weights.vit_img_embedding_kernel.NumElements() ==
|
||||||
patch_size * model_dim);
|
patch_size * model_dim);
|
||||||
|
|
@ -1087,14 +1087,15 @@ HWY_NOINLINE void PrefillVit(const ModelWeightsPtrs<T>& weights,
|
||||||
const Image& image, ImageTokens& image_tokens,
|
const Image& image, ImageTokens& image_tokens,
|
||||||
Activations& activations) {
|
Activations& activations) {
|
||||||
PROFILER_ZONE("Gen.PrefillVit");
|
PROFILER_ZONE("Gen.PrefillVit");
|
||||||
const size_t num_tokens = weights.weights_config.vit_seq_len;
|
const size_t num_tokens = weights.weights_config.vit_config.seq_len;
|
||||||
const size_t vit_model_dim = weights.weights_config.vit_model_dim;
|
const size_t vit_model_dim = weights.weights_config.vit_config.model_dim;
|
||||||
HWY_ASSERT(num_tokens == activations.x.BatchSize());
|
HWY_ASSERT(num_tokens == activations.x.BatchSize());
|
||||||
// Embed the image patches.
|
// Embed the image patches.
|
||||||
EmbedImagePatches(image, weights, activations);
|
EmbedImagePatches(image, weights, activations);
|
||||||
// Go through all layers.
|
// Go through all layers.
|
||||||
for (size_t layer = 0;
|
for (size_t layer = 0;
|
||||||
layer < weights.weights_config.vit_layer_configs.size(); ++layer) {
|
layer < weights.weights_config.vit_config.layer_configs.size();
|
||||||
|
++layer) {
|
||||||
const auto* layer_weights = weights.GetVitLayer(layer);
|
const auto* layer_weights = weights.GetVitLayer(layer);
|
||||||
VitTransformerLayer(num_tokens, layer, layer_weights, activations);
|
VitTransformerLayer(num_tokens, layer, layer_weights, activations);
|
||||||
}
|
}
|
||||||
|
|
@ -1413,11 +1414,11 @@ void GenerateImageTokensT(const ModelWeightsStorage& model,
|
||||||
const RuntimeConfig& runtime_config,
|
const RuntimeConfig& runtime_config,
|
||||||
const Image& image, ImageTokens& image_tokens,
|
const Image& image, ImageTokens& image_tokens,
|
||||||
NestedPools& pools) {
|
NestedPools& pools) {
|
||||||
if (model.Config().vit_layer_configs.empty()) {
|
if (model.Config().vit_config.layer_configs.empty()) {
|
||||||
HWY_ABORT("Model does not support generating image tokens.");
|
HWY_ABORT("Model does not support generating image tokens.");
|
||||||
}
|
}
|
||||||
RuntimeConfig prefill_runtime_config = runtime_config;
|
RuntimeConfig prefill_runtime_config = runtime_config;
|
||||||
ModelConfig vit_config = VitConfig(model.Config());
|
ModelConfig vit_config = GetVitConfig(model.Config());
|
||||||
prefill_runtime_config.prefill_tbatch_size = vit_config.seq_len;
|
prefill_runtime_config.prefill_tbatch_size = vit_config.seq_len;
|
||||||
Activations prefill_activations(vit_config);
|
Activations prefill_activations(vit_config);
|
||||||
prefill_activations.Allocate(vit_config.seq_len, pools);
|
prefill_activations.Allocate(vit_config.seq_len, pools);
|
||||||
|
|
|
||||||
|
|
@ -94,11 +94,12 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
|
||||||
Image image;
|
Image image;
|
||||||
ImageTokens image_tokens;
|
ImageTokens image_tokens;
|
||||||
if (have_image) {
|
if (have_image) {
|
||||||
image_tokens = ImageTokens(Extents2D(model.GetModelConfig().vit_seq_len,
|
image_tokens =
|
||||||
|
ImageTokens(Extents2D(model.GetModelConfig().vit_config.seq_len,
|
||||||
model.GetModelConfig().model_dim));
|
model.GetModelConfig().model_dim));
|
||||||
HWY_ASSERT(model.Info().wrapping == PromptWrapping::PALIGEMMA);
|
HWY_ASSERT(model.Info().wrapping == PromptWrapping::PALIGEMMA);
|
||||||
HWY_ASSERT(image.ReadPPM(args.image_file.path));
|
HWY_ASSERT(image.ReadPPM(args.image_file.path));
|
||||||
const size_t image_size = model.GetModelConfig().image_size;
|
const size_t image_size = model.GetModelConfig().vit_config.image_size;
|
||||||
image.Resize(image_size, image_size);
|
image.Resize(image_size, image_size);
|
||||||
RuntimeConfig runtime_config = {
|
RuntimeConfig runtime_config = {
|
||||||
.gen = &gen, .verbosity = app.verbosity, .use_spinning = app.spin};
|
.gen = &gen, .verbosity = app.verbosity, .use_spinning = app.spin};
|
||||||
|
|
|
||||||
|
|
@ -36,29 +36,29 @@ std::vector<TensorInfo> ModelTensors(const ModelConfig& config) {
|
||||||
.name = "enc_norm_bias",
|
.name = "enc_norm_bias",
|
||||||
.source_names = {"img/Transformer/encoder_norm/bias"},
|
.source_names = {"img/Transformer/encoder_norm/bias"},
|
||||||
.axes = {0},
|
.axes = {0},
|
||||||
.shape = {config.vit_model_dim},
|
.shape = {config.vit_config.model_dim},
|
||||||
.min_size = Type::kBF16,
|
.min_size = Type::kBF16,
|
||||||
},
|
},
|
||||||
TensorInfo{
|
TensorInfo{
|
||||||
.name = "enc_norm_scale",
|
.name = "enc_norm_scale",
|
||||||
.source_names = {"img/Transformer/encoder_norm/scale"},
|
.source_names = {"img/Transformer/encoder_norm/scale"},
|
||||||
.axes = {0},
|
.axes = {0},
|
||||||
.shape = {config.vit_model_dim},
|
.shape = {config.vit_config.model_dim},
|
||||||
.min_size = Type::kBF16,
|
.min_size = Type::kBF16,
|
||||||
},
|
},
|
||||||
TensorInfo{
|
TensorInfo{
|
||||||
.name = "img_emb_bias",
|
.name = "img_emb_bias",
|
||||||
.source_names = {"img/embedding/bias"},
|
.source_names = {"img/embedding/bias"},
|
||||||
.axes = {0},
|
.axes = {0},
|
||||||
.shape = {config.vit_model_dim},
|
.shape = {config.vit_config.model_dim},
|
||||||
.min_size = Type::kF32,
|
.min_size = Type::kF32,
|
||||||
},
|
},
|
||||||
TensorInfo{
|
TensorInfo{
|
||||||
.name = "img_emb_kernel",
|
.name = "img_emb_kernel",
|
||||||
.source_names = {"img/embedding/kernel"},
|
.source_names = {"img/embedding/kernel"},
|
||||||
.axes = {3, 0, 1, 2},
|
.axes = {3, 0, 1, 2},
|
||||||
.shape = {config.vit_model_dim, config.patch_width,
|
.shape = {config.vit_config.model_dim, config.vit_config.patch_width,
|
||||||
config.patch_width, 3},
|
config.vit_config.patch_width, 3},
|
||||||
.min_size = Type::kBF16,
|
.min_size = Type::kBF16,
|
||||||
.cols_take_extra_dims = true,
|
.cols_take_extra_dims = true,
|
||||||
},
|
},
|
||||||
|
|
@ -73,14 +73,15 @@ std::vector<TensorInfo> ModelTensors(const ModelConfig& config) {
|
||||||
.name = "img_head_kernel",
|
.name = "img_head_kernel",
|
||||||
.source_names = {"img/head/kernel"},
|
.source_names = {"img/head/kernel"},
|
||||||
.axes = {1, 0},
|
.axes = {1, 0},
|
||||||
.shape = {config.model_dim, config.vit_model_dim},
|
.shape = {config.model_dim, config.vit_config.model_dim},
|
||||||
.min_size = Type::kBF16,
|
.min_size = Type::kBF16,
|
||||||
},
|
},
|
||||||
TensorInfo{
|
TensorInfo{
|
||||||
.name = "img_pos_emb",
|
.name = "img_pos_emb",
|
||||||
.source_names = {"img/pos_embedding"},
|
.source_names = {"img/pos_embedding"},
|
||||||
.axes = {0, 1},
|
.axes = {0, 1},
|
||||||
.shape = {/*1,*/ config.vit_seq_len, config.vit_model_dim},
|
.shape = {/*1,*/ config.vit_config.seq_len,
|
||||||
|
config.vit_config.model_dim},
|
||||||
.min_size = Type::kF32,
|
.min_size = Type::kF32,
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
@ -95,7 +96,7 @@ std::vector<TensorInfo> ImageLayerTensors(const ModelConfig& config,
|
||||||
.name = "attn_out_w",
|
.name = "attn_out_w",
|
||||||
.source_names = {"MultiHeadDotProductAttention_0/out/kernel"},
|
.source_names = {"MultiHeadDotProductAttention_0/out/kernel"},
|
||||||
.axes = {2, 0, 1},
|
.axes = {2, 0, 1},
|
||||||
.shape = {config.vit_model_dim, layer_config.heads,
|
.shape = {config.vit_config.model_dim, layer_config.heads,
|
||||||
layer_config.qkv_dim},
|
layer_config.qkv_dim},
|
||||||
.min_size = Type::kBF16,
|
.min_size = Type::kBF16,
|
||||||
.cols_take_extra_dims = true,
|
.cols_take_extra_dims = true,
|
||||||
|
|
@ -104,7 +105,7 @@ std::vector<TensorInfo> ImageLayerTensors(const ModelConfig& config,
|
||||||
.name = "attn_out_b",
|
.name = "attn_out_b",
|
||||||
.source_names = {"MultiHeadDotProductAttention_0/out/bias"},
|
.source_names = {"MultiHeadDotProductAttention_0/out/bias"},
|
||||||
.axes = {0},
|
.axes = {0},
|
||||||
.shape = {config.vit_model_dim},
|
.shape = {config.vit_config.model_dim},
|
||||||
.min_size = Type::kF32,
|
.min_size = Type::kF32,
|
||||||
},
|
},
|
||||||
TensorInfo{
|
TensorInfo{
|
||||||
|
|
@ -112,7 +113,7 @@ std::vector<TensorInfo> ImageLayerTensors(const ModelConfig& config,
|
||||||
.source_names = {"MultiHeadDotProductAttention_0/query/kernel"},
|
.source_names = {"MultiHeadDotProductAttention_0/query/kernel"},
|
||||||
.axes = {1, 2, 0},
|
.axes = {1, 2, 0},
|
||||||
.shape = {layer_config.heads, layer_config.qkv_dim,
|
.shape = {layer_config.heads, layer_config.qkv_dim,
|
||||||
config.vit_model_dim},
|
config.vit_config.model_dim},
|
||||||
.concat_names = {"qkv_ein_w", "k_ein_w", "v_ein_w"},
|
.concat_names = {"qkv_ein_w", "k_ein_w", "v_ein_w"},
|
||||||
.concat_axis = 1,
|
.concat_axis = 1,
|
||||||
.min_size = Type::kBF16,
|
.min_size = Type::kBF16,
|
||||||
|
|
@ -122,7 +123,7 @@ std::vector<TensorInfo> ImageLayerTensors(const ModelConfig& config,
|
||||||
.source_names = {"MultiHeadDotProductAttention_0/key/kernel"},
|
.source_names = {"MultiHeadDotProductAttention_0/key/kernel"},
|
||||||
.axes = {1, 2, 0},
|
.axes = {1, 2, 0},
|
||||||
.shape = {layer_config.heads, layer_config.qkv_dim,
|
.shape = {layer_config.heads, layer_config.qkv_dim,
|
||||||
config.vit_model_dim},
|
config.vit_config.model_dim},
|
||||||
.concat_names = {""},
|
.concat_names = {""},
|
||||||
.min_size = Type::kBF16,
|
.min_size = Type::kBF16,
|
||||||
},
|
},
|
||||||
|
|
@ -131,7 +132,7 @@ std::vector<TensorInfo> ImageLayerTensors(const ModelConfig& config,
|
||||||
.source_names = {"MultiHeadDotProductAttention_0/value/kernel"},
|
.source_names = {"MultiHeadDotProductAttention_0/value/kernel"},
|
||||||
.axes = {1, 2, 0},
|
.axes = {1, 2, 0},
|
||||||
.shape = {layer_config.heads, layer_config.qkv_dim,
|
.shape = {layer_config.heads, layer_config.qkv_dim,
|
||||||
config.vit_model_dim},
|
config.vit_config.model_dim},
|
||||||
.concat_names = {""},
|
.concat_names = {""},
|
||||||
.min_size = Type::kBF16,
|
.min_size = Type::kBF16,
|
||||||
},
|
},
|
||||||
|
|
@ -140,7 +141,7 @@ std::vector<TensorInfo> ImageLayerTensors(const ModelConfig& config,
|
||||||
.source_names = {"MultiHeadDotProductAttention_0/qkv/kernel"},
|
.source_names = {"MultiHeadDotProductAttention_0/qkv/kernel"},
|
||||||
.axes = {1, 2, 0},
|
.axes = {1, 2, 0},
|
||||||
.shape = {layer_config.heads, 3 * layer_config.qkv_dim,
|
.shape = {layer_config.heads, 3 * layer_config.qkv_dim,
|
||||||
config.vit_model_dim},
|
config.vit_config.model_dim},
|
||||||
.min_size = Type::kBF16,
|
.min_size = Type::kBF16,
|
||||||
},
|
},
|
||||||
TensorInfo{
|
TensorInfo{
|
||||||
|
|
@ -180,7 +181,7 @@ std::vector<TensorInfo> ImageLayerTensors(const ModelConfig& config,
|
||||||
.name = "linear_0_w",
|
.name = "linear_0_w",
|
||||||
.source_names = {"MlpBlock_0/Dense_0/kernel"},
|
.source_names = {"MlpBlock_0/Dense_0/kernel"},
|
||||||
.axes = {1, 0},
|
.axes = {1, 0},
|
||||||
.shape = {layer_config.ff_hidden_dim, config.vit_model_dim},
|
.shape = {layer_config.ff_hidden_dim, config.vit_config.model_dim},
|
||||||
.min_size = Type::kBF16,
|
.min_size = Type::kBF16,
|
||||||
},
|
},
|
||||||
TensorInfo{
|
TensorInfo{
|
||||||
|
|
@ -194,42 +195,42 @@ std::vector<TensorInfo> ImageLayerTensors(const ModelConfig& config,
|
||||||
.name = "linear_1_w",
|
.name = "linear_1_w",
|
||||||
.source_names = {"MlpBlock_0/Dense_1/kernel"},
|
.source_names = {"MlpBlock_0/Dense_1/kernel"},
|
||||||
.axes = {1, 0},
|
.axes = {1, 0},
|
||||||
.shape = {config.vit_model_dim, layer_config.ff_hidden_dim},
|
.shape = {config.vit_config.model_dim, layer_config.ff_hidden_dim},
|
||||||
.min_size = Type::kBF16,
|
.min_size = Type::kBF16,
|
||||||
},
|
},
|
||||||
TensorInfo{
|
TensorInfo{
|
||||||
.name = "linear_1_b",
|
.name = "linear_1_b",
|
||||||
.source_names = {"MlpBlock_0/Dense_1/bias"},
|
.source_names = {"MlpBlock_0/Dense_1/bias"},
|
||||||
.axes = {0},
|
.axes = {0},
|
||||||
.shape = {config.vit_model_dim},
|
.shape = {config.vit_config.model_dim},
|
||||||
.min_size = Type::kF32,
|
.min_size = Type::kF32,
|
||||||
},
|
},
|
||||||
TensorInfo{
|
TensorInfo{
|
||||||
.name = "ln_0_bias",
|
.name = "ln_0_bias",
|
||||||
.source_names = {"img/Transformer/encoderblock/LayerNorm_0/bias"},
|
.source_names = {"img/Transformer/encoderblock/LayerNorm_0/bias"},
|
||||||
.axes = {0},
|
.axes = {0},
|
||||||
.shape = {config.vit_model_dim},
|
.shape = {config.vit_config.model_dim},
|
||||||
.min_size = Type::kBF16,
|
.min_size = Type::kBF16,
|
||||||
},
|
},
|
||||||
TensorInfo{
|
TensorInfo{
|
||||||
.name = "ln_0_scale",
|
.name = "ln_0_scale",
|
||||||
.source_names = {"img/Transformer/encoderblock/LayerNorm_0/scale"},
|
.source_names = {"img/Transformer/encoderblock/LayerNorm_0/scale"},
|
||||||
.axes = {0},
|
.axes = {0},
|
||||||
.shape = {config.vit_model_dim},
|
.shape = {config.vit_config.model_dim},
|
||||||
.min_size = Type::kBF16,
|
.min_size = Type::kBF16,
|
||||||
},
|
},
|
||||||
TensorInfo{
|
TensorInfo{
|
||||||
.name = "ln_1_bias",
|
.name = "ln_1_bias",
|
||||||
.source_names = {"img/Transformer/encoderblock/LayerNorm_1/bias"},
|
.source_names = {"img/Transformer/encoderblock/LayerNorm_1/bias"},
|
||||||
.axes = {0},
|
.axes = {0},
|
||||||
.shape = {config.vit_model_dim},
|
.shape = {config.vit_config.model_dim},
|
||||||
.min_size = Type::kBF16,
|
.min_size = Type::kBF16,
|
||||||
},
|
},
|
||||||
TensorInfo{
|
TensorInfo{
|
||||||
.name = "ln_1_scale",
|
.name = "ln_1_scale",
|
||||||
.source_names = {"img/Transformer/encoderblock/LayerNorm_1/scale"},
|
.source_names = {"img/Transformer/encoderblock/LayerNorm_1/scale"},
|
||||||
.axes = {0},
|
.axes = {0},
|
||||||
.shape = {config.vit_model_dim},
|
.shape = {config.vit_config.model_dim},
|
||||||
.min_size = Type::kBF16,
|
.min_size = Type::kBF16,
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
@ -526,8 +527,8 @@ TensorIndex::TensorIndex(const ModelConfig& config, int llm_layer_idx,
|
||||||
if (llm_layer_idx < 0 && img_layer_idx < 0) {
|
if (llm_layer_idx < 0 && img_layer_idx < 0) {
|
||||||
tensors_ = ModelTensors(config);
|
tensors_ = ModelTensors(config);
|
||||||
} else if (llm_layer_idx_ < 0 && 0 <= img_layer_idx &&
|
} else if (llm_layer_idx_ < 0 && 0 <= img_layer_idx &&
|
||||||
img_layer_idx < config.vit_layer_configs.size()) {
|
img_layer_idx < config.vit_config.layer_configs.size()) {
|
||||||
const auto& layer_config = config.vit_layer_configs[img_layer_idx];
|
const auto& layer_config = config.vit_config.layer_configs[img_layer_idx];
|
||||||
tensors_ = ImageLayerTensors(config, layer_config);
|
tensors_ = ImageLayerTensors(config, layer_config);
|
||||||
} else if (0 <= llm_layer_idx &&
|
} else if (0 <= llm_layer_idx &&
|
||||||
llm_layer_idx < config.layer_configs.size()) {
|
llm_layer_idx < config.layer_configs.size()) {
|
||||||
|
|
|
||||||
|
|
@ -35,7 +35,7 @@ TEST(TensorIndexTest, FindName) {
|
||||||
/*split_and_reshape=*/false);
|
/*split_and_reshape=*/false);
|
||||||
}
|
}
|
||||||
for (size_t img_layer_idx = 0;
|
for (size_t img_layer_idx = 0;
|
||||||
img_layer_idx < config.vit_layer_configs.size();
|
img_layer_idx < config.vit_config.layer_configs.size();
|
||||||
++img_layer_idx) {
|
++img_layer_idx) {
|
||||||
tensor_indexes.emplace_back(config, /*llm_layer_idx=*/-1,
|
tensor_indexes.emplace_back(config, /*llm_layer_idx=*/-1,
|
||||||
static_cast<int>(img_layer_idx),
|
static_cast<int>(img_layer_idx),
|
||||||
|
|
|
||||||
|
|
@ -86,7 +86,7 @@ BlobError ModelWeightsStorage::Load(const Path& weights, Model model_type,
|
||||||
config_ = ConfigFromModel(model_type);
|
config_ = ConfigFromModel(model_type);
|
||||||
config_.weight = weight_type;
|
config_.weight = weight_type;
|
||||||
config_.wrapping = wrapping;
|
config_.wrapping = wrapping;
|
||||||
scales.resize(config_.num_tensor_scales + config_.num_vit_scales);
|
scales.resize(config_.num_tensor_scales + config_.vit_config.num_scales);
|
||||||
}
|
}
|
||||||
CreateForType(config_.weight, pool);
|
CreateForType(config_.weight, pool);
|
||||||
CallForModelWeightT<TensorLoader>(fet, loader);
|
CallForModelWeightT<TensorLoader>(fet, loader);
|
||||||
|
|
|
||||||
|
|
@ -344,8 +344,9 @@ struct ModelWeightsPtrs {
|
||||||
c_layers.push_back(LayerWeightsPtrs<Weight>(layer_config, tensor_index));
|
c_layers.push_back(LayerWeightsPtrs<Weight>(layer_config, tensor_index));
|
||||||
}
|
}
|
||||||
for (int index = 0;
|
for (int index = 0;
|
||||||
index < static_cast<int>(config.vit_layer_configs.size()); ++index) {
|
index < static_cast<int>(config.vit_config.layer_configs.size());
|
||||||
const auto& layer_config = config.vit_layer_configs[index];
|
++index) {
|
||||||
|
const auto& layer_config = config.vit_config.layer_configs[index];
|
||||||
TensorIndex tensor_index(config, /*llm_layer_idx=*/-1, index,
|
TensorIndex tensor_index(config, /*llm_layer_idx=*/-1, index,
|
||||||
/*reshape_att=*/false);
|
/*reshape_att=*/false);
|
||||||
vit_layers.push_back(
|
vit_layers.push_back(
|
||||||
|
|
@ -479,7 +480,7 @@ struct ModelWeightsPtrs {
|
||||||
int sep_index = -1;
|
int sep_index = -1;
|
||||||
GEMMA_CALL_FUNC(embedder_input_embedding);
|
GEMMA_CALL_FUNC(embedder_input_embedding);
|
||||||
GEMMA_CALL_FUNC(final_norm_scale);
|
GEMMA_CALL_FUNC(final_norm_scale);
|
||||||
if (ptrs[0]->weights_config.vit_layer_configs.size() > 0) {
|
if (ptrs[0]->weights_config.vit_config.layer_configs.size() > 0) {
|
||||||
// Vit parts.
|
// Vit parts.
|
||||||
GEMMA_CALL_FUNC(vit_encoder_norm_bias);
|
GEMMA_CALL_FUNC(vit_encoder_norm_bias);
|
||||||
GEMMA_CALL_FUNC(vit_encoder_norm_scale);
|
GEMMA_CALL_FUNC(vit_encoder_norm_scale);
|
||||||
|
|
@ -498,7 +499,7 @@ struct ModelWeightsPtrs {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Vit layers. Not supported for compress_weights.
|
// Vit layers. Not supported for compress_weights.
|
||||||
if (ptrs[0]->weights_config.vit_layer_configs.size() > 0) {
|
if (ptrs[0]->weights_config.vit_config.layer_configs.size() > 0) {
|
||||||
for (int layer_idx = 0; layer_idx < ptrs[0]->vit_layers.size();
|
for (int layer_idx = 0; layer_idx < ptrs[0]->vit_layers.size();
|
||||||
++layer_idx) {
|
++layer_idx) {
|
||||||
auto type = ptrs[0]->vit_layers[layer_idx].layer_config.type;
|
auto type = ptrs[0]->vit_layers[layer_idx].layer_config.type;
|
||||||
|
|
|
||||||
|
|
@ -50,12 +50,13 @@ class PaliGemmaTest : public ::testing::Test {
|
||||||
void PaliGemmaTest::InitVit(const std::string& path) {
|
void PaliGemmaTest::InitVit(const std::string& path) {
|
||||||
ASSERT_NE(s_env->GetModel(), nullptr);
|
ASSERT_NE(s_env->GetModel(), nullptr);
|
||||||
Gemma& model = *(s_env->GetModel());
|
Gemma& model = *(s_env->GetModel());
|
||||||
image_tokens_ = ImageTokens(Extents2D(model.GetModelConfig().vit_seq_len,
|
image_tokens_ =
|
||||||
|
ImageTokens(Extents2D(model.GetModelConfig().vit_config.seq_len,
|
||||||
model.GetModelConfig().model_dim));
|
model.GetModelConfig().model_dim));
|
||||||
Image image;
|
Image image;
|
||||||
HWY_ASSERT(model.Info().wrapping == PromptWrapping::PALIGEMMA);
|
HWY_ASSERT(model.Info().wrapping == PromptWrapping::PALIGEMMA);
|
||||||
HWY_ASSERT(image.ReadPPM(path));
|
HWY_ASSERT(image.ReadPPM(path));
|
||||||
const size_t image_size = model.GetModelConfig().image_size;
|
const size_t image_size = model.GetModelConfig().vit_config.image_size;
|
||||||
image.Resize(image_size, image_size);
|
image.Resize(image_size, image_size);
|
||||||
RuntimeConfig runtime_config = {.gen = &s_env->MutableGen(), .verbosity = 0};
|
RuntimeConfig runtime_config = {.gen = &s_env->MutableGen(), .verbosity = 0};
|
||||||
model.GenerateImageTokens(runtime_config, image, image_tokens_);
|
model.GenerateImageTokens(runtime_config, image, image_tokens_);
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue