mirror of https://github.com/google/gemma.cpp.git
Added the TensorInfo arg to the compressor so the shape and scale can be output correctly to the file in future.
Corrected some errors in the TensorIndex. PiperOrigin-RevId: 705014619
This commit is contained in:
parent
7b77909427
commit
e69bc3bc1c
|
|
@ -104,6 +104,9 @@ class BlobWriter {
|
||||||
// Stores all blobs to disk in the given order with padding for alignment.
|
// Stores all blobs to disk in the given order with padding for alignment.
|
||||||
BlobError WriteAll(hwy::ThreadPool& pool, const Path& filename);
|
BlobError WriteAll(hwy::ThreadPool& pool, const Path& filename);
|
||||||
|
|
||||||
|
// Returns the number of blobs added.
|
||||||
|
size_t DebugNumBlobsAdded() const { return keys_.size(); }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::vector<hwy::uint128_t> keys_;
|
std::vector<hwy::uint128_t> keys_;
|
||||||
std::vector<hwy::Span<const uint8_t>> blobs_;
|
std::vector<hwy::Span<const uint8_t>> blobs_;
|
||||||
|
|
|
||||||
|
|
@ -705,6 +705,9 @@ class Compressor {
|
||||||
return err;
|
return err;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Returns the number of blobs added.
|
||||||
|
size_t DebugNumBlobsAdded() const { return writer_.DebugNumBlobsAdded(); }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
CompressWorkingSet work_;
|
CompressWorkingSet work_;
|
||||||
hwy::ThreadPool& pool_;
|
hwy::ThreadPool& pool_;
|
||||||
|
|
|
||||||
|
|
@ -216,8 +216,9 @@ class MatPtrT : public MatPtr {
|
||||||
: MatPtr(name, TypeEnum<MatT>(), sizeof(MatT), rows, cols) {}
|
: MatPtr(name, TypeEnum<MatT>(), sizeof(MatT), rows, cols) {}
|
||||||
// Construction from TensorIndex entry to remove duplication of sizes.
|
// Construction from TensorIndex entry to remove duplication of sizes.
|
||||||
MatPtrT(const std::string& name, const TensorIndex& tensor_index)
|
MatPtrT(const std::string& name, const TensorIndex& tensor_index)
|
||||||
|
: MatPtrT<MatT>(name, tensor_index.FindName(name)) {}
|
||||||
|
MatPtrT(const std::string& name, const TensorInfo* tensor)
|
||||||
: MatPtr(name, TypeEnum<MatT>(), sizeof(MatT), 0, 0) {
|
: MatPtr(name, TypeEnum<MatT>(), sizeof(MatT), 0, 0) {
|
||||||
const TensorInfo* tensor = tensor_index.FindName(name);
|
|
||||||
HWY_ASSERT(tensor != nullptr);
|
HWY_ASSERT(tensor != nullptr);
|
||||||
cols_ = tensor->shape.back();
|
cols_ = tensor->shape.back();
|
||||||
rows_ = 1;
|
rows_ = 1;
|
||||||
|
|
|
||||||
|
|
@ -15,6 +15,7 @@ cc_library(
|
||||||
visibility = ["//visibility:private"],
|
visibility = ["//visibility:private"],
|
||||||
deps = [
|
deps = [
|
||||||
"@abseil-cpp//absl/types:span",
|
"@abseil-cpp//absl/types:span",
|
||||||
|
"//:common",
|
||||||
"//compression:compress",
|
"//compression:compress",
|
||||||
"//compression:io",
|
"//compression:io",
|
||||||
"@highway//:hwy",
|
"@highway//:hwy",
|
||||||
|
|
@ -28,6 +29,7 @@ pybind_extension(
|
||||||
deps = [
|
deps = [
|
||||||
":compression_clif_aux",
|
":compression_clif_aux",
|
||||||
"@abseil-cpp//absl/types:span",
|
"@abseil-cpp//absl/types:span",
|
||||||
|
"//:common",
|
||||||
"//compression:sfp",
|
"//compression:sfp",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,7 @@
|
||||||
#include "compression/python/compression_clif_aux.h"
|
#include "compression/python/compression_clif_aux.h"
|
||||||
|
|
||||||
|
#include <cstddef>
|
||||||
|
#include <cstdio>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
|
@ -22,6 +24,7 @@
|
||||||
|
|
||||||
#include "absl/types/span.h"
|
#include "absl/types/span.h"
|
||||||
#include "compression/io.h"
|
#include "compression/io.h"
|
||||||
|
#include "gemma/tensor_index.h"
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
|
|
||||||
|
|
@ -32,7 +35,8 @@ class WriterInterface {
|
||||||
virtual ~WriterInterface() = default;
|
virtual ~WriterInterface() = default;
|
||||||
|
|
||||||
virtual void Insert(std::string name, absl::Span<const float> weights,
|
virtual void Insert(std::string name, absl::Span<const float> weights,
|
||||||
Type type) = 0;
|
Type type, const TensorInfo& tensor_info,
|
||||||
|
float scale) = 0;
|
||||||
virtual void InsertSfp(std::string name, absl::Span<const float> weights) = 0;
|
virtual void InsertSfp(std::string name, absl::Span<const float> weights) = 0;
|
||||||
virtual void InsertNUQ(std::string name, absl::Span<const float> weights) = 0;
|
virtual void InsertNUQ(std::string name, absl::Span<const float> weights) = 0;
|
||||||
virtual void InsertBfloat16(std::string name,
|
virtual void InsertBfloat16(std::string name,
|
||||||
|
|
@ -41,6 +45,8 @@ class WriterInterface {
|
||||||
absl::Span<const float> weights) = 0;
|
absl::Span<const float> weights) = 0;
|
||||||
virtual void AddScales(const std::vector<float>& scales) = 0;
|
virtual void AddScales(const std::vector<float>& scales) = 0;
|
||||||
|
|
||||||
|
virtual size_t DebugNumBlobsAdded() const = 0;
|
||||||
|
|
||||||
virtual int Write(std::string path) = 0;
|
virtual int Write(std::string path) = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -65,24 +71,39 @@ class SbsWriterImpl : public WriterInterface {
|
||||||
std::string decorated_name = storage.CacheName();
|
std::string decorated_name = storage.CacheName();
|
||||||
compressor_(&storage, decorated_name.c_str(), weights.data());
|
compressor_(&storage, decorated_name.c_str(), weights.data());
|
||||||
}
|
}
|
||||||
|
template <typename Packed>
|
||||||
|
void AllocateWithShape(const std::string& name,
|
||||||
|
absl::Span<const float> weights,
|
||||||
|
const TensorInfo& tensor_info, float scale) {
|
||||||
|
MatPtrT<Packed> storage(name, &tensor_info);
|
||||||
|
storage.set_scale(scale);
|
||||||
|
storage.SetNumElements(CompressedArrayElements<Packed>(weights.size()));
|
||||||
|
model_memory_.push_back(storage);
|
||||||
|
if (mode_ == CompressorMode::kTEST_ONLY) return;
|
||||||
|
model_memory_.back().Allocate();
|
||||||
|
storage.SetPtr(model_memory_.back());
|
||||||
|
std::string decorated_name = storage.CacheName();
|
||||||
|
compressor_(&storage, decorated_name.c_str(), weights.data());
|
||||||
|
}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
SbsWriterImpl() : pool_(0), compressor_(pool_) {}
|
explicit SbsWriterImpl(CompressorMode mode)
|
||||||
|
: pool_(0), compressor_(pool_), mode_(mode) {}
|
||||||
|
|
||||||
void Insert(std::string name, absl::Span<const float> weights,
|
void Insert(std::string name, absl::Span<const float> weights, Type type,
|
||||||
Type type) override {
|
const TensorInfo& tensor_info, float scale) override {
|
||||||
switch (type) {
|
switch (type) {
|
||||||
case Type::kSFP:
|
case Type::kSFP:
|
||||||
AllocateAndCompress<SfpStream>(name, weights);
|
AllocateWithShape<SfpStream>(name, weights, tensor_info, scale);
|
||||||
break;
|
break;
|
||||||
case Type::kNUQ:
|
case Type::kNUQ:
|
||||||
AllocateAndCompress<NuqStream>(name, weights);
|
AllocateWithShape<NuqStream>(name, weights, tensor_info, scale);
|
||||||
break;
|
break;
|
||||||
case Type::kBF16:
|
case Type::kBF16:
|
||||||
AllocateAndCompress<BF16>(name, weights);
|
AllocateWithShape<BF16>(name, weights, tensor_info, scale);
|
||||||
break;
|
break;
|
||||||
case Type::kF32:
|
case Type::kF32:
|
||||||
AllocateAndCompress<float>(name, weights);
|
AllocateWithShape<float>(name, weights, tensor_info, scale);
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
HWY_ABORT("Unsupported type");
|
HWY_ABORT("Unsupported type");
|
||||||
|
|
@ -112,6 +133,12 @@ class SbsWriterImpl : public WriterInterface {
|
||||||
compressor_.AddScales(scales_.data(), scales_.size());
|
compressor_.AddScales(scales_.data(), scales_.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Returns the number of blobs added.
|
||||||
|
size_t DebugNumBlobsAdded() const {
|
||||||
|
if (mode_ == CompressorMode::kTEST_ONLY) return model_memory_.size();
|
||||||
|
return compressor_.DebugNumBlobsAdded();
|
||||||
|
}
|
||||||
|
|
||||||
int Write(std::string path) override {
|
int Write(std::string path) override {
|
||||||
return compressor_.WriteAll(pool_, gcpp::Path(path));
|
return compressor_.WriteAll(pool_, gcpp::Path(path));
|
||||||
}
|
}
|
||||||
|
|
@ -121,9 +148,12 @@ class SbsWriterImpl : public WriterInterface {
|
||||||
CompressWorkingSet working_set_;
|
CompressWorkingSet working_set_;
|
||||||
std::vector<MatStorage> model_memory_;
|
std::vector<MatStorage> model_memory_;
|
||||||
std::vector<float> scales_;
|
std::vector<float> scales_;
|
||||||
|
CompressorMode mode_;
|
||||||
};
|
};
|
||||||
|
|
||||||
WriterInterface* NewSbsWriter() { return new SbsWriterImpl(); }
|
WriterInterface* NewSbsWriter(CompressorMode mode) {
|
||||||
|
return new SbsWriterImpl(mode);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace HWY_NAMESPACE
|
} // namespace HWY_NAMESPACE
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
@ -134,12 +164,13 @@ namespace gcpp {
|
||||||
|
|
||||||
HWY_EXPORT(NewSbsWriter);
|
HWY_EXPORT(NewSbsWriter);
|
||||||
|
|
||||||
SbsWriter::SbsWriter() : impl_(HWY_DYNAMIC_DISPATCH(NewSbsWriter)()) {}
|
SbsWriter::SbsWriter(CompressorMode mode)
|
||||||
|
: impl_(HWY_DYNAMIC_DISPATCH(NewSbsWriter)(mode)) {}
|
||||||
SbsWriter::~SbsWriter() = default;
|
SbsWriter::~SbsWriter() = default;
|
||||||
|
|
||||||
void SbsWriter::Insert(std::string name, absl::Span<const float> weights,
|
void SbsWriter::Insert(std::string name, absl::Span<const float> weights,
|
||||||
Type type) {
|
Type type, const TensorInfo& tensor_info, float scale) {
|
||||||
impl_->Insert(name, weights, type);
|
impl_->Insert(name, weights, type, tensor_info, scale);
|
||||||
}
|
}
|
||||||
void SbsWriter::InsertSfp(std::string name, absl::Span<const float> weights) {
|
void SbsWriter::InsertSfp(std::string name, absl::Span<const float> weights) {
|
||||||
impl_->InsertSfp(name, weights);
|
impl_->InsertSfp(name, weights);
|
||||||
|
|
@ -158,6 +189,11 @@ void SbsWriter::InsertFloat(std::string name, absl::Span<const float> weights) {
|
||||||
void SbsWriter::AddScales(const std::vector<float>& scales) {
|
void SbsWriter::AddScales(const std::vector<float>& scales) {
|
||||||
impl_->AddScales(scales);
|
impl_->AddScales(scales);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
size_t SbsWriter::DebugNumBlobsAdded() const {
|
||||||
|
return impl_->DebugNumBlobsAdded();
|
||||||
|
}
|
||||||
|
|
||||||
int SbsWriter::Write(std::string path) { return impl_->Write(path); }
|
int SbsWriter::Write(std::string path) { return impl_->Write(path); }
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
|
||||||
|
|
@ -1,29 +1,44 @@
|
||||||
#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_PYTHON_COMPRESSION_CLIF_AUX_H_
|
#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_PYTHON_COMPRESSION_CLIF_AUX_H_
|
||||||
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_PYTHON_COMPRESSION_CLIF_AUX_H_
|
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_PYTHON_COMPRESSION_CLIF_AUX_H_
|
||||||
|
|
||||||
|
#include <cstddef>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "absl/types/span.h"
|
#include "absl/types/span.h"
|
||||||
#include "compression/shared.h"
|
#include "compression/shared.h"
|
||||||
|
#include "gemma/tensor_index.h"
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
|
// How to process the data.
|
||||||
|
enum class CompressorMode {
|
||||||
|
// No compression, no write to file, just for testing.
|
||||||
|
kTEST_ONLY,
|
||||||
|
// Old-style compression, no table of contents.
|
||||||
|
kNO_TOC,
|
||||||
|
// New-style compression, with table of contents.
|
||||||
|
kWITH_TOC,
|
||||||
|
};
|
||||||
|
|
||||||
class WriterInterface;
|
class WriterInterface;
|
||||||
|
|
||||||
class SbsWriter {
|
class SbsWriter {
|
||||||
public:
|
public:
|
||||||
SbsWriter();
|
explicit SbsWriter(CompressorMode mode);
|
||||||
~SbsWriter();
|
~SbsWriter();
|
||||||
|
|
||||||
void Insert(std::string name, absl::Span<const float> weights, Type type);
|
void Insert(std::string name, absl::Span<const float> weights, Type type,
|
||||||
|
const TensorInfo& tensor_info, float scale);
|
||||||
void InsertSfp(std::string name, absl::Span<const float> weights);
|
void InsertSfp(std::string name, absl::Span<const float> weights);
|
||||||
void InsertNUQ(std::string name, absl::Span<const float> weights);
|
void InsertNUQ(std::string name, absl::Span<const float> weights);
|
||||||
void InsertBfloat16(std::string name, absl::Span<const float> weights);
|
void InsertBfloat16(std::string name, absl::Span<const float> weights);
|
||||||
void InsertFloat(std::string name, absl::Span<const float> weights);
|
void InsertFloat(std::string name, absl::Span<const float> weights);
|
||||||
void AddScales(const std::vector<float>& scales);
|
void AddScales(const std::vector<float>& scales);
|
||||||
|
|
||||||
|
size_t DebugNumBlobsAdded() const;
|
||||||
|
|
||||||
int Write(std::string path);
|
int Write(std::string path);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,7 @@
|
||||||
#include "compression/python/compression_clif_aux.h"
|
#include "compression/python/compression_clif_aux.h"
|
||||||
#include "compression/shared.h"
|
#include "compression/shared.h"
|
||||||
|
|
||||||
|
using gcpp::CompressorMode;
|
||||||
using gcpp::SbsWriter;
|
using gcpp::SbsWriter;
|
||||||
|
|
||||||
namespace py = pybind11;
|
namespace py = pybind11;
|
||||||
|
|
@ -23,18 +24,24 @@ void wrap_span(SbsWriter& writer, std::string name, py::array_t<float> data) {
|
||||||
}
|
}
|
||||||
template <auto Func>
|
template <auto Func>
|
||||||
void wrap_span_typed(SbsWriter& writer, std::string name,
|
void wrap_span_typed(SbsWriter& writer, std::string name,
|
||||||
py::array_t<float> data, gcpp::Type type) {
|
py::array_t<float> data, gcpp::Type type,
|
||||||
|
gcpp::TensorInfo tensor_info, float scale) {
|
||||||
if (data.ndim() != 1 || data.strides(0) != sizeof(float)) {
|
if (data.ndim() != 1 || data.strides(0) != sizeof(float)) {
|
||||||
throw std::domain_error("Input array must be 1D and densely packed.");
|
throw std::domain_error("Input array must be 1D and densely packed.");
|
||||||
}
|
}
|
||||||
std::invoke(Func, writer, name, absl::MakeSpan(data.data(0), data.size()),
|
std::invoke(Func, writer, name, absl::MakeSpan(data.data(0), data.size()),
|
||||||
type);
|
type, tensor_info, scale);
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
PYBIND11_MODULE(compression, m) {
|
PYBIND11_MODULE(compression, m) {
|
||||||
|
py::enum_<CompressorMode>(m, "CompressorMode")
|
||||||
|
.value("TEST_ONLY", CompressorMode::kTEST_ONLY)
|
||||||
|
.value("NO_TOC", CompressorMode::kNO_TOC)
|
||||||
|
.value("WITH_TOC", CompressorMode::kWITH_TOC);
|
||||||
|
|
||||||
py::class_<SbsWriter>(m, "SbsWriter")
|
py::class_<SbsWriter>(m, "SbsWriter")
|
||||||
.def(py::init<>())
|
.def(py::init<CompressorMode>())
|
||||||
// NOTE: Individual compression backends may impose constraints on the
|
// NOTE: Individual compression backends may impose constraints on the
|
||||||
// array length, such as a minimum of (say) 32 elements.
|
// array length, such as a minimum of (say) 32 elements.
|
||||||
.def("insert", wrap_span_typed<&SbsWriter::Insert>)
|
.def("insert", wrap_span_typed<&SbsWriter::Insert>)
|
||||||
|
|
@ -43,5 +50,6 @@ PYBIND11_MODULE(compression, m) {
|
||||||
.def("insert_bf16", wrap_span<&SbsWriter::InsertBfloat16>)
|
.def("insert_bf16", wrap_span<&SbsWriter::InsertBfloat16>)
|
||||||
.def("insert_float", wrap_span<&SbsWriter::InsertFloat>)
|
.def("insert_float", wrap_span<&SbsWriter::InsertFloat>)
|
||||||
.def("add_scales", &SbsWriter::AddScales)
|
.def("add_scales", &SbsWriter::AddScales)
|
||||||
|
.def("debug_num_blobs_added", &SbsWriter::DebugNumBlobsAdded)
|
||||||
.def("write", &SbsWriter::Write);
|
.def("write", &SbsWriter::Write);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -11,12 +11,18 @@ class CompressionTest(unittest.TestCase):
|
||||||
|
|
||||||
def test_sbs_writer(self):
|
def test_sbs_writer(self):
|
||||||
temp_file = self.create_tempfile("test.sbs")
|
temp_file = self.create_tempfile("test.sbs")
|
||||||
|
tensor_info = configs.TensorInfo()
|
||||||
|
tensor_info.name = "foo"
|
||||||
|
tensor_info.axes = [0]
|
||||||
|
tensor_info.shape = [192]
|
||||||
|
|
||||||
writer = compression.SbsWriter()
|
writer = compression.SbsWriter(compression.CompressorMode.NO_TOC)
|
||||||
writer.insert(
|
writer.insert(
|
||||||
"foo",
|
"foo",
|
||||||
np.array([0.0012] * 128 + [0.001] * 64, dtype=np.float32),
|
np.array([0.0012] * 128 + [0.001] * 64, dtype=np.float32),
|
||||||
configs.Type.kSFP,
|
configs.Type.kSFP,
|
||||||
|
tensor_info,
|
||||||
|
1.0,
|
||||||
)
|
)
|
||||||
writer.insert_sfp(
|
writer.insert_sfp(
|
||||||
"bar", np.array([0.000375] * 128 + [0.00009] * 128, dtype=np.float32)
|
"bar", np.array([0.000375] * 128 + [0.00009] * 128, dtype=np.float32)
|
||||||
|
|
@ -30,6 +36,7 @@ class CompressionTest(unittest.TestCase):
|
||||||
writer.insert_float(
|
writer.insert_float(
|
||||||
"quux", np.array([0.000375] * 128 + [0.00006] * 128, dtype=np.float32)
|
"quux", np.array([0.000375] * 128 + [0.00006] * 128, dtype=np.float32)
|
||||||
)
|
)
|
||||||
|
self.assertEqual(writer.debug_num_blobs_added(), 5)
|
||||||
self.assertEqual(writer.write(temp_file.full_path), 0)
|
self.assertEqual(writer.write(temp_file.full_path), 0)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -138,8 +138,8 @@ std::vector<TensorInfo> ImageLayerTensors(const ModelConfig& config,
|
||||||
TensorInfo{
|
TensorInfo{
|
||||||
.name = "qkv_ein_w",
|
.name = "qkv_ein_w",
|
||||||
.source_names = {"MultiHeadDotProductAttention_0/qkv/kernel"},
|
.source_names = {"MultiHeadDotProductAttention_0/qkv/kernel"},
|
||||||
.axes = {2, 0, 3, 1},
|
.axes = {1, 2, 0},
|
||||||
.shape = {layer_config.heads, 3, layer_config.qkv_dim,
|
.shape = {layer_config.heads, 3 * layer_config.qkv_dim,
|
||||||
config.vit_model_dim},
|
config.vit_model_dim},
|
||||||
.min_size = Type::kBF16,
|
.min_size = Type::kBF16,
|
||||||
},
|
},
|
||||||
|
|
@ -156,7 +156,7 @@ std::vector<TensorInfo> ImageLayerTensors(const ModelConfig& config,
|
||||||
.name = "k_ein_b",
|
.name = "k_ein_b",
|
||||||
.source_names = {"MultiHeadDotProductAttention_0/key/bias"},
|
.source_names = {"MultiHeadDotProductAttention_0/key/bias"},
|
||||||
.axes = {0, 1},
|
.axes = {0, 1},
|
||||||
.shape = {layer_config.heads, layer_config.qkv_dim},
|
.shape = {layer_config.kv_heads, layer_config.qkv_dim},
|
||||||
.concat_names = {""},
|
.concat_names = {""},
|
||||||
.min_size = Type::kF32,
|
.min_size = Type::kF32,
|
||||||
},
|
},
|
||||||
|
|
@ -164,15 +164,16 @@ std::vector<TensorInfo> ImageLayerTensors(const ModelConfig& config,
|
||||||
.name = "v_ein_b",
|
.name = "v_ein_b",
|
||||||
.source_names = {"MultiHeadDotProductAttention_0/value/bias"},
|
.source_names = {"MultiHeadDotProductAttention_0/value/bias"},
|
||||||
.axes = {0, 1},
|
.axes = {0, 1},
|
||||||
.shape = {layer_config.heads, layer_config.qkv_dim},
|
.shape = {layer_config.kv_heads, layer_config.qkv_dim},
|
||||||
.concat_names = {""},
|
.concat_names = {""},
|
||||||
.min_size = Type::kF32,
|
.min_size = Type::kF32,
|
||||||
},
|
},
|
||||||
TensorInfo{
|
TensorInfo{
|
||||||
.name = "qkv_ein_b",
|
.name = "qkv_ein_b",
|
||||||
.source_names = {"MultiHeadDotProductAttention_0/qkv/bias"},
|
.source_names = {"MultiHeadDotProductAttention_0/qkv/bias"},
|
||||||
.axes = {1, 0, 2},
|
.axes = {0, 1},
|
||||||
.shape = {layer_config.heads * 3, layer_config.qkv_dim},
|
.shape = {layer_config.heads + layer_config.kv_heads * 2,
|
||||||
|
layer_config.qkv_dim},
|
||||||
.min_size = Type::kF32,
|
.min_size = Type::kF32,
|
||||||
},
|
},
|
||||||
TensorInfo{
|
TensorInfo{
|
||||||
|
|
@ -243,14 +244,15 @@ std::vector<TensorInfo> LLMLayerTensors(const ModelConfig& config,
|
||||||
.name = "qkv1_w",
|
.name = "qkv1_w",
|
||||||
.source_names = {"attn/q_einsum/w"},
|
.source_names = {"attn/q_einsum/w"},
|
||||||
.axes = {0, 2, 1},
|
.axes = {0, 2, 1},
|
||||||
.shape = {layer_config.heads, layer_config.qkv_dim, config.model_dim},
|
.shape = {layer_config.heads * layer_config.qkv_dim,
|
||||||
|
config.model_dim},
|
||||||
.concat_names = {"qkv_ein", "qkv2_w"},
|
.concat_names = {"qkv_ein", "qkv2_w"},
|
||||||
},
|
},
|
||||||
TensorInfo{
|
TensorInfo{
|
||||||
.name = "qkv2_w",
|
.name = "qkv2_w",
|
||||||
.source_names = {"attn/kv_einsum/w"},
|
.source_names = {"attn/kv_einsum/w"},
|
||||||
.axes = {1, 0, 3, 2},
|
.axes = {1, 0, 3, 2},
|
||||||
.shape = {2 * layer_config.kv_heads, layer_config.qkv_dim,
|
.shape = {2 * layer_config.kv_heads * layer_config.qkv_dim,
|
||||||
config.model_dim},
|
config.model_dim},
|
||||||
.concat_names = {""},
|
.concat_names = {""},
|
||||||
},
|
},
|
||||||
|
|
@ -279,8 +281,9 @@ std::vector<TensorInfo> LLMLayerTensors(const ModelConfig& config,
|
||||||
.name = "qkv_ein",
|
.name = "qkv_ein",
|
||||||
.source_names = {"attn/qkv_einsum/w"},
|
.source_names = {"attn/qkv_einsum/w"},
|
||||||
.axes = {1, 0, 3, 2},
|
.axes = {1, 0, 3, 2},
|
||||||
.shape = {(layer_config.heads + 2 * layer_config.kv_heads),
|
.shape = {(layer_config.heads + 2 * layer_config.kv_heads) *
|
||||||
layer_config.qkv_dim, config.model_dim},
|
layer_config.qkv_dim,
|
||||||
|
config.model_dim},
|
||||||
},
|
},
|
||||||
TensorInfo{
|
TensorInfo{
|
||||||
.name = "attn_ob",
|
.name = "attn_ob",
|
||||||
|
|
@ -535,7 +538,8 @@ TensorIndex::TensorIndex(const ModelConfig& config, int llm_layer_idx,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TensorInfo TensorIndex::GetTensorInfo(const std::string& path) const {
|
TensorInfo TensorIndex::TensorInfoFromSourcePath(
|
||||||
|
const std::string& path) const {
|
||||||
for (const auto& tensor : tensors_) {
|
for (const auto& tensor : tensors_) {
|
||||||
for (const auto& source_name : tensor.source_names) {
|
for (const auto& source_name : tensor.source_names) {
|
||||||
auto pos = path.rfind(source_name);
|
auto pos = path.rfind(source_name);
|
||||||
|
|
|
||||||
|
|
@ -68,7 +68,17 @@ class TensorIndex {
|
||||||
// or an empty TensorInfo if not found.
|
// or an empty TensorInfo if not found.
|
||||||
// NOTE: that the returned TensorInfo is a copy, so that the source
|
// NOTE: that the returned TensorInfo is a copy, so that the source
|
||||||
// TensorIndex can be destroyed without affecting the returned TensorInfo.
|
// TensorIndex can be destroyed without affecting the returned TensorInfo.
|
||||||
TensorInfo GetTensorInfo(const std::string& path) const;
|
TensorInfo TensorInfoFromSourcePath(const std::string& path) const;
|
||||||
|
|
||||||
|
// Returns the TensorInfo whose name matches the given name,
|
||||||
|
// or an empty TensorInfo if not found.
|
||||||
|
// NOTE: that the returned TensorInfo is a copy, so that the source
|
||||||
|
// TensorIndex can be destroyed without affecting the returned TensorInfo.
|
||||||
|
TensorInfo TensorInfoFromName(const std::string& name) const {
|
||||||
|
const TensorInfo* info = FindName(name);
|
||||||
|
if (info == nullptr) return TensorInfo();
|
||||||
|
return *info;
|
||||||
|
}
|
||||||
|
|
||||||
// Returns the TensorInfo for the given tensor name, for concise construction
|
// Returns the TensorInfo for the given tensor name, for concise construction
|
||||||
// of ModelWeightsPtrs/LayerWeightsPtrs.
|
// of ModelWeightsPtrs/LayerWeightsPtrs.
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue