diff --git a/compression/python/compression.clif b/compression/python/compression.clif index ea31862..69dfc9b 100644 --- a/compression/python/compression.clif +++ b/compression/python/compression.clif @@ -7,6 +7,7 @@ from "third_party/gemma_cpp/compression/python/compression_clif_aux.h": def `Insert` as insert(self, name: str, weights: NumpyArray) def `InsertNUQ` as insert_nuq(self, name: str, weights: NumpyArray) def `InsertBfloat16` as insert_bf16(self, name: str, weights: NumpyArray) + def `InsertFloat` as insert_float(self, name: str, weights: NumpyArray) def `AddScales` as add_scales(self, scales: list) diff --git a/compression/python/compression_clif_aux.cc b/compression/python/compression_clif_aux.cc index 1be4ed2..a6dcd11 100644 --- a/compression/python/compression_clif_aux.cc +++ b/compression/python/compression_clif_aux.cc @@ -35,6 +35,8 @@ class WriterInterface { virtual void InsertNUQ(std::string name, absl::Span weights) = 0; virtual void InsertBfloat16(std::string name, absl::Span weights) = 0; + virtual void InsertFloat(std::string name, + absl::Span weights) = 0; virtual void AddScales(const std::vector& scales) = 0; virtual void Write(std::string path) = 0; @@ -77,6 +79,10 @@ class SbsWriterImpl : public WriterInterface { bf16_streams_.push_back(AllocateAndCompress(name, weights)); } + void InsertFloat(std::string name, absl::Span weights) override { + f32_streams_.push_back(AllocateAndCompress(name, weights)); + } + void AddScales(const std::vector& scales) override { HWY_ASSERT(scales_.empty()); scales_ = scales; @@ -93,6 +99,7 @@ class SbsWriterImpl : public WriterInterface { std::vector> sfp_streams_; std::vector> nuq_streams_; std::vector> bf16_streams_; + std::vector> f32_streams_; std::vector scales_; }; @@ -120,6 +127,9 @@ void SbsWriter::InsertBfloat16(std::string name, absl::Span weights) { impl_->InsertBfloat16(name, weights); } +void SbsWriter::InsertFloat(std::string name, absl::Span weights) { + impl_->InsertFloat(name, weights); +} void SbsWriter::AddScales(const std::vector& scales) { impl_->AddScales(scales); diff --git a/compression/python/compression_clif_aux.h b/compression/python/compression_clif_aux.h index 34332c3..8dc7a9d 100644 --- a/compression/python/compression_clif_aux.h +++ b/compression/python/compression_clif_aux.h @@ -19,6 +19,7 @@ class SbsWriter { void Insert(std::string name, absl::Span weights); void InsertNUQ(std::string name, absl::Span weights); void InsertBfloat16(std::string name, absl::Span weights); + void InsertFloat(std::string name, absl::Span weights); void AddScales(const std::vector& scales); void Write(std::string path); diff --git a/compression/python/compression_test.py b/compression/python/compression_test.py index 98f6c0d..7b7ff12 100644 --- a/compression/python/compression_test.py +++ b/compression/python/compression_test.py @@ -24,6 +24,9 @@ class CompressionTest(unittest.TestCase): writer.insert_bf16( "qux", np.array([0.000375] * 128 + [0.00007] * 128, dtype=np.float32) ) + writer.insert_float( + "quux", np.array([0.000375] * 128 + [0.00006] * 128, dtype=np.float32) + ) writer.write(temp_file.full_path)