diff --git a/gemma/configs.h b/gemma/configs.h index 3df7ec9..53b19b9 100644 --- a/gemma/configs.h +++ b/gemma/configs.h @@ -214,6 +214,7 @@ enum class Model { GEMMA3_12B, GEMMA3_27B, GEMMA3_270M, + CUSTOM, kSentinel, }; @@ -236,13 +237,13 @@ static inline bool IsObsolete(Model model) { return false; } -// Visits every valid model enum, skipping `UNKNOWN` and `kSentinel`. +// Visits every valid model enum, skipping `UNKNOWN`, `kSentinel` and `CUSTOM`. template void ForEachModel(const Func& func) { for (size_t i = static_cast(Model::GEMMA2_9B); i < static_cast(Model::kSentinel); ++i) { const Model model = static_cast(i); - if (!IsObsolete(model)) func(model); + if (!IsObsolete(model) && model != Model::CUSTOM) func(model); } } diff --git a/gemma/model_store.cc b/gemma/model_store.cc index 76f0c75..67aa306 100644 --- a/gemma/model_store.cc +++ b/gemma/model_store.cc @@ -23,6 +23,7 @@ #include #include #include // strcmp +#include #include #include // std::errc // NOLINT @@ -258,23 +259,25 @@ static ModelConfig ReadOrDeduceConfig(BlobReader& reader, type_prefix.PrintTypeBytes(); } - // Always deduce so we can verify it against the config we read. - const size_t layers = DeduceNumLayers(reader.Keys()); - const int layer_types = DeduceLayerTypes(reader); - const Model deduced_model = - DeduceModel(reader.blob_path(), layers, layer_types); - ModelConfig config; // Check first to prevent `CallWithSpan` from printing a warning. - if (reader.Find(kConfigName)) { + const bool has_config = reader.Find(kConfigName); + if (has_config) { HWY_ASSERT(reader.CallWithSpan( kConfigName, [&config](const SerializedSpan serialized) { const IFields::ReadResult result = config.Read(serialized, 0); WarnIfExtra(result, kConfigName); HWY_ASSERT_M(result.pos != 0, "Error deserializing config"); })); - - HWY_ASSERT(config.model != Model::UNKNOWN); + } + // Optionally deduce so we can verify it against the config we read. + std::optional deduced_model; + if (!has_config || config.model != Model::CUSTOM) { + const size_t layers = DeduceNumLayers(reader.Keys()); + const int layer_types = DeduceLayerTypes(reader); + deduced_model = DeduceModel(reader.blob_path(), layers, layer_types); + } + if (has_config) { HWY_ASSERT(config.wrapping != PromptWrapping::kSentinel); HWY_ASSERT(config.weight != Type::kUnknown); for (const LayerConfig& layer_config : config.layer_configs) { @@ -285,18 +288,19 @@ static ModelConfig ReadOrDeduceConfig(BlobReader& reader, // We trust the deserialized config, but checking helps to validate the // deduction, which we rely on below for pre-2025 files. - if (config.model != deduced_model) { + if (deduced_model.has_value() && config.model != *deduced_model) { const std::string suffix = WrappingSuffix(config.wrapping); HWY_WARN("Detected model %s does not match config %s.", - (std::string(ModelPrefix(deduced_model)) + suffix).c_str(), + (std::string(ModelPrefix(*deduced_model)) + suffix).c_str(), (std::string(ModelPrefix(config.model)) + suffix).c_str()); } return config; } // Pre-2025 format: no config, rely on deduction plus `wrapping_override`. - return ModelConfig(deduced_model, deduced_weight, - ChooseWrapping(deduced_model, wrapping_override)); + HWY_ASSERT(deduced_model.has_value()); + return ModelConfig(*deduced_model, deduced_weight, + ChooseWrapping(*deduced_model, wrapping_override)); } static std::vector ReadScales(BlobReader& reader, diff --git a/python/configs.cc b/python/configs.cc index 0d505dc..2e492ff 100644 --- a/python/configs.cc +++ b/python/configs.cc @@ -84,6 +84,7 @@ PYBIND11_MODULE(configs, py_module) { enum_(py_module, "Model") .value("UNKNOWN", Model::UNKNOWN) + .value("CUSTOM", Model::CUSTOM) .value("GEMMA2_9B", Model::GEMMA2_9B) .value("GEMMA2_27B", Model::GEMMA2_27B) .value("GEMMA2_2B", Model::GEMMA2_2B)