mirror of https://github.com/google/gemma.cpp.git
Added ability to load/save a complete model file, including tokenizer.
PiperOrigin-RevId: 707914366
This commit is contained in:
parent
5bc356f18f
commit
9d40f0117e
|
|
@ -245,6 +245,7 @@ cc_library(
|
|||
"gemma/tensor_index.h",
|
||||
],
|
||||
deps = [
|
||||
"//compression:fields",
|
||||
"//compression:sfp",
|
||||
"@highway//:hwy", # base.h
|
||||
"@highway//:thread_pool",
|
||||
|
|
@ -257,6 +258,7 @@ cc_test(
|
|||
deps = [
|
||||
":common",
|
||||
"@googletest//:gtest_main",
|
||||
"@highway//:hwy",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
@ -388,6 +390,7 @@ cc_library(
|
|||
":ops",
|
||||
":threading",
|
||||
"//compression:io",
|
||||
"//compression:sfp",
|
||||
"@highway//:hwy",
|
||||
],
|
||||
)
|
||||
|
|
|
|||
|
|
@ -390,13 +390,12 @@ static ModelConfig TestConfig() {
|
|||
config.model_dim = 32;
|
||||
config.vocab_size = 12;
|
||||
config.seq_len = 18;
|
||||
LayerConfig layer_config = {
|
||||
.model_dim = config.model_dim,
|
||||
.ff_hidden_dim = 48,
|
||||
.heads = 3,
|
||||
.kv_heads = 1,
|
||||
.qkv_dim = 12,
|
||||
};
|
||||
LayerConfig layer_config;
|
||||
layer_config.model_dim = config.model_dim;
|
||||
layer_config.ff_hidden_dim = 48;
|
||||
layer_config.heads = 3;
|
||||
layer_config.kv_heads = 1;
|
||||
layer_config.qkv_dim = 12;
|
||||
config.layer_configs = {2, layer_config};
|
||||
config.num_tensor_scales = 4 * config.layer_configs.size();
|
||||
config.query_scale = QueryScaleType::SqrtKeySize;
|
||||
|
|
|
|||
|
|
@ -191,13 +191,12 @@ static ModelConfig TestConfig() {
|
|||
config.model_dim = 32;
|
||||
config.vocab_size = 16;
|
||||
config.seq_len = 24;
|
||||
LayerConfig layer_config = {
|
||||
.model_dim = config.model_dim,
|
||||
.ff_hidden_dim = 64,
|
||||
.heads = 3,
|
||||
.kv_heads = 1,
|
||||
.qkv_dim = 16,
|
||||
};
|
||||
LayerConfig layer_config;
|
||||
layer_config.model_dim = config.model_dim;
|
||||
layer_config.ff_hidden_dim = 64;
|
||||
layer_config.heads = 3;
|
||||
layer_config.kv_heads = 1;
|
||||
layer_config.qkv_dim = 16;
|
||||
config.layer_configs = {2, layer_config};
|
||||
config.num_tensor_scales = 4 * config.layer_configs.size();
|
||||
config.query_scale = QueryScaleType::SqrtKeySize;
|
||||
|
|
|
|||
|
|
@ -58,7 +58,6 @@ cc_test(
|
|||
deps = [
|
||||
":fields",
|
||||
"@googletest//:gtest_main", # buildcleaner: keep
|
||||
"@highway//:hwy",
|
||||
"@highway//:hwy_test_util",
|
||||
],
|
||||
)
|
||||
|
|
@ -202,6 +201,7 @@ cc_library(
|
|||
deps = [
|
||||
":blob_store",
|
||||
":distortion",
|
||||
":fields",
|
||||
":io",
|
||||
":nuq",
|
||||
":sfp",
|
||||
|
|
@ -210,7 +210,6 @@ cc_library(
|
|||
"//:common",
|
||||
"@highway//:hwy",
|
||||
"@highway//:nanobenchmark",
|
||||
"@highway//:profiler",
|
||||
"@highway//:stats",
|
||||
"@highway//:thread_pool",
|
||||
],
|
||||
|
|
@ -261,6 +260,7 @@ cc_binary(
|
|||
"//:allocator",
|
||||
"//:args",
|
||||
"//:common",
|
||||
"//:tokenizer",
|
||||
"//:weights",
|
||||
"@highway//:hwy",
|
||||
"@highway//:thread_pool",
|
||||
|
|
@ -277,3 +277,14 @@ cc_binary(
|
|||
"@highway//:hwy_test_util",
|
||||
],
|
||||
)
|
||||
|
||||
cc_binary(
|
||||
name = "migrate_weights",
|
||||
srcs = ["migrate_weights.cc"],
|
||||
deps = [
|
||||
"//:app",
|
||||
"//:args",
|
||||
"//:benchmark_helper",
|
||||
"//:gemma_lib",
|
||||
],
|
||||
)
|
||||
|
|
|
|||
|
|
@ -22,10 +22,13 @@
|
|||
#include <stdio.h>
|
||||
|
||||
#include <cmath> // lroundf, only if COMPRESS_STATS
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "compression/blob_store.h"
|
||||
#include "compression/compress.h" // IWYU pragma: export
|
||||
#include "compression/distortion.h"
|
||||
#include "gemma/configs.h"
|
||||
#include "hwy/aligned_allocator.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
|
@ -673,36 +676,37 @@ HWY_INLINE float DecompressAndCall(D, const PackedSpan<const VT> v,
|
|||
// their scaling factors to BlobStore.
|
||||
class Compressor {
|
||||
public:
|
||||
explicit Compressor(hwy::ThreadPool& pool) : pool_(pool) {}
|
||||
explicit Compressor(hwy::ThreadPool& pool) : writer_(pool) {}
|
||||
|
||||
template <typename Packed>
|
||||
void operator()(MatPtrT<Packed>* compressed, const char* decorated_name,
|
||||
const float* HWY_RESTRICT weights) {
|
||||
size_t num_weights = compressed->NumElements();
|
||||
if (num_weights == 0 || weights == nullptr || compressed->Ptr() == nullptr)
|
||||
return;
|
||||
size_t num_compressed = compressed->NumElements();
|
||||
PackedSpan<Packed> packed = MakeSpan(compressed->data(), num_compressed);
|
||||
fprintf(stderr, "Compressing %s (%zuM), please wait\n", decorated_name,
|
||||
num_weights / (1000 * 1000));
|
||||
Compress(weights, num_weights, work_, packed, /*packed_ofs=*/0, pool_);
|
||||
const size_t num_bytes = packed.num * sizeof(Packed);
|
||||
writer_.Add(MakeKey(decorated_name), packed.ptr, num_bytes);
|
||||
Compress(weights, num_weights, work_, packed, /*packed_ofs=*/0,
|
||||
writer_.pool());
|
||||
writer_(compressed, decorated_name);
|
||||
}
|
||||
|
||||
void AddTokenizer(const std::string& tokenizer) {
|
||||
writer_.AddTokenizer(tokenizer);
|
||||
}
|
||||
|
||||
void AddScales(const float* scales, size_t len) {
|
||||
if (len) {
|
||||
MatPtrT<float> scales_ptr("scales", 0, 1);
|
||||
writer_.Add(MakeKey(scales_ptr.CacheName().c_str()), scales,
|
||||
len * sizeof(scales[0]));
|
||||
}
|
||||
writer_.AddScales(scales, len);
|
||||
}
|
||||
|
||||
BlobError WriteAll(hwy::ThreadPool& pool, const Path& blob_filename) {
|
||||
const BlobError err = writer_.WriteAll(pool, blob_filename);
|
||||
if (err != 0) {
|
||||
fprintf(stderr, "Failed to write blobs to %s (error %d)\n",
|
||||
blob_filename.path.c_str(), err);
|
||||
}
|
||||
return err;
|
||||
// Writes all blobs to disk in the given order. The config is optional and
|
||||
// if given, it is written to the file, along with the TOC, making it
|
||||
// single-file format. Otherwise, the file is written in the multi-file format
|
||||
// without a TOC.
|
||||
BlobError WriteAll(const Path& blob_filename, const ModelConfig* config) {
|
||||
return writer_.WriteAll(blob_filename, config);
|
||||
}
|
||||
|
||||
// Returns the number of blobs added.
|
||||
|
|
@ -710,8 +714,7 @@ class Compressor {
|
|||
|
||||
private:
|
||||
CompressWorkingSet work_;
|
||||
hwy::ThreadPool& pool_;
|
||||
BlobWriter writer_;
|
||||
WriteToBlobStore writer_;
|
||||
};
|
||||
|
||||
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@
|
|||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
|
|
@ -32,11 +33,13 @@
|
|||
|
||||
// IWYU pragma: begin_exports
|
||||
#include "compression/blob_store.h"
|
||||
#include "compression/fields.h"
|
||||
#include "compression/io.h"
|
||||
#include "compression/shared.h"
|
||||
#include "gemma/tensor_index.h"
|
||||
#include "util/basics.h"
|
||||
// IWYU pragma: end_exports
|
||||
#include "gemma/configs.h"
|
||||
#include "util/allocator.h"
|
||||
#include "hwy/per_target.h"
|
||||
#if COMPRESS_STATS
|
||||
|
|
@ -55,7 +58,7 @@ namespace gcpp {
|
|||
// fixed inner dimension and type.
|
||||
// It is designed to be put in a vector, and has default copy and operator=, so
|
||||
// it is easy to read/write a blob_store file.
|
||||
class MatPtr {
|
||||
class MatPtr : public IFields {
|
||||
public:
|
||||
// Full constructor for dynamic sizing.
|
||||
MatPtr(const std::string& name, Type type, size_t element_size, size_t rows,
|
||||
|
|
@ -73,36 +76,6 @@ class MatPtr {
|
|||
MatPtr() = default;
|
||||
virtual ~MatPtr();
|
||||
|
||||
// Number of hwy::uint128_t in a TOC entry.
|
||||
// Note that the old-style BlobStore files only have a list of keys and size.
|
||||
// The new-style BlobStore files have an entry called "toc" that contains a
|
||||
// vector of 4-tuples of
|
||||
// (name, type, (num_elements, element_size), (rows, cols)).
|
||||
// The listed blobs can be read directly into MatPtr from the BlobStore
|
||||
// file, without needing any external knowledge of the number of elements,
|
||||
// element size or type of the data.
|
||||
static constexpr size_t kNumU128InTocEntry = 4;
|
||||
|
||||
// Construct from a TOC entry.
|
||||
MatPtr(const hwy::uint128_t& key0, const hwy::uint128_t& key1,
|
||||
const hwy::uint128_t& key2, const hwy::uint128_t& key3)
|
||||
: name_(StringFromKey(key0)),
|
||||
type_(static_cast<Type>(key1.lo)),
|
||||
element_size_(key2.hi),
|
||||
num_elements_(key2.lo),
|
||||
rows_(key3.lo),
|
||||
cols_(key3.hi) {
|
||||
stride_ = cols_;
|
||||
}
|
||||
|
||||
// Adds the contents entry to the table of contents.
|
||||
void AddToToc(std::vector<hwy::uint128_t>& toc) const {
|
||||
toc.push_back(MakeKey(name_.c_str()));
|
||||
toc.push_back({static_cast<uint64_t>(type_), 0});
|
||||
toc.push_back({num_elements_, element_size_});
|
||||
toc.push_back({rows_, cols_});
|
||||
}
|
||||
|
||||
// Compatibility interface for CompressedArray.
|
||||
// TODO: remove.
|
||||
template <typename T>
|
||||
|
|
@ -124,7 +97,7 @@ class MatPtr {
|
|||
MatPtr& operator=(const MatPtr& other) = default;
|
||||
|
||||
// Returns the name of the blob.
|
||||
const std::string& Name() const { return name_; }
|
||||
const char* Name() const override { return name_.c_str(); }
|
||||
void SetName(const std::string& name) { name_ = name; }
|
||||
|
||||
// Returns the type of the blob.
|
||||
|
|
@ -163,12 +136,6 @@ class MatPtr {
|
|||
return name;
|
||||
}
|
||||
|
||||
// Adds the blob to the writer.
|
||||
void AddToWriter(BlobWriter& writer) const {
|
||||
fprintf(stderr, "Adding %s to writer\n", name_.c_str());
|
||||
writer.Add(MakeKey(name_.c_str()), ptr_, SizeBytes());
|
||||
}
|
||||
|
||||
// Sets all data to zero.
|
||||
void ZeroInit() {
|
||||
if (ptr_ == nullptr)
|
||||
|
|
@ -176,6 +143,17 @@ class MatPtr {
|
|||
hwy::ZeroBytes(ptr_, SizeBytes());
|
||||
}
|
||||
|
||||
void VisitFields(IFieldsVisitor& visitor) override {
|
||||
visitor(name_);
|
||||
visitor(type_);
|
||||
visitor(element_size_);
|
||||
visitor(num_elements_);
|
||||
visitor(rows_);
|
||||
visitor(cols_);
|
||||
visitor(scale_);
|
||||
visitor(stride_);
|
||||
}
|
||||
|
||||
// Calls func on the upcasted type. Since MatPtr by design is not templated,
|
||||
// here we provide a way to get to the derived type, provided that `Type()`
|
||||
// is one of the strings returned by `TypeName()`.
|
||||
|
|
@ -188,13 +166,13 @@ class MatPtr {
|
|||
// Should be the result of TypeEnum<T> for CallUpcasted() to work.
|
||||
Type type_;
|
||||
// sizeof(T)
|
||||
size_t element_size_ = 0;
|
||||
uint32_t element_size_ = 0;
|
||||
// Number of elements in the array.
|
||||
size_t num_elements_ = 0; // In element_size units.
|
||||
uint32_t num_elements_ = 0; // In element_size units.
|
||||
// Number of rows in the 2-d array (outer dimension).
|
||||
size_t rows_ = 0;
|
||||
uint32_t rows_ = 0;
|
||||
// Number of columns in the 2-d array (inner dimension).
|
||||
size_t cols_ = 0;
|
||||
uint32_t cols_ = 0;
|
||||
// Scaling to apply to each element.
|
||||
float scale_ = 1.0f;
|
||||
// Aligned data array. This is always a borrowed pointer. It should never be
|
||||
|
|
@ -202,7 +180,7 @@ class MatPtr {
|
|||
// and must outlive this object.
|
||||
void* ptr_ = nullptr;
|
||||
|
||||
size_t stride_;
|
||||
uint32_t stride_;
|
||||
};
|
||||
|
||||
// MatPtrT adds a single template argument to MatPtr for an explicit type.
|
||||
|
|
@ -394,31 +372,28 @@ class BlobToc {
|
|||
public:
|
||||
BlobToc() = default;
|
||||
|
||||
// Adds all blobs to the blob writer. Note that the blobs must have unique
|
||||
// names.
|
||||
static void AddAllToBlobWriter(const std::vector<MatStorage>& blobs,
|
||||
BlobWriter& writer) {
|
||||
std::vector<hwy::uint128_t> toc;
|
||||
for (const auto& blob : blobs) {
|
||||
blob.AddToToc(toc);
|
||||
blob.AddToWriter(writer);
|
||||
}
|
||||
writer.Add(MakeKey(kTocName), toc.data(), toc.size() * sizeof(toc[0]));
|
||||
}
|
||||
|
||||
// Loads the table of contents from the given reader.
|
||||
BlobError LoadToc(BlobReader& reader) {
|
||||
hwy::uint128_t toc_key = MakeKey(kTocName);
|
||||
size_t toc_size = reader.BlobSize(toc_key);
|
||||
if (toc_size != 0) {
|
||||
std::vector<hwy::uint128_t> toc(toc_size / sizeof(hwy::uint128_t));
|
||||
std::vector<uint32_t> toc(toc_size / sizeof(uint32_t));
|
||||
BlobError err = reader.ReadOne(toc_key, toc.data(), toc_size);
|
||||
if (err != 0) {
|
||||
fprintf(stderr, "Failed to read toc (error %d)\n", err);
|
||||
return err;
|
||||
}
|
||||
for (size_t i = 0; i < toc.size(); i += MatPtr::kNumU128InTocEntry) {
|
||||
AddToToc(MatPtr(toc[i], toc[i + 1], toc[i + 2], toc[i + 3]));
|
||||
size_t consumed = 0;
|
||||
size_t prev_consumed = static_cast<size_t>(-1);
|
||||
while (consumed < toc.size() && prev_consumed != consumed) {
|
||||
MatPtr blob;
|
||||
const IFields::ReadResult result =
|
||||
blob.Read(hwy::Span<const uint32_t>(toc), consumed);
|
||||
prev_consumed = consumed;
|
||||
consumed = result.pos;
|
||||
if (blob.NumElements() > 0) {
|
||||
AddToToc(blob);
|
||||
}
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
|
|
@ -437,11 +412,16 @@ class BlobToc {
|
|||
if (it == toc_map_.end()) return nullptr;
|
||||
return &toc_[it->second];
|
||||
}
|
||||
|
||||
private:
|
||||
// The name of the toc in the blob store file.
|
||||
static constexpr char kTocName[] = "toc";
|
||||
|
||||
// The name of the config in the blob store file.
|
||||
static constexpr char kConfigName[] = "config";
|
||||
|
||||
// The name of the tokenizer in the blob store file.
|
||||
static constexpr char kTokenizerName[] = "tokenizer";
|
||||
|
||||
private:
|
||||
// Adds the blob to the table of contents.
|
||||
void AddToToc(const MatPtr& blob) {
|
||||
HWY_ASSERT(!Contains(blob.Name()));
|
||||
|
|
@ -519,6 +499,68 @@ struct CompressWorkingSet {
|
|||
std::vector<CompressPerThread> tls;
|
||||
};
|
||||
|
||||
// Class to collect and write a set of tensors to a blob store file.
|
||||
class WriteToBlobStore {
|
||||
public:
|
||||
explicit WriteToBlobStore(hwy::ThreadPool& pool) : pool_(pool) {}
|
||||
|
||||
template <typename Packed>
|
||||
void operator()(MatPtrT<Packed>* compressed, const char* decorated_name) {
|
||||
if (compressed->Ptr() == nullptr) return;
|
||||
writer_.Add(MakeKey(decorated_name), compressed->Ptr(),
|
||||
compressed->SizeBytes());
|
||||
MatPtr renamed_tensor(*compressed);
|
||||
renamed_tensor.SetName(decorated_name);
|
||||
renamed_tensor.AppendTo(toc_);
|
||||
}
|
||||
|
||||
void AddTokenizer(const std::string& tokenizer) {
|
||||
writer_.Add(MakeKey(BlobToc::kTokenizerName), tokenizer.data(),
|
||||
tokenizer.size() * sizeof(tokenizer[0]));
|
||||
}
|
||||
|
||||
void AddScales(const float* scales, size_t len) {
|
||||
if (len) {
|
||||
MatPtrT<float> scales_ptr("scales", 0, 1);
|
||||
writer_.Add(MakeKey(scales_ptr.CacheName().c_str()), scales,
|
||||
len * sizeof(scales[0]));
|
||||
}
|
||||
}
|
||||
|
||||
// Writes all blobs to disk in the given order. The config is optional and
|
||||
// if given, it is written to the file, along with the TOC, making it
|
||||
// single-file format. Otherwise, the file is written in the multi-file format
|
||||
// without a TOC.
|
||||
BlobError WriteAll(const Path& blob_filename, const ModelConfig* config) {
|
||||
if (config) {
|
||||
writer_.Add(MakeKey(BlobToc::kTocName), toc_.data(),
|
||||
toc_.size() * sizeof(toc_[0]));
|
||||
config_buffer_ = config->Write();
|
||||
writer_.Add(MakeKey(BlobToc::kConfigName), config_buffer_.data(),
|
||||
config_buffer_.size() * sizeof(config_buffer_[0]));
|
||||
}
|
||||
const BlobError err = writer_.WriteAll(pool_, blob_filename);
|
||||
if (err != 0) {
|
||||
fprintf(stderr, "Failed to write blobs to %s (error %d)\n",
|
||||
blob_filename.path.c_str(), err);
|
||||
}
|
||||
return err;
|
||||
}
|
||||
|
||||
// Returns the number of blobs added.
|
||||
size_t DebugNumBlobsAdded() const { return writer_.DebugNumBlobsAdded(); }
|
||||
|
||||
hwy::ThreadPool& pool() { return pool_; }
|
||||
|
||||
protected:
|
||||
hwy::ThreadPool& pool_;
|
||||
|
||||
private:
|
||||
std::vector<uint32_t> toc_;
|
||||
BlobWriter writer_;
|
||||
std::vector<uint32_t> config_buffer_;
|
||||
};
|
||||
|
||||
// Functor called for each tensor, which loads them and their scaling factors
|
||||
// from BlobStore.
|
||||
class ReadFromBlobStore {
|
||||
|
|
@ -539,11 +581,40 @@ class ReadFromBlobStore {
|
|||
// Returns true if there is a TOC.
|
||||
bool HaveToc() const { return !file_toc_.Empty(); }
|
||||
|
||||
// Reads the config from the blob store file.
|
||||
BlobError LoadConfig(ModelConfig& config) {
|
||||
hwy::uint128_t config_key = MakeKey(BlobToc::kConfigName);
|
||||
size_t config_size = reader_.BlobSize(config_key);
|
||||
if (config_size == 0) return __LINE__;
|
||||
std::vector<uint32_t> config_buffer(config_size / sizeof(uint32_t));
|
||||
BlobError err =
|
||||
reader_.ReadOne(config_key, config_buffer.data(), config_size);
|
||||
if (err != 0) {
|
||||
fprintf(stderr, "Failed to read config (error %d)\n", err);
|
||||
return err;
|
||||
}
|
||||
config.Read(hwy::Span<const uint32_t>(config_buffer), 0);
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Reads the tokenizer from the blob store file.
|
||||
BlobError LoadTokenizer(std::string& tokenizer) {
|
||||
hwy::uint128_t key = MakeKey(BlobToc::kTokenizerName);
|
||||
size_t tokenizer_size = reader_.BlobSize(key);
|
||||
if (tokenizer_size == 0) return __LINE__;
|
||||
tokenizer.resize(tokenizer_size);
|
||||
;
|
||||
BlobError err = reader_.ReadOne(key, tokenizer.data(), tokenizer_size);
|
||||
if (err != 0) {
|
||||
fprintf(stderr, "Failed to read tokenizer (error %d)\n", err);
|
||||
return err;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Called for each tensor, enqueues read requests.
|
||||
void operator()(const char* name, hwy::Span<MatPtr*> tensors) {
|
||||
if (file_toc_.Empty() || file_toc_.Contains(name)) {
|
||||
if (tensors[0]->NumElements() == 0)
|
||||
fprintf(stderr, "Zero elements for %s\n", name);
|
||||
model_toc_.push_back(tensors[0]);
|
||||
file_keys_.push_back(name);
|
||||
}
|
||||
|
|
@ -579,12 +650,12 @@ class ReadFromBlobStore {
|
|||
fprintf(stderr, "Blob %s has size mismatch TOC\n", file_key.c_str());
|
||||
return __LINE__;
|
||||
}
|
||||
MatStorage toc_blob_array(*toc_blob);
|
||||
model_memory.push_back(std::move(toc_blob_array));
|
||||
} else {
|
||||
model_memory.emplace_back(*blob);
|
||||
model_memory.back().SetName(file_key);
|
||||
std::string name = blob->Name();
|
||||
*blob = *toc_blob;
|
||||
blob->SetName(name);
|
||||
}
|
||||
model_memory.emplace_back(*blob);
|
||||
model_memory.back().SetName(file_key);
|
||||
}
|
||||
// Allocate in parallel using the pool.
|
||||
pool.Run(0, model_memory.size(),
|
||||
|
|
@ -594,12 +665,12 @@ class ReadFromBlobStore {
|
|||
});
|
||||
// Enqueue the read requests.
|
||||
for (auto& blob : model_memory) {
|
||||
err_ = reader_.Enqueue(MakeKey(blob.Name().c_str()), blob.data(),
|
||||
blob.SizeBytes());
|
||||
err_ =
|
||||
reader_.Enqueue(MakeKey(blob.Name()), blob.data(), blob.SizeBytes());
|
||||
if (err_ != 0) {
|
||||
fprintf(stderr,
|
||||
"Failed to read blob %s (error %d) of size %zu x %zu x %zu\n",
|
||||
blob.Name().c_str(), err_, blob.Rows(), blob.Cols(),
|
||||
blob.Name(), err_, blob.Rows(), blob.Cols(),
|
||||
blob.ElementSize());
|
||||
return err_;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -25,6 +25,7 @@
|
|||
// After highway.h
|
||||
#include "compression/compress-inl.h"
|
||||
#include "gemma/configs.h"
|
||||
#include "gemma/tokenizer.h"
|
||||
|
||||
#ifndef GEMMA_COMPRESS_WEIGHTS_ONCE
|
||||
#define GEMMA_COMPRESS_WEIGHTS_ONCE
|
||||
|
|
@ -99,6 +100,9 @@ struct Args : public ArgsBase<Args> {
|
|||
std::string model_type_str;
|
||||
std::string weight_type_str;
|
||||
size_t num_threads;
|
||||
// If non-empty, whether to include the config and TOC in the output file, as
|
||||
// well as the tokenizer.
|
||||
Path tokenizer;
|
||||
|
||||
template <class Visitor>
|
||||
void ForEach(const Visitor& visitor) {
|
||||
|
|
@ -123,6 +127,9 @@ struct Args : public ArgsBase<Args> {
|
|||
"Number of threads to use.\n Default = Estimate of the "
|
||||
"number of supported concurrent threads.",
|
||||
2);
|
||||
visitor(tokenizer, "tokenizer", Path(),
|
||||
"Path to tokenizer file. If given, the config and TOC are also "
|
||||
"added to the output file.");
|
||||
}
|
||||
|
||||
// Uninitialized before Validate, must call after that.
|
||||
|
|
@ -156,7 +163,8 @@ namespace HWY_NAMESPACE {
|
|||
template <typename T>
|
||||
void CompressWeights(const Path& weights_path,
|
||||
const Path& compressed_weights_path, Model model_type,
|
||||
hwy::ThreadPool& pool) {
|
||||
Type weight_type, PromptWrapping wrapping,
|
||||
const Path& tokenizer_path, hwy::ThreadPool& pool) {
|
||||
if (!weights_path.Exists()) {
|
||||
HWY_ABORT("The model weights file '%s' does not exist.",
|
||||
weights_path.path.c_str());
|
||||
|
|
@ -164,6 +172,8 @@ void CompressWeights(const Path& weights_path,
|
|||
printf("Compressing weights from %s to %s\n", weights_path.path.c_str(),
|
||||
compressed_weights_path.path.c_str());
|
||||
ModelConfig config = ConfigFromModel(model_type);
|
||||
config.weight = weight_type;
|
||||
config.wrapping = wrapping;
|
||||
std::vector<MatStorage> model_storage;
|
||||
ModelWeightsPtrs<T> c_weights(config);
|
||||
c_weights.Allocate(model_storage, pool);
|
||||
|
|
@ -185,6 +195,9 @@ void CompressWeights(const Path& weights_path,
|
|||
ok &= 1 == fread(tensors[0]->Ptr(), tensors[0]->SizeBytes(), 1, fptr);
|
||||
total_size += tensors[0]->SizeBytes();
|
||||
});
|
||||
if (!tokenizer_path.path.empty()) {
|
||||
uc_weights.AllocAndCopyWithTranspose(pool, model_storage);
|
||||
}
|
||||
const bool scale_for_compression = config.num_tensor_scales > 0;
|
||||
std::vector<float> scales;
|
||||
if (scale_for_compression) {
|
||||
|
|
@ -193,14 +206,21 @@ void CompressWeights(const Path& weights_path,
|
|||
Compressor compressor(pool);
|
||||
ModelWeightsPtrs<T>::ForEachTensor(
|
||||
{reinterpret_cast<ModelWeightsPtrs<T>*>(&uc_weights), &c_weights},
|
||||
ForEachType::kLoadNoToc,
|
||||
tokenizer_path.path.empty() ? ForEachType::kLoadNoToc
|
||||
: ForEachType::kLoadWithToc,
|
||||
[&compressor](const char* name, hwy::Span<MatPtr*> tensors) {
|
||||
tensors[1]->CallUpcasted(
|
||||
compressor, name,
|
||||
reinterpret_cast<const float*>(tensors[0]->Ptr()));
|
||||
});
|
||||
compressor.AddScales(scales.data(), scales.size() * sizeof(scales[0]));
|
||||
compressor.WriteAll(pool, compressed_weights_path);
|
||||
if (!tokenizer_path.path.empty()) {
|
||||
std::string tokenizer_proto = ReadFileToString(tokenizer_path);
|
||||
compressor.AddTokenizer(tokenizer_proto);
|
||||
} else {
|
||||
compressor.AddScales(scales.data(), scales.size() * sizeof(scales[0]));
|
||||
}
|
||||
compressor.WriteAll(compressed_weights_path,
|
||||
tokenizer_path.path.empty() ? nullptr : &config);
|
||||
}
|
||||
|
||||
} // namespace HWY_NAMESPACE
|
||||
|
|
@ -220,19 +240,23 @@ void Run(Args& args) {
|
|||
switch (weight_type) {
|
||||
case Type::kF32:
|
||||
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(CompressWeights<float>)
|
||||
(args.weights, args.compressed_weights, model_type, pool);
|
||||
(args.weights, args.compressed_weights, model_type, weight_type,
|
||||
args.PromptWrappingType(), args.tokenizer, pool);
|
||||
break;
|
||||
case Type::kBF16:
|
||||
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(CompressWeights<BF16>)
|
||||
(args.weights, args.compressed_weights, model_type, pool);
|
||||
(args.weights, args.compressed_weights, model_type, weight_type,
|
||||
args.PromptWrappingType(), args.tokenizer, pool);
|
||||
break;
|
||||
case Type::kSFP:
|
||||
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(CompressWeights<SfpStream>)
|
||||
(args.weights, args.compressed_weights, model_type, pool);
|
||||
(args.weights, args.compressed_weights, model_type, weight_type,
|
||||
args.PromptWrappingType(), args.tokenizer, pool);
|
||||
break;
|
||||
case Type::kNUQ:
|
||||
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(CompressWeights<NuqStream>)
|
||||
(args.weights, args.compressed_weights, model_type, pool);
|
||||
(args.weights, args.compressed_weights, model_type, weight_type,
|
||||
args.PromptWrappingType(), args.tokenizer, pool);
|
||||
break;
|
||||
default:
|
||||
HWY_ABORT("Weight type %d unsupported.", static_cast<int>(weight_type));
|
||||
|
|
|
|||
|
|
@ -83,6 +83,14 @@ class PrintVisitor : public VisitorBase {
|
|||
fprintf(stderr, "%sU32 %u\n", indent_.c_str(), value);
|
||||
}
|
||||
|
||||
void operator()(int32_t& value) override {
|
||||
fprintf(stderr, "%sI32 %d\n", indent_.c_str(), value);
|
||||
}
|
||||
|
||||
void operator()(uint64_t& value) override {
|
||||
fprintf(stderr, "%sU64 %zu\n", indent_.c_str(), value);
|
||||
}
|
||||
|
||||
void operator()(float& value) override {
|
||||
fprintf(stderr, "%sF32 %f\n", indent_.c_str(), value);
|
||||
}
|
||||
|
|
@ -120,6 +128,21 @@ class ReadVisitor : public VisitorBase {
|
|||
value = span_[result_.pos++];
|
||||
}
|
||||
|
||||
void operator()(int32_t& value) override {
|
||||
if (HWY_UNLIKELY(SkipField())) return;
|
||||
|
||||
value = static_cast<int32_t>(span_[result_.pos++]);
|
||||
}
|
||||
|
||||
void operator()(uint64_t& value) override {
|
||||
if (HWY_UNLIKELY(SkipField())) return;
|
||||
uint32_t lower = static_cast<uint32_t>(value);
|
||||
operator()(lower);
|
||||
uint32_t upper = static_cast<uint32_t>(value >> 32);
|
||||
operator()(upper);
|
||||
value = lower | (static_cast<uint64_t>(upper) << 32);
|
||||
}
|
||||
|
||||
void operator()(float& value) override {
|
||||
if (HWY_UNLIKELY(SkipField())) return;
|
||||
|
||||
|
|
@ -229,6 +252,15 @@ class WriteVisitor : public VisitorBase {
|
|||
|
||||
void operator()(uint32_t& value) override { storage_.push_back(value); }
|
||||
|
||||
void operator()(int32_t& value) override {
|
||||
storage_.push_back(static_cast<uint32_t>(value));
|
||||
}
|
||||
|
||||
void operator()(uint64_t& value) override {
|
||||
storage_.push_back(static_cast<uint32_t>(value));
|
||||
storage_.push_back(static_cast<uint32_t>(value >> 32));
|
||||
}
|
||||
|
||||
void operator()(float& value) override {
|
||||
storage_.push_back(hwy::BitCastScalar<uint32_t>(value));
|
||||
CheckF32(value);
|
||||
|
|
|
|||
|
|
@ -55,8 +55,9 @@ class IFields; // breaks circular dependency
|
|||
// Visitors are internal-only, but their base class is visible to user code
|
||||
// because their `IFields::VisitFields` calls `visitor.operator()`.
|
||||
//
|
||||
// Supported field types `T`: `uint32_t`, `float`, `std::string`, classes
|
||||
// derived from `IFields`, `bool`, `enum`, `std::vector<T>`.
|
||||
// Supported field types `T`: `uint32_t`, `int32_t`, `uint64_t`, `float`,
|
||||
// `std::string`,
|
||||
// classes derived from `IFields`, `bool`, `enum`, `std::vector<T>`.
|
||||
class IFieldsVisitor {
|
||||
public:
|
||||
virtual ~IFieldsVisitor();
|
||||
|
|
@ -69,6 +70,8 @@ class IFieldsVisitor {
|
|||
// is out of range. A single generic/overloaded function is required to
|
||||
// support `std::vector<T>`.
|
||||
virtual void operator()(uint32_t& value) = 0;
|
||||
virtual void operator()(int32_t& value) = 0;
|
||||
virtual void operator()(uint64_t& value) = 0;
|
||||
virtual void operator()(float& value) = 0;
|
||||
virtual void operator()(std::string& value) = 0;
|
||||
virtual void operator()(IFields& fields) = 0; // recurse into nested fields
|
||||
|
|
@ -92,7 +95,7 @@ class IFieldsVisitor {
|
|||
uint32_t u32 = static_cast<uint32_t>(value);
|
||||
operator()(u32);
|
||||
if (HWY_UNLIKELY(!EnumValid(static_cast<EnumT>(u32)))) {
|
||||
return NotifyInvalid("Invalid enum %u\n");
|
||||
return NotifyInvalid("Invalid enum %u\n", u32);
|
||||
}
|
||||
value = static_cast<EnumT>(u32);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -97,6 +97,8 @@ struct OldFields : public IFields {
|
|||
visitor(old_str);
|
||||
visitor(old_nested);
|
||||
visitor(old1);
|
||||
visitor(oldi);
|
||||
visitor(oldl);
|
||||
visitor(old_vec_str);
|
||||
visitor(old_vec_nested);
|
||||
visitor(old_f);
|
||||
|
|
@ -110,6 +112,8 @@ struct OldFields : public IFields {
|
|||
EXPECT_EQ(old_str, n.old_str);
|
||||
old_nested.CheckEqual(n.old_nested);
|
||||
EXPECT_EQ(old1, n.old1);
|
||||
EXPECT_EQ(oldi, n.oldi);
|
||||
EXPECT_EQ(oldl, n.oldl);
|
||||
CheckVectorEqual(old_vec_str, n.old_vec_str);
|
||||
CheckVectorEqual(old_vec_nested, n.old_vec_nested);
|
||||
EXPECT_EQ(old_f, n.old_f);
|
||||
|
|
@ -120,6 +124,8 @@ struct OldFields : public IFields {
|
|||
std::string old_str = "old";
|
||||
Nested old_nested = Nested(0);
|
||||
uint32_t old1 = 1;
|
||||
int32_t oldi = -1;
|
||||
uint64_t oldl = 1234567890123456789;
|
||||
std::vector<std::string> old_vec_str = {"abc", "1234"};
|
||||
std::vector<Nested> old_vec_nested = {Nested(1), Nested(4)};
|
||||
float old_f = 1.125f;
|
||||
|
|
@ -134,6 +140,8 @@ struct NewFields : public IFields {
|
|||
visitor(old_str);
|
||||
visitor(old_nested);
|
||||
visitor(old1);
|
||||
visitor(oldi);
|
||||
visitor(oldl);
|
||||
visitor(old_vec_str);
|
||||
visitor(old_vec_nested);
|
||||
visitor(old_f);
|
||||
|
|
@ -149,6 +157,8 @@ struct NewFields : public IFields {
|
|||
visitor(new_enum);
|
||||
visitor(new2);
|
||||
visitor(new_str);
|
||||
visitor(new_i);
|
||||
visitor(new_l);
|
||||
}
|
||||
|
||||
void CheckEqual(const NewFields& n) const {
|
||||
|
|
@ -176,6 +186,8 @@ struct NewFields : public IFields {
|
|||
std::string old_str = "old";
|
||||
Nested old_nested = Nested(0);
|
||||
uint32_t old1 = 1;
|
||||
int32_t oldi = -1;
|
||||
uint64_t oldl = 1234567890123456789;
|
||||
std::vector<std::string> old_vec_str = {"abc", "1234"};
|
||||
std::vector<Nested> old_vec_nested = {Nested(1), Nested(4)};
|
||||
float old_f = 1.125f;
|
||||
|
|
@ -190,6 +202,8 @@ struct NewFields : public IFields {
|
|||
Enum new_enum = Enum::k3;
|
||||
uint32_t new2 = 2;
|
||||
std::string new_str = std::string(); // empty is allowed
|
||||
int32_t new_i = 123456789;
|
||||
uint64_t new_l = 876543210987654321;
|
||||
}; // NewFields
|
||||
|
||||
// Changes all fields to non-default values.
|
||||
|
|
@ -212,6 +226,8 @@ NewFields ModifiedNewFields() {
|
|||
n.new_enum = Enum::k8;
|
||||
n.new2 = 22;
|
||||
n.new_str = "new and even longer";
|
||||
n.new_i = 246810121;
|
||||
n.new_l = 1357913579113579135;
|
||||
|
||||
return n;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,62 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <stdio.h>
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "evals/benchmark_helper.h"
|
||||
#include "gemma/gemma.h"
|
||||
#include "util/args.h"
|
||||
|
||||
namespace gcpp {
|
||||
namespace {
|
||||
|
||||
struct WriterArgs : public ArgsBase<WriterArgs> {
|
||||
// --output_weights is required.
|
||||
WriterArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
|
||||
|
||||
// Returns error string or nullptr if OK.
|
||||
const char* Validate() {
|
||||
if (output_weights.path.empty()) {
|
||||
return "Missing --output_weights flag, a file for the model weights.";
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Path output_weights; // weights file location
|
||||
|
||||
template <class Visitor>
|
||||
void ForEach(const Visitor& visitor) {
|
||||
visitor(output_weights, "output_weights", Path(),
|
||||
"Path name of output weights (.sbs) file.\n Required argument.");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
} // namespace gcpp
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
// Loads a model in the multi-file format and saves it in single-file format.
|
||||
gcpp::WriterArgs args(argc, argv);
|
||||
if (const char* err = args.Validate()) {
|
||||
fprintf(stderr, "Skipping model load because: %s\n", err);
|
||||
return 1;
|
||||
}
|
||||
gcpp::GemmaEnv env(argc, argv, /*required=*/true);
|
||||
hwy::ThreadPool pool(0);
|
||||
env.GetModel()->Save(args.output_weights, pool);
|
||||
return 0;
|
||||
}
|
||||
|
|
@ -16,6 +16,7 @@ cc_library(
|
|||
deps = [
|
||||
"@abseil-cpp//absl/types:span",
|
||||
"//:common",
|
||||
"//:tokenizer",
|
||||
"//compression:compress",
|
||||
"//compression:io",
|
||||
"@highway//:hwy",
|
||||
|
|
|
|||
|
|
@ -24,7 +24,9 @@
|
|||
|
||||
#include "absl/types/span.h"
|
||||
#include "compression/io.h"
|
||||
#include "gemma/configs.h"
|
||||
#include "gemma/tensor_index.h"
|
||||
#include "gemma/tokenizer.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
||||
|
|
@ -44,10 +46,11 @@ class WriterInterface {
|
|||
virtual void InsertFloat(std::string name,
|
||||
absl::Span<const float> weights) = 0;
|
||||
virtual void AddScales(const std::vector<float>& scales) = 0;
|
||||
virtual void AddTokenizer(const std::string& tokenizer_path) = 0;
|
||||
|
||||
virtual size_t DebugNumBlobsAdded() const = 0;
|
||||
|
||||
virtual int Write(std::string path) = 0;
|
||||
virtual int WriteWithConfig(std::string path, const ModelConfig* config) = 0;
|
||||
};
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
@ -133,14 +136,21 @@ class SbsWriterImpl : public WriterInterface {
|
|||
compressor_.AddScales(scales_.data(), scales_.size());
|
||||
}
|
||||
|
||||
void AddTokenizer(const std::string& tokenizer_path) override {
|
||||
Path path(tokenizer_path);
|
||||
GemmaTokenizer tokenizer(path);
|
||||
tokenizer_proto_ = tokenizer.Serialize();
|
||||
compressor_.AddTokenizer(tokenizer_proto_);
|
||||
}
|
||||
|
||||
// Returns the number of blobs added.
|
||||
size_t DebugNumBlobsAdded() const {
|
||||
if (mode_ == CompressorMode::kTEST_ONLY) return model_memory_.size();
|
||||
return compressor_.DebugNumBlobsAdded();
|
||||
}
|
||||
|
||||
int Write(std::string path) override {
|
||||
return compressor_.WriteAll(pool_, gcpp::Path(path));
|
||||
int WriteWithConfig(std::string path, const ModelConfig* config) override {
|
||||
return compressor_.WriteAll(gcpp::Path(path), config);
|
||||
}
|
||||
|
||||
hwy::ThreadPool pool_;
|
||||
|
|
@ -149,6 +159,7 @@ class SbsWriterImpl : public WriterInterface {
|
|||
std::vector<MatStorage> model_memory_;
|
||||
std::vector<float> scales_;
|
||||
CompressorMode mode_;
|
||||
std::string tokenizer_proto_;
|
||||
};
|
||||
|
||||
WriterInterface* NewSbsWriter(CompressorMode mode) {
|
||||
|
|
@ -190,11 +201,17 @@ void SbsWriter::AddScales(const std::vector<float>& scales) {
|
|||
impl_->AddScales(scales);
|
||||
}
|
||||
|
||||
void SbsWriter::AddTokenizer(const std::string& tokenizer_path) {
|
||||
impl_->AddTokenizer(tokenizer_path);
|
||||
}
|
||||
|
||||
size_t SbsWriter::DebugNumBlobsAdded() const {
|
||||
return impl_->DebugNumBlobsAdded();
|
||||
}
|
||||
|
||||
int SbsWriter::Write(std::string path) { return impl_->Write(path); }
|
||||
int SbsWriter::WriteWithConfig(std::string path, const ModelConfig* config) {
|
||||
return impl_->WriteWithConfig(path, config);
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
#endif // HWY_ONCE
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@
|
|||
|
||||
#include "absl/types/span.h"
|
||||
#include "compression/shared.h"
|
||||
#include "gemma/configs.h"
|
||||
#include "gemma/tensor_index.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
|
@ -36,10 +37,12 @@ class SbsWriter {
|
|||
void InsertBfloat16(std::string name, absl::Span<const float> weights);
|
||||
void InsertFloat(std::string name, absl::Span<const float> weights);
|
||||
void AddScales(const std::vector<float>& scales);
|
||||
void AddTokenizer(const std::string& tokenizer_path);
|
||||
|
||||
size_t DebugNumBlobsAdded() const;
|
||||
|
||||
int Write(std::string path);
|
||||
int Write(std::string path) { return WriteWithConfig(path, nullptr); }
|
||||
int WriteWithConfig(std::string path, const ModelConfig* config);
|
||||
|
||||
private:
|
||||
// Isolates Highway-dispatched types and other internals from CLIF.
|
||||
|
|
|
|||
|
|
@ -50,6 +50,8 @@ PYBIND11_MODULE(compression, m) {
|
|||
.def("insert_bf16", wrap_span<&SbsWriter::InsertBfloat16>)
|
||||
.def("insert_float", wrap_span<&SbsWriter::InsertFloat>)
|
||||
.def("add_scales", &SbsWriter::AddScales)
|
||||
.def("add_tokenizer", &SbsWriter::AddTokenizer)
|
||||
.def("debug_num_blobs_added", &SbsWriter::DebugNumBlobsAdded)
|
||||
.def("write", &SbsWriter::Write);
|
||||
.def("write", &SbsWriter::Write)
|
||||
.def("write_with_config", &SbsWriter::WriteWithConfig);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -198,6 +198,11 @@ constexpr bool IsNuqStream() {
|
|||
// Instruction-tuned models require extra 'turn structure' tokens in prompts.
|
||||
enum class PromptWrapping { GEMMA_IT, GEMMA_PT, PALIGEMMA };
|
||||
|
||||
inline bool EnumValid(PromptWrapping type) {
|
||||
return static_cast<int>(type) >= 0 &&
|
||||
static_cast<int>(type) <= static_cast<int>(PromptWrapping::PALIGEMMA);
|
||||
}
|
||||
|
||||
// Tensor types for loading weights. Note that not all types are supported as
|
||||
// weights for a model, but can be used for other purposes, such as types for
|
||||
// ModelWeightsPtrs. When adding a new type that is supported, also
|
||||
|
|
@ -206,6 +211,11 @@ enum class Type { kUnknown, kF32, kBF16, kSFP, kNUQ, kF64, kC64, kU128 };
|
|||
constexpr const char* kTypeStrings[] = {"unknown", "f32", "bf16", "sfp",
|
||||
"nuq", "f64", "c64", "u128"};
|
||||
|
||||
inline bool EnumValid(Type type) {
|
||||
return static_cast<int>(type) >= 0 &&
|
||||
static_cast<int>(type) <= static_cast<int>(Type::kU128);
|
||||
}
|
||||
|
||||
// Returns a Type enum for the type of the template parameter.
|
||||
template <typename PackedT>
|
||||
Type TypeEnum() {
|
||||
|
|
|
|||
|
|
@ -92,9 +92,9 @@ static AppArgs MakeAppArgs(int argc, char** argv) {
|
|||
return AppArgs(argc, argv);
|
||||
}
|
||||
|
||||
GemmaEnv::GemmaEnv(int argc, char** argv)
|
||||
: GemmaEnv(LoaderArgs(argc, argv), InferenceArgs(argc, argv),
|
||||
MakeAppArgs(argc, argv)) {}
|
||||
GemmaEnv::GemmaEnv(int argc, char** argv, bool model_type_required)
|
||||
: GemmaEnv(LoaderArgs(argc, argv, model_type_required),
|
||||
InferenceArgs(argc, argv), MakeAppArgs(argc, argv)) {}
|
||||
|
||||
QueryResult GemmaEnv::QueryModel(const std::vector<int>& tokens) {
|
||||
QueryResult result;
|
||||
|
|
@ -270,7 +270,9 @@ void ShowHelp(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
|
|||
"specify 3 required model loading arguments:\n"
|
||||
" --tokenizer\n"
|
||||
" --weights\n"
|
||||
" --model.\n";
|
||||
" --model,\n"
|
||||
" or with the newer weights format, specify just:\n"
|
||||
" --weights\n";
|
||||
std::cerr << "\n*Example Usage*\n\n./gemma --tokenizer tokenizer.spm "
|
||||
"--weights 2b-it-sfp.sbs --model 2b-it\n";
|
||||
std::cerr << "\n*Model Loading Arguments*\n\n";
|
||||
|
|
|
|||
|
|
@ -44,7 +44,7 @@ struct QueryResult {
|
|||
class GemmaEnv {
|
||||
public:
|
||||
// Calls the other constructor with *Args arguments initialized from argv.
|
||||
GemmaEnv(int argc, char** argv);
|
||||
GemmaEnv(int argc, char** argv, bool model_type_required = false);
|
||||
GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference,
|
||||
const AppArgs& app);
|
||||
|
||||
|
|
|
|||
177
gemma/configs.cc
177
gemma/configs.cc
|
|
@ -15,6 +15,7 @@
|
|||
|
||||
#include "gemma/configs.h"
|
||||
|
||||
#include <cstddef>
|
||||
#include <iostream>
|
||||
|
||||
#include "hwy/base.h"
|
||||
|
|
@ -22,9 +23,9 @@
|
|||
namespace gcpp {
|
||||
|
||||
static ModelConfig ConfigNoSSM() {
|
||||
ModelConfig config = {.scale_names = {"att_ein", "qkv_ein", "gr_lin_x_w",
|
||||
"gr_lin_y_w", "gr_lin_out_w",
|
||||
"gr_gate_w", "gating_ein", "linear_w"}};
|
||||
ModelConfig config;
|
||||
config.scale_names = {"att_ein", "qkv_ein", "gr_lin_x_w", "gr_lin_y_w",
|
||||
"gr_lin_out_w", "gr_gate_w", "gating_ein", "linear_w"};
|
||||
return config;
|
||||
}
|
||||
|
||||
|
|
@ -37,6 +38,18 @@ static ModelConfig ConfigBaseGemmaV2() {
|
|||
return config;
|
||||
}
|
||||
|
||||
static LayerConfig LayerConfigGemma2_27B(size_t model_dim) {
|
||||
LayerConfig config;
|
||||
config.model_dim = model_dim;
|
||||
config.ff_hidden_dim = 16 * 4608 / 2; // = 36864
|
||||
config.heads = 32;
|
||||
config.kv_heads = 16;
|
||||
config.qkv_dim = 128;
|
||||
config.optimized_gating = false;
|
||||
config.post_norm = PostNormType::Scale;
|
||||
return config;
|
||||
}
|
||||
|
||||
static ModelConfig ConfigGemma2_27B() {
|
||||
ModelConfig config = ConfigBaseGemmaV2();
|
||||
config.model_name = "Gemma2_27B";
|
||||
|
|
@ -44,13 +57,7 @@ static ModelConfig ConfigGemma2_27B() {
|
|||
config.model_dim = 4608;
|
||||
config.vocab_size = kVocabSize;
|
||||
config.seq_len = 8192;
|
||||
LayerConfig layer_config = {.model_dim = config.model_dim,
|
||||
.ff_hidden_dim = 16 * 4608 / 2, // = 36864
|
||||
.heads = 32,
|
||||
.kv_heads = 16,
|
||||
.qkv_dim = 128,
|
||||
.optimized_gating = false,
|
||||
.post_norm = PostNormType::Scale};
|
||||
LayerConfig layer_config = LayerConfigGemma2_27B(config.model_dim);
|
||||
config.layer_configs = {46, layer_config};
|
||||
config.num_tensor_scales = 4 * config.layer_configs.size();
|
||||
config.query_scale = QueryScaleType::SqrtModelDimDivNumHeads;
|
||||
|
|
@ -59,6 +66,18 @@ static ModelConfig ConfigGemma2_27B() {
|
|||
return config;
|
||||
}
|
||||
|
||||
static LayerConfig LayerConfigGemma2_9B(size_t model_dim) {
|
||||
LayerConfig config;
|
||||
config.model_dim = model_dim;
|
||||
config.ff_hidden_dim = 8 * 3584 / 2; // = 14336
|
||||
config.heads = 16;
|
||||
config.kv_heads = 8;
|
||||
config.qkv_dim = 256;
|
||||
config.optimized_gating = false;
|
||||
config.post_norm = PostNormType::Scale;
|
||||
return config;
|
||||
}
|
||||
|
||||
static ModelConfig ConfigGemma2_9B() {
|
||||
ModelConfig config = ConfigBaseGemmaV2();
|
||||
config.model_name = "Gemma2_9B";
|
||||
|
|
@ -66,13 +85,7 @@ static ModelConfig ConfigGemma2_9B() {
|
|||
config.model_dim = 3584;
|
||||
config.vocab_size = kVocabSize;
|
||||
config.seq_len = 8192;
|
||||
LayerConfig layer_config = {.model_dim = config.model_dim,
|
||||
.ff_hidden_dim = 8 * 3584 / 2, // = 14336
|
||||
.heads = 16,
|
||||
.kv_heads = 8,
|
||||
.qkv_dim = 256,
|
||||
.optimized_gating = false,
|
||||
.post_norm = PostNormType::Scale};
|
||||
LayerConfig layer_config = LayerConfigGemma2_9B(config.model_dim);
|
||||
config.layer_configs = {42, layer_config};
|
||||
config.num_tensor_scales = 4 * config.layer_configs.size();
|
||||
config.query_scale = QueryScaleType::SqrtKeySize;
|
||||
|
|
@ -81,6 +94,18 @@ static ModelConfig ConfigGemma2_9B() {
|
|||
return config;
|
||||
}
|
||||
|
||||
static LayerConfig LayerConfigGemma2_2B(size_t model_dim) {
|
||||
LayerConfig config;
|
||||
config.model_dim = model_dim;
|
||||
config.ff_hidden_dim = 8 * 2304 / 2; // = 9216
|
||||
config.heads = 8;
|
||||
config.kv_heads = 4;
|
||||
config.qkv_dim = 256;
|
||||
config.optimized_gating = false;
|
||||
config.post_norm = PostNormType::Scale;
|
||||
return config;
|
||||
}
|
||||
|
||||
static ModelConfig ConfigGemma2_2B() {
|
||||
ModelConfig config = ConfigBaseGemmaV2();
|
||||
config.model_name = "Gemma2_2B";
|
||||
|
|
@ -88,13 +113,7 @@ static ModelConfig ConfigGemma2_2B() {
|
|||
config.model_dim = 2304;
|
||||
config.vocab_size = kVocabSize;
|
||||
config.seq_len = 8192;
|
||||
LayerConfig layer_config = {.model_dim = config.model_dim,
|
||||
.ff_hidden_dim = 8 * 2304 / 2, // = 9216
|
||||
.heads = 8,
|
||||
.kv_heads = 4,
|
||||
.qkv_dim = 256,
|
||||
.optimized_gating = false,
|
||||
.post_norm = PostNormType::Scale};
|
||||
LayerConfig layer_config = LayerConfigGemma2_2B(config.model_dim);
|
||||
config.layer_configs = {26, layer_config};
|
||||
config.num_tensor_scales = 4 * config.layer_configs.size();
|
||||
config.query_scale = QueryScaleType::SqrtKeySize;
|
||||
|
|
@ -103,6 +122,16 @@ static ModelConfig ConfigGemma2_2B() {
|
|||
return config;
|
||||
}
|
||||
|
||||
static LayerConfig LayerConfigGemma7B(size_t model_dim) {
|
||||
LayerConfig config;
|
||||
config.model_dim = model_dim;
|
||||
config.ff_hidden_dim = 16 * 3072 / 2; // = 24576
|
||||
config.heads = 16;
|
||||
config.kv_heads = 16;
|
||||
config.qkv_dim = 256;
|
||||
return config;
|
||||
}
|
||||
|
||||
static ModelConfig ConfigGemma7B() {
|
||||
ModelConfig config = ConfigBaseGemmaV1();
|
||||
config.model_name = "Gemma7B";
|
||||
|
|
@ -110,13 +139,7 @@ static ModelConfig ConfigGemma7B() {
|
|||
config.model_dim = 3072;
|
||||
config.vocab_size = kVocabSize;
|
||||
config.seq_len = kSeqLen;
|
||||
LayerConfig layer_config = {
|
||||
.model_dim = config.model_dim,
|
||||
.ff_hidden_dim = 16 * 3072 / 2, // = 24576
|
||||
.heads = 16,
|
||||
.kv_heads = 16,
|
||||
.qkv_dim = 256,
|
||||
};
|
||||
LayerConfig layer_config = LayerConfigGemma7B(config.model_dim);
|
||||
config.layer_configs = {28, layer_config};
|
||||
config.num_tensor_scales = 4 * config.layer_configs.size();
|
||||
config.query_scale = QueryScaleType::SqrtKeySize;
|
||||
|
|
@ -124,6 +147,16 @@ static ModelConfig ConfigGemma7B() {
|
|||
return config;
|
||||
}
|
||||
|
||||
static LayerConfig LayerConfigGemma2B(size_t model_dim) {
|
||||
LayerConfig config;
|
||||
config.model_dim = model_dim;
|
||||
config.ff_hidden_dim = 16 * 2048 / 2; // = 16384
|
||||
config.heads = 8;
|
||||
config.kv_heads = 1;
|
||||
config.qkv_dim = 256;
|
||||
return config;
|
||||
}
|
||||
|
||||
static ModelConfig ConfigGemma2B() {
|
||||
ModelConfig config = ConfigBaseGemmaV1();
|
||||
config.model_name = "Gemma2B";
|
||||
|
|
@ -131,19 +164,23 @@ static ModelConfig ConfigGemma2B() {
|
|||
config.model_dim = 2048;
|
||||
config.vocab_size = kVocabSize;
|
||||
config.seq_len = kSeqLen;
|
||||
LayerConfig layer_config = {
|
||||
.model_dim = config.model_dim,
|
||||
.ff_hidden_dim = 16 * 2048 / 2, // = 16384
|
||||
.heads = 8,
|
||||
.kv_heads = 1,
|
||||
.qkv_dim = 256,
|
||||
};
|
||||
LayerConfig layer_config = LayerConfigGemma2B(config.model_dim);
|
||||
config.layer_configs = {18, layer_config};
|
||||
config.num_tensor_scales = 4 * config.layer_configs.size();
|
||||
config.attention_window_sizes = FixedAttentionWindowSizes<18>(kSeqLen);
|
||||
return config;
|
||||
}
|
||||
|
||||
static LayerConfig LayerConfigGemmaTiny(size_t model_dim) {
|
||||
LayerConfig config;
|
||||
config.model_dim = model_dim;
|
||||
config.ff_hidden_dim = 256;
|
||||
config.heads = 4;
|
||||
config.kv_heads = 1;
|
||||
config.qkv_dim = 16;
|
||||
return config;
|
||||
}
|
||||
|
||||
static ModelConfig ConfigGemmaTiny() {
|
||||
ModelConfig config = ConfigNoSSM();
|
||||
config.model_name = "GemmaTiny";
|
||||
|
|
@ -151,13 +188,7 @@ static ModelConfig ConfigGemmaTiny() {
|
|||
config.model_dim = 128;
|
||||
config.vocab_size = 64;
|
||||
config.seq_len = 32;
|
||||
LayerConfig layer_config = {
|
||||
.model_dim = config.model_dim,
|
||||
.ff_hidden_dim = 256,
|
||||
.heads = 4,
|
||||
.kv_heads = 1,
|
||||
.qkv_dim = 16,
|
||||
};
|
||||
LayerConfig layer_config = LayerConfigGemmaTiny(config.model_dim);
|
||||
config.layer_configs = {3, layer_config};
|
||||
config.num_tensor_scales = 4 * config.layer_configs.size();
|
||||
config.query_scale = QueryScaleType::SqrtKeySize;
|
||||
|
|
@ -167,6 +198,24 @@ static ModelConfig ConfigGemmaTiny() {
|
|||
return config;
|
||||
}
|
||||
|
||||
static LayerConfig LayerConfigGriffin2B(size_t model_dim) {
|
||||
LayerConfig config;
|
||||
config.model_dim = model_dim;
|
||||
config.griffin_dim = model_dim;
|
||||
config.ff_hidden_dim = 7680;
|
||||
config.heads = 10;
|
||||
config.kv_heads = 1;
|
||||
config.qkv_dim = 256;
|
||||
config.conv1d_width = 4;
|
||||
config.ff_biases = true;
|
||||
config.softmax_attn_output_biases = true;
|
||||
config.optimized_gating = false;
|
||||
config.type = LayerAttentionType::kGriffinRecurrentBlock;
|
||||
config.activation = ActivationType::Gelu;
|
||||
config.post_qk = PostQKType::HalfRope;
|
||||
return config;
|
||||
}
|
||||
|
||||
static ModelConfig ConfigGriffin2B() {
|
||||
ModelConfig config = ConfigNoSSM();
|
||||
config.model_name = "Griffin2B";
|
||||
|
|
@ -176,21 +225,7 @@ static ModelConfig ConfigGriffin2B() {
|
|||
config.model_dim = 2560;
|
||||
config.vocab_size = kVocabSize;
|
||||
config.seq_len = 2048;
|
||||
LayerConfig layer_config = {
|
||||
.model_dim = config.model_dim,
|
||||
.griffin_dim = config.model_dim,
|
||||
.ff_hidden_dim = 7680,
|
||||
.heads = 10,
|
||||
.kv_heads = 1,
|
||||
.qkv_dim = 256,
|
||||
.conv1d_width = 4,
|
||||
.ff_biases = true,
|
||||
.softmax_attn_output_biases = true,
|
||||
.optimized_gating = false,
|
||||
.type = LayerAttentionType::kGriffinRecurrentBlock,
|
||||
.activation = ActivationType::Gelu,
|
||||
.post_qk = PostQKType::HalfRope,
|
||||
};
|
||||
LayerConfig layer_config = LayerConfigGriffin2B(config.model_dim);
|
||||
config.layer_configs = {26, layer_config};
|
||||
for (size_t i = 2; i < config.layer_configs.size(); i += 3) {
|
||||
config.layer_configs[i].type = LayerAttentionType::kGemma;
|
||||
|
|
@ -204,6 +239,18 @@ static ModelConfig ConfigGriffin2B() {
|
|||
return config;
|
||||
}
|
||||
|
||||
static LayerConfig LayerConfigVit(size_t model_dim) {
|
||||
LayerConfig config;
|
||||
config.model_dim = model_dim;
|
||||
config.ff_hidden_dim = 4304;
|
||||
config.heads = 16;
|
||||
config.kv_heads = 16;
|
||||
config.qkv_dim = 72;
|
||||
config.ff_biases = true;
|
||||
config.type = LayerAttentionType::kVit;
|
||||
return config;
|
||||
}
|
||||
|
||||
// Adds a ViT config (SigLIP SoViT ViT, used in PaliGemma) to the model config.
|
||||
static void AddVitConfig(ModelConfig& config, size_t image_size = 224) {
|
||||
config.vit_model_dim = 1152;
|
||||
|
|
@ -215,15 +262,7 @@ static void AddVitConfig(ModelConfig& config, size_t image_size = 224) {
|
|||
}
|
||||
const size_t num_patches = config.image_size / config.patch_width;
|
||||
config.vit_seq_len = num_patches * num_patches;
|
||||
LayerConfig vit_layer_config = {
|
||||
.model_dim = config.vit_model_dim,
|
||||
.ff_hidden_dim = 4304,
|
||||
.heads = 16,
|
||||
.kv_heads = 16,
|
||||
.qkv_dim = 72,
|
||||
.ff_biases = true,
|
||||
.type = LayerAttentionType::kVit,
|
||||
};
|
||||
LayerConfig vit_layer_config = LayerConfigVit(config.vit_model_dim);
|
||||
config.vit_layer_configs = {27, vit_layer_config};
|
||||
config.num_vit_scales = 4 * config.vit_layer_configs.size();
|
||||
}
|
||||
|
|
|
|||
155
gemma/configs.h
155
gemma/configs.h
|
|
@ -26,6 +26,7 @@
|
|||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
#include "compression/fields.h" // IFieldsVisitor
|
||||
#include "compression/shared.h" // BF16
|
||||
|
||||
namespace gcpp {
|
||||
|
|
@ -52,52 +53,83 @@ enum class LayerAttentionType {
|
|||
kVit,
|
||||
};
|
||||
|
||||
inline bool EnumValid(LayerAttentionType type) {
|
||||
return static_cast<int>(type) >= 0 &&
|
||||
static_cast<int>(type) <= static_cast<int>(LayerAttentionType::kVit);
|
||||
}
|
||||
|
||||
// Post attention and ffw normalization type.
|
||||
enum class PostNormType {
|
||||
None,
|
||||
Scale,
|
||||
};
|
||||
|
||||
inline bool EnumValid(PostNormType type) {
|
||||
return static_cast<int>(type) >= 0 &&
|
||||
static_cast<int>(type) <= static_cast<int>(PostNormType::Scale);
|
||||
}
|
||||
|
||||
// Post qk projection operation type.
|
||||
enum class PostQKType {
|
||||
Rope,
|
||||
HalfRope,
|
||||
};
|
||||
|
||||
inline bool EnumValid(PostQKType type) {
|
||||
return static_cast<int>(type) >= 0 &&
|
||||
static_cast<int>(type) <= static_cast<int>(PostQKType::HalfRope);
|
||||
}
|
||||
|
||||
// FFW activation function.
|
||||
enum class ActivationType {
|
||||
Gelu,
|
||||
};
|
||||
|
||||
inline bool EnumValid(ActivationType type) {
|
||||
return static_cast<int>(type) >= 0 &&
|
||||
static_cast<int>(type) <= static_cast<int>(ActivationType::Gelu);
|
||||
}
|
||||
|
||||
// Attention query scale.
|
||||
enum class QueryScaleType {
|
||||
SqrtKeySize,
|
||||
SqrtModelDimDivNumHeads,
|
||||
};
|
||||
|
||||
inline bool EnumValid(QueryScaleType type) {
|
||||
return static_cast<int>(type) >= 0 &&
|
||||
static_cast<int>(type) <=
|
||||
static_cast<int>(QueryScaleType::SqrtModelDimDivNumHeads);
|
||||
}
|
||||
|
||||
// Residual connection type.
|
||||
enum class ResidualType {
|
||||
Add,
|
||||
};
|
||||
|
||||
inline bool EnumValid(ResidualType type) {
|
||||
return static_cast<int>(type) >= 0 &&
|
||||
static_cast<int>(type) <= static_cast<int>(ResidualType::Add);
|
||||
}
|
||||
|
||||
template <size_t kNum>
|
||||
std::vector<LayerAttentionType> FixedLayerConfig(LayerAttentionType type) {
|
||||
return std::vector<LayerAttentionType>(kNum, type);
|
||||
}
|
||||
|
||||
template <size_t kNum>
|
||||
std::vector<size_t> FixedAttentionWindowSizes(size_t window_size) {
|
||||
return std::vector<size_t>(kNum, window_size);
|
||||
template <uint32_t kNum>
|
||||
std::vector<uint32_t> FixedAttentionWindowSizes(uint32_t window_size) {
|
||||
return std::vector<uint32_t>(kNum, window_size);
|
||||
}
|
||||
|
||||
// Repeat window_size_pattern for kNum / kPatternSize times.
|
||||
template <size_t kNum, size_t kPatternSize>
|
||||
std::vector<size_t> RepeatedAttentionWindowSizes(
|
||||
const std::array<size_t, kPatternSize>& window_size_pattern) {
|
||||
template <uint32_t kNum, uint32_t kPatternSize>
|
||||
std::vector<uint32_t> RepeatedAttentionWindowSizes(
|
||||
const std::array<uint32_t, kPatternSize>& window_size_pattern) {
|
||||
static_assert(kNum % kPatternSize == 0,
|
||||
"kNum must be a multiple of kPatternSize");
|
||||
std::vector<size_t> window_size_configs(kNum);
|
||||
for (size_t i = 0; i < kNum; ++i) {
|
||||
std::vector<uint32_t> window_size_configs(kNum);
|
||||
for (uint32_t i = 0; i < kNum; ++i) {
|
||||
window_size_configs[i] = window_size_pattern[i % kPatternSize];
|
||||
}
|
||||
return window_size_configs;
|
||||
|
|
@ -130,7 +162,14 @@ static constexpr Model kAllModels[] = {
|
|||
Model::PALIGEMMA2_10B_224, Model::PALIGEMMA2_10B_448,
|
||||
};
|
||||
|
||||
struct LayerConfig {
|
||||
inline bool EnumValid(Model model) {
|
||||
for (Model m : kAllModels) {
|
||||
if (m == model) return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
struct LayerConfig : public IFields {
|
||||
// Returns true if *this and other are equal.
|
||||
// If partial is true, then we don't check for items that are only set after
|
||||
// the tensors are loaded from the checkpoint.
|
||||
|
|
@ -146,13 +185,32 @@ struct LayerConfig {
|
|||
// but for MHA we store them as Q,K,V, Q,K,V, .. instead of Q..Q, K..K, V..V.
|
||||
size_t QStride() const { return qkv_dim * (IsMHA() ? 3 : 1); }
|
||||
|
||||
size_t model_dim = 0;
|
||||
size_t griffin_dim = 0;
|
||||
size_t ff_hidden_dim = 0;
|
||||
size_t heads = 0;
|
||||
size_t kv_heads = 0;
|
||||
size_t qkv_dim = 0;
|
||||
size_t conv1d_width = 0; // griffin only
|
||||
const char* Name() const override { return "LayerConfig"; }
|
||||
|
||||
void VisitFields(IFieldsVisitor& visitor) override {
|
||||
visitor(model_dim);
|
||||
visitor(griffin_dim);
|
||||
visitor(ff_hidden_dim);
|
||||
visitor(heads);
|
||||
visitor(kv_heads);
|
||||
visitor(qkv_dim);
|
||||
visitor(conv1d_width);
|
||||
visitor(ff_biases);
|
||||
visitor(softmax_attn_output_biases);
|
||||
visitor(optimized_gating);
|
||||
visitor(post_norm);
|
||||
visitor(type);
|
||||
visitor(activation);
|
||||
visitor(post_qk);
|
||||
}
|
||||
|
||||
uint32_t model_dim = 0;
|
||||
uint32_t griffin_dim = 0;
|
||||
uint32_t ff_hidden_dim = 0;
|
||||
uint32_t heads = 0;
|
||||
uint32_t kv_heads = 0;
|
||||
uint32_t qkv_dim = 0;
|
||||
uint32_t conv1d_width = 0; // griffin only
|
||||
bool ff_biases = false;
|
||||
bool softmax_attn_output_biases = false;
|
||||
bool optimized_gating = true;
|
||||
|
|
@ -162,7 +220,7 @@ struct LayerConfig {
|
|||
PostQKType post_qk = PostQKType::Rope;
|
||||
};
|
||||
|
||||
struct ModelConfig {
|
||||
struct ModelConfig : public IFields {
|
||||
// Returns true if *this and other are equal.
|
||||
// If partial is true, then we don't check for items that are only set after
|
||||
// the tensors are loaded from the checkpoint.
|
||||
|
|
@ -191,39 +249,68 @@ struct ModelConfig {
|
|||
}
|
||||
|
||||
size_t NumHeads() const {
|
||||
size_t num_heads = 0;
|
||||
uint32_t num_heads = 0;
|
||||
for (const auto& layer_config : layer_configs) {
|
||||
num_heads = std::max(num_heads, layer_config.heads);
|
||||
}
|
||||
return num_heads;
|
||||
}
|
||||
|
||||
const char* Name() const override { return "ModelConfig"; }
|
||||
|
||||
void VisitFields(IFieldsVisitor& visitor) override {
|
||||
visitor(model_family_version);
|
||||
visitor(model_name);
|
||||
visitor(model);
|
||||
visitor(wrapping);
|
||||
visitor(weight);
|
||||
visitor(num_layers);
|
||||
visitor(model_dim);
|
||||
visitor(vocab_size);
|
||||
visitor(seq_len);
|
||||
visitor(num_tensor_scales);
|
||||
visitor(att_cap);
|
||||
visitor(final_cap);
|
||||
visitor(absolute_pe);
|
||||
visitor(use_local_attention);
|
||||
visitor(query_scale);
|
||||
visitor(layer_configs);
|
||||
visitor(attention_window_sizes);
|
||||
visitor(norm_num_groups);
|
||||
visitor(vit_model_dim);
|
||||
visitor(vit_seq_len);
|
||||
visitor(num_vit_scales);
|
||||
visitor(vit_layer_configs);
|
||||
visitor(patch_width);
|
||||
visitor(image_size);
|
||||
}
|
||||
|
||||
std::string model_name;
|
||||
Model model;
|
||||
PromptWrapping wrapping;
|
||||
Type weight;
|
||||
size_t num_layers = 0;
|
||||
size_t model_dim = 0;
|
||||
size_t vit_model_dim = 0;
|
||||
size_t vocab_size = 0;
|
||||
size_t seq_len = 0;
|
||||
size_t vit_seq_len = 0;
|
||||
size_t num_tensor_scales = 0;
|
||||
size_t num_vit_scales = 0;
|
||||
Model model = Model::UNKNOWN;
|
||||
PromptWrapping wrapping = PromptWrapping::GEMMA_PT;
|
||||
Type weight = Type::kUnknown;
|
||||
uint32_t num_layers = 0;
|
||||
uint32_t model_dim = 0;
|
||||
uint32_t vit_model_dim = 0;
|
||||
uint32_t vocab_size = 0;
|
||||
uint32_t seq_len = 0;
|
||||
uint32_t vit_seq_len = 0;
|
||||
uint32_t num_tensor_scales = 0;
|
||||
uint32_t num_vit_scales = 0;
|
||||
float att_cap = 0.0f;
|
||||
float final_cap = 0.0f;
|
||||
bool absolute_pe = false;
|
||||
bool use_local_attention = false; // griffin only
|
||||
QueryScaleType query_scale = QueryScaleType::SqrtKeySize;
|
||||
std::vector<LayerConfig> layer_configs;
|
||||
std::vector<size_t> attention_window_sizes;
|
||||
std::vector<uint32_t> attention_window_sizes;
|
||||
std::vector<LayerConfig> vit_layer_configs;
|
||||
std::unordered_set<std::string> scale_names;
|
||||
int norm_num_groups = 1;
|
||||
int model_family_version = 1;
|
||||
uint32_t norm_num_groups = 1;
|
||||
uint32_t model_family_version = 1;
|
||||
// Dimensions related to image processing.
|
||||
size_t patch_width = 14;
|
||||
size_t image_size = 224;
|
||||
uint32_t patch_width = 14;
|
||||
uint32_t image_size = 224;
|
||||
};
|
||||
|
||||
// Returns the config for the given model.
|
||||
|
|
|
|||
|
|
@ -2,9 +2,12 @@
|
|||
|
||||
#include <array>
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "hwy/aligned_allocator.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
|
|
@ -412,8 +415,17 @@ void AssertMatch(const ModelConfig& config) {
|
|||
ASSERT_EQ(TConfig::kNumTensorScales, config.num_tensor_scales);
|
||||
}
|
||||
|
||||
ModelConfig RoundTripSerialize(const ModelConfig& config) {
|
||||
std::vector<uint32_t> config_buffer = config.Write();
|
||||
ModelConfig deserialized;
|
||||
deserialized.Read(hwy::Span<const uint32_t>(config_buffer), 0);
|
||||
return deserialized;
|
||||
}
|
||||
|
||||
TEST(ConfigsTest, OldConfigGemma2B) {
|
||||
AssertMatch<OldConfigGemma2B<float>>(ConfigFromModel(Model::GEMMA_2B));
|
||||
ModelConfig config = RoundTripSerialize(ConfigFromModel(Model::GEMMA_2B));
|
||||
AssertMatch<OldConfigGemma2B<float>>(config);
|
||||
}
|
||||
|
||||
TEST(ConfigsTest, OldConfigGemma7B) {
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@
|
|||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
|
||||
#include <string>
|
||||
#include <utility> // std::move
|
||||
#include <vector>
|
||||
|
||||
|
|
@ -40,13 +41,21 @@ namespace gcpp {
|
|||
|
||||
Gemma::Gemma(const Path& tokenizer_path, const Path& weights,
|
||||
const ModelInfo& info, NestedPools& pools)
|
||||
: pools_(pools), tokenizer_(tokenizer_path), info_(info) {
|
||||
model_.Load(weights, info.model, info.weight, pools_.Pool());
|
||||
: pools_(pools), tokenizer_(tokenizer_path) {
|
||||
model_.Load(weights, info.model, info.weight, info.wrapping, pools_.Pool(),
|
||||
/*tokenizer_proto=*/nullptr);
|
||||
}
|
||||
|
||||
Gemma::Gemma(const Path& weights, NestedPools& pools) : pools_(pools) {
|
||||
std::string tokenizer_proto;
|
||||
model_.Load(weights, Model::UNKNOWN, Type::kUnknown, PromptWrapping::GEMMA_IT,
|
||||
pools_.Pool(), &tokenizer_proto);
|
||||
tokenizer_.Deserialize(tokenizer_proto);
|
||||
}
|
||||
|
||||
Gemma::Gemma(GemmaTokenizer&& tokenizer, const ModelInfo& info,
|
||||
NestedPools& pools)
|
||||
: pools_(pools), tokenizer_(std::move(tokenizer)), info_(info) {
|
||||
: pools_(pools), tokenizer_(std::move(tokenizer)) {
|
||||
HWY_ASSERT(info.weight == Type::kF32);
|
||||
model_.Allocate(info.model, info.weight, pools_.Pool());
|
||||
}
|
||||
|
|
@ -166,7 +175,7 @@ void RangeChecks(const ModelConfig& weights_config,
|
|||
if (!weights_config.use_local_attention) {
|
||||
if (max_generated_tokens > weights_config.seq_len) {
|
||||
fprintf(stderr,
|
||||
"WARNING: max_generated_tokens %zu > kSeqLen %zu, truncating.\n",
|
||||
"WARNING: max_generated_tokens %zu > kSeqLen %u, truncating.\n",
|
||||
max_generated_tokens, weights_config.seq_len);
|
||||
max_generated_tokens = weights_config.seq_len;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -190,18 +190,28 @@ struct TimingInfo {
|
|||
|
||||
class Gemma {
|
||||
public:
|
||||
// Reads old format weights file and tokenizer file.
|
||||
Gemma(const Path& tokenizer_path, const Path& weights, const ModelInfo& info,
|
||||
NestedPools& pools);
|
||||
|
||||
// Reads new format weights file that contains everything in a single file.
|
||||
Gemma(const Path& weights, NestedPools& pools);
|
||||
// Allocates weights, caller is responsible for filling them.
|
||||
Gemma(GemmaTokenizer&& tokenizer, const ModelInfo& info, NestedPools& pools);
|
||||
~Gemma();
|
||||
|
||||
const ModelConfig& GetModelConfig() const { return model_.Config(); }
|
||||
const ModelInfo& Info() const { return info_; }
|
||||
ModelInfo Info() const {
|
||||
return ModelInfo({.model = model_.Config().model,
|
||||
.wrapping = model_.Config().wrapping,
|
||||
.weight = model_.Config().weight});
|
||||
}
|
||||
const GemmaTokenizer& Tokenizer() const { return tokenizer_; }
|
||||
const ModelWeightsStorage& Weights() const { return model_; }
|
||||
ModelWeightsStorage& MutableWeights() { return model_; }
|
||||
void Save(const Path& weights, hwy::ThreadPool& pool) {
|
||||
std::string tokenizer_proto = tokenizer_.Serialize();
|
||||
model_.Save(tokenizer_proto, weights, pool);
|
||||
}
|
||||
|
||||
// `pos` is the position in the KV cache. Users are responsible for
|
||||
// incrementing it in the `*StreamFunc`, or setting to zero for single-turn.
|
||||
|
|
@ -241,7 +251,6 @@ class Gemma {
|
|||
GemmaTokenizer tokenizer_;
|
||||
// Type-erased so that this can be defined in the header.
|
||||
ModelWeightsStorage model_;
|
||||
ModelInfo info_;
|
||||
};
|
||||
|
||||
// Adds BOS token and possibly 'turn' annotations, which depend on `info`
|
||||
|
|
|
|||
|
|
@ -53,7 +53,7 @@ KVCache KVCache::Create(const ModelConfig& weights_config,
|
|||
LayerAttentionType::kGriffinRecurrentBlock);
|
||||
// TODO(patrickms): Add query batching support for Griffin.
|
||||
if (num_griffin_layers > 0) {
|
||||
size_t conv1d_width = 0;
|
||||
uint32_t conv1d_width = 0;
|
||||
for (const auto& layer_config : weights_config.layer_configs) {
|
||||
conv1d_width = std::max(conv1d_width, layer_config.conv1d_width);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -482,6 +482,8 @@ std::vector<TensorInfo> LLMLayerTensors(const ModelConfig& config,
|
|||
.name = "att_w",
|
||||
.source_names = {"attn/attn_vec_einsum/w",
|
||||
"attention_block/proj_final/kernel"},
|
||||
.preshape = {layer_config.heads, layer_config.qkv_dim,
|
||||
config.model_dim},
|
||||
.axes = {2, 0, 1},
|
||||
.shape = {config.model_dim, layer_config.heads, layer_config.qkv_dim},
|
||||
.cols_take_extra_dims = true,
|
||||
|
|
|
|||
|
|
@ -56,7 +56,8 @@ TEST(TensorIndexTest, FindName) {
|
|||
// Test that the MatPtr can be constructed from the TensorInfo,
|
||||
// and that the dimensions match.
|
||||
MatPtrT<SfpStream> mat_ptr(tensor.Name(), tensor_index);
|
||||
EXPECT_EQ(tensor.Name(), mat_ptr.Name()) << "on tensor " << name;
|
||||
EXPECT_STREQ(tensor.Name(), mat_ptr.Name())
|
||||
<< "on tensor " << name;
|
||||
EXPECT_EQ(tensor.Rows(), mat_ptr.Rows()) << "on tensor " << name;
|
||||
EXPECT_EQ(tensor.Cols(), mat_ptr.Cols()) << "on tensor " << name;
|
||||
++num_found;
|
||||
|
|
|
|||
|
|
@ -44,6 +44,17 @@ class GemmaTokenizer::Impl {
|
|||
HWY_ABORT("Failed to load the tokenizer file.");
|
||||
}
|
||||
}
|
||||
// Loads the tokenizer from a serialized proto.
|
||||
explicit Impl(const std::string& tokenizer_proto) {
|
||||
PROFILER_ZONE("Startup.tokenizer");
|
||||
spp_ = std::make_unique<sentencepiece::SentencePieceProcessor>();
|
||||
if (!spp_->LoadFromSerializedProto(tokenizer_proto).ok()) {
|
||||
fprintf(stderr, "serialized proto size=%zu.\n", tokenizer_proto.size());
|
||||
HWY_ABORT("Failed to load the tokenizer from serialized proto.");
|
||||
}
|
||||
}
|
||||
|
||||
std::string Serialize() const { return spp_->serialized_model_proto(); }
|
||||
|
||||
bool Encode(const std::string& input,
|
||||
std::vector<std::string>* pieces) const {
|
||||
|
|
@ -81,6 +92,12 @@ GemmaTokenizer::~GemmaTokenizer() = default;
|
|||
GemmaTokenizer::GemmaTokenizer(GemmaTokenizer&& other) = default;
|
||||
GemmaTokenizer& GemmaTokenizer::operator=(GemmaTokenizer&& other) = default;
|
||||
|
||||
std::string GemmaTokenizer::Serialize() const { return impl_->Serialize(); }
|
||||
|
||||
void GemmaTokenizer::Deserialize(const std::string& tokenizer_proto) {
|
||||
impl_ = std::make_unique<Impl>(tokenizer_proto);
|
||||
}
|
||||
|
||||
bool GemmaTokenizer::Encode(const std::string& input,
|
||||
std::vector<std::string>* pieces) const {
|
||||
return impl_->Encode(input, pieces);
|
||||
|
|
|
|||
|
|
@ -41,6 +41,9 @@ class GemmaTokenizer {
|
|||
GemmaTokenizer(GemmaTokenizer&& other);
|
||||
GemmaTokenizer& operator=(GemmaTokenizer&& other);
|
||||
|
||||
std::string Serialize() const;
|
||||
void Deserialize(const std::string& tokenizer_proto);
|
||||
|
||||
bool Encode(const std::string& input, std::vector<std::string>* pieces) const;
|
||||
bool Encode(const std::string& input, std::vector<int>* ids) const;
|
||||
bool Decode(const std::vector<int>& ids, std::string* detokenized) const;
|
||||
|
|
|
|||
|
|
@ -19,11 +19,13 @@
|
|||
#include <cstdlib>
|
||||
#include <memory>
|
||||
#include <random>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "compression/blob_store.h"
|
||||
#include "compression/compress.h"
|
||||
#include "compression/io.h" // Path
|
||||
#include "compression/shared.h"
|
||||
#include "gemma/common.h"
|
||||
#include "gemma/configs.h"
|
||||
#include "hwy/aligned_allocator.h"
|
||||
|
|
@ -47,7 +49,9 @@ struct TensorLoader {
|
|||
};
|
||||
|
||||
BlobError ModelWeightsStorage::Load(const Path& weights, Model model_type,
|
||||
Type weight_type, hwy::ThreadPool& pool) {
|
||||
Type weight_type, PromptWrapping wrapping,
|
||||
hwy::ThreadPool& pool,
|
||||
std::string* tokenizer_proto) {
|
||||
PROFILER_ZONE("Startup.LoadModelWeightsPtrs");
|
||||
if (!weights.Exists()) {
|
||||
HWY_ABORT("The model weights file '%s' does not exist.",
|
||||
|
|
@ -56,17 +60,36 @@ BlobError ModelWeightsStorage::Load(const Path& weights, Model model_type,
|
|||
ReadFromBlobStore loader(weights);
|
||||
ForEachType fet =
|
||||
loader.HaveToc() ? ForEachType::kLoadWithToc : ForEachType::kLoadNoToc;
|
||||
std::vector<float> scales;
|
||||
if (fet == ForEachType::kLoadWithToc) {
|
||||
// TODO(rays): Load the config from the file.
|
||||
HWY_ABORT("TOC not supported yet.");
|
||||
BlobError err = loader.LoadConfig(config_);
|
||||
if (err != 0 || config_.model_dim == 0) {
|
||||
fprintf(stderr, "Failed to load model config: %d\n", err);
|
||||
return err;
|
||||
}
|
||||
if (tokenizer_proto != nullptr) {
|
||||
err = loader.LoadTokenizer(*tokenizer_proto);
|
||||
if (err != 0) {
|
||||
fprintf(stderr, "Failed to load tokenizer: %d\n", err);
|
||||
return err;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (weight_type == Type::kUnknown || model_type == Model::UNKNOWN) {
|
||||
fprintf(stderr,
|
||||
"weight type (%d) and model type (%d) must be specified when "
|
||||
"no config is present in weights file\n",
|
||||
static_cast<int>(weight_type), static_cast<int>(model_type));
|
||||
return __LINE__;
|
||||
}
|
||||
// No Toc-> no config.
|
||||
config_ = ConfigFromModel(model_type);
|
||||
config_.weight = weight_type;
|
||||
config_.wrapping = wrapping;
|
||||
scales.resize(config_.num_tensor_scales + config_.num_vit_scales);
|
||||
}
|
||||
CreateForType(weight_type, pool);
|
||||
CreateForType(config_.weight, pool);
|
||||
CallForModelWeightT<TensorLoader>(fet, loader);
|
||||
std::vector<float> scales(config_.num_tensor_scales + config_.num_vit_scales);
|
||||
if (!scales.empty()) {
|
||||
loader.LoadScales(scales.data(), scales.size());
|
||||
}
|
||||
|
|
@ -85,6 +108,34 @@ BlobError ModelWeightsStorage::Load(const Path& weights, Model model_type,
|
|||
return 0;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct TensorSaver {
|
||||
// Adds all the tensors to the blob writer.
|
||||
void operator()(ModelWeightsPtrs<T>& weights, ForEachType fet,
|
||||
WriteToBlobStore& writer) {
|
||||
weights.ForEachTensor(
|
||||
{&weights}, fet,
|
||||
[&writer](const char* name, hwy::Span<MatPtr*> tensors) {
|
||||
tensors[0]->CallUpcasted(writer, name);
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
BlobError ModelWeightsStorage::Save(const std::string& tokenizer,
|
||||
const Path& weights,
|
||||
hwy::ThreadPool& pool) {
|
||||
WriteToBlobStore writer(pool);
|
||||
ForEachType fet = ForEachType::kLoadWithToc;
|
||||
CallForModelWeightT<TensorSaver>(fet, writer);
|
||||
writer.AddTokenizer(tokenizer);
|
||||
int err = writer.WriteAll(weights, &config_);
|
||||
if (err != 0) {
|
||||
fprintf(stderr, "Failed to load model weights: %d\n", err);
|
||||
return err;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
void ModelWeightsStorage::Allocate(const ModelConfig& config, Type weight_type,
|
||||
hwy::ThreadPool& pool) {
|
||||
PROFILER_ZONE("Startup.AllocateModelWeightsPtrs");
|
||||
|
|
|
|||
|
|
@ -522,7 +522,18 @@ class ModelWeightsStorage {
|
|||
ModelWeightsStorage() = default;
|
||||
~ModelWeightsStorage() = default;
|
||||
|
||||
// Loads the weights from a blob store file. Supports multi-file or
|
||||
// single-file format. If the weights file contains a TOC, then it is in
|
||||
// single-file format, and model_type, weight_type, training are ignored,
|
||||
// and tokenizer_proto is required and written to.
|
||||
// With a multi-file format, file, model_type, weight_type, training are
|
||||
// required and tokenizer_proto is ignored.
|
||||
BlobError Load(const Path& weights, Model model_type, Type weight_type,
|
||||
PromptWrapping wrapping, hwy::ThreadPool& pool,
|
||||
std::string* tokenizer_proto);
|
||||
// Writes the weights to a blob store file, using the single-file format with
|
||||
// a TOC and config included.
|
||||
BlobError Save(const std::string& tokenizer, const Path& weights,
|
||||
hwy::ThreadPool& pool);
|
||||
void Allocate(Model model_type, Type weight_type, hwy::ThreadPool& pool) {
|
||||
Allocate(ConfigFromModel(model_type), weight_type, pool);
|
||||
|
|
|
|||
|
|
@ -26,7 +26,6 @@
|
|||
#include <cmath>
|
||||
#include <random>
|
||||
|
||||
#include "compression/compress.h"
|
||||
#include "compression/shared.h"
|
||||
#include "util/allocator.h"
|
||||
#include "util/test_util.h"
|
||||
|
|
|
|||
41
util/app.h
41
util/app.h
|
|
@ -25,6 +25,7 @@
|
|||
#include <string>
|
||||
|
||||
#include "compression/io.h" // Path
|
||||
#include "compression/shared.h"
|
||||
#include "gemma/common.h"
|
||||
#include "gemma/gemma.h" // For CreateGemma
|
||||
#include "ops/matmul.h"
|
||||
|
|
@ -125,7 +126,10 @@ static inline NestedPools CreatePools(const AppArgs& app) {
|
|||
}
|
||||
|
||||
struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
||||
LoaderArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
|
||||
LoaderArgs(int argc, char* argv[], bool required = true)
|
||||
: model_type_required(required) {
|
||||
InitAndParse(argc, argv);
|
||||
}
|
||||
LoaderArgs(const std::string& tokenizer_path, const std::string& weights_path,
|
||||
const std::string& model) {
|
||||
Init(); // Init sets to defaults, so assignments must come after Init().
|
||||
|
|
@ -136,18 +140,24 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
|||
|
||||
// Returns error string or nullptr if OK.
|
||||
const char* Validate() {
|
||||
info_.model = Model::UNKNOWN;
|
||||
info_.wrapping = PromptWrapping::GEMMA_PT;
|
||||
info_.weight = Type::kUnknown;
|
||||
if (const char* err = ParseModelTypeAndWrapping(model_type_str, info_.model,
|
||||
info_.wrapping)) {
|
||||
return err;
|
||||
if (model_type_required) return err;
|
||||
}
|
||||
if (const char* err = ParseType(weight_type_str, info_.weight)) {
|
||||
return err;
|
||||
if (model_type_required) return err;
|
||||
}
|
||||
if (tokenizer.path.empty()) {
|
||||
return "Missing --tokenizer flag, a file for the tokenizer is required.";
|
||||
}
|
||||
if (!tokenizer.Exists()) {
|
||||
return "Can't open file specified with --tokenizer flag.";
|
||||
if (model_type_required) {
|
||||
if (tokenizer.path.empty()) {
|
||||
return "Missing --tokenizer flag, a file for the tokenizer is "
|
||||
"required.";
|
||||
}
|
||||
if (!tokenizer.Exists()) {
|
||||
return "Can't open file specified with --tokenizer flag.";
|
||||
}
|
||||
}
|
||||
if (!compressed_weights.path.empty()) {
|
||||
if (weights.path.empty()) {
|
||||
|
|
@ -172,11 +182,12 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
|||
Path compressed_weights;
|
||||
std::string model_type_str;
|
||||
std::string weight_type_str;
|
||||
bool model_type_required = true;
|
||||
|
||||
template <class Visitor>
|
||||
void ForEach(const Visitor& visitor) {
|
||||
visitor(tokenizer, "tokenizer", Path(),
|
||||
"Path name of tokenizer model file.\n Required argument.");
|
||||
"Path name of tokenizer model file.");
|
||||
visitor(weights, "weights", Path(),
|
||||
"Path name of model weights (.sbs) file.\n Required argument.");
|
||||
visitor(compressed_weights, "compressed_weights", Path(),
|
||||
|
|
@ -186,11 +197,9 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
|||
"2b-pt = 2B parameters, pretrained\n 7b-it = 7B parameters "
|
||||
"instruction-tuned\n 7b-pt = 7B parameters, pretrained\n "
|
||||
"gr2b-it = griffin 2B parameters, instruction-tuned\n "
|
||||
"gr2b-pt = griffin 2B parameters, pretrained\n "
|
||||
" Required argument.");
|
||||
"gr2b-pt = griffin 2B parameters, pretrained.");
|
||||
visitor(weight_type_str, "weight_type", std::string("sfp"),
|
||||
"Weight type\n f32 = float, bf16 = bfloat16, sfp = 8-bit FP\n"
|
||||
" Required argument.");
|
||||
"Weight type\n f32 = float, bf16 = bfloat16, sfp = 8-bit FP.");
|
||||
}
|
||||
|
||||
// Uninitialized before Validate, must call after that.
|
||||
|
|
@ -208,6 +217,12 @@ static inline Gemma CreateGemma(const LoaderArgs& loader, NestedPools& pools) {
|
|||
|
||||
static inline std::unique_ptr<Gemma> AllocateGemma(const LoaderArgs& loader,
|
||||
NestedPools& pools) {
|
||||
if (Type::kUnknown == loader.Info().weight ||
|
||||
Model::UNKNOWN == loader.Info().model || loader.tokenizer.path.empty()) {
|
||||
// Newer weights file format doesn't need tokenizer path or model/weight
|
||||
// info.
|
||||
return std::make_unique<Gemma>(loader.weights, pools);
|
||||
}
|
||||
return std::make_unique<Gemma>(loader.tokenizer, loader.weights,
|
||||
loader.Info(), pools);
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue