Adds insert_float() to SbsWriter() to store a float array directly.

PiperOrigin-RevId: 673982528
This commit is contained in:
Daniel Keysers 2024-09-12 13:26:45 -07:00 committed by Copybara-Service
parent 13a9f76f64
commit 1c8ddcdffe
4 changed files with 15 additions and 0 deletions

View File

@ -7,6 +7,7 @@ from "third_party/gemma_cpp/compression/python/compression_clif_aux.h":
def `Insert` as insert(self, name: str, weights: NumpyArray<float>)
def `InsertNUQ` as insert_nuq(self, name: str, weights: NumpyArray<float>)
def `InsertBfloat16` as insert_bf16(self, name: str, weights: NumpyArray<float>)
def `InsertFloat` as insert_float(self, name: str, weights: NumpyArray<float>)
def `AddScales` as add_scales(self, scales: list<float>)

View File

@ -35,6 +35,8 @@ class WriterInterface {
virtual void InsertNUQ(std::string name, absl::Span<const float> weights) = 0;
virtual void InsertBfloat16(std::string name,
absl::Span<const float> weights) = 0;
virtual void InsertFloat(std::string name,
absl::Span<const float> weights) = 0;
virtual void AddScales(const std::vector<float>& scales) = 0;
virtual void Write(std::string path) = 0;
@ -77,6 +79,10 @@ class SbsWriterImpl : public WriterInterface {
bf16_streams_.push_back(AllocateAndCompress<BF16>(name, weights));
}
void InsertFloat(std::string name, absl::Span<const float> weights) override {
f32_streams_.push_back(AllocateAndCompress<float>(name, weights));
}
void AddScales(const std::vector<float>& scales) override {
HWY_ASSERT(scales_.empty());
scales_ = scales;
@ -93,6 +99,7 @@ class SbsWriterImpl : public WriterInterface {
std::vector<hwy::AlignedFreeUniquePtr<SfpStream[]>> sfp_streams_;
std::vector<hwy::AlignedFreeUniquePtr<NuqStream[]>> nuq_streams_;
std::vector<hwy::AlignedFreeUniquePtr<BF16[]>> bf16_streams_;
std::vector<hwy::AlignedFreeUniquePtr<float[]>> f32_streams_;
std::vector<float> scales_;
};
@ -120,6 +127,9 @@ void SbsWriter::InsertBfloat16(std::string name,
absl::Span<const float> weights) {
impl_->InsertBfloat16(name, weights);
}
void SbsWriter::InsertFloat(std::string name, absl::Span<const float> weights) {
impl_->InsertFloat(name, weights);
}
void SbsWriter::AddScales(const std::vector<float>& scales) {
impl_->AddScales(scales);

View File

@ -19,6 +19,7 @@ class SbsWriter {
void Insert(std::string name, absl::Span<const float> weights);
void InsertNUQ(std::string name, absl::Span<const float> weights);
void InsertBfloat16(std::string name, absl::Span<const float> weights);
void InsertFloat(std::string name, absl::Span<const float> weights);
void AddScales(const std::vector<float>& scales);
void Write(std::string path);

View File

@ -24,6 +24,9 @@ class CompressionTest(unittest.TestCase):
writer.insert_bf16(
"qux", np.array([0.000375] * 128 + [0.00007] * 128, dtype=np.float32)
)
writer.insert_float(
"quux", np.array([0.000375] * 128 + [0.00006] * 128, dtype=np.float32)
)
writer.write(temp_file.full_path)