Add Py bindings for weight compression

TODO: this uses clif instead of pybind11, and depends on absl.

PiperOrigin-RevId: 649575815
This commit is contained in:
Jan Wassenberg 2024-07-05 01:05:17 -07:00 committed by Copybara-Service
parent 118e802b00
commit 41efec4dba
5 changed files with 240 additions and 0 deletions

42
compression/python/BUILD Normal file
View File

@ -0,0 +1,42 @@
load("//devtools/clif/python:clif_build_rule.bzl", "py_clif_cc")
# [internal] load strict.bzl
package(
default_applicable_licenses = [
"//:license", # Placeholder comment, do not modify
],
default_visibility = ["//visibility:public"],
)
cc_library(
name = "compression_clif_aux",
srcs = ["compression_clif_aux.cc"],
hdrs = ["compression_clif_aux.h"],
deps = [
"//third_party/absl/types:span",
"//compression:compress",
"//compression:io",
"@hwy//:hwy",
"@hwy//:thread_pool",
],
)
py_clif_cc(
name = "compression",
srcs = ["compression.clif"],
deps = [
":compression_clif_aux",
"//third_party/absl/python/numpy:span_clif_lib",
],
)
# py_strict
py_test(
name = "compression_test",
srcs = ["compression_test.py"],
deps = [
":compression",
"//testing/pybase",
"//third_party/py/numpy",
],
)

View File

@ -0,0 +1,13 @@
from "third_party/absl/python/numpy/span.h" import *
from "third_party/gemma_cpp/compression/python/compression_clif_aux.h":
namespace `gcpp`:
class SbsWriter:
# NOTE: Individual compression backends may impose constraints on the
# array length, such as a minimum of (say) 32 elements.
def `Insert` as insert(self, name: str, weights: NumpyArray<float>)
def `InsertNUQ` as insert_nuq(self, name: str, weights: NumpyArray<float>)
def `InsertBfloat16` as insert_bf16(self, name: str, weights: NumpyArray<float>)
def `AddScales` as add_scales(self, scales: list<float>)
def `Write` as write(self, path: str)

View File

