From 7af2e70321c58ed7fd69f067424e84eea176e145 Mon Sep 17 00:00:00 2001 From: Daniel Keysers Date: Tue, 28 Jan 2025 08:21:24 -0800 Subject: [PATCH] Add python wrappers for configs and inference. Enable building compression/python/compression_test using bazel. Add default image path for image_test and paligemma_test. PiperOrigin-RevId: 720583438 --- MODULE.bazel | 30 +- compression/blob_compare.cc | 15 + compression/python/BUILD.bazel | 4 +- compression/python/compression_clif_aux.cc | 15 + compression/python/compression_clif_aux.h | 15 + compression/python/compression_extension.cc | 15 + compression/python/compression_test.py | 23 +- compression/python/requirements.txt | 1 + paligemma/image_test.cc | 3 +- paligemma/paligemma_test.cc | 3 +- python/BUILD.bazel | 43 +++ python/configs.cc | 184 ++++++++++++ python/gemma_py.cc | 303 ++++++++++++++++++++ python/requirements.txt | 1 + python/run_example.py | 108 +++++++ 15 files changed, 750 insertions(+), 13 deletions(-) create mode 100644 compression/python/requirements.txt create mode 100644 python/BUILD.bazel create mode 100644 python/configs.cc create mode 100644 python/gemma_py.cc create mode 100644 python/requirements.txt create mode 100644 python/run_example.py diff --git a/MODULE.bazel b/MODULE.bazel index e835941..6bd5a0a 100644 --- a/MODULE.bazel +++ b/MODULE.bazel @@ -4,14 +4,15 @@ module( ) bazel_dep(name = "abseil-cpp", version = "20240722.0") -bazel_dep(name = "bazel_skylib", version = "1.6.1") +bazel_dep(name = "bazel_skylib", version = "1.7.1") bazel_dep(name = "googletest", version = "1.15.2") bazel_dep(name = "highway", version = "1.1.0") bazel_dep(name = "nlohmann_json", version = "3.11.3") bazel_dep(name = "platforms", version = "0.0.10") bazel_dep(name = "pybind11_bazel", version = "2.12.0") -bazel_dep(name = "rules_cc", version = "0.0.9") -bazel_dep(name = "rules_license", version = "0.0.7") +bazel_dep(name = "rules_cc", version = "0.0.16") +bazel_dep(name = "rules_license", version = "1.0.0") +bazel_dep(name = "rules_python", version = "1.0.0") bazel_dep(name = "google_benchmark", version = "1.8.5") # Require a more recent version. @@ -23,6 +24,15 @@ git_override( http_archive = use_repo_rule("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") +http_archive( + name = "com_google_absl_py", + sha256 = "8a3d0830e4eb4f66c4fa907c06edf6ce1c719ced811a12e26d9d3162f8471758", + strip_prefix = "abseil-py-2.1.0", + urls = [ + "https://github.com/abseil/abseil-py/archive/refs/tags/v2.1.0.tar.gz", + ], +) + http_archive( name = "com_google_sentencepiece", build_file = "@//bazel:sentencepiece.bazel", @@ -53,3 +63,17 @@ cc_library( "https://github.com/s-yata/darts-clone/archive/e40ce4627526985a7767444b6ed6893ab6ff8983.zip", ], ) + +pip = use_extension("@rules_python//python/extensions:pip.bzl", "pip") +pip.parse( + hub_name = "compression_deps", + python_version = "3.11", + requirements_lock = "//compression/python:requirements.txt", +) +use_repo(pip, "compression_deps") +pip.parse( + hub_name = "python_deps", + python_version = "3.11", + requirements_lock = "//python:requirements.txt", +) +use_repo(pip, "python_deps") diff --git a/compression/blob_compare.cc b/compression/blob_compare.cc index c6f0a00..ae40582 100644 --- a/compression/blob_compare.cc +++ b/compression/blob_compare.cc @@ -1,3 +1,18 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #include #include #include diff --git a/compression/python/BUILD.bazel b/compression/python/BUILD.bazel index 546705b..8bfb391 100644 --- a/compression/python/BUILD.bazel +++ b/compression/python/BUILD.bazel @@ -41,8 +41,8 @@ py_test( srcs = ["compression_test.py"], deps = [ ":compression", - "//testing/pybase", + "@com_google_absl_py//absl/testing:absltest", "//python:configs", - "//third_party/py/numpy", + "@compression_deps//numpy", ], ) diff --git a/compression/python/compression_clif_aux.cc b/compression/python/compression_clif_aux.cc index cad0d14..b843a93 100644 --- a/compression/python/compression_clif_aux.cc +++ b/compression/python/compression_clif_aux.cc @@ -1,3 +1,18 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #include "compression/python/compression_clif_aux.h" #include diff --git a/compression/python/compression_clif_aux.h b/compression/python/compression_clif_aux.h index cb8eb8c..4ea5b16 100644 --- a/compression/python/compression_clif_aux.h +++ b/compression/python/compression_clif_aux.h @@ -1,3 +1,18 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_PYTHON_COMPRESSION_CLIF_AUX_H_ #define THIRD_PARTY_GEMMA_CPP_COMPRESSION_PYTHON_COMPRESSION_CLIF_AUX_H_ diff --git a/compression/python/compression_extension.cc b/compression/python/compression_extension.cc index 17266e9..c873a23 100644 --- a/compression/python/compression_extension.cc +++ b/compression/python/compression_extension.cc @@ -1,3 +1,18 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #include #include #include diff --git a/compression/python/compression_test.py b/compression/python/compression_test.py index 077e513..33cd055 100644 --- a/compression/python/compression_test.py +++ b/compression/python/compression_test.py @@ -1,13 +1,28 @@ +# Copyright 2024 Google LLC +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Tests for CLIF wrapped .sbs writer.""" import numpy as np -import unittest +from absl.testing import absltest from compression.python import compression -from gemma.python import configs +from python import configs -class CompressionTest(unittest.TestCase): +class CompressionTest(absltest.TestCase): def test_sbs_writer(self): temp_file = self.create_tempfile("test.sbs") @@ -41,4 +56,4 @@ class CompressionTest(unittest.TestCase): if __name__ == "__main__": - unittest.main() + absltest.main() diff --git a/compression/python/requirements.txt b/compression/python/requirements.txt new file mode 100644 index 0000000..f0d1480 --- /dev/null +++ b/compression/python/requirements.txt @@ -0,0 +1 @@ +numpy>=1.26.4 diff --git a/paligemma/image_test.cc b/paligemma/image_test.cc index f114fe5..e2c4bbf 100644 --- a/paligemma/image_test.cc +++ b/paligemma/image_test.cc @@ -29,8 +29,7 @@ float Normalize(float value, float max_value = 255.0f) { } TEST(ImageTest, LoadResize224GetPatch) { - return; // Need to figure out how to get the external path for the test file. - std::string path; + std::string path = "paligemma/testdata/image.ppm"; Image image; EXPECT_EQ(image.width(), 0); EXPECT_EQ(image.height(), 0); diff --git a/paligemma/paligemma_test.cc b/paligemma/paligemma_test.cc index bd8fb2d..fe7fcc9 100644 --- a/paligemma/paligemma_test.cc +++ b/paligemma/paligemma_test.cc @@ -93,8 +93,7 @@ std::string PaliGemmaTest::GemmaReply(const std::string& prompt_text) const{ void PaliGemmaTest::TestQuestions(const char* kQA[][2], size_t num_questions) { ASSERT_NE(s_env->GetModel(), nullptr); - return; // Need to figure out how to get the external path for the test file. - std::string path; + std::string path = "paligemma/testdata/image.ppm"; InitVit(path); for (size_t i = 0; i < num_questions; ++i) { fprintf(stderr, "Question %zu\n\n", i + 1); diff --git a/python/BUILD.bazel b/python/BUILD.bazel new file mode 100644 index 0000000..73a0724 --- /dev/null +++ b/python/BUILD.bazel @@ -0,0 +1,43 @@ +# [internal] load py_binary +load("@pybind11_bazel//:build_defs.bzl", "pybind_extension") + +package( + default_applicable_licenses = [ + "//:license", # Placeholder comment, do not modify + ], + default_visibility = ["//visibility:public"], +) + +pybind_extension( + name = "configs", + srcs = ["configs.cc"], + deps = [ + "//:common", + "//compression:sfp", + ], +) + +pybind_extension( + name = "gemma", + srcs = ["gemma_py.cc"], + deps = [ + "//:app", + "//:benchmark_helper", + "//:gemma_lib", + "//compression:sfp", + "@highway//:hwy", + "@highway//:thread_pool", + ], +) + +py_binary( + name = "run_example", + srcs = ["run_example.py"], + python_version = "PY3", + deps = [ + ":gemma", + "@python_deps//absl_py", + # placeholder forabsl/flags + "@compression_deps//numpy", + ], +) diff --git a/python/configs.cc b/python/configs.cc new file mode 100644 index 0000000..53ba5c4 --- /dev/null +++ b/python/configs.cc @@ -0,0 +1,184 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gemma/configs.h" + +#include +#include +#include + +#include "compression/shared.h" +#include "gemma/tensor_index.h" + +using gcpp::ActivationType; +using gcpp::LayerAttentionType; +using gcpp::LayerConfig; +using gcpp::Model; +using gcpp::ModelConfig; +using gcpp::PostNormType; +using gcpp::PostQKType; +using gcpp::PromptWrapping; +using gcpp::QueryScaleType; +using gcpp::ResidualType; +using gcpp::TensorIndex; +using gcpp::TensorInfo; +using gcpp::Type; +using gcpp::VitConfig; + +namespace pybind11 { + +PYBIND11_MODULE(configs, py_module) { + enum_(py_module, "PromptWrapping") + .value("GEMMA_IT", PromptWrapping::GEMMA_IT) + .value("GEMMA_PT", PromptWrapping::GEMMA_PT) + .value("PALIGEMMA", PromptWrapping::PALIGEMMA); + + enum_(py_module, "Type") + .value("kUnknown", Type::kUnknown) + .value("kF32", Type::kF32) + .value("kBF16", Type::kBF16) + .value("kSFP", Type::kSFP) + .value("kNUQ", Type::kNUQ) + .value("kF64", Type::kF64) + .value("kC64", Type::kC64) + .value("kU128", Type::kU128); + + enum_(py_module, "LayerAttentionType") + .value("kGemma", LayerAttentionType::kGemma) + .value("kGriffinRecurrentBlock", + LayerAttentionType::kGriffinRecurrentBlock) + .value("kVit", LayerAttentionType::kVit); + + enum_(py_module, "PostNormType") + .value("NoPostNorm", PostNormType::None) + .value("Scale", PostNormType::Scale); + + enum_(py_module, "PostQKType") + .value("Rope", PostQKType::Rope) + .value("HalfRope", PostQKType::HalfRope); + + enum_(py_module, "ActivationType") + .value("Gelu", ActivationType::Gelu); + + enum_(py_module, "QueryScaleType") + .value("SqrtKeySize", QueryScaleType::SqrtKeySize) + .value("SqrtModelDimDivNumHeads", + QueryScaleType::SqrtModelDimDivNumHeads); + + enum_(py_module, "ResidualType") + .value("Add", ResidualType::Add); + + enum_(py_module, "Model") + .value("UNKNOWN", Model::UNKNOWN) + .value("GEMMA_2B", Model::GEMMA_2B) + .value("GEMMA_7B", Model::GEMMA_7B) + .value("GEMMA2_9B", Model::GEMMA2_9B) + .value("GEMMA2_27B", Model::GEMMA2_27B) + .value("GRIFFIN_2B", Model::GRIFFIN_2B) + .value("GEMMA_TINY", Model::GEMMA_TINY) + .value("GEMMA2_2B", Model::GEMMA2_2B) + .value("PALIGEMMA2_3B_224", Model::PALIGEMMA2_3B_224) + .value("PALIGEMMA2_10B_224", Model::PALIGEMMA2_10B_224) + .value("PALIGEMMA2_3B_448", Model::PALIGEMMA2_3B_448) + .value("PALIGEMMA2_10B_448", Model::PALIGEMMA2_10B_448) + .value("PALIGEMMA_224", Model::PALIGEMMA_224) + .value("PALIGEMMA_448", Model::PALIGEMMA_448); + + class_(py_module, "TensorInfo") + .def(init()) + .def_readwrite("name", &TensorInfo::name) + .def_readwrite("source_names", &TensorInfo::source_names) + .def_readwrite("preshape", &TensorInfo::preshape) + .def_readwrite("axes", &TensorInfo::axes) + .def_readwrite("shape", &TensorInfo::shape) + .def_readwrite("concat_names", &TensorInfo::concat_names) + .def_readwrite("concat_axis", &TensorInfo::concat_axis) + .def_readwrite("min_size", &TensorInfo::min_size) + .def_readwrite("scaled_softplus", &TensorInfo::scaled_softplus) + .def_readwrite("cols_take_extra_dims", &TensorInfo::cols_take_extra_dims); + + class_(py_module, "TensorIndex") + .def(init()) + .def("tensor_info_from_source_path", + &TensorIndex::TensorInfoFromSourcePath, arg("path")) + .def("tensor_info_from_name", &TensorIndex::TensorInfoFromName, + arg("name")); + + class_(py_module, "LayerConfig") + .def(init()) + .def_readwrite("model_dim", &LayerConfig::model_dim) + .def_readwrite("griffin_dim", &LayerConfig::griffin_dim) + .def_readwrite("ff_hidden_dim", &LayerConfig::ff_hidden_dim) + .def_readwrite("heads", &LayerConfig::heads) + .def_readwrite("kv_heads", &LayerConfig::kv_heads) + .def_readwrite("qkv_dim", &LayerConfig::qkv_dim) + .def_readwrite("conv1d_width", &LayerConfig::conv1d_width) + .def_readwrite("ff_biases", &LayerConfig::ff_biases) + .def_readwrite("softmax_attn_output_biases", + &LayerConfig::softmax_attn_output_biases) + .def_readwrite("optimized_gating", &LayerConfig::optimized_gating) + .def_readwrite("post_norm", &LayerConfig::post_norm) + .def_readwrite("type", &LayerConfig::type) + .def_readwrite("activation", &LayerConfig::activation) + .def_readwrite("post_qk", &LayerConfig::post_qk); + + class_(py_module, "VitConfig") + .def(init()) + .def_readwrite("model_dim", &VitConfig::model_dim) + .def_readwrite("seq_len", &VitConfig::seq_len) + .def_readwrite("num_scales", &VitConfig::num_scales) + .def_readwrite("patch_width", &VitConfig::patch_width) + .def_readwrite("image_size", &VitConfig::image_size) + .def_readwrite("layer_configs", &VitConfig::layer_configs); + + class_(py_module, "ModelConfig") + .def(init()) + .def_readwrite("model_family_version", &ModelConfig::model_family_version) + .def_readwrite("model_name", &ModelConfig::model_name) + .def_readwrite("model", &ModelConfig::model) + .def_readwrite("wrapping", &ModelConfig::wrapping) + .def_readwrite("weight", &ModelConfig::weight) + .def_readwrite("num_layers", &ModelConfig::num_layers) + .def_readwrite("model_dim", &ModelConfig::model_dim) + .def_readwrite("vocab_size", &ModelConfig::vocab_size) + .def_readwrite("seq_len", &ModelConfig::seq_len) + .def_readwrite("num_tensor_scales", &ModelConfig::num_tensor_scales) + .def_readwrite("att_cap", &ModelConfig::att_cap) + .def_readwrite("final_cap", &ModelConfig::final_cap) + .def_readwrite("absolute_pe", &ModelConfig::absolute_pe) + .def_readwrite("use_local_attention", &ModelConfig::use_local_attention) + .def_readwrite("query_scale", &ModelConfig::query_scale) + .def_readwrite("layer_configs", &ModelConfig::layer_configs) + .def_readwrite("attention_window_sizes", + &ModelConfig::attention_window_sizes) + .def_readwrite("scale_names", &ModelConfig::scale_names) + .def_readwrite("norm_num_groups", &ModelConfig::norm_num_groups) + .def_readwrite("vit_config", &ModelConfig::vit_config) + .def("add_layer_config", &ModelConfig::AddLayerConfig, + arg("layer_config")) + .def("test_equal", &ModelConfig::TestEqual, arg("other"), arg("partial"), + arg("debug")); + + // Returns the config for the given model. + py_module.def("config_from_model", &gcpp::ConfigFromModel, arg("model")); + + // Returns the model for the given config, if it matches any standard model. + py_module.def("model_from_config", &gcpp::ModelFromConfig, arg("config")); + + // Returns the sub-config for the ViT model of the PaliGemma model. + py_module.def("vit_config", &gcpp::GetVitConfig, arg("config")); +} + +} // namespace pybind11 diff --git a/python/gemma_py.cc b/python/gemma_py.cc new file mode 100644 index 0000000..a7ce022 --- /dev/null +++ b/python/gemma_py.cc @@ -0,0 +1,303 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "compression/shared.h" +#include "evals/benchmark_helper.h" +#include "gemma/gemma.h" +#include "util/app.h" +#include "hwy/base.h" +#include "hwy/contrib/thread_pool/thread_pool.h" + +namespace py = pybind11; + +static void RemoveTrailingZeros(std::vector &vec) { + auto it = + std::find_if(vec.rbegin(), vec.rend(), [](int v) { return v != 0; }); + vec.erase(it.base(), vec.end()); +} + +// Wrapper around GemmaEnv to expose to Python. +class GemmaModel { + public: + GemmaModel(const gcpp::LoaderArgs& loader, + const gcpp::InferenceArgs& inference, const gcpp::AppArgs& app) + : gemma_(loader, inference, app), last_prob_(0.0f) {} + + // Generates a single example, given a prompt and a callback to stream the + // generated tokens. + void GenerateEx(std::string prompt, gcpp::StreamFunc stream, + size_t max_generated_tokens, float temperature, float seed, + gcpp::AcceptFunc accept, bool skip_prompt) { + gemma_.MutableGen().seed(seed); + std::vector prompt_tokens = gemma_.WrapAndTokenize(prompt); + gcpp::RuntimeConfig& config = gemma_.MutableConfig(); + config.max_generated_tokens = max_generated_tokens; + config.temperature = temperature; + config.verbosity = 0; + config.accept_token = accept; + // If skip_prompt is true, we skip the prompt tokens and only stream the + // generated tokens. + int count_down = prompt_tokens.size(); + auto stream_with_skipping = [&stream, &count_down](int token, float score) { + if (count_down > 0) { + count_down--; + return true; + } + return stream(token, score); + }; + gemma_.QueryModel(prompt_tokens, + skip_prompt ? stream_with_skipping : stream); + } + + // Generates a single example, given a prompt, and returns the result. + std::string Generate(std::string prompt, size_t max_generated_tokens, + float temperature, float seed, + const std::vector& accept, + const std::vector& end) { + std::set end_token_set{}; + for (const std::string& end_token : end) { + std::vector end_token_ids = gemma_.Tokenize(end_token); + end_token_set.insert(end_token_ids.begin(), end_token_ids.end()); + } + + std::vector predicted_token_ids; + predicted_token_ids.reserve(max_generated_tokens); + std::vector prompt_token_ids = gemma_.WrapAndTokenize(prompt); + int generated = 0; + auto stream_token = [&generated, &prompt_token_ids, &predicted_token_ids, + &end_token_set, this](int token, float proba) { + ++generated; + if (generated > prompt_token_ids.size()) { + predicted_token_ids.push_back(token); + if (!end_token_set.empty()) { + return end_token_set.find(token) == end_token_set.end(); + } + } + last_prob_ = proba; + return true; + }; + + std::set accept_token_set{}; + for (const std::string& accept_token : accept) { + std::vector accept_token_ids = gemma_.Tokenize(accept_token); + accept_token_set.insert(accept_token_ids.begin(), accept_token_ids.end()); + } + + auto accept_token = [&predicted_token_ids, &prompt_token_ids, + &accept_token_set](int token, float) { + // i.e. we have no constraints on accepted tokens + if (accept_token_set.empty()) { + return true; + } + + if (predicted_token_ids.size() >= prompt_token_ids.size()) { + return accept_token_set.find(token) != accept_token_set.end(); + } else { + // auto-accept prompt tokens + return true; + } + }; + + gemma_.MutableGen().seed(seed); + gcpp::RuntimeConfig& config = gemma_.MutableConfig(); + config.max_generated_tokens = max_generated_tokens; + config.temperature = temperature; + config.verbosity = 0; + config.accept_token = accept_token; + + gemma_.QueryModel(prompt_token_ids, stream_token); + + if (!predicted_token_ids.empty()) { + return gemma_.StringFromTokens(predicted_token_ids); + } else { + return ""; + } + } + + // Generates a batch of examples, given a list of prompts, and returns the + // results. + std::vector GenerateBatch(const std::vector& inputs, + size_t max_generated_tokens, + float temperature, float seed, + size_t top_k) { + gcpp::RuntimeConfig& config = gemma_.MutableConfig(); + config.max_generated_tokens = max_generated_tokens; + config.temperature = temperature; + config.top_k = top_k; + config.verbosity = 0; + gemma_.MutableGen().seed(seed); + + std::vector outputs = gemma_.BatchQueryModel(inputs); + std::vector result; + result.reserve(outputs.size()); + for (const gcpp::QueryResult& output : outputs) { + result.push_back(output.response.substr(output.response_start_pos)); + } + return result; + } + + // For a PaliGemma model, sets the image to run on. Subseqent calls to + // Generate* will use this image. Throws an error for other models. + void SetImage(const py::array_t& image) { + gcpp::Gemma& model = *(gemma_.GetModel()); + if (model.Info().wrapping != gcpp::PromptWrapping::PALIGEMMA) { + throw std::invalid_argument("Not a PaliGemma model."); + } + py::buffer_info buffer = image.request(); + if (buffer.ndim != 3 || buffer.shape[2] != 3) + throw std::runtime_error( + "Expected a 3D numpy array with shape (height, width, 3)"); + int height = buffer.shape[0]; + int width = buffer.shape[1]; + float* ptr = static_cast(buffer.ptr); + gcpp::Image c_image; + c_image.Set(height, width, ptr); + const size_t image_size = model.GetModelConfig().vit_config.image_size; + c_image.Resize(image_size, image_size); + image_tokens_ = gcpp::ImageTokens(gcpp::Extents2D( + model.GetModelConfig().vit_config.seq_len, + model.GetModelConfig().model_dim)); + gcpp::RuntimeConfig runtime_config = {.gen = &gemma_.MutableGen(), + .verbosity = 0}; + model.GenerateImageTokens(runtime_config, c_image, image_tokens_); + } + + // Generates a response to the given prompt, using the last set image. + // Uses the prompt_tokens if provided, otherwise tokenizes the prompt string. + std::pair> GenerateWithImage( + std::string prompt, size_t max_generated_tokens, float temperature, + float seed, gcpp::AcceptFunc accept, std::vector prompt_tokens) { + if (image_tokens_.Cols() == 0) { + throw std::invalid_argument("No image set."); + } + gcpp::Gemma& model = *(gemma_.GetModel()); + gemma_.MutableGen().seed(seed); + gcpp::RuntimeConfig& config = gemma_.MutableConfig(); + config.max_generated_tokens = max_generated_tokens; + config.temperature = temperature; + config.verbosity = 0; + config.accept_token = accept; + config.image_tokens = &image_tokens_; + std::vector tokens; + if (!prompt_tokens.empty()) { + if (!prompt.empty()) { + throw std::invalid_argument( + "Cannot pass both prompt and prompt_tokens."); + } + tokens = prompt_tokens; + RemoveTrailingZeros(tokens); // Remove padding, if any. + } else { + tokens = gemma_.WrapAndTokenize(prompt); + } + tokens.insert(tokens.begin(), image_tokens_.BatchSize(), 0); + size_t num_tokens = tokens.size(); + size_t prefix_end = num_tokens; + config.prefill_tbatch_size = num_tokens; + int count_down = static_cast(num_tokens); + std::vector response_tokens; + auto stream_token = [&](int token, float) { + if (count_down > 0) { + count_down--; + return true; + } + response_tokens.push_back(token); + return true; + }; + config.stream_token = stream_token; + gcpp::TimingInfo timing_info = {.verbosity = 0}; + model.Generate(config, tokens, /*pos=*/0, prefix_end, + gemma_.MutableKVCache(), timing_info); + std::string response; + model.Tokenizer().Decode(response_tokens, &response); + return {response, response_tokens}; + } + + float GetLastProb() const { return last_prob_; } + + std::string Detokenize(const std::vector& token_ids) const { + return gemma_.StringFromTokens(token_ids); + } + + bool ModelIsLoaded() const { return gemma_.GetModel() != nullptr; } + + private: + gcpp::GemmaEnv gemma_; + gcpp::ImageTokens image_tokens_; + float last_prob_; +}; + +PYBIND11_MODULE(gemma, mod) { + py::class_(mod, "GemmaModel") + .def(py::init([](std::string tokenizer, std::string weights, + std::string model, std::string weight_type, + size_t max_threads) { + gcpp::LoaderArgs loader(tokenizer, weights, model); + if (const char* err = loader.Validate()) { + throw std::invalid_argument(err); + } + loader.weight_type_str = weight_type; + gcpp::InferenceArgs inference; + inference.max_generated_tokens = 512; + gcpp::AppArgs app; + app.max_threads = max_threads; + auto gemma = + std::make_unique(loader, inference, app); + if (!gemma->ModelIsLoaded()) { + throw std::invalid_argument("Could not load model."); + } + return gemma; + }), + py::arg("tokenizer_path"), py::arg("weights_path"), + py::arg("model_flag"), py::arg("weight_type") = "sfp", + py::arg("max_threads") = 0) + .def("generate_ex", &GemmaModel::GenerateEx, py::arg("prompt"), + py::arg("stream"), py::arg("max_generated_tokens") = 1024, + py::arg("temperature") = 0.9, py::arg("seed") = 123456789, + py::arg("accept") = gcpp::AcceptFunc(), + py::arg("skip_prompt") = false) + .def("generate", &GemmaModel::Generate, py::arg("prompt"), + py::arg("max_generated_tokens") = 1024, py::arg("temperature") = 0.9, + py::arg("seed") = 123456789, + py::arg("accept") = std::vector(), + py::arg("end") = std::vector()) + .def("generate_batch", &GemmaModel::GenerateBatch, py::arg("inputs"), + py::arg("max_generated_tokens") = 1024, py::arg("temperature") = 0.9, + py::arg("seed") = 123456789, py::arg("top_k") = 5) + .def("set_image", &GemmaModel::SetImage, py::arg("image")) + .def("generate_with_image", &GemmaModel::GenerateWithImage, + py::arg("prompt") = "", py::arg("max_generated_tokens") = 1024, + py::arg("temperature") = 0.9, py::arg("seed") = 123456789, + py::arg("accept") = gcpp::AcceptFunc(), + py::arg("prompt_tokens") = std::vector()) + .def("get_last_prob", &GemmaModel::GetLastProb) + .def("detokenize", &GemmaModel::Detokenize, py::arg("token_ids")); +} diff --git a/python/requirements.txt b/python/requirements.txt new file mode 100644 index 0000000..b998a06 --- /dev/null +++ b/python/requirements.txt @@ -0,0 +1 @@ +absl-py diff --git a/python/run_example.py b/python/run_example.py new file mode 100644 index 0000000..ba87dc0 --- /dev/null +++ b/python/run_example.py @@ -0,0 +1,108 @@ +# Copyright 2024 Google LLC +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A simple example of using the gemma.cpp Python wrapper.""" + +from collections.abc import Sequence +import os + +from absl import app +from absl import flags +import numpy as np + +from python import gemma + + +_MODEL_DIR = flags.DEFINE_string( + "model_dir", + "", + "Path to the Gemma model directory.", +) + +_PROMPT = flags.DEFINE_string( + "prompt", + "Write an email to the moon.", + "Prompt to generate text with.", +) + + +def main(argv: Sequence[str]) -> None: + if len(argv) > 1: + raise app.UsageError("Too many command-line arguments.") + + tokenizer_path = os.path.join(_MODEL_DIR.value, "tokenizer.spm") + weights_path = os.path.join(_MODEL_DIR.value, "gemma2-2b-it-sfp.sbs") + print(f"Loading model from {tokenizer_path} and {weights_path}") + model = gemma.GemmaModel( + tokenizer_path=tokenizer_path, + weights_path=weights_path, + model_flag="gemma2-2b-it", + max_threads=24, + ) + + prompt = _PROMPT.value + print(f"Running example with prompt='{prompt}'") + output = model.generate(prompt) + print(f"Generated output:\n{output}") + + def callback(tok, _): + s = model.detokenize([tok]) + print(s, end="", flush=True) + return True + + print(f"\n\nRunning example with streaming callback, prompt='{prompt}'") + print("Generating output:\n") + model.generate_ex(prompt, callback, skip_prompt=True) + + prompts = [ + prompt, + "Tell me a joke.", + "Please recite the first paragraph of the Declaration of Independence.", + prompt, + ] + print("\n\n\nRunning example with batch generation") + outputs = model.generate_batch( + prompts, max_generated_tokens=16, temperature=2.0, top_k=30, seed=123456, + ) + print("Generated outputs:") + for prompt, output in zip(prompts, outputs): + print(f"Prompt: '{prompt}' --->\nOutput: {output}\n") + + # PaliGemma example. + tokenizer_path = os.path.join(_MODEL_DIR.value, "paligemma_tokenizer.model") + weights_path = os.path.join(_MODEL_DIR.value, "paligemma-3b-mix-224-sfp.sbs") + print(f"Loading model from {tokenizer_path} and {weights_path}") + model = gemma.GemmaModel( + tokenizer_path=tokenizer_path, + weights_path=weights_path, + model_flag="paligemma-224", + max_threads=24, + ) + image = np.array( + [ + [[255, 0, 0], [0, 255, 0]], # Red, Green + [[0, 0, 255], [255, 255, 255]], # Blue, White + ], + dtype=np.float32, + ) + model.set_image(image) + prompt = "Describe this image." + print(f"Running example with a tiny image and prompt='{prompt}'.") + output, tokens = model.generate_with_image(prompt) + print(f"Generated {len(tokens)} tokens, output:\n{output}") + + +if __name__ == "__main__": + app.run(main)