Add ability to load custom models which are fully described by the ModelConfig blob.

PiperOrigin-RevId: 877265257
This commit is contained in:
Miguel Lobo 2026-03-02 01:17:58 -08:00 committed by Copybara-Service
parent dd268ddbe8
commit f7f5fd5863
3 changed files with 21 additions and 15 deletions

View File

@ -214,6 +214,7 @@ enum class Model {
GEMMA3_12B,
GEMMA3_27B,
GEMMA3_270M,
CUSTOM,
kSentinel,
};
@ -236,13 +237,13 @@ static inline bool IsObsolete(Model model) {
return false;
}
// Visits every valid model enum, skipping `UNKNOWN` and `kSentinel`.
// Visits every valid model enum, skipping `UNKNOWN`, `kSentinel` and `CUSTOM`.
template <class Func>
void ForEachModel(const Func& func) {
for (size_t i = static_cast<size_t>(Model::GEMMA2_9B);
i < static_cast<size_t>(Model::kSentinel); ++i) {
const Model model = static_cast<Model>(i);
if (!IsObsolete(model)) func(model);
if (!IsObsolete(model) && model != Model::CUSTOM) func(model);
}
}

View File

@ -23,6 +23,7 @@
#include <charconv>
#include <cstdlib>
#include <cstring> // strcmp
#include <optional>
#include <string>
#include <system_error> // std::errc // NOLINT
@ -258,23 +259,25 @@ static ModelConfig ReadOrDeduceConfig(BlobReader& reader,
type_prefix.PrintTypeBytes();
}
// Always deduce so we can verify it against the config we read.
const size_t layers = DeduceNumLayers(reader.Keys());
const int layer_types = DeduceLayerTypes(reader);
const Model deduced_model =
DeduceModel(reader.blob_path(), layers, layer_types);
ModelConfig config;
// Check first to prevent `CallWithSpan` from printing a warning.
if (reader.Find(kConfigName)) {
const bool has_config = reader.Find(kConfigName);
if (has_config) {
HWY_ASSERT(reader.CallWithSpan<uint32_t>(
kConfigName, [&config](const SerializedSpan serialized) {
const IFields::ReadResult result = config.Read(serialized, 0);
WarnIfExtra(result, kConfigName);
HWY_ASSERT_M(result.pos != 0, "Error deserializing config");
}));
HWY_ASSERT(config.model != Model::UNKNOWN);
}
// Optionally deduce so we can verify it against the config we read.
std::optional<Model> deduced_model;
if (!has_config || config.model != Model::CUSTOM) {
const size_t layers = DeduceNumLayers(reader.Keys());
const int layer_types = DeduceLayerTypes(reader);
deduced_model = DeduceModel(reader.blob_path(), layers, layer_types);
}
if (has_config) {
HWY_ASSERT(config.wrapping != PromptWrapping::kSentinel);
HWY_ASSERT(config.weight != Type::kUnknown);
for (const LayerConfig& layer_config : config.layer_configs) {
@ -285,18 +288,19 @@ static ModelConfig ReadOrDeduceConfig(BlobReader& reader,
// We trust the deserialized config, but checking helps to validate the
// deduction, which we rely on below for pre-2025 files.
if (config.model != deduced_model) {
if (deduced_model.has_value() && config.model != *deduced_model) {
const std::string suffix = WrappingSuffix(config.wrapping);
HWY_WARN("Detected model %s does not match config %s.",
(std::string(ModelPrefix(deduced_model)) + suffix).c_str(),
(std::string(ModelPrefix(*deduced_model)) + suffix).c_str(),
(std::string(ModelPrefix(config.model)) + suffix).c_str());
}
return config;
}
// Pre-2025 format: no config, rely on deduction plus `wrapping_override`.
return ModelConfig(deduced_model, deduced_weight,
ChooseWrapping(deduced_model, wrapping_override));
HWY_ASSERT(deduced_model.has_value());
return ModelConfig(*deduced_model, deduced_weight,
ChooseWrapping(*deduced_model, wrapping_override));
}
static std::vector<float> ReadScales(BlobReader& reader,

View File

@ -84,6 +84,7 @@ PYBIND11_MODULE(configs, py_module) {
enum_<Model>(py_module, "Model")
.value("UNKNOWN", Model::UNKNOWN)
.value("CUSTOM", Model::CUSTOM)
.value("GEMMA2_9B", Model::GEMMA2_9B)
.value("GEMMA2_27B", Model::GEMMA2_27B)
.value("GEMMA2_2B", Model::GEMMA2_2B)