mirror of https://github.com/google/gemma.cpp.git
Added pybind for configs.
Added ability to test configs for equality. PiperOrigin-RevId: 697572671
This commit is contained in:
parent
36f02ef892
commit
7d685a267f
117
gemma/configs.cc
117
gemma/configs.cc
|
|
@ -15,6 +15,8 @@
|
|||
|
||||
#include "gemma/configs.h"
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "hwy/base.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
|
@ -181,6 +183,7 @@ static ModelConfig ConfigGriffin2B() {
|
|||
.conv1d_width = 4,
|
||||
.ff_biases = true,
|
||||
.softmax_attn_output_biases = true,
|
||||
.optimized_gating = false,
|
||||
.type = LayerAttentionType::kGriffinRecurrentBlock,
|
||||
.activation = ActivationType::Gelu,
|
||||
.post_qk = PostQKType::HalfRope,
|
||||
|
|
@ -204,6 +207,9 @@ static void AddVitConfig(ModelConfig& config) {
|
|||
config.vocab_size = 256000 + 1024 + 128; // = 257152
|
||||
config.image_size = 224;
|
||||
config.patch_width = 14;
|
||||
for (auto& layer_config : config.layer_configs) {
|
||||
layer_config.optimized_gating = false;
|
||||
}
|
||||
const size_t num_patches = config.image_size / config.patch_width;
|
||||
config.vit_seq_len = num_patches * num_patches;
|
||||
LayerConfig vit_layer_config = {
|
||||
|
|
@ -260,4 +266,115 @@ ModelConfig ConfigFromModel(Model model) {
|
|||
}
|
||||
}
|
||||
|
||||
#define TEST_EQUAL(a, b) \
|
||||
if (a != b) { \
|
||||
if (debug) \
|
||||
std::cerr << #a << "=" << a << " != " << #b << "=" << b << "\n"; \
|
||||
result = false; \
|
||||
}
|
||||
|
||||
#define RETURN_IF_NOT_EQUAL(a, b) \
|
||||
if (a != b) { \
|
||||
if (debug) \
|
||||
std::cerr << #a << "=" << a << " != " << #b << "=" << b << "\n"; \
|
||||
return false; \
|
||||
}
|
||||
|
||||
#define WARN_IF_NOT_EQUAL(a, b) \
|
||||
if (a != b) { \
|
||||
std::cerr << #a << "=" << a << " != " << #b << "=" << b << "\n"; \
|
||||
}
|
||||
|
||||
bool LayerConfig::TestEqual(const LayerConfig& other, bool partial,
|
||||
bool debug) const {
|
||||
bool result = true;
|
||||
// Optimized gating may not be set correctly in the c++ configs.
|
||||
if (debug) {
|
||||
WARN_IF_NOT_EQUAL(optimized_gating, other.optimized_gating)
|
||||
}
|
||||
TEST_EQUAL(model_dim, other.model_dim);
|
||||
TEST_EQUAL(griffin_dim, other.griffin_dim);
|
||||
TEST_EQUAL(ff_hidden_dim, other.ff_hidden_dim);
|
||||
TEST_EQUAL(heads, other.heads);
|
||||
TEST_EQUAL(kv_heads, other.kv_heads);
|
||||
TEST_EQUAL(qkv_dim, other.qkv_dim);
|
||||
TEST_EQUAL(conv1d_width, other.conv1d_width);
|
||||
if (!partial) {
|
||||
TEST_EQUAL(ff_biases, other.ff_biases);
|
||||
TEST_EQUAL(softmax_attn_output_biases, other.softmax_attn_output_biases);
|
||||
}
|
||||
TEST_EQUAL(static_cast<int>(post_norm), static_cast<int>(other.post_norm));
|
||||
TEST_EQUAL(static_cast<int>(type), static_cast<int>(other.type));
|
||||
TEST_EQUAL(static_cast<int>(activation), static_cast<int>(other.activation));
|
||||
TEST_EQUAL(static_cast<int>(post_qk), static_cast<int>(other.post_qk));
|
||||
return result;
|
||||
}
|
||||
|
||||
bool ModelConfig::TestEqual(const ModelConfig& other, bool partial,
|
||||
bool debug) const {
|
||||
bool result = true;
|
||||
// We don't care about model_name, model, training, or weight being different,
|
||||
// but will output in debug mode if they are.
|
||||
if (debug) {
|
||||
WARN_IF_NOT_EQUAL(model_name, other.model_name);
|
||||
WARN_IF_NOT_EQUAL(static_cast<int>(model), static_cast<int>(other.model));
|
||||
WARN_IF_NOT_EQUAL(static_cast<int>(training),
|
||||
static_cast<int>(other.training));
|
||||
WARN_IF_NOT_EQUAL(static_cast<int>(weight), static_cast<int>(other.weight));
|
||||
}
|
||||
TEST_EQUAL(model_dim, other.model_dim);
|
||||
TEST_EQUAL(vit_model_dim, other.vit_model_dim);
|
||||
TEST_EQUAL(vocab_size, other.vocab_size);
|
||||
TEST_EQUAL(seq_len, other.seq_len);
|
||||
TEST_EQUAL(vit_seq_len, other.vit_seq_len);
|
||||
if (!partial) {
|
||||
TEST_EQUAL(num_tensor_scales, other.num_tensor_scales);
|
||||
TEST_EQUAL(num_vit_scales, other.num_vit_scales);
|
||||
}
|
||||
TEST_EQUAL(att_cap, other.att_cap);
|
||||
TEST_EQUAL(final_cap, other.final_cap);
|
||||
TEST_EQUAL(absolute_pe, other.absolute_pe);
|
||||
TEST_EQUAL(use_local_attention, other.use_local_attention);
|
||||
TEST_EQUAL(static_cast<int>(query_scale),
|
||||
static_cast<int>(other.query_scale));
|
||||
RETURN_IF_NOT_EQUAL(layer_configs.size(), other.layer_configs.size());
|
||||
for (size_t i = 0; i < layer_configs.size(); ++i) {
|
||||
result &=
|
||||
layer_configs[i].TestEqual(other.layer_configs[i], partial, debug);
|
||||
}
|
||||
RETURN_IF_NOT_EQUAL(attention_window_sizes.size(),
|
||||
other.attention_window_sizes.size());
|
||||
for (size_t i = 0; i < attention_window_sizes.size(); ++i) {
|
||||
TEST_EQUAL(attention_window_sizes[i], other.attention_window_sizes[i]);
|
||||
}
|
||||
RETURN_IF_NOT_EQUAL(vit_layer_configs.size(), other.vit_layer_configs.size());
|
||||
for (size_t i = 0; i < vit_layer_configs.size(); ++i) {
|
||||
result &= vit_layer_configs[i].TestEqual(other.vit_layer_configs[i],
|
||||
partial, debug);
|
||||
}
|
||||
if (!partial) {
|
||||
if (scale_names != other.scale_names) {
|
||||
result = false;
|
||||
if (debug) {
|
||||
std::cerr << "scale_names mismatch\n";
|
||||
}
|
||||
}
|
||||
}
|
||||
TEST_EQUAL(norm_num_groups, other.norm_num_groups);
|
||||
TEST_EQUAL(model_family_version, other.model_family_version);
|
||||
TEST_EQUAL(patch_width, other.patch_width);
|
||||
TEST_EQUAL(image_size, other.image_size);
|
||||
return result;
|
||||
}
|
||||
|
||||
Model ModelFromConfig(const ModelConfig& config) {
|
||||
for (Model model : kAllModels) {
|
||||
ModelConfig model_config = ConfigFromModel(model);
|
||||
if (config.TestEqual(model_config, /*partial=*/true, /*debug=*/false)) {
|
||||
return model;
|
||||
}
|
||||
}
|
||||
return Model::UNKNOWN;
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
|
|
@ -116,7 +116,20 @@ enum class Model {
|
|||
PALIGEMMA_224,
|
||||
};
|
||||
|
||||
// Allows the Model enum to be iterated over.
|
||||
static constexpr Model kAllModels[] = {
|
||||
Model::GEMMA_2B, Model::GEMMA_7B, Model::GEMMA2_9B, Model::GEMMA2_27B,
|
||||
Model::GRIFFIN_2B, Model::GEMMA_TINY, Model::GEMMA2_2B,
|
||||
Model::PALIGEMMA_224,
|
||||
};
|
||||
|
||||
struct LayerConfig {
|
||||
// Returns true if *this and other are equal.
|
||||
// If partial is true, then we don't check for items that are only set after
|
||||
// the tensors are loaded from the checkpoint.
|
||||
// If debug is true, then we output the mismatched fields to stderr.
|
||||
bool TestEqual(const LayerConfig& other, bool partial, bool debug) const;
|
||||
|
||||
size_t CacheLayerSize() const { return kv_heads * qkv_dim * 2; }
|
||||
|
||||
// Multi-Head Attention?
|
||||
|
|
@ -132,9 +145,10 @@ struct LayerConfig {
|
|||
size_t heads = 0;
|
||||
size_t kv_heads = 0;
|
||||
size_t qkv_dim = 0;
|
||||
size_t conv1d_width = 0;
|
||||
size_t conv1d_width = 0; // griffin only
|
||||
bool ff_biases = false;
|
||||
bool softmax_attn_output_biases = false;
|
||||
bool optimized_gating = true;
|
||||
PostNormType post_norm = PostNormType::None;
|
||||
LayerAttentionType type = LayerAttentionType::kGemma;
|
||||
ActivationType activation = ActivationType::Gelu;
|
||||
|
|
@ -142,6 +156,16 @@ struct LayerConfig {
|
|||
};
|
||||
|
||||
struct ModelConfig {
|
||||
// Returns true if *this and other are equal.
|
||||
// If partial is true, then we don't check for items that are only set after
|
||||
// the tensors are loaded from the checkpoint.
|
||||
// If debug is true, then we output the mismatched fields to stderr.
|
||||
bool TestEqual(const ModelConfig& other, bool partial, bool debug) const;
|
||||
|
||||
void AddLayerConfig(const LayerConfig& layer_config) {
|
||||
layer_configs.push_back(layer_config);
|
||||
}
|
||||
|
||||
size_t CachePosSize() const {
|
||||
size_t num_layers = layer_configs.size();
|
||||
return num_layers * layer_configs[0].CacheLayerSize();
|
||||
|
|
@ -171,6 +195,7 @@ struct ModelConfig {
|
|||
Model model;
|
||||
ModelTraining training;
|
||||
Type weight;
|
||||
size_t num_layers = 0;
|
||||
size_t model_dim = 0;
|
||||
size_t vit_model_dim = 0;
|
||||
size_t vocab_size = 0;
|
||||
|
|
@ -181,7 +206,7 @@ struct ModelConfig {
|
|||
float att_cap = 0.0f;
|
||||
float final_cap = 0.0f;
|
||||
bool absolute_pe = false;
|
||||
bool use_local_attention = false;
|
||||
bool use_local_attention = false; // griffin only
|
||||
QueryScaleType query_scale = QueryScaleType::SqrtKeySize;
|
||||
std::vector<LayerConfig> layer_configs;
|
||||
std::vector<size_t> attention_window_sizes;
|
||||
|
|
@ -190,13 +215,16 @@ struct ModelConfig {
|
|||
int norm_num_groups = 1;
|
||||
int model_family_version = 1;
|
||||
// Dimensions related to image processing.
|
||||
int patch_width = 14;
|
||||
int image_size = 224;
|
||||
size_t patch_width = 14;
|
||||
size_t image_size = 224;
|
||||
};
|
||||
|
||||
// Returns the config for the given model.
|
||||
ModelConfig ConfigFromModel(Model model);
|
||||
|
||||
// Returns the model for the given config, if it matches any standard model.
|
||||
Model ModelFromConfig(const ModelConfig& config);
|
||||
|
||||
// Returns the sub-config for the ViT model of the PaliGemma model.
|
||||
ModelConfig VitConfig(const ModelConfig& config);
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,17 @@
|
|||
load("@pybind11_bazel//:build_defs.bzl", "pybind_extension")
|
||||
|
||||
package(
|
||||
default_applicable_licenses = [
|
||||
"//:license", # Placeholder comment, do not modify
|
||||
],
|
||||
default_visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
pybind_extension(
|
||||
name = "configs",
|
||||
srcs = ["configs.cc"],
|
||||
deps = [
|
||||
"//:common",
|
||||
"//compression:sfp",
|
||||
],
|
||||
)
|
||||
|
|
@ -0,0 +1,136 @@
|
|||
#include "gemma/configs.h"
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
|
||||
#include "compression/shared.h"
|
||||
#include "pybind11/cast.h"
|
||||
|
||||
using gcpp::ActivationType;
|
||||
using gcpp::LayerAttentionType;
|
||||
using gcpp::LayerConfig;
|
||||
using gcpp::Model;
|
||||
using gcpp::ModelConfig;
|
||||
using gcpp::ModelTraining;
|
||||
using gcpp::PostNormType;
|
||||
using gcpp::PostQKType;
|
||||
using gcpp::QueryScaleType;
|
||||
using gcpp::ResidualType;
|
||||
using gcpp::Type;
|
||||
|
||||
namespace pybind11 {
|
||||
|
||||
PYBIND11_MODULE(configs, py_module) {
|
||||
enum_<ModelTraining>(py_module, "ModelTraining")
|
||||
.value("GEMMA_IT", ModelTraining::GEMMA_IT)
|
||||
.value("GEMMA_PT", ModelTraining::GEMMA_PT)
|
||||
.value("PALIGEMMA", ModelTraining::PALIGEMMA);
|
||||
|
||||
enum_<Type>(py_module, "Type")
|
||||
.value("kUnknown", Type::kUnknown)
|
||||
.value("kF32", Type::kF32)
|
||||
.value("kBF16", Type::kBF16)
|
||||
.value("kSFP", Type::kSFP)
|
||||
.value("kNUQ", Type::kNUQ)
|
||||
.value("kF64", Type::kF64)
|
||||
.value("kC64", Type::kC64)
|
||||
.value("kU128", Type::kU128);
|
||||
|
||||
enum_<LayerAttentionType>(py_module, "LayerAttentionType")
|
||||
.value("kGemma", LayerAttentionType::kGemma)
|
||||
.value("kGriffinRecurrentBlock",
|
||||
LayerAttentionType::kGriffinRecurrentBlock)
|
||||
.value("kVit", LayerAttentionType::kVit);
|
||||
|
||||
enum_<PostNormType>(py_module, "PostNormType")
|
||||
.value("NoPostNorm", PostNormType::None)
|
||||
.value("Scale", PostNormType::Scale);
|
||||
|
||||
enum_<PostQKType>(py_module, "PostQKType")
|
||||
.value("Rope", PostQKType::Rope)
|
||||
.value("HalfRope", PostQKType::HalfRope);
|
||||
|
||||
enum_<ActivationType>(py_module, "ActivationType")
|
||||
.value("Gelu", ActivationType::Gelu);
|
||||
|
||||
enum_<QueryScaleType>(py_module, "QueryScaleType")
|
||||
.value("SqrtKeySize", QueryScaleType::SqrtKeySize)
|
||||
.value("SqrtModelDimDivNumHeads",
|
||||
QueryScaleType::SqrtModelDimDivNumHeads);
|
||||
|
||||
enum_<ResidualType>(py_module, "ResidualType")
|
||||
.value("Add", ResidualType::Add);
|
||||
|
||||
enum_<Model>(py_module, "Model")
|
||||
.value("UNKNOWN", Model::UNKNOWN)
|
||||
.value("GEMMA_2B", Model::GEMMA_2B)
|
||||
.value("GEMMA_7B", Model::GEMMA_7B)
|
||||
.value("GEMMA2_9B", Model::GEMMA2_9B)
|
||||
.value("GEMMA2_27B", Model::GEMMA2_27B)
|
||||
.value("GRIFFIN_2B", Model::GRIFFIN_2B)
|
||||
.value("GEMMA_TINY", Model::GEMMA_TINY)
|
||||
.value("GEMMA2_2B", Model::GEMMA2_2B)
|
||||
.value("PALIGEMMA_224", Model::PALIGEMMA_224);
|
||||
|
||||
class_<LayerConfig>(py_module, "LayerConfig")
|
||||
.def(init())
|
||||
.def_readwrite("model_dim", &LayerConfig::model_dim)
|
||||
.def_readwrite("griffin_dim", &LayerConfig::griffin_dim)
|
||||
.def_readwrite("ff_hidden_dim", &LayerConfig::ff_hidden_dim)
|
||||
.def_readwrite("heads", &LayerConfig::heads)
|
||||
.def_readwrite("kv_heads", &LayerConfig::kv_heads)
|
||||
.def_readwrite("qkv_dim", &LayerConfig::qkv_dim)
|
||||
.def_readwrite("conv1d_width", &LayerConfig::conv1d_width)
|
||||
.def_readwrite("ff_biases", &LayerConfig::ff_biases)
|
||||
.def_readwrite("softmax_attn_output_biases",
|
||||
&LayerConfig::softmax_attn_output_biases)
|
||||
.def_readwrite("optimized_gating", &LayerConfig::optimized_gating)
|
||||
.def_readwrite("post_norm", &LayerConfig::post_norm)
|
||||
.def_readwrite("type", &LayerConfig::type)
|
||||
.def_readwrite("activation", &LayerConfig::activation)
|
||||
.def_readwrite("post_qk", &LayerConfig::post_qk);
|
||||
|
||||
class_<ModelConfig>(py_module, "ModelConfig")
|
||||
.def(init())
|
||||
.def_readwrite("model_name", &ModelConfig::model_name)
|
||||
.def_readwrite("model", &ModelConfig::model)
|
||||
.def_readwrite("training", &ModelConfig::training)
|
||||
.def_readwrite("weight", &ModelConfig::weight)
|
||||
.def_readwrite("num_layers", &ModelConfig::num_layers)
|
||||
.def_readwrite("model_dim", &ModelConfig::model_dim)
|
||||
.def_readwrite("vit_model_dim", &ModelConfig::vit_model_dim)
|
||||
.def_readwrite("vocab_size", &ModelConfig::vocab_size)
|
||||
.def_readwrite("seq_len", &ModelConfig::seq_len)
|
||||
.def_readwrite("vit_seq_len", &ModelConfig::vit_seq_len)
|
||||
.def_readwrite("num_tensor_scales", &ModelConfig::num_tensor_scales)
|
||||
.def_readwrite("num_vit_scales", &ModelConfig::num_vit_scales)
|
||||
.def_readwrite("att_cap", &ModelConfig::att_cap)
|
||||
.def_readwrite("final_cap", &ModelConfig::final_cap)
|
||||
.def_readwrite("absolute_pe", &ModelConfig::absolute_pe)
|
||||
.def_readwrite("use_local_attention", &ModelConfig::use_local_attention)
|
||||
.def_readwrite("query_scale", &ModelConfig::query_scale)
|
||||
.def_readwrite("layer_configs", &ModelConfig::layer_configs)
|
||||
.def_readwrite("attention_window_sizes",
|
||||
&ModelConfig::attention_window_sizes)
|
||||
.def_readwrite("vit_layer_configs", &ModelConfig::vit_layer_configs)
|
||||
.def_readwrite("scale_names", &ModelConfig::scale_names)
|
||||
.def_readwrite("norm_num_groups", &ModelConfig::norm_num_groups)
|
||||
.def_readwrite("model_family_version", &ModelConfig::model_family_version)
|
||||
.def_readwrite("patch_width", &ModelConfig::patch_width)
|
||||
.def_readwrite("image_size", &ModelConfig::image_size)
|
||||
.def("add_layer_config", &ModelConfig::AddLayerConfig,
|
||||
arg("layer_config"))
|
||||
.def("test_equal", &ModelConfig::TestEqual, arg("other"), arg("partial"),
|
||||
arg("debug"));
|
||||
|
||||
// Returns the config for the given model.
|
||||
py_module.def("config_from_model", &gcpp::ConfigFromModel, arg("model"));
|
||||
|
||||
// Returns the model for the given config, if it matches any standard model.
|
||||
py_module.def("model_from_config", &gcpp::ModelFromConfig, arg("config"));
|
||||
|
||||
// Returns the sub-config for the ViT model of the PaliGemma model.
|
||||
py_module.def("vit_config", &gcpp::VitConfig, arg("config"));
|
||||
}
|
||||
|
||||
} // namespace pybind11
|
||||
Loading…
Reference in New Issue