Added ability to load/save a complete model file, including tokenizer.

PiperOrigin-RevId: 707914366
This commit is contained in:
Ray Smith 2024-12-19 07:59:08 -08:00 committed by Copybara-Service
parent 5bc356f18f
commit 9d40f0117e
32 changed files with 770 additions and 257 deletions

View File

@ -245,6 +245,7 @@ cc_library(
"gemma/tensor_index.h",
],
deps = [
"//compression:fields",
"//compression:sfp",
"@highway//:hwy", # base.h
"@highway//:thread_pool",
@ -257,6 +258,7 @@ cc_test(
deps = [
":common",
"@googletest//:gtest_main",
"@highway//:hwy",
],
)
@ -388,6 +390,7 @@ cc_library(
":ops",
":threading",
"//compression:io",
"//compression:sfp",
"@highway//:hwy",
],
)

View File

@ -390,13 +390,12 @@ static ModelConfig TestConfig() {
config.model_dim = 32;
config.vocab_size = 12;
config.seq_len = 18;
LayerConfig layer_config = {
.model_dim = config.model_dim,
.ff_hidden_dim = 48,
.heads = 3,
.kv_heads = 1,
.qkv_dim = 12,
};
LayerConfig layer_config;
layer_config.model_dim = config.model_dim;
layer_config.ff_hidden_dim = 48;
layer_config.heads = 3;
layer_config.kv_heads = 1;
layer_config.qkv_dim = 12;
config.layer_configs = {2, layer_config};
config.num_tensor_scales = 4 * config.layer_configs.size();
config.query_scale = QueryScaleType::SqrtKeySize;

View File

@ -191,13 +191,12 @@ static ModelConfig TestConfig() {
config.model_dim = 32;
config.vocab_size = 16;
config.seq_len = 24;
LayerConfig layer_config = {
.model_dim = config.model_dim,
.ff_hidden_dim = 64,
.heads = 3,
.kv_heads = 1,
.qkv_dim = 16,
};
LayerConfig layer_config;
layer_config.model_dim = config.model_dim;
layer_config.ff_hidden_dim = 64;
layer_config.heads = 3;
layer_config.kv_heads = 1;
layer_config.qkv_dim = 16;
config.layer_configs = {2, layer_config};
config.num_tensor_scales = 4 * config.layer_configs.size();
config.query_scale = QueryScaleType::SqrtKeySize;

View File

@ -58,7 +58,6 @@ cc_test(
deps = [
":fields",
"@googletest//:gtest_main", # buildcleaner: keep
"@highway//:hwy",
"@highway//:hwy_test_util",
],
)
@ -202,6 +201,7 @@ cc_library(
deps = [
":blob_store",
":distortion",
":fields",
":io",
":nuq",
":sfp",
@ -210,7 +210,6 @@ cc_library(
"//:common",
"@highway//:hwy",
"@highway//:nanobenchmark",
"@highway//:profiler",
"@highway//:stats",
"@highway//:thread_pool",
],
@ -261,6 +260,7 @@ cc_binary(
"//:allocator",
"//:args",
"//:common",
"//:tokenizer",
"//:weights",
"@highway//:hwy",
"@highway//:thread_pool",
@ -277,3 +277,14 @@ cc_binary(
"@highway//:hwy_test_util",
],
)
cc_binary(
name = "migrate_weights",
srcs = ["migrate_weights.cc"],
deps = [
"//:app",
"//:args",
"//:benchmark_helper",
"//:gemma_lib",
],
)

View File

@ -22,10 +22,13 @@
#include <stdio.h>
#include <cmath> // lroundf, only if COMPRESS_STATS
#include <string>
#include <vector>
#include "compression/blob_store.h"
#include "compression/compress.h" // IWYU pragma: export
#include "compression/distortion.h"
#include "gemma/configs.h"
#include "hwy/aligned_allocator.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
@ -673,36 +676,37 @@ HWY_INLINE float DecompressAndCall(D, const PackedSpan<const VT> v,
// their scaling factors to BlobStore.
class Compressor {
public:
explicit Compressor(hwy::ThreadPool& pool) : pool_(pool) {}
explicit Compressor(hwy::ThreadPool& pool) : writer_(pool) {}
template <typename Packed>
void operator()(MatPtrT<Packed>* compressed, const char* decorated_name,
const float* HWY_RESTRICT weights) {
size_t num_weights = compressed->NumElements();
if (num_weights == 0 || weights == nullptr || compressed->Ptr() == nullptr)
return;
size_t num_compressed = compressed->NumElements();
PackedSpan<Packed> packed = MakeSpan(compressed->data(), num_compressed);
fprintf(stderr, "Compressing %s (%zuM), please wait\n", decorated_name,
num_weights / (1000 * 1000));
Compress(weights, num_weights, work_, packed, /*packed_ofs=*/0, pool_);
const size_t num_bytes = packed.num * sizeof(Packed);
writer_.Add(MakeKey(decorated_name), packed.ptr, num_bytes);
Compress(weights, num_weights, work_, packed, /*packed_ofs=*/0,
writer_.pool());
writer_(compressed, decorated_name);
}
void AddTokenizer(const std::string& tokenizer) {
writer_.AddTokenizer(tokenizer);
}
void AddScales(const float* scales, size_t len) {
if (len) {
MatPtrT<float> scales_ptr("scales", 0, 1);
writer_.Add(MakeKey(scales_ptr.CacheName().c_str()), scales,
len * sizeof(scales[0]));
}
writer_.AddScales(scales, len);
}
BlobError WriteAll(hwy::ThreadPool& pool, const Path& blob_filename) {
const BlobError err = writer_.WriteAll(pool, blob_filename);
if (err != 0) {
fprintf(stderr, "Failed to write blobs to %s (error %d)\n",
blob_filename.path.c_str(), err);
}
return err;
// Writes all blobs to disk in the given order. The config is optional and
// if given, it is written to the file, along with the TOC, making it
// single-file format. Otherwise, the file is written in the multi-file format
// without a TOC.
BlobError WriteAll(const Path& blob_filename, const ModelConfig* config) {
return writer_.WriteAll(blob_filename, config);
}
// Returns the number of blobs added.
@ -710,8 +714,7 @@ class Compressor {
private:
CompressWorkingSet work_;
hwy::ThreadPool& pool_;
BlobWriter writer_;
WriteToBlobStore writer_;
};
// NOLINTNEXTLINE(google-readability-namespace-comments)

View File

@ -24,6 +24,7 @@
#include <stdint.h>
#include <stdio.h>
#include <cstdio>
#include <cstring>
#include <string>
#include <unordered_map>
@ -32,11 +33,13 @@
// IWYU pragma: begin_exports
#include "compression/blob_store.h"
#include "compression/fields.h"
#include "compression/io.h"
#include "compression/shared.h"
#include "gemma/tensor_index.h"
#include "util/basics.h"
// IWYU pragma: end_exports
#include "gemma/configs.h"
#include "util/allocator.h"
#include "hwy/per_target.h"
#if COMPRESS_STATS
@ -55,7 +58,7 @@ namespace gcpp {
// fixed inner dimension and type.
// It is designed to be put in a vector, and has default copy and operator=, so
// it is easy to read/write a blob_store file.
class MatPtr {
class MatPtr : public IFields {
public:
// Full constructor for dynamic sizing.
MatPtr(const std::string& name, Type type, size_t element_size, size_t rows,
@ -73,36 +76,6 @@ class MatPtr {
MatPtr() = default;
virtual ~MatPtr();
// Number of hwy::uint128_t in a TOC entry.
// Note that the old-style BlobStore files only have a list of keys and size.
// The new-style BlobStore files have an entry called "toc" that contains a
// vector of 4-tuples of
// (name, type, (num_elements, element_size), (rows, cols)).
// The listed blobs can be read directly into MatPtr from the BlobStore
// file, without needing any external knowledge of the number of elements,
// element size or type of the data.
static constexpr size_t kNumU128InTocEntry = 4;
// Construct from a TOC entry.
MatPtr(const hwy::uint128_t& key0, const hwy::uint128_t& key1,
const hwy::uint128_t& key2, const hwy::uint128_t& key3)
: name_(StringFromKey(key0)),
type_(static_cast<Type>(key1.lo)),
element_size_(key2.hi),
num_elements_(key2.lo),
rows_(key3.lo),
cols_(key3.hi) {
stride_ = cols_;
}
// Adds the contents entry to the table of contents.
void AddToToc(std::vector<hwy::uint128_t>& toc) const {
toc.push_back(MakeKey(name_.c_str()));
toc.push_back({static_cast<uint64_t>(type_), 0});
toc.push_back({num_elements_, element_size_});
toc.push_back({rows_, cols_});
}
// Compatibility interface for CompressedArray.
// TODO: remove.
template <typename T>
@ -124,7 +97,7 @@ class MatPtr {
MatPtr& operator=(const MatPtr& other) = default;
// Returns the name of the blob.
const std::string& Name() const { return name_; }
const char* Name() const override { return name_.c_str(); }
void SetName(const std::string& name) { name_ = name; }
// Returns the type of the blob.
@ -163,12 +136,6 @@ class MatPtr {
return name;
}
// Adds the blob to the writer.
void AddToWriter(BlobWriter& writer) const {
fprintf(stderr, "Adding %s to writer\n", name_.c_str());
writer.Add(MakeKey(name_.c_str()), ptr_, SizeBytes());
}
// Sets all data to zero.
void ZeroInit() {
if (ptr_ == nullptr)
@ -176,6 +143,17 @@ class MatPtr {
hwy::ZeroBytes(ptr_, SizeBytes());
}
void VisitFields(IFieldsVisitor& visitor) override {
visitor(name_);
visitor(type_);
visitor(element_size_);
visitor(num_elements_);
visitor(rows_);
visitor(cols_);
visitor(scale_);
visitor(stride_);
}
// Calls func on the upcasted type. Since MatPtr by design is not templated,
// here we provide a way to get to the derived type, provided that `Type()`
// is one of the strings returned by `TypeName()`.
@ -188,13 +166,13 @@ class MatPtr {
// Should be the result of TypeEnum<T> for CallUpcasted() to work.
Type type_;
// sizeof(T)
size_t element_size_ = 0;
uint32_t element_size_ = 0;
// Number of elements in the array.
size_t num_elements_ = 0; // In element_size units.
uint32_t num_elements_ = 0; // In element_size units.
// Number of rows in the 2-d array (outer dimension).
size_t rows_ = 0;
uint32_t rows_ = 0;
// Number of columns in the 2-d array (inner dimension).
size_t cols_ = 0;
uint32_t cols_ = 0;
// Scaling to apply to each element.
float scale_ = 1.0f;
// Aligned data array. This is always a borrowed pointer. It should never be
@ -202,7 +180,7 @@ class MatPtr {
// and must outlive this object.
void* ptr_ = nullptr;
size_t stride_;
uint32_t stride_;
};
// MatPtrT adds a single template argument to MatPtr for an explicit type.
@ -394,31 +372,28 @@ class BlobToc {
public:
BlobToc() = default;
// Adds all blobs to the blob writer. Note that the blobs must have unique
// names.
static void AddAllToBlobWriter(const std::vector<MatStorage>& blobs,
BlobWriter& writer) {
std::vector<hwy::uint128_t> toc;
for (const auto& blob : blobs) {
blob.AddToToc(toc);
blob.AddToWriter(writer);
}
writer.Add(MakeKey(kTocName), toc.data(), toc.size() * sizeof(toc[0]));
}
// Loads the table of contents from the given reader.
BlobError LoadToc(BlobReader& reader) {
hwy::uint128_t toc_key = MakeKey(kTocName);
size_t toc_size = reader.BlobSize(toc_key);
if (toc_size != 0) {
std::vector<hwy::uint128_t> toc(toc_size / sizeof(hwy::uint128_t));
std::vector<uint32_t> toc(toc_size / sizeof(uint32_t));
BlobError err = reader.ReadOne(toc_key, toc.data(), toc_size);
if (err != 0) {
fprintf(stderr, "Failed to read toc (error %d)\n", err);
return err;
}
for (size_t i = 0; i < toc.size(); i += MatPtr::kNumU128InTocEntry) {
AddToToc(MatPtr(toc[i], toc[i + 1], toc[i + 2], toc[i + 3]));
size_t consumed = 0;
size_t prev_consumed = static_cast<size_t>(-1);
while (consumed < toc.size() && prev_consumed != consumed) {
MatPtr blob;
const IFields::ReadResult result =
blob.Read(hwy::Span<const uint32_t>(toc), consumed);
prev_consumed = consumed;
consumed = result.pos;
if (blob.NumElements() > 0) {
AddToToc(blob);
}
}
}
return 0;
@ -437,11 +412,16 @@ class BlobToc {
if (it == toc_map_.end()) return nullptr;
return &toc_[it->second];
}
private:
// The name of the toc in the blob store file.
static constexpr char kTocName[] = "toc";
// The name of the config in the blob store file.
static constexpr char kConfigName[] = "config";
// The name of the tokenizer in the blob store file.
static constexpr char kTokenizerName[] = "tokenizer";
private:
// Adds the blob to the table of contents.
void AddToToc(const MatPtr& blob) {
HWY_ASSERT(!Contains(blob.Name()));
@ -519,6 +499,68 @@ struct CompressWorkingSet {
std::vector<CompressPerThread> tls;
};
// Class to collect and write a set of tensors to a blob store file.
class WriteToBlobStore {
public:
explicit WriteToBlobStore(hwy::ThreadPool& pool) : pool_(pool) {}
template <typename Packed>
void operator()(MatPtrT<Packed>* compressed, const char* decorated_name) {
if (compressed->Ptr() == nullptr) return;
writer_.Add(MakeKey(decorated_name), compressed->Ptr(),
compressed->SizeBytes());
MatPtr renamed_tensor(*compressed);
renamed_tensor.SetName(decorated_name);
renamed_tensor.AppendTo(toc_);
}
void AddTokenizer(const std::string& tokenizer) {
writer_.Add(MakeKey(BlobToc::kTokenizerName), tokenizer.data(),
tokenizer.size() * sizeof(tokenizer[0]));
}
void AddScales(const float* scales, size_t len) {
if (len) {
MatPtrT<float> scales_ptr("scales", 0, 1);
writer_.Add(MakeKey(scales_ptr.CacheName().c_str()), scales,
len * sizeof(scales[0]));
}
}
// Writes all blobs to disk in the given order. The config is optional and
// if given, it is written to the file, along with the TOC, making it
// single-file format. Otherwise, the file is written in the multi-file format
// without a TOC.
BlobError WriteAll(const Path& blob_filename, const ModelConfig* config) {
if (config) {
writer_.Add(MakeKey(BlobToc::kTocName), toc_.data(),
toc_.size() * sizeof(toc_[0]));
config_buffer_ = config->Write();
writer_.Add(MakeKey(BlobToc::kConfigName), config_buffer_.data(),
config_buffer_.size() * sizeof(config_buffer_[0]));
}
const BlobError err = writer_.WriteAll(pool_, blob_filename);
if (err != 0) {
fprintf(stderr, "Failed to write blobs to %s (error %d)\n",
blob_filename.path.c_str(), err);
}
return err;
}
// Returns the number of blobs added.
size_t DebugNumBlobsAdded() const { return writer_.DebugNumBlobsAdded(); }
hwy::ThreadPool& pool() { return pool_; }
protected:
hwy::ThreadPool& pool_;
private:
std::vector<uint32_t> toc_;
BlobWriter writer_;
std::vector<uint32_t> config_buffer_;
};
// Functor called for each tensor, which loads them and their scaling factors
// from BlobStore.
class ReadFromBlobStore {
@ -539,11 +581,40 @@ class ReadFromBlobStore {
// Returns true if there is a TOC.
bool HaveToc() const { return !file_toc_.Empty(); }
// Reads the config from the blob store file.
BlobError LoadConfig(ModelConfig& config) {
hwy::uint128_t config_key = MakeKey(BlobToc::kConfigName);
size_t config_size = reader_.BlobSize(config_key);
if (config_size == 0) return __LINE__;
std::vector<uint32_t> config_buffer(config_size / sizeof(uint32_t));
BlobError err =
reader_.ReadOne(config_key, config_buffer.data(), config_size);
if (err != 0) {
fprintf(stderr, "Failed to read config (error %d)\n", err);
return err;
}
config.Read(hwy::Span<const uint32_t>(config_buffer), 0);
return 0;
}
// Reads the tokenizer from the blob store file.
BlobError LoadTokenizer(std::string& tokenizer) {
hwy::uint128_t key = MakeKey(BlobToc::kTokenizerName);
size_t tokenizer_size = reader_.BlobSize(key);
if (tokenizer_size == 0) return __LINE__;
tokenizer.resize(tokenizer_size);
;
BlobError err = reader_.ReadOne(key, tokenizer.data(), tokenizer_size);
if (err != 0) {
fprintf(stderr, "Failed to read tokenizer (error %d)\n", err);
return err;
}
return 0;
}
// Called for each tensor, enqueues read requests.
void operator()(const char* name, hwy::Span<MatPtr*> tensors) {
if (file_toc_.Empty() || file_toc_.Contains(name)) {
if (tensors[0]->NumElements() == 0)
fprintf(stderr, "Zero elements for %s\n", name);
model_toc_.push_back(tensors[0]);
file_keys_.push_back(name);
}
@ -579,13 +650,13 @@ class ReadFromBlobStore {
fprintf(stderr, "Blob %s has size mismatch TOC\n", file_key.c_str());
return __LINE__;
}
MatStorage toc_blob_array(*toc_blob);
model_memory.push_back(std::move(toc_blob_array));
} else {
std::string name = blob->Name();
*blob = *toc_blob;
blob->SetName(name);
}
model_memory.emplace_back(*blob);
model_memory.back().SetName(file_key);
}
}
// Allocate in parallel using the pool.
pool.Run(0, model_memory.size(),
[this, &model_memory](uint64_t task, size_t /*thread*/) {
@ -594,12 +665,12 @@ class ReadFromBlobStore {
});
// Enqueue the read requests.
for (auto& blob : model_memory) {
err_ = reader_.Enqueue(MakeKey(blob.Name().c_str()), blob.data(),
blob.SizeBytes());
err_ =
reader_.Enqueue(MakeKey(blob.Name()), blob.data(), blob.SizeBytes());
if (err_ != 0) {
fprintf(stderr,
"Failed to read blob %s (error %d) of size %zu x %zu x %zu\n",
blob.Name().c_str(), err_, blob.Rows(), blob.Cols(),
blob.Name(), err_, blob.Rows(), blob.Cols(),
blob.ElementSize());
return err_;
}

View File

@ -25,6 +25,7 @@
// After highway.h
#include "compression/compress-inl.h"
#include "gemma/configs.h"
#include "gemma/tokenizer.h"
#ifndef GEMMA_COMPRESS_WEIGHTS_ONCE
#define GEMMA_COMPRESS_WEIGHTS_ONCE
@ -99,6 +100,9 @@ struct Args : public ArgsBase<Args> {
std::string model_type_str;
std::string weight_type_str;
size_t num_threads;
// If non-empty, whether to include the config and TOC in the output file, as
// well as the tokenizer.
Path tokenizer;
template <class Visitor>
void ForEach(const Visitor& visitor) {
@ -123,6 +127,9 @@ struct Args : public ArgsBase<Args> {
"Number of threads to use.\n Default = Estimate of the "
"number of supported concurrent threads.",
2);
visitor(tokenizer, "tokenizer", Path(),
"Path to tokenizer file. If given, the config and TOC are also "
"added to the output file.");
}
// Uninitialized before Validate, must call after that.
@ -156,7 +163,8 @@ namespace HWY_NAMESPACE {
template <typename T>
void CompressWeights(const Path& weights_path,
const Path& compressed_weights_path, Model model_type,
hwy::ThreadPool& pool) {
Type weight_type, PromptWrapping wrapping,
const Path& tokenizer_path, hwy::ThreadPool& pool) {
if (!weights_path.Exists()) {
HWY_ABORT("The model weights file '%s' does not exist.",
weights_path.path.c_str());
@ -164,6 +172,8 @@ void CompressWeights(const Path& weights_path,
printf("Compressing weights from %s to %s\n", weights_path.path.c_str(),
compressed_weights_path.path.c_str());
ModelConfig config = ConfigFromModel(model_type);
config.weight = weight_type;
config.wrapping = wrapping;
std::vector<MatStorage> model_storage;
ModelWeightsPtrs<T> c_weights(config);
c_weights.Allocate(model_storage, pool);
@ -185,6 +195,9 @@ void CompressWeights(const Path& weights_path,
ok &= 1 == fread(tensors[0]->Ptr(), tensors[0]->SizeBytes(), 1, fptr);
total_size += tensors[0]->SizeBytes();
});
if (!tokenizer_path.path.empty()) {
uc_weights.AllocAndCopyWithTranspose(pool, model_storage);
}
const bool scale_for_compression = config.num_tensor_scales > 0;
std::vector<float> scales;
if (scale_for_compression) {
@ -193,14 +206,21 @@ void CompressWeights(const Path& weights_path,
Compressor compressor(pool);
ModelWeightsPtrs<T>::ForEachTensor(
{reinterpret_cast<ModelWeightsPtrs<T>*>(&uc_weights), &c_weights},
ForEachType::kLoadNoToc,
tokenizer_path.path.empty() ? ForEachType::kLoadNoToc
: ForEachType::kLoadWithToc,
[&compressor](const char* name, hwy::Span<MatPtr*> tensors) {
tensors[1]->CallUpcasted(
compressor, name,
reinterpret_cast<const float*>(tensors[0]->Ptr()));
});
if (!tokenizer_path.path.empty()) {
std::string tokenizer_proto = ReadFileToString(tokenizer_path);
compressor.AddTokenizer(tokenizer_proto);
} else {
compressor.AddScales(scales.data(), scales.size() * sizeof(scales[0]));
compressor.WriteAll(pool, compressed_weights_path);
}
compressor.WriteAll(compressed_weights_path,
tokenizer_path.path.empty() ? nullptr : &config);
}
} // namespace HWY_NAMESPACE
@ -220,19 +240,23 @@ void Run(Args& args) {
switch (weight_type) {
case Type::kF32:
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(CompressWeights<float>)
(args.weights, args.compressed_weights, model_type, pool);
(args.weights, args.compressed_weights, model_type, weight_type,
args.PromptWrappingType(), args.tokenizer, pool);
break;
case Type::kBF16:
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(CompressWeights<BF16>)
(args.weights, args.compressed_weights, model_type, pool);
(args.weights, args.compressed_weights, model_type, weight_type,
args.PromptWrappingType(), args.tokenizer, pool);
break;
case Type::kSFP:
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(CompressWeights<SfpStream>)
(args.weights, args.compressed_weights, model_type, pool);
(args.weights, args.compressed_weights, model_type, weight_type,
args.PromptWrappingType(), args.tokenizer, pool);
break;
case Type::kNUQ:
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(CompressWeights<NuqStream>)
(args.weights, args.compressed_weights, model_type, pool);
(args.weights, args.compressed_weights, model_type, weight_type,
args.PromptWrappingType(), args.tokenizer, pool);
break;
default:
HWY_ABORT("Weight type %d unsupported.", static_cast<int>(weight_type));

View File

@ -83,6 +83,14 @@ class PrintVisitor : public VisitorBase {
fprintf(stderr, "%sU32 %u\n", indent_.c_str(), value);
}
void operator()(int32_t& value) override {
fprintf(stderr, "%sI32 %d\n", indent_.c_str(), value);
}
void operator()(uint64_t& value) override {
fprintf(stderr, "%sU64 %zu\n", indent_.c_str(), value);
}
void operator()(float& value) override {
fprintf(stderr, "%sF32 %f\n", indent_.c_str(), value);
}
@ -120,6 +128,21 @@ class ReadVisitor : public VisitorBase {
value = span_[result_.pos++];
}
void operator()(int32_t& value) override {
if (HWY_UNLIKELY(SkipField())) return;
value = static_cast<int32_t>(span_[result_.pos++]);
}
void operator()(uint64_t& value) override {
if (HWY_UNLIKELY(SkipField())) return;
uint32_t lower = static_cast<uint32_t>(value);
operator()(lower);
uint32_t upper = static_cast<uint32_t>(value >> 32);
operator()(upper);
value = lower | (static_cast<uint64_t>(upper) << 32);
}
void operator()(float& value) override {
if (HWY_UNLIKELY(SkipField())) return;
@ -229,6 +252,15 @@ class WriteVisitor : public VisitorBase {
void operator()(uint32_t& value) override { storage_.push_back(value); }
void operator()(int32_t& value) override {
storage_.push_back(static_cast<uint32_t>(value));
}
void operator()(uint64_t& value) override {
storage_.push_back(static_cast<uint32_t>(value));
storage_.push_back(static_cast<uint32_t>(value >> 32));
}
void operator()(float& value) override {
storage_.push_back(hwy::BitCastScalar<uint32_t>(value));
CheckF32(value);

View File

@ -55,8 +55,9 @@ class IFields; // breaks circular dependency
// Visitors are internal-only, but their base class is visible to user code
// because their `IFields::VisitFields` calls `visitor.operator()`.
//
// Supported field types `T`: `uint32_t`, `float`, `std::string`, classes
// derived from `IFields`, `bool`, `enum`, `std::vector<T>`.
// Supported field types `T`: `uint32_t`, `int32_t`, `uint64_t`, `float`,
// `std::string`,
// classes derived from `IFields`, `bool`, `enum`, `std::vector<T>`.
class IFieldsVisitor {
public:
virtual ~IFieldsVisitor();
@ -69,6 +70,8 @@ class IFieldsVisitor {
// is out of range. A single generic/overloaded function is required to
// support `std::vector<T>`.
virtual void operator()(uint32_t& value) = 0;
virtual void operator()(int32_t& value) = 0;
virtual void operator()(uint64_t& value) = 0;
virtual void operator()(float& value) = 0;
virtual void operator()(std::string& value) = 0;
virtual void operator()(IFields& fields) = 0; // recurse into nested fields
@ -92,7 +95,7 @@ class IFieldsVisitor {
uint32_t u32 = static_cast<uint32_t>(value);
operator()(u32);
if (HWY_UNLIKELY(!EnumValid(static_cast<EnumT>(u32)))) {
return NotifyInvalid("Invalid enum %u\n");
return NotifyInvalid("Invalid enum %u\n", u32);
}
value = static_cast<EnumT>(u32);
}

View File

@ -97,6 +97,8 @@ struct OldFields : public IFields {
visitor(old_str);
visitor(old_nested);
visitor(old1);
visitor(oldi);
visitor(oldl);
visitor(old_vec_str);
visitor(old_vec_nested);
visitor(old_f);
@ -110,6 +112,8 @@ struct OldFields : public IFields {
EXPECT_EQ(old_str, n.old_str);
old_nested.CheckEqual(n.old_nested);
EXPECT_EQ(old1, n.old1);
EXPECT_EQ(oldi, n.oldi);
EXPECT_EQ(oldl, n.oldl);
CheckVectorEqual(old_vec_str, n.old_vec_str);
CheckVectorEqual(old_vec_nested, n.old_vec_nested);
EXPECT_EQ(old_f, n.old_f);
@ -120,6 +124,8 @@ struct OldFields : public IFields {
std::string old_str = "old";
Nested old_nested = Nested(0);
uint32_t old1 = 1;
int32_t oldi = -1;
uint64_t oldl = 1234567890123456789;
std::vector<std::string> old_vec_str = {"abc", "1234"};
std::vector<Nested> old_vec_nested = {Nested(1), Nested(4)};
float old_f = 1.125f;
@ -134,6 +140,8 @@ struct NewFields : public IFields {
visitor(old_str);
visitor(old_nested);
visitor(old1);
visitor(oldi);
visitor(oldl);
visitor(old_vec_str);
visitor(old_vec_nested);
visitor(old_f);
@ -149,6 +157,8 @@ struct NewFields : public IFields {
visitor(new_enum);
visitor(new2);
visitor(new_str);
visitor(new_i);
visitor(new_l);
}
void CheckEqual(const NewFields& n) const {
@ -176,6 +186,8 @@ struct NewFields : public IFields {
std::string old_str = "old";
Nested old_nested = Nested(0);
uint32_t old1 = 1;
int32_t oldi = -1;
uint64_t oldl = 1234567890123456789;
std::vector<std::string> old_vec_str = {"abc", "1234"};
std::vector<Nested> old_vec_nested = {Nested(1), Nested(4)};
float old_f = 1.125f;
@ -190,6 +202,8 @@ struct NewFields : public IFields {
Enum new_enum = Enum::k3;
uint32_t new2 = 2;
std::string new_str = std::string(); // empty is allowed
int32_t new_i = 123456789;
uint64_t new_l = 876543210987654321;
}; // NewFields
// Changes all fields to non-default values.
@ -212,6 +226,8 @@ NewFields ModifiedNewFields() {
n.new_enum = Enum::k8;
n.new2 = 22;
n.new_str = "new and even longer";
n.new_i = 246810121;
n.new_l = 1357913579113579135;
return n;
}

View File

@ -0,0 +1,62 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <stdio.h>
#include <string>
#include "evals/benchmark_helper.h"
#include "gemma/gemma.h"
#include "util/args.h"
namespace gcpp {
namespace {
struct WriterArgs : public ArgsBase<WriterArgs> {
// --output_weights is required.
WriterArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
// Returns error string or nullptr if OK.
const char* Validate() {
if (output_weights.path.empty()) {
return "Missing --output_weights flag, a file for the model weights.";
}
return nullptr;
}
Path output_weights; // weights file location
template <class Visitor>
void ForEach(const Visitor& visitor) {
visitor(output_weights, "output_weights", Path(),
"Path name of output weights (.sbs) file.\n Required argument.");
}
};
} // namespace
} // namespace gcpp
int main(int argc, char** argv) {
// Loads a model in the multi-file format and saves it in single-file format.
gcpp::WriterArgs args(argc, argv);
if (const char* err = args.Validate()) {
fprintf(stderr, "Skipping model load because: %s\n", err);
return 1;
}
gcpp::GemmaEnv env(argc, argv, /*required=*/true);
hwy::ThreadPool pool(0);
env.GetModel()->Save(args.output_weights, pool);
return 0;
}

View File

@ -16,6 +16,7 @@ cc_library(
deps = [
"@abseil-cpp//absl/types:span",
"//:common",
"//:tokenizer",
"//compression:compress",
"//compression:io",
"@highway//:hwy",

View File

@ -24,7 +24,9 @@
#include "absl/types/span.h"
#include "compression/io.h"
#include "gemma/configs.h"
#include "gemma/tensor_index.h"
#include "gemma/tokenizer.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
@ -44,10 +46,11 @@ class WriterInterface {
virtual void InsertFloat(std::string name,
absl::Span<const float> weights) = 0;
virtual void AddScales(const std::vector<float>& scales) = 0;
virtual void AddTokenizer(const std::string& tokenizer_path) = 0;
virtual size_t DebugNumBlobsAdded() const = 0;
virtual int Write(std::string path) = 0;
virtual int WriteWithConfig(std::string path, const ModelConfig* config) = 0;
};
} // namespace gcpp
@ -133,14 +136,21 @@ class SbsWriterImpl : public WriterInterface {
compressor_.AddScales(scales_.data(), scales_.size());
}
void AddTokenizer(const std::string& tokenizer_path) override {
Path path(tokenizer_path);
GemmaTokenizer tokenizer(path);
tokenizer_proto_ = tokenizer.Serialize();
compressor_.AddTokenizer(tokenizer_proto_);
}
// Returns the number of blobs added.
size_t DebugNumBlobsAdded() const {
if (mode_ == CompressorMode::kTEST_ONLY) return model_memory_.size();
return compressor_.DebugNumBlobsAdded();
}
int Write(std::string path) override {
return compressor_.WriteAll(pool_, gcpp::Path(path));
int WriteWithConfig(std::string path, const ModelConfig* config) override {
return compressor_.WriteAll(gcpp::Path(path), config);
}
hwy::ThreadPool pool_;
@ -149,6 +159,7 @@ class SbsWriterImpl : public WriterInterface {
std::vector<MatStorage> model_memory_;
std::vector<float> scales_;
CompressorMode mode_;
std::string tokenizer_proto_;
};
WriterInterface* NewSbsWriter(CompressorMode mode) {
@ -190,11 +201,17 @@ void SbsWriter::AddScales(const std::vector<float>& scales) {
impl_->AddScales(scales);
}
void SbsWriter::AddTokenizer(const std::string& tokenizer_path) {
impl_->AddTokenizer(tokenizer_path);
}
size_t SbsWriter::DebugNumBlobsAdded() const {
return impl_->DebugNumBlobsAdded();
}
int SbsWriter::Write(std::string path) { return impl_->Write(path); }
int SbsWriter::WriteWithConfig(std::string path, const ModelConfig* config) {
return impl_->WriteWithConfig(path, config);
}
} // namespace gcpp
#endif // HWY_ONCE

View File

@ -8,6 +8,7 @@
#include "absl/types/span.h"
#include "compression/shared.h"
#include "gemma/configs.h"
#include "gemma/tensor_index.h"
namespace gcpp {
@ -36,10 +37,12 @@ class SbsWriter {
void InsertBfloat16(std::string name, absl::Span<const float> weights);
void InsertFloat(std::string name, absl::Span<const float> weights);
void AddScales(const std::vector<float>& scales);
void AddTokenizer(const std::string& tokenizer_path);
size_t DebugNumBlobsAdded() const;
int Write(std::string path);
int Write(std::string path) { return WriteWithConfig(path, nullptr); }
int WriteWithConfig(std::string path, const ModelConfig* config);
private:
// Isolates Highway-dispatched types and other internals from CLIF.

View File

@ -50,6 +50,8 @@ PYBIND11_MODULE(compression, m) {
.def("insert_bf16", wrap_span<&SbsWriter::InsertBfloat16>)
.def("insert_float", wrap_span<&SbsWriter::InsertFloat>)
.def("add_scales", &SbsWriter::AddScales)
.def("add_tokenizer", &SbsWriter::AddTokenizer)
.def("debug_num_blobs_added", &SbsWriter::DebugNumBlobsAdded)
.def("write", &SbsWriter::Write);
.def("write", &SbsWriter::Write)
.def("write_with_config", &SbsWriter::WriteWithConfig);
}

View File

@ -198,6 +198,11 @@ constexpr bool IsNuqStream() {
// Instruction-tuned models require extra 'turn structure' tokens in prompts.
enum class PromptWrapping { GEMMA_IT, GEMMA_PT, PALIGEMMA };
inline bool EnumValid(PromptWrapping type) {
return static_cast<int>(type) >= 0 &&
static_cast<int>(type) <= static_cast<int>(PromptWrapping::PALIGEMMA);
}
// Tensor types for loading weights. Note that not all types are supported as
// weights for a model, but can be used for other purposes, such as types for
// ModelWeightsPtrs. When adding a new type that is supported, also
@ -206,6 +211,11 @@ enum class Type { kUnknown, kF32, kBF16, kSFP, kNUQ, kF64, kC64, kU128 };
constexpr const char* kTypeStrings[] = {"unknown", "f32", "bf16", "sfp",
"nuq", "f64", "c64", "u128"};
inline bool EnumValid(Type type) {
return static_cast<int>(type) >= 0 &&
static_cast<int>(type) <= static_cast<int>(Type::kU128);
}
// Returns a Type enum for the type of the template parameter.
template <typename PackedT>
Type TypeEnum() {

View File

@ -92,9 +92,9 @@ static AppArgs MakeAppArgs(int argc, char** argv) {
return AppArgs(argc, argv);
}
GemmaEnv::GemmaEnv(int argc, char** argv)
: GemmaEnv(LoaderArgs(argc, argv), InferenceArgs(argc, argv),
MakeAppArgs(argc, argv)) {}
GemmaEnv::GemmaEnv(int argc, char** argv, bool model_type_required)
: GemmaEnv(LoaderArgs(argc, argv, model_type_required),
InferenceArgs(argc, argv), MakeAppArgs(argc, argv)) {}
QueryResult GemmaEnv::QueryModel(const std::vector<int>& tokens) {
QueryResult result;
@ -270,7 +270,9 @@ void ShowHelp(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
"specify 3 required model loading arguments:\n"
" --tokenizer\n"
" --weights\n"
" --model.\n";
" --model,\n"
" or with the newer weights format, specify just:\n"
" --weights\n";
std::cerr << "\n*Example Usage*\n\n./gemma --tokenizer tokenizer.spm "
"--weights 2b-it-sfp.sbs --model 2b-it\n";
std::cerr << "\n*Model Loading Arguments*\n\n";

View File

@ -44,7 +44,7 @@ struct QueryResult {
class GemmaEnv {
public:
// Calls the other constructor with *Args arguments initialized from argv.
GemmaEnv(int argc, char** argv);
GemmaEnv(int argc, char** argv, bool model_type_required = false);
GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference,
const AppArgs& app);

View File

@ -15,6 +15,7 @@
#include "gemma/configs.h"
#include <cstddef>
#include <iostream>
#include "hwy/base.h"
@ -22,9 +23,9 @@
namespace gcpp {
static ModelConfig ConfigNoSSM() {
ModelConfig config = {.scale_names = {"att_ein", "qkv_ein", "gr_lin_x_w",
"gr_lin_y_w", "gr_lin_out_w",
"gr_gate_w", "gating_ein", "linear_w"}};
ModelConfig config;
config.scale_names = {"att_ein", "qkv_ein", "gr_lin_x_w", "gr_lin_y_w",
"gr_lin_out_w", "gr_gate_w", "gating_ein", "linear_w"};
return config;
}
@ -37,6 +38,18 @@ static ModelConfig ConfigBaseGemmaV2() {
return config;
}
static LayerConfig LayerConfigGemma2_27B(size_t model_dim) {
LayerConfig config;
config.model_dim = model_dim;
config.ff_hidden_dim = 16 * 4608 / 2; // = 36864
config.heads = 32;
config.kv_heads = 16;
config.qkv_dim = 128;
config.optimized_gating = false;
config.post_norm = PostNormType::Scale;
return config;
}
static ModelConfig ConfigGemma2_27B() {
ModelConfig config = ConfigBaseGemmaV2();
config.model_name = "Gemma2_27B";
@ -44,13 +57,7 @@ static ModelConfig ConfigGemma2_27B() {
config.model_dim = 4608;
config.vocab_size = kVocabSize;
config.seq_len = 8192;
LayerConfig layer_config = {.model_dim = config.model_dim,
.ff_hidden_dim = 16 * 4608 / 2, // = 36864
.heads = 32,
.kv_heads = 16,
.qkv_dim = 128,
.optimized_gating = false,
.post_norm = PostNormType::Scale};
LayerConfig layer_config = LayerConfigGemma2_27B(config.model_dim);
config.layer_configs = {46, layer_config};
config.num_tensor_scales = 4 * config.layer_configs.size();
config.query_scale = QueryScaleType::SqrtModelDimDivNumHeads;
@ -59,6 +66,18 @@ static ModelConfig ConfigGemma2_27B() {
return config;
}
static LayerConfig LayerConfigGemma2_9B(size_t model_dim) {
LayerConfig config;
config.model_dim = model_dim;
config.ff_hidden_dim = 8 * 3584 / 2; // = 14336
config.heads = 16;
config.kv_heads = 8;
config.qkv_dim = 256;
config.optimized_gating = false;
config.post_norm = PostNormType::Scale;
return config;
}
static ModelConfig ConfigGemma2_9B() {
ModelConfig config = ConfigBaseGemmaV2();
config.model_name = "Gemma2_9B";
@ -66,13 +85,7 @@ static ModelConfig ConfigGemma2_9B() {
config.model_dim = 3584;
config.vocab_size = kVocabSize;
config.seq_len = 8192;
LayerConfig layer_config = {.model_dim = config.model_dim,
.ff_hidden_dim = 8 * 3584 / 2, // = 14336
.heads = 16,
.kv_heads = 8,
.qkv_dim = 256,
.optimized_gating = false,
.post_norm = PostNormType::Scale};
LayerConfig layer_config = LayerConfigGemma2_9B(config.model_dim);
config.layer_configs = {42, layer_config};
config.num_tensor_scales = 4 * config.layer_configs.size();
config.query_scale = QueryScaleType::SqrtKeySize;
@ -81,6 +94,18 @@ static ModelConfig ConfigGemma2_9B() {
return config;
}
static LayerConfig LayerConfigGemma2_2B(size_t model_dim) {
LayerConfig config;
config.model_dim = model_dim;
config.ff_hidden_dim = 8 * 2304 / 2; // = 9216
config.heads = 8;
config.kv_heads = 4;
config.qkv_dim = 256;
config.optimized_gating = false;
config.post_norm = PostNormType::Scale;
return config;
}
static ModelConfig ConfigGemma2_2B() {
ModelConfig config = ConfigBaseGemmaV2();
config.model_name = "Gemma2_2B";
@ -88,13 +113,7 @@ static ModelConfig ConfigGemma2_2B() {
config.model_dim = 2304;
config.vocab_size = kVocabSize;
config.seq_len = 8192;
LayerConfig layer_config = {.model_dim = config.model_dim,
.ff_hidden_dim = 8 * 2304 / 2, // = 9216
.heads = 8,
.kv_heads = 4,
.qkv_dim = 256,
.optimized_gating = false,
.post_norm = PostNormType::Scale};
LayerConfig layer_config = LayerConfigGemma2_2B(config.model_dim);
config.layer_configs = {26, layer_config};
config.num_tensor_scales = 4 * config.layer_configs.size();
config.query_scale = QueryScaleType::SqrtKeySize;
@ -103,6 +122,16 @@ static ModelConfig ConfigGemma2_2B() {
return config;
}
static LayerConfig LayerConfigGemma7B(size_t model_dim) {
LayerConfig config;
config.model_dim = model_dim;
config.ff_hidden_dim = 16 * 3072 / 2; // = 24576
config.heads = 16;
config.kv_heads = 16;
config.qkv_dim = 256;
return config;
}
static ModelConfig ConfigGemma7B() {
ModelConfig config = ConfigBaseGemmaV1();
config.model_name = "Gemma7B";
@ -110,13 +139,7 @@ static ModelConfig ConfigGemma7B() {
config.model_dim = 3072;
config.vocab_size = kVocabSize;
config.seq_len = kSeqLen;
LayerConfig layer_config = {
.model_dim = config.model_dim,
.ff_hidden_dim = 16 * 3072 / 2, // = 24576
.heads = 16,
.kv_heads = 16,
.qkv_dim = 256,
};
LayerConfig layer_config = LayerConfigGemma7B(config.model_dim);
config.layer_configs = {28, layer_config};
config.num_tensor_scales = 4 * config.layer_configs.size();
config.query_scale = QueryScaleType::SqrtKeySize;
@ -124,6 +147,16 @@ static ModelConfig ConfigGemma7B() {
return config;
}
static LayerConfig LayerConfigGemma2B(size_t model_dim) {
LayerConfig config;
config.model_dim = model_dim;
config.ff_hidden_dim = 16 * 2048 / 2; // = 16384
config.heads = 8;
config.kv_heads = 1;
config.qkv_dim = 256;
return config;
}
static ModelConfig ConfigGemma2B() {
ModelConfig config = ConfigBaseGemmaV1();
config.model_name = "Gemma2B";
@ -131,19 +164,23 @@ static ModelConfig ConfigGemma2B() {
config.model_dim = 2048;
config.vocab_size = kVocabSize;
config.seq_len = kSeqLen;
LayerConfig layer_config = {
.model_dim = config.model_dim,
.ff_hidden_dim = 16 * 2048 / 2, // = 16384
.heads = 8,
.kv_heads = 1,
.qkv_dim = 256,
};
LayerConfig layer_config = LayerConfigGemma2B(config.model_dim);
config.layer_configs = {18, layer_config};
config.num_tensor_scales = 4 * config.layer_configs.size();
config.attention_window_sizes = FixedAttentionWindowSizes<18>(kSeqLen);
return config;
}
static LayerConfig LayerConfigGemmaTiny(size_t model_dim) {
LayerConfig config;
config.model_dim = model_dim;
config.ff_hidden_dim = 256;
config.heads = 4;
config.kv_heads = 1;
config.qkv_dim = 16;
return config;
}
static ModelConfig ConfigGemmaTiny() {
ModelConfig config = ConfigNoSSM();
config.model_name = "GemmaTiny";
@ -151,13 +188,7 @@ static ModelConfig ConfigGemmaTiny() {
config.model_dim = 128;
config.vocab_size = 64;
config.seq_len = 32;
LayerConfig layer_config = {
.model_dim = config.model_dim,
.ff_hidden_dim = 256,
.heads = 4,
.kv_heads = 1,
.qkv_dim = 16,
};
LayerConfig layer_config = LayerConfigGemmaTiny(config.model_dim);
config.layer_configs = {3, layer_config};
config.num_tensor_scales = 4 * config.layer_configs.size();
config.query_scale = QueryScaleType::SqrtKeySize;
@ -167,6 +198,24 @@ static ModelConfig ConfigGemmaTiny() {
return config;
}
static LayerConfig LayerConfigGriffin2B(size_t model_dim) {
LayerConfig config;
config.model_dim = model_dim;
config.griffin_dim = model_dim;
config.ff_hidden_dim = 7680;
config.heads = 10;
config.kv_heads = 1;
config.qkv_dim = 256;
config.conv1d_width = 4;
config.ff_biases = true;
config.softmax_attn_output_biases = true;
config.optimized_gating = false;
config.type = LayerAttentionType::kGriffinRecurrentBlock;
config.activation = ActivationType::Gelu;
config.post_qk = PostQKType::HalfRope;
return config;
}
static ModelConfig ConfigGriffin2B() {
ModelConfig config = ConfigNoSSM();
config.model_name = "Griffin2B";
@ -176,21 +225,7 @@ static ModelConfig ConfigGriffin2B() {
config.model_dim = 2560;
config.vocab_size = kVocabSize;
config.seq_len = 2048;
LayerConfig layer_config = {
.model_dim = config.model_dim,
.griffin_dim = config.model_dim,
.ff_hidden_dim = 7680,
.heads = 10,
.kv_heads = 1,
.qkv_dim = 256,
.conv1d_width = 4,
.ff_biases = true,
.softmax_attn_output_biases = true,
.optimized_gating = false,
.type = LayerAttentionType::kGriffinRecurrentBlock,
.activation = ActivationType::Gelu,
.post_qk = PostQKType::HalfRope,
};
LayerConfig layer_config = LayerConfigGriffin2B(config.model_dim);
config.layer_configs = {26, layer_config};
for (size_t i = 2; i < config.layer_configs.size(); i += 3) {
config.layer_configs[i].type = LayerAttentionType::kGemma;
@ -204,6 +239,18 @@ static ModelConfig ConfigGriffin2B() {
return config;
}
static LayerConfig LayerConfigVit(size_t model_dim) {
LayerConfig config;
config.model_dim = model_dim;
config.ff_hidden_dim = 4304;
config.heads = 16;
config.kv_heads = 16;
config.qkv_dim = 72;
config.ff_biases = true;
config.type = LayerAttentionType::kVit;
return config;
}
// Adds a ViT config (SigLIP SoViT ViT, used in PaliGemma) to the model config.
static void AddVitConfig(ModelConfig& config, size_t image_size = 224) {
config.vit_model_dim = 1152;
@ -215,15 +262,7 @@ static void AddVitConfig(ModelConfig& config, size_t image_size = 224) {
}
const size_t num_patches = config.image_size / config.patch_width;
config.vit_seq_len = num_patches * num_patches;
LayerConfig vit_layer_config = {
.model_dim = config.vit_model_dim,
.ff_hidden_dim = 4304,
.heads = 16,
.kv_heads = 16,
.qkv_dim = 72,
.ff_biases = true,
.type = LayerAttentionType::kVit,
};
LayerConfig vit_layer_config = LayerConfigVit(config.vit_model_dim);
config.vit_layer_configs = {27, vit_layer_config};
config.num_vit_scales = 4 * config.vit_layer_configs.size();
}

View File

@ -26,6 +26,7 @@
#include <unordered_set>
#include <vector>
#include "compression/fields.h" // IFieldsVisitor
#include "compression/shared.h" // BF16
namespace gcpp {
@ -52,52 +53,83 @@ enum class LayerAttentionType {
kVit,
};
inline bool EnumValid(LayerAttentionType type) {
return static_cast<int>(type) >= 0 &&
static_cast<int>(type) <= static_cast<int>(LayerAttentionType::kVit);
}
// Post attention and ffw normalization type.
enum class PostNormType {
None,
Scale,
};
inline bool EnumValid(PostNormType type) {
return static_cast<int>(type) >= 0 &&
static_cast<int>(type) <= static_cast<int>(PostNormType::Scale);
}
// Post qk projection operation type.
enum class PostQKType {
Rope,
HalfRope,
};
inline bool EnumValid(PostQKType type) {
return static_cast<int>(type) >= 0 &&
static_cast<int>(type) <= static_cast<int>(PostQKType::HalfRope);
}
// FFW activation function.
enum class ActivationType {
Gelu,
};
inline bool EnumValid(ActivationType type) {
return static_cast<int>(type) >= 0 &&
static_cast<int>(type) <= static_cast<int>(ActivationType::Gelu);
}
// Attention query scale.
enum class QueryScaleType {
SqrtKeySize,
SqrtModelDimDivNumHeads,
};
inline bool EnumValid(QueryScaleType type) {
return static_cast<int>(type) >= 0 &&
static_cast<int>(type) <=
static_cast<int>(QueryScaleType::SqrtModelDimDivNumHeads);
}
// Residual connection type.
enum class ResidualType {
Add,
};
inline bool EnumValid(ResidualType type) {
return static_cast<int>(type) >= 0 &&
static_cast<int>(type) <= static_cast<int>(ResidualType::Add);
}
template <size_t kNum>
std::vector<LayerAttentionType> FixedLayerConfig(LayerAttentionType type) {
return std::vector<LayerAttentionType>(kNum, type);
}
template <size_t kNum>
std::vector<size_t> FixedAttentionWindowSizes(size_t window_size) {
return std::vector<size_t>(kNum, window_size);
template <uint32_t kNum>
std::vector<uint32_t> FixedAttentionWindowSizes(uint32_t window_size) {
return std::vector<uint32_t>(kNum, window_size);
}
// Repeat window_size_pattern for kNum / kPatternSize times.
template <size_t kNum, size_t kPatternSize>
std::vector<size_t> RepeatedAttentionWindowSizes(
const std::array<size_t, kPatternSize>& window_size_pattern) {
template <uint32_t kNum, uint32_t kPatternSize>
std::vector<uint32_t> RepeatedAttentionWindowSizes(
const std::array<uint32_t, kPatternSize>& window_size_pattern) {
static_assert(kNum % kPatternSize == 0,
"kNum must be a multiple of kPatternSize");
std::vector<size_t> window_size_configs(kNum);
for (size_t i = 0; i < kNum; ++i) {
std::vector<uint32_t> window_size_configs(kNum);
for (uint32_t i = 0; i < kNum; ++i) {
window_size_configs[i] = window_size_pattern[i % kPatternSize];
}
return window_size_configs;
@ -130,7 +162,14 @@ static constexpr Model kAllModels[] = {
Model::PALIGEMMA2_10B_224, Model::PALIGEMMA2_10B_448,
};
struct LayerConfig {
inline bool EnumValid(Model model) {
for (Model m : kAllModels) {
if (m == model) return true;
}
return false;
}
struct LayerConfig : public IFields {
// Returns true if *this and other are equal.
// If partial is true, then we don't check for items that are only set after
// the tensors are loaded from the checkpoint.
@ -146,13 +185,32 @@ struct LayerConfig {
// but for MHA we store them as Q,K,V, Q,K,V, .. instead of Q..Q, K..K, V..V.
size_t QStride() const { return qkv_dim * (IsMHA() ? 3 : 1); }
size_t model_dim = 0;
size_t griffin_dim = 0;
size_t ff_hidden_dim = 0;
size_t heads = 0;
size_t kv_heads = 0;
size_t qkv_dim = 0;
size_t conv1d_width = 0; // griffin only
const char* Name() const override { return "LayerConfig"; }
void VisitFields(IFieldsVisitor& visitor) override {
visitor(model_dim);
visitor(griffin_dim);
visitor(ff_hidden_dim);
visitor(heads);
visitor(kv_heads);
visitor(qkv_dim);
visitor(conv1d_width);
visitor(ff_biases);
visitor(softmax_attn_output_biases);
visitor(optimized_gating);
visitor(post_norm);
visitor(type);
visitor(activation);
visitor(post_qk);
}
uint32_t model_dim = 0;
uint32_t griffin_dim = 0;
uint32_t ff_hidden_dim = 0;
uint32_t heads = 0;
uint32_t kv_heads = 0;
uint32_t qkv_dim = 0;
uint32_t conv1d_width = 0; // griffin only
bool ff_biases = false;
bool softmax_attn_output_biases = false;
bool optimized_gating = true;
@ -162,7 +220,7 @@ struct LayerConfig {
PostQKType post_qk = PostQKType::Rope;
};
struct ModelConfig {
struct ModelConfig : public IFields {
// Returns true if *this and other are equal.
// If partial is true, then we don't check for items that are only set after
// the tensors are loaded from the checkpoint.
@ -191,39 +249,68 @@ struct ModelConfig {
}
size_t NumHeads() const {
size_t num_heads = 0;
uint32_t num_heads = 0;
for (const auto& layer_config : layer_configs) {
num_heads = std::max(num_heads, layer_config.heads);
}
return num_heads;
}
const char* Name() const override { return "ModelConfig"; }
void VisitFields(IFieldsVisitor& visitor) override {
visitor(model_family_version);
visitor(model_name);
visitor(model);
visitor(wrapping);
visitor(weight);
visitor(num_layers);
visitor(model_dim);
visitor(vocab_size);
visitor(seq_len);
visitor(num_tensor_scales);
visitor(att_cap);
visitor(final_cap);
visitor(absolute_pe);
visitor(use_local_attention);
visitor(query_scale);
visitor(layer_configs);
visitor(attention_window_sizes);
visitor(norm_num_groups);
visitor(vit_model_dim);
visitor(vit_seq_len);
visitor(num_vit_scales);
visitor(vit_layer_configs);
visitor(patch_width);
visitor(image_size);
}
std::string model_name;
Model model;
PromptWrapping wrapping;
Type weight;
size_t num_layers = 0;
size_t model_dim = 0;
size_t vit_model_dim = 0;
size_t vocab_size = 0;
size_t seq_len = 0;
size_t vit_seq_len = 0;
size_t num_tensor_scales = 0;
size_t num_vit_scales = 0;
Model model = Model::UNKNOWN;
PromptWrapping wrapping = PromptWrapping::GEMMA_PT;
Type weight = Type::kUnknown;
uint32_t num_layers = 0;
uint32_t model_dim = 0;
uint32_t vit_model_dim = 0;
uint32_t vocab_size = 0;
uint32_t seq_len = 0;
uint32_t vit_seq_len = 0;
uint32_t num_tensor_scales = 0;
uint32_t num_vit_scales = 0;
float att_cap = 0.0f;
float final_cap = 0.0f;
bool absolute_pe = false;
bool use_local_attention = false; // griffin only
QueryScaleType query_scale = QueryScaleType::SqrtKeySize;
std::vector<LayerConfig> layer_configs;
std::vector<size_t> attention_window_sizes;
std::vector<uint32_t> attention_window_sizes;
std::vector<LayerConfig> vit_layer_configs;
std::unordered_set<std::string> scale_names;
int norm_num_groups = 1;
int model_family_version = 1;
uint32_t norm_num_groups = 1;
uint32_t model_family_version = 1;
// Dimensions related to image processing.
size_t patch_width = 14;
size_t image_size = 224;
uint32_t patch_width = 14;
uint32_t image_size = 224;
};
// Returns the config for the given model.

View File

@ -2,9 +2,12 @@
#include <array>
#include <cstddef>
#include <cstdint>
#include <type_traits>
#include <vector>
#include "gtest/gtest.h"
#include "hwy/aligned_allocator.h"
namespace gcpp {
@ -412,8 +415,17 @@ void AssertMatch(const ModelConfig& config) {
ASSERT_EQ(TConfig::kNumTensorScales, config.num_tensor_scales);
}
ModelConfig RoundTripSerialize(const ModelConfig& config) {
std::vector<uint32_t> config_buffer = config.Write();
ModelConfig deserialized;
deserialized.Read(hwy::Span<const uint32_t>(config_buffer), 0);
return deserialized;
}
TEST(ConfigsTest, OldConfigGemma2B) {
AssertMatch<OldConfigGemma2B<float>>(ConfigFromModel(Model::GEMMA_2B));
ModelConfig config = RoundTripSerialize(ConfigFromModel(Model::GEMMA_2B));
AssertMatch<OldConfigGemma2B<float>>(config);
}
TEST(ConfigsTest, OldConfigGemma7B) {

View File

@ -23,6 +23,7 @@
#include <stdlib.h>
#include <string.h>
#include <string>
#include <utility> // std::move
#include <vector>
@ -40,13 +41,21 @@ namespace gcpp {
Gemma::Gemma(const Path& tokenizer_path, const Path& weights,
const ModelInfo& info, NestedPools& pools)
: pools_(pools), tokenizer_(tokenizer_path), info_(info) {
model_.Load(weights, info.model, info.weight, pools_.Pool());
: pools_(pools), tokenizer_(tokenizer_path) {
model_.Load(weights, info.model, info.weight, info.wrapping, pools_.Pool(),
/*tokenizer_proto=*/nullptr);
}
Gemma::Gemma(const Path& weights, NestedPools& pools) : pools_(pools) {
std::string tokenizer_proto;
model_.Load(weights, Model::UNKNOWN, Type::kUnknown, PromptWrapping::GEMMA_IT,
pools_.Pool(), &tokenizer_proto);
tokenizer_.Deserialize(tokenizer_proto);
}
Gemma::Gemma(GemmaTokenizer&& tokenizer, const ModelInfo& info,
NestedPools& pools)
: pools_(pools), tokenizer_(std::move(tokenizer)), info_(info) {
: pools_(pools), tokenizer_(std::move(tokenizer)) {
HWY_ASSERT(info.weight == Type::kF32);
model_.Allocate(info.model, info.weight, pools_.Pool());
}
@ -166,7 +175,7 @@ void RangeChecks(const ModelConfig& weights_config,
if (!weights_config.use_local_attention) {
if (max_generated_tokens > weights_config.seq_len) {
fprintf(stderr,
"WARNING: max_generated_tokens %zu > kSeqLen %zu, truncating.\n",
"WARNING: max_generated_tokens %zu > kSeqLen %u, truncating.\n",
max_generated_tokens, weights_config.seq_len);
max_generated_tokens = weights_config.seq_len;
}

View File

@ -190,18 +190,28 @@ struct TimingInfo {
class Gemma {
public:
// Reads old format weights file and tokenizer file.
Gemma(const Path& tokenizer_path, const Path& weights, const ModelInfo& info,
NestedPools& pools);
// Reads new format weights file that contains everything in a single file.
Gemma(const Path& weights, NestedPools& pools);
// Allocates weights, caller is responsible for filling them.
Gemma(GemmaTokenizer&& tokenizer, const ModelInfo& info, NestedPools& pools);
~Gemma();
const ModelConfig& GetModelConfig() const { return model_.Config(); }
const ModelInfo& Info() const { return info_; }
ModelInfo Info() const {
return ModelInfo({.model = model_.Config().model,
.wrapping = model_.Config().wrapping,
.weight = model_.Config().weight});
}
const GemmaTokenizer& Tokenizer() const { return tokenizer_; }
const ModelWeightsStorage& Weights() const { return model_; }
ModelWeightsStorage& MutableWeights() { return model_; }
void Save(const Path& weights, hwy::ThreadPool& pool) {
std::string tokenizer_proto = tokenizer_.Serialize();
model_.Save(tokenizer_proto, weights, pool);
}
// `pos` is the position in the KV cache. Users are responsible for
// incrementing it in the `*StreamFunc`, or setting to zero for single-turn.
@ -241,7 +251,6 @@ class Gemma {
GemmaTokenizer tokenizer_;
// Type-erased so that this can be defined in the header.
ModelWeightsStorage model_;
ModelInfo info_;
};
// Adds BOS token and possibly 'turn' annotations, which depend on `info`

View File

@ -53,7 +53,7 @@ KVCache KVCache::Create(const ModelConfig& weights_config,
LayerAttentionType::kGriffinRecurrentBlock);
// TODO(patrickms): Add query batching support for Griffin.
if (num_griffin_layers > 0) {
size_t conv1d_width = 0;
uint32_t conv1d_width = 0;
for (const auto& layer_config : weights_config.layer_configs) {
conv1d_width = std::max(conv1d_width, layer_config.conv1d_width);
}

View File

@ -482,6 +482,8 @@ std::vector<TensorInfo> LLMLayerTensors(const ModelConfig& config,
.name = "att_w",
.source_names = {"attn/attn_vec_einsum/w",
"attention_block/proj_final/kernel"},
.preshape = {layer_config.heads, layer_config.qkv_dim,
config.model_dim},
.axes = {2, 0, 1},
.shape = {config.model_dim, layer_config.heads, layer_config.qkv_dim},
.cols_take_extra_dims = true,

View File

@ -56,7 +56,8 @@ TEST(TensorIndexTest, FindName) {
// Test that the MatPtr can be constructed from the TensorInfo,
// and that the dimensions match.
MatPtrT<SfpStream> mat_ptr(tensor.Name(), tensor_index);
EXPECT_EQ(tensor.Name(), mat_ptr.Name()) << "on tensor " << name;
EXPECT_STREQ(tensor.Name(), mat_ptr.Name())
<< "on tensor " << name;
EXPECT_EQ(tensor.Rows(), mat_ptr.Rows()) << "on tensor " << name;
EXPECT_EQ(tensor.Cols(), mat_ptr.Cols()) << "on tensor " << name;
++num_found;

View File

@ -44,6 +44,17 @@ class GemmaTokenizer::Impl {
HWY_ABORT("Failed to load the tokenizer file.");
}
}
// Loads the tokenizer from a serialized proto.
explicit Impl(const std::string& tokenizer_proto) {
PROFILER_ZONE("Startup.tokenizer");
spp_ = std::make_unique<sentencepiece::SentencePieceProcessor>();
if (!spp_->LoadFromSerializedProto(tokenizer_proto).ok()) {
fprintf(stderr, "serialized proto size=%zu.\n", tokenizer_proto.size());
HWY_ABORT("Failed to load the tokenizer from serialized proto.");
}
}
std::string Serialize() const { return spp_->serialized_model_proto(); }
bool Encode(const std::string& input,
std::vector<std::string>* pieces) const {
@ -81,6 +92,12 @@ GemmaTokenizer::~GemmaTokenizer() = default;
GemmaTokenizer::GemmaTokenizer(GemmaTokenizer&& other) = default;
GemmaTokenizer& GemmaTokenizer::operator=(GemmaTokenizer&& other) = default;
std::string GemmaTokenizer::Serialize() const { return impl_->Serialize(); }
void GemmaTokenizer::Deserialize(const std::string& tokenizer_proto) {
impl_ = std::make_unique<Impl>(tokenizer_proto);
}
bool GemmaTokenizer::Encode(const std::string& input,
std::vector<std::string>* pieces) const {
return impl_->Encode(input, pieces);

View File

@ -41,6 +41,9 @@ class GemmaTokenizer {
GemmaTokenizer(GemmaTokenizer&& other);
GemmaTokenizer& operator=(GemmaTokenizer&& other);
std::string Serialize() const;
void Deserialize(const std::string& tokenizer_proto);
bool Encode(const std::string& input, std::vector<std::string>* pieces) const;
bool Encode(const std::string& input, std::vector<int>* ids) const;
bool Decode(const std::vector<int>& ids, std::string* detokenized) const;

View File

@ -19,11 +19,13 @@
#include <cstdlib>
#include <memory>
#include <random>
#include <string>
#include <vector>
#include "compression/blob_store.h"
#include "compression/compress.h"
#include "compression/io.h" // Path
#include "compression/shared.h"
#include "gemma/common.h"
#include "gemma/configs.h"
#include "hwy/aligned_allocator.h"
@ -47,7 +49,9 @@ struct TensorLoader {
};
BlobError ModelWeightsStorage::Load(const Path& weights, Model model_type,
Type weight_type, hwy::ThreadPool& pool) {
Type weight_type, PromptWrapping wrapping,
hwy::ThreadPool& pool,
std::string* tokenizer_proto) {
PROFILER_ZONE("Startup.LoadModelWeightsPtrs");
if (!weights.Exists()) {
HWY_ABORT("The model weights file '%s' does not exist.",
@ -56,17 +60,36 @@ BlobError ModelWeightsStorage::Load(const Path& weights, Model model_type,
ReadFromBlobStore loader(weights);
ForEachType fet =
loader.HaveToc() ? ForEachType::kLoadWithToc : ForEachType::kLoadNoToc;
std::vector<float> scales;
if (fet == ForEachType::kLoadWithToc) {
// TODO(rays): Load the config from the file.
HWY_ABORT("TOC not supported yet.");
BlobError err = loader.LoadConfig(config_);
if (err != 0 || config_.model_dim == 0) {
fprintf(stderr, "Failed to load model config: %d\n", err);
return err;
}
if (tokenizer_proto != nullptr) {
err = loader.LoadTokenizer(*tokenizer_proto);
if (err != 0) {
fprintf(stderr, "Failed to load tokenizer: %d\n", err);
return err;
}
}
} else {
if (weight_type == Type::kUnknown || model_type == Model::UNKNOWN) {
fprintf(stderr,
"weight type (%d) and model type (%d) must be specified when "
"no config is present in weights file\n",
static_cast<int>(weight_type), static_cast<int>(model_type));
return __LINE__;
}
// No Toc-> no config.
config_ = ConfigFromModel(model_type);
config_.weight = weight_type;
config_.wrapping = wrapping;
scales.resize(config_.num_tensor_scales + config_.num_vit_scales);
}
CreateForType(weight_type, pool);
CreateForType(config_.weight, pool);
CallForModelWeightT<TensorLoader>(fet, loader);
std::vector<float> scales(config_.num_tensor_scales + config_.num_vit_scales);
if (!scales.empty()) {
loader.LoadScales(scales.data(), scales.size());
}
@ -85,6 +108,34 @@ BlobError ModelWeightsStorage::Load(const Path& weights, Model model_type,
return 0;
}
template <typename T>
struct TensorSaver {
// Adds all the tensors to the blob writer.
void operator()(ModelWeightsPtrs<T>& weights, ForEachType fet,
WriteToBlobStore& writer) {
weights.ForEachTensor(
{&weights}, fet,
[&writer](const char* name, hwy::Span<MatPtr*> tensors) {
tensors[0]->CallUpcasted(writer, name);
});
}
};
BlobError ModelWeightsStorage::Save(const std::string& tokenizer,
const Path& weights,
hwy::ThreadPool& pool) {
WriteToBlobStore writer(pool);
ForEachType fet = ForEachType::kLoadWithToc;
CallForModelWeightT<TensorSaver>(fet, writer);
writer.AddTokenizer(tokenizer);
int err = writer.WriteAll(weights, &config_);
if (err != 0) {
fprintf(stderr, "Failed to load model weights: %d\n", err);
return err;
}
return 0;
}
void ModelWeightsStorage::Allocate(const ModelConfig& config, Type weight_type,
hwy::ThreadPool& pool) {
PROFILER_ZONE("Startup.AllocateModelWeightsPtrs");

View File

@ -522,7 +522,18 @@ class ModelWeightsStorage {
ModelWeightsStorage() = default;
~ModelWeightsStorage() = default;
// Loads the weights from a blob store file. Supports multi-file or
// single-file format. If the weights file contains a TOC, then it is in
// single-file format, and model_type, weight_type, training are ignored,
// and tokenizer_proto is required and written to.
// With a multi-file format, file, model_type, weight_type, training are
// required and tokenizer_proto is ignored.
BlobError Load(const Path& weights, Model model_type, Type weight_type,
PromptWrapping wrapping, hwy::ThreadPool& pool,
std::string* tokenizer_proto);
// Writes the weights to a blob store file, using the single-file format with
// a TOC and config included.
BlobError Save(const std::string& tokenizer, const Path& weights,
hwy::ThreadPool& pool);
void Allocate(Model model_type, Type weight_type, hwy::ThreadPool& pool) {
Allocate(ConfigFromModel(model_type), weight_type, pool);

View File

@ -26,7 +26,6 @@
#include <cmath>
#include <random>
#include "compression/compress.h"
#include "compression/shared.h"
#include "util/allocator.h"
#include "util/test_util.h"

View File

@ -25,6 +25,7 @@
#include <string>
#include "compression/io.h" // Path
#include "compression/shared.h"
#include "gemma/common.h"
#include "gemma/gemma.h" // For CreateGemma
#include "ops/matmul.h"
@ -125,7 +126,10 @@ static inline NestedPools CreatePools(const AppArgs& app) {
}
struct LoaderArgs : public ArgsBase<LoaderArgs> {
LoaderArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
LoaderArgs(int argc, char* argv[], bool required = true)
: model_type_required(required) {
InitAndParse(argc, argv);
}
LoaderArgs(const std::string& tokenizer_path, const std::string& weights_path,
const std::string& model) {
Init(); // Init sets to defaults, so assignments must come after Init().
@ -136,19 +140,25 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
// Returns error string or nullptr if OK.
const char* Validate() {
info_.model = Model::UNKNOWN;
info_.wrapping = PromptWrapping::GEMMA_PT;
info_.weight = Type::kUnknown;
if (const char* err = ParseModelTypeAndWrapping(model_type_str, info_.model,
info_.wrapping)) {
return err;
if (model_type_required) return err;
}
if (const char* err = ParseType(weight_type_str, info_.weight)) {
return err;
if (model_type_required) return err;
}
if (model_type_required) {
if (tokenizer.path.empty()) {
return "Missing --tokenizer flag, a file for the tokenizer is required.";
return "Missing --tokenizer flag, a file for the tokenizer is "
"required.";
}
if (!tokenizer.Exists()) {
return "Can't open file specified with --tokenizer flag.";
}
}
if (!compressed_weights.path.empty()) {
if (weights.path.empty()) {
weights = compressed_weights;
@ -172,11 +182,12 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
Path compressed_weights;
std::string model_type_str;
std::string weight_type_str;
bool model_type_required = true;
template <class Visitor>
void ForEach(const Visitor& visitor) {
visitor(tokenizer, "tokenizer", Path(),
"Path name of tokenizer model file.\n Required argument.");
"Path name of tokenizer model file.");
visitor(weights, "weights", Path(),
"Path name of model weights (.sbs) file.\n Required argument.");
visitor(compressed_weights, "compressed_weights", Path(),
@ -186,11 +197,9 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
"2b-pt = 2B parameters, pretrained\n 7b-it = 7B parameters "
"instruction-tuned\n 7b-pt = 7B parameters, pretrained\n "
"gr2b-it = griffin 2B parameters, instruction-tuned\n "
"gr2b-pt = griffin 2B parameters, pretrained\n "
" Required argument.");
"gr2b-pt = griffin 2B parameters, pretrained.");
visitor(weight_type_str, "weight_type", std::string("sfp"),
"Weight type\n f32 = float, bf16 = bfloat16, sfp = 8-bit FP\n"
" Required argument.");
"Weight type\n f32 = float, bf16 = bfloat16, sfp = 8-bit FP.");
}
// Uninitialized before Validate, must call after that.
@ -208,6 +217,12 @@ static inline Gemma CreateGemma(const LoaderArgs& loader, NestedPools& pools) {
static inline std::unique_ptr<Gemma> AllocateGemma(const LoaderArgs& loader,
NestedPools& pools) {
if (Type::kUnknown == loader.Info().weight ||
Model::UNKNOWN == loader.Info().model || loader.tokenizer.path.empty()) {
// Newer weights file format doesn't need tokenizer path or model/weight
// info.
return std::make_unique<Gemma>(loader.weights, pools);
}
return std::make_unique<Gemma>(loader.tokenizer, loader.weights,
loader.Info(), pools);
}