mirror of https://github.com/google/gemma.cpp.git
Replace CLIF SbsWriter with pybind-based gcpp extension
Maintains compatibility with previous version. PiperOrigin-RevId: 696181603
This commit is contained in:
parent
719699f132
commit
5674c33dc5
|
|
@ -9,6 +9,7 @@ bazel_dep(name = "googletest", version = "1.15.2")
|
||||||
bazel_dep(name = "highway", version = "1.1.0")
|
bazel_dep(name = "highway", version = "1.1.0")
|
||||||
bazel_dep(name = "nlohmann_json", version = "3.11.3")
|
bazel_dep(name = "nlohmann_json", version = "3.11.3")
|
||||||
bazel_dep(name = "platforms", version = "0.0.10")
|
bazel_dep(name = "platforms", version = "0.0.10")
|
||||||
|
bazel_dep(name = "pybind11_bazel", version = "2.12.0")
|
||||||
bazel_dep(name = "rules_cc", version = "0.0.9")
|
bazel_dep(name = "rules_cc", version = "0.0.9")
|
||||||
bazel_dep(name = "rules_license", version = "0.0.7")
|
bazel_dep(name = "rules_license", version = "0.0.7")
|
||||||
bazel_dep(name = "google_benchmark", version = "1.8.5")
|
bazel_dep(name = "google_benchmark", version = "1.8.5")
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
load("//devtools/clif/python:clif_build_rule.bzl", "py_clif_cc")
|
|
||||||
# [internal] load strict.bzl
|
# [internal] load strict.bzl
|
||||||
|
load("@pybind11_bazel//:build_defs.bzl", "pybind_extension")
|
||||||
|
|
||||||
package(
|
package(
|
||||||
default_applicable_licenses = [
|
default_applicable_licenses = [
|
||||||
|
|
@ -12,8 +12,9 @@ cc_library(
|
||||||
name = "compression_clif_aux",
|
name = "compression_clif_aux",
|
||||||
srcs = ["compression_clif_aux.cc"],
|
srcs = ["compression_clif_aux.cc"],
|
||||||
hdrs = ["compression_clif_aux.h"],
|
hdrs = ["compression_clif_aux.h"],
|
||||||
|
visibility = ["//visibility:private"],
|
||||||
deps = [
|
deps = [
|
||||||
"//third_party/absl/types:span",
|
"@abseil-cpp//absl/types:span",
|
||||||
"//compression:compress",
|
"//compression:compress",
|
||||||
"//compression:io",
|
"//compression:io",
|
||||||
"@highway//:hwy",
|
"@highway//:hwy",
|
||||||
|
|
@ -21,12 +22,12 @@ cc_library(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
py_clif_cc(
|
pybind_extension(
|
||||||
name = "compression",
|
name = "compression",
|
||||||
srcs = ["compression.clif"],
|
srcs = ["compression_extension.cc"],
|
||||||
deps = [
|
deps = [
|
||||||
":compression_clif_aux",
|
":compression_clif_aux",
|
||||||
"//third_party/absl/python/numpy:span_clif_lib",
|
"@abseil-cpp//absl/types:span",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,14 +0,0 @@
|
||||||
from "third_party/absl/python/numpy/span.h" import *
|
|
||||||
from "third_party/gemma_cpp/compression/python/compression_clif_aux.h":
|
|
||||||
namespace `gcpp`:
|
|
||||||
class SbsWriter:
|
|
||||||
# NOTE: Individual compression backends may impose constraints on the
|
|
||||||
# array length, such as a minimum of (say) 32 elements.
|
|
||||||
def `Insert` as insert(self, name: str, weights: NumpyArray<float>)
|
|
||||||
def `InsertNUQ` as insert_nuq(self, name: str, weights: NumpyArray<float>)
|
|
||||||
def `InsertBfloat16` as insert_bf16(self, name: str, weights: NumpyArray<float>)
|
|
||||||
def `InsertFloat` as insert_float(self, name: str, weights: NumpyArray<float>)
|
|
||||||
|
|
||||||
def `AddScales` as add_scales(self, scales: list<float>)
|
|
||||||
|
|
||||||
def `Write` as write(self, path: str)
|
|
||||||
|
|
@ -20,7 +20,7 @@
|
||||||
#ifndef GEMMA_ONCE
|
#ifndef GEMMA_ONCE
|
||||||
#define GEMMA_ONCE
|
#define GEMMA_ONCE
|
||||||
|
|
||||||
#include "third_party/absl/types/span.h"
|
#include "absl/types/span.h"
|
||||||
#include "compression/io.h"
|
#include "compression/io.h"
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "third_party/absl/types/span.h"
|
#include "absl/types/span.h"
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,38 @@
|
||||||
|
#include <pybind11/pybind11.h>
|
||||||
|
|
||||||
|
#include <exception>
|
||||||
|
#include <stdexcept>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "absl/types/span.h"
|
||||||
|
#include "compression/python/compression_clif_aux.h"
|
||||||
|
#include "pybind11/numpy.h"
|
||||||
|
#include "pybind11/pybind11.h"
|
||||||
|
#include "pybind11/stl.h"
|
||||||
|
|
||||||
|
using gcpp::SbsWriter;
|
||||||
|
|
||||||
|
namespace py = pybind11;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
template <auto Func>
|
||||||
|
void wrap_span(SbsWriter& writer, std::string name, py::array_t<float> data) {
|
||||||
|
if (data.ndim() != 1 || data.strides(0) != sizeof(float)) {
|
||||||
|
throw std::domain_error("Input array must be 1D and densely packed.");
|
||||||
|
}
|
||||||
|
std::invoke(Func, writer, name, absl::MakeSpan(data.data(0), data.size()));
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
PYBIND11_MODULE(compression, m) {
|
||||||
|
py::class_<SbsWriter>(m, "SbsWriter")
|
||||||
|
.def(py::init<>())
|
||||||
|
// NOTE: Individual compression backends may impose constraints on the
|
||||||
|
// array length, such as a minimum of (say) 32 elements.
|
||||||
|
.def("insert", wrap_span<&SbsWriter::Insert>)
|
||||||
|
.def("insert_nuq", wrap_span<&SbsWriter::InsertNUQ>)
|
||||||
|
.def("insert_bf16", wrap_span<&SbsWriter::InsertBfloat16>)
|
||||||
|
.def("insert_float", wrap_span<&SbsWriter::InsertFloat>)
|
||||||
|
.def("add_scales", &SbsWriter::AddScales)
|
||||||
|
.def("write", &SbsWriter::Write);
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue