Moved the vit config fields to their own config struct

PiperOrigin-RevId: 715692800
This commit is contained in:
Ray Smith 2025-01-15 01:09:16 -08:00 committed by Copybara-Service
parent 9d40f0117e
commit b93231a47d
10 changed files with 119 additions and 84 deletions

View File

@ -253,18 +253,19 @@ static LayerConfig LayerConfigVit(size_t model_dim) {
// Adds a ViT config (SigLIP SoViT ViT, used in PaliGemma) to the model config. // Adds a ViT config (SigLIP SoViT ViT, used in PaliGemma) to the model config.
static void AddVitConfig(ModelConfig& config, size_t image_size = 224) { static void AddVitConfig(ModelConfig& config, size_t image_size = 224) {
config.vit_model_dim = 1152; config.vit_config.model_dim = 1152;
config.vocab_size = 256000 + 1024 + 128; // = 257152 config.vocab_size = 256000 + 1024 + 128; // = 257152
config.image_size = image_size; config.vit_config.image_size = image_size;
config.patch_width = 14; config.vit_config.patch_width = 14;
const size_t num_patches =
config.vit_config.image_size / config.vit_config.patch_width;
config.vit_config.seq_len = num_patches * num_patches;
for (auto& layer_config : config.layer_configs) { for (auto& layer_config : config.layer_configs) {
layer_config.optimized_gating = false; layer_config.optimized_gating = false;
} }
const size_t num_patches = config.image_size / config.patch_width; LayerConfig vit_layer_config = LayerConfigVit(config.vit_config.model_dim);
config.vit_seq_len = num_patches * num_patches; config.vit_config.layer_configs = {27, vit_layer_config};
LayerConfig vit_layer_config = LayerConfigVit(config.vit_model_dim); config.vit_config.num_scales = 4 * config.vit_config.layer_configs.size();
config.vit_layer_configs = {27, vit_layer_config};
config.num_vit_scales = 4 * config.vit_layer_configs.size();
} }
static ModelConfig ConfigPaliGemma_224() { static ModelConfig ConfigPaliGemma_224() {
@ -283,11 +284,11 @@ static ModelConfig ConfigPaliGemma_448() {
return config; return config;
} }
ModelConfig VitConfig(const ModelConfig& config) { ModelConfig GetVitConfig(const ModelConfig& config) {
ModelConfig vit_config = ConfigNoSSM(); ModelConfig vit_config = ConfigNoSSM();
vit_config.model_dim = config.vit_model_dim; vit_config.model_dim = config.vit_config.model_dim;
vit_config.seq_len = config.vit_seq_len; vit_config.seq_len = config.vit_config.seq_len;
vit_config.layer_configs = config.vit_layer_configs; vit_config.layer_configs = config.vit_config.layer_configs;
// The Vit part does not have a vocabulary, the image patches are embedded. // The Vit part does not have a vocabulary, the image patches are embedded.
vit_config.vocab_size = 0; vit_config.vocab_size = 0;
return vit_config; return vit_config;
@ -402,9 +403,28 @@ bool LayerConfig::TestEqual(const LayerConfig& other, bool partial,
return result; return result;
} }
bool VitConfig::TestEqual(const VitConfig& other, bool partial,
bool debug) const {
bool result = true;
TEST_EQUAL(model_dim, other.model_dim);
TEST_EQUAL(seq_len, other.seq_len);
if (!partial) {
TEST_EQUAL(num_scales, other.num_scales);
}
TEST_EQUAL(patch_width, other.patch_width);
TEST_EQUAL(image_size, other.image_size);
RETURN_IF_NOT_EQUAL(layer_configs.size(), other.layer_configs.size());
for (size_t i = 0; i < layer_configs.size(); ++i) {
result &=
layer_configs[i].TestEqual(other.layer_configs[i], partial, debug);
}
return result;
}
bool ModelConfig::TestEqual(const ModelConfig& other, bool partial, bool ModelConfig::TestEqual(const ModelConfig& other, bool partial,
bool debug) const { bool debug) const {
bool result = true; bool result = true;
TEST_EQUAL(model_family_version, other.model_family_version);
// We don't care about model_name, model, wrapping, or weight being different, // We don't care about model_name, model, wrapping, or weight being different,
// but will output in debug mode if they are. // but will output in debug mode if they are.
if (debug) { if (debug) {
@ -415,13 +435,10 @@ bool ModelConfig::TestEqual(const ModelConfig& other, bool partial,
WARN_IF_NOT_EQUAL(static_cast<int>(weight), static_cast<int>(other.weight)); WARN_IF_NOT_EQUAL(static_cast<int>(weight), static_cast<int>(other.weight));
} }
TEST_EQUAL(model_dim, other.model_dim); TEST_EQUAL(model_dim, other.model_dim);
TEST_EQUAL(vit_model_dim, other.vit_model_dim);
TEST_EQUAL(vocab_size, other.vocab_size); TEST_EQUAL(vocab_size, other.vocab_size);
TEST_EQUAL(seq_len, other.seq_len); TEST_EQUAL(seq_len, other.seq_len);
TEST_EQUAL(vit_seq_len, other.vit_seq_len);
if (!partial) { if (!partial) {
TEST_EQUAL(num_tensor_scales, other.num_tensor_scales); TEST_EQUAL(num_tensor_scales, other.num_tensor_scales);
TEST_EQUAL(num_vit_scales, other.num_vit_scales);
} }
TEST_EQUAL(att_cap, other.att_cap); TEST_EQUAL(att_cap, other.att_cap);
TEST_EQUAL(final_cap, other.final_cap); TEST_EQUAL(final_cap, other.final_cap);
@ -439,11 +456,6 @@ bool ModelConfig::TestEqual(const ModelConfig& other, bool partial,
for (size_t i = 0; i < attention_window_sizes.size(); ++i) { for (size_t i = 0; i < attention_window_sizes.size(); ++i) {
TEST_EQUAL(attention_window_sizes[i], other.attention_window_sizes[i]); TEST_EQUAL(attention_window_sizes[i], other.attention_window_sizes[i]);
} }
RETURN_IF_NOT_EQUAL(vit_layer_configs.size(), other.vit_layer_configs.size());
for (size_t i = 0; i < vit_layer_configs.size(); ++i) {
result &= vit_layer_configs[i].TestEqual(other.vit_layer_configs[i],
partial, debug);
}
if (!partial) { if (!partial) {
if (scale_names != other.scale_names) { if (scale_names != other.scale_names) {
result = false; result = false;
@ -453,9 +465,7 @@ bool ModelConfig::TestEqual(const ModelConfig& other, bool partial,
} }
} }
TEST_EQUAL(norm_num_groups, other.norm_num_groups); TEST_EQUAL(norm_num_groups, other.norm_num_groups);
TEST_EQUAL(model_family_version, other.model_family_version); result &= vit_config.TestEqual(other.vit_config, partial, debug);
TEST_EQUAL(patch_width, other.patch_width);
TEST_EQUAL(image_size, other.image_size);
return result; return result;
} }

View File

@ -220,6 +220,33 @@ struct LayerConfig : public IFields {
PostQKType post_qk = PostQKType::Rope; PostQKType post_qk = PostQKType::Rope;
}; };
// Dimensions related to image processing.
struct VitConfig : public IFields {
// Returns true if *this and other are equal.
// If partial is true, then we don't check for items that are only set after
// the tensors are loaded from the checkpoint.
// If debug is true, then we output the mismatched fields to stderr.
bool TestEqual(const VitConfig& other, bool partial, bool debug) const;
const char* Name() const override { return "VitConfig"; }
void VisitFields(IFieldsVisitor& visitor) override {
visitor(model_dim);
visitor(seq_len);
visitor(num_scales);
visitor(patch_width);
visitor(image_size);
visitor(layer_configs);
}
uint32_t model_dim = 0;
uint32_t seq_len = 0;
uint32_t num_scales = 0;
uint32_t patch_width = 14;
uint32_t image_size = 224;
std::vector<LayerConfig> layer_configs;
};
struct ModelConfig : public IFields { struct ModelConfig : public IFields {
// Returns true if *this and other are equal. // Returns true if *this and other are equal.
// If partial is true, then we don't check for items that are only set after // If partial is true, then we don't check for items that are only set after
@ -277,26 +304,21 @@ struct ModelConfig : public IFields {
visitor(layer_configs); visitor(layer_configs);
visitor(attention_window_sizes); visitor(attention_window_sizes);
visitor(norm_num_groups); visitor(norm_num_groups);
visitor(vit_model_dim); visitor(vit_config);
visitor(vit_seq_len);
visitor(num_vit_scales);
visitor(vit_layer_configs);
visitor(patch_width);
visitor(image_size);
} }
// Major version of the model family. It is used as a fallback to distinguish
// between model types when there is no explicit information in the config.
uint32_t model_family_version = 1;
std::string model_name; std::string model_name;
Model model = Model::UNKNOWN; Model model = Model::UNKNOWN;
PromptWrapping wrapping = PromptWrapping::GEMMA_PT; PromptWrapping wrapping = PromptWrapping::GEMMA_PT;
Type weight = Type::kUnknown; Type weight = Type::kUnknown;
uint32_t num_layers = 0; uint32_t num_layers = 0;
uint32_t model_dim = 0; uint32_t model_dim = 0;
uint32_t vit_model_dim = 0;
uint32_t vocab_size = 0; uint32_t vocab_size = 0;
uint32_t seq_len = 0; uint32_t seq_len = 0;
uint32_t vit_seq_len = 0;
uint32_t num_tensor_scales = 0; uint32_t num_tensor_scales = 0;
uint32_t num_vit_scales = 0;
float att_cap = 0.0f; float att_cap = 0.0f;
float final_cap = 0.0f; float final_cap = 0.0f;
bool absolute_pe = false; bool absolute_pe = false;
@ -304,13 +326,10 @@ struct ModelConfig : public IFields {
QueryScaleType query_scale = QueryScaleType::SqrtKeySize; QueryScaleType query_scale = QueryScaleType::SqrtKeySize;
std::vector<LayerConfig> layer_configs; std::vector<LayerConfig> layer_configs;
std::vector<uint32_t> attention_window_sizes; std::vector<uint32_t> attention_window_sizes;
std::vector<LayerConfig> vit_layer_configs;
std::unordered_set<std::string> scale_names; std::unordered_set<std::string> scale_names;
uint32_t norm_num_groups = 1; uint32_t norm_num_groups = 1;
uint32_t model_family_version = 1;
// Dimensions related to image processing. // Dimensions related to image processing.
uint32_t patch_width = 14; VitConfig vit_config;
uint32_t image_size = 224;
}; };
// Returns the config for the given model. // Returns the config for the given model.
@ -320,7 +339,7 @@ ModelConfig ConfigFromModel(Model model);
Model ModelFromConfig(const ModelConfig& config); Model ModelFromConfig(const ModelConfig& config);
// Returns the sub-config for the ViT model of the PaliGemma model. // Returns the sub-config for the ViT model of the PaliGemma model.
ModelConfig VitConfig(const ModelConfig& config); ModelConfig GetVitConfig(const ModelConfig& config);
} // namespace gcpp } // namespace gcpp

View File

@ -367,12 +367,13 @@ template <class TConfig>
void AssertMatch(const ModelConfig& config) { void AssertMatch(const ModelConfig& config) {
ASSERT_EQ(TConfig::kModelDim, config.model_dim); ASSERT_EQ(TConfig::kModelDim, config.model_dim);
if constexpr (TConfig::VitConfig::kModelDim != 0) { if constexpr (TConfig::VitConfig::kModelDim != 0) {
ASSERT_EQ(TConfig::VitConfig::kModelDim, config.vit_model_dim); ASSERT_EQ(TConfig::VitConfig::kModelDim, config.vit_config.model_dim);
ASSERT_EQ(TConfig::VitConfig::kSeqLen, config.vit_seq_len); ASSERT_EQ(TConfig::VitConfig::kSeqLen, config.vit_config.seq_len);
ASSERT_EQ(TConfig::VitConfig::kNumTensorScales, config.num_vit_scales); ASSERT_EQ(TConfig::VitConfig::kNumTensorScales,
for (size_t i = 0; i < config.vit_layer_configs.size(); ++i) { config.vit_config.num_scales);
for (size_t i = 0; i < config.vit_config.layer_configs.size(); ++i) {
ASSERT_EQ(TConfig::VitConfig::kLayerConfig[i], ASSERT_EQ(TConfig::VitConfig::kLayerConfig[i],
config.vit_layer_configs[i].type); config.vit_config.layer_configs[i].type);
} }
} }
ASSERT_EQ(TConfig::kVocabSize, config.vocab_size); ASSERT_EQ(TConfig::kVocabSize, config.vocab_size);

View File

@ -1042,9 +1042,9 @@ template <typename T>
HWY_NOINLINE void EmbedImagePatches(const Image& image, HWY_NOINLINE void EmbedImagePatches(const Image& image,
const ModelWeightsPtrs<T>& weights, const ModelWeightsPtrs<T>& weights,
Activations& activations) { Activations& activations) {
const size_t model_dim = weights.weights_config.vit_model_dim; const size_t model_dim = weights.weights_config.vit_config.model_dim;
const size_t patch_width = weights.weights_config.patch_width; const size_t patch_width = weights.weights_config.vit_config.patch_width;
const size_t seq_len = weights.weights_config.vit_seq_len; const size_t seq_len = weights.weights_config.vit_config.seq_len;
const size_t patch_size = patch_width * patch_width * 3; const size_t patch_size = patch_width * patch_width * 3;
HWY_DASSERT(weights.vit_img_embedding_kernel.NumElements() == HWY_DASSERT(weights.vit_img_embedding_kernel.NumElements() ==
patch_size * model_dim); patch_size * model_dim);
@ -1087,14 +1087,15 @@ HWY_NOINLINE void PrefillVit(const ModelWeightsPtrs<T>& weights,
const Image& image, ImageTokens& image_tokens, const Image& image, ImageTokens& image_tokens,
Activations& activations) { Activations& activations) {
PROFILER_ZONE("Gen.PrefillVit"); PROFILER_ZONE("Gen.PrefillVit");
const size_t num_tokens = weights.weights_config.vit_seq_len; const size_t num_tokens = weights.weights_config.vit_config.seq_len;
const size_t vit_model_dim = weights.weights_config.vit_model_dim; const size_t vit_model_dim = weights.weights_config.vit_config.model_dim;
HWY_ASSERT(num_tokens == activations.x.BatchSize()); HWY_ASSERT(num_tokens == activations.x.BatchSize());
// Embed the image patches. // Embed the image patches.
EmbedImagePatches(image, weights, activations); EmbedImagePatches(image, weights, activations);
// Go through all layers. // Go through all layers.
for (size_t layer = 0; for (size_t layer = 0;
layer < weights.weights_config.vit_layer_configs.size(); ++layer) { layer < weights.weights_config.vit_config.layer_configs.size();
++layer) {
const auto* layer_weights = weights.GetVitLayer(layer); const auto* layer_weights = weights.GetVitLayer(layer);
VitTransformerLayer(num_tokens, layer, layer_weights, activations); VitTransformerLayer(num_tokens, layer, layer_weights, activations);
} }
@ -1413,11 +1414,11 @@ void GenerateImageTokensT(const ModelWeightsStorage& model,
const RuntimeConfig& runtime_config, const RuntimeConfig& runtime_config,
const Image& image, ImageTokens& image_tokens, const Image& image, ImageTokens& image_tokens,
NestedPools& pools) { NestedPools& pools) {
if (model.Config().vit_layer_configs.empty()) { if (model.Config().vit_config.layer_configs.empty()) {
HWY_ABORT("Model does not support generating image tokens."); HWY_ABORT("Model does not support generating image tokens.");
} }
RuntimeConfig prefill_runtime_config = runtime_config; RuntimeConfig prefill_runtime_config = runtime_config;
ModelConfig vit_config = VitConfig(model.Config()); ModelConfig vit_config = GetVitConfig(model.Config());
prefill_runtime_config.prefill_tbatch_size = vit_config.seq_len; prefill_runtime_config.prefill_tbatch_size = vit_config.seq_len;
Activations prefill_activations(vit_config); Activations prefill_activations(vit_config);
prefill_activations.Allocate(vit_config.seq_len, pools); prefill_activations.Allocate(vit_config.seq_len, pools);

View File

@ -94,11 +94,12 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
Image image; Image image;
ImageTokens image_tokens; ImageTokens image_tokens;
if (have_image) { if (have_image) {
image_tokens = ImageTokens(Extents2D(model.GetModelConfig().vit_seq_len, image_tokens =
model.GetModelConfig().model_dim)); ImageTokens(Extents2D(model.GetModelConfig().vit_config.seq_len,
model.GetModelConfig().model_dim));
HWY_ASSERT(model.Info().wrapping == PromptWrapping::PALIGEMMA); HWY_ASSERT(model.Info().wrapping == PromptWrapping::PALIGEMMA);
HWY_ASSERT(image.ReadPPM(args.image_file.path)); HWY_ASSERT(image.ReadPPM(args.image_file.path));
const size_t image_size = model.GetModelConfig().image_size; const size_t image_size = model.GetModelConfig().vit_config.image_size;
image.Resize(image_size, image_size); image.Resize(image_size, image_size);
RuntimeConfig runtime_config = { RuntimeConfig runtime_config = {
.gen = &gen, .verbosity = app.verbosity, .use_spinning = app.spin}; .gen = &gen, .verbosity = app.verbosity, .use_spinning = app.spin};

View File

@ -36,29 +36,29 @@ std::vector<TensorInfo> ModelTensors(const ModelConfig& config) {
.name = "enc_norm_bias", .name = "enc_norm_bias",
.source_names = {"img/Transformer/encoder_norm/bias"}, .source_names = {"img/Transformer/encoder_norm/bias"},
.axes = {0}, .axes = {0},
.shape = {config.vit_model_dim}, .shape = {config.vit_config.model_dim},
.min_size = Type::kBF16, .min_size = Type::kBF16,
}, },
TensorInfo{ TensorInfo{
.name = "enc_norm_scale", .name = "enc_norm_scale",
.source_names = {"img/Transformer/encoder_norm/scale"}, .source_names = {"img/Transformer/encoder_norm/scale"},
.axes = {0}, .axes = {0},
.shape = {config.vit_model_dim}, .shape = {config.vit_config.model_dim},
.min_size = Type::kBF16, .min_size = Type::kBF16,
}, },
TensorInfo{ TensorInfo{
.name = "img_emb_bias", .name = "img_emb_bias",
.source_names = {"img/embedding/bias"}, .source_names = {"img/embedding/bias"},
.axes = {0}, .axes = {0},
.shape = {config.vit_model_dim}, .shape = {config.vit_config.model_dim},
.min_size = Type::kF32, .min_size = Type::kF32,
}, },
TensorInfo{ TensorInfo{
.name = "img_emb_kernel", .name = "img_emb_kernel",
.source_names = {"img/embedding/kernel"}, .source_names = {"img/embedding/kernel"},
.axes = {3, 0, 1, 2}, .axes = {3, 0, 1, 2},
.shape = {config.vit_model_dim, config.patch_width, .shape = {config.vit_config.model_dim, config.vit_config.patch_width,
config.patch_width, 3}, config.vit_config.patch_width, 3},
.min_size = Type::kBF16, .min_size = Type::kBF16,
.cols_take_extra_dims = true, .cols_take_extra_dims = true,
}, },
@ -73,14 +73,15 @@ std::vector<TensorInfo> ModelTensors(const ModelConfig& config) {
.name = "img_head_kernel", .name = "img_head_kernel",
.source_names = {"img/head/kernel"}, .source_names = {"img/head/kernel"},
.axes = {1, 0}, .axes = {1, 0},
.shape = {config.model_dim, config.vit_model_dim}, .shape = {config.model_dim, config.vit_config.model_dim},
.min_size = Type::kBF16, .min_size = Type::kBF16,
}, },
TensorInfo{ TensorInfo{
.name = "img_pos_emb", .name = "img_pos_emb",
.source_names = {"img/pos_embedding"}, .source_names = {"img/pos_embedding"},
.axes = {0, 1}, .axes = {0, 1},
.shape = {/*1,*/ config.vit_seq_len, config.vit_model_dim}, .shape = {/*1,*/ config.vit_config.seq_len,
config.vit_config.model_dim},
.min_size = Type::kF32, .min_size = Type::kF32,
}, },
}; };
@ -95,7 +96,7 @@ std::vector<TensorInfo> ImageLayerTensors(const ModelConfig& config,
.name = "attn_out_w", .name = "attn_out_w",
.source_names = {"MultiHeadDotProductAttention_0/out/kernel"}, .source_names = {"MultiHeadDotProductAttention_0/out/kernel"},
.axes = {2, 0, 1}, .axes = {2, 0, 1},
.shape = {config.vit_model_dim, layer_config.heads, .shape = {config.vit_config.model_dim, layer_config.heads,
layer_config.qkv_dim}, layer_config.qkv_dim},
.min_size = Type::kBF16, .min_size = Type::kBF16,
.cols_take_extra_dims = true, .cols_take_extra_dims = true,
@ -104,7 +105,7 @@ std::vector<TensorInfo> ImageLayerTensors(const ModelConfig& config,
.name = "attn_out_b", .name = "attn_out_b",
.source_names = {"MultiHeadDotProductAttention_0/out/bias"}, .source_names = {"MultiHeadDotProductAttention_0/out/bias"},
.axes = {0}, .axes = {0},
.shape = {config.vit_model_dim}, .shape = {config.vit_config.model_dim},
.min_size = Type::kF32, .min_size = Type::kF32,
}, },
TensorInfo{ TensorInfo{
@ -112,7 +113,7 @@ std::vector<TensorInfo> ImageLayerTensors(const ModelConfig& config,
.source_names = {"MultiHeadDotProductAttention_0/query/kernel"}, .source_names = {"MultiHeadDotProductAttention_0/query/kernel"},
.axes = {1, 2, 0}, .axes = {1, 2, 0},
.shape = {layer_config.heads, layer_config.qkv_dim, .shape = {layer_config.heads, layer_config.qkv_dim,
config.vit_model_dim}, config.vit_config.model_dim},
.concat_names = {"qkv_ein_w", "k_ein_w", "v_ein_w"}, .concat_names = {"qkv_ein_w", "k_ein_w", "v_ein_w"},
.concat_axis = 1, .concat_axis = 1,
.min_size = Type::kBF16, .min_size = Type::kBF16,
@ -122,7 +123,7 @@ std::vector<TensorInfo> ImageLayerTensors(const ModelConfig& config,
.source_names = {"MultiHeadDotProductAttention_0/key/kernel"}, .source_names = {"MultiHeadDotProductAttention_0/key/kernel"},
.axes = {1, 2, 0}, .axes = {1, 2, 0},
.shape = {layer_config.heads, layer_config.qkv_dim, .shape = {layer_config.heads, layer_config.qkv_dim,
config.vit_model_dim}, config.vit_config.model_dim},
.concat_names = {""}, .concat_names = {""},
.min_size = Type::kBF16, .min_size = Type::kBF16,
}, },
@ -131,7 +132,7 @@ std::vector<TensorInfo> ImageLayerTensors(const ModelConfig& config,
.source_names = {"MultiHeadDotProductAttention_0/value/kernel"}, .source_names = {"MultiHeadDotProductAttention_0/value/kernel"},
.axes = {1, 2, 0}, .axes = {1, 2, 0},
.shape = {layer_config.heads, layer_config.qkv_dim, .shape = {layer_config.heads, layer_config.qkv_dim,
config.vit_model_dim}, config.vit_config.model_dim},
.concat_names = {""}, .concat_names = {""},
.min_size = Type::kBF16, .min_size = Type::kBF16,
}, },
@ -140,7 +141,7 @@ std::vector<TensorInfo> ImageLayerTensors(const ModelConfig& config,
.source_names = {"MultiHeadDotProductAttention_0/qkv/kernel"}, .source_names = {"MultiHeadDotProductAttention_0/qkv/kernel"},
.axes = {1, 2, 0}, .axes = {1, 2, 0},
.shape = {layer_config.heads, 3 * layer_config.qkv_dim, .shape = {layer_config.heads, 3 * layer_config.qkv_dim,
config.vit_model_dim}, config.vit_config.model_dim},
.min_size = Type::kBF16, .min_size = Type::kBF16,
}, },
TensorInfo{ TensorInfo{
@ -180,7 +181,7 @@ std::vector<TensorInfo> ImageLayerTensors(const ModelConfig& config,
.name = "linear_0_w", .name = "linear_0_w",
.source_names = {"MlpBlock_0/Dense_0/kernel"}, .source_names = {"MlpBlock_0/Dense_0/kernel"},
.axes = {1, 0}, .axes = {1, 0},
.shape = {layer_config.ff_hidden_dim, config.vit_model_dim}, .shape = {layer_config.ff_hidden_dim, config.vit_config.model_dim},
.min_size = Type::kBF16, .min_size = Type::kBF16,
}, },
TensorInfo{ TensorInfo{
@ -194,42 +195,42 @@ std::vector<TensorInfo> ImageLayerTensors(const ModelConfig& config,
.name = "linear_1_w", .name = "linear_1_w",
.source_names = {"MlpBlock_0/Dense_1/kernel"}, .source_names = {"MlpBlock_0/Dense_1/kernel"},
.axes = {1, 0}, .axes = {1, 0},
.shape = {config.vit_model_dim, layer_config.ff_hidden_dim}, .shape = {config.vit_config.model_dim, layer_config.ff_hidden_dim},
.min_size = Type::kBF16, .min_size = Type::kBF16,
}, },
TensorInfo{ TensorInfo{
.name = "linear_1_b", .name = "linear_1_b",
.source_names = {"MlpBlock_0/Dense_1/bias"}, .source_names = {"MlpBlock_0/Dense_1/bias"},
.axes = {0}, .axes = {0},
.shape = {config.vit_model_dim}, .shape = {config.vit_config.model_dim},
.min_size = Type::kF32, .min_size = Type::kF32,
}, },
TensorInfo{ TensorInfo{
.name = "ln_0_bias", .name = "ln_0_bias",
.source_names = {"img/Transformer/encoderblock/LayerNorm_0/bias"}, .source_names = {"img/Transformer/encoderblock/LayerNorm_0/bias"},
.axes = {0}, .axes = {0},
.shape = {config.vit_model_dim}, .shape = {config.vit_config.model_dim},
.min_size = Type::kBF16, .min_size = Type::kBF16,
}, },
TensorInfo{ TensorInfo{
.name = "ln_0_scale", .name = "ln_0_scale",
.source_names = {"img/Transformer/encoderblock/LayerNorm_0/scale"}, .source_names = {"img/Transformer/encoderblock/LayerNorm_0/scale"},
.axes = {0}, .axes = {0},
.shape = {config.vit_model_dim}, .shape = {config.vit_config.model_dim},
.min_size = Type::kBF16, .min_size = Type::kBF16,
}, },
TensorInfo{ TensorInfo{
.name = "ln_1_bias", .name = "ln_1_bias",
.source_names = {"img/Transformer/encoderblock/LayerNorm_1/bias"}, .source_names = {"img/Transformer/encoderblock/LayerNorm_1/bias"},
.axes = {0}, .axes = {0},
.shape = {config.vit_model_dim}, .shape = {config.vit_config.model_dim},
.min_size = Type::kBF16, .min_size = Type::kBF16,
}, },
TensorInfo{ TensorInfo{
.name = "ln_1_scale", .name = "ln_1_scale",
.source_names = {"img/Transformer/encoderblock/LayerNorm_1/scale"}, .source_names = {"img/Transformer/encoderblock/LayerNorm_1/scale"},
.axes = {0}, .axes = {0},
.shape = {config.vit_model_dim}, .shape = {config.vit_config.model_dim},
.min_size = Type::kBF16, .min_size = Type::kBF16,
}, },
}; };
@ -526,8 +527,8 @@ TensorIndex::TensorIndex(const ModelConfig& config, int llm_layer_idx,
if (llm_layer_idx < 0 && img_layer_idx < 0) { if (llm_layer_idx < 0 && img_layer_idx < 0) {
tensors_ = ModelTensors(config); tensors_ = ModelTensors(config);
} else if (llm_layer_idx_ < 0 && 0 <= img_layer_idx && } else if (llm_layer_idx_ < 0 && 0 <= img_layer_idx &&
img_layer_idx < config.vit_layer_configs.size()) { img_layer_idx < config.vit_config.layer_configs.size()) {
const auto& layer_config = config.vit_layer_configs[img_layer_idx]; const auto& layer_config = config.vit_config.layer_configs[img_layer_idx];
tensors_ = ImageLayerTensors(config, layer_config); tensors_ = ImageLayerTensors(config, layer_config);
} else if (0 <= llm_layer_idx && } else if (0 <= llm_layer_idx &&
llm_layer_idx < config.layer_configs.size()) { llm_layer_idx < config.layer_configs.size()) {

View File

@ -35,7 +35,7 @@ TEST(TensorIndexTest, FindName) {
/*split_and_reshape=*/false); /*split_and_reshape=*/false);
} }
for (size_t img_layer_idx = 0; for (size_t img_layer_idx = 0;
img_layer_idx < config.vit_layer_configs.size(); img_layer_idx < config.vit_config.layer_configs.size();
++img_layer_idx) { ++img_layer_idx) {
tensor_indexes.emplace_back(config, /*llm_layer_idx=*/-1, tensor_indexes.emplace_back(config, /*llm_layer_idx=*/-1,
static_cast<int>(img_layer_idx), static_cast<int>(img_layer_idx),

View File

@ -86,7 +86,7 @@ BlobError ModelWeightsStorage::Load(const Path& weights, Model model_type,
config_ = ConfigFromModel(model_type); config_ = ConfigFromModel(model_type);
config_.weight = weight_type; config_.weight = weight_type;
config_.wrapping = wrapping; config_.wrapping = wrapping;
scales.resize(config_.num_tensor_scales + config_.num_vit_scales); scales.resize(config_.num_tensor_scales + config_.vit_config.num_scales);
} }
CreateForType(config_.weight, pool); CreateForType(config_.weight, pool);
CallForModelWeightT<TensorLoader>(fet, loader); CallForModelWeightT<TensorLoader>(fet, loader);

View File

@ -344,8 +344,9 @@ struct ModelWeightsPtrs {
c_layers.push_back(LayerWeightsPtrs<Weight>(layer_config, tensor_index)); c_layers.push_back(LayerWeightsPtrs<Weight>(layer_config, tensor_index));
} }
for (int index = 0; for (int index = 0;
index < static_cast<int>(config.vit_layer_configs.size()); ++index) { index < static_cast<int>(config.vit_config.layer_configs.size());
const auto& layer_config = config.vit_layer_configs[index]; ++index) {
const auto& layer_config = config.vit_config.layer_configs[index];
TensorIndex tensor_index(config, /*llm_layer_idx=*/-1, index, TensorIndex tensor_index(config, /*llm_layer_idx=*/-1, index,
/*reshape_att=*/false); /*reshape_att=*/false);
vit_layers.push_back( vit_layers.push_back(
@ -479,7 +480,7 @@ struct ModelWeightsPtrs {
int sep_index = -1; int sep_index = -1;
GEMMA_CALL_FUNC(embedder_input_embedding); GEMMA_CALL_FUNC(embedder_input_embedding);
GEMMA_CALL_FUNC(final_norm_scale); GEMMA_CALL_FUNC(final_norm_scale);
if (ptrs[0]->weights_config.vit_layer_configs.size() > 0) { if (ptrs[0]->weights_config.vit_config.layer_configs.size() > 0) {
// Vit parts. // Vit parts.
GEMMA_CALL_FUNC(vit_encoder_norm_bias); GEMMA_CALL_FUNC(vit_encoder_norm_bias);
GEMMA_CALL_FUNC(vit_encoder_norm_scale); GEMMA_CALL_FUNC(vit_encoder_norm_scale);
@ -498,7 +499,7 @@ struct ModelWeightsPtrs {
} }
// Vit layers. Not supported for compress_weights. // Vit layers. Not supported for compress_weights.
if (ptrs[0]->weights_config.vit_layer_configs.size() > 0) { if (ptrs[0]->weights_config.vit_config.layer_configs.size() > 0) {
for (int layer_idx = 0; layer_idx < ptrs[0]->vit_layers.size(); for (int layer_idx = 0; layer_idx < ptrs[0]->vit_layers.size();
++layer_idx) { ++layer_idx) {
auto type = ptrs[0]->vit_layers[layer_idx].layer_config.type; auto type = ptrs[0]->vit_layers[layer_idx].layer_config.type;

View File

@ -50,12 +50,13 @@ class PaliGemmaTest : public ::testing::Test {
void PaliGemmaTest::InitVit(const std::string& path) { void PaliGemmaTest::InitVit(const std::string& path) {
ASSERT_NE(s_env->GetModel(), nullptr); ASSERT_NE(s_env->GetModel(), nullptr);
Gemma& model = *(s_env->GetModel()); Gemma& model = *(s_env->GetModel());
image_tokens_ = ImageTokens(Extents2D(model.GetModelConfig().vit_seq_len, image_tokens_ =
model.GetModelConfig().model_dim)); ImageTokens(Extents2D(model.GetModelConfig().vit_config.seq_len,
model.GetModelConfig().model_dim));
Image image; Image image;
HWY_ASSERT(model.Info().wrapping == PromptWrapping::PALIGEMMA); HWY_ASSERT(model.Info().wrapping == PromptWrapping::PALIGEMMA);
HWY_ASSERT(image.ReadPPM(path)); HWY_ASSERT(image.ReadPPM(path));
const size_t image_size = model.GetModelConfig().image_size; const size_t image_size = model.GetModelConfig().vit_config.image_size;
image.Resize(image_size, image_size); image.Resize(image_size, image_size);
RuntimeConfig runtime_config = {.gen = &s_env->MutableGen(), .verbosity = 0}; RuntimeConfig runtime_config = {.gen = &s_env->MutableGen(), .verbosity = 0};
model.GenerateImageTokens(runtime_config, image, image_tokens_); model.GenerateImageTokens(runtime_config, image, image_tokens_);