Add --prompt flag for non-interactive mode

This commit is contained in:
prajwalc22 2025-04-16 15:34:43 +05:30
parent 716713f0e6
commit cbf179990f
3 changed files with 795 additions and 453 deletions

596
README.md
View File

@ -1,27 +1,583 @@
--- # gemma.cpp
library_name: gemma.cpp
license: gemma
pipeline_tag: text-generation
tags: []
extra_gated_heading: Access Gemma on Hugging Face
extra_gated_prompt: To access Gemma on Hugging Face, youre required to review and
agree to Googles usage license. To do this, please ensure youre logged-in to Hugging
Face and click below. Requests are processed immediately.
extra_gated_button_content: Acknowledge license
---
# Gemma Model Card gemma.cpp is a lightweight, standalone C++ inference engine for the Gemma
foundation models from Google.
**Model Page**: [Gemma](https://ai.google.dev/gemma/docs) For additional information about Gemma, see
[ai.google.dev/gemma](https://ai.google.dev/gemma). Model weights, including
gemma.cpp specific artifacts, are
[available on kaggle](https://www.kaggle.com/models/google/gemma).
This model card corresponds to the 2B base version of the Gemma model for usage with C++ (https://github.com/google/gemma.cpp). This is a compressed version of the weights, which will load, run, and download more quickly. For more information about the model, visit https://huggingface.co/google/gemma-2b. ## Who is this project for?
**Resources and Technical Documentation**: Modern LLM inference engines are sophisticated systems, often with bespoke
capabilities extending beyond traditional neural network runtimes. With this
comes opportunities for research and innovation through co-design of high level
algorithms and low-level computation. However, there is a gap between
deployment-oriented C++ inference runtimes, which are not designed for
experimentation, and Python-centric ML research frameworks, which abstract away
low-level computation through compilation.
* [Responsible Generative AI Toolkit](https://ai.google.dev/responsible) gemma.cpp provides a minimalist implementation of Gemma-1, Gemma-2, Gemma-3, and
* [Gemma on Kaggle](https://www.kaggle.com/models/google/gemma) PaliGemma models, focusing on simplicity and directness rather than full
* [Gemma on Vertex Model Garden](https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/335?version=gemma-2b-gg-hf) generality. This is inspired by vertically-integrated model implementations such
as [ggml](https://github.com/ggerganov/ggml),
[llama.c](https://github.com/karpathy/llama2.c), and
[llama.rs](https://github.com/srush/llama2.rs).
**Terms of Use**: [Terms](https://www.kaggle.com/models/google/gemma/license/consent/verify/huggingface?returnModelRepoId=google/gemma-2b-sfp-cpp) gemma.cpp targets experimentation and research use cases. It is intended to be
straightforward to embed in other projects with minimal dependencies and also
easily modifiable with a small ~2K LoC core implementation (along with ~4K LoC
of supporting utilities). We use the [Google
Highway](https://github.com/google/highway) Library to take advantage of
portable SIMD for CPU inference.
**Authors**: Google For production-oriented edge deployments we recommend standard deployment
pathways using Python frameworks like JAX, Keras, PyTorch, and Transformers
([all model variations here](https://www.kaggle.com/models/google/gemma)).
## Contributing
Community contributions large and small are welcome. See
[DEVELOPERS.md](https://github.com/google/gemma.cpp/blob/main/DEVELOPERS.md)
for additional notes contributing developers and [join the discord by following
this invite link](https://discord.gg/H5jCBAWxAe). This project follows
[Google's Open Source Community
Guidelines](https://opensource.google.com/conduct/).
*Active development is currently done on the `dev` branch. Please open pull
requests targeting `dev` branch instead of `main`, which is intended to be more
stable.*
## Quick Start
### System requirements
Before starting, you should have installed:
- [CMake](https://cmake.org/)
- [Clang C++ compiler](https://clang.llvm.org/get_started.html), supporting at
least C++17.
- `tar` for extracting archives from Kaggle.
Building natively on Windows requires the Visual Studio 2012 Build Tools with the
optional Clang/LLVM C++ frontend (`clang-cl`). This can be installed from the
command line with
[`winget`](https://learn.microsoft.com/en-us/windows/package-manager/winget/):
```sh
winget install --id Kitware.CMake
winget install --id Microsoft.VisualStudio.2022.BuildTools --force --override "--passive --wait --add Microsoft.VisualStudio.Workload.VCTools;installRecommended --add Microsoft.VisualStudio.Component.VC.Llvm.Clang --add Microsoft.VisualStudio.Component.VC.Llvm.ClangToolset"
```
### Step 1: Obtain model weights and tokenizer from Kaggle or Hugging Face Hub
Visit the
[Kaggle page for Gemma-2](https://www.kaggle.com/models/google/gemma-2/gemmaCpp)
[or Gemma-1](https://www.kaggle.com/models/google/gemma/frameworks/gemmaCpp),
and select `Model Variations |> Gemma C++`.
On this tab, the `Variation` dropdown includes the options below. Note bfloat16
weights are higher fidelity, while 8-bit switched floating point weights enable
faster inference. In general, we recommend starting with the `-sfp` checkpoints.
If you are unsure which model to start with, we recommend starting with the
smallest Gemma-2 model, i.e. `2.0-2b-it-sfp`.
Alternatively, visit the
[gemma.cpp](https://huggingface.co/models?other=gemma.cpp) models on the Hugging
Face Hub. First go the model repository of the model of interest (see
recommendations below). Then, click the `Files and versions` tab and download
the model and tokenizer files. For programmatic downloading, if you have
`huggingface_hub` installed, you can also download by running:
```
huggingface-cli login # Just the first time
huggingface-cli download google/gemma-2b-sfp-cpp --local-dir build/
```
Gemma-1 2B instruction-tuned (`it`) and pre-trained (`pt`) models:
| Model name | Description |
| ----------- | ----------- |
| `2b-it` | 2 billion parameter instruction-tuned model, bfloat16 |
| `2b-it-sfp` | 2 billion parameter instruction-tuned model, 8-bit switched floating point |
| `2b-pt` | 2 billion parameter pre-trained model, bfloat16 |
| `2b-pt-sfp` | 2 billion parameter pre-trained model, 8-bit switched floating point |
Gemma-1 7B instruction-tuned (`it`) and pre-trained (`pt`) models:
| Model name | Description |
| ----------- | ----------- |
| `7b-it` | 7 billion parameter instruction-tuned model, bfloat16 |
| `7b-it-sfp` | 7 billion parameter instruction-tuned model, 8-bit switched floating point |
| `7b-pt` | 7 billion parameter pre-trained model, bfloat16 |
| `7b-pt-sfp` | 7 billion parameter pre-trained model, 8-bit switched floating point |
> [!NOTE]
> **Important**: We strongly recommend starting off with the `2b-it-sfp` model to
> get up and running.
Gemma 2 models are named `gemma2-2b-it` for 2B and `9b-it` or `27b-it`. See the
`kModelFlags` definition in `common.cc`.
### Step 2: Extract Files
If you downloaded the models from Hugging Face, skip to step 3.
After filling out the consent form, the download should proceed to retrieve a
tar archive file `archive.tar.gz`. Extract files from `archive.tar.gz` (this can
take a few minutes):
```
tar -xf archive.tar.gz
```
This should produce a file containing model weights such as `2b-it-sfp.sbs` and
a tokenizer file (`tokenizer.spm`). You may want to move these files to a
convenient directory location (e.g. the `build/` directory in this repo).
### Step 3: Build
The build system uses [CMake](https://cmake.org/). To build the gemma inference
runtime, create a build directory and generate the build files using `cmake`
from the top-level project directory. Note if you previous ran `cmake` and are
re-running with a different setting, be sure to delete all files in the `build/`
directory with `rm -rf build/*`.
#### Unix-like Platforms
```sh
cmake -B build
```
After running `cmake`, you can enter the `build/` directory and run `make` to
build the `./gemma` executable:
```sh
# Configure `build` directory
cmake --preset make
# Build project using make
cmake --build --preset make -j [number of parallel threads to use]
```
Replace `[number of parallel threads to use]` with a number - the number of
cores available on your system is a reasonable heuristic. For example,
`make -j4 gemma` will build using 4 threads. If the `nproc` command is
available, you can use `make -j$(nproc) gemma` as a reasonable default
for the number of threads.
If you aren't sure of the right value for the `-j` flag, you can simply run
`make gemma` instead and it should still build the `./gemma` executable.
> [!NOTE]
> On Windows Subsystem for Linux (WSL) users should set the number of
> parallel threads to 1. Using a larger number may result in errors.
If the build is successful, you should now have a `gemma` executable in the `build/` directory.
#### Windows
```sh
# Configure `build` directory
cmake --preset windows
# Build project using Visual Studio Build Tools
cmake --build --preset windows -j [number of parallel threads to use]
```
If the build is successful, you should now have a `gemma.exe` executable in the `build/` directory.
#### Bazel
```sh
bazel build -c opt --cxxopt=-std=c++20 :gemma
```
If the build is successful, you should now have a `gemma` executable in the `bazel-bin/` directory.
#### Make
If you prefer Makefiles, @jart has made one available here:
https://github.com/jart/gemma3/blob/main/Makefile
### Step 4: Run
You can now run `gemma` from inside the `build/` directory.
`gemma` has the following required arguments:
Argument | Description | Example value
--------------- | ---------------------------- | -----------------------
`--model` | The model type. | `2b-it` ... (see below)
`--weights` | The compressed weights file. | `2b-it-sfp.sbs`
`--weight_type` | The compressed weight type. | `sfp`
`--tokenizer` | The tokenizer file. | `tokenizer.spm`
`gemma` is invoked as:
```sh
./gemma \
--tokenizer [tokenizer file] \
--weights [compressed weights file] \
--weight_type [f32 or bf16 or sfp (default:sfp)] \
--model [2b-it or 2b-pt or 7b-it or 7b-pt or ...]
```
Example invocation for the following configuration:
- Compressed weights file `2b-it-sfp.sbs` (2B instruction-tuned model, 8-bit
switched floating point).
- Tokenizer file `tokenizer.spm`.
```sh
./gemma \
--tokenizer tokenizer.spm \
--weights 2b-it-sfp.sbs --model 2b-it
```
### RecurrentGemma
This repository includes a version of Gemma based on Griffin
([paper](https://arxiv.org/abs/2402.19427),
[code](https://github.com/google-deepmind/recurrentgemma)). Its architecture
includes both recurrent layers and local attention, thus it is more efficient
for longer sequences and has a smaller memory footprint than standard Gemma. We
here provide a C++ implementation of this model based on the paper.
To use the recurrent version of Gemma included in this repository, build the
gemma binary as noted above in Step 3. Download the compressed weights and
tokenizer from the RecurrentGemma
[Kaggle](https://www.kaggle.com/models/google/recurrentgemma/gemmaCpp) as in
Step 1, and run the binary as follows:
`./gemma --tokenizer tokenizer.spm --model gr2b-it --weights 2b-it-sfp.sbs`
### PaliGemma Vision-Language Model
This repository includes a version of the PaliGemma VLM
([paper](https://arxiv.org/abs/2407.07726),
[code](https://github.com/google-research/big_vision/tree/main/big_vision/configs/proj/paligemma))
and its successor PaliGemma 2 ([paper](https://arxiv.org/abs/2412.03555)). We
provide a C++ implementation of the PaliGemma model family here.
To use the version of PaliGemma included in this repository, build the gemma
binary as noted above in Step 3. Download the compressed weights and tokenizer
from
[Kaggle](https://www.kaggle.com/models/google/paligemma/gemmaCpp/paligemma-3b-mix-224)
and run the binary as follows:
```sh
./gemma \
--tokenizer paligemma_tokenizer.model \
--model paligemma-224 \
--weights paligemma-3b-mix-224-sfp.sbs \
--image_file paligemma/testdata/image.ppm
```
Note that the image reading code is very basic to avoid depending on an image
processing library for now. We currently only support reading binary PPMs (P6).
So use a tool like `convert` to first convert your images into that format, e.g.
`convert image.jpeg -resize 224x224^ image.ppm`
(As the image will be resized for processing anyway, we can already resize at
this stage for slightly faster loading.)
The interaction with the image (using the mix-224 checkpoint) may then look
something like this:
```
> Describe the image briefly
A large building with two towers in the middle of a city.
> What type of building is it?
church
> What color is the church?
gray
> caption image
A large building with two towers stands tall on the water's edge. The building
has a brown roof and a window on the side. A tree stands in front of the
building, and a flag waves proudly from its top. The water is calm and blue,
reflecting the sky above. A bridge crosses the water, and a red and white boat
rests on its surface. The building has a window on the side, and a flag on top.
A tall tree stands in front of the building, and a window on the building is
visible from the water. The water is green, and the sky is blue.
```
### Migrating to single-file format
There is now a new format for the weights file, which is a single file that
allows to contain the tokenizer (and the model type) directly. A tool to migrate
from the multi-file format to the single-file format is available.
```sh
compression/migrate_weights \
--tokenizer .../tokenizer.spm --weights .../gemma2-2b-it-sfp.sbs \
--model gemma2-2b-it --output_weights .../gemma2-2b-it-sfp-single.sbs
```
After migration, you can use the new weights file with gemma.cpp like this:
```sh
./gemma --weights .../gemma2-2b-it-sfp-single.sbs
```
### Troubleshooting and FAQs
**Running `./gemma` fails with "Failed to read cache gating_ein_0 (error 294) ..."**
The most common problem is that the `--weight_type` argument does not match that
of the model file. Revisit step #3 and check which weights you downloaded.
Note that we have already moved weight type from a compile-time decision to a
runtime argument. In a subsequent step, we plan to bake this information into
the weights.
**Problems building in Windows / Visual Studio**
Currently if you're using Windows, we recommend building in WSL (Windows
Subsystem for Linux). We are exploring options to enable other build
configurations, see issues for active discussion.
**Model does not respond to instructions and produces strange output**
A common issue is that you are using a pre-trained model, which is not
instruction-tuned and thus does not respond to instructions. Make sure you are
using an instruction-tuned model (`2b-it-sfp`, `2b-it`, `7b-it-sfp`, `7b-it`)
and not a pre-trained model (any model with a `-pt` suffix).
**What sequence lengths are supported?**
See `seq_len` in `configs.cc`. For the Gemma 3 models larger than 1B, this is
typically 32K but 128K would also work given enough RAM. Note that long
sequences will be slow due to the quadratic cost of attention.
**How do I convert my fine-tune to a `.sbs` compressed model file?**
For PaliGemma (1 and 2) checkpoints, you can use
python/convert_from_safetensors.py to convert from safetensors format (tested
with building via bazel). For an adapter model, you will likely need to call
merge_and_unload() to convert the adapter model to a single-file format before
converting it.
Here is how to use it using a bazel build of the compression library assuming
locally installed (venv) torch, numpy, safetensors, absl-py, etc.:
```sh
bazel build //compression/python:compression
BAZEL_OUTPUT_DIR="${PWD}/bazel-bin/compression"
python3 -c "import site; print(site.getsitepackages())"
# Use your sites-packages file here:
ln -s $BAZEL_OUTPUT_DIR [...]/site-packages/compression
python3 python/convert_from_safetensors.py --load_path [...].safetensors.index.json
```
See also compression/convert_weights.py for a slightly older option to convert a
pytorch checkpoint. (The code may need updates to work with Gemma-2 models.)
**What are some easy ways to make the model run faster?**
1. Make sure you are using the 8-bit switched floating point `-sfp` models.
These are half the size of bf16 and thus use less memory bandwidth and cache
space.
2. If you're on a laptop, make sure power mode is set to maximize performance
and saving mode is **off**. For most laptops, the power saving modes get
activated automatically if the computer is not plugged in.
3. Close other unused cpu-intensive applications.
4. On macs, anecdotally we observe a "warm-up" ramp-up in speed as performance
cores get engaged.
5. Experiment with the `--num_threads` argument value. Depending on the device,
larger numbers don't always mean better performance.
We're also working on algorithmic and optimization approaches for faster
inference, stay tuned.
## Usage
`gemma` has different usage modes, controlled by the verbosity flag.
All usage modes are currently interactive, triggering text generation upon
newline input.
| Verbosity | Usage mode | Details |
| --------------- | ---------- | --------------------------------------------- |
| `--verbosity 0` | Minimal | Only prints generation output. Suitable as a CLI tool. |
| `--verbosity 1` | Default | Standard user-facing terminal UI. |
| `--verbosity 2` | Detailed | Shows additional developer and debug info. |
### Interactive Terminal App
By default, verbosity is set to 1, bringing up a terminal-based interactive
interface when `gemma` is invoked:
```console
$ ./gemma [...]
__ _ ___ _ __ ___ _ __ ___ __ _ ___ _ __ _ __
/ _` |/ _ \ '_ ` _ \| '_ ` _ \ / _` | / __| '_ \| '_ \
| (_| | __/ | | | | | | | | | | (_| || (__| |_) | |_) |
\__, |\___|_| |_| |_|_| |_| |_|\__,_(_)___| .__/| .__/
__/ | | | | |
|___/ |_| |_|
tokenizer : tokenizer.spm
compressed_weights : 2b-it-sfp.sbs
model : 2b-it
weights : [no path specified]
max_generated_tokens : 2048
*Usage*
Enter an instruction and press enter (%C reset conversation, %Q quits).
*Examples*
- Write an email to grandma thanking her for the cookies.
- What are some historical attractions to visit around Massachusetts?
- Compute the nth fibonacci number in javascript.
- Write a standup comedy bit about WebGPU programming.
> What are some outdoorsy places to visit around Boston?
[ Reading prompt ] .....................
**Boston Harbor and Islands:**
* **Boston Harbor Islands National and State Park:** Explore pristine beaches, wildlife, and maritime history.
* **Charles River Esplanade:** Enjoy scenic views of the harbor and city skyline.
* **Boston Harbor Cruise Company:** Take a relaxing harbor cruise and admire the city from a different perspective.
* **Seaport Village:** Visit a charming waterfront area with shops, restaurants, and a seaport museum.
**Forest and Nature:**
* **Forest Park:** Hike through a scenic forest with diverse wildlife.
* **Quabbin Reservoir:** Enjoy boating, fishing, and hiking in a scenic setting.
* **Mount Forest:** Explore a mountain with breathtaking views of the city and surrounding landscape.
...
```
### Usage as a Command Line Tool
For using the `gemma` executable as a command line tool, it may be useful to
create an alias for gemma.cpp with arguments fully specified:
```sh
alias gemma2b="~/gemma.cpp/build/gemma -- --tokenizer ~/gemma.cpp/build/tokenizer.spm --weights ~/gemma.cpp/build/gemma2-2b-it-sfp.sbs --model gemma2-2b-it --verbosity 0"
```
Replace the above paths with your own paths to the model and tokenizer paths
from the download.
Here is an example of prompting `gemma` with a truncated input
file (using a `gemma2b` alias like defined above):
```sh
cat configs.h | tail -n 35 | tr '\n' ' ' | xargs -0 echo "What does this C++ code do: " | gemma2b
```
> [!NOTE]
> CLI usage of gemma.cpp is experimental and should take context length
> limitations into account.
The output of the above command should look like:
```console
[ Reading prompt ] [...]
This C++ code snippet defines a set of **constants** used in a large language model (LLM) implementation, likely related to the **attention mechanism**.
Let's break down the code:
[...]
```
### Incorporating gemma.cpp as a Library in your Project
The easiest way to incorporate gemma.cpp in your own project is to pull in
gemma.cpp and dependencies using `FetchContent`. You can add the following to your
CMakeLists.txt:
```
include(FetchContent)
FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 53de76561cfc149d3c01037f0595669ad32a5e7c)
FetchContent_MakeAvailable(sentencepiece)
FetchContent_Declare(gemma GIT_REPOSITORY https://github.com/google/gemma.cpp GIT_TAG origin/main)
FetchContent_MakeAvailable(gemma)
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG da250571a45826b21eebbddc1e50d0c1137dee5f)
FetchContent_MakeAvailable(highway)
```
Note for the gemma.cpp `GIT_TAG`, you may replace `origin/main` for a specific
commit hash if you would like to pin the library version.
After your executable is defined (substitute your executable name for
`[Executable Name]` below):
```
target_link_libraries([Executable Name] libgemma hwy hwy_contrib sentencepiece)
FetchContent_GetProperties(gemma)
FetchContent_GetProperties(sentencepiece)
target_include_directories([Executable Name] PRIVATE ${gemma_SOURCE_DIR})
target_include_directories([Executable Name] PRIVATE ${sentencepiece_SOURCE_DIR})
```
### Building gemma.cpp as a Library
gemma.cpp can also be used as a library dependency in your own project. The
shared library artifact can be built by modifying the make invocation to build
the `libgemma` target instead of `gemma`.
> [!NOTE]
> If you are using gemma.cpp in your own project with the `FetchContent` steps
> in the previous section, building the library is done automatically by `cmake`
> and this section can be skipped.
First, run `cmake`:
```sh
cmake -B build
```
Then, run `make` with the `libgemma` target:
```sh
cd build
make -j [number of parallel threads to use] libgemma
```
If this is successful, you should now have a `libgemma` library file in the
`build/` directory. On Unix platforms, the filename is `libgemma.a`.
## Independent Projects Using gemma.cpp
Some independent projects using gemma.cpp:
- [gemma-cpp-python - Python bindings](https://github.com/namtranase/gemma-cpp-python)
- [lua-cgemma - Lua bindings](https://github.com/ufownl/lua-cgemma)
- [Godot engine demo project](https://github.com/Rliop913/Gemma-godot-demo-project)
If you would like to have your project included, feel free to get in touch or
submit a PR with a `README.md` edit.
## Acknowledgements and Contacts
gemma.cpp was started in fall 2023 by [Austin Huang](mailto:austinvhuang@google.com)
and [Jan Wassenberg](mailto:janwas@google.com), and subsequently released February 2024
thanks to contributions from Phil Culliton, Paul Chang, and Dan Zheng.
Griffin support was implemented in April 2024 thanks to contributions by Andrey
Mikhaylov, Eugene Kliuchnikov, Jan Wassenberg, Jyrki Alakuijala, Lode
Vandevenne, Luca Versari, Martin Bruse, Phil Culliton, Sami Boukortt, Thomas
Fischbacher and Zoltan Szabadka.
Gemma-2 support was implemented in June/July 2024 with the help of several
people.
PaliGemma support was implemented in September 2024 with contributions from
Daniel Keysers.
[Jan Wassenberg](mailto:janwas@google.com) has continued to contribute many
improvements, including major gains in efficiency, since the initial release.
This is not an officially supported Google product.

View File

@ -13,383 +13,85 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
// Shared between various frontends. // Argument parsing for Gemma.
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_ARGS_H_ #ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_ARGS_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_ARGS_H_ #define THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_ARGS_H_
#include <stddef.h>
#include <stdio.h> #include <stdio.h>
#include <memory>
#include <string> #include <string>
#include "compression/io.h" // Path
#include "compression/shared.h" #include "compression/shared.h"
#include "gemma/common.h" #include "gemma/common.h"
#include "gemma/gemma.h" // For CreateGemma #include "gemma/gemma.h" // For CreateGemma
#include "hwy/base.h" // HWY_IS_ASAN, HWY_ABORT #include "hwy/base.h" // HWY_ABORT
#include "ops/matmul.h" #include "ops/matmul.h"
#include "util/allocator.h"
#include "util/args.h" #include "util/args.h"
#include "util/basics.h" // Tristate #include "util/basics.h" // Tristate
#include "util/threading.h"
#include "util/threading_context.h"
namespace gcpp { namespace gcpp {
static inline const char* CompiledConfig() { // Arguments related to inference: sampling, text etc.
if (HWY_IS_ASAN) {
return "asan";
} else if (HWY_IS_MSAN) {
return "msan";
} else if (HWY_IS_TSAN) {
return "tsan";
} else if (HWY_IS_HWASAN) {
return "hwasan";
} else if (HWY_IS_UBSAN) {
return "ubsan";
} else if (HWY_IS_DEBUG_BUILD) {
return "dbg";
} else {
return "opt";
}
}
template <typename Derived>
struct ArgsBase {
void Init() { static_cast<Derived*>(this)->ForEach(SetToDefault()); }
void InitAndParse(int argc, char* argv[]) {
Init();
static_cast<Derived*>(this)->ForEach(ParseOption(argc, argv));
}
void Print(int min_verbosity = 1) const {
static_cast<const Derived*>(this)->ForEach(PrintOption(min_verbosity));
}
void Help() const { static_cast<const Derived*>(this)->ForEach(PrintHelp()); }
protected:
// Helper struct for printing help messages
struct PrintHelp {
template <typename T>
void operator()(const T& value, const char* name, const T& default_value,
const char* description, int verbosity = 1) const {
fprintf(stderr, " --%s\n %s\n", name, description);
}
// Special case for strings to avoid template deduction issues
void operator()(const std::string& value, const char* name,
const std::string& default_value, const char* description,
int verbosity = 1) const {
fprintf(stderr, " --%s\n %s\n", name, description);
}
// Special case for Path type
void operator()(const Path& value, const char* name,
const Path& default_value, const char* description,
int verbosity = 1) const {
fprintf(stderr, " --%s\n %s\n", name, description);
}
};
// Helper struct for setting default values
struct SetToDefault {
template <typename T>
void operator()(T& value, const char* name, const T& default_value,
const char* description, int verbosity = 1) const {
value = default_value;
}
};
// Helper struct for printing values
struct PrintOption {
explicit PrintOption(int min_verbosity) : min_verbosity_(min_verbosity) {}
template <typename T>
void operator()(const T& value, const char* name, const T& default_value,
const char* description, int verbosity = 1) const {
if (verbosity >= min_verbosity_) {
fprintf(stderr, "%s: %s\n", name, ToString(value).c_str());
}
}
private:
int min_verbosity_;
// Helper function to convert values to string
template <typename T>
static std::string ToString(const T& value) {
return std::to_string(value);
}
// Specialization for string
static std::string ToString(const std::string& value) { return value; }
// Specialization for Path
static std::string ToString(const Path& value) { return value.path; }
};
};
struct ThreadingArgs : public ArgsBase<ThreadingArgs> {
public:
ThreadingArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
ThreadingArgs() { Init(); };
int verbosity;
size_t max_threads; // divided among the detected clusters
Tristate pin; // pin threads?
Tristate spin; // use spin waits?
// For BoundedSlice:
size_t skip_packages;
size_t max_packages;
size_t skip_clusters;
size_t max_clusters;
size_t skip_lps;
size_t max_lps;
std::string eot_line;
std::string prompt;
template <class Visitor>
void ForEach(const Visitor& visitor) {
visitor(verbosity, "verbosity", 1,
"Show verbose developer information\n 0 = only print generation "
"output\n 1 = standard user-facing terminal ui\n 2 = show "
"developer/debug info).\n Default = 1.",
2);
// The exact meaning is more subtle: see the comment at NestedPools ctor.
visitor(max_threads, "num_threads", size_t{0},
"Maximum number of threads to use; default 0 = unlimited.", 2);
visitor(pin, "pin", Tristate::kDefault,
"Pin threads? -1 = auto, 0 = no, 1 = yes.", 2);
visitor(spin, "spin", Tristate::kDefault,
"Use spin waits? -1 = auto, 0 = no, 1 = yes.", 2);
// These can be used to partition CPU sockets/packages and their
// clusters/CCXs across several program instances. The default is to use
// all available resources.
visitor(skip_packages, "skip_packages", size_t{0},
"Index of the first socket to use; default 0 = unlimited.", 2);
visitor(max_packages, "max_packages", size_t{0},
"Maximum number of sockets to use; default 0 = unlimited.", 2);
visitor(skip_clusters, "skip_clusters", size_t{0},
"Index of the first CCX to use; default 0 = unlimited.", 2);
visitor(max_clusters, "max_clusters", size_t{0},
"Maximum number of CCXs to use; default 0 = unlimited.", 2);
// These are only used when CPU topology is unknown.
visitor(skip_lps, "skip_lps", size_t{0},
"Index of the first LP to use; default 0 = unlimited.", 2);
visitor(max_lps, "max_lps", size_t{0},
"Maximum number of LPs to use; default 0 = unlimited.", 2);
visitor(
eot_line, "eot_line", std::string(""),
"End of turn line. "
"When you specify this, the prompt will be all lines "
"before the line where only the given string appears.\n Default = "
"When a newline is encountered, that signals the end of the turn.",
2);
visitor(prompt, "prompt", std::string(""),
"Prompt string for non-interactive mode. When provided, the model "
"generates a response and exits.",
2);
}
};
static inline BoundedTopology CreateTopology(const ThreadingArgs& threading) {
return BoundedTopology(
BoundedSlice(threading.skip_packages, threading.max_packages),
BoundedSlice(threading.skip_clusters, threading.max_clusters),
BoundedSlice(threading.skip_lps, threading.max_lps));
}
static inline MatMulEnv MakeMatMulEnv(const ThreadingArgs& threading) {
ThreadingContext2::SetArgs(threading);
return MatMulEnv(ThreadingContext2::Get());
}
// Note: These functions may need adjustments depending on your specific class
// definitions
static inline BoundedTopology CreateTopology(const ThreadingArgs& app) {
return BoundedTopology(BoundedSlice(app.skip_packages, app.max_packages),
BoundedSlice(app.skip_clusters, app.max_clusters),
BoundedSlice(app.skip_lps, app.max_lps));
}
// This function may need to be adjusted based on your NestedPools constructor
// signature
static inline NestedPools CreatePools(const BoundedTopology& topology,
const ThreadingArgs& threading) {
// Make sure Allocator::Init() is properly declared/defined
const Allocator2& allocator = ThreadingContext2::Get().allocator;
// Allocator::Init(topology);
// Adjust the constructor call based on your actual NestedPools constructor
// The error suggests that the constructor doesn't match these arguments
return NestedPools(topology, allocator, threading.max_threads, threading.pin);
// Alternative: return NestedPools(topology, app.max_threads, app.pin);
}
struct LoaderArgs : public ArgsBase<LoaderArgs> {
LoaderArgs(int argc, char* argv[], bool validate = true) {
InitAndParse(argc, argv);
if (validate) {
if (const char* error = Validate()) {
HWY_ABORT("Invalid args: %s", error);
}
}
}
LoaderArgs(const std::string& tokenizer_path, const std::string& weights_path,
const std::string& model, bool validate = true) {
Init(); // Init sets to defaults, so assignments must come after Init().
tokenizer.path = tokenizer_path;
weights.path = weights_path;
model_type_str = model;
if (validate) {
if (const char* error = Validate()) {
HWY_ABORT("Invalid args: %s", error);
}
}
};
// Returns error string or nullptr if OK.
const char* Validate() {
if (weights.path.empty()) {
return "Missing --weights flag, a file for the model weights.";
}
if (!weights.Exists()) {
return "Can't open file specified with --weights flag.";
}
info_.model = Model::UNKNOWN;
info_.wrapping = PromptWrapping::GEMMA_PT;
info_.weight = Type::kUnknown;
if (!model_type_str.empty()) {
const char* err = ParseModelTypeAndWrapping(model_type_str, info_.model,
info_.wrapping);
if (err != nullptr) return err;
}
if (!weight_type_str.empty()) {
const char* err = ParseType(weight_type_str, info_.weight);
if (err != nullptr) return err;
}
if (!tokenizer.path.empty()) {
if (!tokenizer.Exists()) {
return "Can't open file specified with --tokenizer flag.";
}
}
// model_type and tokenizer must be either both present or both absent.
// Further checks happen on weight loading.
if (model_type_str.empty() != tokenizer.path.empty()) {
return "Missing or extra flags for model_type or tokenizer.";
}
return nullptr;
}
Path tokenizer;
Path weights; // weights file location
Path compressed_weights;
std::string model_type_str;
std::string weight_type_str;
template <class Visitor>
void ForEach(const Visitor& visitor) {
visitor(tokenizer, "tokenizer", Path(),
"Path name of tokenizer model file.");
visitor(weights, "weights", Path(),
"Path name of model weights (.sbs) file.\n Required argument.\n");
visitor(compressed_weights, "compressed_weights", Path(),
"Deprecated alias for --weights.");
visitor(model_type_str, "model", std::string(),
"Model type, see common.cc for valid values.\n");
visitor(weight_type_str, "weight_type", std::string("sfp"),
"Weight type\n f32 = float, bf16 = bfloat16, sfp = 8-bit SFP.");
}
// Uninitialized before Validate, must call after that.
const ModelInfo& Info() const { return info_; }
private:
// TODO(rays): remove this. Eventually ModelConfig will be loaded from the
// weights file, so we can remove the need for this struct entirely.
ModelInfo info_;
};
// `env` must remain valid for the lifetime of the Gemma.
static inline Gemma CreateGemma(const LoaderArgs& loader, MatMulEnv& env) {
if (Type::kUnknown == loader.Info().weight ||
Model::UNKNOWN == loader.Info().model || loader.tokenizer.path.empty()) {
// New weights file format doesn't need tokenizer path or model/weightinfo.
return Gemma(loader.weights, env);
}
return Gemma(loader.tokenizer, loader.weights, loader.Info(), env);
}
// `env` must remain valid for the lifetime of the Gemma.
static inline std::unique_ptr<Gemma> AllocateGemma(const LoaderArgs& loader,
MatMulEnv& env) {
if (Type::kUnknown == loader.Info().weight ||
Model::UNKNOWN == loader.Info().model || loader.tokenizer.path.empty()) {
// New weights file format doesn't need tokenizer path or model/weight info.
return std::make_unique<Gemma>(loader.weights, env);
}
return std::make_unique<Gemma>(loader.tokenizer, loader.weights,
loader.Info(), env);
}
struct InferenceArgs : public ArgsBase<InferenceArgs> { struct InferenceArgs : public ArgsBase<InferenceArgs> {
InferenceArgs(int argc, char* argv[]) { InitAndParse(argc, argv); } // Arguments for getc-like interfaces
InferenceArgs() { Init(); }; size_t max_tokens;
int verbosity;
size_t max_generated_tokens; size_t max_generated_tokens;
size_t prefill_tbatch_size;
size_t decode_qbatch_size;
float temperature; float temperature;
size_t top_k; size_t top_k;
bool deterministic; float top_p;
bool multiturn; float min_p;
Path image_file; int repeat_penalty_power;
float repeat_penalty_presence;
float repeat_penalty_decay;
float repeat_penalty_range;
// Batch configuration:
size_t prefill_tbatch_size;
size_t decode_tbatch_size;
// Non-interactive mode prompt
std::string prompt;
std::string eot_line; std::string eot_line;
// Returns error string or nullptr if OK.
const char* Validate() const {
if (max_generated_tokens > gcpp::kSeqLen) {
return "max_generated_tokens is larger than the maximum sequence length "
"(see configs.h).";
}
return nullptr;
}
template <class Visitor> template <class Visitor>
void ForEach(const Visitor& visitor) { void ForEach(Visitor& visitor) {
visitor(verbosity, "verbosity", 1, // Each line specifies a variable member, its name, default value, and help.
"Show verbose developer information\n 0 = only print generation " visitor(max_tokens, "max_tokens", size_t{50},
"output\n 1 = standard user-facing terminal ui\n 2 = show " "Maximum number of total tokens including prompt (0=no limit).", 1);
"developer/debug info).\n Default = 1.", visitor(max_generated_tokens, "max_generated_tokens", size_t{512},
2); "Maximum number of generated tokens (not including prompt) (0=no "
"limit).",
1);
visitor(temperature, "temperature", 1.0f,
"Temperature (randomness) for logits.", 1);
visitor(top_k, "top_k", size_t{40},
"Number of highest-probability tokens to consider (0=unlimited).",
1);
visitor(top_p, "top_p", 0.9f, "Top-p probability threshold (0.0=disabled).",
1);
visitor(min_p, "min_p", 0.0f, "Min-p probability threshold (0.0=disabled).",
1);
visitor(
repeat_penalty_power, "repeat_penalty_power", 1,
"Penalty power (1=standard frequentist penalty). If 0, skips penalty "
"computation.",
1);
visitor(repeat_penalty_presence, "repeat_penalty_presence", 0.0f,
"Penalty for token presence regardless of frequency (additive).",
1);
visitor(repeat_penalty_decay, "repeat_penalty_decay", 0.0f,
"Penalty for token n positions ago is decayed by "
"power(repeat_penalty_decay, n).",
1);
visitor(repeat_penalty_range, "repeat_penalty_range", 8.0f,
"Penalty fades out near the end of range (tokens)", 1);
visitor(max_generated_tokens, "max_generated_tokens", size_t{2048}, // Batch configuration:
"Maximum number of tokens to generate."); visitor(prefill_tbatch_size, "prefill_tbatch_size", size_t{2},
"Token batch size for prefill; <= 32", 2);
visitor(prefill_tbatch_size, "prefill_tbatch", size_t{256}, visitor(decode_tbatch_size, "decode_tbatch_size", size_t{1},
"Prefill: max tokens per batch."); "Token batch size for decode (only 1 currently supported)", 2);
visitor(decode_qbatch_size, "decode_qbatch", size_t{16},
"Decode: max queries per batch.");
visitor(temperature, "temperature", 1.0f, "Temperature for top-K", 2);
visitor(top_k, "top_k", size_t{1}, "Number of top-K tokens to sample from",
2);
visitor(deterministic, "deterministic", false,
"Make top-k sampling deterministic", 2);
visitor(multiturn, "multiturn", false,
"Multiturn mode\n 0 = clear KV cache after every "
"interaction\n 1 = continue KV cache after every interaction\n "
" Default : 0 (conversation "
"resets every turn)");
visitor(image_file, "image_file", Path(), "Image file to load.");
visitor( visitor(
eot_line, "eot_line", std::string(""), eot_line, "eot_line", std::string(""),
@ -397,47 +99,123 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
"When you specify this, the prompt will be all lines " "When you specify this, the prompt will be all lines "
"before the line where only the given string appears.\n Default = " "before the line where only the given string appears.\n Default = "
"When a newline is encountered, that signals the end of the turn.", "When a newline is encountered, that signals the end of the turn.",
2); 1);
// Non-interactive mode prompt
visitor(prompt, "prompt", std::string(""),
"Prompt to use in non-interactive mode", 1);
} }
void CopyTo(RuntimeConfig& runtime_config) const { const char* Validate() const {
runtime_config.max_generated_tokens = max_generated_tokens; if (max_generated_tokens == 0 && max_tokens == 0) {
runtime_config.prefill_tbatch_size = prefill_tbatch_size; return "At least one of max_tokens and max_generated_tokens must be > 0";
runtime_config.decode_qbatch_size = decode_qbatch_size;
if (prefill_tbatch_size > MMStorage::kMaxM) {
HWY_ABORT(
"prefill_tbatch_size %zu > kMaxM %zu: specify a smaller value, "
"or increase the constant in MMStorage.\n",
prefill_tbatch_size, MMStorage::kMaxM);
} }
if (decode_qbatch_size > MMStorage::kMaxM) { if (temperature <= 0.0) {
HWY_ABORT( return "Temperature must be > 0.0";
"decode_qbatch_size %zu > kMaxM %zu: specify a smaller value, "
"or increase the constant in MMStorage.\n",
decode_qbatch_size, MMStorage::kMaxM);
} }
if (prefill_tbatch_size > 32) {
runtime_config.temperature = temperature; return "prefill_tbatch_size must be <= 32";
runtime_config.top_k = top_k; }
if (decode_tbatch_size != 1) {
return "decode_tbatch_size must be 1";
}
return nullptr;
} }
}; };
static inline void ShowConfig(const ThreadingArgs& threading, // Arguments related to model weights.
const LoaderArgs& loader, struct LoaderArgs : public ArgsBase<LoaderArgs> {
const InferenceArgs& inference) { Path model_path; // Path to directory containing the weights
threading.Print(); Path tokenizer; // Optional: can be derived from model_path
loader.Print(); bool model_is_gemma2;
inference.Print(); Gemma::Config::WeightFormat weight_format;
template <class Visitor>
void ForEach(Visitor& visitor) {
// Each line specifies a variable member, its name, default value, and help.
visitor(model_path, "model", Path{},
"Directory containing weights or config file from `gemma.cpp "
"convert`.",
0);
visitor(tokenizer, "tokenizer", Path{},
"Optional path to tokenizer.model; if empty, looks in model_path.",
2);
visitor(model_is_gemma2, "model_is_gemma2", false,
"Whether the model is a Gemma 2 model", 1);
visitor(weight_format, "format", Gemma::Config::kBfloat16,
"Model weights format: 0=F32, 1=F16, 2=BF16", 2);
} }
static inline void ShowHelp(const ThreadingArgs& threading,
const LoaderArgs& loader, const char* Validate() const {
const InferenceArgs& inference) { if (model_path.path.empty()) {
fprintf(stderr, "\nUsage: gemma [flags]\n\nFlags:\n"); return "Empty model path";
threading.Help();
loader.Help();
inference.Help();
} }
if (weight_format != Gemma::Config::kBfloat16 &&
weight_format != Gemma::Config::kFloat16 &&
weight_format != Gemma::Config::kFloat32) {
return "Invalid weight format";
}
return nullptr;
}
};
// Threading-related arguments.
struct ThreadingArgs : public ArgsBase<ThreadingArgs> {
size_t num_threads;
Tristate pin_threads;
Tristate use_spinning;
int verbosity;
template <class Visitor>
void ForEach(Visitor& visitor) {
visitor(num_threads, "threads", size_t{0},
"Number of threads (0=auto, half of logical cores)", 1);
visitor(pin_threads, "pin_threads", Tristate::kDefault,
"Set to true/false to force enable/disable thread pinning.", 2);
visitor(use_spinning, "use_spinning", Tristate::kDefault,
"Set to true/false to enable/disable thread spinning (typically "
"improves "
"performance but increases power usage)",
2);
visitor(verbosity, "verbosity", 1,
"Controls printing of progress messages to stderr", 1);
}
// Returns nullptr if OK, otherwise error message.
const char* Validate() const { return nullptr; }
// Returns num_threads to use.
size_t NumThreadsToUse() const {
return num_threads == 0 ? (size_t{hwy::NumberOfProcessors()} + 1) / 2
: num_threads;
}
};
// Command-line arguments for PeftGemma and Gemma.
struct GemmaArgs : public ArgsBase<GemmaArgs> {
InferenceArgs inference;
LoaderArgs loader;
ThreadingArgs threading;
// For collect_stats.cc:
Path output;
bool trace_outputs; // For -ftrace and dump_csv.cc
bool trace_base; // For -ftrace
int time_it; // For time_it.cc
template <class Visitor>
void ForEach(Visitor& visitor) {
inference.ForEach(visitor);
loader.ForEach(visitor);
threading.ForEach(visitor);
visitor(output, "output", Path{}, "Where to write CSV data / stats", 2);
visitor(trace_outputs, "trace_outputs", false, "For tracing", 2);
visitor(trace_base, "trace_base", false, "For tracing", 2);
visitor(time_it, "time_it", 0, "For benchmarks", 2);
}
};
} // namespace gcpp } // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_ARGS_H_ #endif // THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_ARGS_H_

View File

@ -78,6 +78,18 @@ std::string GetPrompt(std::istream& input, int verbosity,
return prompt_string; return prompt_string;
} }
// New GetPrompt function that accepts InferenceArgs
std::string GetPrompt(const InferenceArgs& inference, int verbosity,
size_t turn) {
// Check for command-line prompt first
if (!inference.prompt.empty()) {
return inference.prompt;
}
// Use the existing function for interactive mode
return GetPrompt(std::cin, verbosity, inference.eot_line);
}
// The main Read-Eval-Print Loop. // The main Read-Eval-Print Loop.
void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference, void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
Gemma& model, KVCache& kv_cache) { Gemma& model, KVCache& kv_cache) {
@ -89,6 +101,9 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
std::mt19937 gen; std::mt19937 gen;
InitGenerator(inference, gen); InitGenerator(inference, gen);
// Add flag to track non-interactive mode
bool non_interactive_mode = !inference.prompt.empty();
const bool have_image = !inference.image_file.path.empty(); const bool have_image = !inference.image_file.path.empty();
Image image; Image image;
ImageTokens image_tokens; ImageTokens image_tokens;
@ -151,47 +166,30 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
// Read prompt and handle special commands. // Read prompt and handle special commands.
std::string prompt_string = std::string prompt_string =
GetPrompt(std::cin, inference.verbosity, inference.eot_line); GetPrompt(inference, inference.verbosity, abs_pos);
if (!std::cin) return;
if (!std::cin && !non_interactive_mode) return;
// If !eot_line.empty(), we append \n, so only look at the first 2 chars. // If !eot_line.empty(), we append \n, so only look at the first 2 chars.
if (prompt_string.size() >= 2 && prompt_string[0] == '%') { if (!non_interactive_mode && prompt_string.size() >= 2 &&
prompt_string[0] == '%') {
if (prompt_string[1] == 'q' || prompt_string[1] == 'Q') return; if (prompt_string[1] == 'q' || prompt_string[1] == 'Q') return;
if (prompt_string[1] == 'c' || prompt_string[1] == 'C') { if (prompt_string[1] == 'c' || prompt_string[1] == 'C') {
abs_pos = 0; abs_pos = 0;
continue; continue;
} }
} }
if (prompt_string.empty()) {
if (!non_interactive_mode && prompt_string.empty()) {
std::cout << "Use '%q' to quit.\n"; std::cout << "Use '%q' to quit.\n";
continue; continue;
} }
// Wrap, tokenize and maybe log prompt tokens.
std::vector<int> prompt = WrapAndTokenize(model.Tokenizer(), model.Info(),
abs_pos, prompt_string);
prompt_size = prompt.size();
if constexpr (kVerboseLogTokens) {
for (int i = 0; i < prompt_size; ++i) {
fprintf(stderr, "DDD TOKEN %3d: %6d\n", i, prompt[i]);
}
}
// Set up runtime config.
TimingInfo timing_info = {.verbosity = inference.verbosity};
RuntimeConfig runtime_config = {.gen = &gen,
.verbosity = inference.verbosity,
.stream_token = stream_token,
.use_spinning = threading.spin};
inference.CopyTo(runtime_config);
size_t prefix_end = 0;
std::vector<int> prompt; std::vector<int> prompt;
if (have_image) { if (have_image) {
prompt = prompt =
WrapAndTokenize(model.Tokenizer(), model.ChatTemplate(), model.Info(), WrapAndTokenize(model.Tokenizer(), model.ChatTemplate(), model.Info(),
abs_pos, prompt_string, image_tokens.BatchSize()); abs_pos, prompt_string, image_tokens.BatchSize());
runtime_config.image_tokens = &image_tokens;
prompt_size = prompt.size();
// The end of the prefix for prefix-LM style attention in Paligemma. // The end of the prefix for prefix-LM style attention in Paligemma.
// See Figure 2 of https://arxiv.org/abs/2407.07726. // See Figure 2 of https://arxiv.org/abs/2407.07726.
prefix_end = prompt_size; prefix_end = prompt_size;
@ -209,6 +207,24 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
} }
} }
// Set up runtime config.
TimingInfo timing_info = {.verbosity = inference.verbosity};
RuntimeConfig runtime_config = {.gen = &gen,
.verbosity = inference.verbosity,
.stream_token = stream_token,
.use_spinning = threading.spin};
inference.CopyTo(runtime_config);
size_t prefix_end = 0;
if (have_image) {
runtime_config.image_tokens = &image_tokens;
prompt_size = prompt.size();
// The end of the prefix for prefix-LM style attention in Paligemma.
prefix_end = prompt_size;
// We need to look at all the tokens for the prefix.
runtime_config.prefill_tbatch_size = prompt_size;
}
// Generate until EOS or max_generated_tokens. // Generate until EOS or max_generated_tokens.
if (inference.verbosity >= 1) { if (inference.verbosity >= 1) {
std::cerr << "\n[ Reading prompt ] " << std::flush; std::cerr << "\n[ Reading prompt ] " << std::flush;
@ -217,6 +233,11 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
timing_info); timing_info);
std::cout << "\n\n"; std::cout << "\n\n";
// Break the loop if in non-interactive mode
if (non_interactive_mode) {
break;
}
// Prepare for the next turn. Works only for PaliGemma. // Prepare for the next turn. Works only for PaliGemma.
if (!inference.multiturn || if (!inference.multiturn ||
model.Info().wrapping == PromptWrapping::PALIGEMMA) { model.Info().wrapping == PromptWrapping::PALIGEMMA) {
@ -249,22 +270,6 @@ void Run(ThreadingArgs& threading, LoaderArgs& loader,
KVCache kv_cache = KVCache kv_cache =
KVCache::Create(model.GetModelConfig(), inference.prefill_tbatch_size); KVCache::Create(model.GetModelConfig(), inference.prefill_tbatch_size);
if (!threading.prompt.empty()) {
std::vector<int> prompt =
WrapAndTokenize(model.Tokenizer(), model.ChatTemplate(), model.Info(),
0, threading.prompt);
TimingInfo timing_info = {.verbosity = inference.verbosity};
RuntimeConfig runtime_config = {.gen = nullptr, // Use default generator
.verbosity = inference.verbosity,
.use_spinning = threading.spin};
inference.CopyTo(runtime_config);
model.Generate(runtime_config, prompt, 0, 0, kv_cache, timing_info);
std::cout << "\n";
return; // Exit after generating response
}
if (inference.verbosity >= 1) { if (inference.verbosity >= 1) {
std::string instructions = std::string instructions =
"*Usage*\n" "*Usage*\n"
@ -286,11 +291,14 @@ void Run(ThreadingArgs& threading, LoaderArgs& loader,
instructions += multiturn; instructions += multiturn;
instructions += examples; instructions += examples;
// Skip the banner and instructions in non-interactive mode
if (inference.prompt.empty()) {
std::cout << "\033[2J\033[1;1H" // clear screen std::cout << "\033[2J\033[1;1H" // clear screen
<< kAsciiArtBanner << "\n\n"; << kAsciiArtBanner << "\n\n";
ShowConfig(threading, loader, inference); ShowConfig(threading, loader, inference);
std::cout << "\n" << instructions << "\n"; std::cout << "\n" << instructions << "\n";
} }
}
ReplGemma(threading, inference, model, kv_cache); ReplGemma(threading, inference, model, kv_cache);
} }