Adds insert_float() to SbsWriter() to store a float array directly.

PiperOrigin-RevId: 673982528
This commit is contained in:
Daniel Keysers 2024-09-12 13:26:45 -07:00 committed by Copybara-Service
parent 13a9f76f64
commit 1c8ddcdffe
4 changed files with 15 additions and 0 deletions

View File

@ -7,6 +7,7 @@ from "third_party/gemma_cpp/compression/python/compression_clif_aux.h":
def `Insert` as insert(self, name: str, weights: NumpyArray<float>) def `Insert` as insert(self, name: str, weights: NumpyArray<float>)
def `InsertNUQ` as insert_nuq(self, name: str, weights: NumpyArray<float>) def `InsertNUQ` as insert_nuq(self, name: str, weights: NumpyArray<float>)
def `InsertBfloat16` as insert_bf16(self, name: str, weights: NumpyArray<float>) def `InsertBfloat16` as insert_bf16(self, name: str, weights: NumpyArray<float>)
def `InsertFloat` as insert_float(self, name: str, weights: NumpyArray<float>)
def `AddScales` as add_scales(self, scales: list<float>) def `AddScales` as add_scales(self, scales: list<float>)

View File

@ -35,6 +35,8 @@ class WriterInterface {
virtual void InsertNUQ(std::string name, absl::Span<const float> weights) = 0; virtual void InsertNUQ(std::string name, absl::Span<const float> weights) = 0;
virtual void InsertBfloat16(std::string name, virtual void InsertBfloat16(std::string name,
absl::Span<const float> weights) = 0; absl::Span<const float> weights) = 0;
virtual void InsertFloat(std::string name,
absl::Span<const float> weights) = 0;
virtual void AddScales(const std::vector<float>& scales) = 0; virtual void AddScales(const std::vector<float>& scales) = 0;
virtual void Write(std::string path) = 0; virtual void Write(std::string path) = 0;
@ -77,6 +79,10 @@ class SbsWriterImpl : public WriterInterface {
bf16_streams_.push_back(AllocateAndCompress<BF16>(name, weights)); bf16_streams_.push_back(AllocateAndCompress<BF16>(name, weights));
} }
void InsertFloat(std::string name, absl::Span<const float> weights) override {
f32_streams_.push_back(AllocateAndCompress<float>(name, weights));
}
void AddScales(const std::vector<float>& scales) override { void AddScales(const std::vector<float>& scales) override {
HWY_ASSERT(scales_.empty()); HWY_ASSERT(scales_.empty());
scales_ = scales; scales_ = scales;
@ -93,6 +99,7 @@ class SbsWriterImpl : public WriterInterface {
std::vector<hwy::AlignedFreeUniquePtr<SfpStream[]>> sfp_streams_; std::vector<hwy::AlignedFreeUniquePtr<SfpStream[]>> sfp_streams_;
std::vector<hwy::AlignedFreeUniquePtr<NuqStream[]>> nuq_streams_; std::vector<hwy::AlignedFreeUniquePtr<NuqStream[]>> nuq_streams_;
std::vector<hwy::AlignedFreeUniquePtr<BF16[]>> bf16_streams_; std::vector<hwy::AlignedFreeUniquePtr<BF16[]>> bf16_streams_;
std::vector<hwy::AlignedFreeUniquePtr<float[]>> f32_streams_;
std::vector<float> scales_; std::vector<float> scales_;
}; };
@ -120,6 +127,9 @@ void SbsWriter::InsertBfloat16(std::string name,
absl::Span<const float> weights) { absl::Span<const float> weights) {
impl_->InsertBfloat16(name, weights); impl_->InsertBfloat16(name, weights);
} }
void SbsWriter::InsertFloat(std::string name, absl::Span<const float> weights) {
impl_->InsertFloat(name, weights);
}
void SbsWriter::AddScales(const std::vector<float>& scales) { void SbsWriter::AddScales(const std::vector<float>& scales) {
impl_->AddScales(scales); impl_->AddScales(scales);

View File

@ -19,6 +19,7 @@ class SbsWriter {
void Insert(std::string name, absl::Span<const float> weights); void Insert(std::string name, absl::Span<const float> weights);
void InsertNUQ(std::string name, absl::Span<const float> weights); void InsertNUQ(std::string name, absl::Span<const float> weights);
void InsertBfloat16(std::string name, absl::Span<const float> weights); void InsertBfloat16(std::string name, absl::Span<const float> weights);
void InsertFloat(std::string name, absl::Span<const float> weights);
void AddScales(const std::vector<float>& scales); void AddScales(const std::vector<float>& scales);
void Write(std::string path); void Write(std::string path);

View File

@ -24,6 +24,9 @@ class CompressionTest(unittest.TestCase):
writer.insert_bf16( writer.insert_bf16(
"qux", np.array([0.000375] * 128 + [0.00007] * 128, dtype=np.float32) "qux", np.array([0.000375] * 128 + [0.00007] * 128, dtype=np.float32)
) )
writer.insert_float(
"quux", np.array([0.000375] * 128 + [0.00006] * 128, dtype=np.float32)
)
writer.write(temp_file.full_path) writer.write(temp_file.full_path)