From 41efec4dba849ffecc6fe60123d8e84c60d6d6d4 Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Fri, 5 Jul 2024 01:05:17 -0700 Subject: [PATCH] Add Py bindings for weight compression TODO: this uses clif instead of pybind11, and depends on absl. PiperOrigin-RevId: 649575815 --- compression/python/BUILD | 42 +++++++ compression/python/compression.clif | 13 +++ compression/python/compression_clif_aux.cc | 121 +++++++++++++++++++++ compression/python/compression_clif_aux.h | 33 ++++++ compression/python/compression_test.py | 31 ++++++ 5 files changed, 240 insertions(+) create mode 100644 compression/python/BUILD create mode 100644 compression/python/compression.clif create mode 100644 compression/python/compression_clif_aux.cc create mode 100644 compression/python/compression_clif_aux.h create mode 100644 compression/python/compression_test.py diff --git a/compression/python/BUILD b/compression/python/BUILD new file mode 100644 index 0000000..bdf67f9 --- /dev/null +++ b/compression/python/BUILD @@ -0,0 +1,42 @@ +load("//devtools/clif/python:clif_build_rule.bzl", "py_clif_cc") +# [internal] load strict.bzl + +package( + default_applicable_licenses = [ + "//:license", # Placeholder comment, do not modify + ], + default_visibility = ["//visibility:public"], +) + +cc_library( + name = "compression_clif_aux", + srcs = ["compression_clif_aux.cc"], + hdrs = ["compression_clif_aux.h"], + deps = [ + "//third_party/absl/types:span", + "//compression:compress", + "//compression:io", + "@hwy//:hwy", + "@hwy//:thread_pool", + ], +) + +py_clif_cc( + name = "compression", + srcs = ["compression.clif"], + deps = [ + ":compression_clif_aux", + "//third_party/absl/python/numpy:span_clif_lib", + ], +) + +# py_strict +py_test( + name = "compression_test", + srcs = ["compression_test.py"], + deps = [ + ":compression", + "//testing/pybase", + "//third_party/py/numpy", + ], +) diff --git a/compression/python/compression.clif b/compression/python/compression.clif new file mode 100644 index 0000000..ea31862 --- /dev/null +++ b/compression/python/compression.clif @@ -0,0 +1,13 @@ +from "third_party/absl/python/numpy/span.h" import * +from "third_party/gemma_cpp/compression/python/compression_clif_aux.h": + namespace `gcpp`: + class SbsWriter: + # NOTE: Individual compression backends may impose constraints on the + # array length, such as a minimum of (say) 32 elements. + def `Insert` as insert(self, name: str, weights: NumpyArray) + def `InsertNUQ` as insert_nuq(self, name: str, weights: NumpyArray) + def `InsertBfloat16` as insert_bf16(self, name: str, weights: NumpyArray) + + def `AddScales` as add_scales(self, scales: list) + + def `Write` as write(self, path: str) diff --git a/compression/python/compression_clif_aux.cc b/compression/python/compression_clif_aux.cc new file mode 100644 index 0000000..c017e64 --- /dev/null +++ b/compression/python/compression_clif_aux.cc @@ -0,0 +1,121 @@ +#include "compression/python/compression_clif_aux.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE \ + "third_party/gemma_cpp/compression/python/compression_clif_aux.cc" // NOLINT +#include "hwy/foreach_target.h" // IWYU pragma: keep +// Must come after foreach_target.h to avoid redefinition errors. +#include "compression/compress-inl.h" +#include "hwy/highway.h" + +// Non-SIMD includes and types. Note that HWY_ONCE is only true on the last +// compile pass, whereas we want this defined in the first. +#ifndef GEMMA_ONCE +#define GEMMA_ONCE + +#include "compression/io.h" +#include "hwy/base.h" +#include "hwy/contrib/thread_pool/thread_pool.h" + +namespace gcpp { + +class WriterInterface { + public: + virtual ~WriterInterface() = default; + + virtual void Insert(std::string name, absl::Span weights) = 0; + virtual void InsertNUQ(std::string name, absl::Span weights) = 0; + virtual void InsertBfloat16(std::string name, + absl::Span weights) = 0; + virtual void AddScales(const std::vector& scales) = 0; + + virtual void Write(std::string path) = 0; +}; + +} // namespace gcpp + +#endif // GEMMA_ONCE + +// SIMD code, compiled once per target. +HWY_BEFORE_NAMESPACE(); +namespace gcpp { +namespace HWY_NAMESPACE { + +class SbsWriterImpl : public WriterInterface { + public: + SbsWriterImpl() : pool_(0), compressor_(pool_) {} + + void Insert(std::string name, absl::Span weights) override { + const size_t out_size = CompressedArraySize(weights.size()); + sfp_streams_.push_back(std::vector(out_size)); + compressor_.Insert(name.data(), weights.data(), weights.size(), + working_set_, out_size, + sfp_streams_.back().data(), 0, pool_); + } + + void InsertNUQ(std::string name, absl::Span weights) override { + const size_t out_size = CompressedArraySize(weights.size()); + nuq_streams_.push_back(std::vector(out_size)); + compressor_.Insert(name.data(), weights.data(), weights.size(), + working_set_, out_size, + nuq_streams_.back().data(), 0, pool_); + } + + void InsertBfloat16(std::string name, + absl::Span weights) override { + const size_t out_size = + CompressedArraySize(weights.size()); + bf16_streams_.push_back(std::vector(out_size)); + compressor_.Insert(name.data(), weights.data(), + weights.size(), working_set_, out_size, + bf16_streams_.back().data(), 0, pool_); + } + + void AddScales(const std::vector& scales) override { + compressor_.AddScales(scales.data(), scales.size()); + } + + void Write(std::string path) override { + compressor_.WriteAll(pool_, gcpp::Path(path)); + } + + hwy::ThreadPool pool_; + Compressor compressor_; + CompressWorkingSet working_set_; + std::vector> sfp_streams_; + std::vector> nuq_streams_; + std::vector> bf16_streams_; +}; + +WriterInterface* NewSbsWriter() { return new SbsWriterImpl(); } + +} // namespace HWY_NAMESPACE +} // namespace gcpp +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace gcpp { + +HWY_EXPORT(NewSbsWriter); + +SbsWriter::SbsWriter() : impl_(HWY_DYNAMIC_DISPATCH(NewSbsWriter)()) {} +SbsWriter::~SbsWriter() = default; + +void SbsWriter::Insert(std::string name, absl::Span weights) { + impl_->Insert(name, weights); +} +void SbsWriter::InsertNUQ(std::string name, absl::Span weights) { + impl_->InsertNUQ(name, weights); +} +void SbsWriter::InsertBfloat16(std::string name, + absl::Span weights) { + impl_->InsertBfloat16(name, weights); +} + +void SbsWriter::AddScales(const std::vector& scales) { + impl_->AddScales(scales); +} +void SbsWriter::Write(std::string path) { impl_->Write(path); } + +} // namespace gcpp +#endif // HWY_ONCE diff --git a/compression/python/compression_clif_aux.h b/compression/python/compression_clif_aux.h new file mode 100644 index 0000000..34332c3 --- /dev/null +++ b/compression/python/compression_clif_aux.h @@ -0,0 +1,33 @@ +#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_PYTHON_COMPRESSION_CLIF_AUX_H_ +#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_PYTHON_COMPRESSION_CLIF_AUX_H_ + +#include +#include +#include + +#include "third_party/absl/types/span.h" + +namespace gcpp { + +class WriterInterface; + +class SbsWriter { + public: + SbsWriter(); + ~SbsWriter(); + + void Insert(std::string name, absl::Span weights); + void InsertNUQ(std::string name, absl::Span weights); + void InsertBfloat16(std::string name, absl::Span weights); + void AddScales(const std::vector& scales); + + void Write(std::string path); + + private: + // Isolates Highway-dispatched types and other internals from CLIF. + std::unique_ptr impl_; +}; + +} // namespace gcpp + +#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_PYTHON_COMPRESSION_CLIF_AUX_H_ diff --git a/compression/python/compression_test.py b/compression/python/compression_test.py new file mode 100644 index 0000000..98f6c0d --- /dev/null +++ b/compression/python/compression_test.py @@ -0,0 +1,31 @@ +"""Tests for CLIF wrapped .sbs writer.""" + +import numpy as np + +import unittest +from compression.python import compression + + +class CompressionTest(unittest.TestCase): + + def test_sbs_writer(self): + temp_file = self.create_tempfile("test.sbs") + + writer = compression.SbsWriter() + writer.insert( + "foo", np.array([0.0012] * 128 + [0.001] * 64, dtype=np.float32) + ) + writer.insert( + "bar", np.array([0.000375] * 128 + [0.00009] * 128, dtype=np.float32) + ) + writer.insert_nuq( + "baz", np.array([0.000125] * 128 + [0.00008] * 128, dtype=np.float32) + ) + writer.insert_bf16( + "qux", np.array([0.000375] * 128 + [0.00007] * 128, dtype=np.float32) + ) + writer.write(temp_file.full_path) + + +if __name__ == "__main__": + unittest.main()