gemma.cpp/gemma/model_store.cc

419 lines
15 KiB
C++

// Copyright 2025 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "gemma/model_store.h"
#include <stddef.h>
#include <stdint.h>
#include <stdio.h>
#include <array>
#include <cstdlib>
#include <cstring> // strcmp
#include <string>
#include "compression/types.h"
#include "gemma/configs.h" // ModelConfig
#include "gemma/tensor_info.h"
#include "gemma/tokenizer.h"
#include "io/blob_store.h"
#include "io/fields.h"
#include "io/io.h" // Path
#include "util/basics.h"
#include "util/threading_context.h"
#include "hwy/base.h"
namespace gcpp {
// Single-file format contains blobs with these names:
static constexpr char kConfigName[] = "config";
static constexpr char kTokenizerName[] = "tokenizer";
static constexpr char kMatPtrsName[] = "toc";
// Pre-2025 format has one metadata blob. 'F' denoted f32.
static constexpr char kDecoratedScalesName[] = "Fscales";
static void WarnIfExtra(const IFields::ReadResult& result, const char* name) {
// No warning if missing_fields > 0: those fields are default-initialized.
if (result.extra_u32) {
HWY_WARN(
"Serialized blob %s has %u extra fields the code is not aware of. "
"Consider updating to the latest code from GitHub.",
name, result.extra_u32);
}
}
// Returns the serialized tokenizer (std::string is required for proto).
// Reads it from a blob or from a separate file if pre-2025.
static std::string ReadTokenizer(BlobReader& reader,
const Path& tokenizer_path) {
std::string tokenizer;
// Check prevents `CallWithSpan` from printing a warning.
if (reader.Find(kTokenizerName)) {
if (!reader.CallWithSpan<char>(
kTokenizerName, [&tokenizer](const hwy::Span<const char> bytes) {
tokenizer.assign(bytes.data(), bytes.size());
})) {
HWY_WARN(
"Reading tokenizer blob failed, please raise an issue. You can "
"instead specify a tokenizer file via --tokenizer.");
}
}
if (!tokenizer.empty() && tokenizer != kMockTokenizer) {
return tokenizer; // Read actual tokenizer from blob.
}
// No blob but user specified path to file: read it or abort.
if (!tokenizer_path.Empty()) {
return ReadFileToString(tokenizer_path);
}
HWY_WARN(
"BlobStore does not contain a tokenizer and no --tokenizer was "
"specified. Tests may continue but inference will fail.");
return kMockTokenizer;
}
using KeyVec = std::vector<std::string>;
class TypePrefix {
public:
static Type TypeFromChar(char c) {
switch (c) {
case 'F':
return Type::kF32;
case 'B':
return Type::kBF16;
case '$':
return Type::kSFP;
case '2':
return Type::kNUQ;
default:
// The other types were not written to pre-2025 files, hence no need to
// encode and check for them here.
return Type::kUnknown;
}
}
TypePrefix(const KeyVec& keys, const BlobReader& reader) {
for (size_t key_idx = 0; key_idx < keys.size(); ++key_idx) {
const std::string& key = keys[key_idx];
const Type type = TypeFromChar(key[0]);
const uint64_t bytes = reader.Range(key_idx).bytes;
bytes_[static_cast<size_t>(type)] += bytes;
blobs_[static_cast<size_t>(type)]++;
total_bytes_ += bytes;
}
}
// Returns true for pre-2025 format, which has type prefixes and thus the
// functions below may be used.
bool HasPrefixes() const {
return bytes_[static_cast<size_t>(Type::kUnknown)] != total_bytes_;
}
// Returns the weight type deduced from the histogram of blobs per type.
// Rationale: We expect a mix of types due to varying precision requirements
// for each tensor. The preferred weight type might not even be the most
// common, because we prioritize higher compression for the *large* tensors.
// Ignore types which only have a few blobs (might be metadata), and assume
// that there would be at least 4 of the large tensors (in particular, global
// attention layers). Hence return the smallest type with >= 4 blobs.
Type DeduceWeightType() const {
size_t min_bits = ~size_t{0};
Type weight_type = Type::kUnknown;
for (size_t i = 0; i < kNumTypes; ++i) {
if (blobs_[i] < 4) continue;
const size_t bits = TypeBits(static_cast<Type>(i));
if (bits < min_bits) {
min_bits = bits;
weight_type = static_cast<Type>(i);
}
}
return weight_type;
}
// Prints statistics on the total size of tensors by type.
void PrintTypeBytes() const {
for (size_t type_idx = 0; type_idx < kNumTypes; ++type_idx) {
const Type type = static_cast<Type>(type_idx);
const uint64_t bytes = bytes_[type_idx];
if (bytes == 0) continue;
const double percent = 100.0 * bytes / total_bytes_;
fprintf(stderr, "%zu blob bytes (%.2f%%) of %s\n",
static_cast<size_t>(bytes), percent, TypeName(type));
}
}
private:
uint64_t total_bytes_ = 0;
std::array<size_t, kNumTypes> bytes_{0};
std::array<size_t, kNumTypes> blobs_{0};
};
// Returns the number of layers based on the largest blob name suffix seen.
// This works with or without type prefixes because it searches for suffixes.
static size_t DeduceNumLayers(const KeyVec& keys) {
size_t max_layer_idx = 0;
for (const std::string& key : keys) {
const size_t suffix_pos = key.rfind('_');
if (suffix_pos == std::string::npos) continue;
char* end;
auto layer_idx = strtoul(key.c_str() + suffix_pos + 1, &end, 10); // NOLINT
HWY_ASSERT(layer_idx < 999); // Also checks for `ULONG_MAX` if out of range
// Ignore if not a suffix. Some names are prefixed with "c_" for historical
// reasons. In such cases, parsing layer_idx anyway returns 0.
if (end - key.c_str() != key.size()) continue;
max_layer_idx = HWY_MAX(max_layer_idx, layer_idx);
}
return max_layer_idx + 1;
}
// Looks for known tensor names associated with model families.
// This works with or without type prefixes because it searches for substrings.
static int DeduceLayerTypes(const KeyVec& keys) {
int layer_types = 0;
for (const std::string& key : keys) {
if (key.find("gr_conv_w") != std::string::npos) { // NOLINT
return kDeducedGriffin;
}
if (key.find("qkv_einsum_w") != std::string::npos) { // NOLINT
layer_types |= kDeducedViT;
}
}
return layer_types;
}
// `wrapping_override` is forwarded from the command line. For pre-2025 files
// without `ModelConfig`, it is the only way to force PT.
static ModelConfig ReadOrDeduceConfig(BlobReader& reader,
Tristate wrapping_override) {
const TypePrefix type_prefix(reader.Keys(), reader);
Type deduced_weight = Type::kUnknown;
if (type_prefix.HasPrefixes()) {
deduced_weight = type_prefix.DeduceWeightType();
type_prefix.PrintTypeBytes();
}
// Always deduce so we can verify it against the config we read.
const size_t layers = DeduceNumLayers(reader.Keys());
const int layer_types = DeduceLayerTypes(reader.Keys());
const Model deduced_model = DeduceModel(layers, layer_types);
ModelConfig config;
// Check first to prevent `CallWithSpan` from printing a warning.
if (reader.Find(kConfigName)) {
HWY_ASSERT(reader.CallWithSpan<uint32_t>(
kConfigName, [&config](const SerializedSpan serialized) {
const IFields::ReadResult result = config.Read(serialized, 0);
WarnIfExtra(result, kConfigName);
HWY_ASSERT_M(result.pos != 0, "Error deserializing config");
}));
HWY_ASSERT(config.model != Model::UNKNOWN);
HWY_ASSERT(config.wrapping != PromptWrapping::kSentinel);
HWY_ASSERT(config.weight != Type::kUnknown);
// We trust the deserialized config, but checking helps to validate the
// deduction, which we rely on below for pre-2025 files.
if (config.model != deduced_model) {
const std::string suffix = WrappingSuffix(config.wrapping);
HWY_WARN("Detected model %s does not match config %s.",
(std::string(ModelPrefix(deduced_model)) + suffix).c_str(),
(std::string(ModelPrefix(config.model)) + suffix).c_str());
}
return config;
}
// Pre-2025 format: no config, rely on deduction plus `wrapping_override`.
return ModelConfig(deduced_model, deduced_weight,
ChooseWrapping(config.model, wrapping_override));
}
static std::vector<float> ReadScales(BlobReader& reader,
const ModelConfig& config) {
std::vector<float> scales;
// Check first to prevent `CallWithSpan` from printing a warning. This blob is
// optional even in pre-2025 format; Griffin was the first to include it.
if (reader.Find(kDecoratedScalesName)) {
HWY_ASSERT(reader.CallWithSpan<float>(
kDecoratedScalesName,
[&scales](const hwy::Span<const float> scales_blob) {
scales.assign(scales_blob.cbegin(), scales_blob.cend());
}));
}
return scales;
}
// Single-file format: reads `MatPtr` from the blob; returns false if not found.
bool ModelStore::ReadMatPtrs(BlobReader& reader) {
// Check first to prevent `CallWithSpan` from printing a warning.
if (!reader.Find(kMatPtrsName)) return false;
// For verifying `config_.weight`.
size_t min_bits = ~size_t{0};
Type weight_type = Type::kUnknown;
HWY_ASSERT(reader.CallWithSpan<uint32_t>(
kMatPtrsName, [&, this](SerializedSpan serialized) {
for (size_t pos = 0; pos < serialized.size();) {
MatPtr mat;
const IFields::ReadResult result = mat.Read(serialized, pos);
WarnIfExtra(result, mat.Name());
if (result.pos == 0) {
HWY_ABORT("Deserializing MatPtr %s failed (pos %zu of %zu).",
mat.Name(), pos, serialized.size());
}
pos = result.pos + result.extra_u32;
// Retrieve actual key index because a writer may have written other
// blobs before the tensor data.
const BlobRange* range = reader.Find(mat.Name());
HWY_ASSERT(range);
const size_t key_idx = range->key_idx;
AddMatPtr(key_idx, mat);
const size_t bits = TypeBits(mat.GetType());
if (bits < min_bits) {
min_bits = bits;
weight_type = mat.GetType();
}
}
}));
HWY_ASSERT(weight_type != Type::kUnknown);
HWY_ASSERT(weight_type == config_.weight);
return true;
}
// Pre-2025 format: synthesizes `MatPtr` from the blob names if `!ReadMatPtrs`.
void ModelStore::CreateMatPtrs(BlobReader& reader) {
const TensorInfoRegistry tensors(config_);
const KeyVec& keys = reader.Keys();
mat_ptrs_.reserve(keys.size());
// `key_idx` is the blob index. It is not the same as the index of the
// `MatPtr` in `mat_ptrs_` because not all blobs are tensors.
for (size_t key_idx = 0; key_idx < keys.size(); ++key_idx) {
const Type type = TypePrefix::TypeFromChar(keys[key_idx][0]);
if (type == Type::kUnknown) continue; // likely not a tensor
// Strip type prefix from the key. Still includes layer suffix.
const std::string name = keys[key_idx].substr(1);
const TensorInfo* info = tensors.Find(name);
if (HWY_UNLIKELY(!info)) {
if (name == "scales") continue; // ignore, not a tensor.
HWY_ABORT("Unknown tensor %s.", name.c_str());
}
// Unable to set scale already because they are ordered according to
// `ForEachTensor`, which we do not know here. The initial value is 1.0f
// and we set the correct value in `FindAndUpdateMatPtr`.
AddMatPtr(key_idx, MatPtr(name.c_str(), type, ExtentsFromInfo(info)));
}
HWY_ASSERT(mat_ptrs_.size() <= keys.size());
HWY_ASSERT(mat_ptrs_.size() == key_idx_.size());
}
ModelStore::ModelStore(BlobReader& reader, const Path& tokenizer_path,
Tristate wrapping)
: config_(ReadOrDeduceConfig(reader, wrapping)),
tokenizer_(ReadTokenizer(reader, tokenizer_path)) {
if (!ReadMatPtrs(reader)) { // Pre-2025 format.
CreateMatPtrs(reader);
scales_ = ReadScales(reader, config_);
// ModelConfig serialized a vector of strings. Unpack into a set for more
// efficient lookup.
for (const std::string& name : config_.scale_base_names) {
scale_base_names_.insert(name);
}
// If the model has scales, the config must know about it.
HWY_ASSERT(scales_.empty() || !scale_base_names_.empty());
}
HWY_ASSERT(key_idx_.size() == mat_ptrs_.size());
}
ModelStore::~ModelStore() {
// Sanity check: ensure all scales were consumed.
HWY_ASSERT(scales_consumed_ == scales_.size());
}
const MatPtr* ModelStore::FindMat(const char* name) const {
auto it = mat_idx_for_name_.find(name);
if (it == mat_idx_for_name_.end()) return nullptr;
const size_t mat_idx = it->second;
const MatPtr* file_mat = &mat_ptrs_[mat_idx];
HWY_ASSERT(!strcmp(file_mat->Name(), name));
return file_mat;
}
bool ModelStore::FindAndUpdateMatPtr(MatPtr& mat, size_t& key_idx) const {
const MatPtr* file_mat = FindMat(mat.Name());
if (!file_mat) return false;
if (file_mat->Rows() != mat.Rows() || file_mat->Cols() != mat.Cols()) {
HWY_ABORT("Tensor %s shape %zu %zu mismatches file %zu %zu.", mat.Name(),
mat.Rows(), mat.Cols(), file_mat->Rows(), file_mat->Cols());
}
// `Compress()` output is always packed because it assumes a 1D array.
HWY_ASSERT(mat.IsPacked());
// Update fields. Name already matched, otherwise we would not find it.
mat.SetType(file_mat->GetType());
if (scales_.empty()) {
// `file_mat->Scale()` is either read from file, or we have pre-2025 format
// without the optional scales, and it is default-initialized to 1.0f.
mat.SetScale(file_mat->Scale());
} else { // Pre-2025 with scaling factors: set next if `mat` wants one.
if (scale_base_names_.find(StripLayerSuffix(mat.Name())) !=
scale_base_names_.end()) {
HWY_ASSERT(scales_consumed_ < scales_.size());
mat.SetScale(scales_[scales_consumed_++]);
}
}
key_idx = key_idx_[file_mat - mat_ptrs_.data()];
return true;
}
static void AddBlob(const char* name, const std::vector<uint32_t>& data,
BlobWriter& writer) {
HWY_ASSERT(!data.empty());
writer.Add(name, data.data(), data.size() * sizeof(data[0]));
}
void WriteSingleFile(const ModelConfig& config, const GemmaTokenizer& tokenizer,
const std::vector<uint32_t>& serialized_mat_ptrs,
BlobWriter& writer, hwy::ThreadPool& pool,
const Path& path) {
HWY_ASSERT(config.model != Model::UNKNOWN);
HWY_ASSERT(config.weight != Type::kUnknown);
HWY_ASSERT(config.wrapping != PromptWrapping::kSentinel);
const std::vector<uint32_t> serialized_config = config.Write();
AddBlob(kConfigName, serialized_config, writer);
const std::string serialized_tokenizer = tokenizer.Serialize();
HWY_ASSERT(!serialized_tokenizer.empty());
writer.Add(kTokenizerName, serialized_tokenizer.data(),
serialized_tokenizer.size());
AddBlob(kMatPtrsName, serialized_mat_ptrs, writer);
writer.WriteAll(pool, path);
}
} // namespace gcpp