diff --git a/gemma/configs.cc b/gemma/configs.cc index 03fce99..7724c59 100644 --- a/gemma/configs.cc +++ b/gemma/configs.cc @@ -15,6 +15,8 @@ #include "gemma/configs.h" +#include + #include "hwy/base.h" namespace gcpp { @@ -181,6 +183,7 @@ static ModelConfig ConfigGriffin2B() { .conv1d_width = 4, .ff_biases = true, .softmax_attn_output_biases = true, + .optimized_gating = false, .type = LayerAttentionType::kGriffinRecurrentBlock, .activation = ActivationType::Gelu, .post_qk = PostQKType::HalfRope, @@ -204,6 +207,9 @@ static void AddVitConfig(ModelConfig& config) { config.vocab_size = 256000 + 1024 + 128; // = 257152 config.image_size = 224; config.patch_width = 14; + for (auto& layer_config : config.layer_configs) { + layer_config.optimized_gating = false; + } const size_t num_patches = config.image_size / config.patch_width; config.vit_seq_len = num_patches * num_patches; LayerConfig vit_layer_config = { @@ -260,4 +266,115 @@ ModelConfig ConfigFromModel(Model model) { } } +#define TEST_EQUAL(a, b) \ + if (a != b) { \ + if (debug) \ + std::cerr << #a << "=" << a << " != " << #b << "=" << b << "\n"; \ + result = false; \ + } + +#define RETURN_IF_NOT_EQUAL(a, b) \ + if (a != b) { \ + if (debug) \ + std::cerr << #a << "=" << a << " != " << #b << "=" << b << "\n"; \ + return false; \ + } + +#define WARN_IF_NOT_EQUAL(a, b) \ + if (a != b) { \ + std::cerr << #a << "=" << a << " != " << #b << "=" << b << "\n"; \ + } + +bool LayerConfig::TestEqual(const LayerConfig& other, bool partial, + bool debug) const { + bool result = true; + // Optimized gating may not be set correctly in the c++ configs. + if (debug) { + WARN_IF_NOT_EQUAL(optimized_gating, other.optimized_gating) + } + TEST_EQUAL(model_dim, other.model_dim); + TEST_EQUAL(griffin_dim, other.griffin_dim); + TEST_EQUAL(ff_hidden_dim, other.ff_hidden_dim); + TEST_EQUAL(heads, other.heads); + TEST_EQUAL(kv_heads, other.kv_heads); + TEST_EQUAL(qkv_dim, other.qkv_dim); + TEST_EQUAL(conv1d_width, other.conv1d_width); + if (!partial) { + TEST_EQUAL(ff_biases, other.ff_biases); + TEST_EQUAL(softmax_attn_output_biases, other.softmax_attn_output_biases); + } + TEST_EQUAL(static_cast(post_norm), static_cast(other.post_norm)); + TEST_EQUAL(static_cast(type), static_cast(other.type)); + TEST_EQUAL(static_cast(activation), static_cast(other.activation)); + TEST_EQUAL(static_cast(post_qk), static_cast(other.post_qk)); + return result; +} + +bool ModelConfig::TestEqual(const ModelConfig& other, bool partial, + bool debug) const { + bool result = true; + // We don't care about model_name, model, training, or weight being different, + // but will output in debug mode if they are. + if (debug) { + WARN_IF_NOT_EQUAL(model_name, other.model_name); + WARN_IF_NOT_EQUAL(static_cast(model), static_cast(other.model)); + WARN_IF_NOT_EQUAL(static_cast(training), + static_cast(other.training)); + WARN_IF_NOT_EQUAL(static_cast(weight), static_cast(other.weight)); + } + TEST_EQUAL(model_dim, other.model_dim); + TEST_EQUAL(vit_model_dim, other.vit_model_dim); + TEST_EQUAL(vocab_size, other.vocab_size); + TEST_EQUAL(seq_len, other.seq_len); + TEST_EQUAL(vit_seq_len, other.vit_seq_len); + if (!partial) { + TEST_EQUAL(num_tensor_scales, other.num_tensor_scales); + TEST_EQUAL(num_vit_scales, other.num_vit_scales); + } + TEST_EQUAL(att_cap, other.att_cap); + TEST_EQUAL(final_cap, other.final_cap); + TEST_EQUAL(absolute_pe, other.absolute_pe); + TEST_EQUAL(use_local_attention, other.use_local_attention); + TEST_EQUAL(static_cast(query_scale), + static_cast(other.query_scale)); + RETURN_IF_NOT_EQUAL(layer_configs.size(), other.layer_configs.size()); + for (size_t i = 0; i < layer_configs.size(); ++i) { + result &= + layer_configs[i].TestEqual(other.layer_configs[i], partial, debug); + } + RETURN_IF_NOT_EQUAL(attention_window_sizes.size(), + other.attention_window_sizes.size()); + for (size_t i = 0; i < attention_window_sizes.size(); ++i) { + TEST_EQUAL(attention_window_sizes[i], other.attention_window_sizes[i]); + } + RETURN_IF_NOT_EQUAL(vit_layer_configs.size(), other.vit_layer_configs.size()); + for (size_t i = 0; i < vit_layer_configs.size(); ++i) { + result &= vit_layer_configs[i].TestEqual(other.vit_layer_configs[i], + partial, debug); + } + if (!partial) { + if (scale_names != other.scale_names) { + result = false; + if (debug) { + std::cerr << "scale_names mismatch\n"; + } + } + } + TEST_EQUAL(norm_num_groups, other.norm_num_groups); + TEST_EQUAL(model_family_version, other.model_family_version); + TEST_EQUAL(patch_width, other.patch_width); + TEST_EQUAL(image_size, other.image_size); + return result; +} + +Model ModelFromConfig(const ModelConfig& config) { + for (Model model : kAllModels) { + ModelConfig model_config = ConfigFromModel(model); + if (config.TestEqual(model_config, /*partial=*/true, /*debug=*/false)) { + return model; + } + } + return Model::UNKNOWN; +} + } // namespace gcpp diff --git a/gemma/configs.h b/gemma/configs.h index e709df7..6bbbc45 100644 --- a/gemma/configs.h +++ b/gemma/configs.h @@ -116,7 +116,20 @@ enum class Model { PALIGEMMA_224, }; +// Allows the Model enum to be iterated over. +static constexpr Model kAllModels[] = { + Model::GEMMA_2B, Model::GEMMA_7B, Model::GEMMA2_9B, Model::GEMMA2_27B, + Model::GRIFFIN_2B, Model::GEMMA_TINY, Model::GEMMA2_2B, + Model::PALIGEMMA_224, +}; + struct LayerConfig { + // Returns true if *this and other are equal. + // If partial is true, then we don't check for items that are only set after + // the tensors are loaded from the checkpoint. + // If debug is true, then we output the mismatched fields to stderr. + bool TestEqual(const LayerConfig& other, bool partial, bool debug) const; + size_t CacheLayerSize() const { return kv_heads * qkv_dim * 2; } // Multi-Head Attention? @@ -132,9 +145,10 @@ struct LayerConfig { size_t heads = 0; size_t kv_heads = 0; size_t qkv_dim = 0; - size_t conv1d_width = 0; + size_t conv1d_width = 0; // griffin only bool ff_biases = false; bool softmax_attn_output_biases = false; + bool optimized_gating = true; PostNormType post_norm = PostNormType::None; LayerAttentionType type = LayerAttentionType::kGemma; ActivationType activation = ActivationType::Gelu; @@ -142,6 +156,16 @@ struct LayerConfig { }; struct ModelConfig { + // Returns true if *this and other are equal. + // If partial is true, then we don't check for items that are only set after + // the tensors are loaded from the checkpoint. + // If debug is true, then we output the mismatched fields to stderr. + bool TestEqual(const ModelConfig& other, bool partial, bool debug) const; + + void AddLayerConfig(const LayerConfig& layer_config) { + layer_configs.push_back(layer_config); + } + size_t CachePosSize() const { size_t num_layers = layer_configs.size(); return num_layers * layer_configs[0].CacheLayerSize(); @@ -171,6 +195,7 @@ struct ModelConfig { Model model; ModelTraining training; Type weight; + size_t num_layers = 0; size_t model_dim = 0; size_t vit_model_dim = 0; size_t vocab_size = 0; @@ -181,7 +206,7 @@ struct ModelConfig { float att_cap = 0.0f; float final_cap = 0.0f; bool absolute_pe = false; - bool use_local_attention = false; + bool use_local_attention = false; // griffin only QueryScaleType query_scale = QueryScaleType::SqrtKeySize; std::vector layer_configs; std::vector attention_window_sizes; @@ -190,13 +215,16 @@ struct ModelConfig { int norm_num_groups = 1; int model_family_version = 1; // Dimensions related to image processing. - int patch_width = 14; - int image_size = 224; + size_t patch_width = 14; + size_t image_size = 224; }; // Returns the config for the given model. ModelConfig ConfigFromModel(Model model); +// Returns the model for the given config, if it matches any standard model. +Model ModelFromConfig(const ModelConfig& config); + // Returns the sub-config for the ViT model of the PaliGemma model. ModelConfig VitConfig(const ModelConfig& config); diff --git a/gemma/python/BUILD.bazel b/gemma/python/BUILD.bazel new file mode 100644 index 0000000..d6b09b9 --- /dev/null +++ b/gemma/python/BUILD.bazel @@ -0,0 +1,17 @@ +load("@pybind11_bazel//:build_defs.bzl", "pybind_extension") + +package( + default_applicable_licenses = [ + "//:license", # Placeholder comment, do not modify + ], + default_visibility = ["//visibility:public"], +) + +pybind_extension( + name = "configs", + srcs = ["configs.cc"], + deps = [ + "//:common", + "//compression:sfp", + ], +) diff --git a/gemma/python/configs.cc b/gemma/python/configs.cc new file mode 100644 index 0000000..aff93cc --- /dev/null +++ b/gemma/python/configs.cc @@ -0,0 +1,136 @@ +#include "gemma/configs.h" + +#include +#include + +#include "compression/shared.h" +#include "pybind11/cast.h" + +using gcpp::ActivationType; +using gcpp::LayerAttentionType; +using gcpp::LayerConfig; +using gcpp::Model; +using gcpp::ModelConfig; +using gcpp::ModelTraining; +using gcpp::PostNormType; +using gcpp::PostQKType; +using gcpp::QueryScaleType; +using gcpp::ResidualType; +using gcpp::Type; + +namespace pybind11 { + +PYBIND11_MODULE(configs, py_module) { + enum_(py_module, "ModelTraining") + .value("GEMMA_IT", ModelTraining::GEMMA_IT) + .value("GEMMA_PT", ModelTraining::GEMMA_PT) + .value("PALIGEMMA", ModelTraining::PALIGEMMA); + + enum_(py_module, "Type") + .value("kUnknown", Type::kUnknown) + .value("kF32", Type::kF32) + .value("kBF16", Type::kBF16) + .value("kSFP", Type::kSFP) + .value("kNUQ", Type::kNUQ) + .value("kF64", Type::kF64) + .value("kC64", Type::kC64) + .value("kU128", Type::kU128); + + enum_(py_module, "LayerAttentionType") + .value("kGemma", LayerAttentionType::kGemma) + .value("kGriffinRecurrentBlock", + LayerAttentionType::kGriffinRecurrentBlock) + .value("kVit", LayerAttentionType::kVit); + + enum_(py_module, "PostNormType") + .value("NoPostNorm", PostNormType::None) + .value("Scale", PostNormType::Scale); + + enum_(py_module, "PostQKType") + .value("Rope", PostQKType::Rope) + .value("HalfRope", PostQKType::HalfRope); + + enum_(py_module, "ActivationType") + .value("Gelu", ActivationType::Gelu); + + enum_(py_module, "QueryScaleType") + .value("SqrtKeySize", QueryScaleType::SqrtKeySize) + .value("SqrtModelDimDivNumHeads", + QueryScaleType::SqrtModelDimDivNumHeads); + + enum_(py_module, "ResidualType") + .value("Add", ResidualType::Add); + + enum_(py_module, "Model") + .value("UNKNOWN", Model::UNKNOWN) + .value("GEMMA_2B", Model::GEMMA_2B) + .value("GEMMA_7B", Model::GEMMA_7B) + .value("GEMMA2_9B", Model::GEMMA2_9B) + .value("GEMMA2_27B", Model::GEMMA2_27B) + .value("GRIFFIN_2B", Model::GRIFFIN_2B) + .value("GEMMA_TINY", Model::GEMMA_TINY) + .value("GEMMA2_2B", Model::GEMMA2_2B) + .value("PALIGEMMA_224", Model::PALIGEMMA_224); + + class_(py_module, "LayerConfig") + .def(init()) + .def_readwrite("model_dim", &LayerConfig::model_dim) + .def_readwrite("griffin_dim", &LayerConfig::griffin_dim) + .def_readwrite("ff_hidden_dim", &LayerConfig::ff_hidden_dim) + .def_readwrite("heads", &LayerConfig::heads) + .def_readwrite("kv_heads", &LayerConfig::kv_heads) + .def_readwrite("qkv_dim", &LayerConfig::qkv_dim) + .def_readwrite("conv1d_width", &LayerConfig::conv1d_width) + .def_readwrite("ff_biases", &LayerConfig::ff_biases) + .def_readwrite("softmax_attn_output_biases", + &LayerConfig::softmax_attn_output_biases) + .def_readwrite("optimized_gating", &LayerConfig::optimized_gating) + .def_readwrite("post_norm", &LayerConfig::post_norm) + .def_readwrite("type", &LayerConfig::type) + .def_readwrite("activation", &LayerConfig::activation) + .def_readwrite("post_qk", &LayerConfig::post_qk); + + class_(py_module, "ModelConfig") + .def(init()) + .def_readwrite("model_name", &ModelConfig::model_name) + .def_readwrite("model", &ModelConfig::model) + .def_readwrite("training", &ModelConfig::training) + .def_readwrite("weight", &ModelConfig::weight) + .def_readwrite("num_layers", &ModelConfig::num_layers) + .def_readwrite("model_dim", &ModelConfig::model_dim) + .def_readwrite("vit_model_dim", &ModelConfig::vit_model_dim) + .def_readwrite("vocab_size", &ModelConfig::vocab_size) + .def_readwrite("seq_len", &ModelConfig::seq_len) + .def_readwrite("vit_seq_len", &ModelConfig::vit_seq_len) + .def_readwrite("num_tensor_scales", &ModelConfig::num_tensor_scales) + .def_readwrite("num_vit_scales", &ModelConfig::num_vit_scales) + .def_readwrite("att_cap", &ModelConfig::att_cap) + .def_readwrite("final_cap", &ModelConfig::final_cap) + .def_readwrite("absolute_pe", &ModelConfig::absolute_pe) + .def_readwrite("use_local_attention", &ModelConfig::use_local_attention) + .def_readwrite("query_scale", &ModelConfig::query_scale) + .def_readwrite("layer_configs", &ModelConfig::layer_configs) + .def_readwrite("attention_window_sizes", + &ModelConfig::attention_window_sizes) + .def_readwrite("vit_layer_configs", &ModelConfig::vit_layer_configs) + .def_readwrite("scale_names", &ModelConfig::scale_names) + .def_readwrite("norm_num_groups", &ModelConfig::norm_num_groups) + .def_readwrite("model_family_version", &ModelConfig::model_family_version) + .def_readwrite("patch_width", &ModelConfig::patch_width) + .def_readwrite("image_size", &ModelConfig::image_size) + .def("add_layer_config", &ModelConfig::AddLayerConfig, + arg("layer_config")) + .def("test_equal", &ModelConfig::TestEqual, arg("other"), arg("partial"), + arg("debug")); + + // Returns the config for the given model. + py_module.def("config_from_model", &gcpp::ConfigFromModel, arg("model")); + + // Returns the model for the given config, if it matches any standard model. + py_module.def("model_from_config", &gcpp::ModelFromConfig, arg("config")); + + // Returns the sub-config for the ViT model of the PaliGemma model. + py_module.def("vit_config", &gcpp::VitConfig, arg("config")); +} + +} // namespace pybind11