mirror of https://github.com/google/gemma.cpp.git
Replace CLIF SbsWriter with pybind-based gcpp extension
Maintains compatibility with previous version. PiperOrigin-RevId: 696181603
This commit is contained in:
parent
719699f132
commit
5674c33dc5
|
|
@ -9,6 +9,7 @@ bazel_dep(name = "googletest", version = "1.15.2")
|
|||
bazel_dep(name = "highway", version = "1.1.0")
|
||||
bazel_dep(name = "nlohmann_json", version = "3.11.3")
|
||||
bazel_dep(name = "platforms", version = "0.0.10")
|
||||
bazel_dep(name = "pybind11_bazel", version = "2.12.0")
|
||||
bazel_dep(name = "rules_cc", version = "0.0.9")
|
||||
bazel_dep(name = "rules_license", version = "0.0.7")
|
||||
bazel_dep(name = "google_benchmark", version = "1.8.5")
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
load("//devtools/clif/python:clif_build_rule.bzl", "py_clif_cc")
|
||||
# [internal] load strict.bzl
|
||||
load("@pybind11_bazel//:build_defs.bzl", "pybind_extension")
|
||||
|
||||
package(
|
||||
default_applicable_licenses = [
|
||||
|
|
@ -12,8 +12,9 @@ cc_library(
|
|||
name = "compression_clif_aux",
|
||||
srcs = ["compression_clif_aux.cc"],
|
||||
hdrs = ["compression_clif_aux.h"],
|
||||
visibility = ["//visibility:private"],
|
||||
deps = [
|
||||
"//third_party/absl/types:span",
|
||||
"@abseil-cpp//absl/types:span",
|
||||
"//compression:compress",
|
||||
"//compression:io",
|
||||
"@highway//:hwy",
|
||||
|
|
@ -21,12 +22,12 @@ cc_library(
|
|||
],
|
||||
)
|
||||
|
||||
py_clif_cc(
|
||||
pybind_extension(
|
||||
name = "compression",
|
||||
srcs = ["compression.clif"],
|
||||
srcs = ["compression_extension.cc"],
|
||||
deps = [
|
||||
":compression_clif_aux",
|
||||
"//third_party/absl/python/numpy:span_clif_lib",
|
||||
"@abseil-cpp//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,14 +0,0 @@
|
|||
from "third_party/absl/python/numpy/span.h" import *
|
||||
from "third_party/gemma_cpp/compression/python/compression_clif_aux.h":
|
||||
namespace `gcpp`:
|
||||
class SbsWriter:
|
||||
# NOTE: Individual compression backends may impose constraints on the
|
||||
# array length, such as a minimum of (say) 32 elements.
|
||||
def `Insert` as insert(self, name: str, weights: NumpyArray<float>)
|
||||
def `InsertNUQ` as insert_nuq(self, name: str, weights: NumpyArray<float>)
|
||||
def `InsertBfloat16` as insert_bf16(self, name: str, weights: NumpyArray<float>)
|
||||
def `InsertFloat` as insert_float(self, name: str, weights: NumpyArray<float>)
|
||||
|
||||
def `AddScales` as add_scales(self, scales: list<float>)
|
||||
|
||||
def `Write` as write(self, path: str)
|
||||
|
|
@ -20,7 +20,7 @@
|
|||
#ifndef GEMMA_ONCE
|
||||
#define GEMMA_ONCE
|
||||
|
||||
#include "third_party/absl/types/span.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "compression/io.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@
|
|||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "third_party/absl/types/span.h"
|
||||
#include "absl/types/span.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,38 @@
|
|||
#include <pybind11/pybind11.h>
|
||||
|
||||
#include <exception>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "compression/python/compression_clif_aux.h"
|
||||
#include "pybind11/numpy.h"
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "pybind11/stl.h"
|
||||
|
||||
using gcpp::SbsWriter;
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
namespace {
|
||||
template <auto Func>
|
||||
void wrap_span(SbsWriter& writer, std::string name, py::array_t<float> data) {
|
||||
if (data.ndim() != 1 || data.strides(0) != sizeof(float)) {
|
||||
throw std::domain_error("Input array must be 1D and densely packed.");
|
||||
}
|
||||
std::invoke(Func, writer, name, absl::MakeSpan(data.data(0), data.size()));
|
||||
}
|
||||
} // namespace
|
||||
|
||||
PYBIND11_MODULE(compression, m) {
|
||||
py::class_<SbsWriter>(m, "SbsWriter")
|
||||
.def(py::init<>())
|
||||
// NOTE: Individual compression backends may impose constraints on the
|
||||
// array length, such as a minimum of (say) 32 elements.
|
||||
.def("insert", wrap_span<&SbsWriter::Insert>)
|
||||
.def("insert_nuq", wrap_span<&SbsWriter::InsertNUQ>)
|
||||
.def("insert_bf16", wrap_span<&SbsWriter::InsertBfloat16>)
|
||||
.def("insert_float", wrap_span<&SbsWriter::InsertFloat>)
|
||||
.def("add_scales", &SbsWriter::AddScales)
|
||||
.def("write", &SbsWriter::Write);
|
||||
}
|
||||
Loading…
Reference in New Issue