Moved the vit config fields to their own config struct

PiperOrigin-RevId: 715692800
This commit is contained in:
Ray Smith 2025-01-15 01:09:16 -08:00 committed by Copybara-Service
parent 9d40f0117e
commit b93231a47d
10 changed files with 119 additions and 84 deletions

View File

@ -253,18 +253,19 @@ static LayerConfig LayerConfigVit(size_t model_dim) {
// Adds a ViT config (SigLIP SoViT ViT, used in PaliGemma) to the model config.
static void AddVitConfig(ModelConfig& config, size_t image_size = 224) {
config.vit_model_dim = 1152;
config.vit_config.model_dim = 1152;
config.vocab_size = 256000 + 1024 + 128; // = 257152
config.image_size = image_size;
config.patch_width = 14;
config.vit_config.image_size = image_size;
config.vit_config.patch_width = 14;
const size_t num_patches =
config.vit_config.image_size / config.vit_config.patch_width;
config.vit_config.seq_len = num_patches * num_patches;
for (auto& layer_config : config.layer_configs) {
layer_config.optimized_gating = false;
}
const size_t num_patches = config.image_size / config.patch_width;
config.vit_seq_len = num_patches * num_patches;
LayerConfig vit_layer_config = LayerConfigVit(config.vit_model_dim);
config.vit_layer_configs = {27, vit_layer_config};
config.num_vit_scales = 4 * config.vit_layer_configs.size();
LayerConfig vit_layer_config = LayerConfigVit(config.vit_config.model_dim);
config.vit_config.layer_configs = {27, vit_layer_config};
config.vit_config.num_scales = 4 * config.vit_config.layer_configs.size();
}
static ModelConfig ConfigPaliGemma_224() {
@ -283,11 +284,11 @@ static ModelConfig ConfigPaliGemma_448() {
return config;
}
ModelConfig VitConfig(const ModelConfig& config) {
ModelConfig GetVitConfig(const ModelConfig& config) {
ModelConfig vit_config = ConfigNoSSM();
vit_config.model_dim = config.vit_model_dim;
vit_config.seq_len = config.vit_seq_len;
vit_config.layer_configs = config.vit_layer_configs;
vit_config.model_dim = config.vit_config.model_dim;
vit_config.seq_len = config.vit_config.seq_len;
vit_config.layer_configs = config.vit_config.layer_configs;
// The Vit part does not have a vocabulary, the image patches are embedded.
vit_config.vocab_size = 0;
return vit_config;
@ -402,9 +403,28 @@ bool LayerConfig::TestEqual(const LayerConfig& other, bool partial,
return result;
}
bool VitConfig::TestEqual(const VitConfig& other, bool partial,
bool debug) const {
bool result = true;
TEST_EQUAL(model_dim, other.model_dim);
TEST_EQUAL(seq_len, other.seq_len);
if (!partial) {
TEST_EQUAL(num_scales, other.num_scales);
}
TEST_EQUAL(patch_width, other.patch_width);
TEST_EQUAL(image_size, other.image_size);
RETURN_IF_NOT_EQUAL(layer_configs.size(), other.layer_configs.size());
for (size_t i = 0; i < layer_configs.size(); ++i) {
result &=
layer_configs[i].TestEqual(other.layer_configs[i], partial, debug);
}
return result;
}
bool ModelConfig::TestEqual(const ModelConfig& other, bool partial,
bool debug) const {
bool result = true;
TEST_EQUAL(model_family_version, other.model_family_version);
// We don't care about model_name, model, wrapping, or weight being different,
// but will output in debug mode if they are.
if (debug) {
@ -415,13 +435,10 @@ bool ModelConfig::TestEqual(const ModelConfig& other, bool partial,
WARN_IF_NOT_EQUAL(static_cast<int>(weight), static_cast<int>(other.weight));
}
TEST_EQUAL(model_dim, other.model_dim);
TEST_EQUAL(vit_model_dim, other.vit_model_dim);
TEST_EQUAL(vocab_size, other.vocab_size);
TEST_EQUAL(seq_len, other.seq_len);
TEST_EQUAL(vit_seq_len, other.vit_seq_len);
if (!partial) {
TEST_EQUAL(num_tensor_scales, other.num_tensor_scales);
TEST_EQUAL(num_vit_scales, other.num_vit_scales);
}
TEST_EQUAL(att_cap, other.att_cap);
TEST_EQUAL(final_cap, other.final_cap);
@ -439,11 +456,6 @@ bool ModelConfig::TestEqual(const ModelConfig& other, bool partial,
for (size_t i = 0; i < attention_window_sizes.size(); ++i) {
TEST_EQUAL(attention_window_sizes[i], other.attention_window_sizes[i]);
}
RETURN_IF_NOT_EQUAL(vit_layer_configs.size(), other.vit_layer_configs.size());
for (size_t i = 0; i < vit_layer_configs.size(); ++i) {
result &= vit_layer_configs[i].TestEqual(other.vit_layer_configs[i],
partial, debug);
}
if (!partial) {
if (scale_names != other.scale_names) {
result = false;
@ -453,9 +465,7 @@ bool ModelConfig::TestEqual(const ModelConfig& other, bool partial,
}
}
TEST_EQUAL(norm_num_groups, other.norm_num_groups);
TEST_EQUAL(model_family_version, other.model_family_version);
TEST_EQUAL(patch_width, other.patch_width);
TEST_EQUAL(image_size, other.image_size);
result &= vit_config.TestEqual(other.vit_config, partial, debug);
return result;
}

View File

@ -220,6 +220,33 @@ struct LayerConfig : public IFields {
PostQKType post_qk = PostQKType::Rope;
};
// Dimensions related to image processing.
struct VitConfig : public IFields {
// Returns true if *this and other are equal.
// If partial is true, then we don't check for items that are only set after
// the tensors are loaded from the checkpoint.
// If debug is true, then we output the mismatched fields to stderr.
bool TestEqual(const VitConfig& other, bool partial, bool debug) const;
const char* Name() const override { return "VitConfig"; }
void VisitFields(IFieldsVisitor& visitor) override {
visitor(model_dim);
visitor(seq_len);
visitor(num_scales);
visitor(patch_width);
visitor(image_size);
visitor(layer_configs);
}
uint32_t model_dim = 0;
uint32_t seq_len = 0;
uint32_t num_scales = 0;
uint32_t patch_width = 14;
uint32_t image_size = 224;
std::vector<LayerConfig> layer_configs;
};
struct ModelConfig : public IFields {
// Returns true if *this and other are equal.
// If partial is true, then we don't check for items that are only set after
@ -277,26 +304,21 @@ struct ModelConfig : public IFields {
visitor(layer_configs);
visitor(attention_window_sizes);
visitor(norm_num_groups);
visitor(vit_model_dim);
visitor(vit_seq_len);
visitor(num_vit_scales);
visitor(vit_layer_configs);
visitor(patch_width);
visitor(image_size);
visitor(vit_config);
}
// Major version of the model family. It is used as a fallback to distinguish
// between model types when there is no explicit information in the config.
uint32_t model_family_version = 1;
std::string model_name;
Model model = Model::UNKNOWN;
PromptWrapping wrapping = PromptWrapping::GEMMA_PT;
Type weight = Type::kUnknown;
uint32_t num_layers = 0;
uint32_t model_dim = 0;
uint32_t vit_model_dim = 0;
uint32_t vocab_size = 0;
uint32_t seq_len = 0;
uint32_t vit_seq_len = 0;
uint32_t num_tensor_scales = 0;
uint32_t num_vit_scales = 0;
float att_cap = 0.0f;
float final_cap = 0.0f;
bool absolute_pe = false;
@ -304,13 +326,10 @@ struct ModelConfig : public IFields {
QueryScaleType query_scale = QueryScaleType::SqrtKeySize;
std::vector<LayerConfig> layer_configs;
std::vector<uint32_t> attention_window_sizes;
std::vector<LayerConfig> vit_layer_configs;
std::unordered_set<std::string> scale_names;
uint32_t norm_num_groups = 1;
uint32_t model_family_version = 1;
// Dimensions related to image processing.
uint32_t patch_width = 14;
uint32_t image_size = 224;
VitConfig vit_config;
};
// Returns the config for the given model.
@ -320,7 +339,7 @@ ModelConfig ConfigFromModel(Model model);
Model ModelFromConfig(const ModelConfig& config);
// Returns the sub-config for the ViT model of the PaliGemma model.
ModelConfig VitConfig(const ModelConfig& config);
ModelConfig GetVitConfig(const ModelConfig& config);
} // namespace gcpp

View File

@ -367,12 +367,13 @@ template <class TConfig>
void AssertMatch(const ModelConfig& config) {
ASSERT_EQ(TConfig::kModelDim, config.model_dim);
if constexpr (TConfig::VitConfig::kModelDim != 0) {
ASSERT_EQ(TConfig::VitConfig::kModelDim, config.vit_model_dim);
ASSERT_EQ(TConfig::VitConfig::kSeqLen, config.vit_seq_len);
ASSERT_EQ(TConfig::VitConfig::kNumTensorScales, config.num_vit_scales);
for (size_t i = 0; i < config.vit_layer_configs.size(); ++i) {
ASSERT_EQ(TConfig::VitConfig::kModelDim, config.vit_config.model_dim);
ASSERT_EQ(TConfig::VitConfig::kSeqLen, config.vit_config.seq_len);
ASSERT_EQ(TConfig::VitConfig::kNumTensorScales,
config.vit_config.num_scales);
for (size_t i = 0; i < config.vit_config.layer_configs.size(); ++i) {
ASSERT_EQ(TConfig::VitConfig::kLayerConfig[i],
config.vit_layer_configs[i].type);
config.vit_config.layer_configs[i].type);
}
}
ASSERT_EQ(TConfig::kVocabSize, config.vocab_size);

View File

@ -1042,9 +1042,9 @@ template <typename T>
HWY_NOINLINE void EmbedImagePatches(const Image& image,
const ModelWeightsPtrs<T>& weights,
Activations& activations) {
const size_t model_dim = weights.weights_config.vit_model_dim;
const size_t patch_width = weights.weights_config.patch_width;
const size_t seq_len = weights.weights_config.vit_seq_len;
const size_t model_dim = weights.weights_config.vit_config.model_dim;
const size_t patch_width = weights.weights_config.vit_config.patch_width;
const size_t seq_len = weights.weights_config.vit_config.seq_len;
const size_t patch_size = patch_width * patch_width * 3;
HWY_DASSERT(weights.vit_img_embedding_kernel.NumElements() ==
patch_size * model_dim);
@ -1087,14 +1087,15 @@ HWY_NOINLINE void PrefillVit(const ModelWeightsPtrs<T>& weights,
const Image& image, ImageTokens& image_tokens,
Activations& activations) {
PROFILER_ZONE("Gen.PrefillVit");
const size_t num_tokens = weights.weights_config.vit_seq_len;
const size_t vit_model_dim = weights.weights_config.vit_model_dim;
const size_t num_tokens = weights.weights_config.vit_config.seq_len;
const size_t vit_model_dim = weights.weights_config.vit_config.model_dim;
HWY_ASSERT(num_tokens == activations.x.BatchSize());
// Embed the image patches.
EmbedImagePatches(image, weights, activations);
// Go through all layers.
for (size_t layer = 0;
layer < weights.weights_config.vit_layer_configs.size(); ++layer) {
layer < weights.weights_config.vit_config.layer_configs.size();
++layer) {
const auto* layer_weights = weights.GetVitLayer(layer);
VitTransformerLayer(num_tokens, layer, layer_weights, activations);
}
@ -1413,11 +1414,11 @@ void GenerateImageTokensT(const ModelWeightsStorage& model,
const RuntimeConfig& runtime_config,
const Image& image, ImageTokens& image_tokens,
NestedPools& pools) {
if (model.Config().vit_layer_configs.empty()) {
if (model.Config().vit_config.layer_configs.empty()) {
HWY_ABORT("Model does not support generating image tokens.");
}
RuntimeConfig prefill_runtime_config = runtime_config;
ModelConfig vit_config = VitConfig(model.Config());
ModelConfig vit_config = GetVitConfig(model.Config());
prefill_runtime_config.prefill_tbatch_size = vit_config.seq_len;
Activations prefill_activations(vit_config);
prefill_activations.Allocate(vit_config.seq_len, pools);

View File

@ -94,11 +94,12 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
Image image;
ImageTokens image_tokens;
if (have_image) {
image_tokens = ImageTokens(Extents2D(model.GetModelConfig().vit_seq_len,
image_tokens =
ImageTokens(Extents2D(model.GetModelConfig().vit_config.seq_len,
model.GetModelConfig().model_dim));
HWY_ASSERT(model.Info().wrapping == PromptWrapping::PALIGEMMA);
HWY_ASSERT(image.ReadPPM(args.image_file.path));
const size_t image_size = model.GetModelConfig().image_size;
const size_t image_size = model.GetModelConfig().vit_config.image_size;
image.Resize(image_size, image_size);
RuntimeConfig runtime_config = {
.gen = &gen, .verbosity = app.verbosity, .use_spinning = app.spin};

View File

@ -36,29 +36,29 @@ std::vector<TensorInfo> ModelTensors(const ModelConfig& config) {
.name = "enc_norm_bias",
.source_names = {"img/Transformer/encoder_norm/bias"},
.axes = {0},
.shape = {config.vit_model_dim},
.shape = {config.vit_config.model_dim},
.min_size = Type::kBF16,
},
TensorInfo{
.name = "enc_norm_scale",
.source_names = {"img/Transformer/encoder_norm/scale"},
.axes = {0},
.shape = {config.vit_model_dim},
.shape = {config.vit_config.model_dim},
.min_size = Type::kBF16,
},
TensorInfo{
.name = "img_emb_bias",
.source_names = {"img/embedding/bias"},
.axes = {0},
.shape = {config.vit_model_dim},
.shape = {config.vit_config.model_dim},
.min_size = Type::kF32,
},
TensorInfo{
.name = "img_emb_kernel",
.source_names = {"img/embedding/kernel"},
.axes = {3, 0, 1, 2},
.shape = {config.vit_model_dim, config.patch_width,
config.patch_width, 3},
.shape = {config.vit_config.model_dim, config.vit_config.patch_width,
config.vit_config.patch_width, 3},
.min_size = Type::kBF16,
.cols_take_extra_dims = true,
},
@ -73,14 +73,15 @@ std::vector<TensorInfo> ModelTensors(const ModelConfig& config) {
.name = "img_head_kernel",
.source_names = {"img/head/kernel"},
.axes = {1, 0},
.shape = {config.model_dim, config.vit_model_dim},
.shape = {config.model_dim, config.vit_config.model_dim},
.min_size = Type::kBF16,
},
TensorInfo{
.name = "img_pos_emb",
.source_names = {"img/pos_embedding"},
.axes = {0, 1},
.shape = {/*1,*/ config.vit_seq_len, config.vit_model_dim},
.shape = {/*1,*/ config.vit_config.seq_len,
config.vit_config.model_dim},
.min_size = Type::kF32,
},
};
@ -95,7 +96,7 @@ std::vector<TensorInfo> ImageLayerTensors(const ModelConfig& config,
.name = "attn_out_w",
.source_names = {"MultiHeadDotProductAttention_0/out/kernel"},
.axes = {2, 0, 1},
.shape = {config.vit_model_dim, layer_config.heads,
.shape = {config.vit_config.model_dim, layer_config.heads,
layer_config.qkv_dim},
.min_size = Type::kBF16,
.cols_take_extra_dims = true,
@ -104,7 +105,7 @@ std::vector<TensorInfo> ImageLayerTensors(const ModelConfig& config,
.name = "attn_out_b",
.source_names = {"MultiHeadDotProductAttention_0/out/bias"},
.axes = {0},
.shape = {config.vit_model_dim},
.shape = {config.vit_config.model_dim},
.min_size = Type::kF32,
},
TensorInfo{
@ -112,7 +113,7 @@ std::vector<TensorInfo> ImageLayerTensors(const ModelConfig& config,
.source_names = {"MultiHeadDotProductAttention_0/query/kernel"},
.axes = {1, 2, 0},
.shape = {layer_config.heads, layer_config.qkv_dim,
config.vit_model_dim},
config.vit_config.model_dim},
.concat_names = {"qkv_ein_w", "k_ein_w", "v_ein_w"},
.concat_axis = 1,
.min_size = Type::kBF16,
@ -122,7 +123,7 @@ std::vector<TensorInfo> ImageLayerTensors(const ModelConfig& config,
.source_names = {"MultiHeadDotProductAttention_0/key/kernel"},
.axes = {1, 2, 0},
.shape = {layer_config.heads, layer_config.qkv_dim,
config.vit_model_dim},
config.vit_config.model_dim},
.concat_names = {""},
.min_size = Type::kBF16,
},
@ -131,7 +132,7 @@ std::vector<TensorInfo> ImageLayerTensors(const ModelConfig& config,
.source_names = {"MultiHeadDotProductAttention_0/value/kernel"},
.axes = {1, 2, 0},
.shape = {layer_config.heads, layer_config.qkv_dim,
config.vit_model_dim},
config.vit_config.model_dim},
.concat_names = {""},
.min_size = Type::kBF16,
},
@ -140,7 +141,7 @@ std::vector<TensorInfo> ImageLayerTensors(const ModelConfig& config,
.source_names = {"MultiHeadDotProductAttention_0/qkv/kernel"},
.axes = {1, 2, 0},
.shape = {layer_config.heads, 3 * layer_config.qkv_dim,
config.vit_model_dim},
config.vit_config.model_dim},
.min_size = Type::kBF16,
},
TensorInfo{
@ -180,7 +181,7 @@ std::vector<TensorInfo> ImageLayerTensors(const ModelConfig& config,
.name = "linear_0_w",
.source_names = {"MlpBlock_0/Dense_0/kernel"},
.axes = {1, 0},
.shape = {layer_config.ff_hidden_dim, config.vit_model_dim},
.shape = {layer_config.ff_hidden_dim, config.vit_config.model_dim},
.min_size = Type::kBF16,
},
TensorInfo{
@ -194,42 +195,42 @@ std::vector<TensorInfo> ImageLayerTensors(const ModelConfig& config,
.name = "linear_1_w",
.source_names = {"MlpBlock_0/Dense_1/kernel"},
.axes = {1, 0},
.shape = {config.vit_model_dim, layer_config.ff_hidden_dim},
.shape = {config.vit_config.model_dim, layer_config.ff_hidden_dim},
.min_size = Type::kBF16,
},
TensorInfo{
.name = "linear_1_b",
.source_names = {"MlpBlock_0/Dense_1/bias"},
.axes = {0},
.shape = {config.vit_model_dim},
.shape = {config.vit_config.model_dim},
.min_size = Type::kF32,
},
TensorInfo{
.name = "ln_0_bias",
.source_names = {"img/Transformer/encoderblock/LayerNorm_0/bias"},
.axes = {0},
.shape = {config.vit_model_dim},
.shape = {config.vit_config.model_dim},
.min_size = Type::kBF16,
},
TensorInfo{
.name = "ln_0_scale",
.source_names = {"img/Transformer/encoderblock/LayerNorm_0/scale"},
.axes = {0},
.shape = {config.vit_model_dim},
.shape = {config.vit_config.model_dim},
.min_size = Type::kBF16,
},
TensorInfo{
.name = "ln_1_bias",
.source_names = {"img/Transformer/encoderblock/LayerNorm_1/bias"},
.axes = {0},
.shape = {config.vit_model_dim},
.shape = {config.vit_config.model_dim},
.min_size = Type::kBF16,
},
TensorInfo{
.name = "ln_1_scale",
.source_names = {"img/Transformer/encoderblock/LayerNorm_1/scale"},
.axes = {0},
.shape = {config.vit_model_dim},
.shape = {config.vit_config.model_dim},
.min_size = Type::kBF16,
},
};
@ -526,8 +527,8 @@ TensorIndex::TensorIndex(const ModelConfig& config, int llm_layer_idx,
if (llm_layer_idx < 0 && img_layer_idx < 0) {
tensors_ = ModelTensors(config);
} else if (llm_layer_idx_ < 0 && 0 <= img_layer_idx &&
img_layer_idx < config.vit_layer_configs.size()) {
const auto& layer_config = config.vit_layer_configs[img_layer_idx];
img_layer_idx < config.vit_config.layer_configs.size()) {
const auto& layer_config = config.vit_config.layer_configs[img_layer_idx];
tensors_ = ImageLayerTensors(config, layer_config);
} else if (0 <= llm_layer_idx &&
llm_layer_idx < config.layer_configs.size()) {

View File

@ -35,7 +35,7 @@ TEST(TensorIndexTest, FindName) {
/*split_and_reshape=*/false);
}
for (size_t img_layer_idx = 0;
img_layer_idx < config.vit_layer_configs.size();
img_layer_idx < config.vit_config.layer_configs.size();
++img_layer_idx) {
tensor_indexes.emplace_back(config, /*llm_layer_idx=*/-1,
static_cast<int>(img_layer_idx),

View File

@ -86,7 +86,7 @@ BlobError ModelWeightsStorage::Load(const Path& weights, Model model_type,
config_ = ConfigFromModel(model_type);
config_.weight = weight_type;
config_.wrapping = wrapping;
scales.resize(config_.num_tensor_scales + config_.num_vit_scales);
scales.resize(config_.num_tensor_scales + config_.vit_config.num_scales);
}
CreateForType(config_.weight, pool);
CallForModelWeightT<TensorLoader>(fet, loader);

View File

@ -344,8 +344,9 @@ struct ModelWeightsPtrs {
c_layers.push_back(LayerWeightsPtrs<Weight>(layer_config, tensor_index));
}
for (int index = 0;
index < static_cast<int>(config.vit_layer_configs.size()); ++index) {
const auto& layer_config = config.vit_layer_configs[index];
index < static_cast<int>(config.vit_config.layer_configs.size());
++index) {
const auto& layer_config = config.vit_config.layer_configs[index];
TensorIndex tensor_index(config, /*llm_layer_idx=*/-1, index,
/*reshape_att=*/false);
vit_layers.push_back(
@ -479,7 +480,7 @@ struct ModelWeightsPtrs {
int sep_index = -1;
GEMMA_CALL_FUNC(embedder_input_embedding);
GEMMA_CALL_FUNC(final_norm_scale);
if (ptrs[0]->weights_config.vit_layer_configs.size() > 0) {
if (ptrs[0]->weights_config.vit_config.layer_configs.size() > 0) {
// Vit parts.
GEMMA_CALL_FUNC(vit_encoder_norm_bias);
GEMMA_CALL_FUNC(vit_encoder_norm_scale);
@ -498,7 +499,7 @@ struct ModelWeightsPtrs {
}
// Vit layers. Not supported for compress_weights.
if (ptrs[0]->weights_config.vit_layer_configs.size() > 0) {
if (ptrs[0]->weights_config.vit_config.layer_configs.size() > 0) {
for (int layer_idx = 0; layer_idx < ptrs[0]->vit_layers.size();
++layer_idx) {
auto type = ptrs[0]->vit_layers[layer_idx].layer_config.type;

View File

@ -50,12 +50,13 @@ class PaliGemmaTest : public ::testing::Test {
void PaliGemmaTest::InitVit(const std::string& path) {
ASSERT_NE(s_env->GetModel(), nullptr);
Gemma& model = *(s_env->GetModel());
image_tokens_ = ImageTokens(Extents2D(model.GetModelConfig().vit_seq_len,
image_tokens_ =
ImageTokens(Extents2D(model.GetModelConfig().vit_config.seq_len,
model.GetModelConfig().model_dim));
Image image;
HWY_ASSERT(model.Info().wrapping == PromptWrapping::PALIGEMMA);
HWY_ASSERT(image.ReadPPM(path));
const size_t image_size = model.GetModelConfig().image_size;
const size_t image_size = model.GetModelConfig().vit_config.image_size;
image.Resize(image_size, image_size);
RuntimeConfig runtime_config = {.gen = &s_env->MutableGen(), .verbosity = 0};
model.GenerateImageTokens(runtime_config, image, image_tokens_);