diff --git a/examples/simplified_gemma/BUILD.bazel b/examples/simplified_gemma/BUILD.bazel new file mode 100644 index 0000000..2ae7861 --- /dev/null +++ b/examples/simplified_gemma/BUILD.bazel @@ -0,0 +1,39 @@ +# Hello World example frontend to gemma.cpp. +package( + default_applicable_licenses = [ + "//:license", # Placeholder comment, do not modify + ], + default_visibility = ["//visibility:public"], +) + +cc_library( + name = "gemma", + hdrs = ["gemma.hpp"], + deps = [ + "//:app", + "//:args", + "//:common", + "//:gemma_lib", + "//:threading", + "//:tokenizer", + "@highway//:hwy", + "@highway//:thread_pool", + ], +) + +cc_binary( + name = "simplified_gemma", + srcs = ["run.cc"], + deps = [ + ":gemma", + # Placeholder for internal dep, do not remove., + "//:app", + "//:args", + "//:common", + "//:gemma_lib", + "//:threading", + "//:tokenizer", + "@highway//:hwy", + "@highway//:thread_pool", + ], +) diff --git a/examples/simplified_gemma/CMakeLists.txt b/examples/simplified_gemma/CMakeLists.txt new file mode 100644 index 0000000..609459e --- /dev/null +++ b/examples/simplified_gemma/CMakeLists.txt @@ -0,0 +1,49 @@ +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +cmake_minimum_required(VERSION 3.11) +project(simplified_gemma) +set(CMAKE_CXX_STANDARD_REQUIRED ON) + +include(FetchContent) +FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG f2209b911c74019e85d0b7a7a2833c9a2e1b7995) +FetchContent_MakeAvailable(highway) +FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 53de76561cfc149d3c01037f0595669ad32a5e7c) +FetchContent_MakeAvailable(sentencepiece) + + + +# Allow for both local and remote building) +option(BUILD_MODE "'local' or 'remote' git fetch for builds") +if (NOT BUILD_MODE) + set(BUILD_MODE "remote") +endif() +if (BUILD_MODE STREQUAL "local") + # Relative path to gemma.cpp from examples/simplified_gemma/build/ + FetchContent_Declare(gemma SOURCE_DIR ../../..) +else() + FetchContent_Declare(gemma GIT_REPOSITORY https://github.com/google/gemma.cpp.git GIT_TAG a9aa63fd2ea6b786ed0706d619588bfe2d43370e) +endif() +FetchContent_MakeAvailable(gemma) + +if(NOT CMAKE_BUILD_TYPE) + set(CMAKE_BUILD_TYPE "Release") +endif() + +add_executable(simplified_gemma run.cc) +target_link_libraries(simplified_gemma hwy hwy_contrib sentencepiece libgemma) +FetchContent_GetProperties(sentencepiece) +target_include_directories(simplified_gemma PRIVATE ${sentencepiece_SOURCE_DIR}) +target_compile_definitions(simplified_gemma PRIVATE $<$:_CRT_SECURE_NO_WARNINGS NOMINMAX>) +target_compile_options(simplified_gemma PRIVATE $<$:-Wno-deprecated-declarations>) diff --git a/examples/simplified_gemma/README.md b/examples/simplified_gemma/README.md new file mode 100644 index 0000000..d8f9394 --- /dev/null +++ b/examples/simplified_gemma/README.md @@ -0,0 +1,60 @@ +# Simplified Gemma.cpp Example + +This is a minimal/template project for using `gemma.cpp` as a library. Instead +of an interactive interface, it sets up the model state and generates text for a +single hard coded prompt. + +Build steps are similar to the main `gemma` executable. For now only +`cmake`/`make` is available for builds (PRs welcome for other build options). + +First use `cmake` to configure the project, starting from the `simplified_gemma` +example directory (`gemma.cpp/examples/simplified_gemma`): + +```sh +cmake -B build +``` + +This sets up a build configuration in `gemma.cpp/examples/simplified_gemma/build`. +Note that this fetches `libgemma` from a git commit hash on github. +Alternatively if you want to build using the local version of `gemma.cpp` use: + +```sh +cmake -B build -DBUILD_MODE=local +``` + +Make sure you delete the contents of the build directory before changing +configurations. + +Then use `make` to build the project: + +```sh +cd build +make simplified_gemma +``` + +As with the top-level `gemma.cpp` project you can use the `make` commands `-j` +flag to use parallel threads for faster builds. + +From inside the `gemma.cpp/examples/simplified_gemma/build` directory, there should +be a `simplified_gemma` executable. You can run it with the same 3 model arguments as +gemma.cpp specifying the tokenizer, compressed weights file, and model type, for +example: + +```sh +./simplified_gemma --tokenizer tokenizer.spm --compressed_weights 2b-it-sfp.sbs --model 2b-it +``` + +Should print a greeting to the terminal: + +``` +"Hello, world! It's a pleasure to greet you all. May your day be filled with joy, peace, and all the things that make your heart soar. +``` + +For a demonstration of constrained decoding, add the `--reject` flag followed by +a list of token IDs (note that it must be the last flag, since it consumes every +subsequent argument). For example, to reject variations of the word "greeting", +run: + +```sh +./simplified_gemma [...] --reject 32338 42360 78107 106837 132832 143859 154230 190205 +``` diff --git a/examples/simplified_gemma/build/.gitignore b/examples/simplified_gemma/build/.gitignore new file mode 100644 index 0000000..d6b7ef3 --- /dev/null +++ b/examples/simplified_gemma/build/.gitignore @@ -0,0 +1,2 @@ +* +!.gitignore diff --git a/examples/simplified_gemma/gemma.hpp b/examples/simplified_gemma/gemma.hpp new file mode 100644 index 0000000..84283a3 --- /dev/null +++ b/examples/simplified_gemma/gemma.hpp @@ -0,0 +1,114 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by app_licable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "third_party/gemma_cpp/gemma/gemma.h" +#include "third_party/gemma_cpp/gemma/tokenizer.h" +#include "third_party/gemma_cpp/util/app.h" // LoaderArgs +#include "third_party/gemma_cpp/util/threading.h" +#include "third_party/highway/hwy/base.h" +#include "third_party/highway/hwy/contrib/thread_pool/thread_pool.h" + +class SimplifiedGemma { + public: + SimplifiedGemma(const gcpp::LoaderArgs& loader, + const gcpp::InferenceArgs& inference = gcpp::InferenceArgs(), + const gcpp::AppArgs& app = gcpp::AppArgs()) + : loader_(loader), + inference_(inference), + app_(app), + pools_(gcpp::CreatePools(app_)), + model_(gcpp::CreateGemma(loader_, pools_)) { + Init(); + } + + SimplifiedGemma(int argc, char** argv) + : loader_(argc, argv, /*validate=*/true), + inference_(argc, argv), + app_(argc, argv), + pools_(gcpp::CreatePools(app_)), + model_(gcpp::CreateGemma(loader_, pools_)) { + Init(); + } + + void Init() { + gcpp::Allocator::Init(pools_.Topology()); + + // Instantiate model and KV Cache + kv_cache_ = gcpp::KVCache::Create(model_.GetModelConfig(), + inference_.prefill_tbatch_size); + + // Initialize random number generator + std::random_device rd; + gen_.seed(rd()); + } + + void Generate(std::string& prompt, size_t max_generated_tokens = 1024, + float temperature = 0.7, + const std::set& reject_tokens = {}) { + size_t generated = 0; + + const std::vector tokens = gcpp::WrapAndTokenize( + model_.Tokenizer(), loader_.Info(), generated, prompt); + const size_t prompt_size = tokens.size(); + + // This callback function gets invoked every time a token is generated + auto stream_token = [&generated, &prompt_size, this](int token, float) { + ++generated; + if (generated < prompt_size) { + // print feedback + } else if (token != gcpp::EOS_ID) { + std::string token_text; + HWY_ASSERT(this->model_.Tokenizer().Decode({token}, &token_text)); + std::cout << token_text << std::flush; + } + return true; + }; + + gcpp::TimingInfo timing_info; + gcpp::RuntimeConfig runtime_config = { + .max_generated_tokens = max_generated_tokens, + .temperature = temperature, + .gen = &gen_, + .verbosity = 0, + .stream_token = stream_token, + .accept_token = + [&](int token, float /* prob */) { + return !reject_tokens.contains(token); + }, + }; + model_.Generate(runtime_config, tokens, 0, kv_cache_, timing_info); + } + ~SimplifiedGemma() = default; + + private: + gcpp::LoaderArgs loader_; + gcpp::InferenceArgs inference_; + gcpp::AppArgs app_; + gcpp::NestedPools pools_; + gcpp::Gemma model_; + gcpp::KVCache kv_cache_; + std::mt19937 gen_; + std::string validation_error_; +}; \ No newline at end of file diff --git a/examples/simplified_gemma/run.cc b/examples/simplified_gemma/run.cc new file mode 100644 index 0000000..f73ddb5 --- /dev/null +++ b/examples/simplified_gemma/run.cc @@ -0,0 +1,50 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include + +// Placeholder for internal header, do not modify. +#include "third_party/gemma_cpp/examples/simplified_gemma/gemma.hpp" +#include "util/app.h" // LoaderArgs + +int main(int argc, char** argv) { + { + // Placeholder for internal init, do not modify. + } + + // Standard usage: LoaderArgs takes argc and argv as input, then parses + // necessary flags. + gcpp::LoaderArgs loader(argc, argv, /*validate=*/true); + + // Optional: LoaderArgs can also take tokenizer and weights paths directly. + // + // gcpp::LoaderArgs loader("/path/to/tokenizer", "/path/to/weights", + // "model_identifier"); + + // Optional: InferenceArgs and AppArgs can be passed in as well. If not + // specified, default values will be used. + // + // gcpp::InferenceArgs inference(argc, argv); + // gcpp::AppArgs app(argc, argv); + // SimplifiedGemma gemma(loader, inference, app); + + SimplifiedGemma gemma(loader); + std::string prompt = "Write a greeting to the world."; + gemma.Generate(prompt, 256, 0.6); + + return 0; +} \ No newline at end of file diff --git a/util/app.h b/util/app.h index 4b6dffb..d759467 100644 --- a/util/app.h +++ b/util/app.h @@ -126,15 +126,27 @@ static inline NestedPools CreatePools(const AppArgs& app) { } struct LoaderArgs : public ArgsBase { - LoaderArgs(int argc, char* argv[]) { + LoaderArgs(int argc, char* argv[], bool validate = true) { InitAndParse(argc, argv); + + if (validate) { + if (const char* error = Validate()) { + HWY_ABORT("Invalid args: %s", error); + } + } } LoaderArgs(const std::string& tokenizer_path, const std::string& weights_path, - const std::string& model) { + const std::string& model, bool validate = true) { Init(); // Init sets to defaults, so assignments must come after Init(). tokenizer.path = tokenizer_path; weights.path = weights_path; model_type_str = model; + + if (validate) { + if (const char* error = Validate()) { + HWY_ABORT("Invalid args: %s", error); + } + } }; // Returns error string or nullptr if OK.