From f3116d25775f536c8e93f232447b7f344aedd0a2 Mon Sep 17 00:00:00 2001 From: prajwalc22 Date: Sat, 12 Apr 2025 13:22:48 +0530 Subject: [PATCH 1/9] Add --prompt flag for non-interactive mode This change adds a --prompt command-line option that allows users to provide prompts directly without entering interactive mode, which is useful for scripting and automation. --- gemma/gemma_args.h | 217 +++++++++++++++++++++++++++++++++++++++++++-- gemma/run.cc | 38 ++++++-- 2 files changed, 245 insertions(+), 10 deletions(-) diff --git a/gemma/gemma_args.h b/gemma/gemma_args.h index 4fe2d33..dc4019c 100644 --- a/gemma/gemma_args.h +++ b/gemma/gemma_args.h @@ -28,13 +28,205 @@ #include "compression/shared.h" #include "gemma/common.h" #include "gemma/gemma.h" // For CreateGemma +#include "hwy/base.h" // HWY_IS_ASAN, HWY_ABORT #include "ops/matmul.h" +#include "util/allocator.h" #include "util/args.h" #include "util/basics.h" // Tristate -#include "hwy/base.h" // HWY_ABORT +#include "util/threading.h" +#include "util/threading_context.h" namespace gcpp { +static inline const char* CompiledConfig() { + if (HWY_IS_ASAN) { + return "asan"; + } else if (HWY_IS_MSAN) { + return "msan"; + } else if (HWY_IS_TSAN) { + return "tsan"; + } else if (HWY_IS_HWASAN) { + return "hwasan"; + } else if (HWY_IS_UBSAN) { + return "ubsan"; + } else if (HWY_IS_DEBUG_BUILD) { + return "dbg"; + } else { + return "opt"; + } +} +template +struct ArgsBase { + void Init() { static_cast(this)->ForEach(SetToDefault()); } + + void InitAndParse(int argc, char* argv[]) { + Init(); + static_cast(this)->ForEach(ParseOption(argc, argv)); + } + + void Print(int min_verbosity = 1) const { + static_cast(this)->ForEach(PrintOption(min_verbosity)); + } + + void Help() const { static_cast(this)->ForEach(PrintHelp()); } + + protected: + // Helper struct for printing help messages + struct PrintHelp { + template + void operator()(const T& value, const char* name, const T& default_value, + const char* description, int verbosity = 1) const { + fprintf(stderr, " --%s\n %s\n", name, description); + } + // Special case for strings to avoid template deduction issues + void operator()(const std::string& value, const char* name, + const std::string& default_value, const char* description, + int verbosity = 1) const { + fprintf(stderr, " --%s\n %s\n", name, description); + } + // Special case for Path type + void operator()(const Path& value, const char* name, + const Path& default_value, const char* description, + int verbosity = 1) const { + fprintf(stderr, " --%s\n %s\n", name, description); + } + }; + + // Helper struct for setting default values + struct SetToDefault { + template + void operator()(T& value, const char* name, const T& default_value, + const char* description, int verbosity = 1) const { + value = default_value; + } + }; + + // Helper struct for printing values + struct PrintOption { + explicit PrintOption(int min_verbosity) : min_verbosity_(min_verbosity) {} + + template + void operator()(const T& value, const char* name, const T& default_value, + const char* description, int verbosity = 1) const { + if (verbosity >= min_verbosity_) { + fprintf(stderr, "%s: %s\n", name, ToString(value).c_str()); + } + } + + private: + int min_verbosity_; + + // Helper function to convert values to string + template + static std::string ToString(const T& value) { + return std::to_string(value); + } + // Specialization for string + static std::string ToString(const std::string& value) { return value; } + // Specialization for Path + static std::string ToString(const Path& value) { return value.path; } + }; +}; +struct ThreadingArgs : public ArgsBase { + public: + ThreadingArgs(int argc, char* argv[]) { InitAndParse(argc, argv); } + ThreadingArgs() { Init(); }; + + int verbosity; + + size_t max_threads; // divided among the detected clusters + Tristate pin; // pin threads? + Tristate spin; // use spin waits? + + // For BoundedSlice: + size_t skip_packages; + size_t max_packages; + size_t skip_clusters; + size_t max_clusters; + size_t skip_lps; + size_t max_lps; + + std::string eot_line; + std::string prompt; + template + void ForEach(const Visitor& visitor) { + visitor(verbosity, "verbosity", 1, + "Show verbose developer information\n 0 = only print generation " + "output\n 1 = standard user-facing terminal ui\n 2 = show " + "developer/debug info).\n Default = 1.", + 2); + + // The exact meaning is more subtle: see the comment at NestedPools ctor. + visitor(max_threads, "num_threads", size_t{0}, + "Maximum number of threads to use; default 0 = unlimited.", 2); + visitor(pin, "pin", Tristate::kDefault, + "Pin threads? -1 = auto, 0 = no, 1 = yes.", 2); + visitor(spin, "spin", Tristate::kDefault, + "Use spin waits? -1 = auto, 0 = no, 1 = yes.", 2); + // These can be used to partition CPU sockets/packages and their + // clusters/CCXs across several program instances. The default is to use + // all available resources. + visitor(skip_packages, "skip_packages", size_t{0}, + "Index of the first socket to use; default 0 = unlimited.", 2); + visitor(max_packages, "max_packages", size_t{0}, + "Maximum number of sockets to use; default 0 = unlimited.", 2); + visitor(skip_clusters, "skip_clusters", size_t{0}, + "Index of the first CCX to use; default 0 = unlimited.", 2); + visitor(max_clusters, "max_clusters", size_t{0}, + "Maximum number of CCXs to use; default 0 = unlimited.", 2); + // These are only used when CPU topology is unknown. + visitor(skip_lps, "skip_lps", size_t{0}, + "Index of the first LP to use; default 0 = unlimited.", 2); + visitor(max_lps, "max_lps", size_t{0}, + "Maximum number of LPs to use; default 0 = unlimited.", 2); + + visitor( + eot_line, "eot_line", std::string(""), + "End of turn line. " + "When you specify this, the prompt will be all lines " + "before the line where only the given string appears.\n Default = " + "When a newline is encountered, that signals the end of the turn.", + 2); + + visitor(prompt, "prompt", std::string(""), + "Prompt string for non-interactive mode. When provided, the model " + "generates a response and exits.", + 2); + } +}; +static inline BoundedTopology CreateTopology(const ThreadingArgs& threading) { + return BoundedTopology( + BoundedSlice(threading.skip_packages, threading.max_packages), + BoundedSlice(threading.skip_clusters, threading.max_clusters), + BoundedSlice(threading.skip_lps, threading.max_lps)); +} + +static inline MatMulEnv MakeMatMulEnv(const ThreadingArgs& threading) { + ThreadingContext2::SetArgs(threading); + return MatMulEnv(ThreadingContext2::Get()); +} +// Note: These functions may need adjustments depending on your specific class +// definitions +static inline BoundedTopology CreateTopology(const ThreadingArgs& app) { + return BoundedTopology(BoundedSlice(app.skip_packages, app.max_packages), + BoundedSlice(app.skip_clusters, app.max_clusters), + BoundedSlice(app.skip_lps, app.max_lps)); +} + +// This function may need to be adjusted based on your NestedPools constructor +// signature +static inline NestedPools CreatePools(const BoundedTopology& topology, + const ThreadingArgs& threading) { + // Make sure Allocator::Init() is properly declared/defined + const Allocator2& allocator = ThreadingContext2::Get().allocator; + // Allocator::Init(topology); + + // Adjust the constructor call based on your actual NestedPools constructor + // The error suggests that the constructor doesn't match these arguments + return NestedPools(topology, allocator, threading.max_threads, threading.pin); + // Alternative: return NestedPools(topology, app.max_threads, app.pin); +} + struct LoaderArgs : public ArgsBase { LoaderArgs(int argc, char* argv[], bool validate = true) { InitAndParse(argc, argv); @@ -106,9 +298,8 @@ struct LoaderArgs : public ArgsBase { "Path name of model weights (.sbs) file.\n Required argument.\n"); visitor(compressed_weights, "compressed_weights", Path(), "Deprecated alias for --weights."); - visitor( - model_type_str, "model", std::string(), - "Model type, see common.cc for valid values.\n"); + visitor(model_type_str, "model", std::string(), + "Model type, see common.cc for valid values.\n"); visitor(weight_type_str, "weight_type", std::string("sfp"), "Weight type\n f32 = float, bf16 = bfloat16, sfp = 8-bit SFP."); } @@ -231,6 +422,22 @@ struct InferenceArgs : public ArgsBase { } }; +static inline void ShowConfig(const ThreadingArgs& threading, + const LoaderArgs& loader, + const InferenceArgs& inference) { + threading.Print(); + loader.Print(); + inference.Print(); +} +static inline void ShowHelp(const ThreadingArgs& threading, + const LoaderArgs& loader, + const InferenceArgs& inference) { + fprintf(stderr, "\nUsage: gemma [flags]\n\nFlags:\n"); + threading.Help(); + loader.Help(); + inference.Help(); +} + } // namespace gcpp -#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_ARGS_H_ +#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_ARGS_H_ \ No newline at end of file diff --git a/gemma/run.cc b/gemma/run.cc index 5170b6e..381dac4 100644 --- a/gemma/run.cc +++ b/gemma/run.cc @@ -27,13 +27,14 @@ #include "evals/benchmark_helper.h" #include "gemma/common.h" #include "gemma/gemma.h" // Gemma -#include "gemma/gemma_args.h" // LoaderArgs -#include "ops/matmul.h" // MatMulEnv -#include "paligemma/image.h" -#include "util/args.h" // HasHelp -#include "util/threading_context.h" +#include "gemma/gemma_args.h" +#include "hwy/base.h" #include "hwy/highway.h" #include "hwy/profiler.h" +#include "ops/matmul.h" // MatMulEnv +#include "paligemma/image.h" +#include "util/args.h" // HasHelp +#include "util/threading.h" #if (!defined(HWY_VERSION_LT) || HWY_VERSION_LT(1, 2)) && !HWY_IDE #error "Please update to version 1.2 of github.com/google/highway." @@ -165,6 +166,16 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference, continue; } + // Wrap, tokenize and maybe log prompt tokens. + std::vector prompt = WrapAndTokenize(model.Tokenizer(), model.Info(), + abs_pos, prompt_string); + prompt_size = prompt.size(); + if constexpr (kVerboseLogTokens) { + for (int i = 0; i < prompt_size; ++i) { + fprintf(stderr, "DDD TOKEN %3d: %6d\n", i, prompt[i]); + } + } + // Set up runtime config. TimingInfo timing_info = {.verbosity = inference.verbosity}; RuntimeConfig runtime_config = {.gen = &gen, @@ -238,6 +249,22 @@ void Run(ThreadingArgs& threading, LoaderArgs& loader, KVCache kv_cache = KVCache::Create(model.GetModelConfig(), inference.prefill_tbatch_size); + if (!threading.prompt.empty()) { + std::vector prompt = + WrapAndTokenize(model.Tokenizer(), model.ChatTemplate(), model.Info(), + 0, threading.prompt); + + TimingInfo timing_info = {.verbosity = inference.verbosity}; + RuntimeConfig runtime_config = {.gen = nullptr, // Use default generator + .verbosity = inference.verbosity, + .use_spinning = threading.spin}; + inference.CopyTo(runtime_config); + + model.Generate(runtime_config, prompt, 0, 0, kv_cache, timing_info); + std::cout << "\n"; + return; // Exit after generating response + } + if (inference.verbosity >= 1) { std::string instructions = "*Usage*\n" @@ -280,6 +307,7 @@ int main(int argc, char** argv) { if (gcpp::HasHelp(argc, argv)) { std::cerr << gcpp::kAsciiArtBanner; + gcpp::ShowHelp(threading, loader, inference); return 0; } From 87a1c76578d8e59c33351c5a7299c3b9b730694c Mon Sep 17 00:00:00 2001 From: prajwalc22 Date: Tue, 15 Apr 2025 08:16:02 +0530 Subject: [PATCH 2/9] Update CMake configuration and documentation for --prompt flag --- CMakeLists.txt | 2 +- README.md | 596 ++--------------------------------------------- build/.gitignore | 3 - 3 files changed, 21 insertions(+), 580 deletions(-) delete mode 100644 build/.gitignore diff --git a/CMakeLists.txt b/CMakeLists.txt index b572835..b9558ad 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -cmake_minimum_required(VERSION 3.11) +cmake_minimum_required(VERSION 3.11...4.0) include(FetchContent) diff --git a/README.md b/README.md index e9a6745..2c2020e 100644 --- a/README.md +++ b/README.md @@ -1,583 +1,27 @@ -# gemma.cpp +--- +library_name: gemma.cpp +license: gemma +pipeline_tag: text-generation +tags: [] +extra_gated_heading: Access Gemma on Hugging Face +extra_gated_prompt: To access Gemma on Hugging Face, you’re required to review and + agree to Google’s usage license. To do this, please ensure you’re logged-in to Hugging + Face and click below. Requests are processed immediately. +extra_gated_button_content: Acknowledge license +--- -gemma.cpp is a lightweight, standalone C++ inference engine for the Gemma -foundation models from Google. +# Gemma Model Card -For additional information about Gemma, see -[ai.google.dev/gemma](https://ai.google.dev/gemma). Model weights, including -gemma.cpp specific artifacts, are -[available on kaggle](https://www.kaggle.com/models/google/gemma). +**Model Page**: [Gemma](https://ai.google.dev/gemma/docs) -## Who is this project for? +This model card corresponds to the 2B base version of the Gemma model for usage with C++ (https://github.com/google/gemma.cpp). This is a compressed version of the weights, which will load, run, and download more quickly. For more information about the model, visit https://huggingface.co/google/gemma-2b. -Modern LLM inference engines are sophisticated systems, often with bespoke -capabilities extending beyond traditional neural network runtimes. With this -comes opportunities for research and innovation through co-design of high level -algorithms and low-level computation. However, there is a gap between -deployment-oriented C++ inference runtimes, which are not designed for -experimentation, and Python-centric ML research frameworks, which abstract away -low-level computation through compilation. +**Resources and Technical Documentation**: -gemma.cpp provides a minimalist implementation of Gemma-1, Gemma-2, Gemma-3, and -PaliGemma models, focusing on simplicity and directness rather than full -generality. This is inspired by vertically-integrated model implementations such -as [ggml](https://github.com/ggerganov/ggml), -[llama.c](https://github.com/karpathy/llama2.c), and -[llama.rs](https://github.com/srush/llama2.rs). +* [Responsible Generative AI Toolkit](https://ai.google.dev/responsible) +* [Gemma on Kaggle](https://www.kaggle.com/models/google/gemma) +* [Gemma on Vertex Model Garden](https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/335?version=gemma-2b-gg-hf) -gemma.cpp targets experimentation and research use cases. It is intended to be -straightforward to embed in other projects with minimal dependencies and also -easily modifiable with a small ~2K LoC core implementation (along with ~4K LoC -of supporting utilities). We use the [Google -Highway](https://github.com/google/highway) Library to take advantage of -portable SIMD for CPU inference. +**Terms of Use**: [Terms](https://www.kaggle.com/models/google/gemma/license/consent/verify/huggingface?returnModelRepoId=google/gemma-2b-sfp-cpp) -For production-oriented edge deployments we recommend standard deployment -pathways using Python frameworks like JAX, Keras, PyTorch, and Transformers -([all model variations here](https://www.kaggle.com/models/google/gemma)). - -## Contributing - -Community contributions large and small are welcome. See -[DEVELOPERS.md](https://github.com/google/gemma.cpp/blob/main/DEVELOPERS.md) -for additional notes contributing developers and [join the discord by following -this invite link](https://discord.gg/H5jCBAWxAe). This project follows -[Google's Open Source Community -Guidelines](https://opensource.google.com/conduct/). - -*Active development is currently done on the `dev` branch. Please open pull -requests targeting `dev` branch instead of `main`, which is intended to be more -stable.* - -## Quick Start - -### System requirements - -Before starting, you should have installed: - -- [CMake](https://cmake.org/) -- [Clang C++ compiler](https://clang.llvm.org/get_started.html), supporting at - least C++17. -- `tar` for extracting archives from Kaggle. - -Building natively on Windows requires the Visual Studio 2012 Build Tools with the -optional Clang/LLVM C++ frontend (`clang-cl`). This can be installed from the -command line with -[`winget`](https://learn.microsoft.com/en-us/windows/package-manager/winget/): - -```sh -winget install --id Kitware.CMake -winget install --id Microsoft.VisualStudio.2022.BuildTools --force --override "--passive --wait --add Microsoft.VisualStudio.Workload.VCTools;installRecommended --add Microsoft.VisualStudio.Component.VC.Llvm.Clang --add Microsoft.VisualStudio.Component.VC.Llvm.ClangToolset" -``` - -### Step 1: Obtain model weights and tokenizer from Kaggle or Hugging Face Hub - -Visit the -[Kaggle page for Gemma-2](https://www.kaggle.com/models/google/gemma-2/gemmaCpp) -[or Gemma-1](https://www.kaggle.com/models/google/gemma/frameworks/gemmaCpp), -and select `Model Variations |> Gemma C++`. - -On this tab, the `Variation` dropdown includes the options below. Note bfloat16 -weights are higher fidelity, while 8-bit switched floating point weights enable -faster inference. In general, we recommend starting with the `-sfp` checkpoints. - -If you are unsure which model to start with, we recommend starting with the -smallest Gemma-2 model, i.e. `2.0-2b-it-sfp`. - -Alternatively, visit the -[gemma.cpp](https://huggingface.co/models?other=gemma.cpp) models on the Hugging -Face Hub. First go the model repository of the model of interest (see -recommendations below). Then, click the `Files and versions` tab and download -the model and tokenizer files. For programmatic downloading, if you have -`huggingface_hub` installed, you can also download by running: - -``` -huggingface-cli login # Just the first time -huggingface-cli download google/gemma-2b-sfp-cpp --local-dir build/ -``` - -Gemma-1 2B instruction-tuned (`it`) and pre-trained (`pt`) models: - -| Model name | Description | -| ----------- | ----------- | -| `2b-it` | 2 billion parameter instruction-tuned model, bfloat16 | -| `2b-it-sfp` | 2 billion parameter instruction-tuned model, 8-bit switched floating point | -| `2b-pt` | 2 billion parameter pre-trained model, bfloat16 | -| `2b-pt-sfp` | 2 billion parameter pre-trained model, 8-bit switched floating point | - -Gemma-1 7B instruction-tuned (`it`) and pre-trained (`pt`) models: - -| Model name | Description | -| ----------- | ----------- | -| `7b-it` | 7 billion parameter instruction-tuned model, bfloat16 | -| `7b-it-sfp` | 7 billion parameter instruction-tuned model, 8-bit switched floating point | -| `7b-pt` | 7 billion parameter pre-trained model, bfloat16 | -| `7b-pt-sfp` | 7 billion parameter pre-trained model, 8-bit switched floating point | - -> [!NOTE] -> **Important**: We strongly recommend starting off with the `2b-it-sfp` model to -> get up and running. - -Gemma 2 models are named `gemma2-2b-it` for 2B and `9b-it` or `27b-it`. See the -`kModelFlags` definition in `common.cc`. - -### Step 2: Extract Files - -If you downloaded the models from Hugging Face, skip to step 3. - -After filling out the consent form, the download should proceed to retrieve a -tar archive file `archive.tar.gz`. Extract files from `archive.tar.gz` (this can -take a few minutes): - -``` -tar -xf archive.tar.gz -``` - -This should produce a file containing model weights such as `2b-it-sfp.sbs` and -a tokenizer file (`tokenizer.spm`). You may want to move these files to a -convenient directory location (e.g. the `build/` directory in this repo). - -### Step 3: Build - -The build system uses [CMake](https://cmake.org/). To build the gemma inference -runtime, create a build directory and generate the build files using `cmake` -from the top-level project directory. Note if you previous ran `cmake` and are -re-running with a different setting, be sure to delete all files in the `build/` -directory with `rm -rf build/*`. - -#### Unix-like Platforms -```sh -cmake -B build -``` - -After running `cmake`, you can enter the `build/` directory and run `make` to -build the `./gemma` executable: - -```sh -# Configure `build` directory -cmake --preset make - -# Build project using make -cmake --build --preset make -j [number of parallel threads to use] -``` - -Replace `[number of parallel threads to use]` with a number - the number of -cores available on your system is a reasonable heuristic. For example, -`make -j4 gemma` will build using 4 threads. If the `nproc` command is -available, you can use `make -j$(nproc) gemma` as a reasonable default -for the number of threads. - -If you aren't sure of the right value for the `-j` flag, you can simply run -`make gemma` instead and it should still build the `./gemma` executable. - -> [!NOTE] -> On Windows Subsystem for Linux (WSL) users should set the number of -> parallel threads to 1. Using a larger number may result in errors. - -If the build is successful, you should now have a `gemma` executable in the `build/` directory. - -#### Windows - -```sh -# Configure `build` directory -cmake --preset windows - -# Build project using Visual Studio Build Tools -cmake --build --preset windows -j [number of parallel threads to use] -``` - -If the build is successful, you should now have a `gemma.exe` executable in the `build/` directory. - -#### Bazel - -```sh -bazel build -c opt --cxxopt=-std=c++20 :gemma -``` - -If the build is successful, you should now have a `gemma` executable in the `bazel-bin/` directory. - -#### Make - -If you prefer Makefiles, @jart has made one available here: - -https://github.com/jart/gemma3/blob/main/Makefile - -### Step 4: Run - -You can now run `gemma` from inside the `build/` directory. - -`gemma` has the following required arguments: - -Argument | Description | Example value ---------------- | ---------------------------- | ----------------------- -`--model` | The model type. | `2b-it` ... (see below) -`--weights` | The compressed weights file. | `2b-it-sfp.sbs` -`--weight_type` | The compressed weight type. | `sfp` -`--tokenizer` | The tokenizer file. | `tokenizer.spm` - -`gemma` is invoked as: - -```sh -./gemma \ ---tokenizer [tokenizer file] \ ---weights [compressed weights file] \ ---weight_type [f32 or bf16 or sfp (default:sfp)] \ ---model [2b-it or 2b-pt or 7b-it or 7b-pt or ...] -``` - -Example invocation for the following configuration: - -- Compressed weights file `2b-it-sfp.sbs` (2B instruction-tuned model, 8-bit - switched floating point). -- Tokenizer file `tokenizer.spm`. - -```sh -./gemma \ ---tokenizer tokenizer.spm \ ---weights 2b-it-sfp.sbs --model 2b-it -``` - -### RecurrentGemma - -This repository includes a version of Gemma based on Griffin -([paper](https://arxiv.org/abs/2402.19427), -[code](https://github.com/google-deepmind/recurrentgemma)). Its architecture -includes both recurrent layers and local attention, thus it is more efficient -for longer sequences and has a smaller memory footprint than standard Gemma. We -here provide a C++ implementation of this model based on the paper. - -To use the recurrent version of Gemma included in this repository, build the -gemma binary as noted above in Step 3. Download the compressed weights and -tokenizer from the RecurrentGemma -[Kaggle](https://www.kaggle.com/models/google/recurrentgemma/gemmaCpp) as in -Step 1, and run the binary as follows: - -`./gemma --tokenizer tokenizer.spm --model gr2b-it --weights 2b-it-sfp.sbs` - -### PaliGemma Vision-Language Model - -This repository includes a version of the PaliGemma VLM -([paper](https://arxiv.org/abs/2407.07726), -[code](https://github.com/google-research/big_vision/tree/main/big_vision/configs/proj/paligemma)) -and its successor PaliGemma 2 ([paper](https://arxiv.org/abs/2412.03555)). We -provide a C++ implementation of the PaliGemma model family here. - -To use the version of PaliGemma included in this repository, build the gemma -binary as noted above in Step 3. Download the compressed weights and tokenizer -from -[Kaggle](https://www.kaggle.com/models/google/paligemma/gemmaCpp/paligemma-3b-mix-224) -and run the binary as follows: - -```sh -./gemma \ ---tokenizer paligemma_tokenizer.model \ ---model paligemma-224 \ ---weights paligemma-3b-mix-224-sfp.sbs \ ---image_file paligemma/testdata/image.ppm -``` - -Note that the image reading code is very basic to avoid depending on an image -processing library for now. We currently only support reading binary PPMs (P6). -So use a tool like `convert` to first convert your images into that format, e.g. - -`convert image.jpeg -resize 224x224^ image.ppm` - -(As the image will be resized for processing anyway, we can already resize at -this stage for slightly faster loading.) - -The interaction with the image (using the mix-224 checkpoint) may then look -something like this: - -``` -> Describe the image briefly -A large building with two towers in the middle of a city. -> What type of building is it? -church -> What color is the church? -gray -> caption image -A large building with two towers stands tall on the water's edge. The building -has a brown roof and a window on the side. A tree stands in front of the -building, and a flag waves proudly from its top. The water is calm and blue, -reflecting the sky above. A bridge crosses the water, and a red and white boat -rests on its surface. The building has a window on the side, and a flag on top. -A tall tree stands in front of the building, and a window on the building is -visible from the water. The water is green, and the sky is blue. -``` - -### Migrating to single-file format - -There is now a new format for the weights file, which is a single file that -allows to contain the tokenizer (and the model type) directly. A tool to migrate -from the multi-file format to the single-file format is available. - -```sh -compression/migrate_weights \ - --tokenizer .../tokenizer.spm --weights .../gemma2-2b-it-sfp.sbs \ - --model gemma2-2b-it --output_weights .../gemma2-2b-it-sfp-single.sbs -``` - -After migration, you can use the new weights file with gemma.cpp like this: - -```sh -./gemma --weights .../gemma2-2b-it-sfp-single.sbs -``` - -### Troubleshooting and FAQs - -**Running `./gemma` fails with "Failed to read cache gating_ein_0 (error 294) ..."** - -The most common problem is that the `--weight_type` argument does not match that -of the model file. Revisit step #3 and check which weights you downloaded. - -Note that we have already moved weight type from a compile-time decision to a -runtime argument. In a subsequent step, we plan to bake this information into -the weights. - -**Problems building in Windows / Visual Studio** - -Currently if you're using Windows, we recommend building in WSL (Windows -Subsystem for Linux). We are exploring options to enable other build -configurations, see issues for active discussion. - -**Model does not respond to instructions and produces strange output** - -A common issue is that you are using a pre-trained model, which is not -instruction-tuned and thus does not respond to instructions. Make sure you are -using an instruction-tuned model (`2b-it-sfp`, `2b-it`, `7b-it-sfp`, `7b-it`) -and not a pre-trained model (any model with a `-pt` suffix). - -**What sequence lengths are supported?** - -See `seq_len` in `configs.cc`. For the Gemma 3 models larger than 1B, this is -typically 32K but 128K would also work given enough RAM. Note that long -sequences will be slow due to the quadratic cost of attention. - -**How do I convert my fine-tune to a `.sbs` compressed model file?** - -For PaliGemma (1 and 2) checkpoints, you can use -python/convert_from_safetensors.py to convert from safetensors format (tested -with building via bazel). For an adapter model, you will likely need to call -merge_and_unload() to convert the adapter model to a single-file format before -converting it. - -Here is how to use it using a bazel build of the compression library assuming -locally installed (venv) torch, numpy, safetensors, absl-py, etc.: - -```sh -bazel build //compression/python:compression -BAZEL_OUTPUT_DIR="${PWD}/bazel-bin/compression" -python3 -c "import site; print(site.getsitepackages())" -# Use your sites-packages file here: -ln -s $BAZEL_OUTPUT_DIR [...]/site-packages/compression -python3 python/convert_from_safetensors.py --load_path [...].safetensors.index.json -``` - -See also compression/convert_weights.py for a slightly older option to convert a -pytorch checkpoint. (The code may need updates to work with Gemma-2 models.) - -**What are some easy ways to make the model run faster?** - -1. Make sure you are using the 8-bit switched floating point `-sfp` models. - These are half the size of bf16 and thus use less memory bandwidth and cache - space. -2. If you're on a laptop, make sure power mode is set to maximize performance - and saving mode is **off**. For most laptops, the power saving modes get - activated automatically if the computer is not plugged in. -3. Close other unused cpu-intensive applications. -4. On macs, anecdotally we observe a "warm-up" ramp-up in speed as performance - cores get engaged. -5. Experiment with the `--num_threads` argument value. Depending on the device, - larger numbers don't always mean better performance. - -We're also working on algorithmic and optimization approaches for faster -inference, stay tuned. - -## Usage - -`gemma` has different usage modes, controlled by the verbosity flag. - -All usage modes are currently interactive, triggering text generation upon -newline input. - -| Verbosity | Usage mode | Details | -| --------------- | ---------- | --------------------------------------------- | -| `--verbosity 0` | Minimal | Only prints generation output. Suitable as a CLI tool. | -| `--verbosity 1` | Default | Standard user-facing terminal UI. | -| `--verbosity 2` | Detailed | Shows additional developer and debug info. | - -### Interactive Terminal App - -By default, verbosity is set to 1, bringing up a terminal-based interactive -interface when `gemma` is invoked: - -```console -$ ./gemma [...] - __ _ ___ _ __ ___ _ __ ___ __ _ ___ _ __ _ __ - / _` |/ _ \ '_ ` _ \| '_ ` _ \ / _` | / __| '_ \| '_ \ -| (_| | __/ | | | | | | | | | | (_| || (__| |_) | |_) | - \__, |\___|_| |_| |_|_| |_| |_|\__,_(_)___| .__/| .__/ - __/ | | | | | - |___/ |_| |_| - -tokenizer : tokenizer.spm -compressed_weights : 2b-it-sfp.sbs -model : 2b-it -weights : [no path specified] -max_generated_tokens : 2048 - -*Usage* - Enter an instruction and press enter (%C reset conversation, %Q quits). - -*Examples* - - Write an email to grandma thanking her for the cookies. - - What are some historical attractions to visit around Massachusetts? - - Compute the nth fibonacci number in javascript. - - Write a standup comedy bit about WebGPU programming. - -> What are some outdoorsy places to visit around Boston? - -[ Reading prompt ] ..................... - - -**Boston Harbor and Islands:** - -* **Boston Harbor Islands National and State Park:** Explore pristine beaches, wildlife, and maritime history. -* **Charles River Esplanade:** Enjoy scenic views of the harbor and city skyline. -* **Boston Harbor Cruise Company:** Take a relaxing harbor cruise and admire the city from a different perspective. -* **Seaport Village:** Visit a charming waterfront area with shops, restaurants, and a seaport museum. - -**Forest and Nature:** - -* **Forest Park:** Hike through a scenic forest with diverse wildlife. -* **Quabbin Reservoir:** Enjoy boating, fishing, and hiking in a scenic setting. -* **Mount Forest:** Explore a mountain with breathtaking views of the city and surrounding landscape. - -... -``` - -### Usage as a Command Line Tool - -For using the `gemma` executable as a command line tool, it may be useful to -create an alias for gemma.cpp with arguments fully specified: - -```sh -alias gemma2b="~/gemma.cpp/build/gemma -- --tokenizer ~/gemma.cpp/build/tokenizer.spm --weights ~/gemma.cpp/build/gemma2-2b-it-sfp.sbs --model gemma2-2b-it --verbosity 0" -``` - -Replace the above paths with your own paths to the model and tokenizer paths -from the download. - -Here is an example of prompting `gemma` with a truncated input -file (using a `gemma2b` alias like defined above): - -```sh -cat configs.h | tail -n 35 | tr '\n' ' ' | xargs -0 echo "What does this C++ code do: " | gemma2b -``` - -> [!NOTE] -> CLI usage of gemma.cpp is experimental and should take context length -> limitations into account. - -The output of the above command should look like: - -```console -[ Reading prompt ] [...] -This C++ code snippet defines a set of **constants** used in a large language model (LLM) implementation, likely related to the **attention mechanism**. - -Let's break down the code: -[...] -``` - -### Incorporating gemma.cpp as a Library in your Project - -The easiest way to incorporate gemma.cpp in your own project is to pull in -gemma.cpp and dependencies using `FetchContent`. You can add the following to your -CMakeLists.txt: - -``` -include(FetchContent) - -FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 53de76561cfc149d3c01037f0595669ad32a5e7c) -FetchContent_MakeAvailable(sentencepiece) - -FetchContent_Declare(gemma GIT_REPOSITORY https://github.com/google/gemma.cpp GIT_TAG origin/main) -FetchContent_MakeAvailable(gemma) - -FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG da250571a45826b21eebbddc1e50d0c1137dee5f) -FetchContent_MakeAvailable(highway) -``` - -Note for the gemma.cpp `GIT_TAG`, you may replace `origin/main` for a specific -commit hash if you would like to pin the library version. - -After your executable is defined (substitute your executable name for -`[Executable Name]` below): - -``` -target_link_libraries([Executable Name] libgemma hwy hwy_contrib sentencepiece) -FetchContent_GetProperties(gemma) -FetchContent_GetProperties(sentencepiece) -target_include_directories([Executable Name] PRIVATE ${gemma_SOURCE_DIR}) -target_include_directories([Executable Name] PRIVATE ${sentencepiece_SOURCE_DIR}) -``` - -### Building gemma.cpp as a Library - -gemma.cpp can also be used as a library dependency in your own project. The -shared library artifact can be built by modifying the make invocation to build -the `libgemma` target instead of `gemma`. - -> [!NOTE] -> If you are using gemma.cpp in your own project with the `FetchContent` steps -> in the previous section, building the library is done automatically by `cmake` -> and this section can be skipped. - -First, run `cmake`: - -```sh -cmake -B build -``` - -Then, run `make` with the `libgemma` target: - -```sh -cd build -make -j [number of parallel threads to use] libgemma -``` - -If this is successful, you should now have a `libgemma` library file in the -`build/` directory. On Unix platforms, the filename is `libgemma.a`. - -## Independent Projects Using gemma.cpp - -Some independent projects using gemma.cpp: - -- [gemma-cpp-python - Python bindings](https://github.com/namtranase/gemma-cpp-python) -- [lua-cgemma - Lua bindings](https://github.com/ufownl/lua-cgemma) -- [Godot engine demo project](https://github.com/Rliop913/Gemma-godot-demo-project) - -If you would like to have your project included, feel free to get in touch or -submit a PR with a `README.md` edit. - -## Acknowledgements and Contacts - -gemma.cpp was started in fall 2023 by [Austin Huang](mailto:austinvhuang@google.com) -and [Jan Wassenberg](mailto:janwas@google.com), and subsequently released February 2024 -thanks to contributions from Phil Culliton, Paul Chang, and Dan Zheng. - -Griffin support was implemented in April 2024 thanks to contributions by Andrey -Mikhaylov, Eugene Kliuchnikov, Jan Wassenberg, Jyrki Alakuijala, Lode -Vandevenne, Luca Versari, Martin Bruse, Phil Culliton, Sami Boukortt, Thomas -Fischbacher and Zoltan Szabadka. - -Gemma-2 support was implemented in June/July 2024 with the help of several -people. - -PaliGemma support was implemented in September 2024 with contributions from -Daniel Keysers. - -[Jan Wassenberg](mailto:janwas@google.com) has continued to contribute many -improvements, including major gains in efficiency, since the initial release. - -This is not an officially supported Google product. +**Authors**: Google \ No newline at end of file diff --git a/build/.gitignore b/build/.gitignore deleted file mode 100644 index 3822a0b..0000000 --- a/build/.gitignore +++ /dev/null @@ -1,3 +0,0 @@ -* -!.gitignore -!.hgignore \ No newline at end of file From 01caf379ba724b16b48255e571d68ed2b76157b0 Mon Sep 17 00:00:00 2001 From: prajwalc22 Date: Tue, 15 Apr 2025 08:21:19 +0530 Subject: [PATCH 3/9] Update .gitignore to exclude build directory and model files --- .gitignore | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.gitignore b/.gitignore index d4264cb..7025304 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,6 @@ bazel-*/ build-*/ python/*/__pycache__ +build/ +*.sbs +*.spm From 716713f0e60fb8ec1f857e73114b394909058918 Mon Sep 17 00:00:00 2001 From: prajwalc22 Date: Wed, 16 Apr 2025 09:52:30 +0530 Subject: [PATCH 4/9] Update .gitignore to exclude build directory and model files --- .gitattributes | 37 +++++++++++++++++++++++++++++++++++ .gitignore | 20 ++++++++++++++++++- .vscode/c_cpp_properties.json | 15 ++++++++++++++ 3 files changed, 71 insertions(+), 1 deletion(-) create mode 100644 .gitattributes create mode 100644 .vscode/c_cpp_properties.json diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000..ae66414 --- /dev/null +++ b/.gitattributes @@ -0,0 +1,37 @@ +*.7z filter=lfs diff=lfs merge=lfs -text +*.arrow filter=lfs diff=lfs merge=lfs -text +*.bin filter=lfs diff=lfs merge=lfs -text +*.bz2 filter=lfs diff=lfs merge=lfs -text +*.ckpt filter=lfs diff=lfs merge=lfs -text +*.ftz filter=lfs diff=lfs merge=lfs -text +*.gz filter=lfs diff=lfs merge=lfs -text +*.h5 filter=lfs diff=lfs merge=lfs -text +*.joblib filter=lfs diff=lfs merge=lfs -text +*.lfs.* filter=lfs diff=lfs merge=lfs -text +*.mlmodel filter=lfs diff=lfs merge=lfs -text +*.model filter=lfs diff=lfs merge=lfs -text +*.msgpack filter=lfs diff=lfs merge=lfs -text +*.npy filter=lfs diff=lfs merge=lfs -text +*.npz filter=lfs diff=lfs merge=lfs -text +*.onnx filter=lfs diff=lfs merge=lfs -text +*.ot filter=lfs diff=lfs merge=lfs -text +*.parquet filter=lfs diff=lfs merge=lfs -text +*.pb filter=lfs diff=lfs merge=lfs -text +*.pickle filter=lfs diff=lfs merge=lfs -text +*.pkl filter=lfs diff=lfs merge=lfs -text +*.pt filter=lfs diff=lfs merge=lfs -text +*.pth filter=lfs diff=lfs merge=lfs -text +*.rar filter=lfs diff=lfs merge=lfs -text +*.safetensors filter=lfs diff=lfs merge=lfs -text +saved_model/**/* filter=lfs diff=lfs merge=lfs -text +*.tar.* filter=lfs diff=lfs merge=lfs -text +*.tar filter=lfs diff=lfs merge=lfs -text +*.tflite filter=lfs diff=lfs merge=lfs -text +*.tgz filter=lfs diff=lfs merge=lfs -text +*.wasm filter=lfs diff=lfs merge=lfs -text +*.xz filter=lfs diff=lfs merge=lfs -text +*.zip filter=lfs diff=lfs merge=lfs -text +*.zst filter=lfs diff=lfs merge=lfs -text +*tfevents* filter=lfs diff=lfs merge=lfs -text +2b-pt-sfp.sbs filter=lfs diff=lfs merge=lfs -text +tokenizer.spm filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore index 7025304..1c13032 100644 --- a/.gitignore +++ b/.gitignore @@ -1,7 +1,25 @@ +# Build directories .cache/ bazel-*/ build-*/ -python/*/__pycache__ build/ + +# Python cache +python/*/__pycache__ + +# Model files *.sbs *.spm +*.data +*.bin +*.weights + +# IDE and editor files +.vscode/ +.idea/ +*.swp +*~ + +# Local development +.env +.env.local \ No newline at end of file diff --git a/.vscode/c_cpp_properties.json b/.vscode/c_cpp_properties.json new file mode 100644 index 0000000..64d3f90 --- /dev/null +++ b/.vscode/c_cpp_properties.json @@ -0,0 +1,15 @@ +{ + "configurations": [ + { + "name": "Linux", + "includePath": [ + "${workspaceFolder}/**" + ], + "defines": [], + "cStandard": "c17", + "cppStandard": "c++17", + "intelliSenseMode": "linux-clang-x64" + } + ], + "version": 4 +} \ No newline at end of file From cbf179990f57d3f7e0ecd9b67a08c9c3d2bdd799 Mon Sep 17 00:00:00 2001 From: prajwalc22 Date: Wed, 16 Apr 2025 15:34:43 +0530 Subject: [PATCH 5/9] Add --prompt flag for non-interactive mode --- README.md | 596 +++++++++++++++++++++++++++++++++++++++++++-- gemma/gemma_args.h | 552 +++++++++++++---------------------------- gemma/run.cc | 100 ++++---- 3 files changed, 795 insertions(+), 453 deletions(-) diff --git a/README.md b/README.md index 2c2020e..e9a6745 100644 --- a/README.md +++ b/README.md @@ -1,27 +1,583 @@ ---- -library_name: gemma.cpp -license: gemma -pipeline_tag: text-generation -tags: [] -extra_gated_heading: Access Gemma on Hugging Face -extra_gated_prompt: To access Gemma on Hugging Face, you’re required to review and - agree to Google’s usage license. To do this, please ensure you’re logged-in to Hugging - Face and click below. Requests are processed immediately. -extra_gated_button_content: Acknowledge license ---- +# gemma.cpp -# Gemma Model Card +gemma.cpp is a lightweight, standalone C++ inference engine for the Gemma +foundation models from Google. -**Model Page**: [Gemma](https://ai.google.dev/gemma/docs) +For additional information about Gemma, see +[ai.google.dev/gemma](https://ai.google.dev/gemma). Model weights, including +gemma.cpp specific artifacts, are +[available on kaggle](https://www.kaggle.com/models/google/gemma). -This model card corresponds to the 2B base version of the Gemma model for usage with C++ (https://github.com/google/gemma.cpp). This is a compressed version of the weights, which will load, run, and download more quickly. For more information about the model, visit https://huggingface.co/google/gemma-2b. +## Who is this project for? -**Resources and Technical Documentation**: +Modern LLM inference engines are sophisticated systems, often with bespoke +capabilities extending beyond traditional neural network runtimes. With this +comes opportunities for research and innovation through co-design of high level +algorithms and low-level computation. However, there is a gap between +deployment-oriented C++ inference runtimes, which are not designed for +experimentation, and Python-centric ML research frameworks, which abstract away +low-level computation through compilation. -* [Responsible Generative AI Toolkit](https://ai.google.dev/responsible) -* [Gemma on Kaggle](https://www.kaggle.com/models/google/gemma) -* [Gemma on Vertex Model Garden](https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/335?version=gemma-2b-gg-hf) +gemma.cpp provides a minimalist implementation of Gemma-1, Gemma-2, Gemma-3, and +PaliGemma models, focusing on simplicity and directness rather than full +generality. This is inspired by vertically-integrated model implementations such +as [ggml](https://github.com/ggerganov/ggml), +[llama.c](https://github.com/karpathy/llama2.c), and +[llama.rs](https://github.com/srush/llama2.rs). -**Terms of Use**: [Terms](https://www.kaggle.com/models/google/gemma/license/consent/verify/huggingface?returnModelRepoId=google/gemma-2b-sfp-cpp) +gemma.cpp targets experimentation and research use cases. It is intended to be +straightforward to embed in other projects with minimal dependencies and also +easily modifiable with a small ~2K LoC core implementation (along with ~4K LoC +of supporting utilities). We use the [Google +Highway](https://github.com/google/highway) Library to take advantage of +portable SIMD for CPU inference. -**Authors**: Google \ No newline at end of file +For production-oriented edge deployments we recommend standard deployment +pathways using Python frameworks like JAX, Keras, PyTorch, and Transformers +([all model variations here](https://www.kaggle.com/models/google/gemma)). + +## Contributing + +Community contributions large and small are welcome. See +[DEVELOPERS.md](https://github.com/google/gemma.cpp/blob/main/DEVELOPERS.md) +for additional notes contributing developers and [join the discord by following +this invite link](https://discord.gg/H5jCBAWxAe). This project follows +[Google's Open Source Community +Guidelines](https://opensource.google.com/conduct/). + +*Active development is currently done on the `dev` branch. Please open pull +requests targeting `dev` branch instead of `main`, which is intended to be more +stable.* + +## Quick Start + +### System requirements + +Before starting, you should have installed: + +- [CMake](https://cmake.org/) +- [Clang C++ compiler](https://clang.llvm.org/get_started.html), supporting at + least C++17. +- `tar` for extracting archives from Kaggle. + +Building natively on Windows requires the Visual Studio 2012 Build Tools with the +optional Clang/LLVM C++ frontend (`clang-cl`). This can be installed from the +command line with +[`winget`](https://learn.microsoft.com/en-us/windows/package-manager/winget/): + +```sh +winget install --id Kitware.CMake +winget install --id Microsoft.VisualStudio.2022.BuildTools --force --override "--passive --wait --add Microsoft.VisualStudio.Workload.VCTools;installRecommended --add Microsoft.VisualStudio.Component.VC.Llvm.Clang --add Microsoft.VisualStudio.Component.VC.Llvm.ClangToolset" +``` + +### Step 1: Obtain model weights and tokenizer from Kaggle or Hugging Face Hub + +Visit the +[Kaggle page for Gemma-2](https://www.kaggle.com/models/google/gemma-2/gemmaCpp) +[or Gemma-1](https://www.kaggle.com/models/google/gemma/frameworks/gemmaCpp), +and select `Model Variations |> Gemma C++`. + +On this tab, the `Variation` dropdown includes the options below. Note bfloat16 +weights are higher fidelity, while 8-bit switched floating point weights enable +faster inference. In general, we recommend starting with the `-sfp` checkpoints. + +If you are unsure which model to start with, we recommend starting with the +smallest Gemma-2 model, i.e. `2.0-2b-it-sfp`. + +Alternatively, visit the +[gemma.cpp](https://huggingface.co/models?other=gemma.cpp) models on the Hugging +Face Hub. First go the model repository of the model of interest (see +recommendations below). Then, click the `Files and versions` tab and download +the model and tokenizer files. For programmatic downloading, if you have +`huggingface_hub` installed, you can also download by running: + +``` +huggingface-cli login # Just the first time +huggingface-cli download google/gemma-2b-sfp-cpp --local-dir build/ +``` + +Gemma-1 2B instruction-tuned (`it`) and pre-trained (`pt`) models: + +| Model name | Description | +| ----------- | ----------- | +| `2b-it` | 2 billion parameter instruction-tuned model, bfloat16 | +| `2b-it-sfp` | 2 billion parameter instruction-tuned model, 8-bit switched floating point | +| `2b-pt` | 2 billion parameter pre-trained model, bfloat16 | +| `2b-pt-sfp` | 2 billion parameter pre-trained model, 8-bit switched floating point | + +Gemma-1 7B instruction-tuned (`it`) and pre-trained (`pt`) models: + +| Model name | Description | +| ----------- | ----------- | +| `7b-it` | 7 billion parameter instruction-tuned model, bfloat16 | +| `7b-it-sfp` | 7 billion parameter instruction-tuned model, 8-bit switched floating point | +| `7b-pt` | 7 billion parameter pre-trained model, bfloat16 | +| `7b-pt-sfp` | 7 billion parameter pre-trained model, 8-bit switched floating point | + +> [!NOTE] +> **Important**: We strongly recommend starting off with the `2b-it-sfp` model to +> get up and running. + +Gemma 2 models are named `gemma2-2b-it` for 2B and `9b-it` or `27b-it`. See the +`kModelFlags` definition in `common.cc`. + +### Step 2: Extract Files + +If you downloaded the models from Hugging Face, skip to step 3. + +After filling out the consent form, the download should proceed to retrieve a +tar archive file `archive.tar.gz`. Extract files from `archive.tar.gz` (this can +take a few minutes): + +``` +tar -xf archive.tar.gz +``` + +This should produce a file containing model weights such as `2b-it-sfp.sbs` and +a tokenizer file (`tokenizer.spm`). You may want to move these files to a +convenient directory location (e.g. the `build/` directory in this repo). + +### Step 3: Build + +The build system uses [CMake](https://cmake.org/). To build the gemma inference +runtime, create a build directory and generate the build files using `cmake` +from the top-level project directory. Note if you previous ran `cmake` and are +re-running with a different setting, be sure to delete all files in the `build/` +directory with `rm -rf build/*`. + +#### Unix-like Platforms +```sh +cmake -B build +``` + +After running `cmake`, you can enter the `build/` directory and run `make` to +build the `./gemma` executable: + +```sh +# Configure `build` directory +cmake --preset make + +# Build project using make +cmake --build --preset make -j [number of parallel threads to use] +``` + +Replace `[number of parallel threads to use]` with a number - the number of +cores available on your system is a reasonable heuristic. For example, +`make -j4 gemma` will build using 4 threads. If the `nproc` command is +available, you can use `make -j$(nproc) gemma` as a reasonable default +for the number of threads. + +If you aren't sure of the right value for the `-j` flag, you can simply run +`make gemma` instead and it should still build the `./gemma` executable. + +> [!NOTE] +> On Windows Subsystem for Linux (WSL) users should set the number of +> parallel threads to 1. Using a larger number may result in errors. + +If the build is successful, you should now have a `gemma` executable in the `build/` directory. + +#### Windows + +```sh +# Configure `build` directory +cmake --preset windows + +# Build project using Visual Studio Build Tools +cmake --build --preset windows -j [number of parallel threads to use] +``` + +If the build is successful, you should now have a `gemma.exe` executable in the `build/` directory. + +#### Bazel + +```sh +bazel build -c opt --cxxopt=-std=c++20 :gemma +``` + +If the build is successful, you should now have a `gemma` executable in the `bazel-bin/` directory. + +#### Make + +If you prefer Makefiles, @jart has made one available here: + +https://github.com/jart/gemma3/blob/main/Makefile + +### Step 4: Run + +You can now run `gemma` from inside the `build/` directory. + +`gemma` has the following required arguments: + +Argument | Description | Example value +--------------- | ---------------------------- | ----------------------- +`--model` | The model type. | `2b-it` ... (see below) +`--weights` | The compressed weights file. | `2b-it-sfp.sbs` +`--weight_type` | The compressed weight type. | `sfp` +`--tokenizer` | The tokenizer file. | `tokenizer.spm` + +`gemma` is invoked as: + +```sh +./gemma \ +--tokenizer [tokenizer file] \ +--weights [compressed weights file] \ +--weight_type [f32 or bf16 or sfp (default:sfp)] \ +--model [2b-it or 2b-pt or 7b-it or 7b-pt or ...] +``` + +Example invocation for the following configuration: + +- Compressed weights file `2b-it-sfp.sbs` (2B instruction-tuned model, 8-bit + switched floating point). +- Tokenizer file `tokenizer.spm`. + +```sh +./gemma \ +--tokenizer tokenizer.spm \ +--weights 2b-it-sfp.sbs --model 2b-it +``` + +### RecurrentGemma + +This repository includes a version of Gemma based on Griffin +([paper](https://arxiv.org/abs/2402.19427), +[code](https://github.com/google-deepmind/recurrentgemma)). Its architecture +includes both recurrent layers and local attention, thus it is more efficient +for longer sequences and has a smaller memory footprint than standard Gemma. We +here provide a C++ implementation of this model based on the paper. + +To use the recurrent version of Gemma included in this repository, build the +gemma binary as noted above in Step 3. Download the compressed weights and +tokenizer from the RecurrentGemma +[Kaggle](https://www.kaggle.com/models/google/recurrentgemma/gemmaCpp) as in +Step 1, and run the binary as follows: + +`./gemma --tokenizer tokenizer.spm --model gr2b-it --weights 2b-it-sfp.sbs` + +### PaliGemma Vision-Language Model + +This repository includes a version of the PaliGemma VLM +([paper](https://arxiv.org/abs/2407.07726), +[code](https://github.com/google-research/big_vision/tree/main/big_vision/configs/proj/paligemma)) +and its successor PaliGemma 2 ([paper](https://arxiv.org/abs/2412.03555)). We +provide a C++ implementation of the PaliGemma model family here. + +To use the version of PaliGemma included in this repository, build the gemma +binary as noted above in Step 3. Download the compressed weights and tokenizer +from +[Kaggle](https://www.kaggle.com/models/google/paligemma/gemmaCpp/paligemma-3b-mix-224) +and run the binary as follows: + +```sh +./gemma \ +--tokenizer paligemma_tokenizer.model \ +--model paligemma-224 \ +--weights paligemma-3b-mix-224-sfp.sbs \ +--image_file paligemma/testdata/image.ppm +``` + +Note that the image reading code is very basic to avoid depending on an image +processing library for now. We currently only support reading binary PPMs (P6). +So use a tool like `convert` to first convert your images into that format, e.g. + +`convert image.jpeg -resize 224x224^ image.ppm` + +(As the image will be resized for processing anyway, we can already resize at +this stage for slightly faster loading.) + +The interaction with the image (using the mix-224 checkpoint) may then look +something like this: + +``` +> Describe the image briefly +A large building with two towers in the middle of a city. +> What type of building is it? +church +> What color is the church? +gray +> caption image +A large building with two towers stands tall on the water's edge. The building +has a brown roof and a window on the side. A tree stands in front of the +building, and a flag waves proudly from its top. The water is calm and blue, +reflecting the sky above. A bridge crosses the water, and a red and white boat +rests on its surface. The building has a window on the side, and a flag on top. +A tall tree stands in front of the building, and a window on the building is +visible from the water. The water is green, and the sky is blue. +``` + +### Migrating to single-file format + +There is now a new format for the weights file, which is a single file that +allows to contain the tokenizer (and the model type) directly. A tool to migrate +from the multi-file format to the single-file format is available. + +```sh +compression/migrate_weights \ + --tokenizer .../tokenizer.spm --weights .../gemma2-2b-it-sfp.sbs \ + --model gemma2-2b-it --output_weights .../gemma2-2b-it-sfp-single.sbs +``` + +After migration, you can use the new weights file with gemma.cpp like this: + +```sh +./gemma --weights .../gemma2-2b-it-sfp-single.sbs +``` + +### Troubleshooting and FAQs + +**Running `./gemma` fails with "Failed to read cache gating_ein_0 (error 294) ..."** + +The most common problem is that the `--weight_type` argument does not match that +of the model file. Revisit step #3 and check which weights you downloaded. + +Note that we have already moved weight type from a compile-time decision to a +runtime argument. In a subsequent step, we plan to bake this information into +the weights. + +**Problems building in Windows / Visual Studio** + +Currently if you're using Windows, we recommend building in WSL (Windows +Subsystem for Linux). We are exploring options to enable other build +configurations, see issues for active discussion. + +**Model does not respond to instructions and produces strange output** + +A common issue is that you are using a pre-trained model, which is not +instruction-tuned and thus does not respond to instructions. Make sure you are +using an instruction-tuned model (`2b-it-sfp`, `2b-it`, `7b-it-sfp`, `7b-it`) +and not a pre-trained model (any model with a `-pt` suffix). + +**What sequence lengths are supported?** + +See `seq_len` in `configs.cc`. For the Gemma 3 models larger than 1B, this is +typically 32K but 128K would also work given enough RAM. Note that long +sequences will be slow due to the quadratic cost of attention. + +**How do I convert my fine-tune to a `.sbs` compressed model file?** + +For PaliGemma (1 and 2) checkpoints, you can use +python/convert_from_safetensors.py to convert from safetensors format (tested +with building via bazel). For an adapter model, you will likely need to call +merge_and_unload() to convert the adapter model to a single-file format before +converting it. + +Here is how to use it using a bazel build of the compression library assuming +locally installed (venv) torch, numpy, safetensors, absl-py, etc.: + +```sh +bazel build //compression/python:compression +BAZEL_OUTPUT_DIR="${PWD}/bazel-bin/compression" +python3 -c "import site; print(site.getsitepackages())" +# Use your sites-packages file here: +ln -s $BAZEL_OUTPUT_DIR [...]/site-packages/compression +python3 python/convert_from_safetensors.py --load_path [...].safetensors.index.json +``` + +See also compression/convert_weights.py for a slightly older option to convert a +pytorch checkpoint. (The code may need updates to work with Gemma-2 models.) + +**What are some easy ways to make the model run faster?** + +1. Make sure you are using the 8-bit switched floating point `-sfp` models. + These are half the size of bf16 and thus use less memory bandwidth and cache + space. +2. If you're on a laptop, make sure power mode is set to maximize performance + and saving mode is **off**. For most laptops, the power saving modes get + activated automatically if the computer is not plugged in. +3. Close other unused cpu-intensive applications. +4. On macs, anecdotally we observe a "warm-up" ramp-up in speed as performance + cores get engaged. +5. Experiment with the `--num_threads` argument value. Depending on the device, + larger numbers don't always mean better performance. + +We're also working on algorithmic and optimization approaches for faster +inference, stay tuned. + +## Usage + +`gemma` has different usage modes, controlled by the verbosity flag. + +All usage modes are currently interactive, triggering text generation upon +newline input. + +| Verbosity | Usage mode | Details | +| --------------- | ---------- | --------------------------------------------- | +| `--verbosity 0` | Minimal | Only prints generation output. Suitable as a CLI tool. | +| `--verbosity 1` | Default | Standard user-facing terminal UI. | +| `--verbosity 2` | Detailed | Shows additional developer and debug info. | + +### Interactive Terminal App + +By default, verbosity is set to 1, bringing up a terminal-based interactive +interface when `gemma` is invoked: + +```console +$ ./gemma [...] + __ _ ___ _ __ ___ _ __ ___ __ _ ___ _ __ _ __ + / _` |/ _ \ '_ ` _ \| '_ ` _ \ / _` | / __| '_ \| '_ \ +| (_| | __/ | | | | | | | | | | (_| || (__| |_) | |_) | + \__, |\___|_| |_| |_|_| |_| |_|\__,_(_)___| .__/| .__/ + __/ | | | | | + |___/ |_| |_| + +tokenizer : tokenizer.spm +compressed_weights : 2b-it-sfp.sbs +model : 2b-it +weights : [no path specified] +max_generated_tokens : 2048 + +*Usage* + Enter an instruction and press enter (%C reset conversation, %Q quits). + +*Examples* + - Write an email to grandma thanking her for the cookies. + - What are some historical attractions to visit around Massachusetts? + - Compute the nth fibonacci number in javascript. + - Write a standup comedy bit about WebGPU programming. + +> What are some outdoorsy places to visit around Boston? + +[ Reading prompt ] ..................... + + +**Boston Harbor and Islands:** + +* **Boston Harbor Islands National and State Park:** Explore pristine beaches, wildlife, and maritime history. +* **Charles River Esplanade:** Enjoy scenic views of the harbor and city skyline. +* **Boston Harbor Cruise Company:** Take a relaxing harbor cruise and admire the city from a different perspective. +* **Seaport Village:** Visit a charming waterfront area with shops, restaurants, and a seaport museum. + +**Forest and Nature:** + +* **Forest Park:** Hike through a scenic forest with diverse wildlife. +* **Quabbin Reservoir:** Enjoy boating, fishing, and hiking in a scenic setting. +* **Mount Forest:** Explore a mountain with breathtaking views of the city and surrounding landscape. + +... +``` + +### Usage as a Command Line Tool + +For using the `gemma` executable as a command line tool, it may be useful to +create an alias for gemma.cpp with arguments fully specified: + +```sh +alias gemma2b="~/gemma.cpp/build/gemma -- --tokenizer ~/gemma.cpp/build/tokenizer.spm --weights ~/gemma.cpp/build/gemma2-2b-it-sfp.sbs --model gemma2-2b-it --verbosity 0" +``` + +Replace the above paths with your own paths to the model and tokenizer paths +from the download. + +Here is an example of prompting `gemma` with a truncated input +file (using a `gemma2b` alias like defined above): + +```sh +cat configs.h | tail -n 35 | tr '\n' ' ' | xargs -0 echo "What does this C++ code do: " | gemma2b +``` + +> [!NOTE] +> CLI usage of gemma.cpp is experimental and should take context length +> limitations into account. + +The output of the above command should look like: + +```console +[ Reading prompt ] [...] +This C++ code snippet defines a set of **constants** used in a large language model (LLM) implementation, likely related to the **attention mechanism**. + +Let's break down the code: +[...] +``` + +### Incorporating gemma.cpp as a Library in your Project + +The easiest way to incorporate gemma.cpp in your own project is to pull in +gemma.cpp and dependencies using `FetchContent`. You can add the following to your +CMakeLists.txt: + +``` +include(FetchContent) + +FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 53de76561cfc149d3c01037f0595669ad32a5e7c) +FetchContent_MakeAvailable(sentencepiece) + +FetchContent_Declare(gemma GIT_REPOSITORY https://github.com/google/gemma.cpp GIT_TAG origin/main) +FetchContent_MakeAvailable(gemma) + +FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG da250571a45826b21eebbddc1e50d0c1137dee5f) +FetchContent_MakeAvailable(highway) +``` + +Note for the gemma.cpp `GIT_TAG`, you may replace `origin/main` for a specific +commit hash if you would like to pin the library version. + +After your executable is defined (substitute your executable name for +`[Executable Name]` below): + +``` +target_link_libraries([Executable Name] libgemma hwy hwy_contrib sentencepiece) +FetchContent_GetProperties(gemma) +FetchContent_GetProperties(sentencepiece) +target_include_directories([Executable Name] PRIVATE ${gemma_SOURCE_DIR}) +target_include_directories([Executable Name] PRIVATE ${sentencepiece_SOURCE_DIR}) +``` + +### Building gemma.cpp as a Library + +gemma.cpp can also be used as a library dependency in your own project. The +shared library artifact can be built by modifying the make invocation to build +the `libgemma` target instead of `gemma`. + +> [!NOTE] +> If you are using gemma.cpp in your own project with the `FetchContent` steps +> in the previous section, building the library is done automatically by `cmake` +> and this section can be skipped. + +First, run `cmake`: + +```sh +cmake -B build +``` + +Then, run `make` with the `libgemma` target: + +```sh +cd build +make -j [number of parallel threads to use] libgemma +``` + +If this is successful, you should now have a `libgemma` library file in the +`build/` directory. On Unix platforms, the filename is `libgemma.a`. + +## Independent Projects Using gemma.cpp + +Some independent projects using gemma.cpp: + +- [gemma-cpp-python - Python bindings](https://github.com/namtranase/gemma-cpp-python) +- [lua-cgemma - Lua bindings](https://github.com/ufownl/lua-cgemma) +- [Godot engine demo project](https://github.com/Rliop913/Gemma-godot-demo-project) + +If you would like to have your project included, feel free to get in touch or +submit a PR with a `README.md` edit. + +## Acknowledgements and Contacts + +gemma.cpp was started in fall 2023 by [Austin Huang](mailto:austinvhuang@google.com) +and [Jan Wassenberg](mailto:janwas@google.com), and subsequently released February 2024 +thanks to contributions from Phil Culliton, Paul Chang, and Dan Zheng. + +Griffin support was implemented in April 2024 thanks to contributions by Andrey +Mikhaylov, Eugene Kliuchnikov, Jan Wassenberg, Jyrki Alakuijala, Lode +Vandevenne, Luca Versari, Martin Bruse, Phil Culliton, Sami Boukortt, Thomas +Fischbacher and Zoltan Szabadka. + +Gemma-2 support was implemented in June/July 2024 with the help of several +people. + +PaliGemma support was implemented in September 2024 with contributions from +Daniel Keysers. + +[Jan Wassenberg](mailto:janwas@google.com) has continued to contribute many +improvements, including major gains in efficiency, since the initial release. + +This is not an officially supported Google product. diff --git a/gemma/gemma_args.h b/gemma/gemma_args.h index dc4019c..2c15986 100644 --- a/gemma/gemma_args.h +++ b/gemma/gemma_args.h @@ -13,383 +13,85 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Shared between various frontends. +// Argument parsing for Gemma. -#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_ARGS_H_ -#define THIRD_PARTY_GEMMA_CPP_GEMMA_ARGS_H_ +#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_ARGS_H_ +#define THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_ARGS_H_ -#include #include -#include #include -#include "compression/io.h" // Path #include "compression/shared.h" #include "gemma/common.h" #include "gemma/gemma.h" // For CreateGemma -#include "hwy/base.h" // HWY_IS_ASAN, HWY_ABORT +#include "hwy/base.h" // HWY_ABORT #include "ops/matmul.h" -#include "util/allocator.h" #include "util/args.h" #include "util/basics.h" // Tristate -#include "util/threading.h" -#include "util/threading_context.h" namespace gcpp { -static inline const char* CompiledConfig() { - if (HWY_IS_ASAN) { - return "asan"; - } else if (HWY_IS_MSAN) { - return "msan"; - } else if (HWY_IS_TSAN) { - return "tsan"; - } else if (HWY_IS_HWASAN) { - return "hwasan"; - } else if (HWY_IS_UBSAN) { - return "ubsan"; - } else if (HWY_IS_DEBUG_BUILD) { - return "dbg"; - } else { - return "opt"; - } -} -template -struct ArgsBase { - void Init() { static_cast(this)->ForEach(SetToDefault()); } - - void InitAndParse(int argc, char* argv[]) { - Init(); - static_cast(this)->ForEach(ParseOption(argc, argv)); - } - - void Print(int min_verbosity = 1) const { - static_cast(this)->ForEach(PrintOption(min_verbosity)); - } - - void Help() const { static_cast(this)->ForEach(PrintHelp()); } - - protected: - // Helper struct for printing help messages - struct PrintHelp { - template - void operator()(const T& value, const char* name, const T& default_value, - const char* description, int verbosity = 1) const { - fprintf(stderr, " --%s\n %s\n", name, description); - } - // Special case for strings to avoid template deduction issues - void operator()(const std::string& value, const char* name, - const std::string& default_value, const char* description, - int verbosity = 1) const { - fprintf(stderr, " --%s\n %s\n", name, description); - } - // Special case for Path type - void operator()(const Path& value, const char* name, - const Path& default_value, const char* description, - int verbosity = 1) const { - fprintf(stderr, " --%s\n %s\n", name, description); - } - }; - - // Helper struct for setting default values - struct SetToDefault { - template - void operator()(T& value, const char* name, const T& default_value, - const char* description, int verbosity = 1) const { - value = default_value; - } - }; - - // Helper struct for printing values - struct PrintOption { - explicit PrintOption(int min_verbosity) : min_verbosity_(min_verbosity) {} - - template - void operator()(const T& value, const char* name, const T& default_value, - const char* description, int verbosity = 1) const { - if (verbosity >= min_verbosity_) { - fprintf(stderr, "%s: %s\n", name, ToString(value).c_str()); - } - } - - private: - int min_verbosity_; - - // Helper function to convert values to string - template - static std::string ToString(const T& value) { - return std::to_string(value); - } - // Specialization for string - static std::string ToString(const std::string& value) { return value; } - // Specialization for Path - static std::string ToString(const Path& value) { return value.path; } - }; -}; -struct ThreadingArgs : public ArgsBase { - public: - ThreadingArgs(int argc, char* argv[]) { InitAndParse(argc, argv); } - ThreadingArgs() { Init(); }; - - int verbosity; - - size_t max_threads; // divided among the detected clusters - Tristate pin; // pin threads? - Tristate spin; // use spin waits? - - // For BoundedSlice: - size_t skip_packages; - size_t max_packages; - size_t skip_clusters; - size_t max_clusters; - size_t skip_lps; - size_t max_lps; - - std::string eot_line; - std::string prompt; - template - void ForEach(const Visitor& visitor) { - visitor(verbosity, "verbosity", 1, - "Show verbose developer information\n 0 = only print generation " - "output\n 1 = standard user-facing terminal ui\n 2 = show " - "developer/debug info).\n Default = 1.", - 2); - - // The exact meaning is more subtle: see the comment at NestedPools ctor. - visitor(max_threads, "num_threads", size_t{0}, - "Maximum number of threads to use; default 0 = unlimited.", 2); - visitor(pin, "pin", Tristate::kDefault, - "Pin threads? -1 = auto, 0 = no, 1 = yes.", 2); - visitor(spin, "spin", Tristate::kDefault, - "Use spin waits? -1 = auto, 0 = no, 1 = yes.", 2); - // These can be used to partition CPU sockets/packages and their - // clusters/CCXs across several program instances. The default is to use - // all available resources. - visitor(skip_packages, "skip_packages", size_t{0}, - "Index of the first socket to use; default 0 = unlimited.", 2); - visitor(max_packages, "max_packages", size_t{0}, - "Maximum number of sockets to use; default 0 = unlimited.", 2); - visitor(skip_clusters, "skip_clusters", size_t{0}, - "Index of the first CCX to use; default 0 = unlimited.", 2); - visitor(max_clusters, "max_clusters", size_t{0}, - "Maximum number of CCXs to use; default 0 = unlimited.", 2); - // These are only used when CPU topology is unknown. - visitor(skip_lps, "skip_lps", size_t{0}, - "Index of the first LP to use; default 0 = unlimited.", 2); - visitor(max_lps, "max_lps", size_t{0}, - "Maximum number of LPs to use; default 0 = unlimited.", 2); - - visitor( - eot_line, "eot_line", std::string(""), - "End of turn line. " - "When you specify this, the prompt will be all lines " - "before the line where only the given string appears.\n Default = " - "When a newline is encountered, that signals the end of the turn.", - 2); - - visitor(prompt, "prompt", std::string(""), - "Prompt string for non-interactive mode. When provided, the model " - "generates a response and exits.", - 2); - } -}; -static inline BoundedTopology CreateTopology(const ThreadingArgs& threading) { - return BoundedTopology( - BoundedSlice(threading.skip_packages, threading.max_packages), - BoundedSlice(threading.skip_clusters, threading.max_clusters), - BoundedSlice(threading.skip_lps, threading.max_lps)); -} - -static inline MatMulEnv MakeMatMulEnv(const ThreadingArgs& threading) { - ThreadingContext2::SetArgs(threading); - return MatMulEnv(ThreadingContext2::Get()); -} -// Note: These functions may need adjustments depending on your specific class -// definitions -static inline BoundedTopology CreateTopology(const ThreadingArgs& app) { - return BoundedTopology(BoundedSlice(app.skip_packages, app.max_packages), - BoundedSlice(app.skip_clusters, app.max_clusters), - BoundedSlice(app.skip_lps, app.max_lps)); -} - -// This function may need to be adjusted based on your NestedPools constructor -// signature -static inline NestedPools CreatePools(const BoundedTopology& topology, - const ThreadingArgs& threading) { - // Make sure Allocator::Init() is properly declared/defined - const Allocator2& allocator = ThreadingContext2::Get().allocator; - // Allocator::Init(topology); - - // Adjust the constructor call based on your actual NestedPools constructor - // The error suggests that the constructor doesn't match these arguments - return NestedPools(topology, allocator, threading.max_threads, threading.pin); - // Alternative: return NestedPools(topology, app.max_threads, app.pin); -} - -struct LoaderArgs : public ArgsBase { - LoaderArgs(int argc, char* argv[], bool validate = true) { - InitAndParse(argc, argv); - - if (validate) { - if (const char* error = Validate()) { - HWY_ABORT("Invalid args: %s", error); - } - } - } - LoaderArgs(const std::string& tokenizer_path, const std::string& weights_path, - const std::string& model, bool validate = true) { - Init(); // Init sets to defaults, so assignments must come after Init(). - tokenizer.path = tokenizer_path; - weights.path = weights_path; - model_type_str = model; - - if (validate) { - if (const char* error = Validate()) { - HWY_ABORT("Invalid args: %s", error); - } - } - }; - - // Returns error string or nullptr if OK. - const char* Validate() { - if (weights.path.empty()) { - return "Missing --weights flag, a file for the model weights."; - } - if (!weights.Exists()) { - return "Can't open file specified with --weights flag."; - } - info_.model = Model::UNKNOWN; - info_.wrapping = PromptWrapping::GEMMA_PT; - info_.weight = Type::kUnknown; - if (!model_type_str.empty()) { - const char* err = ParseModelTypeAndWrapping(model_type_str, info_.model, - info_.wrapping); - if (err != nullptr) return err; - } - if (!weight_type_str.empty()) { - const char* err = ParseType(weight_type_str, info_.weight); - if (err != nullptr) return err; - } - if (!tokenizer.path.empty()) { - if (!tokenizer.Exists()) { - return "Can't open file specified with --tokenizer flag."; - } - } - // model_type and tokenizer must be either both present or both absent. - // Further checks happen on weight loading. - if (model_type_str.empty() != tokenizer.path.empty()) { - return "Missing or extra flags for model_type or tokenizer."; - } - return nullptr; - } - - Path tokenizer; - Path weights; // weights file location - Path compressed_weights; - std::string model_type_str; - std::string weight_type_str; - - template - void ForEach(const Visitor& visitor) { - visitor(tokenizer, "tokenizer", Path(), - "Path name of tokenizer model file."); - visitor(weights, "weights", Path(), - "Path name of model weights (.sbs) file.\n Required argument.\n"); - visitor(compressed_weights, "compressed_weights", Path(), - "Deprecated alias for --weights."); - visitor(model_type_str, "model", std::string(), - "Model type, see common.cc for valid values.\n"); - visitor(weight_type_str, "weight_type", std::string("sfp"), - "Weight type\n f32 = float, bf16 = bfloat16, sfp = 8-bit SFP."); - } - - // Uninitialized before Validate, must call after that. - const ModelInfo& Info() const { return info_; } - - private: - // TODO(rays): remove this. Eventually ModelConfig will be loaded from the - // weights file, so we can remove the need for this struct entirely. - ModelInfo info_; -}; - -// `env` must remain valid for the lifetime of the Gemma. -static inline Gemma CreateGemma(const LoaderArgs& loader, MatMulEnv& env) { - if (Type::kUnknown == loader.Info().weight || - Model::UNKNOWN == loader.Info().model || loader.tokenizer.path.empty()) { - // New weights file format doesn't need tokenizer path or model/weightinfo. - return Gemma(loader.weights, env); - } - return Gemma(loader.tokenizer, loader.weights, loader.Info(), env); -} - -// `env` must remain valid for the lifetime of the Gemma. -static inline std::unique_ptr AllocateGemma(const LoaderArgs& loader, - MatMulEnv& env) { - if (Type::kUnknown == loader.Info().weight || - Model::UNKNOWN == loader.Info().model || loader.tokenizer.path.empty()) { - // New weights file format doesn't need tokenizer path or model/weight info. - return std::make_unique(loader.weights, env); - } - return std::make_unique(loader.tokenizer, loader.weights, - loader.Info(), env); -} - +// Arguments related to inference: sampling, text etc. struct InferenceArgs : public ArgsBase { - InferenceArgs(int argc, char* argv[]) { InitAndParse(argc, argv); } - InferenceArgs() { Init(); }; - - int verbosity; - + // Arguments for getc-like interfaces + size_t max_tokens; size_t max_generated_tokens; - - size_t prefill_tbatch_size; - size_t decode_qbatch_size; - float temperature; size_t top_k; - bool deterministic; - bool multiturn; - Path image_file; + float top_p; + float min_p; + int repeat_penalty_power; + float repeat_penalty_presence; + float repeat_penalty_decay; + float repeat_penalty_range; + // Batch configuration: + size_t prefill_tbatch_size; + size_t decode_tbatch_size; + + // Non-interactive mode prompt + std::string prompt; std::string eot_line; - // Returns error string or nullptr if OK. - const char* Validate() const { - if (max_generated_tokens > gcpp::kSeqLen) { - return "max_generated_tokens is larger than the maximum sequence length " - "(see configs.h)."; - } - return nullptr; - } - template - void ForEach(const Visitor& visitor) { - visitor(verbosity, "verbosity", 1, - "Show verbose developer information\n 0 = only print generation " - "output\n 1 = standard user-facing terminal ui\n 2 = show " - "developer/debug info).\n Default = 1.", - 2); + void ForEach(Visitor& visitor) { + // Each line specifies a variable member, its name, default value, and help. + visitor(max_tokens, "max_tokens", size_t{50}, + "Maximum number of total tokens including prompt (0=no limit).", 1); + visitor(max_generated_tokens, "max_generated_tokens", size_t{512}, + "Maximum number of generated tokens (not including prompt) (0=no " + "limit).", + 1); + visitor(temperature, "temperature", 1.0f, + "Temperature (randomness) for logits.", 1); + visitor(top_k, "top_k", size_t{40}, + "Number of highest-probability tokens to consider (0=unlimited).", + 1); + visitor(top_p, "top_p", 0.9f, "Top-p probability threshold (0.0=disabled).", + 1); + visitor(min_p, "min_p", 0.0f, "Min-p probability threshold (0.0=disabled).", + 1); + visitor( + repeat_penalty_power, "repeat_penalty_power", 1, + "Penalty power (1=standard frequentist penalty). If 0, skips penalty " + "computation.", + 1); + visitor(repeat_penalty_presence, "repeat_penalty_presence", 0.0f, + "Penalty for token presence regardless of frequency (additive).", + 1); + visitor(repeat_penalty_decay, "repeat_penalty_decay", 0.0f, + "Penalty for token n positions ago is decayed by " + "power(repeat_penalty_decay, n).", + 1); + visitor(repeat_penalty_range, "repeat_penalty_range", 8.0f, + "Penalty fades out near the end of range (tokens)", 1); - visitor(max_generated_tokens, "max_generated_tokens", size_t{2048}, - "Maximum number of tokens to generate."); - - visitor(prefill_tbatch_size, "prefill_tbatch", size_t{256}, - "Prefill: max tokens per batch."); - visitor(decode_qbatch_size, "decode_qbatch", size_t{16}, - "Decode: max queries per batch."); - - visitor(temperature, "temperature", 1.0f, "Temperature for top-K", 2); - visitor(top_k, "top_k", size_t{1}, "Number of top-K tokens to sample from", - 2); - visitor(deterministic, "deterministic", false, - "Make top-k sampling deterministic", 2); - visitor(multiturn, "multiturn", false, - "Multiturn mode\n 0 = clear KV cache after every " - "interaction\n 1 = continue KV cache after every interaction\n " - " Default : 0 (conversation " - "resets every turn)"); - visitor(image_file, "image_file", Path(), "Image file to load."); + // Batch configuration: + visitor(prefill_tbatch_size, "prefill_tbatch_size", size_t{2}, + "Token batch size for prefill; <= 32", 2); + visitor(decode_tbatch_size, "decode_tbatch_size", size_t{1}, + "Token batch size for decode (only 1 currently supported)", 2); visitor( eot_line, "eot_line", std::string(""), @@ -397,47 +99,123 @@ struct InferenceArgs : public ArgsBase { "When you specify this, the prompt will be all lines " "before the line where only the given string appears.\n Default = " "When a newline is encountered, that signals the end of the turn.", - 2); + 1); + + // Non-interactive mode prompt + visitor(prompt, "prompt", std::string(""), + "Prompt to use in non-interactive mode", 1); } - void CopyTo(RuntimeConfig& runtime_config) const { - runtime_config.max_generated_tokens = max_generated_tokens; - runtime_config.prefill_tbatch_size = prefill_tbatch_size; - runtime_config.decode_qbatch_size = decode_qbatch_size; - if (prefill_tbatch_size > MMStorage::kMaxM) { - HWY_ABORT( - "prefill_tbatch_size %zu > kMaxM %zu: specify a smaller value, " - "or increase the constant in MMStorage.\n", - prefill_tbatch_size, MMStorage::kMaxM); + const char* Validate() const { + if (max_generated_tokens == 0 && max_tokens == 0) { + return "At least one of max_tokens and max_generated_tokens must be > 0"; } - if (decode_qbatch_size > MMStorage::kMaxM) { - HWY_ABORT( - "decode_qbatch_size %zu > kMaxM %zu: specify a smaller value, " - "or increase the constant in MMStorage.\n", - decode_qbatch_size, MMStorage::kMaxM); + if (temperature <= 0.0) { + return "Temperature must be > 0.0"; } - - runtime_config.temperature = temperature; - runtime_config.top_k = top_k; + if (prefill_tbatch_size > 32) { + return "prefill_tbatch_size must be <= 32"; + } + if (decode_tbatch_size != 1) { + return "decode_tbatch_size must be 1"; + } + return nullptr; } }; -static inline void ShowConfig(const ThreadingArgs& threading, - const LoaderArgs& loader, - const InferenceArgs& inference) { - threading.Print(); - loader.Print(); - inference.Print(); -} -static inline void ShowHelp(const ThreadingArgs& threading, - const LoaderArgs& loader, - const InferenceArgs& inference) { - fprintf(stderr, "\nUsage: gemma [flags]\n\nFlags:\n"); - threading.Help(); - loader.Help(); - inference.Help(); -} +// Arguments related to model weights. +struct LoaderArgs : public ArgsBase { + Path model_path; // Path to directory containing the weights + Path tokenizer; // Optional: can be derived from model_path + bool model_is_gemma2; + Gemma::Config::WeightFormat weight_format; + + template + void ForEach(Visitor& visitor) { + // Each line specifies a variable member, its name, default value, and help. + visitor(model_path, "model", Path{}, + "Directory containing weights or config file from `gemma.cpp " + "convert`.", + 0); + visitor(tokenizer, "tokenizer", Path{}, + "Optional path to tokenizer.model; if empty, looks in model_path.", + 2); + visitor(model_is_gemma2, "model_is_gemma2", false, + "Whether the model is a Gemma 2 model", 1); + visitor(weight_format, "format", Gemma::Config::kBfloat16, + "Model weights format: 0=F32, 1=F16, 2=BF16", 2); + } + + const char* Validate() const { + if (model_path.path.empty()) { + return "Empty model path"; + } + if (weight_format != Gemma::Config::kBfloat16 && + weight_format != Gemma::Config::kFloat16 && + weight_format != Gemma::Config::kFloat32) { + return "Invalid weight format"; + } + return nullptr; + } +}; + +// Threading-related arguments. +struct ThreadingArgs : public ArgsBase { + size_t num_threads; + Tristate pin_threads; + Tristate use_spinning; + int verbosity; + + template + void ForEach(Visitor& visitor) { + visitor(num_threads, "threads", size_t{0}, + "Number of threads (0=auto, half of logical cores)", 1); + visitor(pin_threads, "pin_threads", Tristate::kDefault, + "Set to true/false to force enable/disable thread pinning.", 2); + visitor(use_spinning, "use_spinning", Tristate::kDefault, + "Set to true/false to enable/disable thread spinning (typically " + "improves " + "performance but increases power usage)", + 2); + visitor(verbosity, "verbosity", 1, + "Controls printing of progress messages to stderr", 1); + } + + // Returns nullptr if OK, otherwise error message. + const char* Validate() const { return nullptr; } + + // Returns num_threads to use. + size_t NumThreadsToUse() const { + return num_threads == 0 ? (size_t{hwy::NumberOfProcessors()} + 1) / 2 + : num_threads; + } +}; + +// Command-line arguments for PeftGemma and Gemma. +struct GemmaArgs : public ArgsBase { + InferenceArgs inference; + LoaderArgs loader; + ThreadingArgs threading; + // For collect_stats.cc: + Path output; + + bool trace_outputs; // For -ftrace and dump_csv.cc + bool trace_base; // For -ftrace + int time_it; // For time_it.cc + + template + void ForEach(Visitor& visitor) { + inference.ForEach(visitor); + loader.ForEach(visitor); + threading.ForEach(visitor); + + visitor(output, "output", Path{}, "Where to write CSV data / stats", 2); + visitor(trace_outputs, "trace_outputs", false, "For tracing", 2); + visitor(trace_base, "trace_base", false, "For tracing", 2); + visitor(time_it, "time_it", 0, "For benchmarks", 2); + } +}; } // namespace gcpp -#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_ARGS_H_ \ No newline at end of file +#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_ARGS_H_ \ No newline at end of file diff --git a/gemma/run.cc b/gemma/run.cc index 381dac4..32eb5ff 100644 --- a/gemma/run.cc +++ b/gemma/run.cc @@ -78,6 +78,18 @@ std::string GetPrompt(std::istream& input, int verbosity, return prompt_string; } +// New GetPrompt function that accepts InferenceArgs +std::string GetPrompt(const InferenceArgs& inference, int verbosity, + size_t turn) { + // Check for command-line prompt first + if (!inference.prompt.empty()) { + return inference.prompt; + } + + // Use the existing function for interactive mode + return GetPrompt(std::cin, verbosity, inference.eot_line); +} + // The main Read-Eval-Print Loop. void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference, Gemma& model, KVCache& kv_cache) { @@ -89,6 +101,9 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference, std::mt19937 gen; InitGenerator(inference, gen); + // Add flag to track non-interactive mode + bool non_interactive_mode = !inference.prompt.empty(); + const bool have_image = !inference.image_file.path.empty(); Image image; ImageTokens image_tokens; @@ -151,47 +166,30 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference, // Read prompt and handle special commands. std::string prompt_string = - GetPrompt(std::cin, inference.verbosity, inference.eot_line); - if (!std::cin) return; + GetPrompt(inference, inference.verbosity, abs_pos); + + if (!std::cin && !non_interactive_mode) return; + // If !eot_line.empty(), we append \n, so only look at the first 2 chars. - if (prompt_string.size() >= 2 && prompt_string[0] == '%') { + if (!non_interactive_mode && prompt_string.size() >= 2 && + prompt_string[0] == '%') { if (prompt_string[1] == 'q' || prompt_string[1] == 'Q') return; if (prompt_string[1] == 'c' || prompt_string[1] == 'C') { abs_pos = 0; continue; } } - if (prompt_string.empty()) { + + if (!non_interactive_mode && prompt_string.empty()) { std::cout << "Use '%q' to quit.\n"; continue; } - // Wrap, tokenize and maybe log prompt tokens. - std::vector prompt = WrapAndTokenize(model.Tokenizer(), model.Info(), - abs_pos, prompt_string); - prompt_size = prompt.size(); - if constexpr (kVerboseLogTokens) { - for (int i = 0; i < prompt_size; ++i) { - fprintf(stderr, "DDD TOKEN %3d: %6d\n", i, prompt[i]); - } - } - - // Set up runtime config. - TimingInfo timing_info = {.verbosity = inference.verbosity}; - RuntimeConfig runtime_config = {.gen = &gen, - .verbosity = inference.verbosity, - .stream_token = stream_token, - .use_spinning = threading.spin}; - inference.CopyTo(runtime_config); - size_t prefix_end = 0; - std::vector prompt; if (have_image) { prompt = WrapAndTokenize(model.Tokenizer(), model.ChatTemplate(), model.Info(), abs_pos, prompt_string, image_tokens.BatchSize()); - runtime_config.image_tokens = &image_tokens; - prompt_size = prompt.size(); // The end of the prefix for prefix-LM style attention in Paligemma. // See Figure 2 of https://arxiv.org/abs/2407.07726. prefix_end = prompt_size; @@ -209,6 +207,24 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference, } } + // Set up runtime config. + TimingInfo timing_info = {.verbosity = inference.verbosity}; + RuntimeConfig runtime_config = {.gen = &gen, + .verbosity = inference.verbosity, + .stream_token = stream_token, + .use_spinning = threading.spin}; + inference.CopyTo(runtime_config); + size_t prefix_end = 0; + + if (have_image) { + runtime_config.image_tokens = &image_tokens; + prompt_size = prompt.size(); + // The end of the prefix for prefix-LM style attention in Paligemma. + prefix_end = prompt_size; + // We need to look at all the tokens for the prefix. + runtime_config.prefill_tbatch_size = prompt_size; + } + // Generate until EOS or max_generated_tokens. if (inference.verbosity >= 1) { std::cerr << "\n[ Reading prompt ] " << std::flush; @@ -217,6 +233,11 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference, timing_info); std::cout << "\n\n"; + // Break the loop if in non-interactive mode + if (non_interactive_mode) { + break; + } + // Prepare for the next turn. Works only for PaliGemma. if (!inference.multiturn || model.Info().wrapping == PromptWrapping::PALIGEMMA) { @@ -249,22 +270,6 @@ void Run(ThreadingArgs& threading, LoaderArgs& loader, KVCache kv_cache = KVCache::Create(model.GetModelConfig(), inference.prefill_tbatch_size); - if (!threading.prompt.empty()) { - std::vector prompt = - WrapAndTokenize(model.Tokenizer(), model.ChatTemplate(), model.Info(), - 0, threading.prompt); - - TimingInfo timing_info = {.verbosity = inference.verbosity}; - RuntimeConfig runtime_config = {.gen = nullptr, // Use default generator - .verbosity = inference.verbosity, - .use_spinning = threading.spin}; - inference.CopyTo(runtime_config); - - model.Generate(runtime_config, prompt, 0, 0, kv_cache, timing_info); - std::cout << "\n"; - return; // Exit after generating response - } - if (inference.verbosity >= 1) { std::string instructions = "*Usage*\n" @@ -286,10 +291,13 @@ void Run(ThreadingArgs& threading, LoaderArgs& loader, instructions += multiturn; instructions += examples; - std::cout << "\033[2J\033[1;1H" // clear screen - << kAsciiArtBanner << "\n\n"; - ShowConfig(threading, loader, inference); - std::cout << "\n" << instructions << "\n"; + // Skip the banner and instructions in non-interactive mode + if (inference.prompt.empty()) { + std::cout << "\033[2J\033[1;1H" // clear screen + << kAsciiArtBanner << "\n\n"; + ShowConfig(threading, loader, inference); + std::cout << "\n" << instructions << "\n"; + } } ReplGemma(threading, inference, model, kv_cache); @@ -328,4 +336,4 @@ int main(int argc, char** argv) { } PROFILER_PRINT_RESULTS(); // Must call outside the zone above. return 0; -} +} \ No newline at end of file From 8246e4919945c1bdc2022730add7e378f00a2373 Mon Sep 17 00:00:00 2001 From: prajwalc22 Date: Wed, 16 Apr 2025 16:26:52 +0530 Subject: [PATCH 6/9] Add non-interactive mode support - Added prompt flag to InferenceArgs for non-interactive mode - Set user-facing options to verbosity level 1 - Fixed prompt_size declaration and variable ordering in run.cc - Properly set prompt_size after WrapAndTokenize calls - Moved kVerboseLogTokens block after prompt_size is set --- gemma/gemma_args.h | 355 ++++++++++++++++++++++++--------------------- gemma/run.cc | 57 ++++---- 2 files changed, 213 insertions(+), 199 deletions(-) diff --git a/gemma/gemma_args.h b/gemma/gemma_args.h index 2c15986..d02dece 100644 --- a/gemma/gemma_args.h +++ b/gemma/gemma_args.h @@ -13,15 +13,18 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Argument parsing for Gemma. +// Shared between various frontends. -#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_ARGS_H_ -#define THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_ARGS_H_ +#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_ARGS_H_ +#define THIRD_PARTY_GEMMA_CPP_GEMMA_ARGS_H_ +#include #include +#include #include +#include "compression/io.h" // Path #include "compression/shared.h" #include "gemma/common.h" #include "gemma/gemma.h" // For CreateGemma @@ -32,66 +35,174 @@ namespace gcpp { -// Arguments related to inference: sampling, text etc. -struct InferenceArgs : public ArgsBase { - // Arguments for getc-like interfaces - size_t max_tokens; - size_t max_generated_tokens; - float temperature; - size_t top_k; - float top_p; - float min_p; - int repeat_penalty_power; - float repeat_penalty_presence; - float repeat_penalty_decay; - float repeat_penalty_range; +struct LoaderArgs : public ArgsBase { + LoaderArgs(int argc, char* argv[], bool validate = true) { + InitAndParse(argc, argv); - // Batch configuration: - size_t prefill_tbatch_size; - size_t decode_tbatch_size; + if (validate) { + if (const char* error = Validate()) { + HWY_ABORT("Invalid args: %s", error); + } + } + } + LoaderArgs(const std::string& tokenizer_path, const std::string& weights_path, + const std::string& model, bool validate = true) { + Init(); // Init sets to defaults, so assignments must come after Init(). + tokenizer.path = tokenizer_path; + weights.path = weights_path; + model_type_str = model; - // Non-interactive mode prompt - std::string prompt; - std::string eot_line; + if (validate) { + if (const char* error = Validate()) { + HWY_ABORT("Invalid args: %s", error); + } + } + }; + + // Returns error string or nullptr if OK. + const char* Validate() { + if (weights.path.empty()) { + return "Missing --weights flag, a file for the model weights."; + } + if (!weights.Exists()) { + return "Can't open file specified with --weights flag."; + } + info_.model = Model::UNKNOWN; + info_.wrapping = PromptWrapping::GEMMA_PT; + info_.weight = Type::kUnknown; + if (!model_type_str.empty()) { + const char* err = ParseModelTypeAndWrapping(model_type_str, info_.model, + info_.wrapping); + if (err != nullptr) return err; + } + if (!weight_type_str.empty()) { + const char* err = ParseType(weight_type_str, info_.weight); + if (err != nullptr) return err; + } + if (!tokenizer.path.empty()) { + if (!tokenizer.Exists()) { + return "Can't open file specified with --tokenizer flag."; + } + } + // model_type and tokenizer must be either both present or both absent. + // Further checks happen on weight loading. + if (model_type_str.empty() != tokenizer.path.empty()) { + return "Missing or extra flags for model_type or tokenizer."; + } + return nullptr; + } + + Path tokenizer; + Path weights; // weights file location + Path compressed_weights; + std::string model_type_str; + std::string weight_type_str; template - void ForEach(Visitor& visitor) { - // Each line specifies a variable member, its name, default value, and help. - visitor(max_tokens, "max_tokens", size_t{50}, - "Maximum number of total tokens including prompt (0=no limit).", 1); - visitor(max_generated_tokens, "max_generated_tokens", size_t{512}, - "Maximum number of generated tokens (not including prompt) (0=no " - "limit).", - 1); - visitor(temperature, "temperature", 1.0f, - "Temperature (randomness) for logits.", 1); - visitor(top_k, "top_k", size_t{40}, - "Number of highest-probability tokens to consider (0=unlimited).", - 1); - visitor(top_p, "top_p", 0.9f, "Top-p probability threshold (0.0=disabled).", - 1); - visitor(min_p, "min_p", 0.0f, "Min-p probability threshold (0.0=disabled).", - 1); - visitor( - repeat_penalty_power, "repeat_penalty_power", 1, - "Penalty power (1=standard frequentist penalty). If 0, skips penalty " - "computation.", - 1); - visitor(repeat_penalty_presence, "repeat_penalty_presence", 0.0f, - "Penalty for token presence regardless of frequency (additive).", - 1); - visitor(repeat_penalty_decay, "repeat_penalty_decay", 0.0f, - "Penalty for token n positions ago is decayed by " - "power(repeat_penalty_decay, n).", - 1); - visitor(repeat_penalty_range, "repeat_penalty_range", 8.0f, - "Penalty fades out near the end of range (tokens)", 1); + void ForEach(const Visitor& visitor) { + visitor(tokenizer, "tokenizer", Path(), + "Path name of tokenizer model file."); + visitor(weights, "weights", Path(), + "Path name of model weights (.sbs) file.\n Required argument.\n"); + visitor(compressed_weights, "compressed_weights", Path(), + "Deprecated alias for --weights."); + visitor(model_type_str, "model", std::string(), + "Model type, see common.cc for valid values.\n"); + visitor(weight_type_str, "weight_type", std::string("sfp"), + "Weight type\n f32 = float, bf16 = bfloat16, sfp = 8-bit SFP."); + } - // Batch configuration: - visitor(prefill_tbatch_size, "prefill_tbatch_size", size_t{2}, - "Token batch size for prefill; <= 32", 2); - visitor(decode_tbatch_size, "decode_tbatch_size", size_t{1}, - "Token batch size for decode (only 1 currently supported)", 2); + // Uninitialized before Validate, must call after that. + const ModelInfo& Info() const { return info_; } + + private: + ModelInfo info_; +}; + +// `env` must remain valid for the lifetime of the Gemma. +static inline Gemma CreateGemma(const LoaderArgs& loader, MatMulEnv& env) { + if (Type::kUnknown == loader.Info().weight || + Model::UNKNOWN == loader.Info().model || loader.tokenizer.path.empty()) { + // New weights file format doesn't need tokenizer path or model/weightinfo. + return Gemma(loader.weights, env); + } + return Gemma(loader.tokenizer, loader.weights, loader.Info(), env); +} + +// `env` must remain valid for the lifetime of the Gemma. +static inline std::unique_ptr AllocateGemma(const LoaderArgs& loader, + MatMulEnv& env) { + if (Type::kUnknown == loader.Info().weight || + Model::UNKNOWN == loader.Info().model || loader.tokenizer.path.empty()) { + // New weights file format doesn't need tokenizer path or model/weight info. + return std::make_unique(loader.weights, env); + } + return std::make_unique(loader.tokenizer, loader.weights, + loader.Info(), env); +} + +struct InferenceArgs : public ArgsBase { + InferenceArgs(int argc, char* argv[]) { InitAndParse(argc, argv); } + InferenceArgs() { Init(); }; + + int verbosity; + + size_t max_generated_tokens; + + size_t prefill_tbatch_size; + size_t decode_qbatch_size; + + float temperature; + size_t top_k; + bool deterministic; + bool multiturn; + Path image_file; + + std::string prompt; // Added prompt flag for non-interactive mode + std::string eot_line; + + // Returns error string or nullptr if OK. + const char* Validate() const { + if (max_generated_tokens > gcpp::kSeqLen) { + return "max_generated_tokens is larger than the maximum sequence length " + "(see configs.h)."; + } + return nullptr; + } + + template + void ForEach(const Visitor& visitor) { + visitor(verbosity, "verbosity", 1, + "Show verbose developer information\n 0 = only print generation " + "output\n 1 = standard user-facing terminal ui\n 2 = show " + "developer/debug info).\n Default = 1.", + 1); // Changed verbosity level to 1 since it's user-facing + + visitor(max_generated_tokens, "max_generated_tokens", size_t{2048}, + "Maximum number of tokens to generate."); + + visitor(prefill_tbatch_size, "prefill_tbatch", size_t{256}, + "Prefill: max tokens per batch."); + visitor(decode_qbatch_size, "decode_qbatch", size_t{16}, + "Decode: max queries per batch."); + + visitor(temperature, "temperature", 1.0f, "Temperature for top-K", 2); + visitor(top_k, "top_k", size_t{1}, "Number of top-K tokens to sample from", + 2); + visitor(deterministic, "deterministic", false, + "Make top-k sampling deterministic", 2); + visitor(multiturn, "multiturn", false, + "Multiturn mode\n 0 = clear KV cache after every " + "interaction\n 1 = continue KV cache after every interaction\n " + " Default : 0 (conversation " + "resets every turn)"); + visitor(image_file, "image_file", Path(), "Image file to load."); + + visitor(prompt, "prompt", std::string(""), + "Initial prompt for non-interactive mode. When specified, " + "generates a response" + " and exits.", + 1); // Added as user-facing option visitor( eot_line, "eot_line", std::string(""), @@ -99,123 +210,31 @@ struct InferenceArgs : public ArgsBase { "When you specify this, the prompt will be all lines " "before the line where only the given string appears.\n Default = " "When a newline is encountered, that signals the end of the turn.", - 1); - - // Non-interactive mode prompt - visitor(prompt, "prompt", std::string(""), - "Prompt to use in non-interactive mode", 1); + 2); } - const char* Validate() const { - if (max_generated_tokens == 0 && max_tokens == 0) { - return "At least one of max_tokens and max_generated_tokens must be > 0"; + void CopyTo(RuntimeConfig& runtime_config) const { + runtime_config.max_generated_tokens = max_generated_tokens; + runtime_config.prefill_tbatch_size = prefill_tbatch_size; + runtime_config.decode_qbatch_size = decode_qbatch_size; + if (prefill_tbatch_size > MMStorage::kMaxM) { + HWY_ABORT( + "prefill_tbatch_size %zu > kMaxM %zu: specify a smaller value, " + "or increase the constant in MMStorage.\n", + prefill_tbatch_size, MMStorage::kMaxM); } - if (temperature <= 0.0) { - return "Temperature must be > 0.0"; + if (decode_qbatch_size > MMStorage::kMaxM) { + HWY_ABORT( + "decode_qbatch_size %zu > kMaxM %zu: specify a smaller value, " + "or increase the constant in MMStorage.\n", + decode_qbatch_size, MMStorage::kMaxM); } - if (prefill_tbatch_size > 32) { - return "prefill_tbatch_size must be <= 32"; - } - if (decode_tbatch_size != 1) { - return "decode_tbatch_size must be 1"; - } - return nullptr; - } -}; -// Arguments related to model weights. -struct LoaderArgs : public ArgsBase { - Path model_path; // Path to directory containing the weights - Path tokenizer; // Optional: can be derived from model_path - bool model_is_gemma2; - Gemma::Config::WeightFormat weight_format; - - template - void ForEach(Visitor& visitor) { - // Each line specifies a variable member, its name, default value, and help. - visitor(model_path, "model", Path{}, - "Directory containing weights or config file from `gemma.cpp " - "convert`.", - 0); - visitor(tokenizer, "tokenizer", Path{}, - "Optional path to tokenizer.model; if empty, looks in model_path.", - 2); - visitor(model_is_gemma2, "model_is_gemma2", false, - "Whether the model is a Gemma 2 model", 1); - visitor(weight_format, "format", Gemma::Config::kBfloat16, - "Model weights format: 0=F32, 1=F16, 2=BF16", 2); - } - - const char* Validate() const { - if (model_path.path.empty()) { - return "Empty model path"; - } - if (weight_format != Gemma::Config::kBfloat16 && - weight_format != Gemma::Config::kFloat16 && - weight_format != Gemma::Config::kFloat32) { - return "Invalid weight format"; - } - return nullptr; - } -}; - -// Threading-related arguments. -struct ThreadingArgs : public ArgsBase { - size_t num_threads; - Tristate pin_threads; - Tristate use_spinning; - int verbosity; - - template - void ForEach(Visitor& visitor) { - visitor(num_threads, "threads", size_t{0}, - "Number of threads (0=auto, half of logical cores)", 1); - visitor(pin_threads, "pin_threads", Tristate::kDefault, - "Set to true/false to force enable/disable thread pinning.", 2); - visitor(use_spinning, "use_spinning", Tristate::kDefault, - "Set to true/false to enable/disable thread spinning (typically " - "improves " - "performance but increases power usage)", - 2); - visitor(verbosity, "verbosity", 1, - "Controls printing of progress messages to stderr", 1); - } - - // Returns nullptr if OK, otherwise error message. - const char* Validate() const { return nullptr; } - - // Returns num_threads to use. - size_t NumThreadsToUse() const { - return num_threads == 0 ? (size_t{hwy::NumberOfProcessors()} + 1) / 2 - : num_threads; - } -}; - -// Command-line arguments for PeftGemma and Gemma. -struct GemmaArgs : public ArgsBase { - InferenceArgs inference; - LoaderArgs loader; - ThreadingArgs threading; - // For collect_stats.cc: - Path output; - - bool trace_outputs; // For -ftrace and dump_csv.cc - bool trace_base; // For -ftrace - int time_it; // For time_it.cc - - template - void ForEach(Visitor& visitor) { - inference.ForEach(visitor); - loader.ForEach(visitor); - threading.ForEach(visitor); - - visitor(output, "output", Path{}, "Where to write CSV data / stats", 2); - visitor(trace_outputs, "trace_outputs", false, "For tracing", 2); - visitor(trace_base, "trace_base", false, "For tracing", 2); - visitor(time_it, "time_it", 0, "For benchmarks", 2); + runtime_config.temperature = temperature; + runtime_config.top_k = top_k; } }; } // namespace gcpp -#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_ARGS_H_ \ No newline at end of file +#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_ARGS_H_ \ No newline at end of file diff --git a/gemma/run.cc b/gemma/run.cc index 32eb5ff..b7e8fa1 100644 --- a/gemma/run.cc +++ b/gemma/run.cc @@ -78,16 +78,15 @@ std::string GetPrompt(std::istream& input, int verbosity, return prompt_string; } -// New GetPrompt function that accepts InferenceArgs -std::string GetPrompt(const InferenceArgs& inference, int verbosity, - size_t turn) { - // Check for command-line prompt first +// Get prompt either from interactive input or command line +std::string GetPrompt(const InferenceArgs& inference) { + // If prompt is provided via command line, use that if (!inference.prompt.empty()) { return inference.prompt; } - // Use the existing function for interactive mode - return GetPrompt(std::cin, verbosity, inference.eot_line); + // Otherwise get interactive prompt + return GetPrompt(std::cin, inference.verbosity, inference.eot_line); } // The main Read-Eval-Print Loop. @@ -101,9 +100,6 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference, std::mt19937 gen; InitGenerator(inference, gen); - // Add flag to track non-interactive mode - bool non_interactive_mode = !inference.prompt.empty(); - const bool have_image = !inference.image_file.path.empty(); Image image; ImageTokens image_tokens; @@ -165,13 +161,12 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference, tokens_generated_this_turn = 0; // Read prompt and handle special commands. - std::string prompt_string = - GetPrompt(inference, inference.verbosity, abs_pos); + std::string prompt_string = GetPrompt(inference); - if (!std::cin && !non_interactive_mode) return; + if (!std::cin && inference.prompt.empty()) return; // If !eot_line.empty(), we append \n, so only look at the first 2 chars. - if (!non_interactive_mode && prompt_string.size() >= 2 && + if (inference.prompt.empty() && prompt_string.size() >= 2 && prompt_string[0] == '%') { if (prompt_string[1] == 'q' || prompt_string[1] == 'Q') return; if (prompt_string[1] == 'c' || prompt_string[1] == 'C') { @@ -180,12 +175,27 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference, } } - if (!non_interactive_mode && prompt_string.empty()) { + if (inference.prompt.empty() && prompt_string.empty()) { std::cout << "Use '%q' to quit.\n"; continue; } - std::vector prompt; + size_t prompt_size = 0; + size_t prefix_end = 0; + if constexpr (kVerboseLogTokens) { + for (int i = 0; i < prompt_size; ++i) { + fprintf(stderr, "DDD TOKEN %3d: %6d\n", i, prompt[i]); + } + } + + // Set up runtime config. + TimingInfo timing_info = {.verbosity = inference.verbosity}; + RuntimeConfig runtime_config = {.gen = &gen, + .verbosity = inference.verbosity, + .stream_token = stream_token, + .use_spinning = threading.spin}; + inference.CopyTo(runtime_config); + if (have_image) { prompt = WrapAndTokenize(model.Tokenizer(), model.ChatTemplate(), model.Info(), @@ -201,21 +211,6 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference, prompt_size = prompt.size(); } - if constexpr (kVerboseLogTokens) { - for (int i = 0; i < prompt_size; ++i) { - fprintf(stderr, "DDD TOKEN %3d: %6d\n", i, prompt[i]); - } - } - - // Set up runtime config. - TimingInfo timing_info = {.verbosity = inference.verbosity}; - RuntimeConfig runtime_config = {.gen = &gen, - .verbosity = inference.verbosity, - .stream_token = stream_token, - .use_spinning = threading.spin}; - inference.CopyTo(runtime_config); - size_t prefix_end = 0; - if (have_image) { runtime_config.image_tokens = &image_tokens; prompt_size = prompt.size(); @@ -234,7 +229,7 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference, std::cout << "\n\n"; // Break the loop if in non-interactive mode - if (non_interactive_mode) { + if (inference.prompt.empty()) { break; } From 27c28cc9386f8642071bd909169176e4bffcb98c Mon Sep 17 00:00:00 2001 From: prajwalc22 Date: Thu, 17 Apr 2025 10:15:05 +0530 Subject: [PATCH 7/9] Address review feedback: Fix prefill_tbatch_size and variable placement issues --- gemma/run.cc | 35 +++++++++++++++-------------------- 1 file changed, 15 insertions(+), 20 deletions(-) diff --git a/gemma/run.cc b/gemma/run.cc index b7e8fa1..36d1bc2 100644 --- a/gemma/run.cc +++ b/gemma/run.cc @@ -156,7 +156,8 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference, std::cout << token_text << std::flush; return true; }; - + // Flag to check if we should exit after processing non-interactive prompt + bool exit_after_generation = !inference.prompt.empty(); while (true) { // Loop until user quits. tokens_generated_this_turn = 0; @@ -179,14 +180,6 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference, std::cout << "Use '%q' to quit.\n"; continue; } - std::vector prompt; - size_t prompt_size = 0; - size_t prefix_end = 0; - if constexpr (kVerboseLogTokens) { - for (int i = 0; i < prompt_size; ++i) { - fprintf(stderr, "DDD TOKEN %3d: %6d\n", i, prompt[i]); - } - } // Set up runtime config. TimingInfo timing_info = {.verbosity = inference.verbosity}; @@ -195,29 +188,31 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference, .stream_token = stream_token, .use_spinning = threading.spin}; inference.CopyTo(runtime_config); - + std::vector prompt; + size_t prompt_size = 0; + size_t prefix_end = 0; if (have_image) { prompt = WrapAndTokenize(model.Tokenizer(), model.ChatTemplate(), model.Info(), abs_pos, prompt_string, image_tokens.BatchSize()); + runtime_config.image_tokens = &image_tokens; + prompt_size = prompt.size(); // The end of the prefix for prefix-LM style attention in Paligemma. // See Figure 2 of https://arxiv.org/abs/2407.07726. prefix_end = prompt_size; - // We need to look at all the tokens for the prefix. - runtime_config.prefill_tbatch_size = prompt_size; + + // REMOVED: Don't change prefill_tbatch_size for image handling + // runtime_config.prefill_tbatch_size = prompt_size; } else { prompt = WrapAndTokenize(model.Tokenizer(), model.ChatTemplate(), model.Info(), abs_pos, prompt_string); prompt_size = prompt.size(); } - if (have_image) { - runtime_config.image_tokens = &image_tokens; - prompt_size = prompt.size(); - // The end of the prefix for prefix-LM style attention in Paligemma. - prefix_end = prompt_size; - // We need to look at all the tokens for the prefix. - runtime_config.prefill_tbatch_size = prompt_size; + if constexpr (kVerboseLogTokens) { + for (int i = 0; i < prompt_size; ++i) { + fprintf(stderr, "DDD TOKEN %3d: %6d\n", i, prompt[i]); + } } // Generate until EOS or max_generated_tokens. @@ -229,7 +224,7 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference, std::cout << "\n\n"; // Break the loop if in non-interactive mode - if (inference.prompt.empty()) { + if (exit_after_generation) { break; } From f55c321397895c60db063938a2bc76b4f08ede38 Mon Sep 17 00:00:00 2001 From: prajwalc22 Date: Thu, 17 Apr 2025 10:15:21 +0530 Subject: [PATCH 8/9] Address review feedback: Fix prefill_tbatch_size and variable placement issues --- gemma/run.cc | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/gemma/run.cc b/gemma/run.cc index 36d1bc2..56cdb75 100644 --- a/gemma/run.cc +++ b/gemma/run.cc @@ -156,8 +156,7 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference, std::cout << token_text << std::flush; return true; }; - // Flag to check if we should exit after processing non-interactive prompt - bool exit_after_generation = !inference.prompt.empty(); + while (true) { // Loop until user quits. tokens_generated_this_turn = 0; @@ -224,7 +223,7 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference, std::cout << "\n\n"; // Break the loop if in non-interactive mode - if (exit_after_generation) { + if (!inference.prompt.empty()) { break; } From a9e56c27eb546e6a4de10b4d02e0ac58f341ddd7 Mon Sep 17 00:00:00 2001 From: prajwalc22 Date: Thu, 17 Apr 2025 23:44:23 +0530 Subject: [PATCH 9/9] removed unnecessary threading.h import --- gemma/run.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/gemma/run.cc b/gemma/run.cc index 56cdb75..2de1c1d 100644 --- a/gemma/run.cc +++ b/gemma/run.cc @@ -34,7 +34,6 @@ #include "ops/matmul.h" // MatMulEnv #include "paligemma/image.h" #include "util/args.h" // HasHelp -#include "util/threading.h" #if (!defined(HWY_VERSION_LT) || HWY_VERSION_LT(1, 2)) && !HWY_IDE #error "Please update to version 1.2 of github.com/google/highway."