mirror of https://github.com/google/gemma.cpp.git
Adds insert_float() to SbsWriter() to store a float array directly.
PiperOrigin-RevId: 673982528
This commit is contained in:
parent
13a9f76f64
commit
1c8ddcdffe
|
|
@ -7,6 +7,7 @@ from "third_party/gemma_cpp/compression/python/compression_clif_aux.h":
|
|||
def `Insert` as insert(self, name: str, weights: NumpyArray<float>)
|
||||
def `InsertNUQ` as insert_nuq(self, name: str, weights: NumpyArray<float>)
|
||||
def `InsertBfloat16` as insert_bf16(self, name: str, weights: NumpyArray<float>)
|
||||
def `InsertFloat` as insert_float(self, name: str, weights: NumpyArray<float>)
|
||||
|
||||
def `AddScales` as add_scales(self, scales: list<float>)
|
||||
|
||||
|
|
|
|||
|
|
@ -35,6 +35,8 @@ class WriterInterface {
|
|||
virtual void InsertNUQ(std::string name, absl::Span<const float> weights) = 0;
|
||||
virtual void InsertBfloat16(std::string name,
|
||||
absl::Span<const float> weights) = 0;
|
||||
virtual void InsertFloat(std::string name,
|
||||
absl::Span<const float> weights) = 0;
|
||||
virtual void AddScales(const std::vector<float>& scales) = 0;
|
||||
|
||||
virtual void Write(std::string path) = 0;
|
||||
|
|
@ -77,6 +79,10 @@ class SbsWriterImpl : public WriterInterface {
|
|||
bf16_streams_.push_back(AllocateAndCompress<BF16>(name, weights));
|
||||
}
|
||||
|
||||
void InsertFloat(std::string name, absl::Span<const float> weights) override {
|
||||
f32_streams_.push_back(AllocateAndCompress<float>(name, weights));
|
||||
}
|
||||
|
||||
void AddScales(const std::vector<float>& scales) override {
|
||||
HWY_ASSERT(scales_.empty());
|
||||
scales_ = scales;
|
||||
|
|
@ -93,6 +99,7 @@ class SbsWriterImpl : public WriterInterface {
|
|||
std::vector<hwy::AlignedFreeUniquePtr<SfpStream[]>> sfp_streams_;
|
||||
std::vector<hwy::AlignedFreeUniquePtr<NuqStream[]>> nuq_streams_;
|
||||
std::vector<hwy::AlignedFreeUniquePtr<BF16[]>> bf16_streams_;
|
||||
std::vector<hwy::AlignedFreeUniquePtr<float[]>> f32_streams_;
|
||||
std::vector<float> scales_;
|
||||
};
|
||||
|
||||
|
|
@ -120,6 +127,9 @@ void SbsWriter::InsertBfloat16(std::string name,
|
|||
absl::Span<const float> weights) {
|
||||
impl_->InsertBfloat16(name, weights);
|
||||
}
|
||||
void SbsWriter::InsertFloat(std::string name, absl::Span<const float> weights) {
|
||||
impl_->InsertFloat(name, weights);
|
||||
}
|
||||
|
||||
void SbsWriter::AddScales(const std::vector<float>& scales) {
|
||||
impl_->AddScales(scales);
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ class SbsWriter {
|
|||
void Insert(std::string name, absl::Span<const float> weights);
|
||||
void InsertNUQ(std::string name, absl::Span<const float> weights);
|
||||
void InsertBfloat16(std::string name, absl::Span<const float> weights);
|
||||
void InsertFloat(std::string name, absl::Span<const float> weights);
|
||||
void AddScales(const std::vector<float>& scales);
|
||||
|
||||
void Write(std::string path);
|
||||
|
|
|
|||
|
|
@ -24,6 +24,9 @@ class CompressionTest(unittest.TestCase):
|
|||
writer.insert_bf16(
|
||||
"qux", np.array([0.000375] * 128 + [0.00007] * 128, dtype=np.float32)
|
||||
)
|
||||
writer.insert_float(
|
||||
"quux", np.array([0.000375] * 128 + [0.00006] * 128, dtype=np.float32)
|
||||
)
|
||||
writer.write(temp_file.full_path)
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue