mirror of https://github.com/google/gemma.cpp.git
Simplified interface class and example for Gemma.cpp usage.
PiperOrigin-RevId: 720591037
This commit is contained in:
parent
7af2e70321
commit
23dac72463
|
|
@ -0,0 +1,39 @@
|
||||||
|
# Hello World example frontend to gemma.cpp.
|
||||||
|
package(
|
||||||
|
default_applicable_licenses = [
|
||||||
|
"//:license", # Placeholder comment, do not modify
|
||||||
|
],
|
||||||
|
default_visibility = ["//visibility:public"],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "gemma",
|
||||||
|
hdrs = ["gemma.hpp"],
|
||||||
|
deps = [
|
||||||
|
"//:app",
|
||||||
|
"//:args",
|
||||||
|
"//:common",
|
||||||
|
"//:gemma_lib",
|
||||||
|
"//:threading",
|
||||||
|
"//:tokenizer",
|
||||||
|
"@highway//:hwy",
|
||||||
|
"@highway//:thread_pool",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_binary(
|
||||||
|
name = "simplified_gemma",
|
||||||
|
srcs = ["run.cc"],
|
||||||
|
deps = [
|
||||||
|
":gemma",
|
||||||
|
# Placeholder for internal dep, do not remove.,
|
||||||
|
"//:app",
|
||||||
|
"//:args",
|
||||||
|
"//:common",
|
||||||
|
"//:gemma_lib",
|
||||||
|
"//:threading",
|
||||||
|
"//:tokenizer",
|
||||||
|
"@highway//:hwy",
|
||||||
|
"@highway//:thread_pool",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
@ -0,0 +1,49 @@
|
||||||
|
# Copyright 2019 Google LLC
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
cmake_minimum_required(VERSION 3.11)
|
||||||
|
project(simplified_gemma)
|
||||||
|
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||||
|
|
||||||
|
include(FetchContent)
|
||||||
|
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG f2209b911c74019e85d0b7a7a2833c9a2e1b7995)
|
||||||
|
FetchContent_MakeAvailable(highway)
|
||||||
|
FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 53de76561cfc149d3c01037f0595669ad32a5e7c)
|
||||||
|
FetchContent_MakeAvailable(sentencepiece)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# Allow for both local and remote building)
|
||||||
|
option(BUILD_MODE "'local' or 'remote' git fetch for builds")
|
||||||
|
if (NOT BUILD_MODE)
|
||||||
|
set(BUILD_MODE "remote")
|
||||||
|
endif()
|
||||||
|
if (BUILD_MODE STREQUAL "local")
|
||||||
|
# Relative path to gemma.cpp from examples/simplified_gemma/build/
|
||||||
|
FetchContent_Declare(gemma SOURCE_DIR ../../..)
|
||||||
|
else()
|
||||||
|
FetchContent_Declare(gemma GIT_REPOSITORY https://github.com/google/gemma.cpp.git GIT_TAG a9aa63fd2ea6b786ed0706d619588bfe2d43370e)
|
||||||
|
endif()
|
||||||
|
FetchContent_MakeAvailable(gemma)
|
||||||
|
|
||||||
|
if(NOT CMAKE_BUILD_TYPE)
|
||||||
|
set(CMAKE_BUILD_TYPE "Release")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
add_executable(simplified_gemma run.cc)
|
||||||
|
target_link_libraries(simplified_gemma hwy hwy_contrib sentencepiece libgemma)
|
||||||
|
FetchContent_GetProperties(sentencepiece)
|
||||||
|
target_include_directories(simplified_gemma PRIVATE ${sentencepiece_SOURCE_DIR})
|
||||||
|
target_compile_definitions(simplified_gemma PRIVATE $<$<PLATFORM_ID:Windows>:_CRT_SECURE_NO_WARNINGS NOMINMAX>)
|
||||||
|
target_compile_options(simplified_gemma PRIVATE $<$<PLATFORM_ID:Windows>:-Wno-deprecated-declarations>)
|
||||||
|
|
@ -0,0 +1,60 @@
|
||||||
|
# Simplified Gemma.cpp Example
|
||||||
|
|
||||||
|
This is a minimal/template project for using `gemma.cpp` as a library. Instead
|
||||||
|
of an interactive interface, it sets up the model state and generates text for a
|
||||||
|
single hard coded prompt.
|
||||||
|
|
||||||
|
Build steps are similar to the main `gemma` executable. For now only
|
||||||
|
`cmake`/`make` is available for builds (PRs welcome for other build options).
|
||||||
|
|
||||||
|
First use `cmake` to configure the project, starting from the `simplified_gemma`
|
||||||
|
example directory (`gemma.cpp/examples/simplified_gemma`):
|
||||||
|
|
||||||
|
```sh
|
||||||
|
cmake -B build
|
||||||
|
```
|
||||||
|
|
||||||
|
This sets up a build configuration in `gemma.cpp/examples/simplified_gemma/build`.
|
||||||
|
Note that this fetches `libgemma` from a git commit hash on github.
|
||||||
|
Alternatively if you want to build using the local version of `gemma.cpp` use:
|
||||||
|
|
||||||
|
```sh
|
||||||
|
cmake -B build -DBUILD_MODE=local
|
||||||
|
```
|
||||||
|
|
||||||
|
Make sure you delete the contents of the build directory before changing
|
||||||
|
configurations.
|
||||||
|
|
||||||
|
Then use `make` to build the project:
|
||||||
|
|
||||||
|
```sh
|
||||||
|
cd build
|
||||||
|
make simplified_gemma
|
||||||
|
```
|
||||||
|
|
||||||
|
As with the top-level `gemma.cpp` project you can use the `make` commands `-j`
|
||||||
|
flag to use parallel threads for faster builds.
|
||||||
|
|
||||||
|
From inside the `gemma.cpp/examples/simplified_gemma/build` directory, there should
|
||||||
|
be a `simplified_gemma` executable. You can run it with the same 3 model arguments as
|
||||||
|
gemma.cpp specifying the tokenizer, compressed weights file, and model type, for
|
||||||
|
example:
|
||||||
|
|
||||||
|
```sh
|
||||||
|
./simplified_gemma --tokenizer tokenizer.spm --compressed_weights 2b-it-sfp.sbs --model 2b-it
|
||||||
|
```
|
||||||
|
|
||||||
|
Should print a greeting to the terminal:
|
||||||
|
|
||||||
|
```
|
||||||
|
"Hello, world! It's a pleasure to greet you all. May your day be filled with joy, peace, and all the things that make your heart soar.
|
||||||
|
```
|
||||||
|
|
||||||
|
For a demonstration of constrained decoding, add the `--reject` flag followed by
|
||||||
|
a list of token IDs (note that it must be the last flag, since it consumes every
|
||||||
|
subsequent argument). For example, to reject variations of the word "greeting",
|
||||||
|
run:
|
||||||
|
|
||||||
|
```sh
|
||||||
|
./simplified_gemma [...] --reject 32338 42360 78107 106837 132832 143859 154230 190205
|
||||||
|
```
|
||||||
|
|
@ -0,0 +1,2 @@
|
||||||
|
*
|
||||||
|
!.gitignore
|
||||||
|
|
@ -0,0 +1,114 @@
|
||||||
|
// Copyright 2024 Google LLC
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by app_licable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include <stddef.h>
|
||||||
|
|
||||||
|
#include <cstdlib>
|
||||||
|
#include <cstring>
|
||||||
|
#include <iostream>
|
||||||
|
#include <random>
|
||||||
|
#include <set>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "third_party/gemma_cpp/gemma/gemma.h"
|
||||||
|
#include "third_party/gemma_cpp/gemma/tokenizer.h"
|
||||||
|
#include "third_party/gemma_cpp/util/app.h" // LoaderArgs
|
||||||
|
#include "third_party/gemma_cpp/util/threading.h"
|
||||||
|
#include "third_party/highway/hwy/base.h"
|
||||||
|
#include "third_party/highway/hwy/contrib/thread_pool/thread_pool.h"
|
||||||
|
|
||||||
|
class SimplifiedGemma {
|
||||||
|
public:
|
||||||
|
SimplifiedGemma(const gcpp::LoaderArgs& loader,
|
||||||
|
const gcpp::InferenceArgs& inference = gcpp::InferenceArgs(),
|
||||||
|
const gcpp::AppArgs& app = gcpp::AppArgs())
|
||||||
|
: loader_(loader),
|
||||||
|
inference_(inference),
|
||||||
|
app_(app),
|
||||||
|
pools_(gcpp::CreatePools(app_)),
|
||||||
|
model_(gcpp::CreateGemma(loader_, pools_)) {
|
||||||
|
Init();
|
||||||
|
}
|
||||||
|
|
||||||
|
SimplifiedGemma(int argc, char** argv)
|
||||||
|
: loader_(argc, argv, /*validate=*/true),
|
||||||
|
inference_(argc, argv),
|
||||||
|
app_(argc, argv),
|
||||||
|
pools_(gcpp::CreatePools(app_)),
|
||||||
|
model_(gcpp::CreateGemma(loader_, pools_)) {
|
||||||
|
Init();
|
||||||
|
}
|
||||||
|
|
||||||
|
void Init() {
|
||||||
|
gcpp::Allocator::Init(pools_.Topology());
|
||||||
|
|
||||||
|
// Instantiate model and KV Cache
|
||||||
|
kv_cache_ = gcpp::KVCache::Create(model_.GetModelConfig(),
|
||||||
|
inference_.prefill_tbatch_size);
|
||||||
|
|
||||||
|
// Initialize random number generator
|
||||||
|
std::random_device rd;
|
||||||
|
gen_.seed(rd());
|
||||||
|
}
|
||||||
|
|
||||||
|
void Generate(std::string& prompt, size_t max_generated_tokens = 1024,
|
||||||
|
float temperature = 0.7,
|
||||||
|
const std::set<int>& reject_tokens = {}) {
|
||||||
|
size_t generated = 0;
|
||||||
|
|
||||||
|
const std::vector<int> tokens = gcpp::WrapAndTokenize(
|
||||||
|
model_.Tokenizer(), loader_.Info(), generated, prompt);
|
||||||
|
const size_t prompt_size = tokens.size();
|
||||||
|
|
||||||
|
// This callback function gets invoked every time a token is generated
|
||||||
|
auto stream_token = [&generated, &prompt_size, this](int token, float) {
|
||||||
|
++generated;
|
||||||
|
if (generated < prompt_size) {
|
||||||
|
// print feedback
|
||||||
|
} else if (token != gcpp::EOS_ID) {
|
||||||
|
std::string token_text;
|
||||||
|
HWY_ASSERT(this->model_.Tokenizer().Decode({token}, &token_text));
|
||||||
|
std::cout << token_text << std::flush;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
};
|
||||||
|
|
||||||
|
gcpp::TimingInfo timing_info;
|
||||||
|
gcpp::RuntimeConfig runtime_config = {
|
||||||
|
.max_generated_tokens = max_generated_tokens,
|
||||||
|
.temperature = temperature,
|
||||||
|
.gen = &gen_,
|
||||||
|
.verbosity = 0,
|
||||||
|
.stream_token = stream_token,
|
||||||
|
.accept_token =
|
||||||
|
[&](int token, float /* prob */) {
|
||||||
|
return !reject_tokens.contains(token);
|
||||||
|
},
|
||||||
|
};
|
||||||
|
model_.Generate(runtime_config, tokens, 0, kv_cache_, timing_info);
|
||||||
|
}
|
||||||
|
~SimplifiedGemma() = default;
|
||||||
|
|
||||||
|
private:
|
||||||
|
gcpp::LoaderArgs loader_;
|
||||||
|
gcpp::InferenceArgs inference_;
|
||||||
|
gcpp::AppArgs app_;
|
||||||
|
gcpp::NestedPools pools_;
|
||||||
|
gcpp::Gemma model_;
|
||||||
|
gcpp::KVCache kv_cache_;
|
||||||
|
std::mt19937 gen_;
|
||||||
|
std::string validation_error_;
|
||||||
|
};
|
||||||
|
|
@ -0,0 +1,50 @@
|
||||||
|
// Copyright 2024 Google LLC
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include <stddef.h>
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
// Placeholder for internal header, do not modify.
|
||||||
|
#include "third_party/gemma_cpp/examples/simplified_gemma/gemma.hpp"
|
||||||
|
#include "util/app.h" // LoaderArgs
|
||||||
|
|
||||||
|
int main(int argc, char** argv) {
|
||||||
|
{
|
||||||
|
// Placeholder for internal init, do not modify.
|
||||||
|
}
|
||||||
|
|
||||||
|
// Standard usage: LoaderArgs takes argc and argv as input, then parses
|
||||||
|
// necessary flags.
|
||||||
|
gcpp::LoaderArgs loader(argc, argv, /*validate=*/true);
|
||||||
|
|
||||||
|
// Optional: LoaderArgs can also take tokenizer and weights paths directly.
|
||||||
|
//
|
||||||
|
// gcpp::LoaderArgs loader("/path/to/tokenizer", "/path/to/weights",
|
||||||
|
// "model_identifier");
|
||||||
|
|
||||||
|
// Optional: InferenceArgs and AppArgs can be passed in as well. If not
|
||||||
|
// specified, default values will be used.
|
||||||
|
//
|
||||||
|
// gcpp::InferenceArgs inference(argc, argv);
|
||||||
|
// gcpp::AppArgs app(argc, argv);
|
||||||
|
// SimplifiedGemma gemma(loader, inference, app);
|
||||||
|
|
||||||
|
SimplifiedGemma gemma(loader);
|
||||||
|
std::string prompt = "Write a greeting to the world.";
|
||||||
|
gemma.Generate(prompt, 256, 0.6);
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
16
util/app.h
16
util/app.h
|
|
@ -126,15 +126,27 @@ static inline NestedPools CreatePools(const AppArgs& app) {
|
||||||
}
|
}
|
||||||
|
|
||||||
struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
||||||
LoaderArgs(int argc, char* argv[]) {
|
LoaderArgs(int argc, char* argv[], bool validate = true) {
|
||||||
InitAndParse(argc, argv);
|
InitAndParse(argc, argv);
|
||||||
|
|
||||||
|
if (validate) {
|
||||||
|
if (const char* error = Validate()) {
|
||||||
|
HWY_ABORT("Invalid args: %s", error);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
LoaderArgs(const std::string& tokenizer_path, const std::string& weights_path,
|
LoaderArgs(const std::string& tokenizer_path, const std::string& weights_path,
|
||||||
const std::string& model) {
|
const std::string& model, bool validate = true) {
|
||||||
Init(); // Init sets to defaults, so assignments must come after Init().
|
Init(); // Init sets to defaults, so assignments must come after Init().
|
||||||
tokenizer.path = tokenizer_path;
|
tokenizer.path = tokenizer_path;
|
||||||
weights.path = weights_path;
|
weights.path = weights_path;
|
||||||
model_type_str = model;
|
model_type_str = model;
|
||||||
|
|
||||||
|
if (validate) {
|
||||||
|
if (const char* error = Validate()) {
|
||||||
|
HWY_ABORT("Invalid args: %s", error);
|
||||||
|
}
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Returns error string or nullptr if OK.
|
// Returns error string or nullptr if OK.
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue