mirror of https://github.com/google/gemma.cpp.git
Add python wrappers for configs and inference.
Enable building compression/python/compression_test using bazel. Add default image path for image_test and paligemma_test. PiperOrigin-RevId: 720583438
This commit is contained in:
parent
bcdb0d65bd
commit
7af2e70321
30
MODULE.bazel
30
MODULE.bazel
|
|
@ -4,14 +4,15 @@ module(
|
|||
)
|
||||
|
||||
bazel_dep(name = "abseil-cpp", version = "20240722.0")
|
||||
bazel_dep(name = "bazel_skylib", version = "1.6.1")
|
||||
bazel_dep(name = "bazel_skylib", version = "1.7.1")
|
||||
bazel_dep(name = "googletest", version = "1.15.2")
|
||||
bazel_dep(name = "highway", version = "1.1.0")
|
||||
bazel_dep(name = "nlohmann_json", version = "3.11.3")
|
||||
bazel_dep(name = "platforms", version = "0.0.10")
|
||||
bazel_dep(name = "pybind11_bazel", version = "2.12.0")
|
||||
bazel_dep(name = "rules_cc", version = "0.0.9")
|
||||
bazel_dep(name = "rules_license", version = "0.0.7")
|
||||
bazel_dep(name = "rules_cc", version = "0.0.16")
|
||||
bazel_dep(name = "rules_license", version = "1.0.0")
|
||||
bazel_dep(name = "rules_python", version = "1.0.0")
|
||||
bazel_dep(name = "google_benchmark", version = "1.8.5")
|
||||
|
||||
# Require a more recent version.
|
||||
|
|
@ -23,6 +24,15 @@ git_override(
|
|||
|
||||
http_archive = use_repo_rule("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
|
||||
|
||||
http_archive(
|
||||
name = "com_google_absl_py",
|
||||
sha256 = "8a3d0830e4eb4f66c4fa907c06edf6ce1c719ced811a12e26d9d3162f8471758",
|
||||
strip_prefix = "abseil-py-2.1.0",
|
||||
urls = [
|
||||
"https://github.com/abseil/abseil-py/archive/refs/tags/v2.1.0.tar.gz",
|
||||
],
|
||||
)
|
||||
|
||||
http_archive(
|
||||
name = "com_google_sentencepiece",
|
||||
build_file = "@//bazel:sentencepiece.bazel",
|
||||
|
|
@ -53,3 +63,17 @@ cc_library(
|
|||
"https://github.com/s-yata/darts-clone/archive/e40ce4627526985a7767444b6ed6893ab6ff8983.zip",
|
||||
],
|
||||
)
|
||||
|
||||
pip = use_extension("@rules_python//python/extensions:pip.bzl", "pip")
|
||||
pip.parse(
|
||||
hub_name = "compression_deps",
|
||||
python_version = "3.11",
|
||||
requirements_lock = "//compression/python:requirements.txt",
|
||||
)
|
||||
use_repo(pip, "compression_deps")
|
||||
pip.parse(
|
||||
hub_name = "python_deps",
|
||||
python_version = "3.11",
|
||||
requirements_lock = "//python:requirements.txt",
|
||||
)
|
||||
use_repo(pip, "python_deps")
|
||||
|
|
|
|||
|
|
@ -1,3 +1,18 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <cstdint>
|
||||
#include <cstdio>
|
||||
#include <string>
|
||||
|
|
|
|||
|
|
@ -41,8 +41,8 @@ py_test(
|
|||
srcs = ["compression_test.py"],
|
||||
deps = [
|
||||
":compression",
|
||||
"//testing/pybase",
|
||||
"@com_google_absl_py//absl/testing:absltest",
|
||||
"//python:configs",
|
||||
"//third_party/py/numpy",
|
||||
"@compression_deps//numpy",
|
||||
],
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,3 +1,18 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "compression/python/compression_clif_aux.h"
|
||||
|
||||
#include <cstddef>
|
||||
|
|
|
|||
|
|
@ -1,3 +1,18 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_PYTHON_COMPRESSION_CLIF_AUX_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_PYTHON_COMPRESSION_CLIF_AUX_H_
|
||||
|
||||
|
|
|
|||
|
|
@ -1,3 +1,18 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <pybind11/numpy.h>
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
|
|
|
|||
|
|
@ -1,13 +1,28 @@
|
|||
# Copyright 2024 Google LLC
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for CLIF wrapped .sbs writer."""
|
||||
|
||||
import numpy as np
|
||||
|
||||
import unittest
|
||||
from absl.testing import absltest
|
||||
from compression.python import compression
|
||||
from gemma.python import configs
|
||||
from python import configs
|
||||
|
||||
|
||||
class CompressionTest(unittest.TestCase):
|
||||
class CompressionTest(absltest.TestCase):
|
||||
|
||||
def test_sbs_writer(self):
|
||||
temp_file = self.create_tempfile("test.sbs")
|
||||
|
|
@ -41,4 +56,4 @@ class CompressionTest(unittest.TestCase):
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
absltest.main()
|
||||
|
|
|
|||
|
|
@ -0,0 +1 @@
|
|||
numpy>=1.26.4
|
||||
|
|
@ -29,8 +29,7 @@ float Normalize(float value, float max_value = 255.0f) {
|
|||
}
|
||||
|
||||
TEST(ImageTest, LoadResize224GetPatch) {
|
||||
return; // Need to figure out how to get the external path for the test file.
|
||||
std::string path;
|
||||
std::string path = "paligemma/testdata/image.ppm";
|
||||
Image image;
|
||||
EXPECT_EQ(image.width(), 0);
|
||||
EXPECT_EQ(image.height(), 0);
|
||||
|
|
|
|||
|
|
@ -93,8 +93,7 @@ std::string PaliGemmaTest::GemmaReply(const std::string& prompt_text) const{
|
|||
|
||||
void PaliGemmaTest::TestQuestions(const char* kQA[][2], size_t num_questions) {
|
||||
ASSERT_NE(s_env->GetModel(), nullptr);
|
||||
return; // Need to figure out how to get the external path for the test file.
|
||||
std::string path;
|
||||
std::string path = "paligemma/testdata/image.ppm";
|
||||
InitVit(path);
|
||||
for (size_t i = 0; i < num_questions; ++i) {
|
||||
fprintf(stderr, "Question %zu\n\n", i + 1);
|
||||
|
|
|
|||
|
|
@ -0,0 +1,43 @@
|
|||
# [internal] load py_binary
|
||||
load("@pybind11_bazel//:build_defs.bzl", "pybind_extension")
|
||||
|
||||
package(
|
||||
default_applicable_licenses = [
|
||||
"//:license", # Placeholder comment, do not modify
|
||||
],
|
||||
default_visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
pybind_extension(
|
||||
name = "configs",
|
||||
srcs = ["configs.cc"],
|
||||
deps = [
|
||||
"//:common",
|
||||
"//compression:sfp",
|
||||
],
|
||||
)
|
||||
|
||||
pybind_extension(
|
||||
name = "gemma",
|
||||
srcs = ["gemma_py.cc"],
|
||||
deps = [
|
||||
"//:app",
|
||||
"//:benchmark_helper",
|
||||
"//:gemma_lib",
|
||||
"//compression:sfp",
|
||||
"@highway//:hwy",
|
||||
"@highway//:thread_pool",
|
||||
],
|
||||
)
|
||||
|
||||
py_binary(
|
||||
name = "run_example",
|
||||
srcs = ["run_example.py"],
|
||||
python_version = "PY3",
|
||||
deps = [
|
||||
":gemma",
|
||||
"@python_deps//absl_py",
|
||||
# placeholder forabsl/flags
|
||||
"@compression_deps//numpy",
|
||||
],
|
||||
)
|
||||
|
|
@ -0,0 +1,184 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "gemma/configs.h"
|
||||
|
||||
#include <pybind11/cast.h>
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
|
||||
#include "compression/shared.h"
|
||||
#include "gemma/tensor_index.h"
|
||||
|
||||
using gcpp::ActivationType;
|
||||
using gcpp::LayerAttentionType;
|
||||
using gcpp::LayerConfig;
|
||||
using gcpp::Model;
|
||||
using gcpp::ModelConfig;
|
||||
using gcpp::PostNormType;
|
||||
using gcpp::PostQKType;
|
||||
using gcpp::PromptWrapping;
|
||||
using gcpp::QueryScaleType;
|
||||
using gcpp::ResidualType;
|
||||
using gcpp::TensorIndex;
|
||||
using gcpp::TensorInfo;
|
||||
using gcpp::Type;
|
||||
using gcpp::VitConfig;
|
||||
|
||||
namespace pybind11 {
|
||||
|
||||
PYBIND11_MODULE(configs, py_module) {
|
||||
enum_<PromptWrapping>(py_module, "PromptWrapping")
|
||||
.value("GEMMA_IT", PromptWrapping::GEMMA_IT)
|
||||
.value("GEMMA_PT", PromptWrapping::GEMMA_PT)
|
||||
.value("PALIGEMMA", PromptWrapping::PALIGEMMA);
|
||||
|
||||
enum_<Type>(py_module, "Type")
|
||||
.value("kUnknown", Type::kUnknown)
|
||||
.value("kF32", Type::kF32)
|
||||
.value("kBF16", Type::kBF16)
|
||||
.value("kSFP", Type::kSFP)
|
||||
.value("kNUQ", Type::kNUQ)
|
||||
.value("kF64", Type::kF64)
|
||||
.value("kC64", Type::kC64)
|
||||
.value("kU128", Type::kU128);
|
||||
|
||||
enum_<LayerAttentionType>(py_module, "LayerAttentionType")
|
||||
.value("kGemma", LayerAttentionType::kGemma)
|
||||
.value("kGriffinRecurrentBlock",
|
||||
LayerAttentionType::kGriffinRecurrentBlock)
|
||||
.value("kVit", LayerAttentionType::kVit);
|
||||
|
||||
enum_<PostNormType>(py_module, "PostNormType")
|
||||
.value("NoPostNorm", PostNormType::None)
|
||||
.value("Scale", PostNormType::Scale);
|
||||
|
||||
enum_<PostQKType>(py_module, "PostQKType")
|
||||
.value("Rope", PostQKType::Rope)
|
||||
.value("HalfRope", PostQKType::HalfRope);
|
||||
|
||||
enum_<ActivationType>(py_module, "ActivationType")
|
||||
.value("Gelu", ActivationType::Gelu);
|
||||
|
||||
enum_<QueryScaleType>(py_module, "QueryScaleType")
|
||||
.value("SqrtKeySize", QueryScaleType::SqrtKeySize)
|
||||
.value("SqrtModelDimDivNumHeads",
|
||||
QueryScaleType::SqrtModelDimDivNumHeads);
|
||||
|
||||
enum_<ResidualType>(py_module, "ResidualType")
|
||||
.value("Add", ResidualType::Add);
|
||||
|
||||
enum_<Model>(py_module, "Model")
|
||||
.value("UNKNOWN", Model::UNKNOWN)
|
||||
.value("GEMMA_2B", Model::GEMMA_2B)
|
||||
.value("GEMMA_7B", Model::GEMMA_7B)
|
||||
.value("GEMMA2_9B", Model::GEMMA2_9B)
|
||||
.value("GEMMA2_27B", Model::GEMMA2_27B)
|
||||
.value("GRIFFIN_2B", Model::GRIFFIN_2B)
|
||||
.value("GEMMA_TINY", Model::GEMMA_TINY)
|
||||
.value("GEMMA2_2B", Model::GEMMA2_2B)
|
||||
.value("PALIGEMMA2_3B_224", Model::PALIGEMMA2_3B_224)
|
||||
.value("PALIGEMMA2_10B_224", Model::PALIGEMMA2_10B_224)
|
||||
.value("PALIGEMMA2_3B_448", Model::PALIGEMMA2_3B_448)
|
||||
.value("PALIGEMMA2_10B_448", Model::PALIGEMMA2_10B_448)
|
||||
.value("PALIGEMMA_224", Model::PALIGEMMA_224)
|
||||
.value("PALIGEMMA_448", Model::PALIGEMMA_448);
|
||||
|
||||
class_<TensorInfo>(py_module, "TensorInfo")
|
||||
.def(init())
|
||||
.def_readwrite("name", &TensorInfo::name)
|
||||
.def_readwrite("source_names", &TensorInfo::source_names)
|
||||
.def_readwrite("preshape", &TensorInfo::preshape)
|
||||
.def_readwrite("axes", &TensorInfo::axes)
|
||||
.def_readwrite("shape", &TensorInfo::shape)
|
||||
.def_readwrite("concat_names", &TensorInfo::concat_names)
|
||||
.def_readwrite("concat_axis", &TensorInfo::concat_axis)
|
||||
.def_readwrite("min_size", &TensorInfo::min_size)
|
||||
.def_readwrite("scaled_softplus", &TensorInfo::scaled_softplus)
|
||||
.def_readwrite("cols_take_extra_dims", &TensorInfo::cols_take_extra_dims);
|
||||
|
||||
class_<TensorIndex>(py_module, "TensorIndex")
|
||||
.def(init<const ModelConfig&, int, int, bool>())
|
||||
.def("tensor_info_from_source_path",
|
||||
&TensorIndex::TensorInfoFromSourcePath, arg("path"))
|
||||
.def("tensor_info_from_name", &TensorIndex::TensorInfoFromName,
|
||||
arg("name"));
|
||||
|
||||
class_<LayerConfig>(py_module, "LayerConfig")
|
||||
.def(init())
|
||||
.def_readwrite("model_dim", &LayerConfig::model_dim)
|
||||
.def_readwrite("griffin_dim", &LayerConfig::griffin_dim)
|
||||
.def_readwrite("ff_hidden_dim", &LayerConfig::ff_hidden_dim)
|
||||
.def_readwrite("heads", &LayerConfig::heads)
|
||||
.def_readwrite("kv_heads", &LayerConfig::kv_heads)
|
||||
.def_readwrite("qkv_dim", &LayerConfig::qkv_dim)
|
||||
.def_readwrite("conv1d_width", &LayerConfig::conv1d_width)
|
||||
.def_readwrite("ff_biases", &LayerConfig::ff_biases)
|
||||
.def_readwrite("softmax_attn_output_biases",
|
||||
&LayerConfig::softmax_attn_output_biases)
|
||||
.def_readwrite("optimized_gating", &LayerConfig::optimized_gating)
|
||||
.def_readwrite("post_norm", &LayerConfig::post_norm)
|
||||
.def_readwrite("type", &LayerConfig::type)
|
||||
.def_readwrite("activation", &LayerConfig::activation)
|
||||
.def_readwrite("post_qk", &LayerConfig::post_qk);
|
||||
|
||||
class_<VitConfig>(py_module, "VitConfig")
|
||||
.def(init())
|
||||
.def_readwrite("model_dim", &VitConfig::model_dim)
|
||||
.def_readwrite("seq_len", &VitConfig::seq_len)
|
||||
.def_readwrite("num_scales", &VitConfig::num_scales)
|
||||
.def_readwrite("patch_width", &VitConfig::patch_width)
|
||||
.def_readwrite("image_size", &VitConfig::image_size)
|
||||
.def_readwrite("layer_configs", &VitConfig::layer_configs);
|
||||
|
||||
class_<ModelConfig>(py_module, "ModelConfig")
|
||||
.def(init())
|
||||
.def_readwrite("model_family_version", &ModelConfig::model_family_version)
|
||||
.def_readwrite("model_name", &ModelConfig::model_name)
|
||||
.def_readwrite("model", &ModelConfig::model)
|
||||
.def_readwrite("wrapping", &ModelConfig::wrapping)
|
||||
.def_readwrite("weight", &ModelConfig::weight)
|
||||
.def_readwrite("num_layers", &ModelConfig::num_layers)
|
||||
.def_readwrite("model_dim", &ModelConfig::model_dim)
|
||||
.def_readwrite("vocab_size", &ModelConfig::vocab_size)
|
||||
.def_readwrite("seq_len", &ModelConfig::seq_len)
|
||||
.def_readwrite("num_tensor_scales", &ModelConfig::num_tensor_scales)
|
||||
.def_readwrite("att_cap", &ModelConfig::att_cap)
|
||||
.def_readwrite("final_cap", &ModelConfig::final_cap)
|
||||
.def_readwrite("absolute_pe", &ModelConfig::absolute_pe)
|
||||
.def_readwrite("use_local_attention", &ModelConfig::use_local_attention)
|
||||
.def_readwrite("query_scale", &ModelConfig::query_scale)
|
||||
.def_readwrite("layer_configs", &ModelConfig::layer_configs)
|
||||
.def_readwrite("attention_window_sizes",
|
||||
&ModelConfig::attention_window_sizes)
|
||||
.def_readwrite("scale_names", &ModelConfig::scale_names)
|
||||
.def_readwrite("norm_num_groups", &ModelConfig::norm_num_groups)
|
||||
.def_readwrite("vit_config", &ModelConfig::vit_config)
|
||||
.def("add_layer_config", &ModelConfig::AddLayerConfig,
|
||||
arg("layer_config"))
|
||||
.def("test_equal", &ModelConfig::TestEqual, arg("other"), arg("partial"),
|
||||
arg("debug"));
|
||||
|
||||
// Returns the config for the given model.
|
||||
py_module.def("config_from_model", &gcpp::ConfigFromModel, arg("model"));
|
||||
|
||||
// Returns the model for the given config, if it matches any standard model.
|
||||
py_module.def("model_from_config", &gcpp::ModelFromConfig, arg("config"));
|
||||
|
||||
// Returns the sub-config for the ViT model of the PaliGemma model.
|
||||
py_module.def("vit_config", &gcpp::GetVitConfig, arg("config"));
|
||||
}
|
||||
|
||||
} // namespace pybind11
|
||||
|
|
@ -0,0 +1,303 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <Python.h>
|
||||
#include <pybind11/cast.h>
|
||||
#include <pybind11/functional.h>
|
||||
#include <pybind11/numpy.h>
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <random>
|
||||
#include <set>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "compression/shared.h"
|
||||
#include "evals/benchmark_helper.h"
|
||||
#include "gemma/gemma.h"
|
||||
#include "util/app.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
static void RemoveTrailingZeros(std::vector<int> &vec) {
|
||||
auto it =
|
||||
std::find_if(vec.rbegin(), vec.rend(), [](int v) { return v != 0; });
|
||||
vec.erase(it.base(), vec.end());
|
||||
}
|
||||
|
||||
// Wrapper around GemmaEnv to expose to Python.
|
||||
class GemmaModel {
|
||||
public:
|
||||
GemmaModel(const gcpp::LoaderArgs& loader,
|
||||
const gcpp::InferenceArgs& inference, const gcpp::AppArgs& app)
|
||||
: gemma_(loader, inference, app), last_prob_(0.0f) {}
|
||||
|
||||
// Generates a single example, given a prompt and a callback to stream the
|
||||
// generated tokens.
|
||||
void GenerateEx(std::string prompt, gcpp::StreamFunc stream,
|
||||
size_t max_generated_tokens, float temperature, float seed,
|
||||
gcpp::AcceptFunc accept, bool skip_prompt) {
|
||||
gemma_.MutableGen().seed(seed);
|
||||
std::vector<int> prompt_tokens = gemma_.WrapAndTokenize(prompt);
|
||||
gcpp::RuntimeConfig& config = gemma_.MutableConfig();
|
||||
config.max_generated_tokens = max_generated_tokens;
|
||||
config.temperature = temperature;
|
||||
config.verbosity = 0;
|
||||
config.accept_token = accept;
|
||||
// If skip_prompt is true, we skip the prompt tokens and only stream the
|
||||
// generated tokens.
|
||||
int count_down = prompt_tokens.size();
|
||||
auto stream_with_skipping = [&stream, &count_down](int token, float score) {
|
||||
if (count_down > 0) {
|
||||
count_down--;
|
||||
return true;
|
||||
}
|
||||
return stream(token, score);
|
||||
};
|
||||
gemma_.QueryModel(prompt_tokens,
|
||||
skip_prompt ? stream_with_skipping : stream);
|
||||
}
|
||||
|
||||
// Generates a single example, given a prompt, and returns the result.
|
||||
std::string Generate(std::string prompt, size_t max_generated_tokens,
|
||||
float temperature, float seed,
|
||||
const std::vector<std::string>& accept,
|
||||
const std::vector<std::string>& end) {
|
||||
std::set<int> end_token_set{};
|
||||
for (const std::string& end_token : end) {
|
||||
std::vector<int> end_token_ids = gemma_.Tokenize(end_token);
|
||||
end_token_set.insert(end_token_ids.begin(), end_token_ids.end());
|
||||
}
|
||||
|
||||
std::vector<int> predicted_token_ids;
|
||||
predicted_token_ids.reserve(max_generated_tokens);
|
||||
std::vector<int> prompt_token_ids = gemma_.WrapAndTokenize(prompt);
|
||||
int generated = 0;
|
||||
auto stream_token = [&generated, &prompt_token_ids, &predicted_token_ids,
|
||||
&end_token_set, this](int token, float proba) {
|
||||
++generated;
|
||||
if (generated > prompt_token_ids.size()) {
|
||||
predicted_token_ids.push_back(token);
|
||||
if (!end_token_set.empty()) {
|
||||
return end_token_set.find(token) == end_token_set.end();
|
||||
}
|
||||
}
|
||||
last_prob_ = proba;
|
||||
return true;
|
||||
};
|
||||
|
||||
std::set<int> accept_token_set{};
|
||||
for (const std::string& accept_token : accept) {
|
||||
std::vector<int> accept_token_ids = gemma_.Tokenize(accept_token);
|
||||
accept_token_set.insert(accept_token_ids.begin(), accept_token_ids.end());
|
||||
}
|
||||
|
||||
auto accept_token = [&predicted_token_ids, &prompt_token_ids,
|
||||
&accept_token_set](int token, float) {
|
||||
// i.e. we have no constraints on accepted tokens
|
||||
if (accept_token_set.empty()) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (predicted_token_ids.size() >= prompt_token_ids.size()) {
|
||||
return accept_token_set.find(token) != accept_token_set.end();
|
||||
} else {
|
||||
// auto-accept prompt tokens
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
gemma_.MutableGen().seed(seed);
|
||||
gcpp::RuntimeConfig& config = gemma_.MutableConfig();
|
||||
config.max_generated_tokens = max_generated_tokens;
|
||||
config.temperature = temperature;
|
||||
config.verbosity = 0;
|
||||
config.accept_token = accept_token;
|
||||
|
||||
gemma_.QueryModel(prompt_token_ids, stream_token);
|
||||
|
||||
if (!predicted_token_ids.empty()) {
|
||||
return gemma_.StringFromTokens(predicted_token_ids);
|
||||
} else {
|
||||
return "";
|
||||
}
|
||||
}
|
||||
|
||||
// Generates a batch of examples, given a list of prompts, and returns the
|
||||
// results.
|
||||
std::vector<std::string> GenerateBatch(const std::vector<std::string>& inputs,
|
||||
size_t max_generated_tokens,
|
||||
float temperature, float seed,
|
||||
size_t top_k) {
|
||||
gcpp::RuntimeConfig& config = gemma_.MutableConfig();
|
||||
config.max_generated_tokens = max_generated_tokens;
|
||||
config.temperature = temperature;
|
||||
config.top_k = top_k;
|
||||
config.verbosity = 0;
|
||||
gemma_.MutableGen().seed(seed);
|
||||
|
||||
std::vector<gcpp::QueryResult> outputs = gemma_.BatchQueryModel(inputs);
|
||||
std::vector<std::string> result;
|
||||
result.reserve(outputs.size());
|
||||
for (const gcpp::QueryResult& output : outputs) {
|
||||
result.push_back(output.response.substr(output.response_start_pos));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// For a PaliGemma model, sets the image to run on. Subseqent calls to
|
||||
// Generate* will use this image. Throws an error for other models.
|
||||
void SetImage(const py::array_t<float, py::array::c_style |
|
||||
py::array::forcecast>& image) {
|
||||
gcpp::Gemma& model = *(gemma_.GetModel());
|
||||
if (model.Info().wrapping != gcpp::PromptWrapping::PALIGEMMA) {
|
||||
throw std::invalid_argument("Not a PaliGemma model.");
|
||||
}
|
||||
py::buffer_info buffer = image.request();
|
||||
if (buffer.ndim != 3 || buffer.shape[2] != 3)
|
||||
throw std::runtime_error(
|
||||
"Expected a 3D numpy array with shape (height, width, 3)");
|
||||
int height = buffer.shape[0];
|
||||
int width = buffer.shape[1];
|
||||
float* ptr = static_cast<float*>(buffer.ptr);
|
||||
gcpp::Image c_image;
|
||||
c_image.Set(height, width, ptr);
|
||||
const size_t image_size = model.GetModelConfig().vit_config.image_size;
|
||||
c_image.Resize(image_size, image_size);
|
||||
image_tokens_ = gcpp::ImageTokens(gcpp::Extents2D(
|
||||
model.GetModelConfig().vit_config.seq_len,
|
||||
model.GetModelConfig().model_dim));
|
||||
gcpp::RuntimeConfig runtime_config = {.gen = &gemma_.MutableGen(),
|
||||
.verbosity = 0};
|
||||
model.GenerateImageTokens(runtime_config, c_image, image_tokens_);
|
||||
}
|
||||
|
||||
// Generates a response to the given prompt, using the last set image.
|
||||
// Uses the prompt_tokens if provided, otherwise tokenizes the prompt string.
|
||||
std::pair<std::string, std::vector<int>> GenerateWithImage(
|
||||
std::string prompt, size_t max_generated_tokens, float temperature,
|
||||
float seed, gcpp::AcceptFunc accept, std::vector<int> prompt_tokens) {
|
||||
if (image_tokens_.Cols() == 0) {
|
||||
throw std::invalid_argument("No image set.");
|
||||
}
|
||||
gcpp::Gemma& model = *(gemma_.GetModel());
|
||||
gemma_.MutableGen().seed(seed);
|
||||
gcpp::RuntimeConfig& config = gemma_.MutableConfig();
|
||||
config.max_generated_tokens = max_generated_tokens;
|
||||
config.temperature = temperature;
|
||||
config.verbosity = 0;
|
||||
config.accept_token = accept;
|
||||
config.image_tokens = &image_tokens_;
|
||||
std::vector<int> tokens;
|
||||
if (!prompt_tokens.empty()) {
|
||||
if (!prompt.empty()) {
|
||||
throw std::invalid_argument(
|
||||
"Cannot pass both prompt and prompt_tokens.");
|
||||
}
|
||||
tokens = prompt_tokens;
|
||||
RemoveTrailingZeros(tokens); // Remove padding, if any.
|
||||
} else {
|
||||
tokens = gemma_.WrapAndTokenize(prompt);
|
||||
}
|
||||
tokens.insert(tokens.begin(), image_tokens_.BatchSize(), 0);
|
||||
size_t num_tokens = tokens.size();
|
||||
size_t prefix_end = num_tokens;
|
||||
config.prefill_tbatch_size = num_tokens;
|
||||
int count_down = static_cast<int>(num_tokens);
|
||||
std::vector<int> response_tokens;
|
||||
auto stream_token = [&](int token, float) {
|
||||
if (count_down > 0) {
|
||||
count_down--;
|
||||
return true;
|
||||
}
|
||||
response_tokens.push_back(token);
|
||||
return true;
|
||||
};
|
||||
config.stream_token = stream_token;
|
||||
gcpp::TimingInfo timing_info = {.verbosity = 0};
|
||||
model.Generate(config, tokens, /*pos=*/0, prefix_end,
|
||||
gemma_.MutableKVCache(), timing_info);
|
||||
std::string response;
|
||||
model.Tokenizer().Decode(response_tokens, &response);
|
||||
return {response, response_tokens};
|
||||
}
|
||||
|
||||
float GetLastProb() const { return last_prob_; }
|
||||
|
||||
std::string Detokenize(const std::vector<int>& token_ids) const {
|
||||
return gemma_.StringFromTokens(token_ids);
|
||||
}
|
||||
|
||||
bool ModelIsLoaded() const { return gemma_.GetModel() != nullptr; }
|
||||
|
||||
private:
|
||||
gcpp::GemmaEnv gemma_;
|
||||
gcpp::ImageTokens image_tokens_;
|
||||
float last_prob_;
|
||||
};
|
||||
|
||||
PYBIND11_MODULE(gemma, mod) {
|
||||
py::class_<GemmaModel>(mod, "GemmaModel")
|
||||
.def(py::init([](std::string tokenizer, std::string weights,
|
||||
std::string model, std::string weight_type,
|
||||
size_t max_threads) {
|
||||
gcpp::LoaderArgs loader(tokenizer, weights, model);
|
||||
if (const char* err = loader.Validate()) {
|
||||
throw std::invalid_argument(err);
|
||||
}
|
||||
loader.weight_type_str = weight_type;
|
||||
gcpp::InferenceArgs inference;
|
||||
inference.max_generated_tokens = 512;
|
||||
gcpp::AppArgs app;
|
||||
app.max_threads = max_threads;
|
||||
auto gemma =
|
||||
std::make_unique<GemmaModel>(loader, inference, app);
|
||||
if (!gemma->ModelIsLoaded()) {
|
||||
throw std::invalid_argument("Could not load model.");
|
||||
}
|
||||
return gemma;
|
||||
}),
|
||||
py::arg("tokenizer_path"), py::arg("weights_path"),
|
||||
py::arg("model_flag"), py::arg("weight_type") = "sfp",
|
||||
py::arg("max_threads") = 0)
|
||||
.def("generate_ex", &GemmaModel::GenerateEx, py::arg("prompt"),
|
||||
py::arg("stream"), py::arg("max_generated_tokens") = 1024,
|
||||
py::arg("temperature") = 0.9, py::arg("seed") = 123456789,
|
||||
py::arg("accept") = gcpp::AcceptFunc(),
|
||||
py::arg("skip_prompt") = false)
|
||||
.def("generate", &GemmaModel::Generate, py::arg("prompt"),
|
||||
py::arg("max_generated_tokens") = 1024, py::arg("temperature") = 0.9,
|
||||
py::arg("seed") = 123456789,
|
||||
py::arg("accept") = std::vector<std::string>(),
|
||||
py::arg("end") = std::vector<std::string>())
|
||||
.def("generate_batch", &GemmaModel::GenerateBatch, py::arg("inputs"),
|
||||
py::arg("max_generated_tokens") = 1024, py::arg("temperature") = 0.9,
|
||||
py::arg("seed") = 123456789, py::arg("top_k") = 5)
|
||||
.def("set_image", &GemmaModel::SetImage, py::arg("image"))
|
||||
.def("generate_with_image", &GemmaModel::GenerateWithImage,
|
||||
py::arg("prompt") = "", py::arg("max_generated_tokens") = 1024,
|
||||
py::arg("temperature") = 0.9, py::arg("seed") = 123456789,
|
||||
py::arg("accept") = gcpp::AcceptFunc(),
|
||||
py::arg("prompt_tokens") = std::vector<int>())
|
||||
.def("get_last_prob", &GemmaModel::GetLastProb)
|
||||
.def("detokenize", &GemmaModel::Detokenize, py::arg("token_ids"));
|
||||
}
|
||||
|
|
@ -0,0 +1 @@
|
|||
absl-py
|
||||
|
|
@ -0,0 +1,108 @@
|
|||
# Copyright 2024 Google LLC
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""A simple example of using the gemma.cpp Python wrapper."""
|
||||
|
||||
from collections.abc import Sequence
|
||||
import os
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
import numpy as np
|
||||
|
||||
from python import gemma
|
||||
|
||||
|
||||
_MODEL_DIR = flags.DEFINE_string(
|
||||
"model_dir",
|
||||
"",
|
||||
"Path to the Gemma model directory.",
|
||||
)
|
||||
|
||||
_PROMPT = flags.DEFINE_string(
|
||||
"prompt",
|
||||
"Write an email to the moon.",
|
||||
"Prompt to generate text with.",
|
||||
)
|
||||
|
||||
|
||||
def main(argv: Sequence[str]) -> None:
|
||||
if len(argv) > 1:
|
||||
raise app.UsageError("Too many command-line arguments.")
|
||||
|
||||
tokenizer_path = os.path.join(_MODEL_DIR.value, "tokenizer.spm")
|
||||
weights_path = os.path.join(_MODEL_DIR.value, "gemma2-2b-it-sfp.sbs")
|
||||
print(f"Loading model from {tokenizer_path} and {weights_path}")
|
||||
model = gemma.GemmaModel(
|
||||
tokenizer_path=tokenizer_path,
|
||||
weights_path=weights_path,
|
||||
model_flag="gemma2-2b-it",
|
||||
max_threads=24,
|
||||
)
|
||||
|
||||
prompt = _PROMPT.value
|
||||
print(f"Running example with prompt='{prompt}'")
|
||||
output = model.generate(prompt)
|
||||
print(f"Generated output:\n{output}")
|
||||
|
||||
def callback(tok, _):
|
||||
s = model.detokenize([tok])
|
||||
print(s, end="", flush=True)
|
||||
return True
|
||||
|
||||
print(f"\n\nRunning example with streaming callback, prompt='{prompt}'")
|
||||
print("Generating output:\n")
|
||||
model.generate_ex(prompt, callback, skip_prompt=True)
|
||||
|
||||
prompts = [
|
||||
prompt,
|
||||
"Tell me a joke.",
|
||||
"Please recite the first paragraph of the Declaration of Independence.",
|
||||
prompt,
|
||||
]
|
||||
print("\n\n\nRunning example with batch generation")
|
||||
outputs = model.generate_batch(
|
||||
prompts, max_generated_tokens=16, temperature=2.0, top_k=30, seed=123456,
|
||||
)
|
||||
print("Generated outputs:")
|
||||
for prompt, output in zip(prompts, outputs):
|
||||
print(f"Prompt: '{prompt}' --->\nOutput: {output}\n")
|
||||
|
||||
# PaliGemma example.
|
||||
tokenizer_path = os.path.join(_MODEL_DIR.value, "paligemma_tokenizer.model")
|
||||
weights_path = os.path.join(_MODEL_DIR.value, "paligemma-3b-mix-224-sfp.sbs")
|
||||
print(f"Loading model from {tokenizer_path} and {weights_path}")
|
||||
model = gemma.GemmaModel(
|
||||
tokenizer_path=tokenizer_path,
|
||||
weights_path=weights_path,
|
||||
model_flag="paligemma-224",
|
||||
max_threads=24,
|
||||
)
|
||||
image = np.array(
|
||||
[
|
||||
[[255, 0, 0], [0, 255, 0]], # Red, Green
|
||||
[[0, 0, 255], [255, 255, 255]], # Blue, White
|
||||
],
|
||||
dtype=np.float32,
|
||||
)
|
||||
model.set_image(image)
|
||||
prompt = "Describe this image."
|
||||
print(f"Running example with a tiny image and prompt='{prompt}'.")
|
||||
output, tokens = model.generate_with_image(prompt)
|
||||
print(f"Generated {len(tokens)} tokens, output:\n{output}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(main)
|
||||
Loading…
Reference in New Issue