mirror of https://github.com/google/gemma.cpp.git
Added ability to load/save a complete model file, including tokenizer.
PiperOrigin-RevId: 707914366
This commit is contained in:
parent
5bc356f18f
commit
9d40f0117e
|
|
@ -245,6 +245,7 @@ cc_library(
|
||||||
"gemma/tensor_index.h",
|
"gemma/tensor_index.h",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
|
"//compression:fields",
|
||||||
"//compression:sfp",
|
"//compression:sfp",
|
||||||
"@highway//:hwy", # base.h
|
"@highway//:hwy", # base.h
|
||||||
"@highway//:thread_pool",
|
"@highway//:thread_pool",
|
||||||
|
|
@ -257,6 +258,7 @@ cc_test(
|
||||||
deps = [
|
deps = [
|
||||||
":common",
|
":common",
|
||||||
"@googletest//:gtest_main",
|
"@googletest//:gtest_main",
|
||||||
|
"@highway//:hwy",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -388,6 +390,7 @@ cc_library(
|
||||||
":ops",
|
":ops",
|
||||||
":threading",
|
":threading",
|
||||||
"//compression:io",
|
"//compression:io",
|
||||||
|
"//compression:sfp",
|
||||||
"@highway//:hwy",
|
"@highway//:hwy",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -390,13 +390,12 @@ static ModelConfig TestConfig() {
|
||||||
config.model_dim = 32;
|
config.model_dim = 32;
|
||||||
config.vocab_size = 12;
|
config.vocab_size = 12;
|
||||||
config.seq_len = 18;
|
config.seq_len = 18;
|
||||||
LayerConfig layer_config = {
|
LayerConfig layer_config;
|
||||||
.model_dim = config.model_dim,
|
layer_config.model_dim = config.model_dim;
|
||||||
.ff_hidden_dim = 48,
|
layer_config.ff_hidden_dim = 48;
|
||||||
.heads = 3,
|
layer_config.heads = 3;
|
||||||
.kv_heads = 1,
|
layer_config.kv_heads = 1;
|
||||||
.qkv_dim = 12,
|
layer_config.qkv_dim = 12;
|
||||||
};
|
|
||||||
config.layer_configs = {2, layer_config};
|
config.layer_configs = {2, layer_config};
|
||||||
config.num_tensor_scales = 4 * config.layer_configs.size();
|
config.num_tensor_scales = 4 * config.layer_configs.size();
|
||||||
config.query_scale = QueryScaleType::SqrtKeySize;
|
config.query_scale = QueryScaleType::SqrtKeySize;
|
||||||
|
|
|
||||||
|
|
@ -191,13 +191,12 @@ static ModelConfig TestConfig() {
|
||||||
config.model_dim = 32;
|
config.model_dim = 32;
|
||||||
config.vocab_size = 16;
|
config.vocab_size = 16;
|
||||||
config.seq_len = 24;
|
config.seq_len = 24;
|
||||||
LayerConfig layer_config = {
|
LayerConfig layer_config;
|
||||||
.model_dim = config.model_dim,
|
layer_config.model_dim = config.model_dim;
|
||||||
.ff_hidden_dim = 64,
|
layer_config.ff_hidden_dim = 64;
|
||||||
.heads = 3,
|
layer_config.heads = 3;
|
||||||
.kv_heads = 1,
|
layer_config.kv_heads = 1;
|
||||||
.qkv_dim = 16,
|
layer_config.qkv_dim = 16;
|
||||||
};
|
|
||||||
config.layer_configs = {2, layer_config};
|
config.layer_configs = {2, layer_config};
|
||||||
config.num_tensor_scales = 4 * config.layer_configs.size();
|
config.num_tensor_scales = 4 * config.layer_configs.size();
|
||||||
config.query_scale = QueryScaleType::SqrtKeySize;
|
config.query_scale = QueryScaleType::SqrtKeySize;
|
||||||
|
|
|
||||||
|
|
@ -58,7 +58,6 @@ cc_test(
|
||||||
deps = [
|
deps = [
|
||||||
":fields",
|
":fields",
|
||||||
"@googletest//:gtest_main", # buildcleaner: keep
|
"@googletest//:gtest_main", # buildcleaner: keep
|
||||||
"@highway//:hwy",
|
|
||||||
"@highway//:hwy_test_util",
|
"@highway//:hwy_test_util",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
@ -202,6 +201,7 @@ cc_library(
|
||||||
deps = [
|
deps = [
|
||||||
":blob_store",
|
":blob_store",
|
||||||
":distortion",
|
":distortion",
|
||||||
|
":fields",
|
||||||
":io",
|
":io",
|
||||||
":nuq",
|
":nuq",
|
||||||
":sfp",
|
":sfp",
|
||||||
|
|
@ -210,7 +210,6 @@ cc_library(
|
||||||
"//:common",
|
"//:common",
|
||||||
"@highway//:hwy",
|
"@highway//:hwy",
|
||||||
"@highway//:nanobenchmark",
|
"@highway//:nanobenchmark",
|
||||||
"@highway//:profiler",
|
|
||||||
"@highway//:stats",
|
"@highway//:stats",
|
||||||
"@highway//:thread_pool",
|
"@highway//:thread_pool",
|
||||||
],
|
],
|
||||||
|
|
@ -261,6 +260,7 @@ cc_binary(
|
||||||
"//:allocator",
|
"//:allocator",
|
||||||
"//:args",
|
"//:args",
|
||||||
"//:common",
|
"//:common",
|
||||||
|
"//:tokenizer",
|
||||||
"//:weights",
|
"//:weights",
|
||||||
"@highway//:hwy",
|
"@highway//:hwy",
|
||||||
"@highway//:thread_pool",
|
"@highway//:thread_pool",
|
||||||
|
|
@ -277,3 +277,14 @@ cc_binary(
|
||||||
"@highway//:hwy_test_util",
|
"@highway//:hwy_test_util",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_binary(
|
||||||
|
name = "migrate_weights",
|
||||||
|
srcs = ["migrate_weights.cc"],
|
||||||
|
deps = [
|
||||||
|
"//:app",
|
||||||
|
"//:args",
|
||||||
|
"//:benchmark_helper",
|
||||||
|
"//:gemma_lib",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -22,10 +22,13 @@
|
||||||
#include <stdio.h>
|
#include <stdio.h>
|
||||||
|
|
||||||
#include <cmath> // lroundf, only if COMPRESS_STATS
|
#include <cmath> // lroundf, only if COMPRESS_STATS
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
#include "compression/blob_store.h"
|
#include "compression/blob_store.h"
|
||||||
#include "compression/compress.h" // IWYU pragma: export
|
#include "compression/compress.h" // IWYU pragma: export
|
||||||
#include "compression/distortion.h"
|
#include "compression/distortion.h"
|
||||||
|
#include "gemma/configs.h"
|
||||||
#include "hwy/aligned_allocator.h"
|
#include "hwy/aligned_allocator.h"
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
|
|
@ -673,36 +676,37 @@ HWY_INLINE float DecompressAndCall(D, const PackedSpan<const VT> v,
|
||||||
// their scaling factors to BlobStore.
|
// their scaling factors to BlobStore.
|
||||||
class Compressor {
|
class Compressor {
|
||||||
public:
|
public:
|
||||||
explicit Compressor(hwy::ThreadPool& pool) : pool_(pool) {}
|
explicit Compressor(hwy::ThreadPool& pool) : writer_(pool) {}
|
||||||
|
|
||||||
template <typename Packed>
|
template <typename Packed>
|
||||||
void operator()(MatPtrT<Packed>* compressed, const char* decorated_name,
|
void operator()(MatPtrT<Packed>* compressed, const char* decorated_name,
|
||||||
const float* HWY_RESTRICT weights) {
|
const float* HWY_RESTRICT weights) {
|
||||||
size_t num_weights = compressed->NumElements();
|
size_t num_weights = compressed->NumElements();
|
||||||
|
if (num_weights == 0 || weights == nullptr || compressed->Ptr() == nullptr)
|
||||||
|
return;
|
||||||
size_t num_compressed = compressed->NumElements();
|
size_t num_compressed = compressed->NumElements();
|
||||||
PackedSpan<Packed> packed = MakeSpan(compressed->data(), num_compressed);
|
PackedSpan<Packed> packed = MakeSpan(compressed->data(), num_compressed);
|
||||||
fprintf(stderr, "Compressing %s (%zuM), please wait\n", decorated_name,
|
fprintf(stderr, "Compressing %s (%zuM), please wait\n", decorated_name,
|
||||||
num_weights / (1000 * 1000));
|
num_weights / (1000 * 1000));
|
||||||
Compress(weights, num_weights, work_, packed, /*packed_ofs=*/0, pool_);
|
Compress(weights, num_weights, work_, packed, /*packed_ofs=*/0,
|
||||||
const size_t num_bytes = packed.num * sizeof(Packed);
|
writer_.pool());
|
||||||
writer_.Add(MakeKey(decorated_name), packed.ptr, num_bytes);
|
writer_(compressed, decorated_name);
|
||||||
|
}
|
||||||
|
|
||||||
|
void AddTokenizer(const std::string& tokenizer) {
|
||||||
|
writer_.AddTokenizer(tokenizer);
|
||||||
}
|
}
|
||||||
|
|
||||||
void AddScales(const float* scales, size_t len) {
|
void AddScales(const float* scales, size_t len) {
|
||||||
if (len) {
|
writer_.AddScales(scales, len);
|
||||||
MatPtrT<float> scales_ptr("scales", 0, 1);
|
|
||||||
writer_.Add(MakeKey(scales_ptr.CacheName().c_str()), scales,
|
|
||||||
len * sizeof(scales[0]));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
BlobError WriteAll(hwy::ThreadPool& pool, const Path& blob_filename) {
|
// Writes all blobs to disk in the given order. The config is optional and
|
||||||
const BlobError err = writer_.WriteAll(pool, blob_filename);
|
// if given, it is written to the file, along with the TOC, making it
|
||||||
if (err != 0) {
|
// single-file format. Otherwise, the file is written in the multi-file format
|
||||||
fprintf(stderr, "Failed to write blobs to %s (error %d)\n",
|
// without a TOC.
|
||||||
blob_filename.path.c_str(), err);
|
BlobError WriteAll(const Path& blob_filename, const ModelConfig* config) {
|
||||||
}
|
return writer_.WriteAll(blob_filename, config);
|
||||||
return err;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns the number of blobs added.
|
// Returns the number of blobs added.
|
||||||
|
|
@ -710,8 +714,7 @@ class Compressor {
|
||||||
|
|
||||||
private:
|
private:
|
||||||
CompressWorkingSet work_;
|
CompressWorkingSet work_;
|
||||||
hwy::ThreadPool& pool_;
|
WriteToBlobStore writer_;
|
||||||
BlobWriter writer_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||||
|
|
|
||||||
|
|
@ -24,6 +24,7 @@
|
||||||
#include <stdint.h>
|
#include <stdint.h>
|
||||||
#include <stdio.h>
|
#include <stdio.h>
|
||||||
|
|
||||||
|
#include <cstdio>
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
|
|
@ -32,11 +33,13 @@
|
||||||
|
|
||||||
// IWYU pragma: begin_exports
|
// IWYU pragma: begin_exports
|
||||||
#include "compression/blob_store.h"
|
#include "compression/blob_store.h"
|
||||||
|
#include "compression/fields.h"
|
||||||
#include "compression/io.h"
|
#include "compression/io.h"
|
||||||
#include "compression/shared.h"
|
#include "compression/shared.h"
|
||||||
#include "gemma/tensor_index.h"
|
#include "gemma/tensor_index.h"
|
||||||
#include "util/basics.h"
|
#include "util/basics.h"
|
||||||
// IWYU pragma: end_exports
|
// IWYU pragma: end_exports
|
||||||
|
#include "gemma/configs.h"
|
||||||
#include "util/allocator.h"
|
#include "util/allocator.h"
|
||||||
#include "hwy/per_target.h"
|
#include "hwy/per_target.h"
|
||||||
#if COMPRESS_STATS
|
#if COMPRESS_STATS
|
||||||
|
|
@ -55,7 +58,7 @@ namespace gcpp {
|
||||||
// fixed inner dimension and type.
|
// fixed inner dimension and type.
|
||||||
// It is designed to be put in a vector, and has default copy and operator=, so
|
// It is designed to be put in a vector, and has default copy and operator=, so
|
||||||
// it is easy to read/write a blob_store file.
|
// it is easy to read/write a blob_store file.
|
||||||
class MatPtr {
|
class MatPtr : public IFields {
|
||||||
public:
|
public:
|
||||||
// Full constructor for dynamic sizing.
|
// Full constructor for dynamic sizing.
|
||||||
MatPtr(const std::string& name, Type type, size_t element_size, size_t rows,
|
MatPtr(const std::string& name, Type type, size_t element_size, size_t rows,
|
||||||
|
|
@ -73,36 +76,6 @@ class MatPtr {
|
||||||
MatPtr() = default;
|
MatPtr() = default;
|
||||||
virtual ~MatPtr();
|
virtual ~MatPtr();
|
||||||
|
|
||||||
// Number of hwy::uint128_t in a TOC entry.
|
|
||||||
// Note that the old-style BlobStore files only have a list of keys and size.
|
|
||||||
// The new-style BlobStore files have an entry called "toc" that contains a
|
|
||||||
// vector of 4-tuples of
|
|
||||||
// (name, type, (num_elements, element_size), (rows, cols)).
|
|
||||||
// The listed blobs can be read directly into MatPtr from the BlobStore
|
|
||||||
// file, without needing any external knowledge of the number of elements,
|
|
||||||
// element size or type of the data.
|
|
||||||
static constexpr size_t kNumU128InTocEntry = 4;
|
|
||||||
|
|
||||||
// Construct from a TOC entry.
|
|
||||||
MatPtr(const hwy::uint128_t& key0, const hwy::uint128_t& key1,
|
|
||||||
const hwy::uint128_t& key2, const hwy::uint128_t& key3)
|
|
||||||
: name_(StringFromKey(key0)),
|
|
||||||
type_(static_cast<Type>(key1.lo)),
|
|
||||||
element_size_(key2.hi),
|
|
||||||
num_elements_(key2.lo),
|
|
||||||
rows_(key3.lo),
|
|
||||||
cols_(key3.hi) {
|
|
||||||
stride_ = cols_;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Adds the contents entry to the table of contents.
|
|
||||||
void AddToToc(std::vector<hwy::uint128_t>& toc) const {
|
|
||||||
toc.push_back(MakeKey(name_.c_str()));
|
|
||||||
toc.push_back({static_cast<uint64_t>(type_), 0});
|
|
||||||
toc.push_back({num_elements_, element_size_});
|
|
||||||
toc.push_back({rows_, cols_});
|
|
||||||
}
|
|
||||||
|
|
||||||
// Compatibility interface for CompressedArray.
|
// Compatibility interface for CompressedArray.
|
||||||
// TODO: remove.
|
// TODO: remove.
|
||||||
template <typename T>
|
template <typename T>
|
||||||
|
|
@ -124,7 +97,7 @@ class MatPtr {
|
||||||
MatPtr& operator=(const MatPtr& other) = default;
|
MatPtr& operator=(const MatPtr& other) = default;
|
||||||
|
|
||||||
// Returns the name of the blob.
|
// Returns the name of the blob.
|
||||||
const std::string& Name() const { return name_; }
|
const char* Name() const override { return name_.c_str(); }
|
||||||
void SetName(const std::string& name) { name_ = name; }
|
void SetName(const std::string& name) { name_ = name; }
|
||||||
|
|
||||||
// Returns the type of the blob.
|
// Returns the type of the blob.
|
||||||
|
|
@ -163,12 +136,6 @@ class MatPtr {
|
||||||
return name;
|
return name;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Adds the blob to the writer.
|
|
||||||
void AddToWriter(BlobWriter& writer) const {
|
|
||||||
fprintf(stderr, "Adding %s to writer\n", name_.c_str());
|
|
||||||
writer.Add(MakeKey(name_.c_str()), ptr_, SizeBytes());
|
|
||||||
}
|
|
||||||
|
|
||||||
// Sets all data to zero.
|
// Sets all data to zero.
|
||||||
void ZeroInit() {
|
void ZeroInit() {
|
||||||
if (ptr_ == nullptr)
|
if (ptr_ == nullptr)
|
||||||
|
|
@ -176,6 +143,17 @@ class MatPtr {
|
||||||
hwy::ZeroBytes(ptr_, SizeBytes());
|
hwy::ZeroBytes(ptr_, SizeBytes());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void VisitFields(IFieldsVisitor& visitor) override {
|
||||||
|
visitor(name_);
|
||||||
|
visitor(type_);
|
||||||
|
visitor(element_size_);
|
||||||
|
visitor(num_elements_);
|
||||||
|
visitor(rows_);
|
||||||
|
visitor(cols_);
|
||||||
|
visitor(scale_);
|
||||||
|
visitor(stride_);
|
||||||
|
}
|
||||||
|
|
||||||
// Calls func on the upcasted type. Since MatPtr by design is not templated,
|
// Calls func on the upcasted type. Since MatPtr by design is not templated,
|
||||||
// here we provide a way to get to the derived type, provided that `Type()`
|
// here we provide a way to get to the derived type, provided that `Type()`
|
||||||
// is one of the strings returned by `TypeName()`.
|
// is one of the strings returned by `TypeName()`.
|
||||||
|
|
@ -188,13 +166,13 @@ class MatPtr {
|
||||||
// Should be the result of TypeEnum<T> for CallUpcasted() to work.
|
// Should be the result of TypeEnum<T> for CallUpcasted() to work.
|
||||||
Type type_;
|
Type type_;
|
||||||
// sizeof(T)
|
// sizeof(T)
|
||||||
size_t element_size_ = 0;
|
uint32_t element_size_ = 0;
|
||||||
// Number of elements in the array.
|
// Number of elements in the array.
|
||||||
size_t num_elements_ = 0; // In element_size units.
|
uint32_t num_elements_ = 0; // In element_size units.
|
||||||
// Number of rows in the 2-d array (outer dimension).
|
// Number of rows in the 2-d array (outer dimension).
|
||||||
size_t rows_ = 0;
|
uint32_t rows_ = 0;
|
||||||
// Number of columns in the 2-d array (inner dimension).
|
// Number of columns in the 2-d array (inner dimension).
|
||||||
size_t cols_ = 0;
|
uint32_t cols_ = 0;
|
||||||
// Scaling to apply to each element.
|
// Scaling to apply to each element.
|
||||||
float scale_ = 1.0f;
|
float scale_ = 1.0f;
|
||||||
// Aligned data array. This is always a borrowed pointer. It should never be
|
// Aligned data array. This is always a borrowed pointer. It should never be
|
||||||
|
|
@ -202,7 +180,7 @@ class MatPtr {
|
||||||
// and must outlive this object.
|
// and must outlive this object.
|
||||||
void* ptr_ = nullptr;
|
void* ptr_ = nullptr;
|
||||||
|
|
||||||
size_t stride_;
|
uint32_t stride_;
|
||||||
};
|
};
|
||||||
|
|
||||||
// MatPtrT adds a single template argument to MatPtr for an explicit type.
|
// MatPtrT adds a single template argument to MatPtr for an explicit type.
|
||||||
|
|
@ -394,31 +372,28 @@ class BlobToc {
|
||||||
public:
|
public:
|
||||||
BlobToc() = default;
|
BlobToc() = default;
|
||||||
|
|
||||||
// Adds all blobs to the blob writer. Note that the blobs must have unique
|
|
||||||
// names.
|
|
||||||
static void AddAllToBlobWriter(const std::vector<MatStorage>& blobs,
|
|
||||||
BlobWriter& writer) {
|
|
||||||
std::vector<hwy::uint128_t> toc;
|
|
||||||
for (const auto& blob : blobs) {
|
|
||||||
blob.AddToToc(toc);
|
|
||||||
blob.AddToWriter(writer);
|
|
||||||
}
|
|
||||||
writer.Add(MakeKey(kTocName), toc.data(), toc.size() * sizeof(toc[0]));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Loads the table of contents from the given reader.
|
// Loads the table of contents from the given reader.
|
||||||
BlobError LoadToc(BlobReader& reader) {
|
BlobError LoadToc(BlobReader& reader) {
|
||||||
hwy::uint128_t toc_key = MakeKey(kTocName);
|
hwy::uint128_t toc_key = MakeKey(kTocName);
|
||||||
size_t toc_size = reader.BlobSize(toc_key);
|
size_t toc_size = reader.BlobSize(toc_key);
|
||||||
if (toc_size != 0) {
|
if (toc_size != 0) {
|
||||||
std::vector<hwy::uint128_t> toc(toc_size / sizeof(hwy::uint128_t));
|
std::vector<uint32_t> toc(toc_size / sizeof(uint32_t));
|
||||||
BlobError err = reader.ReadOne(toc_key, toc.data(), toc_size);
|
BlobError err = reader.ReadOne(toc_key, toc.data(), toc_size);
|
||||||
if (err != 0) {
|
if (err != 0) {
|
||||||
fprintf(stderr, "Failed to read toc (error %d)\n", err);
|
fprintf(stderr, "Failed to read toc (error %d)\n", err);
|
||||||
return err;
|
return err;
|
||||||
}
|
}
|
||||||
for (size_t i = 0; i < toc.size(); i += MatPtr::kNumU128InTocEntry) {
|
size_t consumed = 0;
|
||||||
AddToToc(MatPtr(toc[i], toc[i + 1], toc[i + 2], toc[i + 3]));
|
size_t prev_consumed = static_cast<size_t>(-1);
|
||||||
|
while (consumed < toc.size() && prev_consumed != consumed) {
|
||||||
|
MatPtr blob;
|
||||||
|
const IFields::ReadResult result =
|
||||||
|
blob.Read(hwy::Span<const uint32_t>(toc), consumed);
|
||||||
|
prev_consumed = consumed;
|
||||||
|
consumed = result.pos;
|
||||||
|
if (blob.NumElements() > 0) {
|
||||||
|
AddToToc(blob);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return 0;
|
return 0;
|
||||||
|
|
@ -437,11 +412,16 @@ class BlobToc {
|
||||||
if (it == toc_map_.end()) return nullptr;
|
if (it == toc_map_.end()) return nullptr;
|
||||||
return &toc_[it->second];
|
return &toc_[it->second];
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
|
||||||
// The name of the toc in the blob store file.
|
// The name of the toc in the blob store file.
|
||||||
static constexpr char kTocName[] = "toc";
|
static constexpr char kTocName[] = "toc";
|
||||||
|
|
||||||
|
// The name of the config in the blob store file.
|
||||||
|
static constexpr char kConfigName[] = "config";
|
||||||
|
|
||||||
|
// The name of the tokenizer in the blob store file.
|
||||||
|
static constexpr char kTokenizerName[] = "tokenizer";
|
||||||
|
|
||||||
|
private:
|
||||||
// Adds the blob to the table of contents.
|
// Adds the blob to the table of contents.
|
||||||
void AddToToc(const MatPtr& blob) {
|
void AddToToc(const MatPtr& blob) {
|
||||||
HWY_ASSERT(!Contains(blob.Name()));
|
HWY_ASSERT(!Contains(blob.Name()));
|
||||||
|
|
@ -519,6 +499,68 @@ struct CompressWorkingSet {
|
||||||
std::vector<CompressPerThread> tls;
|
std::vector<CompressPerThread> tls;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Class to collect and write a set of tensors to a blob store file.
|
||||||
|
class WriteToBlobStore {
|
||||||
|
public:
|
||||||
|
explicit WriteToBlobStore(hwy::ThreadPool& pool) : pool_(pool) {}
|
||||||
|
|
||||||
|
template <typename Packed>
|
||||||
|
void operator()(MatPtrT<Packed>* compressed, const char* decorated_name) {
|
||||||
|
if (compressed->Ptr() == nullptr) return;
|
||||||
|
writer_.Add(MakeKey(decorated_name), compressed->Ptr(),
|
||||||
|
compressed->SizeBytes());
|
||||||
|
MatPtr renamed_tensor(*compressed);
|
||||||
|
renamed_tensor.SetName(decorated_name);
|
||||||
|
renamed_tensor.AppendTo(toc_);
|
||||||
|
}
|
||||||
|
|
||||||
|
void AddTokenizer(const std::string& tokenizer) {
|
||||||
|
writer_.Add(MakeKey(BlobToc::kTokenizerName), tokenizer.data(),
|
||||||
|
tokenizer.size() * sizeof(tokenizer[0]));
|
||||||
|
}
|
||||||
|
|
||||||
|
void AddScales(const float* scales, size_t len) {
|
||||||
|
if (len) {
|
||||||
|
MatPtrT<float> scales_ptr("scales", 0, 1);
|
||||||
|
writer_.Add(MakeKey(scales_ptr.CacheName().c_str()), scales,
|
||||||
|
len * sizeof(scales[0]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Writes all blobs to disk in the given order. The config is optional and
|
||||||
|
// if given, it is written to the file, along with the TOC, making it
|
||||||
|
// single-file format. Otherwise, the file is written in the multi-file format
|
||||||
|
// without a TOC.
|
||||||
|
BlobError WriteAll(const Path& blob_filename, const ModelConfig* config) {
|
||||||
|
if (config) {
|
||||||
|
writer_.Add(MakeKey(BlobToc::kTocName), toc_.data(),
|
||||||
|
toc_.size() * sizeof(toc_[0]));
|
||||||
|
config_buffer_ = config->Write();
|
||||||
|
writer_.Add(MakeKey(BlobToc::kConfigName), config_buffer_.data(),
|
||||||
|
config_buffer_.size() * sizeof(config_buffer_[0]));
|
||||||
|
}
|
||||||
|
const BlobError err = writer_.WriteAll(pool_, blob_filename);
|
||||||
|
if (err != 0) {
|
||||||
|
fprintf(stderr, "Failed to write blobs to %s (error %d)\n",
|
||||||
|
blob_filename.path.c_str(), err);
|
||||||
|
}
|
||||||
|
return err;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns the number of blobs added.
|
||||||
|
size_t DebugNumBlobsAdded() const { return writer_.DebugNumBlobsAdded(); }
|
||||||
|
|
||||||
|
hwy::ThreadPool& pool() { return pool_; }
|
||||||
|
|
||||||
|
protected:
|
||||||
|
hwy::ThreadPool& pool_;
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::vector<uint32_t> toc_;
|
||||||
|
BlobWriter writer_;
|
||||||
|
std::vector<uint32_t> config_buffer_;
|
||||||
|
};
|
||||||
|
|
||||||
// Functor called for each tensor, which loads them and their scaling factors
|
// Functor called for each tensor, which loads them and their scaling factors
|
||||||
// from BlobStore.
|
// from BlobStore.
|
||||||
class ReadFromBlobStore {
|
class ReadFromBlobStore {
|
||||||
|
|
@ -539,11 +581,40 @@ class ReadFromBlobStore {
|
||||||
// Returns true if there is a TOC.
|
// Returns true if there is a TOC.
|
||||||
bool HaveToc() const { return !file_toc_.Empty(); }
|
bool HaveToc() const { return !file_toc_.Empty(); }
|
||||||
|
|
||||||
|
// Reads the config from the blob store file.
|
||||||
|
BlobError LoadConfig(ModelConfig& config) {
|
||||||
|
hwy::uint128_t config_key = MakeKey(BlobToc::kConfigName);
|
||||||
|
size_t config_size = reader_.BlobSize(config_key);
|
||||||
|
if (config_size == 0) return __LINE__;
|
||||||
|
std::vector<uint32_t> config_buffer(config_size / sizeof(uint32_t));
|
||||||
|
BlobError err =
|
||||||
|
reader_.ReadOne(config_key, config_buffer.data(), config_size);
|
||||||
|
if (err != 0) {
|
||||||
|
fprintf(stderr, "Failed to read config (error %d)\n", err);
|
||||||
|
return err;
|
||||||
|
}
|
||||||
|
config.Read(hwy::Span<const uint32_t>(config_buffer), 0);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reads the tokenizer from the blob store file.
|
||||||
|
BlobError LoadTokenizer(std::string& tokenizer) {
|
||||||
|
hwy::uint128_t key = MakeKey(BlobToc::kTokenizerName);
|
||||||
|
size_t tokenizer_size = reader_.BlobSize(key);
|
||||||
|
if (tokenizer_size == 0) return __LINE__;
|
||||||
|
tokenizer.resize(tokenizer_size);
|
||||||
|
;
|
||||||
|
BlobError err = reader_.ReadOne(key, tokenizer.data(), tokenizer_size);
|
||||||
|
if (err != 0) {
|
||||||
|
fprintf(stderr, "Failed to read tokenizer (error %d)\n", err);
|
||||||
|
return err;
|
||||||
|
}
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
// Called for each tensor, enqueues read requests.
|
// Called for each tensor, enqueues read requests.
|
||||||
void operator()(const char* name, hwy::Span<MatPtr*> tensors) {
|
void operator()(const char* name, hwy::Span<MatPtr*> tensors) {
|
||||||
if (file_toc_.Empty() || file_toc_.Contains(name)) {
|
if (file_toc_.Empty() || file_toc_.Contains(name)) {
|
||||||
if (tensors[0]->NumElements() == 0)
|
|
||||||
fprintf(stderr, "Zero elements for %s\n", name);
|
|
||||||
model_toc_.push_back(tensors[0]);
|
model_toc_.push_back(tensors[0]);
|
||||||
file_keys_.push_back(name);
|
file_keys_.push_back(name);
|
||||||
}
|
}
|
||||||
|
|
@ -579,12 +650,12 @@ class ReadFromBlobStore {
|
||||||
fprintf(stderr, "Blob %s has size mismatch TOC\n", file_key.c_str());
|
fprintf(stderr, "Blob %s has size mismatch TOC\n", file_key.c_str());
|
||||||
return __LINE__;
|
return __LINE__;
|
||||||
}
|
}
|
||||||
MatStorage toc_blob_array(*toc_blob);
|
std::string name = blob->Name();
|
||||||
model_memory.push_back(std::move(toc_blob_array));
|
*blob = *toc_blob;
|
||||||
} else {
|
blob->SetName(name);
|
||||||
model_memory.emplace_back(*blob);
|
|
||||||
model_memory.back().SetName(file_key);
|
|
||||||
}
|
}
|
||||||
|
model_memory.emplace_back(*blob);
|
||||||
|
model_memory.back().SetName(file_key);
|
||||||
}
|
}
|
||||||
// Allocate in parallel using the pool.
|
// Allocate in parallel using the pool.
|
||||||
pool.Run(0, model_memory.size(),
|
pool.Run(0, model_memory.size(),
|
||||||
|
|
@ -594,12 +665,12 @@ class ReadFromBlobStore {
|
||||||
});
|
});
|
||||||
// Enqueue the read requests.
|
// Enqueue the read requests.
|
||||||
for (auto& blob : model_memory) {
|
for (auto& blob : model_memory) {
|
||||||
err_ = reader_.Enqueue(MakeKey(blob.Name().c_str()), blob.data(),
|
err_ =
|
||||||
blob.SizeBytes());
|
reader_.Enqueue(MakeKey(blob.Name()), blob.data(), blob.SizeBytes());
|
||||||
if (err_ != 0) {
|
if (err_ != 0) {
|
||||||
fprintf(stderr,
|
fprintf(stderr,
|
||||||
"Failed to read blob %s (error %d) of size %zu x %zu x %zu\n",
|
"Failed to read blob %s (error %d) of size %zu x %zu x %zu\n",
|
||||||
blob.Name().c_str(), err_, blob.Rows(), blob.Cols(),
|
blob.Name(), err_, blob.Rows(), blob.Cols(),
|
||||||
blob.ElementSize());
|
blob.ElementSize());
|
||||||
return err_;
|
return err_;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -25,6 +25,7 @@
|
||||||
// After highway.h
|
// After highway.h
|
||||||
#include "compression/compress-inl.h"
|
#include "compression/compress-inl.h"
|
||||||
#include "gemma/configs.h"
|
#include "gemma/configs.h"
|
||||||
|
#include "gemma/tokenizer.h"
|
||||||
|
|
||||||
#ifndef GEMMA_COMPRESS_WEIGHTS_ONCE
|
#ifndef GEMMA_COMPRESS_WEIGHTS_ONCE
|
||||||
#define GEMMA_COMPRESS_WEIGHTS_ONCE
|
#define GEMMA_COMPRESS_WEIGHTS_ONCE
|
||||||
|
|
@ -99,6 +100,9 @@ struct Args : public ArgsBase<Args> {
|
||||||
std::string model_type_str;
|
std::string model_type_str;
|
||||||
std::string weight_type_str;
|
std::string weight_type_str;
|
||||||
size_t num_threads;
|
size_t num_threads;
|
||||||
|
// If non-empty, whether to include the config and TOC in the output file, as
|
||||||
|
// well as the tokenizer.
|
||||||
|
Path tokenizer;
|
||||||
|
|
||||||
template <class Visitor>
|
template <class Visitor>
|
||||||
void ForEach(const Visitor& visitor) {
|
void ForEach(const Visitor& visitor) {
|
||||||
|
|
@ -123,6 +127,9 @@ struct Args : public ArgsBase<Args> {
|
||||||
"Number of threads to use.\n Default = Estimate of the "
|
"Number of threads to use.\n Default = Estimate of the "
|
||||||
"number of supported concurrent threads.",
|
"number of supported concurrent threads.",
|
||||||
2);
|
2);
|
||||||
|
visitor(tokenizer, "tokenizer", Path(),
|
||||||
|
"Path to tokenizer file. If given, the config and TOC are also "
|
||||||
|
"added to the output file.");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Uninitialized before Validate, must call after that.
|
// Uninitialized before Validate, must call after that.
|
||||||
|
|
@ -156,7 +163,8 @@ namespace HWY_NAMESPACE {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void CompressWeights(const Path& weights_path,
|
void CompressWeights(const Path& weights_path,
|
||||||
const Path& compressed_weights_path, Model model_type,
|
const Path& compressed_weights_path, Model model_type,
|
||||||
hwy::ThreadPool& pool) {
|
Type weight_type, PromptWrapping wrapping,
|
||||||
|
const Path& tokenizer_path, hwy::ThreadPool& pool) {
|
||||||
if (!weights_path.Exists()) {
|
if (!weights_path.Exists()) {
|
||||||
HWY_ABORT("The model weights file '%s' does not exist.",
|
HWY_ABORT("The model weights file '%s' does not exist.",
|
||||||
weights_path.path.c_str());
|
weights_path.path.c_str());
|
||||||
|
|
@ -164,6 +172,8 @@ void CompressWeights(const Path& weights_path,
|
||||||
printf("Compressing weights from %s to %s\n", weights_path.path.c_str(),
|
printf("Compressing weights from %s to %s\n", weights_path.path.c_str(),
|
||||||
compressed_weights_path.path.c_str());
|
compressed_weights_path.path.c_str());
|
||||||
ModelConfig config = ConfigFromModel(model_type);
|
ModelConfig config = ConfigFromModel(model_type);
|
||||||
|
config.weight = weight_type;
|
||||||
|
config.wrapping = wrapping;
|
||||||
std::vector<MatStorage> model_storage;
|
std::vector<MatStorage> model_storage;
|
||||||
ModelWeightsPtrs<T> c_weights(config);
|
ModelWeightsPtrs<T> c_weights(config);
|
||||||
c_weights.Allocate(model_storage, pool);
|
c_weights.Allocate(model_storage, pool);
|
||||||
|
|
@ -185,6 +195,9 @@ void CompressWeights(const Path& weights_path,
|
||||||
ok &= 1 == fread(tensors[0]->Ptr(), tensors[0]->SizeBytes(), 1, fptr);
|
ok &= 1 == fread(tensors[0]->Ptr(), tensors[0]->SizeBytes(), 1, fptr);
|
||||||
total_size += tensors[0]->SizeBytes();
|
total_size += tensors[0]->SizeBytes();
|
||||||
});
|
});
|
||||||
|
if (!tokenizer_path.path.empty()) {
|
||||||
|
uc_weights.AllocAndCopyWithTranspose(pool, model_storage);
|
||||||
|
}
|
||||||
const bool scale_for_compression = config.num_tensor_scales > 0;
|
const bool scale_for_compression = config.num_tensor_scales > 0;
|
||||||
std::vector<float> scales;
|
std::vector<float> scales;
|
||||||
if (scale_for_compression) {
|
if (scale_for_compression) {
|
||||||
|
|
@ -193,14 +206,21 @@ void CompressWeights(const Path& weights_path,
|
||||||
Compressor compressor(pool);
|
Compressor compressor(pool);
|
||||||
ModelWeightsPtrs<T>::ForEachTensor(
|
ModelWeightsPtrs<T>::ForEachTensor(
|
||||||
{reinterpret_cast<ModelWeightsPtrs<T>*>(&uc_weights), &c_weights},
|
{reinterpret_cast<ModelWeightsPtrs<T>*>(&uc_weights), &c_weights},
|
||||||
ForEachType::kLoadNoToc,
|
tokenizer_path.path.empty() ? ForEachType::kLoadNoToc
|
||||||
|
: ForEachType::kLoadWithToc,
|
||||||
[&compressor](const char* name, hwy::Span<MatPtr*> tensors) {
|
[&compressor](const char* name, hwy::Span<MatPtr*> tensors) {
|
||||||
tensors[1]->CallUpcasted(
|
tensors[1]->CallUpcasted(
|
||||||
compressor, name,
|
compressor, name,
|
||||||
reinterpret_cast<const float*>(tensors[0]->Ptr()));
|
reinterpret_cast<const float*>(tensors[0]->Ptr()));
|
||||||
});
|
});
|
||||||
compressor.AddScales(scales.data(), scales.size() * sizeof(scales[0]));
|
if (!tokenizer_path.path.empty()) {
|
||||||
compressor.WriteAll(pool, compressed_weights_path);
|
std::string tokenizer_proto = ReadFileToString(tokenizer_path);
|
||||||
|
compressor.AddTokenizer(tokenizer_proto);
|
||||||
|
} else {
|
||||||
|
compressor.AddScales(scales.data(), scales.size() * sizeof(scales[0]));
|
||||||
|
}
|
||||||
|
compressor.WriteAll(compressed_weights_path,
|
||||||
|
tokenizer_path.path.empty() ? nullptr : &config);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace HWY_NAMESPACE
|
} // namespace HWY_NAMESPACE
|
||||||
|
|
@ -220,19 +240,23 @@ void Run(Args& args) {
|
||||||
switch (weight_type) {
|
switch (weight_type) {
|
||||||
case Type::kF32:
|
case Type::kF32:
|
||||||
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(CompressWeights<float>)
|
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(CompressWeights<float>)
|
||||||
(args.weights, args.compressed_weights, model_type, pool);
|
(args.weights, args.compressed_weights, model_type, weight_type,
|
||||||
|
args.PromptWrappingType(), args.tokenizer, pool);
|
||||||
break;
|
break;
|
||||||
case Type::kBF16:
|
case Type::kBF16:
|
||||||
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(CompressWeights<BF16>)
|
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(CompressWeights<BF16>)
|
||||||
(args.weights, args.compressed_weights, model_type, pool);
|
(args.weights, args.compressed_weights, model_type, weight_type,
|
||||||
|
args.PromptWrappingType(), args.tokenizer, pool);
|
||||||
break;
|
break;
|
||||||
case Type::kSFP:
|
case Type::kSFP:
|
||||||
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(CompressWeights<SfpStream>)
|
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(CompressWeights<SfpStream>)
|
||||||
(args.weights, args.compressed_weights, model_type, pool);
|
(args.weights, args.compressed_weights, model_type, weight_type,
|
||||||
|
args.PromptWrappingType(), args.tokenizer, pool);
|
||||||
break;
|
break;
|
||||||
case Type::kNUQ:
|
case Type::kNUQ:
|
||||||
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(CompressWeights<NuqStream>)
|
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(CompressWeights<NuqStream>)
|
||||||
(args.weights, args.compressed_weights, model_type, pool);
|
(args.weights, args.compressed_weights, model_type, weight_type,
|
||||||
|
args.PromptWrappingType(), args.tokenizer, pool);
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
HWY_ABORT("Weight type %d unsupported.", static_cast<int>(weight_type));
|
HWY_ABORT("Weight type %d unsupported.", static_cast<int>(weight_type));
|
||||||
|
|
|
||||||
|
|
@ -83,6 +83,14 @@ class PrintVisitor : public VisitorBase {
|
||||||
fprintf(stderr, "%sU32 %u\n", indent_.c_str(), value);
|
fprintf(stderr, "%sU32 %u\n", indent_.c_str(), value);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void operator()(int32_t& value) override {
|
||||||
|
fprintf(stderr, "%sI32 %d\n", indent_.c_str(), value);
|
||||||
|
}
|
||||||
|
|
||||||
|
void operator()(uint64_t& value) override {
|
||||||
|
fprintf(stderr, "%sU64 %zu\n", indent_.c_str(), value);
|
||||||
|
}
|
||||||
|
|
||||||
void operator()(float& value) override {
|
void operator()(float& value) override {
|
||||||
fprintf(stderr, "%sF32 %f\n", indent_.c_str(), value);
|
fprintf(stderr, "%sF32 %f\n", indent_.c_str(), value);
|
||||||
}
|
}
|
||||||
|
|
@ -120,6 +128,21 @@ class ReadVisitor : public VisitorBase {
|
||||||
value = span_[result_.pos++];
|
value = span_[result_.pos++];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void operator()(int32_t& value) override {
|
||||||
|
if (HWY_UNLIKELY(SkipField())) return;
|
||||||
|
|
||||||
|
value = static_cast<int32_t>(span_[result_.pos++]);
|
||||||
|
}
|
||||||
|
|
||||||
|
void operator()(uint64_t& value) override {
|
||||||
|
if (HWY_UNLIKELY(SkipField())) return;
|
||||||
|
uint32_t lower = static_cast<uint32_t>(value);
|
||||||
|
operator()(lower);
|
||||||
|
uint32_t upper = static_cast<uint32_t>(value >> 32);
|
||||||
|
operator()(upper);
|
||||||
|
value = lower | (static_cast<uint64_t>(upper) << 32);
|
||||||
|
}
|
||||||
|
|
||||||
void operator()(float& value) override {
|
void operator()(float& value) override {
|
||||||
if (HWY_UNLIKELY(SkipField())) return;
|
if (HWY_UNLIKELY(SkipField())) return;
|
||||||
|
|
||||||
|
|
@ -229,6 +252,15 @@ class WriteVisitor : public VisitorBase {
|
||||||
|
|
||||||
void operator()(uint32_t& value) override { storage_.push_back(value); }
|
void operator()(uint32_t& value) override { storage_.push_back(value); }
|
||||||
|
|
||||||
|
void operator()(int32_t& value) override {
|
||||||
|
storage_.push_back(static_cast<uint32_t>(value));
|
||||||
|
}
|
||||||
|
|
||||||
|
void operator()(uint64_t& value) override {
|
||||||
|
storage_.push_back(static_cast<uint32_t>(value));
|
||||||
|
storage_.push_back(static_cast<uint32_t>(value >> 32));
|
||||||
|
}
|
||||||
|
|
||||||
void operator()(float& value) override {
|
void operator()(float& value) override {
|
||||||
storage_.push_back(hwy::BitCastScalar<uint32_t>(value));
|
storage_.push_back(hwy::BitCastScalar<uint32_t>(value));
|
||||||
CheckF32(value);
|
CheckF32(value);
|
||||||
|
|
|
||||||
|
|
@ -55,8 +55,9 @@ class IFields; // breaks circular dependency
|
||||||
// Visitors are internal-only, but their base class is visible to user code
|
// Visitors are internal-only, but their base class is visible to user code
|
||||||
// because their `IFields::VisitFields` calls `visitor.operator()`.
|
// because their `IFields::VisitFields` calls `visitor.operator()`.
|
||||||
//
|
//
|
||||||
// Supported field types `T`: `uint32_t`, `float`, `std::string`, classes
|
// Supported field types `T`: `uint32_t`, `int32_t`, `uint64_t`, `float`,
|
||||||
// derived from `IFields`, `bool`, `enum`, `std::vector<T>`.
|
// `std::string`,
|
||||||
|
// classes derived from `IFields`, `bool`, `enum`, `std::vector<T>`.
|
||||||
class IFieldsVisitor {
|
class IFieldsVisitor {
|
||||||
public:
|
public:
|
||||||
virtual ~IFieldsVisitor();
|
virtual ~IFieldsVisitor();
|
||||||
|
|
@ -69,6 +70,8 @@ class IFieldsVisitor {
|
||||||
// is out of range. A single generic/overloaded function is required to
|
// is out of range. A single generic/overloaded function is required to
|
||||||
// support `std::vector<T>`.
|
// support `std::vector<T>`.
|
||||||
virtual void operator()(uint32_t& value) = 0;
|
virtual void operator()(uint32_t& value) = 0;
|
||||||
|
virtual void operator()(int32_t& value) = 0;
|
||||||
|
virtual void operator()(uint64_t& value) = 0;
|
||||||
virtual void operator()(float& value) = 0;
|
virtual void operator()(float& value) = 0;
|
||||||
virtual void operator()(std::string& value) = 0;
|
virtual void operator()(std::string& value) = 0;
|
||||||
virtual void operator()(IFields& fields) = 0; // recurse into nested fields
|
virtual void operator()(IFields& fields) = 0; // recurse into nested fields
|
||||||
|
|
@ -92,7 +95,7 @@ class IFieldsVisitor {
|
||||||
uint32_t u32 = static_cast<uint32_t>(value);
|
uint32_t u32 = static_cast<uint32_t>(value);
|
||||||
operator()(u32);
|
operator()(u32);
|
||||||
if (HWY_UNLIKELY(!EnumValid(static_cast<EnumT>(u32)))) {
|
if (HWY_UNLIKELY(!EnumValid(static_cast<EnumT>(u32)))) {
|
||||||
return NotifyInvalid("Invalid enum %u\n");
|
return NotifyInvalid("Invalid enum %u\n", u32);
|
||||||
}
|
}
|
||||||
value = static_cast<EnumT>(u32);
|
value = static_cast<EnumT>(u32);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -97,6 +97,8 @@ struct OldFields : public IFields {
|
||||||
visitor(old_str);
|
visitor(old_str);
|
||||||
visitor(old_nested);
|
visitor(old_nested);
|
||||||
visitor(old1);
|
visitor(old1);
|
||||||
|
visitor(oldi);
|
||||||
|
visitor(oldl);
|
||||||
visitor(old_vec_str);
|
visitor(old_vec_str);
|
||||||
visitor(old_vec_nested);
|
visitor(old_vec_nested);
|
||||||
visitor(old_f);
|
visitor(old_f);
|
||||||
|
|
@ -110,6 +112,8 @@ struct OldFields : public IFields {
|
||||||
EXPECT_EQ(old_str, n.old_str);
|
EXPECT_EQ(old_str, n.old_str);
|
||||||
old_nested.CheckEqual(n.old_nested);
|
old_nested.CheckEqual(n.old_nested);
|
||||||
EXPECT_EQ(old1, n.old1);
|
EXPECT_EQ(old1, n.old1);
|
||||||
|
EXPECT_EQ(oldi, n.oldi);
|
||||||
|
EXPECT_EQ(oldl, n.oldl);
|
||||||
CheckVectorEqual(old_vec_str, n.old_vec_str);
|
CheckVectorEqual(old_vec_str, n.old_vec_str);
|
||||||
CheckVectorEqual(old_vec_nested, n.old_vec_nested);
|
CheckVectorEqual(old_vec_nested, n.old_vec_nested);
|
||||||
EXPECT_EQ(old_f, n.old_f);
|
EXPECT_EQ(old_f, n.old_f);
|
||||||
|
|
@ -120,6 +124,8 @@ struct OldFields : public IFields {
|
||||||
std::string old_str = "old";
|
std::string old_str = "old";
|
||||||
Nested old_nested = Nested(0);
|
Nested old_nested = Nested(0);
|
||||||
uint32_t old1 = 1;
|
uint32_t old1 = 1;
|
||||||
|
int32_t oldi = -1;
|
||||||
|
uint64_t oldl = 1234567890123456789;
|
||||||
std::vector<std::string> old_vec_str = {"abc", "1234"};
|
std::vector<std::string> old_vec_str = {"abc", "1234"};
|
||||||
std::vector<Nested> old_vec_nested = {Nested(1), Nested(4)};
|
std::vector<Nested> old_vec_nested = {Nested(1), Nested(4)};
|
||||||
float old_f = 1.125f;
|
float old_f = 1.125f;
|
||||||
|
|
@ -134,6 +140,8 @@ struct NewFields : public IFields {
|
||||||
visitor(old_str);
|
visitor(old_str);
|
||||||
visitor(old_nested);
|
visitor(old_nested);
|
||||||
visitor(old1);
|
visitor(old1);
|
||||||
|
visitor(oldi);
|
||||||
|
visitor(oldl);
|
||||||
visitor(old_vec_str);
|
visitor(old_vec_str);
|
||||||
visitor(old_vec_nested);
|
visitor(old_vec_nested);
|
||||||
visitor(old_f);
|
visitor(old_f);
|
||||||
|
|
@ -149,6 +157,8 @@ struct NewFields : public IFields {
|
||||||
visitor(new_enum);
|
visitor(new_enum);
|
||||||
visitor(new2);
|
visitor(new2);
|
||||||
visitor(new_str);
|
visitor(new_str);
|
||||||
|
visitor(new_i);
|
||||||
|
visitor(new_l);
|
||||||
}
|
}
|
||||||
|
|
||||||
void CheckEqual(const NewFields& n) const {
|
void CheckEqual(const NewFields& n) const {
|
||||||
|
|
@ -176,6 +186,8 @@ struct NewFields : public IFields {
|
||||||
std::string old_str = "old";
|
std::string old_str = "old";
|
||||||
Nested old_nested = Nested(0);
|
Nested old_nested = Nested(0);
|
||||||
uint32_t old1 = 1;
|
uint32_t old1 = 1;
|
||||||
|
int32_t oldi = -1;
|
||||||
|
uint64_t oldl = 1234567890123456789;
|
||||||
std::vector<std::string> old_vec_str = {"abc", "1234"};
|
std::vector<std::string> old_vec_str = {"abc", "1234"};
|
||||||
std::vector<Nested> old_vec_nested = {Nested(1), Nested(4)};
|
std::vector<Nested> old_vec_nested = {Nested(1), Nested(4)};
|
||||||
float old_f = 1.125f;
|
float old_f = 1.125f;
|
||||||
|
|
@ -190,6 +202,8 @@ struct NewFields : public IFields {
|
||||||
Enum new_enum = Enum::k3;
|
Enum new_enum = Enum::k3;
|
||||||
uint32_t new2 = 2;
|
uint32_t new2 = 2;
|
||||||
std::string new_str = std::string(); // empty is allowed
|
std::string new_str = std::string(); // empty is allowed
|
||||||
|
int32_t new_i = 123456789;
|
||||||
|
uint64_t new_l = 876543210987654321;
|
||||||
}; // NewFields
|
}; // NewFields
|
||||||
|
|
||||||
// Changes all fields to non-default values.
|
// Changes all fields to non-default values.
|
||||||
|
|
@ -212,6 +226,8 @@ NewFields ModifiedNewFields() {
|
||||||
n.new_enum = Enum::k8;
|
n.new_enum = Enum::k8;
|
||||||
n.new2 = 22;
|
n.new2 = 22;
|
||||||
n.new_str = "new and even longer";
|
n.new_str = "new and even longer";
|
||||||
|
n.new_i = 246810121;
|
||||||
|
n.new_l = 1357913579113579135;
|
||||||
|
|
||||||
return n;
|
return n;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,62 @@
|
||||||
|
// Copyright 2024 Google LLC
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include <stdio.h>
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "evals/benchmark_helper.h"
|
||||||
|
#include "gemma/gemma.h"
|
||||||
|
#include "util/args.h"
|
||||||
|
|
||||||
|
namespace gcpp {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
struct WriterArgs : public ArgsBase<WriterArgs> {
|
||||||
|
// --output_weights is required.
|
||||||
|
WriterArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
|
||||||
|
|
||||||
|
// Returns error string or nullptr if OK.
|
||||||
|
const char* Validate() {
|
||||||
|
if (output_weights.path.empty()) {
|
||||||
|
return "Missing --output_weights flag, a file for the model weights.";
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
Path output_weights; // weights file location
|
||||||
|
|
||||||
|
template <class Visitor>
|
||||||
|
void ForEach(const Visitor& visitor) {
|
||||||
|
visitor(output_weights, "output_weights", Path(),
|
||||||
|
"Path name of output weights (.sbs) file.\n Required argument.");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace gcpp
|
||||||
|
|
||||||
|
int main(int argc, char** argv) {
|
||||||
|
// Loads a model in the multi-file format and saves it in single-file format.
|
||||||
|
gcpp::WriterArgs args(argc, argv);
|
||||||
|
if (const char* err = args.Validate()) {
|
||||||
|
fprintf(stderr, "Skipping model load because: %s\n", err);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
gcpp::GemmaEnv env(argc, argv, /*required=*/true);
|
||||||
|
hwy::ThreadPool pool(0);
|
||||||
|
env.GetModel()->Save(args.output_weights, pool);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
@ -16,6 +16,7 @@ cc_library(
|
||||||
deps = [
|
deps = [
|
||||||
"@abseil-cpp//absl/types:span",
|
"@abseil-cpp//absl/types:span",
|
||||||
"//:common",
|
"//:common",
|
||||||
|
"//:tokenizer",
|
||||||
"//compression:compress",
|
"//compression:compress",
|
||||||
"//compression:io",
|
"//compression:io",
|
||||||
"@highway//:hwy",
|
"@highway//:hwy",
|
||||||
|
|
|
||||||
|
|
@ -24,7 +24,9 @@
|
||||||
|
|
||||||
#include "absl/types/span.h"
|
#include "absl/types/span.h"
|
||||||
#include "compression/io.h"
|
#include "compression/io.h"
|
||||||
|
#include "gemma/configs.h"
|
||||||
#include "gemma/tensor_index.h"
|
#include "gemma/tensor_index.h"
|
||||||
|
#include "gemma/tokenizer.h"
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
|
|
||||||
|
|
@ -44,10 +46,11 @@ class WriterInterface {
|
||||||
virtual void InsertFloat(std::string name,
|
virtual void InsertFloat(std::string name,
|
||||||
absl::Span<const float> weights) = 0;
|
absl::Span<const float> weights) = 0;
|
||||||
virtual void AddScales(const std::vector<float>& scales) = 0;
|
virtual void AddScales(const std::vector<float>& scales) = 0;
|
||||||
|
virtual void AddTokenizer(const std::string& tokenizer_path) = 0;
|
||||||
|
|
||||||
virtual size_t DebugNumBlobsAdded() const = 0;
|
virtual size_t DebugNumBlobsAdded() const = 0;
|
||||||
|
|
||||||
virtual int Write(std::string path) = 0;
|
virtual int WriteWithConfig(std::string path, const ModelConfig* config) = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
@ -133,14 +136,21 @@ class SbsWriterImpl : public WriterInterface {
|
||||||
compressor_.AddScales(scales_.data(), scales_.size());
|
compressor_.AddScales(scales_.data(), scales_.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void AddTokenizer(const std::string& tokenizer_path) override {
|
||||||
|
Path path(tokenizer_path);
|
||||||
|
GemmaTokenizer tokenizer(path);
|
||||||
|
tokenizer_proto_ = tokenizer.Serialize();
|
||||||
|
compressor_.AddTokenizer(tokenizer_proto_);
|
||||||
|
}
|
||||||
|
|
||||||
// Returns the number of blobs added.
|
// Returns the number of blobs added.
|
||||||
size_t DebugNumBlobsAdded() const {
|
size_t DebugNumBlobsAdded() const {
|
||||||
if (mode_ == CompressorMode::kTEST_ONLY) return model_memory_.size();
|
if (mode_ == CompressorMode::kTEST_ONLY) return model_memory_.size();
|
||||||
return compressor_.DebugNumBlobsAdded();
|
return compressor_.DebugNumBlobsAdded();
|
||||||
}
|
}
|
||||||
|
|
||||||
int Write(std::string path) override {
|
int WriteWithConfig(std::string path, const ModelConfig* config) override {
|
||||||
return compressor_.WriteAll(pool_, gcpp::Path(path));
|
return compressor_.WriteAll(gcpp::Path(path), config);
|
||||||
}
|
}
|
||||||
|
|
||||||
hwy::ThreadPool pool_;
|
hwy::ThreadPool pool_;
|
||||||
|
|
@ -149,6 +159,7 @@ class SbsWriterImpl : public WriterInterface {
|
||||||
std::vector<MatStorage> model_memory_;
|
std::vector<MatStorage> model_memory_;
|
||||||
std::vector<float> scales_;
|
std::vector<float> scales_;
|
||||||
CompressorMode mode_;
|
CompressorMode mode_;
|
||||||
|
std::string tokenizer_proto_;
|
||||||
};
|
};
|
||||||
|
|
||||||
WriterInterface* NewSbsWriter(CompressorMode mode) {
|
WriterInterface* NewSbsWriter(CompressorMode mode) {
|
||||||
|
|
@ -190,11 +201,17 @@ void SbsWriter::AddScales(const std::vector<float>& scales) {
|
||||||
impl_->AddScales(scales);
|
impl_->AddScales(scales);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void SbsWriter::AddTokenizer(const std::string& tokenizer_path) {
|
||||||
|
impl_->AddTokenizer(tokenizer_path);
|
||||||
|
}
|
||||||
|
|
||||||
size_t SbsWriter::DebugNumBlobsAdded() const {
|
size_t SbsWriter::DebugNumBlobsAdded() const {
|
||||||
return impl_->DebugNumBlobsAdded();
|
return impl_->DebugNumBlobsAdded();
|
||||||
}
|
}
|
||||||
|
|
||||||
int SbsWriter::Write(std::string path) { return impl_->Write(path); }
|
int SbsWriter::WriteWithConfig(std::string path, const ModelConfig* config) {
|
||||||
|
return impl_->WriteWithConfig(path, config);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
#endif // HWY_ONCE
|
#endif // HWY_ONCE
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,7 @@
|
||||||
|
|
||||||
#include "absl/types/span.h"
|
#include "absl/types/span.h"
|
||||||
#include "compression/shared.h"
|
#include "compression/shared.h"
|
||||||
|
#include "gemma/configs.h"
|
||||||
#include "gemma/tensor_index.h"
|
#include "gemma/tensor_index.h"
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
@ -36,10 +37,12 @@ class SbsWriter {
|
||||||
void InsertBfloat16(std::string name, absl::Span<const float> weights);
|
void InsertBfloat16(std::string name, absl::Span<const float> weights);
|
||||||
void InsertFloat(std::string name, absl::Span<const float> weights);
|
void InsertFloat(std::string name, absl::Span<const float> weights);
|
||||||
void AddScales(const std::vector<float>& scales);
|
void AddScales(const std::vector<float>& scales);
|
||||||
|
void AddTokenizer(const std::string& tokenizer_path);
|
||||||
|
|
||||||
size_t DebugNumBlobsAdded() const;
|
size_t DebugNumBlobsAdded() const;
|
||||||
|
|
||||||
int Write(std::string path);
|
int Write(std::string path) { return WriteWithConfig(path, nullptr); }
|
||||||
|
int WriteWithConfig(std::string path, const ModelConfig* config);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Isolates Highway-dispatched types and other internals from CLIF.
|
// Isolates Highway-dispatched types and other internals from CLIF.
|
||||||
|
|
|
||||||
|
|
@ -50,6 +50,8 @@ PYBIND11_MODULE(compression, m) {
|
||||||
.def("insert_bf16", wrap_span<&SbsWriter::InsertBfloat16>)
|
.def("insert_bf16", wrap_span<&SbsWriter::InsertBfloat16>)
|
||||||
.def("insert_float", wrap_span<&SbsWriter::InsertFloat>)
|
.def("insert_float", wrap_span<&SbsWriter::InsertFloat>)
|
||||||
.def("add_scales", &SbsWriter::AddScales)
|
.def("add_scales", &SbsWriter::AddScales)
|
||||||
|
.def("add_tokenizer", &SbsWriter::AddTokenizer)
|
||||||
.def("debug_num_blobs_added", &SbsWriter::DebugNumBlobsAdded)
|
.def("debug_num_blobs_added", &SbsWriter::DebugNumBlobsAdded)
|
||||||
.def("write", &SbsWriter::Write);
|
.def("write", &SbsWriter::Write)
|
||||||
|
.def("write_with_config", &SbsWriter::WriteWithConfig);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -198,6 +198,11 @@ constexpr bool IsNuqStream() {
|
||||||
// Instruction-tuned models require extra 'turn structure' tokens in prompts.
|
// Instruction-tuned models require extra 'turn structure' tokens in prompts.
|
||||||
enum class PromptWrapping { GEMMA_IT, GEMMA_PT, PALIGEMMA };
|
enum class PromptWrapping { GEMMA_IT, GEMMA_PT, PALIGEMMA };
|
||||||
|
|
||||||
|
inline bool EnumValid(PromptWrapping type) {
|
||||||
|
return static_cast<int>(type) >= 0 &&
|
||||||
|
static_cast<int>(type) <= static_cast<int>(PromptWrapping::PALIGEMMA);
|
||||||
|
}
|
||||||
|
|
||||||
// Tensor types for loading weights. Note that not all types are supported as
|
// Tensor types for loading weights. Note that not all types are supported as
|
||||||
// weights for a model, but can be used for other purposes, such as types for
|
// weights for a model, but can be used for other purposes, such as types for
|
||||||
// ModelWeightsPtrs. When adding a new type that is supported, also
|
// ModelWeightsPtrs. When adding a new type that is supported, also
|
||||||
|
|
@ -206,6 +211,11 @@ enum class Type { kUnknown, kF32, kBF16, kSFP, kNUQ, kF64, kC64, kU128 };
|
||||||
constexpr const char* kTypeStrings[] = {"unknown", "f32", "bf16", "sfp",
|
constexpr const char* kTypeStrings[] = {"unknown", "f32", "bf16", "sfp",
|
||||||
"nuq", "f64", "c64", "u128"};
|
"nuq", "f64", "c64", "u128"};
|
||||||
|
|
||||||
|
inline bool EnumValid(Type type) {
|
||||||
|
return static_cast<int>(type) >= 0 &&
|
||||||
|
static_cast<int>(type) <= static_cast<int>(Type::kU128);
|
||||||
|
}
|
||||||
|
|
||||||
// Returns a Type enum for the type of the template parameter.
|
// Returns a Type enum for the type of the template parameter.
|
||||||
template <typename PackedT>
|
template <typename PackedT>
|
||||||
Type TypeEnum() {
|
Type TypeEnum() {
|
||||||
|
|
|
||||||
|
|
@ -92,9 +92,9 @@ static AppArgs MakeAppArgs(int argc, char** argv) {
|
||||||
return AppArgs(argc, argv);
|
return AppArgs(argc, argv);
|
||||||
}
|
}
|
||||||
|
|
||||||
GemmaEnv::GemmaEnv(int argc, char** argv)
|
GemmaEnv::GemmaEnv(int argc, char** argv, bool model_type_required)
|
||||||
: GemmaEnv(LoaderArgs(argc, argv), InferenceArgs(argc, argv),
|
: GemmaEnv(LoaderArgs(argc, argv, model_type_required),
|
||||||
MakeAppArgs(argc, argv)) {}
|
InferenceArgs(argc, argv), MakeAppArgs(argc, argv)) {}
|
||||||
|
|
||||||
QueryResult GemmaEnv::QueryModel(const std::vector<int>& tokens) {
|
QueryResult GemmaEnv::QueryModel(const std::vector<int>& tokens) {
|
||||||
QueryResult result;
|
QueryResult result;
|
||||||
|
|
@ -270,7 +270,9 @@ void ShowHelp(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
|
||||||
"specify 3 required model loading arguments:\n"
|
"specify 3 required model loading arguments:\n"
|
||||||
" --tokenizer\n"
|
" --tokenizer\n"
|
||||||
" --weights\n"
|
" --weights\n"
|
||||||
" --model.\n";
|
" --model,\n"
|
||||||
|
" or with the newer weights format, specify just:\n"
|
||||||
|
" --weights\n";
|
||||||
std::cerr << "\n*Example Usage*\n\n./gemma --tokenizer tokenizer.spm "
|
std::cerr << "\n*Example Usage*\n\n./gemma --tokenizer tokenizer.spm "
|
||||||
"--weights 2b-it-sfp.sbs --model 2b-it\n";
|
"--weights 2b-it-sfp.sbs --model 2b-it\n";
|
||||||
std::cerr << "\n*Model Loading Arguments*\n\n";
|
std::cerr << "\n*Model Loading Arguments*\n\n";
|
||||||
|
|
|
||||||
|
|
@ -44,7 +44,7 @@ struct QueryResult {
|
||||||
class GemmaEnv {
|
class GemmaEnv {
|
||||||
public:
|
public:
|
||||||
// Calls the other constructor with *Args arguments initialized from argv.
|
// Calls the other constructor with *Args arguments initialized from argv.
|
||||||
GemmaEnv(int argc, char** argv);
|
GemmaEnv(int argc, char** argv, bool model_type_required = false);
|
||||||
GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference,
|
GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference,
|
||||||
const AppArgs& app);
|
const AppArgs& app);
|
||||||
|
|
||||||
|
|
|
||||||
177
gemma/configs.cc
177
gemma/configs.cc
|
|
@ -15,6 +15,7 @@
|
||||||
|
|
||||||
#include "gemma/configs.h"
|
#include "gemma/configs.h"
|
||||||
|
|
||||||
|
#include <cstddef>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
|
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
|
|
@ -22,9 +23,9 @@
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
static ModelConfig ConfigNoSSM() {
|
static ModelConfig ConfigNoSSM() {
|
||||||
ModelConfig config = {.scale_names = {"att_ein", "qkv_ein", "gr_lin_x_w",
|
ModelConfig config;
|
||||||
"gr_lin_y_w", "gr_lin_out_w",
|
config.scale_names = {"att_ein", "qkv_ein", "gr_lin_x_w", "gr_lin_y_w",
|
||||||
"gr_gate_w", "gating_ein", "linear_w"}};
|
"gr_lin_out_w", "gr_gate_w", "gating_ein", "linear_w"};
|
||||||
return config;
|
return config;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -37,6 +38,18 @@ static ModelConfig ConfigBaseGemmaV2() {
|
||||||
return config;
|
return config;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static LayerConfig LayerConfigGemma2_27B(size_t model_dim) {
|
||||||
|
LayerConfig config;
|
||||||
|
config.model_dim = model_dim;
|
||||||
|
config.ff_hidden_dim = 16 * 4608 / 2; // = 36864
|
||||||
|
config.heads = 32;
|
||||||
|
config.kv_heads = 16;
|
||||||
|
config.qkv_dim = 128;
|
||||||
|
config.optimized_gating = false;
|
||||||
|
config.post_norm = PostNormType::Scale;
|
||||||
|
return config;
|
||||||
|
}
|
||||||
|
|
||||||
static ModelConfig ConfigGemma2_27B() {
|
static ModelConfig ConfigGemma2_27B() {
|
||||||
ModelConfig config = ConfigBaseGemmaV2();
|
ModelConfig config = ConfigBaseGemmaV2();
|
||||||
config.model_name = "Gemma2_27B";
|
config.model_name = "Gemma2_27B";
|
||||||
|
|
@ -44,13 +57,7 @@ static ModelConfig ConfigGemma2_27B() {
|
||||||
config.model_dim = 4608;
|
config.model_dim = 4608;
|
||||||
config.vocab_size = kVocabSize;
|
config.vocab_size = kVocabSize;
|
||||||
config.seq_len = 8192;
|
config.seq_len = 8192;
|
||||||
LayerConfig layer_config = {.model_dim = config.model_dim,
|
LayerConfig layer_config = LayerConfigGemma2_27B(config.model_dim);
|
||||||
.ff_hidden_dim = 16 * 4608 / 2, // = 36864
|
|
||||||
.heads = 32,
|
|
||||||
.kv_heads = 16,
|
|
||||||
.qkv_dim = 128,
|
|
||||||
.optimized_gating = false,
|
|
||||||
.post_norm = PostNormType::Scale};
|
|
||||||
config.layer_configs = {46, layer_config};
|
config.layer_configs = {46, layer_config};
|
||||||
config.num_tensor_scales = 4 * config.layer_configs.size();
|
config.num_tensor_scales = 4 * config.layer_configs.size();
|
||||||
config.query_scale = QueryScaleType::SqrtModelDimDivNumHeads;
|
config.query_scale = QueryScaleType::SqrtModelDimDivNumHeads;
|
||||||
|
|
@ -59,6 +66,18 @@ static ModelConfig ConfigGemma2_27B() {
|
||||||
return config;
|
return config;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static LayerConfig LayerConfigGemma2_9B(size_t model_dim) {
|
||||||
|
LayerConfig config;
|
||||||
|
config.model_dim = model_dim;
|
||||||
|
config.ff_hidden_dim = 8 * 3584 / 2; // = 14336
|
||||||
|
config.heads = 16;
|
||||||
|
config.kv_heads = 8;
|
||||||
|
config.qkv_dim = 256;
|
||||||
|
config.optimized_gating = false;
|
||||||
|
config.post_norm = PostNormType::Scale;
|
||||||
|
return config;
|
||||||
|
}
|
||||||
|
|
||||||
static ModelConfig ConfigGemma2_9B() {
|
static ModelConfig ConfigGemma2_9B() {
|
||||||
ModelConfig config = ConfigBaseGemmaV2();
|
ModelConfig config = ConfigBaseGemmaV2();
|
||||||
config.model_name = "Gemma2_9B";
|
config.model_name = "Gemma2_9B";
|
||||||
|
|
@ -66,13 +85,7 @@ static ModelConfig ConfigGemma2_9B() {
|
||||||
config.model_dim = 3584;
|
config.model_dim = 3584;
|
||||||
config.vocab_size = kVocabSize;
|
config.vocab_size = kVocabSize;
|
||||||
config.seq_len = 8192;
|
config.seq_len = 8192;
|
||||||
LayerConfig layer_config = {.model_dim = config.model_dim,
|
LayerConfig layer_config = LayerConfigGemma2_9B(config.model_dim);
|
||||||
.ff_hidden_dim = 8 * 3584 / 2, // = 14336
|
|
||||||
.heads = 16,
|
|
||||||
.kv_heads = 8,
|
|
||||||
.qkv_dim = 256,
|
|
||||||
.optimized_gating = false,
|
|
||||||
.post_norm = PostNormType::Scale};
|
|
||||||
config.layer_configs = {42, layer_config};
|
config.layer_configs = {42, layer_config};
|
||||||
config.num_tensor_scales = 4 * config.layer_configs.size();
|
config.num_tensor_scales = 4 * config.layer_configs.size();
|
||||||
config.query_scale = QueryScaleType::SqrtKeySize;
|
config.query_scale = QueryScaleType::SqrtKeySize;
|
||||||
|
|
@ -81,6 +94,18 @@ static ModelConfig ConfigGemma2_9B() {
|
||||||
return config;
|
return config;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static LayerConfig LayerConfigGemma2_2B(size_t model_dim) {
|
||||||
|
LayerConfig config;
|
||||||
|
config.model_dim = model_dim;
|
||||||
|
config.ff_hidden_dim = 8 * 2304 / 2; // = 9216
|
||||||
|
config.heads = 8;
|
||||||
|
config.kv_heads = 4;
|
||||||
|
config.qkv_dim = 256;
|
||||||
|
config.optimized_gating = false;
|
||||||
|
config.post_norm = PostNormType::Scale;
|
||||||
|
return config;
|
||||||
|
}
|
||||||
|
|
||||||
static ModelConfig ConfigGemma2_2B() {
|
static ModelConfig ConfigGemma2_2B() {
|
||||||
ModelConfig config = ConfigBaseGemmaV2();
|
ModelConfig config = ConfigBaseGemmaV2();
|
||||||
config.model_name = "Gemma2_2B";
|
config.model_name = "Gemma2_2B";
|
||||||
|
|
@ -88,13 +113,7 @@ static ModelConfig ConfigGemma2_2B() {
|
||||||
config.model_dim = 2304;
|
config.model_dim = 2304;
|
||||||
config.vocab_size = kVocabSize;
|
config.vocab_size = kVocabSize;
|
||||||
config.seq_len = 8192;
|
config.seq_len = 8192;
|
||||||
LayerConfig layer_config = {.model_dim = config.model_dim,
|
LayerConfig layer_config = LayerConfigGemma2_2B(config.model_dim);
|
||||||
.ff_hidden_dim = 8 * 2304 / 2, // = 9216
|
|
||||||
.heads = 8,
|
|
||||||
.kv_heads = 4,
|
|
||||||
.qkv_dim = 256,
|
|
||||||
.optimized_gating = false,
|
|
||||||
.post_norm = PostNormType::Scale};
|
|
||||||
config.layer_configs = {26, layer_config};
|
config.layer_configs = {26, layer_config};
|
||||||
config.num_tensor_scales = 4 * config.layer_configs.size();
|
config.num_tensor_scales = 4 * config.layer_configs.size();
|
||||||
config.query_scale = QueryScaleType::SqrtKeySize;
|
config.query_scale = QueryScaleType::SqrtKeySize;
|
||||||
|
|
@ -103,6 +122,16 @@ static ModelConfig ConfigGemma2_2B() {
|
||||||
return config;
|
return config;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static LayerConfig LayerConfigGemma7B(size_t model_dim) {
|
||||||
|
LayerConfig config;
|
||||||
|
config.model_dim = model_dim;
|
||||||
|
config.ff_hidden_dim = 16 * 3072 / 2; // = 24576
|
||||||
|
config.heads = 16;
|
||||||
|
config.kv_heads = 16;
|
||||||
|
config.qkv_dim = 256;
|
||||||
|
return config;
|
||||||
|
}
|
||||||
|
|
||||||
static ModelConfig ConfigGemma7B() {
|
static ModelConfig ConfigGemma7B() {
|
||||||
ModelConfig config = ConfigBaseGemmaV1();
|
ModelConfig config = ConfigBaseGemmaV1();
|
||||||
config.model_name = "Gemma7B";
|
config.model_name = "Gemma7B";
|
||||||
|
|
@ -110,13 +139,7 @@ static ModelConfig ConfigGemma7B() {
|
||||||
config.model_dim = 3072;
|
config.model_dim = 3072;
|
||||||
config.vocab_size = kVocabSize;
|
config.vocab_size = kVocabSize;
|
||||||
config.seq_len = kSeqLen;
|
config.seq_len = kSeqLen;
|
||||||
LayerConfig layer_config = {
|
LayerConfig layer_config = LayerConfigGemma7B(config.model_dim);
|
||||||
.model_dim = config.model_dim,
|
|
||||||
.ff_hidden_dim = 16 * 3072 / 2, // = 24576
|
|
||||||
.heads = 16,
|
|
||||||
.kv_heads = 16,
|
|
||||||
.qkv_dim = 256,
|
|
||||||
};
|
|
||||||
config.layer_configs = {28, layer_config};
|
config.layer_configs = {28, layer_config};
|
||||||
config.num_tensor_scales = 4 * config.layer_configs.size();
|
config.num_tensor_scales = 4 * config.layer_configs.size();
|
||||||
config.query_scale = QueryScaleType::SqrtKeySize;
|
config.query_scale = QueryScaleType::SqrtKeySize;
|
||||||
|
|
@ -124,6 +147,16 @@ static ModelConfig ConfigGemma7B() {
|
||||||
return config;
|
return config;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static LayerConfig LayerConfigGemma2B(size_t model_dim) {
|
||||||
|
LayerConfig config;
|
||||||
|
config.model_dim = model_dim;
|
||||||
|
config.ff_hidden_dim = 16 * 2048 / 2; // = 16384
|
||||||
|
config.heads = 8;
|
||||||
|
config.kv_heads = 1;
|
||||||
|
config.qkv_dim = 256;
|
||||||
|
return config;
|
||||||
|
}
|
||||||
|
|
||||||
static ModelConfig ConfigGemma2B() {
|
static ModelConfig ConfigGemma2B() {
|
||||||
ModelConfig config = ConfigBaseGemmaV1();
|
ModelConfig config = ConfigBaseGemmaV1();
|
||||||
config.model_name = "Gemma2B";
|
config.model_name = "Gemma2B";
|
||||||
|
|
@ -131,19 +164,23 @@ static ModelConfig ConfigGemma2B() {
|
||||||
config.model_dim = 2048;
|
config.model_dim = 2048;
|
||||||
config.vocab_size = kVocabSize;
|
config.vocab_size = kVocabSize;
|
||||||
config.seq_len = kSeqLen;
|
config.seq_len = kSeqLen;
|
||||||
LayerConfig layer_config = {
|
LayerConfig layer_config = LayerConfigGemma2B(config.model_dim);
|
||||||
.model_dim = config.model_dim,
|
|
||||||
.ff_hidden_dim = 16 * 2048 / 2, // = 16384
|
|
||||||
.heads = 8,
|
|
||||||
.kv_heads = 1,
|
|
||||||
.qkv_dim = 256,
|
|
||||||
};
|
|
||||||
config.layer_configs = {18, layer_config};
|
config.layer_configs = {18, layer_config};
|
||||||
config.num_tensor_scales = 4 * config.layer_configs.size();
|
config.num_tensor_scales = 4 * config.layer_configs.size();
|
||||||
config.attention_window_sizes = FixedAttentionWindowSizes<18>(kSeqLen);
|
config.attention_window_sizes = FixedAttentionWindowSizes<18>(kSeqLen);
|
||||||
return config;
|
return config;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static LayerConfig LayerConfigGemmaTiny(size_t model_dim) {
|
||||||
|
LayerConfig config;
|
||||||
|
config.model_dim = model_dim;
|
||||||
|
config.ff_hidden_dim = 256;
|
||||||
|
config.heads = 4;
|
||||||
|
config.kv_heads = 1;
|
||||||
|
config.qkv_dim = 16;
|
||||||
|
return config;
|
||||||
|
}
|
||||||
|
|
||||||
static ModelConfig ConfigGemmaTiny() {
|
static ModelConfig ConfigGemmaTiny() {
|
||||||
ModelConfig config = ConfigNoSSM();
|
ModelConfig config = ConfigNoSSM();
|
||||||
config.model_name = "GemmaTiny";
|
config.model_name = "GemmaTiny";
|
||||||
|
|
@ -151,13 +188,7 @@ static ModelConfig ConfigGemmaTiny() {
|
||||||
config.model_dim = 128;
|
config.model_dim = 128;
|
||||||
config.vocab_size = 64;
|
config.vocab_size = 64;
|
||||||
config.seq_len = 32;
|
config.seq_len = 32;
|
||||||
LayerConfig layer_config = {
|
LayerConfig layer_config = LayerConfigGemmaTiny(config.model_dim);
|
||||||
.model_dim = config.model_dim,
|
|
||||||
.ff_hidden_dim = 256,
|
|
||||||
.heads = 4,
|
|
||||||
.kv_heads = 1,
|
|
||||||
.qkv_dim = 16,
|
|
||||||
};
|
|
||||||
config.layer_configs = {3, layer_config};
|
config.layer_configs = {3, layer_config};
|
||||||
config.num_tensor_scales = 4 * config.layer_configs.size();
|
config.num_tensor_scales = 4 * config.layer_configs.size();
|
||||||
config.query_scale = QueryScaleType::SqrtKeySize;
|
config.query_scale = QueryScaleType::SqrtKeySize;
|
||||||
|
|
@ -167,6 +198,24 @@ static ModelConfig ConfigGemmaTiny() {
|
||||||
return config;
|
return config;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static LayerConfig LayerConfigGriffin2B(size_t model_dim) {
|
||||||
|
LayerConfig config;
|
||||||
|
config.model_dim = model_dim;
|
||||||
|
config.griffin_dim = model_dim;
|
||||||
|
config.ff_hidden_dim = 7680;
|
||||||
|
config.heads = 10;
|
||||||
|
config.kv_heads = 1;
|
||||||
|
config.qkv_dim = 256;
|
||||||
|
config.conv1d_width = 4;
|
||||||
|
config.ff_biases = true;
|
||||||
|
config.softmax_attn_output_biases = true;
|
||||||
|
config.optimized_gating = false;
|
||||||
|
config.type = LayerAttentionType::kGriffinRecurrentBlock;
|
||||||
|
config.activation = ActivationType::Gelu;
|
||||||
|
config.post_qk = PostQKType::HalfRope;
|
||||||
|
return config;
|
||||||
|
}
|
||||||
|
|
||||||
static ModelConfig ConfigGriffin2B() {
|
static ModelConfig ConfigGriffin2B() {
|
||||||
ModelConfig config = ConfigNoSSM();
|
ModelConfig config = ConfigNoSSM();
|
||||||
config.model_name = "Griffin2B";
|
config.model_name = "Griffin2B";
|
||||||
|
|
@ -176,21 +225,7 @@ static ModelConfig ConfigGriffin2B() {
|
||||||
config.model_dim = 2560;
|
config.model_dim = 2560;
|
||||||
config.vocab_size = kVocabSize;
|
config.vocab_size = kVocabSize;
|
||||||
config.seq_len = 2048;
|
config.seq_len = 2048;
|
||||||
LayerConfig layer_config = {
|
LayerConfig layer_config = LayerConfigGriffin2B(config.model_dim);
|
||||||
.model_dim = config.model_dim,
|
|
||||||
.griffin_dim = config.model_dim,
|
|
||||||
.ff_hidden_dim = 7680,
|
|
||||||
.heads = 10,
|
|
||||||
.kv_heads = 1,
|
|
||||||
.qkv_dim = 256,
|
|
||||||
.conv1d_width = 4,
|
|
||||||
.ff_biases = true,
|
|
||||||
.softmax_attn_output_biases = true,
|
|
||||||
.optimized_gating = false,
|
|
||||||
.type = LayerAttentionType::kGriffinRecurrentBlock,
|
|
||||||
.activation = ActivationType::Gelu,
|
|
||||||
.post_qk = PostQKType::HalfRope,
|
|
||||||
};
|
|
||||||
config.layer_configs = {26, layer_config};
|
config.layer_configs = {26, layer_config};
|
||||||
for (size_t i = 2; i < config.layer_configs.size(); i += 3) {
|
for (size_t i = 2; i < config.layer_configs.size(); i += 3) {
|
||||||
config.layer_configs[i].type = LayerAttentionType::kGemma;
|
config.layer_configs[i].type = LayerAttentionType::kGemma;
|
||||||
|
|
@ -204,6 +239,18 @@ static ModelConfig ConfigGriffin2B() {
|
||||||
return config;
|
return config;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static LayerConfig LayerConfigVit(size_t model_dim) {
|
||||||
|
LayerConfig config;
|
||||||
|
config.model_dim = model_dim;
|
||||||
|
config.ff_hidden_dim = 4304;
|
||||||
|
config.heads = 16;
|
||||||
|
config.kv_heads = 16;
|
||||||
|
config.qkv_dim = 72;
|
||||||
|
config.ff_biases = true;
|
||||||
|
config.type = LayerAttentionType::kVit;
|
||||||
|
return config;
|
||||||
|
}
|
||||||
|
|
||||||
// Adds a ViT config (SigLIP SoViT ViT, used in PaliGemma) to the model config.
|
// Adds a ViT config (SigLIP SoViT ViT, used in PaliGemma) to the model config.
|
||||||
static void AddVitConfig(ModelConfig& config, size_t image_size = 224) {
|
static void AddVitConfig(ModelConfig& config, size_t image_size = 224) {
|
||||||
config.vit_model_dim = 1152;
|
config.vit_model_dim = 1152;
|
||||||
|
|
@ -215,15 +262,7 @@ static void AddVitConfig(ModelConfig& config, size_t image_size = 224) {
|
||||||
}
|
}
|
||||||
const size_t num_patches = config.image_size / config.patch_width;
|
const size_t num_patches = config.image_size / config.patch_width;
|
||||||
config.vit_seq_len = num_patches * num_patches;
|
config.vit_seq_len = num_patches * num_patches;
|
||||||
LayerConfig vit_layer_config = {
|
LayerConfig vit_layer_config = LayerConfigVit(config.vit_model_dim);
|
||||||
.model_dim = config.vit_model_dim,
|
|
||||||
.ff_hidden_dim = 4304,
|
|
||||||
.heads = 16,
|
|
||||||
.kv_heads = 16,
|
|
||||||
.qkv_dim = 72,
|
|
||||||
.ff_biases = true,
|
|
||||||
.type = LayerAttentionType::kVit,
|
|
||||||
};
|
|
||||||
config.vit_layer_configs = {27, vit_layer_config};
|
config.vit_layer_configs = {27, vit_layer_config};
|
||||||
config.num_vit_scales = 4 * config.vit_layer_configs.size();
|
config.num_vit_scales = 4 * config.vit_layer_configs.size();
|
||||||
}
|
}
|
||||||
|
|
|
||||||
155
gemma/configs.h
155
gemma/configs.h
|
|
@ -26,6 +26,7 @@
|
||||||
#include <unordered_set>
|
#include <unordered_set>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "compression/fields.h" // IFieldsVisitor
|
||||||
#include "compression/shared.h" // BF16
|
#include "compression/shared.h" // BF16
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
@ -52,52 +53,83 @@ enum class LayerAttentionType {
|
||||||
kVit,
|
kVit,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
inline bool EnumValid(LayerAttentionType type) {
|
||||||
|
return static_cast<int>(type) >= 0 &&
|
||||||
|
static_cast<int>(type) <= static_cast<int>(LayerAttentionType::kVit);
|
||||||
|
}
|
||||||
|
|
||||||
// Post attention and ffw normalization type.
|
// Post attention and ffw normalization type.
|
||||||
enum class PostNormType {
|
enum class PostNormType {
|
||||||
None,
|
None,
|
||||||
Scale,
|
Scale,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
inline bool EnumValid(PostNormType type) {
|
||||||
|
return static_cast<int>(type) >= 0 &&
|
||||||
|
static_cast<int>(type) <= static_cast<int>(PostNormType::Scale);
|
||||||
|
}
|
||||||
|
|
||||||
// Post qk projection operation type.
|
// Post qk projection operation type.
|
||||||
enum class PostQKType {
|
enum class PostQKType {
|
||||||
Rope,
|
Rope,
|
||||||
HalfRope,
|
HalfRope,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
inline bool EnumValid(PostQKType type) {
|
||||||
|
return static_cast<int>(type) >= 0 &&
|
||||||
|
static_cast<int>(type) <= static_cast<int>(PostQKType::HalfRope);
|
||||||
|
}
|
||||||
|
|
||||||
// FFW activation function.
|
// FFW activation function.
|
||||||
enum class ActivationType {
|
enum class ActivationType {
|
||||||
Gelu,
|
Gelu,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
inline bool EnumValid(ActivationType type) {
|
||||||
|
return static_cast<int>(type) >= 0 &&
|
||||||
|
static_cast<int>(type) <= static_cast<int>(ActivationType::Gelu);
|
||||||
|
}
|
||||||
|
|
||||||
// Attention query scale.
|
// Attention query scale.
|
||||||
enum class QueryScaleType {
|
enum class QueryScaleType {
|
||||||
SqrtKeySize,
|
SqrtKeySize,
|
||||||
SqrtModelDimDivNumHeads,
|
SqrtModelDimDivNumHeads,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
inline bool EnumValid(QueryScaleType type) {
|
||||||
|
return static_cast<int>(type) >= 0 &&
|
||||||
|
static_cast<int>(type) <=
|
||||||
|
static_cast<int>(QueryScaleType::SqrtModelDimDivNumHeads);
|
||||||
|
}
|
||||||
|
|
||||||
// Residual connection type.
|
// Residual connection type.
|
||||||
enum class ResidualType {
|
enum class ResidualType {
|
||||||
Add,
|
Add,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
inline bool EnumValid(ResidualType type) {
|
||||||
|
return static_cast<int>(type) >= 0 &&
|
||||||
|
static_cast<int>(type) <= static_cast<int>(ResidualType::Add);
|
||||||
|
}
|
||||||
|
|
||||||
template <size_t kNum>
|
template <size_t kNum>
|
||||||
std::vector<LayerAttentionType> FixedLayerConfig(LayerAttentionType type) {
|
std::vector<LayerAttentionType> FixedLayerConfig(LayerAttentionType type) {
|
||||||
return std::vector<LayerAttentionType>(kNum, type);
|
return std::vector<LayerAttentionType>(kNum, type);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <size_t kNum>
|
template <uint32_t kNum>
|
||||||
std::vector<size_t> FixedAttentionWindowSizes(size_t window_size) {
|
std::vector<uint32_t> FixedAttentionWindowSizes(uint32_t window_size) {
|
||||||
return std::vector<size_t>(kNum, window_size);
|
return std::vector<uint32_t>(kNum, window_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Repeat window_size_pattern for kNum / kPatternSize times.
|
// Repeat window_size_pattern for kNum / kPatternSize times.
|
||||||
template <size_t kNum, size_t kPatternSize>
|
template <uint32_t kNum, uint32_t kPatternSize>
|
||||||
std::vector<size_t> RepeatedAttentionWindowSizes(
|
std::vector<uint32_t> RepeatedAttentionWindowSizes(
|
||||||
const std::array<size_t, kPatternSize>& window_size_pattern) {
|
const std::array<uint32_t, kPatternSize>& window_size_pattern) {
|
||||||
static_assert(kNum % kPatternSize == 0,
|
static_assert(kNum % kPatternSize == 0,
|
||||||
"kNum must be a multiple of kPatternSize");
|
"kNum must be a multiple of kPatternSize");
|
||||||
std::vector<size_t> window_size_configs(kNum);
|
std::vector<uint32_t> window_size_configs(kNum);
|
||||||
for (size_t i = 0; i < kNum; ++i) {
|
for (uint32_t i = 0; i < kNum; ++i) {
|
||||||
window_size_configs[i] = window_size_pattern[i % kPatternSize];
|
window_size_configs[i] = window_size_pattern[i % kPatternSize];
|
||||||
}
|
}
|
||||||
return window_size_configs;
|
return window_size_configs;
|
||||||
|
|
@ -130,7 +162,14 @@ static constexpr Model kAllModels[] = {
|
||||||
Model::PALIGEMMA2_10B_224, Model::PALIGEMMA2_10B_448,
|
Model::PALIGEMMA2_10B_224, Model::PALIGEMMA2_10B_448,
|
||||||
};
|
};
|
||||||
|
|
||||||
struct LayerConfig {
|
inline bool EnumValid(Model model) {
|
||||||
|
for (Model m : kAllModels) {
|
||||||
|
if (m == model) return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct LayerConfig : public IFields {
|
||||||
// Returns true if *this and other are equal.
|
// Returns true if *this and other are equal.
|
||||||
// If partial is true, then we don't check for items that are only set after
|
// If partial is true, then we don't check for items that are only set after
|
||||||
// the tensors are loaded from the checkpoint.
|
// the tensors are loaded from the checkpoint.
|
||||||
|
|
@ -146,13 +185,32 @@ struct LayerConfig {
|
||||||
// but for MHA we store them as Q,K,V, Q,K,V, .. instead of Q..Q, K..K, V..V.
|
// but for MHA we store them as Q,K,V, Q,K,V, .. instead of Q..Q, K..K, V..V.
|
||||||
size_t QStride() const { return qkv_dim * (IsMHA() ? 3 : 1); }
|
size_t QStride() const { return qkv_dim * (IsMHA() ? 3 : 1); }
|
||||||
|
|
||||||
size_t model_dim = 0;
|
const char* Name() const override { return "LayerConfig"; }
|
||||||
size_t griffin_dim = 0;
|
|
||||||
size_t ff_hidden_dim = 0;
|
void VisitFields(IFieldsVisitor& visitor) override {
|
||||||
size_t heads = 0;
|
visitor(model_dim);
|
||||||
size_t kv_heads = 0;
|
visitor(griffin_dim);
|
||||||
size_t qkv_dim = 0;
|
visitor(ff_hidden_dim);
|
||||||
size_t conv1d_width = 0; // griffin only
|
visitor(heads);
|
||||||
|
visitor(kv_heads);
|
||||||
|
visitor(qkv_dim);
|
||||||
|
visitor(conv1d_width);
|
||||||
|
visitor(ff_biases);
|
||||||
|
visitor(softmax_attn_output_biases);
|
||||||
|
visitor(optimized_gating);
|
||||||
|
visitor(post_norm);
|
||||||
|
visitor(type);
|
||||||
|
visitor(activation);
|
||||||
|
visitor(post_qk);
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t model_dim = 0;
|
||||||
|
uint32_t griffin_dim = 0;
|
||||||
|
uint32_t ff_hidden_dim = 0;
|
||||||
|
uint32_t heads = 0;
|
||||||
|
uint32_t kv_heads = 0;
|
||||||
|
uint32_t qkv_dim = 0;
|
||||||
|
uint32_t conv1d_width = 0; // griffin only
|
||||||
bool ff_biases = false;
|
bool ff_biases = false;
|
||||||
bool softmax_attn_output_biases = false;
|
bool softmax_attn_output_biases = false;
|
||||||
bool optimized_gating = true;
|
bool optimized_gating = true;
|
||||||
|
|
@ -162,7 +220,7 @@ struct LayerConfig {
|
||||||
PostQKType post_qk = PostQKType::Rope;
|
PostQKType post_qk = PostQKType::Rope;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct ModelConfig {
|
struct ModelConfig : public IFields {
|
||||||
// Returns true if *this and other are equal.
|
// Returns true if *this and other are equal.
|
||||||
// If partial is true, then we don't check for items that are only set after
|
// If partial is true, then we don't check for items that are only set after
|
||||||
// the tensors are loaded from the checkpoint.
|
// the tensors are loaded from the checkpoint.
|
||||||
|
|
@ -191,39 +249,68 @@ struct ModelConfig {
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t NumHeads() const {
|
size_t NumHeads() const {
|
||||||
size_t num_heads = 0;
|
uint32_t num_heads = 0;
|
||||||
for (const auto& layer_config : layer_configs) {
|
for (const auto& layer_config : layer_configs) {
|
||||||
num_heads = std::max(num_heads, layer_config.heads);
|
num_heads = std::max(num_heads, layer_config.heads);
|
||||||
}
|
}
|
||||||
return num_heads;
|
return num_heads;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const char* Name() const override { return "ModelConfig"; }
|
||||||
|
|
||||||
|
void VisitFields(IFieldsVisitor& visitor) override {
|
||||||
|
visitor(model_family_version);
|
||||||
|
visitor(model_name);
|
||||||
|
visitor(model);
|
||||||
|
visitor(wrapping);
|
||||||
|
visitor(weight);
|
||||||
|
visitor(num_layers);
|
||||||
|
visitor(model_dim);
|
||||||
|
visitor(vocab_size);
|
||||||
|
visitor(seq_len);
|
||||||
|
visitor(num_tensor_scales);
|
||||||
|
visitor(att_cap);
|
||||||
|
visitor(final_cap);
|
||||||
|
visitor(absolute_pe);
|
||||||
|
visitor(use_local_attention);
|
||||||
|
visitor(query_scale);
|
||||||
|
visitor(layer_configs);
|
||||||
|
visitor(attention_window_sizes);
|
||||||
|
visitor(norm_num_groups);
|
||||||
|
visitor(vit_model_dim);
|
||||||
|
visitor(vit_seq_len);
|
||||||
|
visitor(num_vit_scales);
|
||||||
|
visitor(vit_layer_configs);
|
||||||
|
visitor(patch_width);
|
||||||
|
visitor(image_size);
|
||||||
|
}
|
||||||
|
|
||||||
std::string model_name;
|
std::string model_name;
|
||||||
Model model;
|
Model model = Model::UNKNOWN;
|
||||||
PromptWrapping wrapping;
|
PromptWrapping wrapping = PromptWrapping::GEMMA_PT;
|
||||||
Type weight;
|
Type weight = Type::kUnknown;
|
||||||
size_t num_layers = 0;
|
uint32_t num_layers = 0;
|
||||||
size_t model_dim = 0;
|
uint32_t model_dim = 0;
|
||||||
size_t vit_model_dim = 0;
|
uint32_t vit_model_dim = 0;
|
||||||
size_t vocab_size = 0;
|
uint32_t vocab_size = 0;
|
||||||
size_t seq_len = 0;
|
uint32_t seq_len = 0;
|
||||||
size_t vit_seq_len = 0;
|
uint32_t vit_seq_len = 0;
|
||||||
size_t num_tensor_scales = 0;
|
uint32_t num_tensor_scales = 0;
|
||||||
size_t num_vit_scales = 0;
|
uint32_t num_vit_scales = 0;
|
||||||
float att_cap = 0.0f;
|
float att_cap = 0.0f;
|
||||||
float final_cap = 0.0f;
|
float final_cap = 0.0f;
|
||||||
bool absolute_pe = false;
|
bool absolute_pe = false;
|
||||||
bool use_local_attention = false; // griffin only
|
bool use_local_attention = false; // griffin only
|
||||||
QueryScaleType query_scale = QueryScaleType::SqrtKeySize;
|
QueryScaleType query_scale = QueryScaleType::SqrtKeySize;
|
||||||
std::vector<LayerConfig> layer_configs;
|
std::vector<LayerConfig> layer_configs;
|
||||||
std::vector<size_t> attention_window_sizes;
|
std::vector<uint32_t> attention_window_sizes;
|
||||||
std::vector<LayerConfig> vit_layer_configs;
|
std::vector<LayerConfig> vit_layer_configs;
|
||||||
std::unordered_set<std::string> scale_names;
|
std::unordered_set<std::string> scale_names;
|
||||||
int norm_num_groups = 1;
|
uint32_t norm_num_groups = 1;
|
||||||
int model_family_version = 1;
|
uint32_t model_family_version = 1;
|
||||||
// Dimensions related to image processing.
|
// Dimensions related to image processing.
|
||||||
size_t patch_width = 14;
|
uint32_t patch_width = 14;
|
||||||
size_t image_size = 224;
|
uint32_t image_size = 224;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Returns the config for the given model.
|
// Returns the config for the given model.
|
||||||
|
|
|
||||||
|
|
@ -2,9 +2,12 @@
|
||||||
|
|
||||||
#include <array>
|
#include <array>
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
|
#include <cstdint>
|
||||||
#include <type_traits>
|
#include <type_traits>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
#include "gtest/gtest.h"
|
#include "gtest/gtest.h"
|
||||||
|
#include "hwy/aligned_allocator.h"
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
|
|
@ -412,8 +415,17 @@ void AssertMatch(const ModelConfig& config) {
|
||||||
ASSERT_EQ(TConfig::kNumTensorScales, config.num_tensor_scales);
|
ASSERT_EQ(TConfig::kNumTensorScales, config.num_tensor_scales);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ModelConfig RoundTripSerialize(const ModelConfig& config) {
|
||||||
|
std::vector<uint32_t> config_buffer = config.Write();
|
||||||
|
ModelConfig deserialized;
|
||||||
|
deserialized.Read(hwy::Span<const uint32_t>(config_buffer), 0);
|
||||||
|
return deserialized;
|
||||||
|
}
|
||||||
|
|
||||||
TEST(ConfigsTest, OldConfigGemma2B) {
|
TEST(ConfigsTest, OldConfigGemma2B) {
|
||||||
AssertMatch<OldConfigGemma2B<float>>(ConfigFromModel(Model::GEMMA_2B));
|
AssertMatch<OldConfigGemma2B<float>>(ConfigFromModel(Model::GEMMA_2B));
|
||||||
|
ModelConfig config = RoundTripSerialize(ConfigFromModel(Model::GEMMA_2B));
|
||||||
|
AssertMatch<OldConfigGemma2B<float>>(config);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(ConfigsTest, OldConfigGemma7B) {
|
TEST(ConfigsTest, OldConfigGemma7B) {
|
||||||
|
|
|
||||||
|
|
@ -23,6 +23,7 @@
|
||||||
#include <stdlib.h>
|
#include <stdlib.h>
|
||||||
#include <string.h>
|
#include <string.h>
|
||||||
|
|
||||||
|
#include <string>
|
||||||
#include <utility> // std::move
|
#include <utility> // std::move
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
|
@ -40,13 +41,21 @@ namespace gcpp {
|
||||||
|
|
||||||
Gemma::Gemma(const Path& tokenizer_path, const Path& weights,
|
Gemma::Gemma(const Path& tokenizer_path, const Path& weights,
|
||||||
const ModelInfo& info, NestedPools& pools)
|
const ModelInfo& info, NestedPools& pools)
|
||||||
: pools_(pools), tokenizer_(tokenizer_path), info_(info) {
|
: pools_(pools), tokenizer_(tokenizer_path) {
|
||||||
model_.Load(weights, info.model, info.weight, pools_.Pool());
|
model_.Load(weights, info.model, info.weight, info.wrapping, pools_.Pool(),
|
||||||
|
/*tokenizer_proto=*/nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
Gemma::Gemma(const Path& weights, NestedPools& pools) : pools_(pools) {
|
||||||
|
std::string tokenizer_proto;
|
||||||
|
model_.Load(weights, Model::UNKNOWN, Type::kUnknown, PromptWrapping::GEMMA_IT,
|
||||||
|
pools_.Pool(), &tokenizer_proto);
|
||||||
|
tokenizer_.Deserialize(tokenizer_proto);
|
||||||
}
|
}
|
||||||
|
|
||||||
Gemma::Gemma(GemmaTokenizer&& tokenizer, const ModelInfo& info,
|
Gemma::Gemma(GemmaTokenizer&& tokenizer, const ModelInfo& info,
|
||||||
NestedPools& pools)
|
NestedPools& pools)
|
||||||
: pools_(pools), tokenizer_(std::move(tokenizer)), info_(info) {
|
: pools_(pools), tokenizer_(std::move(tokenizer)) {
|
||||||
HWY_ASSERT(info.weight == Type::kF32);
|
HWY_ASSERT(info.weight == Type::kF32);
|
||||||
model_.Allocate(info.model, info.weight, pools_.Pool());
|
model_.Allocate(info.model, info.weight, pools_.Pool());
|
||||||
}
|
}
|
||||||
|
|
@ -166,7 +175,7 @@ void RangeChecks(const ModelConfig& weights_config,
|
||||||
if (!weights_config.use_local_attention) {
|
if (!weights_config.use_local_attention) {
|
||||||
if (max_generated_tokens > weights_config.seq_len) {
|
if (max_generated_tokens > weights_config.seq_len) {
|
||||||
fprintf(stderr,
|
fprintf(stderr,
|
||||||
"WARNING: max_generated_tokens %zu > kSeqLen %zu, truncating.\n",
|
"WARNING: max_generated_tokens %zu > kSeqLen %u, truncating.\n",
|
||||||
max_generated_tokens, weights_config.seq_len);
|
max_generated_tokens, weights_config.seq_len);
|
||||||
max_generated_tokens = weights_config.seq_len;
|
max_generated_tokens = weights_config.seq_len;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -190,18 +190,28 @@ struct TimingInfo {
|
||||||
|
|
||||||
class Gemma {
|
class Gemma {
|
||||||
public:
|
public:
|
||||||
|
// Reads old format weights file and tokenizer file.
|
||||||
Gemma(const Path& tokenizer_path, const Path& weights, const ModelInfo& info,
|
Gemma(const Path& tokenizer_path, const Path& weights, const ModelInfo& info,
|
||||||
NestedPools& pools);
|
NestedPools& pools);
|
||||||
|
// Reads new format weights file that contains everything in a single file.
|
||||||
|
Gemma(const Path& weights, NestedPools& pools);
|
||||||
// Allocates weights, caller is responsible for filling them.
|
// Allocates weights, caller is responsible for filling them.
|
||||||
Gemma(GemmaTokenizer&& tokenizer, const ModelInfo& info, NestedPools& pools);
|
Gemma(GemmaTokenizer&& tokenizer, const ModelInfo& info, NestedPools& pools);
|
||||||
~Gemma();
|
~Gemma();
|
||||||
|
|
||||||
const ModelConfig& GetModelConfig() const { return model_.Config(); }
|
const ModelConfig& GetModelConfig() const { return model_.Config(); }
|
||||||
const ModelInfo& Info() const { return info_; }
|
ModelInfo Info() const {
|
||||||
|
return ModelInfo({.model = model_.Config().model,
|
||||||
|
.wrapping = model_.Config().wrapping,
|
||||||
|
.weight = model_.Config().weight});
|
||||||
|
}
|
||||||
const GemmaTokenizer& Tokenizer() const { return tokenizer_; }
|
const GemmaTokenizer& Tokenizer() const { return tokenizer_; }
|
||||||
const ModelWeightsStorage& Weights() const { return model_; }
|
const ModelWeightsStorage& Weights() const { return model_; }
|
||||||
ModelWeightsStorage& MutableWeights() { return model_; }
|
ModelWeightsStorage& MutableWeights() { return model_; }
|
||||||
|
void Save(const Path& weights, hwy::ThreadPool& pool) {
|
||||||
|
std::string tokenizer_proto = tokenizer_.Serialize();
|
||||||
|
model_.Save(tokenizer_proto, weights, pool);
|
||||||
|
}
|
||||||
|
|
||||||
// `pos` is the position in the KV cache. Users are responsible for
|
// `pos` is the position in the KV cache. Users are responsible for
|
||||||
// incrementing it in the `*StreamFunc`, or setting to zero for single-turn.
|
// incrementing it in the `*StreamFunc`, or setting to zero for single-turn.
|
||||||
|
|
@ -241,7 +251,6 @@ class Gemma {
|
||||||
GemmaTokenizer tokenizer_;
|
GemmaTokenizer tokenizer_;
|
||||||
// Type-erased so that this can be defined in the header.
|
// Type-erased so that this can be defined in the header.
|
||||||
ModelWeightsStorage model_;
|
ModelWeightsStorage model_;
|
||||||
ModelInfo info_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// Adds BOS token and possibly 'turn' annotations, which depend on `info`
|
// Adds BOS token and possibly 'turn' annotations, which depend on `info`
|
||||||
|
|
|
||||||
|
|
@ -53,7 +53,7 @@ KVCache KVCache::Create(const ModelConfig& weights_config,
|
||||||
LayerAttentionType::kGriffinRecurrentBlock);
|
LayerAttentionType::kGriffinRecurrentBlock);
|
||||||
// TODO(patrickms): Add query batching support for Griffin.
|
// TODO(patrickms): Add query batching support for Griffin.
|
||||||
if (num_griffin_layers > 0) {
|
if (num_griffin_layers > 0) {
|
||||||
size_t conv1d_width = 0;
|
uint32_t conv1d_width = 0;
|
||||||
for (const auto& layer_config : weights_config.layer_configs) {
|
for (const auto& layer_config : weights_config.layer_configs) {
|
||||||
conv1d_width = std::max(conv1d_width, layer_config.conv1d_width);
|
conv1d_width = std::max(conv1d_width, layer_config.conv1d_width);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -482,6 +482,8 @@ std::vector<TensorInfo> LLMLayerTensors(const ModelConfig& config,
|
||||||
.name = "att_w",
|
.name = "att_w",
|
||||||
.source_names = {"attn/attn_vec_einsum/w",
|
.source_names = {"attn/attn_vec_einsum/w",
|
||||||
"attention_block/proj_final/kernel"},
|
"attention_block/proj_final/kernel"},
|
||||||
|
.preshape = {layer_config.heads, layer_config.qkv_dim,
|
||||||
|
config.model_dim},
|
||||||
.axes = {2, 0, 1},
|
.axes = {2, 0, 1},
|
||||||
.shape = {config.model_dim, layer_config.heads, layer_config.qkv_dim},
|
.shape = {config.model_dim, layer_config.heads, layer_config.qkv_dim},
|
||||||
.cols_take_extra_dims = true,
|
.cols_take_extra_dims = true,
|
||||||
|
|
|
||||||
|
|
@ -56,7 +56,8 @@ TEST(TensorIndexTest, FindName) {
|
||||||
// Test that the MatPtr can be constructed from the TensorInfo,
|
// Test that the MatPtr can be constructed from the TensorInfo,
|
||||||
// and that the dimensions match.
|
// and that the dimensions match.
|
||||||
MatPtrT<SfpStream> mat_ptr(tensor.Name(), tensor_index);
|
MatPtrT<SfpStream> mat_ptr(tensor.Name(), tensor_index);
|
||||||
EXPECT_EQ(tensor.Name(), mat_ptr.Name()) << "on tensor " << name;
|
EXPECT_STREQ(tensor.Name(), mat_ptr.Name())
|
||||||
|
<< "on tensor " << name;
|
||||||
EXPECT_EQ(tensor.Rows(), mat_ptr.Rows()) << "on tensor " << name;
|
EXPECT_EQ(tensor.Rows(), mat_ptr.Rows()) << "on tensor " << name;
|
||||||
EXPECT_EQ(tensor.Cols(), mat_ptr.Cols()) << "on tensor " << name;
|
EXPECT_EQ(tensor.Cols(), mat_ptr.Cols()) << "on tensor " << name;
|
||||||
++num_found;
|
++num_found;
|
||||||
|
|
|
||||||
|
|
@ -44,6 +44,17 @@ class GemmaTokenizer::Impl {
|
||||||
HWY_ABORT("Failed to load the tokenizer file.");
|
HWY_ABORT("Failed to load the tokenizer file.");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// Loads the tokenizer from a serialized proto.
|
||||||
|
explicit Impl(const std::string& tokenizer_proto) {
|
||||||
|
PROFILER_ZONE("Startup.tokenizer");
|
||||||
|
spp_ = std::make_unique<sentencepiece::SentencePieceProcessor>();
|
||||||
|
if (!spp_->LoadFromSerializedProto(tokenizer_proto).ok()) {
|
||||||
|
fprintf(stderr, "serialized proto size=%zu.\n", tokenizer_proto.size());
|
||||||
|
HWY_ABORT("Failed to load the tokenizer from serialized proto.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string Serialize() const { return spp_->serialized_model_proto(); }
|
||||||
|
|
||||||
bool Encode(const std::string& input,
|
bool Encode(const std::string& input,
|
||||||
std::vector<std::string>* pieces) const {
|
std::vector<std::string>* pieces) const {
|
||||||
|
|
@ -81,6 +92,12 @@ GemmaTokenizer::~GemmaTokenizer() = default;
|
||||||
GemmaTokenizer::GemmaTokenizer(GemmaTokenizer&& other) = default;
|
GemmaTokenizer::GemmaTokenizer(GemmaTokenizer&& other) = default;
|
||||||
GemmaTokenizer& GemmaTokenizer::operator=(GemmaTokenizer&& other) = default;
|
GemmaTokenizer& GemmaTokenizer::operator=(GemmaTokenizer&& other) = default;
|
||||||
|
|
||||||
|
std::string GemmaTokenizer::Serialize() const { return impl_->Serialize(); }
|
||||||
|
|
||||||
|
void GemmaTokenizer::Deserialize(const std::string& tokenizer_proto) {
|
||||||
|
impl_ = std::make_unique<Impl>(tokenizer_proto);
|
||||||
|
}
|
||||||
|
|
||||||
bool GemmaTokenizer::Encode(const std::string& input,
|
bool GemmaTokenizer::Encode(const std::string& input,
|
||||||
std::vector<std::string>* pieces) const {
|
std::vector<std::string>* pieces) const {
|
||||||
return impl_->Encode(input, pieces);
|
return impl_->Encode(input, pieces);
|
||||||
|
|
|
||||||
|
|
@ -41,6 +41,9 @@ class GemmaTokenizer {
|
||||||
GemmaTokenizer(GemmaTokenizer&& other);
|
GemmaTokenizer(GemmaTokenizer&& other);
|
||||||
GemmaTokenizer& operator=(GemmaTokenizer&& other);
|
GemmaTokenizer& operator=(GemmaTokenizer&& other);
|
||||||
|
|
||||||
|
std::string Serialize() const;
|
||||||
|
void Deserialize(const std::string& tokenizer_proto);
|
||||||
|
|
||||||
bool Encode(const std::string& input, std::vector<std::string>* pieces) const;
|
bool Encode(const std::string& input, std::vector<std::string>* pieces) const;
|
||||||
bool Encode(const std::string& input, std::vector<int>* ids) const;
|
bool Encode(const std::string& input, std::vector<int>* ids) const;
|
||||||
bool Decode(const std::vector<int>& ids, std::string* detokenized) const;
|
bool Decode(const std::vector<int>& ids, std::string* detokenized) const;
|
||||||
|
|
|
||||||
|
|
@ -19,11 +19,13 @@
|
||||||
#include <cstdlib>
|
#include <cstdlib>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <random>
|
#include <random>
|
||||||
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "compression/blob_store.h"
|
#include "compression/blob_store.h"
|
||||||
#include "compression/compress.h"
|
#include "compression/compress.h"
|
||||||
#include "compression/io.h" // Path
|
#include "compression/io.h" // Path
|
||||||
|
#include "compression/shared.h"
|
||||||
#include "gemma/common.h"
|
#include "gemma/common.h"
|
||||||
#include "gemma/configs.h"
|
#include "gemma/configs.h"
|
||||||
#include "hwy/aligned_allocator.h"
|
#include "hwy/aligned_allocator.h"
|
||||||
|
|
@ -47,7 +49,9 @@ struct TensorLoader {
|
||||||
};
|
};
|
||||||
|
|
||||||
BlobError ModelWeightsStorage::Load(const Path& weights, Model model_type,
|
BlobError ModelWeightsStorage::Load(const Path& weights, Model model_type,
|
||||||
Type weight_type, hwy::ThreadPool& pool) {
|
Type weight_type, PromptWrapping wrapping,
|
||||||
|
hwy::ThreadPool& pool,
|
||||||
|
std::string* tokenizer_proto) {
|
||||||
PROFILER_ZONE("Startup.LoadModelWeightsPtrs");
|
PROFILER_ZONE("Startup.LoadModelWeightsPtrs");
|
||||||
if (!weights.Exists()) {
|
if (!weights.Exists()) {
|
||||||
HWY_ABORT("The model weights file '%s' does not exist.",
|
HWY_ABORT("The model weights file '%s' does not exist.",
|
||||||
|
|
@ -56,17 +60,36 @@ BlobError ModelWeightsStorage::Load(const Path& weights, Model model_type,
|
||||||
ReadFromBlobStore loader(weights);
|
ReadFromBlobStore loader(weights);
|
||||||
ForEachType fet =
|
ForEachType fet =
|
||||||
loader.HaveToc() ? ForEachType::kLoadWithToc : ForEachType::kLoadNoToc;
|
loader.HaveToc() ? ForEachType::kLoadWithToc : ForEachType::kLoadNoToc;
|
||||||
|
std::vector<float> scales;
|
||||||
if (fet == ForEachType::kLoadWithToc) {
|
if (fet == ForEachType::kLoadWithToc) {
|
||||||
// TODO(rays): Load the config from the file.
|
BlobError err = loader.LoadConfig(config_);
|
||||||
HWY_ABORT("TOC not supported yet.");
|
if (err != 0 || config_.model_dim == 0) {
|
||||||
|
fprintf(stderr, "Failed to load model config: %d\n", err);
|
||||||
|
return err;
|
||||||
|
}
|
||||||
|
if (tokenizer_proto != nullptr) {
|
||||||
|
err = loader.LoadTokenizer(*tokenizer_proto);
|
||||||
|
if (err != 0) {
|
||||||
|
fprintf(stderr, "Failed to load tokenizer: %d\n", err);
|
||||||
|
return err;
|
||||||
|
}
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
|
if (weight_type == Type::kUnknown || model_type == Model::UNKNOWN) {
|
||||||
|
fprintf(stderr,
|
||||||
|
"weight type (%d) and model type (%d) must be specified when "
|
||||||
|
"no config is present in weights file\n",
|
||||||
|
static_cast<int>(weight_type), static_cast<int>(model_type));
|
||||||
|
return __LINE__;
|
||||||
|
}
|
||||||
// No Toc-> no config.
|
// No Toc-> no config.
|
||||||
config_ = ConfigFromModel(model_type);
|
config_ = ConfigFromModel(model_type);
|
||||||
config_.weight = weight_type;
|
config_.weight = weight_type;
|
||||||
|
config_.wrapping = wrapping;
|
||||||
|
scales.resize(config_.num_tensor_scales + config_.num_vit_scales);
|
||||||
}
|
}
|
||||||
CreateForType(weight_type, pool);
|
CreateForType(config_.weight, pool);
|
||||||
CallForModelWeightT<TensorLoader>(fet, loader);
|
CallForModelWeightT<TensorLoader>(fet, loader);
|
||||||
std::vector<float> scales(config_.num_tensor_scales + config_.num_vit_scales);
|
|
||||||
if (!scales.empty()) {
|
if (!scales.empty()) {
|
||||||
loader.LoadScales(scales.data(), scales.size());
|
loader.LoadScales(scales.data(), scales.size());
|
||||||
}
|
}
|
||||||
|
|
@ -85,6 +108,34 @@ BlobError ModelWeightsStorage::Load(const Path& weights, Model model_type,
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct TensorSaver {
|
||||||
|
// Adds all the tensors to the blob writer.
|
||||||
|
void operator()(ModelWeightsPtrs<T>& weights, ForEachType fet,
|
||||||
|
WriteToBlobStore& writer) {
|
||||||
|
weights.ForEachTensor(
|
||||||
|
{&weights}, fet,
|
||||||
|
[&writer](const char* name, hwy::Span<MatPtr*> tensors) {
|
||||||
|
tensors[0]->CallUpcasted(writer, name);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
BlobError ModelWeightsStorage::Save(const std::string& tokenizer,
|
||||||
|
const Path& weights,
|
||||||
|
hwy::ThreadPool& pool) {
|
||||||
|
WriteToBlobStore writer(pool);
|
||||||
|
ForEachType fet = ForEachType::kLoadWithToc;
|
||||||
|
CallForModelWeightT<TensorSaver>(fet, writer);
|
||||||
|
writer.AddTokenizer(tokenizer);
|
||||||
|
int err = writer.WriteAll(weights, &config_);
|
||||||
|
if (err != 0) {
|
||||||
|
fprintf(stderr, "Failed to load model weights: %d\n", err);
|
||||||
|
return err;
|
||||||
|
}
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
void ModelWeightsStorage::Allocate(const ModelConfig& config, Type weight_type,
|
void ModelWeightsStorage::Allocate(const ModelConfig& config, Type weight_type,
|
||||||
hwy::ThreadPool& pool) {
|
hwy::ThreadPool& pool) {
|
||||||
PROFILER_ZONE("Startup.AllocateModelWeightsPtrs");
|
PROFILER_ZONE("Startup.AllocateModelWeightsPtrs");
|
||||||
|
|
|
||||||
|
|
@ -522,7 +522,18 @@ class ModelWeightsStorage {
|
||||||
ModelWeightsStorage() = default;
|
ModelWeightsStorage() = default;
|
||||||
~ModelWeightsStorage() = default;
|
~ModelWeightsStorage() = default;
|
||||||
|
|
||||||
|
// Loads the weights from a blob store file. Supports multi-file or
|
||||||
|
// single-file format. If the weights file contains a TOC, then it is in
|
||||||
|
// single-file format, and model_type, weight_type, training are ignored,
|
||||||
|
// and tokenizer_proto is required and written to.
|
||||||
|
// With a multi-file format, file, model_type, weight_type, training are
|
||||||
|
// required and tokenizer_proto is ignored.
|
||||||
BlobError Load(const Path& weights, Model model_type, Type weight_type,
|
BlobError Load(const Path& weights, Model model_type, Type weight_type,
|
||||||
|
PromptWrapping wrapping, hwy::ThreadPool& pool,
|
||||||
|
std::string* tokenizer_proto);
|
||||||
|
// Writes the weights to a blob store file, using the single-file format with
|
||||||
|
// a TOC and config included.
|
||||||
|
BlobError Save(const std::string& tokenizer, const Path& weights,
|
||||||
hwy::ThreadPool& pool);
|
hwy::ThreadPool& pool);
|
||||||
void Allocate(Model model_type, Type weight_type, hwy::ThreadPool& pool) {
|
void Allocate(Model model_type, Type weight_type, hwy::ThreadPool& pool) {
|
||||||
Allocate(ConfigFromModel(model_type), weight_type, pool);
|
Allocate(ConfigFromModel(model_type), weight_type, pool);
|
||||||
|
|
|
||||||
|
|
@ -26,7 +26,6 @@
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <random>
|
#include <random>
|
||||||
|
|
||||||
#include "compression/compress.h"
|
|
||||||
#include "compression/shared.h"
|
#include "compression/shared.h"
|
||||||
#include "util/allocator.h"
|
#include "util/allocator.h"
|
||||||
#include "util/test_util.h"
|
#include "util/test_util.h"
|
||||||
|
|
|
||||||
41
util/app.h
41
util/app.h
|
|
@ -25,6 +25,7 @@
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
#include "compression/io.h" // Path
|
#include "compression/io.h" // Path
|
||||||
|
#include "compression/shared.h"
|
||||||
#include "gemma/common.h"
|
#include "gemma/common.h"
|
||||||
#include "gemma/gemma.h" // For CreateGemma
|
#include "gemma/gemma.h" // For CreateGemma
|
||||||
#include "ops/matmul.h"
|
#include "ops/matmul.h"
|
||||||
|
|
@ -125,7 +126,10 @@ static inline NestedPools CreatePools(const AppArgs& app) {
|
||||||
}
|
}
|
||||||
|
|
||||||
struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
||||||
LoaderArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
|
LoaderArgs(int argc, char* argv[], bool required = true)
|
||||||
|
: model_type_required(required) {
|
||||||
|
InitAndParse(argc, argv);
|
||||||
|
}
|
||||||
LoaderArgs(const std::string& tokenizer_path, const std::string& weights_path,
|
LoaderArgs(const std::string& tokenizer_path, const std::string& weights_path,
|
||||||
const std::string& model) {
|
const std::string& model) {
|
||||||
Init(); // Init sets to defaults, so assignments must come after Init().
|
Init(); // Init sets to defaults, so assignments must come after Init().
|
||||||
|
|
@ -136,18 +140,24 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
||||||
|
|
||||||
// Returns error string or nullptr if OK.
|
// Returns error string or nullptr if OK.
|
||||||
const char* Validate() {
|
const char* Validate() {
|
||||||
|
info_.model = Model::UNKNOWN;
|
||||||
|
info_.wrapping = PromptWrapping::GEMMA_PT;
|
||||||
|
info_.weight = Type::kUnknown;
|
||||||
if (const char* err = ParseModelTypeAndWrapping(model_type_str, info_.model,
|
if (const char* err = ParseModelTypeAndWrapping(model_type_str, info_.model,
|
||||||
info_.wrapping)) {
|
info_.wrapping)) {
|
||||||
return err;
|
if (model_type_required) return err;
|
||||||
}
|
}
|
||||||
if (const char* err = ParseType(weight_type_str, info_.weight)) {
|
if (const char* err = ParseType(weight_type_str, info_.weight)) {
|
||||||
return err;
|
if (model_type_required) return err;
|
||||||
}
|
}
|
||||||
if (tokenizer.path.empty()) {
|
if (model_type_required) {
|
||||||
return "Missing --tokenizer flag, a file for the tokenizer is required.";
|
if (tokenizer.path.empty()) {
|
||||||
}
|
return "Missing --tokenizer flag, a file for the tokenizer is "
|
||||||
if (!tokenizer.Exists()) {
|
"required.";
|
||||||
return "Can't open file specified with --tokenizer flag.";
|
}
|
||||||
|
if (!tokenizer.Exists()) {
|
||||||
|
return "Can't open file specified with --tokenizer flag.";
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if (!compressed_weights.path.empty()) {
|
if (!compressed_weights.path.empty()) {
|
||||||
if (weights.path.empty()) {
|
if (weights.path.empty()) {
|
||||||
|
|
@ -172,11 +182,12 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
||||||
Path compressed_weights;
|
Path compressed_weights;
|
||||||
std::string model_type_str;
|
std::string model_type_str;
|
||||||
std::string weight_type_str;
|
std::string weight_type_str;
|
||||||
|
bool model_type_required = true;
|
||||||
|
|
||||||
template <class Visitor>
|
template <class Visitor>
|
||||||
void ForEach(const Visitor& visitor) {
|
void ForEach(const Visitor& visitor) {
|
||||||
visitor(tokenizer, "tokenizer", Path(),
|
visitor(tokenizer, "tokenizer", Path(),
|
||||||
"Path name of tokenizer model file.\n Required argument.");
|
"Path name of tokenizer model file.");
|
||||||
visitor(weights, "weights", Path(),
|
visitor(weights, "weights", Path(),
|
||||||
"Path name of model weights (.sbs) file.\n Required argument.");
|
"Path name of model weights (.sbs) file.\n Required argument.");
|
||||||
visitor(compressed_weights, "compressed_weights", Path(),
|
visitor(compressed_weights, "compressed_weights", Path(),
|
||||||
|
|
@ -186,11 +197,9 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
||||||
"2b-pt = 2B parameters, pretrained\n 7b-it = 7B parameters "
|
"2b-pt = 2B parameters, pretrained\n 7b-it = 7B parameters "
|
||||||
"instruction-tuned\n 7b-pt = 7B parameters, pretrained\n "
|
"instruction-tuned\n 7b-pt = 7B parameters, pretrained\n "
|
||||||
"gr2b-it = griffin 2B parameters, instruction-tuned\n "
|
"gr2b-it = griffin 2B parameters, instruction-tuned\n "
|
||||||
"gr2b-pt = griffin 2B parameters, pretrained\n "
|
"gr2b-pt = griffin 2B parameters, pretrained.");
|
||||||
" Required argument.");
|
|
||||||
visitor(weight_type_str, "weight_type", std::string("sfp"),
|
visitor(weight_type_str, "weight_type", std::string("sfp"),
|
||||||
"Weight type\n f32 = float, bf16 = bfloat16, sfp = 8-bit FP\n"
|
"Weight type\n f32 = float, bf16 = bfloat16, sfp = 8-bit FP.");
|
||||||
" Required argument.");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Uninitialized before Validate, must call after that.
|
// Uninitialized before Validate, must call after that.
|
||||||
|
|
@ -208,6 +217,12 @@ static inline Gemma CreateGemma(const LoaderArgs& loader, NestedPools& pools) {
|
||||||
|
|
||||||
static inline std::unique_ptr<Gemma> AllocateGemma(const LoaderArgs& loader,
|
static inline std::unique_ptr<Gemma> AllocateGemma(const LoaderArgs& loader,
|
||||||
NestedPools& pools) {
|
NestedPools& pools) {
|
||||||
|
if (Type::kUnknown == loader.Info().weight ||
|
||||||
|
Model::UNKNOWN == loader.Info().model || loader.tokenizer.path.empty()) {
|
||||||
|
// Newer weights file format doesn't need tokenizer path or model/weight
|
||||||
|
// info.
|
||||||
|
return std::make_unique<Gemma>(loader.weights, pools);
|
||||||
|
}
|
||||||
return std::make_unique<Gemma>(loader.tokenizer, loader.weights,
|
return std::make_unique<Gemma>(loader.tokenizer, loader.weights,
|
||||||
loader.Info(), pools);
|
loader.Info(), pools);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue