gemma.cpp/python/gemma_py.cc

304 lines
12 KiB
C++

// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <Python.h>
#include <pybind11/cast.h>
#include <pybind11/functional.h>
#include <pybind11/numpy.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <algorithm>
#include <memory>
#include <random>
#include <set>
#include <stdexcept>
#include <string>
#include <utility>
#include <vector>
#include "compression/shared.h"
#include "evals/benchmark_helper.h"
#include "gemma/gemma.h"
#include "util/app.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
namespace py = pybind11;
static void RemoveTrailingZeros(std::vector<int> &vec) {
auto it =
std::find_if(vec.rbegin(), vec.rend(), [](int v) { return v != 0; });
vec.erase(it.base(), vec.end());
}
// Wrapper around GemmaEnv to expose to Python.
class GemmaModel {
public:
GemmaModel(const gcpp::LoaderArgs& loader,
const gcpp::InferenceArgs& inference, const gcpp::AppArgs& app)
: gemma_(loader, inference, app), last_prob_(0.0f) {}
// Generates a single example, given a prompt and a callback to stream the
// generated tokens.
void GenerateEx(std::string prompt, gcpp::StreamFunc stream,
size_t max_generated_tokens, float temperature, float seed,
gcpp::AcceptFunc accept, bool skip_prompt) {
gemma_.MutableGen().seed(seed);
std::vector<int> prompt_tokens = gemma_.WrapAndTokenize(prompt);
gcpp::RuntimeConfig& config = gemma_.MutableConfig();
config.max_generated_tokens = max_generated_tokens;
config.temperature = temperature;
config.verbosity = 0;
config.accept_token = accept;
// If skip_prompt is true, we skip the prompt tokens and only stream the
// generated tokens.
int count_down = prompt_tokens.size();
auto stream_with_skipping = [&stream, &count_down](int token, float score) {
if (count_down > 0) {
count_down--;
return true;
}
return stream(token, score);
};
gemma_.QueryModel(prompt_tokens,
skip_prompt ? stream_with_skipping : stream);
}
// Generates a single example, given a prompt, and returns the result.
std::string Generate(std::string prompt, size_t max_generated_tokens,
float temperature, float seed,
const std::vector<std::string>& accept,
const std::vector<std::string>& end) {
std::set<int> end_token_set{};
for (const std::string& end_token : end) {
std::vector<int> end_token_ids = gemma_.Tokenize(end_token);
end_token_set.insert(end_token_ids.begin(), end_token_ids.end());
}
std::vector<int> predicted_token_ids;
predicted_token_ids.reserve(max_generated_tokens);
std::vector<int> prompt_token_ids = gemma_.WrapAndTokenize(prompt);
int generated = 0;
auto stream_token = [&generated, &prompt_token_ids, &predicted_token_ids,
&end_token_set, this](int token, float proba) {
++generated;
if (generated > prompt_token_ids.size()) {
predicted_token_ids.push_back(token);
if (!end_token_set.empty()) {
return end_token_set.find(token) == end_token_set.end();
}
}
last_prob_ = proba;
return true;
};
std::set<int> accept_token_set{};
for (const std::string& accept_token : accept) {
std::vector<int> accept_token_ids = gemma_.Tokenize(accept_token);
accept_token_set.insert(accept_token_ids.begin(), accept_token_ids.end());
}
auto accept_token = [&predicted_token_ids, &prompt_token_ids,
&accept_token_set](int token, float) {
// i.e. we have no constraints on accepted tokens
if (accept_token_set.empty()) {
return true;
}
if (predicted_token_ids.size() >= prompt_token_ids.size()) {
return accept_token_set.find(token) != accept_token_set.end();
} else {
// auto-accept prompt tokens
return true;
}
};
gemma_.MutableGen().seed(seed);
gcpp::RuntimeConfig& config = gemma_.MutableConfig();
config.max_generated_tokens = max_generated_tokens;
config.temperature = temperature;
config.verbosity = 0;
config.accept_token = accept_token;
gemma_.QueryModel(prompt_token_ids, stream_token);
if (!predicted_token_ids.empty()) {
return gemma_.StringFromTokens(predicted_token_ids);
} else {
return "";
}
}
// Generates a batch of examples, given a list of prompts, and returns the
// results.
std::vector<std::string> GenerateBatch(const std::vector<std::string>& inputs,
size_t max_generated_tokens,
float temperature, float seed,
size_t top_k) {
gcpp::RuntimeConfig& config = gemma_.MutableConfig();
config.max_generated_tokens = max_generated_tokens;
config.temperature = temperature;
config.top_k = top_k;
config.verbosity = 0;
gemma_.MutableGen().seed(seed);
std::vector<gcpp::QueryResult> outputs = gemma_.BatchQueryModel(inputs);
std::vector<std::string> result;
result.reserve(outputs.size());
for (const gcpp::QueryResult& output : outputs) {
result.push_back(output.response.substr(output.response_start_pos));
}
return result;
}
// For a PaliGemma model, sets the image to run on. Subseqent calls to
// Generate* will use this image. Throws an error for other models.
void SetImage(const py::array_t<float, py::array::c_style |
py::array::forcecast>& image) {
gcpp::Gemma& model = *(gemma_.GetModel());
if (model.Info().wrapping != gcpp::PromptWrapping::PALIGEMMA) {
throw std::invalid_argument("Not a PaliGemma model.");
}
py::buffer_info buffer = image.request();
if (buffer.ndim != 3 || buffer.shape[2] != 3)
throw std::runtime_error(
"Expected a 3D numpy array with shape (height, width, 3)");
int height = buffer.shape[0];
int width = buffer.shape[1];
float* ptr = static_cast<float*>(buffer.ptr);
gcpp::Image c_image;
c_image.Set(height, width, ptr);
const size_t image_size = model.GetModelConfig().vit_config.image_size;
c_image.Resize(image_size, image_size);
image_tokens_ = gcpp::ImageTokens(gcpp::Extents2D(
model.GetModelConfig().vit_config.seq_len,
model.GetModelConfig().model_dim));
gcpp::RuntimeConfig runtime_config = {.gen = &gemma_.MutableGen(),
.verbosity = 0};
model.GenerateImageTokens(runtime_config, c_image, image_tokens_);
}
// Generates a response to the given prompt, using the last set image.
// Uses the prompt_tokens if provided, otherwise tokenizes the prompt string.
std::pair<std::string, std::vector<int>> GenerateWithImage(
std::string prompt, size_t max_generated_tokens, float temperature,
float seed, gcpp::AcceptFunc accept, std::vector<int> prompt_tokens) {
if (image_tokens_.Cols() == 0) {
throw std::invalid_argument("No image set.");
}
gcpp::Gemma& model = *(gemma_.GetModel());
gemma_.MutableGen().seed(seed);
gcpp::RuntimeConfig& config = gemma_.MutableConfig();
config.max_generated_tokens = max_generated_tokens;
config.temperature = temperature;
config.verbosity = 0;
config.accept_token = accept;
config.image_tokens = &image_tokens_;
std::vector<int> tokens;
if (!prompt_tokens.empty()) {
if (!prompt.empty()) {
throw std::invalid_argument(
"Cannot pass both prompt and prompt_tokens.");
}
tokens = prompt_tokens;
RemoveTrailingZeros(tokens); // Remove padding, if any.
} else {
tokens = gemma_.WrapAndTokenize(prompt);
}
tokens.insert(tokens.begin(), image_tokens_.BatchSize(), 0);
size_t num_tokens = tokens.size();
size_t prefix_end = num_tokens;
config.prefill_tbatch_size = num_tokens;
int count_down = static_cast<int>(num_tokens);
std::vector<int> response_tokens;
auto stream_token = [&](int token, float) {
if (count_down > 0) {
count_down--;
return true;
}
response_tokens.push_back(token);
return true;
};
config.stream_token = stream_token;
gcpp::TimingInfo timing_info = {.verbosity = 0};
model.Generate(config, tokens, /*pos=*/0, prefix_end,
gemma_.MutableKVCache(), timing_info);
std::string response;
model.Tokenizer().Decode(response_tokens, &response);
return {response, response_tokens};
}
float GetLastProb() const { return last_prob_; }
std::string Detokenize(const std::vector<int>& token_ids) const {
return gemma_.StringFromTokens(token_ids);
}
bool ModelIsLoaded() const { return gemma_.GetModel() != nullptr; }
private:
gcpp::GemmaEnv gemma_;
gcpp::ImageTokens image_tokens_;
float last_prob_;
};
PYBIND11_MODULE(gemma, mod) {
py::class_<GemmaModel>(mod, "GemmaModel")
.def(py::init([](std::string tokenizer, std::string weights,
std::string model, std::string weight_type,
size_t max_threads) {
gcpp::LoaderArgs loader(tokenizer, weights, model);
if (const char* err = loader.Validate()) {
throw std::invalid_argument(err);
}
loader.weight_type_str = weight_type;
gcpp::InferenceArgs inference;
inference.max_generated_tokens = 512;
gcpp::AppArgs app;
app.max_threads = max_threads;
auto gemma =
std::make_unique<GemmaModel>(loader, inference, app);
if (!gemma->ModelIsLoaded()) {
throw std::invalid_argument("Could not load model.");
}
return gemma;
}),
py::arg("tokenizer_path"), py::arg("weights_path"),
py::arg("model_flag"), py::arg("weight_type") = "sfp",
py::arg("max_threads") = 0)
.def("generate_ex", &GemmaModel::GenerateEx, py::arg("prompt"),
py::arg("stream"), py::arg("max_generated_tokens") = 1024,
py::arg("temperature") = 0.9, py::arg("seed") = 123456789,
py::arg("accept") = gcpp::AcceptFunc(),
py::arg("skip_prompt") = false)
.def("generate", &GemmaModel::Generate, py::arg("prompt"),
py::arg("max_generated_tokens") = 1024, py::arg("temperature") = 0.9,
py::arg("seed") = 123456789,
py::arg("accept") = std::vector<std::string>(),
py::arg("end") = std::vector<std::string>())
.def("generate_batch", &GemmaModel::GenerateBatch, py::arg("inputs"),
py::arg("max_generated_tokens") = 1024, py::arg("temperature") = 0.9,
py::arg("seed") = 123456789, py::arg("top_k") = 5)
.def("set_image", &GemmaModel::SetImage, py::arg("image"))
.def("generate_with_image", &GemmaModel::GenerateWithImage,
py::arg("prompt") = "", py::arg("max_generated_tokens") = 1024,
py::arg("temperature") = 0.9, py::arg("seed") = 123456789,
py::arg("accept") = gcpp::AcceptFunc(),
py::arg("prompt_tokens") = std::vector<int>())
.def("get_last_prob", &GemmaModel::GetLastProb)
.def("detokenize", &GemmaModel::Detokenize, py::arg("token_ids"));
}