@ -0,0 +1,121 @@
#include "compression/python/compression_clif_aux.h"
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE \
"third_party/gemma_cpp/compression/python/compression_clif_aux.cc" // NOLINT
#include "hwy/foreach_target.h" // IWYU pragma: keep
// Must come after foreach_target.h to avoid redefinition errors.
#include "compression/compress-inl.h"
#include "hwy/highway.h"
// Non-SIMD includes and types. Note that HWY_ONCE is only true on the last
// compile pass, whereas we want this defined in the first.
#ifndef GEMMA_ONCE
#define GEMMA_ONCE
#include "compression/io.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
namespace gcpp {
class WriterInterface {
public:
virtual ~WriterInterface() = default;
virtual void Insert(std::string name, absl::Span<const float> weights) = 0;
virtual void InsertNUQ(std::string name, absl::Span<const float> weights) = 0;
virtual void InsertBfloat16(std::string name,
absl::Span<const float> weights) = 0;
virtual void AddScales(const std::vector<float>& scales) = 0;
virtual void Write(std::string path) = 0;
};
} // namespace gcpp
#endif // GEMMA_ONCE
// SIMD code, compiled once per target.
HWY_BEFORE_NAMESPACE();
namespace gcpp {
namespace HWY_NAMESPACE {
class SbsWriterImpl : public WriterInterface {
public:
SbsWriterImpl() : pool_(0), compressor_(pool_) {}
void Insert(std::string name, absl::Span<const float> weights) override {
const size_t out_size = CompressedArraySize<SfpStream>(weights.size());
sfp_streams_.push_back(std::vector<SfpStream>(out_size));
compressor_.Insert<SfpStream>(name.data(), weights.data(), weights.size(),
working_set_, out_size,
sfp_streams_.back().data(), 0, pool_);
}
void InsertNUQ(std::string name, absl::Span<const float> weights) override {
const size_t out_size = CompressedArraySize<NuqStream>(weights.size());
nuq_streams_.push_back(std::vector<NuqStream>(out_size));
compressor_.Insert<NuqStream>(name.data(), weights.data(), weights.size(),
working_set_, out_size,
nuq_streams_.back().data(), 0, pool_);
}
void InsertBfloat16(std::string name,
absl::Span<const float> weights) override {
const size_t out_size =
CompressedArraySize<hwy::bfloat16_t>(weights.size());
bf16_streams_.push_back(std::vector<hwy::bfloat16_t>(out_size));
compressor_.Insert<hwy::bfloat16_t>(name.data(), weights.data(),
weights.size(), working_set_, out_size,
bf16_streams_.back().data(), 0, pool_);
}
void AddScales(const std::vector<float>& scales) override {
compressor_.AddScales(scales.data(), scales.size());
}
void Write(std::string path) override {
compressor_.WriteAll(pool_, gcpp::Path(path));
}
hwy::ThreadPool pool_;
Compressor compressor_;
CompressWorkingSet working_set_;
std::vector<std::vector<SfpStream>> sfp_streams_;
std::vector<std::vector<NuqStream>> nuq_streams_;
std::vector<std::vector<hwy::bfloat16_t>> bf16_streams_;
};
WriterInterface* NewSbsWriter() { return new SbsWriterImpl(); }
} // namespace HWY_NAMESPACE
} // namespace gcpp
HWY_AFTER_NAMESPACE();
#if HWY_ONCE
namespace gcpp {
HWY_EXPORT(NewSbsWriter);
SbsWriter::SbsWriter() : impl_(HWY_DYNAMIC_DISPATCH(NewSbsWriter)()) {}
SbsWriter::~SbsWriter() = default;
void SbsWriter::Insert(std::string name, absl::Span<const float> weights) {
impl_->Insert(name, weights);
}
void SbsWriter::InsertNUQ(std::string name, absl::Span<const float> weights) {
impl_->InsertNUQ(name, weights);
}
void SbsWriter::InsertBfloat16(std::string name,
absl::Span<const float> weights) {
impl_->InsertBfloat16(name, weights);
}
void SbsWriter::AddScales(const std::vector<float>& scales) {
impl_->AddScales(scales);
}
void SbsWriter::Write(std::string path) { impl_->Write(path); }
} // namespace gcpp
#endif // HWY_ONCE

View File

@ -0,0 +1,33 @@
#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_PYTHON_COMPRESSION_CLIF_AUX_H_
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_PYTHON_COMPRESSION_CLIF_AUX_H_
#include <memory>
#include <string>
#include <vector>
#include "third_party/absl/types/span.h"
namespace gcpp {
class WriterInterface;
class SbsWriter {
public:
SbsWriter();
~SbsWriter();
void Insert(std::string name, absl::Span<const float> weights);
void InsertNUQ(std::string name, absl::Span<const float> weights);
void InsertBfloat16(std::string name, absl::Span<const float> weights);
void AddScales(const std::vector<float>& scales);
void Write(std::string path);
private:
// Isolates Highway-dispatched types and other internals from CLIF.
std::unique_ptr<WriterInterface> impl_;
};
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_PYTHON_COMPRESSION_CLIF_AUX_H_

View File

@ -0,0 +1,31 @@
"""Tests for CLIF wrapped .sbs writer."""
import numpy as np
import unittest
from compression.python import compression
class CompressionTest(unittest.TestCase):
def test_sbs_writer(self):
temp_file = self.create_tempfile("test.sbs")
writer = compression.SbsWriter()
writer.insert(
"foo", np.array([0.0012] * 128 + [0.001] * 64, dtype=np.float32)
)
writer.insert(
"bar", np.array([0.000375] * 128 + [0.00009] * 128, dtype=np.float32)
)
writer.insert_nuq(
"baz", np.array([0.000125] * 128 + [0.00008] * 128, dtype=np.float32)
)
writer.insert_bf16(
"qux", np.array([0.000375] * 128 + [0.00007] * 128, dtype=np.float32)
)
writer.write(temp_file.full_path)
if __name__ == "__main__":
unittest.main()