Replace CLIF SbsWriter with pybind-based gcpp extension

Maintains compatibility with previous version.

PiperOrigin-RevId: 696181603
This commit is contained in:
Paul Chang 2024-11-13 10:18:11 -08:00 committed by Copybara-Service
parent 719699f132
commit 5674c33dc5
6 changed files with 47 additions and 21 deletions

View File

@ -9,6 +9,7 @@ bazel_dep(name = "googletest", version = "1.15.2")
bazel_dep(name = "highway", version = "1.1.0")
bazel_dep(name = "nlohmann_json", version = "3.11.3")
bazel_dep(name = "platforms", version = "0.0.10")
bazel_dep(name = "pybind11_bazel", version = "2.12.0")
bazel_dep(name = "rules_cc", version = "0.0.9")
bazel_dep(name = "rules_license", version = "0.0.7")
bazel_dep(name = "google_benchmark", version = "1.8.5")

View File

@ -1,5 +1,5 @@
load("//devtools/clif/python:clif_build_rule.bzl", "py_clif_cc")
# [internal] load strict.bzl
load("@pybind11_bazel//:build_defs.bzl", "pybind_extension")
package(
default_applicable_licenses = [
@ -12,8 +12,9 @@ cc_library(
name = "compression_clif_aux",
srcs = ["compression_clif_aux.cc"],
hdrs = ["compression_clif_aux.h"],
visibility = ["//visibility:private"],
deps = [
"//third_party/absl/types:span",
"@abseil-cpp//absl/types:span",
"//compression:compress",
"//compression:io",
"@highway//:hwy",
@ -21,12 +22,12 @@ cc_library(
],
)
py_clif_cc(
pybind_extension(
name = "compression",
srcs = ["compression.clif"],
srcs = ["compression_extension.cc"],
deps = [
":compression_clif_aux",
"//third_party/absl/python/numpy:span_clif_lib",
"@abseil-cpp//absl/types:span",
],
)

View File

@ -1,14 +0,0 @@
from "third_party/absl/python/numpy/span.h" import *
from "third_party/gemma_cpp/compression/python/compression_clif_aux.h":
namespace `gcpp`:
class SbsWriter:
# NOTE: Individual compression backends may impose constraints on the
# array length, such as a minimum of (say) 32 elements.
def `Insert` as insert(self, name: str, weights: NumpyArray<float>)
def `InsertNUQ` as insert_nuq(self, name: str, weights: NumpyArray<float>)
def `InsertBfloat16` as insert_bf16(self, name: str, weights: NumpyArray<float>)
def `InsertFloat` as insert_float(self, name: str, weights: NumpyArray<float>)
def `AddScales` as add_scales(self, scales: list<float>)
def `Write` as write(self, path: str)

View File

@ -20,7 +20,7 @@
#ifndef GEMMA_ONCE
#define GEMMA_ONCE
#include "third_party/absl/types/span.h"
#include "absl/types/span.h"
#include "compression/io.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"

View File

@ -5,7 +5,7 @@
#include <string>
#include <vector>
#include "third_party/absl/types/span.h"
#include "absl/types/span.h"
namespace gcpp {

View File

@ -0,0 +1,38 @@
#include <pybind11/pybind11.h>
#include <exception>
#include <stdexcept>
#include <string>
#include "absl/types/span.h"
#include "compression/python/compression_clif_aux.h"
#include "pybind11/numpy.h"
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
using gcpp::SbsWriter;
namespace py = pybind11;
namespace {
template <auto Func>
void wrap_span(SbsWriter& writer, std::string name, py::array_t<float> data) {
if (data.ndim() != 1 || data.strides(0) != sizeof(float)) {
throw std::domain_error("Input array must be 1D and densely packed.");
}
std::invoke(Func, writer, name, absl::MakeSpan(data.data(0), data.size()));
}
} // namespace
PYBIND11_MODULE(compression, m) {
py::class_<SbsWriter>(m, "SbsWriter")
.def(py::init<>())
// NOTE: Individual compression backends may impose constraints on the
// array length, such as a minimum of (say) 32 elements.
.def("insert", wrap_span<&SbsWriter::Insert>)
.def("insert_nuq", wrap_span<&SbsWriter::InsertNUQ>)
.def("insert_bf16", wrap_span<&SbsWriter::InsertBfloat16>)
.def("insert_float", wrap_span<&SbsWriter::InsertFloat>)
.def("add_scales", &SbsWriter::AddScales)
.def("write", &SbsWriter::Write);
}