mirror of https://github.com/google/gemma.cpp.git
Add ability to load custom models which are fully described by the ModelConfig blob.
PiperOrigin-RevId: 877265257
This commit is contained in:
parent
dd268ddbe8
commit
f7f5fd5863
|
|
@ -214,6 +214,7 @@ enum class Model {
|
|||
GEMMA3_12B,
|
||||
GEMMA3_27B,
|
||||
GEMMA3_270M,
|
||||
CUSTOM,
|
||||
kSentinel,
|
||||
};
|
||||
|
||||
|
|
@ -236,13 +237,13 @@ static inline bool IsObsolete(Model model) {
|
|||
return false;
|
||||
}
|
||||
|
||||
// Visits every valid model enum, skipping `UNKNOWN` and `kSentinel`.
|
||||
// Visits every valid model enum, skipping `UNKNOWN`, `kSentinel` and `CUSTOM`.
|
||||
template <class Func>
|
||||
void ForEachModel(const Func& func) {
|
||||
for (size_t i = static_cast<size_t>(Model::GEMMA2_9B);
|
||||
i < static_cast<size_t>(Model::kSentinel); ++i) {
|
||||
const Model model = static_cast<Model>(i);
|
||||
if (!IsObsolete(model)) func(model);
|
||||
if (!IsObsolete(model) && model != Model::CUSTOM) func(model);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@
|
|||
#include <charconv>
|
||||
#include <cstdlib>
|
||||
#include <cstring> // strcmp
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <system_error> // std::errc // NOLINT
|
||||
|
||||
|
|
@ -258,23 +259,25 @@ static ModelConfig ReadOrDeduceConfig(BlobReader& reader,
|
|||
type_prefix.PrintTypeBytes();
|
||||
}
|
||||
|
||||
// Always deduce so we can verify it against the config we read.
|
||||
const size_t layers = DeduceNumLayers(reader.Keys());
|
||||
const int layer_types = DeduceLayerTypes(reader);
|
||||
const Model deduced_model =
|
||||
DeduceModel(reader.blob_path(), layers, layer_types);
|
||||
|
||||
ModelConfig config;
|
||||
// Check first to prevent `CallWithSpan` from printing a warning.
|
||||
if (reader.Find(kConfigName)) {
|
||||
const bool has_config = reader.Find(kConfigName);
|
||||
if (has_config) {
|
||||
HWY_ASSERT(reader.CallWithSpan<uint32_t>(
|
||||
kConfigName, [&config](const SerializedSpan serialized) {
|
||||
const IFields::ReadResult result = config.Read(serialized, 0);
|
||||
WarnIfExtra(result, kConfigName);
|
||||
HWY_ASSERT_M(result.pos != 0, "Error deserializing config");
|
||||
}));
|
||||
|
||||
HWY_ASSERT(config.model != Model::UNKNOWN);
|
||||
}
|
||||
// Optionally deduce so we can verify it against the config we read.
|
||||
std::optional<Model> deduced_model;
|
||||
if (!has_config || config.model != Model::CUSTOM) {
|
||||
const size_t layers = DeduceNumLayers(reader.Keys());
|
||||
const int layer_types = DeduceLayerTypes(reader);
|
||||
deduced_model = DeduceModel(reader.blob_path(), layers, layer_types);
|
||||
}
|
||||
if (has_config) {
|
||||
HWY_ASSERT(config.wrapping != PromptWrapping::kSentinel);
|
||||
HWY_ASSERT(config.weight != Type::kUnknown);
|
||||
for (const LayerConfig& layer_config : config.layer_configs) {
|
||||
|
|
@ -285,18 +288,19 @@ static ModelConfig ReadOrDeduceConfig(BlobReader& reader,
|
|||
|
||||
// We trust the deserialized config, but checking helps to validate the
|
||||
// deduction, which we rely on below for pre-2025 files.
|
||||
if (config.model != deduced_model) {
|
||||
if (deduced_model.has_value() && config.model != *deduced_model) {
|
||||
const std::string suffix = WrappingSuffix(config.wrapping);
|
||||
HWY_WARN("Detected model %s does not match config %s.",
|
||||
(std::string(ModelPrefix(deduced_model)) + suffix).c_str(),
|
||||
(std::string(ModelPrefix(*deduced_model)) + suffix).c_str(),
|
||||
(std::string(ModelPrefix(config.model)) + suffix).c_str());
|
||||
}
|
||||
return config;
|
||||
}
|
||||
|
||||
// Pre-2025 format: no config, rely on deduction plus `wrapping_override`.
|
||||
return ModelConfig(deduced_model, deduced_weight,
|
||||
ChooseWrapping(deduced_model, wrapping_override));
|
||||
HWY_ASSERT(deduced_model.has_value());
|
||||
return ModelConfig(*deduced_model, deduced_weight,
|
||||
ChooseWrapping(*deduced_model, wrapping_override));
|
||||
}
|
||||
|
||||
static std::vector<float> ReadScales(BlobReader& reader,
|
||||
|
|
|
|||
|
|
@ -84,6 +84,7 @@ PYBIND11_MODULE(configs, py_module) {
|
|||
|
||||
enum_<Model>(py_module, "Model")
|
||||
.value("UNKNOWN", Model::UNKNOWN)
|
||||
.value("CUSTOM", Model::CUSTOM)
|
||||
.value("GEMMA2_9B", Model::GEMMA2_9B)
|
||||
.value("GEMMA2_27B", Model::GEMMA2_27B)
|
||||
.value("GEMMA2_2B", Model::GEMMA2_2B)
|
||||
|
|
|
|||
Loading…
Reference in New Issue