mirror of https://github.com/google/gemma.cpp.git
Add Py bindings for weight compression
TODO: this uses clif instead of pybind11, and depends on absl. PiperOrigin-RevId: 649575815
This commit is contained in:
parent
118e802b00
commit
41efec4dba
|
|
@ -0,0 +1,42 @@
|
||||||
|
load("//devtools/clif/python:clif_build_rule.bzl", "py_clif_cc")
|
||||||
|
# [internal] load strict.bzl
|
||||||
|
|
||||||
|
package(
|
||||||
|
default_applicable_licenses = [
|
||||||
|
"//:license", # Placeholder comment, do not modify
|
||||||
|
],
|
||||||
|
default_visibility = ["//visibility:public"],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "compression_clif_aux",
|
||||||
|
srcs = ["compression_clif_aux.cc"],
|
||||||
|
hdrs = ["compression_clif_aux.h"],
|
||||||
|
deps = [
|
||||||
|
"//third_party/absl/types:span",
|
||||||
|
"//compression:compress",
|
||||||
|
"//compression:io",
|
||||||
|
"@hwy//:hwy",
|
||||||
|
"@hwy//:thread_pool",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_clif_cc(
|
||||||
|
name = "compression",
|
||||||
|
srcs = ["compression.clif"],
|
||||||
|
deps = [
|
||||||
|
":compression_clif_aux",
|
||||||
|
"//third_party/absl/python/numpy:span_clif_lib",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
# py_strict
|
||||||
|
py_test(
|
||||||
|
name = "compression_test",
|
||||||
|
srcs = ["compression_test.py"],
|
||||||
|
deps = [
|
||||||
|
":compression",
|
||||||
|
"//testing/pybase",
|
||||||
|
"//third_party/py/numpy",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
@ -0,0 +1,13 @@
|
||||||
|
from "third_party/absl/python/numpy/span.h" import *
|
||||||
|
from "third_party/gemma_cpp/compression/python/compression_clif_aux.h":
|
||||||
|
namespace `gcpp`:
|
||||||
|
class SbsWriter:
|
||||||
|
# NOTE: Individual compression backends may impose constraints on the
|
||||||
|
# array length, such as a minimum of (say) 32 elements.
|
||||||
|
def `Insert` as insert(self, name: str, weights: NumpyArray<float>)
|
||||||
|
def `InsertNUQ` as insert_nuq(self, name: str, weights: NumpyArray<float>)
|
||||||
|
def `InsertBfloat16` as insert_bf16(self, name: str, weights: NumpyArray<float>)
|
||||||
|
|
||||||
|
def `AddScales` as add_scales(self, scales: list<float>)
|
||||||
|
|
||||||
|
def `Write` as write(self, path: str)
|
||||||
|
|
@ -0,0 +1,121 @@
|
||||||
|
#include "compression/python/compression_clif_aux.h"
|
||||||
|
|
||||||
|
#undef HWY_TARGET_INCLUDE
|
||||||
|
#define HWY_TARGET_INCLUDE \
|
||||||
|
"third_party/gemma_cpp/compression/python/compression_clif_aux.cc" // NOLINT
|
||||||
|
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||||
|
// Must come after foreach_target.h to avoid redefinition errors.
|
||||||
|
#include "compression/compress-inl.h"
|
||||||
|
#include "hwy/highway.h"
|
||||||
|
|
||||||
|
// Non-SIMD includes and types. Note that HWY_ONCE is only true on the last
|
||||||
|
// compile pass, whereas we want this defined in the first.
|
||||||
|
#ifndef GEMMA_ONCE
|
||||||
|
#define GEMMA_ONCE
|
||||||
|
|
||||||
|
#include "compression/io.h"
|
||||||
|
#include "hwy/base.h"
|
||||||
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
|
|
||||||
|
namespace gcpp {
|
||||||
|
|
||||||
|
class WriterInterface {
|
||||||
|
public:
|
||||||
|
virtual ~WriterInterface() = default;
|
||||||
|
|
||||||
|
virtual void Insert(std::string name, absl::Span<const float> weights) = 0;
|
||||||
|
virtual void InsertNUQ(std::string name, absl::Span<const float> weights) = 0;
|
||||||
|
virtual void InsertBfloat16(std::string name,
|
||||||
|
absl::Span<const float> weights) = 0;
|
||||||
|
virtual void AddScales(const std::vector<float>& scales) = 0;
|
||||||
|
|
||||||
|
virtual void Write(std::string path) = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace gcpp
|
||||||
|
|
||||||
|
#endif // GEMMA_ONCE
|
||||||
|
|
||||||
|
// SIMD code, compiled once per target.
|
||||||
|
HWY_BEFORE_NAMESPACE();
|
||||||
|
namespace gcpp {
|
||||||
|
namespace HWY_NAMESPACE {
|
||||||
|
|
||||||
|
class SbsWriterImpl : public WriterInterface {
|
||||||
|
public:
|
||||||
|
SbsWriterImpl() : pool_(0), compressor_(pool_) {}
|
||||||
|
|
||||||
|
void Insert(std::string name, absl::Span<const float> weights) override {
|
||||||
|
const size_t out_size = CompressedArraySize<SfpStream>(weights.size());
|
||||||
|
sfp_streams_.push_back(std::vector<SfpStream>(out_size));
|
||||||
|
compressor_.Insert<SfpStream>(name.data(), weights.data(), weights.size(),
|
||||||
|
working_set_, out_size,
|
||||||
|
sfp_streams_.back().data(), 0, pool_);
|
||||||
|
}
|
||||||
|
|
||||||
|
void InsertNUQ(std::string name, absl::Span<const float> weights) override {
|
||||||
|
const size_t out_size = CompressedArraySize<NuqStream>(weights.size());
|
||||||
|
nuq_streams_.push_back(std::vector<NuqStream>(out_size));
|
||||||
|
compressor_.Insert<NuqStream>(name.data(), weights.data(), weights.size(),
|
||||||
|
working_set_, out_size,
|
||||||
|
nuq_streams_.back().data(), 0, pool_);
|
||||||
|
}
|
||||||
|
|
||||||
|
void InsertBfloat16(std::string name,
|
||||||
|
absl::Span<const float> weights) override {
|
||||||
|
const size_t out_size =
|
||||||
|
CompressedArraySize<hwy::bfloat16_t>(weights.size());
|
||||||
|
bf16_streams_.push_back(std::vector<hwy::bfloat16_t>(out_size));
|
||||||
|
compressor_.Insert<hwy::bfloat16_t>(name.data(), weights.data(),
|
||||||
|
weights.size(), working_set_, out_size,
|
||||||
|
bf16_streams_.back().data(), 0, pool_);
|
||||||
|
}
|
||||||
|
|
||||||
|
void AddScales(const std::vector<float>& scales) override {
|
||||||
|
compressor_.AddScales(scales.data(), scales.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
void Write(std::string path) override {
|
||||||
|
compressor_.WriteAll(pool_, gcpp::Path(path));
|
||||||
|
}
|
||||||
|
|
||||||
|
hwy::ThreadPool pool_;
|
||||||
|
Compressor compressor_;
|
||||||
|
CompressWorkingSet working_set_;
|
||||||
|
std::vector<std::vector<SfpStream>> sfp_streams_;
|
||||||
|
std::vector<std::vector<NuqStream>> nuq_streams_;
|
||||||
|
std::vector<std::vector<hwy::bfloat16_t>> bf16_streams_;
|
||||||
|
};
|
||||||
|
|
||||||
|
WriterInterface* NewSbsWriter() { return new SbsWriterImpl(); }
|
||||||
|
|
||||||
|
} // namespace HWY_NAMESPACE
|
||||||
|
} // namespace gcpp
|
||||||
|
HWY_AFTER_NAMESPACE();
|
||||||
|
|
||||||
|
#if HWY_ONCE
|
||||||
|
namespace gcpp {
|
||||||
|
|
||||||
|
HWY_EXPORT(NewSbsWriter);
|
||||||
|
|
||||||
|
SbsWriter::SbsWriter() : impl_(HWY_DYNAMIC_DISPATCH(NewSbsWriter)()) {}
|
||||||
|
SbsWriter::~SbsWriter() = default;
|
||||||
|
|
||||||
|
void SbsWriter::Insert(std::string name, absl::Span<const float> weights) {
|
||||||
|
impl_->Insert(name, weights);
|
||||||
|
}
|
||||||
|
void SbsWriter::InsertNUQ(std::string name, absl::Span<const float> weights) {
|
||||||
|
impl_->InsertNUQ(name, weights);
|
||||||
|
}
|
||||||
|
void SbsWriter::InsertBfloat16(std::string name,
|
||||||
|
absl::Span<const float> weights) {
|
||||||
|
impl_->InsertBfloat16(name, weights);
|
||||||
|
}
|
||||||
|
|
||||||
|
void SbsWriter::AddScales(const std::vector<float>& scales) {
|
||||||
|
impl_->AddScales(scales);
|
||||||
|
}
|
||||||
|
void SbsWriter::Write(std::string path) { impl_->Write(path); }
|
||||||
|
|
||||||
|
} // namespace gcpp
|
||||||
|
#endif // HWY_ONCE
|
||||||
|
|
@ -0,0 +1,33 @@
|
||||||
|
#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_PYTHON_COMPRESSION_CLIF_AUX_H_
|
||||||
|
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_PYTHON_COMPRESSION_CLIF_AUX_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "third_party/absl/types/span.h"
|
||||||
|
|
||||||
|
namespace gcpp {
|
||||||
|
|
||||||
|
class WriterInterface;
|
||||||
|
|
||||||
|
class SbsWriter {
|
||||||
|
public:
|
||||||
|
SbsWriter();
|
||||||
|
~SbsWriter();
|
||||||
|
|
||||||
|
void Insert(std::string name, absl::Span<const float> weights);
|
||||||
|
void InsertNUQ(std::string name, absl::Span<const float> weights);
|
||||||
|
void InsertBfloat16(std::string name, absl::Span<const float> weights);
|
||||||
|
void AddScales(const std::vector<float>& scales);
|
||||||
|
|
||||||
|
void Write(std::string path);
|
||||||
|
|
||||||
|
private:
|
||||||
|
// Isolates Highway-dispatched types and other internals from CLIF.
|
||||||
|
std::unique_ptr<WriterInterface> impl_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace gcpp
|
||||||
|
|
||||||
|
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_PYTHON_COMPRESSION_CLIF_AUX_H_
|
||||||
|
|
@ -0,0 +1,31 @@
|
||||||
|
"""Tests for CLIF wrapped .sbs writer."""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
from compression.python import compression
|
||||||
|
|
||||||
|
|
||||||
|
class CompressionTest(unittest.TestCase):
|
||||||
|
|
||||||
|
def test_sbs_writer(self):
|
||||||
|
temp_file = self.create_tempfile("test.sbs")
|
||||||
|
|
||||||
|
writer = compression.SbsWriter()
|
||||||
|
writer.insert(
|
||||||
|
"foo", np.array([0.0012] * 128 + [0.001] * 64, dtype=np.float32)
|
||||||
|
)
|
||||||
|
writer.insert(
|
||||||
|
"bar", np.array([0.000375] * 128 + [0.00009] * 128, dtype=np.float32)
|
||||||
|
)
|
||||||
|
writer.insert_nuq(
|
||||||
|
"baz", np.array([0.000125] * 128 + [0.00008] * 128, dtype=np.float32)
|
||||||
|
)
|
||||||
|
writer.insert_bf16(
|
||||||
|
"qux", np.array([0.000375] * 128 + [0.00007] * 128, dtype=np.float32)
|
||||||
|
)
|
||||||
|
writer.write(temp_file.full_path)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Loading…
Reference in New Issue