mirror of https://github.com/google/gemma.cpp.git
Add --prompt flag for non-interactive mode
This commit is contained in:
parent
716713f0e6
commit
cbf179990f
596
README.md
596
README.md
|
|
@ -1,27 +1,583 @@
|
||||||
---
|
# gemma.cpp
|
||||||
library_name: gemma.cpp
|
|
||||||
license: gemma
|
|
||||||
pipeline_tag: text-generation
|
|
||||||
tags: []
|
|
||||||
extra_gated_heading: Access Gemma on Hugging Face
|
|
||||||
extra_gated_prompt: To access Gemma on Hugging Face, you’re required to review and
|
|
||||||
agree to Google’s usage license. To do this, please ensure you’re logged-in to Hugging
|
|
||||||
Face and click below. Requests are processed immediately.
|
|
||||||
extra_gated_button_content: Acknowledge license
|
|
||||||
---
|
|
||||||
|
|
||||||
# Gemma Model Card
|
gemma.cpp is a lightweight, standalone C++ inference engine for the Gemma
|
||||||
|
foundation models from Google.
|
||||||
|
|
||||||
**Model Page**: [Gemma](https://ai.google.dev/gemma/docs)
|
For additional information about Gemma, see
|
||||||
|
[ai.google.dev/gemma](https://ai.google.dev/gemma). Model weights, including
|
||||||
|
gemma.cpp specific artifacts, are
|
||||||
|
[available on kaggle](https://www.kaggle.com/models/google/gemma).
|
||||||
|
|
||||||
This model card corresponds to the 2B base version of the Gemma model for usage with C++ (https://github.com/google/gemma.cpp). This is a compressed version of the weights, which will load, run, and download more quickly. For more information about the model, visit https://huggingface.co/google/gemma-2b.
|
## Who is this project for?
|
||||||
|
|
||||||
**Resources and Technical Documentation**:
|
Modern LLM inference engines are sophisticated systems, often with bespoke
|
||||||
|
capabilities extending beyond traditional neural network runtimes. With this
|
||||||
|
comes opportunities for research and innovation through co-design of high level
|
||||||
|
algorithms and low-level computation. However, there is a gap between
|
||||||
|
deployment-oriented C++ inference runtimes, which are not designed for
|
||||||
|
experimentation, and Python-centric ML research frameworks, which abstract away
|
||||||
|
low-level computation through compilation.
|
||||||
|
|
||||||
* [Responsible Generative AI Toolkit](https://ai.google.dev/responsible)
|
gemma.cpp provides a minimalist implementation of Gemma-1, Gemma-2, Gemma-3, and
|
||||||
* [Gemma on Kaggle](https://www.kaggle.com/models/google/gemma)
|
PaliGemma models, focusing on simplicity and directness rather than full
|
||||||
* [Gemma on Vertex Model Garden](https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/335?version=gemma-2b-gg-hf)
|
generality. This is inspired by vertically-integrated model implementations such
|
||||||
|
as [ggml](https://github.com/ggerganov/ggml),
|
||||||
|
[llama.c](https://github.com/karpathy/llama2.c), and
|
||||||
|
[llama.rs](https://github.com/srush/llama2.rs).
|
||||||
|
|
||||||
**Terms of Use**: [Terms](https://www.kaggle.com/models/google/gemma/license/consent/verify/huggingface?returnModelRepoId=google/gemma-2b-sfp-cpp)
|
gemma.cpp targets experimentation and research use cases. It is intended to be
|
||||||
|
straightforward to embed in other projects with minimal dependencies and also
|
||||||
|
easily modifiable with a small ~2K LoC core implementation (along with ~4K LoC
|
||||||
|
of supporting utilities). We use the [Google
|
||||||
|
Highway](https://github.com/google/highway) Library to take advantage of
|
||||||
|
portable SIMD for CPU inference.
|
||||||
|
|
||||||
**Authors**: Google
|
For production-oriented edge deployments we recommend standard deployment
|
||||||
|
pathways using Python frameworks like JAX, Keras, PyTorch, and Transformers
|
||||||
|
([all model variations here](https://www.kaggle.com/models/google/gemma)).
|
||||||
|
|
||||||
|
## Contributing
|
||||||
|
|
||||||
|
Community contributions large and small are welcome. See
|
||||||
|
[DEVELOPERS.md](https://github.com/google/gemma.cpp/blob/main/DEVELOPERS.md)
|
||||||
|
for additional notes contributing developers and [join the discord by following
|
||||||
|
this invite link](https://discord.gg/H5jCBAWxAe). This project follows
|
||||||
|
[Google's Open Source Community
|
||||||
|
Guidelines](https://opensource.google.com/conduct/).
|
||||||
|
|
||||||
|
*Active development is currently done on the `dev` branch. Please open pull
|
||||||
|
requests targeting `dev` branch instead of `main`, which is intended to be more
|
||||||
|
stable.*
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
### System requirements
|
||||||
|
|
||||||
|
Before starting, you should have installed:
|
||||||
|
|
||||||
|
- [CMake](https://cmake.org/)
|
||||||
|
- [Clang C++ compiler](https://clang.llvm.org/get_started.html), supporting at
|
||||||
|
least C++17.
|
||||||
|
- `tar` for extracting archives from Kaggle.
|
||||||
|
|
||||||
|
Building natively on Windows requires the Visual Studio 2012 Build Tools with the
|
||||||
|
optional Clang/LLVM C++ frontend (`clang-cl`). This can be installed from the
|
||||||
|
command line with
|
||||||
|
[`winget`](https://learn.microsoft.com/en-us/windows/package-manager/winget/):
|
||||||
|
|
||||||
|
```sh
|
||||||
|
winget install --id Kitware.CMake
|
||||||
|
winget install --id Microsoft.VisualStudio.2022.BuildTools --force --override "--passive --wait --add Microsoft.VisualStudio.Workload.VCTools;installRecommended --add Microsoft.VisualStudio.Component.VC.Llvm.Clang --add Microsoft.VisualStudio.Component.VC.Llvm.ClangToolset"
|
||||||
|
```
|
||||||
|
|
||||||
|
### Step 1: Obtain model weights and tokenizer from Kaggle or Hugging Face Hub
|
||||||
|
|
||||||
|
Visit the
|
||||||
|
[Kaggle page for Gemma-2](https://www.kaggle.com/models/google/gemma-2/gemmaCpp)
|
||||||
|
[or Gemma-1](https://www.kaggle.com/models/google/gemma/frameworks/gemmaCpp),
|
||||||
|
and select `Model Variations |> Gemma C++`.
|
||||||
|
|
||||||
|
On this tab, the `Variation` dropdown includes the options below. Note bfloat16
|
||||||
|
weights are higher fidelity, while 8-bit switched floating point weights enable
|
||||||
|
faster inference. In general, we recommend starting with the `-sfp` checkpoints.
|
||||||
|
|
||||||
|
If you are unsure which model to start with, we recommend starting with the
|
||||||
|
smallest Gemma-2 model, i.e. `2.0-2b-it-sfp`.
|
||||||
|
|
||||||
|
Alternatively, visit the
|
||||||
|
[gemma.cpp](https://huggingface.co/models?other=gemma.cpp) models on the Hugging
|
||||||
|
Face Hub. First go the model repository of the model of interest (see
|
||||||
|
recommendations below). Then, click the `Files and versions` tab and download
|
||||||
|
the model and tokenizer files. For programmatic downloading, if you have
|
||||||
|
`huggingface_hub` installed, you can also download by running:
|
||||||
|
|
||||||
|
```
|
||||||
|
huggingface-cli login # Just the first time
|
||||||
|
huggingface-cli download google/gemma-2b-sfp-cpp --local-dir build/
|
||||||
|
```
|
||||||
|
|
||||||
|
Gemma-1 2B instruction-tuned (`it`) and pre-trained (`pt`) models:
|
||||||
|
|
||||||
|
| Model name | Description |
|
||||||
|
| ----------- | ----------- |
|
||||||
|
| `2b-it` | 2 billion parameter instruction-tuned model, bfloat16 |
|
||||||
|
| `2b-it-sfp` | 2 billion parameter instruction-tuned model, 8-bit switched floating point |
|
||||||
|
| `2b-pt` | 2 billion parameter pre-trained model, bfloat16 |
|
||||||
|
| `2b-pt-sfp` | 2 billion parameter pre-trained model, 8-bit switched floating point |
|
||||||
|
|
||||||
|
Gemma-1 7B instruction-tuned (`it`) and pre-trained (`pt`) models:
|
||||||
|
|
||||||
|
| Model name | Description |
|
||||||
|
| ----------- | ----------- |
|
||||||
|
| `7b-it` | 7 billion parameter instruction-tuned model, bfloat16 |
|
||||||
|
| `7b-it-sfp` | 7 billion parameter instruction-tuned model, 8-bit switched floating point |
|
||||||
|
| `7b-pt` | 7 billion parameter pre-trained model, bfloat16 |
|
||||||
|
| `7b-pt-sfp` | 7 billion parameter pre-trained model, 8-bit switched floating point |
|
||||||
|
|
||||||
|
> [!NOTE]
|
||||||
|
> **Important**: We strongly recommend starting off with the `2b-it-sfp` model to
|
||||||
|
> get up and running.
|
||||||
|
|
||||||
|
Gemma 2 models are named `gemma2-2b-it` for 2B and `9b-it` or `27b-it`. See the
|
||||||
|
`kModelFlags` definition in `common.cc`.
|
||||||
|
|
||||||
|
### Step 2: Extract Files
|
||||||
|
|
||||||
|
If you downloaded the models from Hugging Face, skip to step 3.
|
||||||
|
|
||||||
|
After filling out the consent form, the download should proceed to retrieve a
|
||||||
|
tar archive file `archive.tar.gz`. Extract files from `archive.tar.gz` (this can
|
||||||
|
take a few minutes):
|
||||||
|
|
||||||
|
```
|
||||||
|
tar -xf archive.tar.gz
|
||||||
|
```
|
||||||
|
|
||||||
|
This should produce a file containing model weights such as `2b-it-sfp.sbs` and
|
||||||
|
a tokenizer file (`tokenizer.spm`). You may want to move these files to a
|
||||||
|
convenient directory location (e.g. the `build/` directory in this repo).
|
||||||
|
|
||||||
|
### Step 3: Build
|
||||||
|
|
||||||
|
The build system uses [CMake](https://cmake.org/). To build the gemma inference
|
||||||
|
runtime, create a build directory and generate the build files using `cmake`
|
||||||
|
from the top-level project directory. Note if you previous ran `cmake` and are
|
||||||
|
re-running with a different setting, be sure to delete all files in the `build/`
|
||||||
|
directory with `rm -rf build/*`.
|
||||||
|
|
||||||
|
#### Unix-like Platforms
|
||||||
|
```sh
|
||||||
|
cmake -B build
|
||||||
|
```
|
||||||
|
|
||||||
|
After running `cmake`, you can enter the `build/` directory and run `make` to
|
||||||
|
build the `./gemma` executable:
|
||||||
|
|
||||||
|
```sh
|
||||||
|
# Configure `build` directory
|
||||||
|
cmake --preset make
|
||||||
|
|
||||||
|
# Build project using make
|
||||||
|
cmake --build --preset make -j [number of parallel threads to use]
|
||||||
|
```
|
||||||
|
|
||||||
|
Replace `[number of parallel threads to use]` with a number - the number of
|
||||||
|
cores available on your system is a reasonable heuristic. For example,
|
||||||
|
`make -j4 gemma` will build using 4 threads. If the `nproc` command is
|
||||||
|
available, you can use `make -j$(nproc) gemma` as a reasonable default
|
||||||
|
for the number of threads.
|
||||||
|
|
||||||
|
If you aren't sure of the right value for the `-j` flag, you can simply run
|
||||||
|
`make gemma` instead and it should still build the `./gemma` executable.
|
||||||
|
|
||||||
|
> [!NOTE]
|
||||||
|
> On Windows Subsystem for Linux (WSL) users should set the number of
|
||||||
|
> parallel threads to 1. Using a larger number may result in errors.
|
||||||
|
|
||||||
|
If the build is successful, you should now have a `gemma` executable in the `build/` directory.
|
||||||
|
|
||||||
|
#### Windows
|
||||||
|
|
||||||
|
```sh
|
||||||
|
# Configure `build` directory
|
||||||
|
cmake --preset windows
|
||||||
|
|
||||||
|
# Build project using Visual Studio Build Tools
|
||||||
|
cmake --build --preset windows -j [number of parallel threads to use]
|
||||||
|
```
|
||||||
|
|
||||||
|
If the build is successful, you should now have a `gemma.exe` executable in the `build/` directory.
|
||||||
|
|
||||||
|
#### Bazel
|
||||||
|
|
||||||
|
```sh
|
||||||
|
bazel build -c opt --cxxopt=-std=c++20 :gemma
|
||||||
|
```
|
||||||
|
|
||||||
|
If the build is successful, you should now have a `gemma` executable in the `bazel-bin/` directory.
|
||||||
|
|
||||||
|
#### Make
|
||||||
|
|
||||||
|
If you prefer Makefiles, @jart has made one available here:
|
||||||
|
|
||||||
|
https://github.com/jart/gemma3/blob/main/Makefile
|
||||||
|
|
||||||
|
### Step 4: Run
|
||||||
|
|
||||||
|
You can now run `gemma` from inside the `build/` directory.
|
||||||
|
|
||||||
|
`gemma` has the following required arguments:
|
||||||
|
|
||||||
|
Argument | Description | Example value
|
||||||
|
--------------- | ---------------------------- | -----------------------
|
||||||
|
`--model` | The model type. | `2b-it` ... (see below)
|
||||||
|
`--weights` | The compressed weights file. | `2b-it-sfp.sbs`
|
||||||
|
`--weight_type` | The compressed weight type. | `sfp`
|
||||||
|
`--tokenizer` | The tokenizer file. | `tokenizer.spm`
|
||||||
|
|
||||||
|
`gemma` is invoked as:
|
||||||
|
|
||||||
|
```sh
|
||||||
|
./gemma \
|
||||||
|
--tokenizer [tokenizer file] \
|
||||||
|
--weights [compressed weights file] \
|
||||||
|
--weight_type [f32 or bf16 or sfp (default:sfp)] \
|
||||||
|
--model [2b-it or 2b-pt or 7b-it or 7b-pt or ...]
|
||||||
|
```
|
||||||
|
|
||||||
|
Example invocation for the following configuration:
|
||||||
|
|
||||||
|
- Compressed weights file `2b-it-sfp.sbs` (2B instruction-tuned model, 8-bit
|
||||||
|
switched floating point).
|
||||||
|
- Tokenizer file `tokenizer.spm`.
|
||||||
|
|
||||||
|
```sh
|
||||||
|
./gemma \
|
||||||
|
--tokenizer tokenizer.spm \
|
||||||
|
--weights 2b-it-sfp.sbs --model 2b-it
|
||||||
|
```
|
||||||
|
|
||||||
|
### RecurrentGemma
|
||||||
|
|
||||||
|
This repository includes a version of Gemma based on Griffin
|
||||||
|
([paper](https://arxiv.org/abs/2402.19427),
|
||||||
|
[code](https://github.com/google-deepmind/recurrentgemma)). Its architecture
|
||||||
|
includes both recurrent layers and local attention, thus it is more efficient
|
||||||
|
for longer sequences and has a smaller memory footprint than standard Gemma. We
|
||||||
|
here provide a C++ implementation of this model based on the paper.
|
||||||
|
|
||||||
|
To use the recurrent version of Gemma included in this repository, build the
|
||||||
|
gemma binary as noted above in Step 3. Download the compressed weights and
|
||||||
|
tokenizer from the RecurrentGemma
|
||||||
|
[Kaggle](https://www.kaggle.com/models/google/recurrentgemma/gemmaCpp) as in
|
||||||
|
Step 1, and run the binary as follows:
|
||||||
|
|
||||||
|
`./gemma --tokenizer tokenizer.spm --model gr2b-it --weights 2b-it-sfp.sbs`
|
||||||
|
|
||||||
|
### PaliGemma Vision-Language Model
|
||||||
|
|
||||||
|
This repository includes a version of the PaliGemma VLM
|
||||||
|
([paper](https://arxiv.org/abs/2407.07726),
|
||||||
|
[code](https://github.com/google-research/big_vision/tree/main/big_vision/configs/proj/paligemma))
|
||||||
|
and its successor PaliGemma 2 ([paper](https://arxiv.org/abs/2412.03555)). We
|
||||||
|
provide a C++ implementation of the PaliGemma model family here.
|
||||||
|
|
||||||
|
To use the version of PaliGemma included in this repository, build the gemma
|
||||||
|
binary as noted above in Step 3. Download the compressed weights and tokenizer
|
||||||
|
from
|
||||||
|
[Kaggle](https://www.kaggle.com/models/google/paligemma/gemmaCpp/paligemma-3b-mix-224)
|
||||||
|
and run the binary as follows:
|
||||||
|
|
||||||
|
```sh
|
||||||
|
./gemma \
|
||||||
|
--tokenizer paligemma_tokenizer.model \
|
||||||
|
--model paligemma-224 \
|
||||||
|
--weights paligemma-3b-mix-224-sfp.sbs \
|
||||||
|
--image_file paligemma/testdata/image.ppm
|
||||||
|
```
|
||||||
|
|
||||||
|
Note that the image reading code is very basic to avoid depending on an image
|
||||||
|
processing library for now. We currently only support reading binary PPMs (P6).
|
||||||
|
So use a tool like `convert` to first convert your images into that format, e.g.
|
||||||
|
|
||||||
|
`convert image.jpeg -resize 224x224^ image.ppm`
|
||||||
|
|
||||||
|
(As the image will be resized for processing anyway, we can already resize at
|
||||||
|
this stage for slightly faster loading.)
|
||||||
|
|
||||||
|
The interaction with the image (using the mix-224 checkpoint) may then look
|
||||||
|
something like this:
|
||||||
|
|
||||||
|
```
|
||||||
|
> Describe the image briefly
|
||||||
|
A large building with two towers in the middle of a city.
|
||||||
|
> What type of building is it?
|
||||||
|
church
|
||||||
|
> What color is the church?
|
||||||
|
gray
|
||||||
|
> caption image
|
||||||
|
A large building with two towers stands tall on the water's edge. The building
|
||||||
|
has a brown roof and a window on the side. A tree stands in front of the
|
||||||
|
building, and a flag waves proudly from its top. The water is calm and blue,
|
||||||
|
reflecting the sky above. A bridge crosses the water, and a red and white boat
|
||||||
|
rests on its surface. The building has a window on the side, and a flag on top.
|
||||||
|
A tall tree stands in front of the building, and a window on the building is
|
||||||
|
visible from the water. The water is green, and the sky is blue.
|
||||||
|
```
|
||||||
|
|
||||||
|
### Migrating to single-file format
|
||||||
|
|
||||||
|
There is now a new format for the weights file, which is a single file that
|
||||||
|
allows to contain the tokenizer (and the model type) directly. A tool to migrate
|
||||||
|
from the multi-file format to the single-file format is available.
|
||||||
|
|
||||||
|
```sh
|
||||||
|
compression/migrate_weights \
|
||||||
|
--tokenizer .../tokenizer.spm --weights .../gemma2-2b-it-sfp.sbs \
|
||||||
|
--model gemma2-2b-it --output_weights .../gemma2-2b-it-sfp-single.sbs
|
||||||
|
```
|
||||||
|
|
||||||
|
After migration, you can use the new weights file with gemma.cpp like this:
|
||||||
|
|
||||||
|
```sh
|
||||||
|
./gemma --weights .../gemma2-2b-it-sfp-single.sbs
|
||||||
|
```
|
||||||
|
|
||||||
|
### Troubleshooting and FAQs
|
||||||
|
|
||||||
|
**Running `./gemma` fails with "Failed to read cache gating_ein_0 (error 294) ..."**
|
||||||
|
|
||||||
|
The most common problem is that the `--weight_type` argument does not match that
|
||||||
|
of the model file. Revisit step #3 and check which weights you downloaded.
|
||||||
|
|
||||||
|
Note that we have already moved weight type from a compile-time decision to a
|
||||||
|
runtime argument. In a subsequent step, we plan to bake this information into
|
||||||
|
the weights.
|
||||||
|
|
||||||
|
**Problems building in Windows / Visual Studio**
|
||||||
|
|
||||||
|
Currently if you're using Windows, we recommend building in WSL (Windows
|
||||||
|
Subsystem for Linux). We are exploring options to enable other build
|
||||||
|
configurations, see issues for active discussion.
|
||||||
|
|
||||||
|
**Model does not respond to instructions and produces strange output**
|
||||||
|
|
||||||
|
A common issue is that you are using a pre-trained model, which is not
|
||||||
|
instruction-tuned and thus does not respond to instructions. Make sure you are
|
||||||
|
using an instruction-tuned model (`2b-it-sfp`, `2b-it`, `7b-it-sfp`, `7b-it`)
|
||||||
|
and not a pre-trained model (any model with a `-pt` suffix).
|
||||||
|
|
||||||
|
**What sequence lengths are supported?**
|
||||||
|
|
||||||
|
See `seq_len` in `configs.cc`. For the Gemma 3 models larger than 1B, this is
|
||||||
|
typically 32K but 128K would also work given enough RAM. Note that long
|
||||||
|
sequences will be slow due to the quadratic cost of attention.
|
||||||
|
|
||||||
|
**How do I convert my fine-tune to a `.sbs` compressed model file?**
|
||||||
|
|
||||||
|
For PaliGemma (1 and 2) checkpoints, you can use
|
||||||
|
python/convert_from_safetensors.py to convert from safetensors format (tested
|
||||||
|
with building via bazel). For an adapter model, you will likely need to call
|
||||||
|
merge_and_unload() to convert the adapter model to a single-file format before
|
||||||
|
converting it.
|
||||||
|
|
||||||
|
Here is how to use it using a bazel build of the compression library assuming
|
||||||
|
locally installed (venv) torch, numpy, safetensors, absl-py, etc.:
|
||||||
|
|
||||||
|
```sh
|
||||||
|
bazel build //compression/python:compression
|
||||||
|
BAZEL_OUTPUT_DIR="${PWD}/bazel-bin/compression"
|
||||||
|
python3 -c "import site; print(site.getsitepackages())"
|
||||||
|
# Use your sites-packages file here:
|
||||||
|
ln -s $BAZEL_OUTPUT_DIR [...]/site-packages/compression
|
||||||
|
python3 python/convert_from_safetensors.py --load_path [...].safetensors.index.json
|
||||||
|
```
|
||||||
|
|
||||||
|
See also compression/convert_weights.py for a slightly older option to convert a
|
||||||
|
pytorch checkpoint. (The code may need updates to work with Gemma-2 models.)
|
||||||
|
|
||||||
|
**What are some easy ways to make the model run faster?**
|
||||||
|
|
||||||
|
1. Make sure you are using the 8-bit switched floating point `-sfp` models.
|
||||||
|
These are half the size of bf16 and thus use less memory bandwidth and cache
|
||||||
|
space.
|
||||||
|
2. If you're on a laptop, make sure power mode is set to maximize performance
|
||||||
|
and saving mode is **off**. For most laptops, the power saving modes get
|
||||||
|
activated automatically if the computer is not plugged in.
|
||||||
|
3. Close other unused cpu-intensive applications.
|
||||||
|
4. On macs, anecdotally we observe a "warm-up" ramp-up in speed as performance
|
||||||
|
cores get engaged.
|
||||||
|
5. Experiment with the `--num_threads` argument value. Depending on the device,
|
||||||
|
larger numbers don't always mean better performance.
|
||||||
|
|
||||||
|
We're also working on algorithmic and optimization approaches for faster
|
||||||
|
inference, stay tuned.
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
`gemma` has different usage modes, controlled by the verbosity flag.
|
||||||
|
|
||||||
|
All usage modes are currently interactive, triggering text generation upon
|
||||||
|
newline input.
|
||||||
|
|
||||||
|
| Verbosity | Usage mode | Details |
|
||||||
|
| --------------- | ---------- | --------------------------------------------- |
|
||||||
|
| `--verbosity 0` | Minimal | Only prints generation output. Suitable as a CLI tool. |
|
||||||
|
| `--verbosity 1` | Default | Standard user-facing terminal UI. |
|
||||||
|
| `--verbosity 2` | Detailed | Shows additional developer and debug info. |
|
||||||
|
|
||||||
|
### Interactive Terminal App
|
||||||
|
|
||||||
|
By default, verbosity is set to 1, bringing up a terminal-based interactive
|
||||||
|
interface when `gemma` is invoked:
|
||||||
|
|
||||||
|
```console
|
||||||
|
$ ./gemma [...]
|
||||||
|
__ _ ___ _ __ ___ _ __ ___ __ _ ___ _ __ _ __
|
||||||
|
/ _` |/ _ \ '_ ` _ \| '_ ` _ \ / _` | / __| '_ \| '_ \
|
||||||
|
| (_| | __/ | | | | | | | | | | (_| || (__| |_) | |_) |
|
||||||
|
\__, |\___|_| |_| |_|_| |_| |_|\__,_(_)___| .__/| .__/
|
||||||
|
__/ | | | | |
|
||||||
|
|___/ |_| |_|
|
||||||
|
|
||||||
|
tokenizer : tokenizer.spm
|
||||||
|
compressed_weights : 2b-it-sfp.sbs
|
||||||
|
model : 2b-it
|
||||||
|
weights : [no path specified]
|
||||||
|
max_generated_tokens : 2048
|
||||||
|
|
||||||
|
*Usage*
|
||||||
|
Enter an instruction and press enter (%C reset conversation, %Q quits).
|
||||||
|
|
||||||
|
*Examples*
|
||||||
|
- Write an email to grandma thanking her for the cookies.
|
||||||
|
- What are some historical attractions to visit around Massachusetts?
|
||||||
|
- Compute the nth fibonacci number in javascript.
|
||||||
|
- Write a standup comedy bit about WebGPU programming.
|
||||||
|
|
||||||
|
> What are some outdoorsy places to visit around Boston?
|
||||||
|
|
||||||
|
[ Reading prompt ] .....................
|
||||||
|
|
||||||
|
|
||||||
|
**Boston Harbor and Islands:**
|
||||||
|
|
||||||
|
* **Boston Harbor Islands National and State Park:** Explore pristine beaches, wildlife, and maritime history.
|
||||||
|
* **Charles River Esplanade:** Enjoy scenic views of the harbor and city skyline.
|
||||||
|
* **Boston Harbor Cruise Company:** Take a relaxing harbor cruise and admire the city from a different perspective.
|
||||||
|
* **Seaport Village:** Visit a charming waterfront area with shops, restaurants, and a seaport museum.
|
||||||
|
|
||||||
|
**Forest and Nature:**
|
||||||
|
|
||||||
|
* **Forest Park:** Hike through a scenic forest with diverse wildlife.
|
||||||
|
* **Quabbin Reservoir:** Enjoy boating, fishing, and hiking in a scenic setting.
|
||||||
|
* **Mount Forest:** Explore a mountain with breathtaking views of the city and surrounding landscape.
|
||||||
|
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
### Usage as a Command Line Tool
|
||||||
|
|
||||||
|
For using the `gemma` executable as a command line tool, it may be useful to
|
||||||
|
create an alias for gemma.cpp with arguments fully specified:
|
||||||
|
|
||||||
|
```sh
|
||||||
|
alias gemma2b="~/gemma.cpp/build/gemma -- --tokenizer ~/gemma.cpp/build/tokenizer.spm --weights ~/gemma.cpp/build/gemma2-2b-it-sfp.sbs --model gemma2-2b-it --verbosity 0"
|
||||||
|
```
|
||||||
|
|
||||||
|
Replace the above paths with your own paths to the model and tokenizer paths
|
||||||
|
from the download.
|
||||||
|
|
||||||
|
Here is an example of prompting `gemma` with a truncated input
|
||||||
|
file (using a `gemma2b` alias like defined above):
|
||||||
|
|
||||||
|
```sh
|
||||||
|
cat configs.h | tail -n 35 | tr '\n' ' ' | xargs -0 echo "What does this C++ code do: " | gemma2b
|
||||||
|
```
|
||||||
|
|
||||||
|
> [!NOTE]
|
||||||
|
> CLI usage of gemma.cpp is experimental and should take context length
|
||||||
|
> limitations into account.
|
||||||
|
|
||||||
|
The output of the above command should look like:
|
||||||
|
|
||||||
|
```console
|
||||||
|
[ Reading prompt ] [...]
|
||||||
|
This C++ code snippet defines a set of **constants** used in a large language model (LLM) implementation, likely related to the **attention mechanism**.
|
||||||
|
|
||||||
|
Let's break down the code:
|
||||||
|
[...]
|
||||||
|
```
|
||||||
|
|
||||||
|
### Incorporating gemma.cpp as a Library in your Project
|
||||||
|
|
||||||
|
The easiest way to incorporate gemma.cpp in your own project is to pull in
|
||||||
|
gemma.cpp and dependencies using `FetchContent`. You can add the following to your
|
||||||
|
CMakeLists.txt:
|
||||||
|
|
||||||
|
```
|
||||||
|
include(FetchContent)
|
||||||
|
|
||||||
|
FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 53de76561cfc149d3c01037f0595669ad32a5e7c)
|
||||||
|
FetchContent_MakeAvailable(sentencepiece)
|
||||||
|
|
||||||
|
FetchContent_Declare(gemma GIT_REPOSITORY https://github.com/google/gemma.cpp GIT_TAG origin/main)
|
||||||
|
FetchContent_MakeAvailable(gemma)
|
||||||
|
|
||||||
|
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG da250571a45826b21eebbddc1e50d0c1137dee5f)
|
||||||
|
FetchContent_MakeAvailable(highway)
|
||||||
|
```
|
||||||
|
|
||||||
|
Note for the gemma.cpp `GIT_TAG`, you may replace `origin/main` for a specific
|
||||||
|
commit hash if you would like to pin the library version.
|
||||||
|
|
||||||
|
After your executable is defined (substitute your executable name for
|
||||||
|
`[Executable Name]` below):
|
||||||
|
|
||||||
|
```
|
||||||
|
target_link_libraries([Executable Name] libgemma hwy hwy_contrib sentencepiece)
|
||||||
|
FetchContent_GetProperties(gemma)
|
||||||
|
FetchContent_GetProperties(sentencepiece)
|
||||||
|
target_include_directories([Executable Name] PRIVATE ${gemma_SOURCE_DIR})
|
||||||
|
target_include_directories([Executable Name] PRIVATE ${sentencepiece_SOURCE_DIR})
|
||||||
|
```
|
||||||
|
|
||||||
|
### Building gemma.cpp as a Library
|
||||||
|
|
||||||
|
gemma.cpp can also be used as a library dependency in your own project. The
|
||||||
|
shared library artifact can be built by modifying the make invocation to build
|
||||||
|
the `libgemma` target instead of `gemma`.
|
||||||
|
|
||||||
|
> [!NOTE]
|
||||||
|
> If you are using gemma.cpp in your own project with the `FetchContent` steps
|
||||||
|
> in the previous section, building the library is done automatically by `cmake`
|
||||||
|
> and this section can be skipped.
|
||||||
|
|
||||||
|
First, run `cmake`:
|
||||||
|
|
||||||
|
```sh
|
||||||
|
cmake -B build
|
||||||
|
```
|
||||||
|
|
||||||
|
Then, run `make` with the `libgemma` target:
|
||||||
|
|
||||||
|
```sh
|
||||||
|
cd build
|
||||||
|
make -j [number of parallel threads to use] libgemma
|
||||||
|
```
|
||||||
|
|
||||||
|
If this is successful, you should now have a `libgemma` library file in the
|
||||||
|
`build/` directory. On Unix platforms, the filename is `libgemma.a`.
|
||||||
|
|
||||||
|
## Independent Projects Using gemma.cpp
|
||||||
|
|
||||||
|
Some independent projects using gemma.cpp:
|
||||||
|
|
||||||
|
- [gemma-cpp-python - Python bindings](https://github.com/namtranase/gemma-cpp-python)
|
||||||
|
- [lua-cgemma - Lua bindings](https://github.com/ufownl/lua-cgemma)
|
||||||
|
- [Godot engine demo project](https://github.com/Rliop913/Gemma-godot-demo-project)
|
||||||
|
|
||||||
|
If you would like to have your project included, feel free to get in touch or
|
||||||
|
submit a PR with a `README.md` edit.
|
||||||
|
|
||||||
|
## Acknowledgements and Contacts
|
||||||
|
|
||||||
|
gemma.cpp was started in fall 2023 by [Austin Huang](mailto:austinvhuang@google.com)
|
||||||
|
and [Jan Wassenberg](mailto:janwas@google.com), and subsequently released February 2024
|
||||||
|
thanks to contributions from Phil Culliton, Paul Chang, and Dan Zheng.
|
||||||
|
|
||||||
|
Griffin support was implemented in April 2024 thanks to contributions by Andrey
|
||||||
|
Mikhaylov, Eugene Kliuchnikov, Jan Wassenberg, Jyrki Alakuijala, Lode
|
||||||
|
Vandevenne, Luca Versari, Martin Bruse, Phil Culliton, Sami Boukortt, Thomas
|
||||||
|
Fischbacher and Zoltan Szabadka.
|
||||||
|
|
||||||
|
Gemma-2 support was implemented in June/July 2024 with the help of several
|
||||||
|
people.
|
||||||
|
|
||||||
|
PaliGemma support was implemented in September 2024 with contributions from
|
||||||
|
Daniel Keysers.
|
||||||
|
|
||||||
|
[Jan Wassenberg](mailto:janwas@google.com) has continued to contribute many
|
||||||
|
improvements, including major gains in efficiency, since the initial release.
|
||||||
|
|
||||||
|
This is not an officially supported Google product.
|
||||||
|
|
|
||||||
|
|
@ -13,383 +13,85 @@
|
||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
// Shared between various frontends.
|
// Argument parsing for Gemma.
|
||||||
|
|
||||||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_ARGS_H_
|
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_ARGS_H_
|
||||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_ARGS_H_
|
#define THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_ARGS_H_
|
||||||
|
|
||||||
#include <stddef.h>
|
|
||||||
#include <stdio.h>
|
#include <stdio.h>
|
||||||
|
|
||||||
#include <memory>
|
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
#include "compression/io.h" // Path
|
|
||||||
#include "compression/shared.h"
|
#include "compression/shared.h"
|
||||||
#include "gemma/common.h"
|
#include "gemma/common.h"
|
||||||
#include "gemma/gemma.h" // For CreateGemma
|
#include "gemma/gemma.h" // For CreateGemma
|
||||||
#include "hwy/base.h" // HWY_IS_ASAN, HWY_ABORT
|
#include "hwy/base.h" // HWY_ABORT
|
||||||
#include "ops/matmul.h"
|
#include "ops/matmul.h"
|
||||||
#include "util/allocator.h"
|
|
||||||
#include "util/args.h"
|
#include "util/args.h"
|
||||||
#include "util/basics.h" // Tristate
|
#include "util/basics.h" // Tristate
|
||||||
#include "util/threading.h"
|
|
||||||
#include "util/threading_context.h"
|
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
static inline const char* CompiledConfig() {
|
// Arguments related to inference: sampling, text etc.
|
||||||
if (HWY_IS_ASAN) {
|
|
||||||
return "asan";
|
|
||||||
} else if (HWY_IS_MSAN) {
|
|
||||||
return "msan";
|
|
||||||
} else if (HWY_IS_TSAN) {
|
|
||||||
return "tsan";
|
|
||||||
} else if (HWY_IS_HWASAN) {
|
|
||||||
return "hwasan";
|
|
||||||
} else if (HWY_IS_UBSAN) {
|
|
||||||
return "ubsan";
|
|
||||||
} else if (HWY_IS_DEBUG_BUILD) {
|
|
||||||
return "dbg";
|
|
||||||
} else {
|
|
||||||
return "opt";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
template <typename Derived>
|
|
||||||
struct ArgsBase {
|
|
||||||
void Init() { static_cast<Derived*>(this)->ForEach(SetToDefault()); }
|
|
||||||
|
|
||||||
void InitAndParse(int argc, char* argv[]) {
|
|
||||||
Init();
|
|
||||||
static_cast<Derived*>(this)->ForEach(ParseOption(argc, argv));
|
|
||||||
}
|
|
||||||
|
|
||||||
void Print(int min_verbosity = 1) const {
|
|
||||||
static_cast<const Derived*>(this)->ForEach(PrintOption(min_verbosity));
|
|
||||||
}
|
|
||||||
|
|
||||||
void Help() const { static_cast<const Derived*>(this)->ForEach(PrintHelp()); }
|
|
||||||
|
|
||||||
protected:
|
|
||||||
// Helper struct for printing help messages
|
|
||||||
struct PrintHelp {
|
|
||||||
template <typename T>
|
|
||||||
void operator()(const T& value, const char* name, const T& default_value,
|
|
||||||
const char* description, int verbosity = 1) const {
|
|
||||||
fprintf(stderr, " --%s\n %s\n", name, description);
|
|
||||||
}
|
|
||||||
// Special case for strings to avoid template deduction issues
|
|
||||||
void operator()(const std::string& value, const char* name,
|
|
||||||
const std::string& default_value, const char* description,
|
|
||||||
int verbosity = 1) const {
|
|
||||||
fprintf(stderr, " --%s\n %s\n", name, description);
|
|
||||||
}
|
|
||||||
// Special case for Path type
|
|
||||||
void operator()(const Path& value, const char* name,
|
|
||||||
const Path& default_value, const char* description,
|
|
||||||
int verbosity = 1) const {
|
|
||||||
fprintf(stderr, " --%s\n %s\n", name, description);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// Helper struct for setting default values
|
|
||||||
struct SetToDefault {
|
|
||||||
template <typename T>
|
|
||||||
void operator()(T& value, const char* name, const T& default_value,
|
|
||||||
const char* description, int verbosity = 1) const {
|
|
||||||
value = default_value;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// Helper struct for printing values
|
|
||||||
struct PrintOption {
|
|
||||||
explicit PrintOption(int min_verbosity) : min_verbosity_(min_verbosity) {}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
void operator()(const T& value, const char* name, const T& default_value,
|
|
||||||
const char* description, int verbosity = 1) const {
|
|
||||||
if (verbosity >= min_verbosity_) {
|
|
||||||
fprintf(stderr, "%s: %s\n", name, ToString(value).c_str());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
int min_verbosity_;
|
|
||||||
|
|
||||||
// Helper function to convert values to string
|
|
||||||
template <typename T>
|
|
||||||
static std::string ToString(const T& value) {
|
|
||||||
return std::to_string(value);
|
|
||||||
}
|
|
||||||
// Specialization for string
|
|
||||||
static std::string ToString(const std::string& value) { return value; }
|
|
||||||
// Specialization for Path
|
|
||||||
static std::string ToString(const Path& value) { return value.path; }
|
|
||||||
};
|
|
||||||
};
|
|
||||||
struct ThreadingArgs : public ArgsBase<ThreadingArgs> {
|
|
||||||
public:
|
|
||||||
ThreadingArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
|
|
||||||
ThreadingArgs() { Init(); };
|
|
||||||
|
|
||||||
int verbosity;
|
|
||||||
|
|
||||||
size_t max_threads; // divided among the detected clusters
|
|
||||||
Tristate pin; // pin threads?
|
|
||||||
Tristate spin; // use spin waits?
|
|
||||||
|
|
||||||
// For BoundedSlice:
|
|
||||||
size_t skip_packages;
|
|
||||||
size_t max_packages;
|
|
||||||
size_t skip_clusters;
|
|
||||||
size_t max_clusters;
|
|
||||||
size_t skip_lps;
|
|
||||||
size_t max_lps;
|
|
||||||
|
|
||||||
std::string eot_line;
|
|
||||||
std::string prompt;
|
|
||||||
template <class Visitor>
|
|
||||||
void ForEach(const Visitor& visitor) {
|
|
||||||
visitor(verbosity, "verbosity", 1,
|
|
||||||
"Show verbose developer information\n 0 = only print generation "
|
|
||||||
"output\n 1 = standard user-facing terminal ui\n 2 = show "
|
|
||||||
"developer/debug info).\n Default = 1.",
|
|
||||||
2);
|
|
||||||
|
|
||||||
// The exact meaning is more subtle: see the comment at NestedPools ctor.
|
|
||||||
visitor(max_threads, "num_threads", size_t{0},
|
|
||||||
"Maximum number of threads to use; default 0 = unlimited.", 2);
|
|
||||||
visitor(pin, "pin", Tristate::kDefault,
|
|
||||||
"Pin threads? -1 = auto, 0 = no, 1 = yes.", 2);
|
|
||||||
visitor(spin, "spin", Tristate::kDefault,
|
|
||||||
"Use spin waits? -1 = auto, 0 = no, 1 = yes.", 2);
|
|
||||||
// These can be used to partition CPU sockets/packages and their
|
|
||||||
// clusters/CCXs across several program instances. The default is to use
|
|
||||||
// all available resources.
|
|
||||||
visitor(skip_packages, "skip_packages", size_t{0},
|
|
||||||
"Index of the first socket to use; default 0 = unlimited.", 2);
|
|
||||||
visitor(max_packages, "max_packages", size_t{0},
|
|
||||||
"Maximum number of sockets to use; default 0 = unlimited.", 2);
|
|
||||||
visitor(skip_clusters, "skip_clusters", size_t{0},
|
|
||||||
"Index of the first CCX to use; default 0 = unlimited.", 2);
|
|
||||||
visitor(max_clusters, "max_clusters", size_t{0},
|
|
||||||
"Maximum number of CCXs to use; default 0 = unlimited.", 2);
|
|
||||||
// These are only used when CPU topology is unknown.
|
|
||||||
visitor(skip_lps, "skip_lps", size_t{0},
|
|
||||||
"Index of the first LP to use; default 0 = unlimited.", 2);
|
|
||||||
visitor(max_lps, "max_lps", size_t{0},
|
|
||||||
"Maximum number of LPs to use; default 0 = unlimited.", 2);
|
|
||||||
|
|
||||||
visitor(
|
|
||||||
eot_line, "eot_line", std::string(""),
|
|
||||||
"End of turn line. "
|
|
||||||
"When you specify this, the prompt will be all lines "
|
|
||||||
"before the line where only the given string appears.\n Default = "
|
|
||||||
"When a newline is encountered, that signals the end of the turn.",
|
|
||||||
2);
|
|
||||||
|
|
||||||
visitor(prompt, "prompt", std::string(""),
|
|
||||||
"Prompt string for non-interactive mode. When provided, the model "
|
|
||||||
"generates a response and exits.",
|
|
||||||
2);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
static inline BoundedTopology CreateTopology(const ThreadingArgs& threading) {
|
|
||||||
return BoundedTopology(
|
|
||||||
BoundedSlice(threading.skip_packages, threading.max_packages),
|
|
||||||
BoundedSlice(threading.skip_clusters, threading.max_clusters),
|
|
||||||
BoundedSlice(threading.skip_lps, threading.max_lps));
|
|
||||||
}
|
|
||||||
|
|
||||||
static inline MatMulEnv MakeMatMulEnv(const ThreadingArgs& threading) {
|
|
||||||
ThreadingContext2::SetArgs(threading);
|
|
||||||
return MatMulEnv(ThreadingContext2::Get());
|
|
||||||
}
|
|
||||||
// Note: These functions may need adjustments depending on your specific class
|
|
||||||
// definitions
|
|
||||||
static inline BoundedTopology CreateTopology(const ThreadingArgs& app) {
|
|
||||||
return BoundedTopology(BoundedSlice(app.skip_packages, app.max_packages),
|
|
||||||
BoundedSlice(app.skip_clusters, app.max_clusters),
|
|
||||||
BoundedSlice(app.skip_lps, app.max_lps));
|
|
||||||
}
|
|
||||||
|
|
||||||
// This function may need to be adjusted based on your NestedPools constructor
|
|
||||||
// signature
|
|
||||||
static inline NestedPools CreatePools(const BoundedTopology& topology,
|
|
||||||
const ThreadingArgs& threading) {
|
|
||||||
// Make sure Allocator::Init() is properly declared/defined
|
|
||||||
const Allocator2& allocator = ThreadingContext2::Get().allocator;
|
|
||||||
// Allocator::Init(topology);
|
|
||||||
|
|
||||||
// Adjust the constructor call based on your actual NestedPools constructor
|
|
||||||
// The error suggests that the constructor doesn't match these arguments
|
|
||||||
return NestedPools(topology, allocator, threading.max_threads, threading.pin);
|
|
||||||
// Alternative: return NestedPools(topology, app.max_threads, app.pin);
|
|
||||||
}
|
|
||||||
|
|
||||||
struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
|
||||||
LoaderArgs(int argc, char* argv[], bool validate = true) {
|
|
||||||
InitAndParse(argc, argv);
|
|
||||||
|
|
||||||
if (validate) {
|
|
||||||
if (const char* error = Validate()) {
|
|
||||||
HWY_ABORT("Invalid args: %s", error);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
LoaderArgs(const std::string& tokenizer_path, const std::string& weights_path,
|
|
||||||
const std::string& model, bool validate = true) {
|
|
||||||
Init(); // Init sets to defaults, so assignments must come after Init().
|
|
||||||
tokenizer.path = tokenizer_path;
|
|
||||||
weights.path = weights_path;
|
|
||||||
model_type_str = model;
|
|
||||||
|
|
||||||
if (validate) {
|
|
||||||
if (const char* error = Validate()) {
|
|
||||||
HWY_ABORT("Invalid args: %s", error);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// Returns error string or nullptr if OK.
|
|
||||||
const char* Validate() {
|
|
||||||
if (weights.path.empty()) {
|
|
||||||
return "Missing --weights flag, a file for the model weights.";
|
|
||||||
}
|
|
||||||
if (!weights.Exists()) {
|
|
||||||
return "Can't open file specified with --weights flag.";
|
|
||||||
}
|
|
||||||
info_.model = Model::UNKNOWN;
|
|
||||||
info_.wrapping = PromptWrapping::GEMMA_PT;
|
|
||||||
info_.weight = Type::kUnknown;
|
|
||||||
if (!model_type_str.empty()) {
|
|
||||||
const char* err = ParseModelTypeAndWrapping(model_type_str, info_.model,
|
|
||||||
info_.wrapping);
|
|
||||||
if (err != nullptr) return err;
|
|
||||||
}
|
|
||||||
if (!weight_type_str.empty()) {
|
|
||||||
const char* err = ParseType(weight_type_str, info_.weight);
|
|
||||||
if (err != nullptr) return err;
|
|
||||||
}
|
|
||||||
if (!tokenizer.path.empty()) {
|
|
||||||
if (!tokenizer.Exists()) {
|
|
||||||
return "Can't open file specified with --tokenizer flag.";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// model_type and tokenizer must be either both present or both absent.
|
|
||||||
// Further checks happen on weight loading.
|
|
||||||
if (model_type_str.empty() != tokenizer.path.empty()) {
|
|
||||||
return "Missing or extra flags for model_type or tokenizer.";
|
|
||||||
}
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
Path tokenizer;
|
|
||||||
Path weights; // weights file location
|
|
||||||
Path compressed_weights;
|
|
||||||
std::string model_type_str;
|
|
||||||
std::string weight_type_str;
|
|
||||||
|
|
||||||
template <class Visitor>
|
|
||||||
void ForEach(const Visitor& visitor) {
|
|
||||||
visitor(tokenizer, "tokenizer", Path(),
|
|
||||||
"Path name of tokenizer model file.");
|
|
||||||
visitor(weights, "weights", Path(),
|
|
||||||
"Path name of model weights (.sbs) file.\n Required argument.\n");
|
|
||||||
visitor(compressed_weights, "compressed_weights", Path(),
|
|
||||||
"Deprecated alias for --weights.");
|
|
||||||
visitor(model_type_str, "model", std::string(),
|
|
||||||
"Model type, see common.cc for valid values.\n");
|
|
||||||
visitor(weight_type_str, "weight_type", std::string("sfp"),
|
|
||||||
"Weight type\n f32 = float, bf16 = bfloat16, sfp = 8-bit SFP.");
|
|
||||||
}
|
|
||||||
|
|
||||||
// Uninitialized before Validate, must call after that.
|
|
||||||
const ModelInfo& Info() const { return info_; }
|
|
||||||
|
|
||||||
private:
|
|
||||||
// TODO(rays): remove this. Eventually ModelConfig will be loaded from the
|
|
||||||
// weights file, so we can remove the need for this struct entirely.
|
|
||||||
ModelInfo info_;
|
|
||||||
};
|
|
||||||
|
|
||||||
// `env` must remain valid for the lifetime of the Gemma.
|
|
||||||
static inline Gemma CreateGemma(const LoaderArgs& loader, MatMulEnv& env) {
|
|
||||||
if (Type::kUnknown == loader.Info().weight ||
|
|
||||||
Model::UNKNOWN == loader.Info().model || loader.tokenizer.path.empty()) {
|
|
||||||
// New weights file format doesn't need tokenizer path or model/weightinfo.
|
|
||||||
return Gemma(loader.weights, env);
|
|
||||||
}
|
|
||||||
return Gemma(loader.tokenizer, loader.weights, loader.Info(), env);
|
|
||||||
}
|
|
||||||
|
|
||||||
// `env` must remain valid for the lifetime of the Gemma.
|
|
||||||
static inline std::unique_ptr<Gemma> AllocateGemma(const LoaderArgs& loader,
|
|
||||||
MatMulEnv& env) {
|
|
||||||
if (Type::kUnknown == loader.Info().weight ||
|
|
||||||
Model::UNKNOWN == loader.Info().model || loader.tokenizer.path.empty()) {
|
|
||||||
// New weights file format doesn't need tokenizer path or model/weight info.
|
|
||||||
return std::make_unique<Gemma>(loader.weights, env);
|
|
||||||
}
|
|
||||||
return std::make_unique<Gemma>(loader.tokenizer, loader.weights,
|
|
||||||
loader.Info(), env);
|
|
||||||
}
|
|
||||||
|
|
||||||
struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
||||||
InferenceArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
|
// Arguments for getc-like interfaces
|
||||||
InferenceArgs() { Init(); };
|
size_t max_tokens;
|
||||||
|
|
||||||
int verbosity;
|
|
||||||
|
|
||||||
size_t max_generated_tokens;
|
size_t max_generated_tokens;
|
||||||
|
|
||||||
size_t prefill_tbatch_size;
|
|
||||||
size_t decode_qbatch_size;
|
|
||||||
|
|
||||||
float temperature;
|
float temperature;
|
||||||
size_t top_k;
|
size_t top_k;
|
||||||
bool deterministic;
|
float top_p;
|
||||||
bool multiturn;
|
float min_p;
|
||||||
Path image_file;
|
int repeat_penalty_power;
|
||||||
|
float repeat_penalty_presence;
|
||||||
|
float repeat_penalty_decay;
|
||||||
|
float repeat_penalty_range;
|
||||||
|
|
||||||
|
// Batch configuration:
|
||||||
|
size_t prefill_tbatch_size;
|
||||||
|
size_t decode_tbatch_size;
|
||||||
|
|
||||||
|
// Non-interactive mode prompt
|
||||||
|
std::string prompt;
|
||||||
std::string eot_line;
|
std::string eot_line;
|
||||||
|
|
||||||
// Returns error string or nullptr if OK.
|
|
||||||
const char* Validate() const {
|
|
||||||
if (max_generated_tokens > gcpp::kSeqLen) {
|
|
||||||
return "max_generated_tokens is larger than the maximum sequence length "
|
|
||||||
"(see configs.h).";
|
|
||||||
}
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <class Visitor>
|
template <class Visitor>
|
||||||
void ForEach(const Visitor& visitor) {
|
void ForEach(Visitor& visitor) {
|
||||||
visitor(verbosity, "verbosity", 1,
|
// Each line specifies a variable member, its name, default value, and help.
|
||||||
"Show verbose developer information\n 0 = only print generation "
|
visitor(max_tokens, "max_tokens", size_t{50},
|
||||||
"output\n 1 = standard user-facing terminal ui\n 2 = show "
|
"Maximum number of total tokens including prompt (0=no limit).", 1);
|
||||||
"developer/debug info).\n Default = 1.",
|
visitor(max_generated_tokens, "max_generated_tokens", size_t{512},
|
||||||
2);
|
"Maximum number of generated tokens (not including prompt) (0=no "
|
||||||
|
"limit).",
|
||||||
|
1);
|
||||||
|
visitor(temperature, "temperature", 1.0f,
|
||||||
|
"Temperature (randomness) for logits.", 1);
|
||||||
|
visitor(top_k, "top_k", size_t{40},
|
||||||
|
"Number of highest-probability tokens to consider (0=unlimited).",
|
||||||
|
1);
|
||||||
|
visitor(top_p, "top_p", 0.9f, "Top-p probability threshold (0.0=disabled).",
|
||||||
|
1);
|
||||||
|
visitor(min_p, "min_p", 0.0f, "Min-p probability threshold (0.0=disabled).",
|
||||||
|
1);
|
||||||
|
visitor(
|
||||||
|
repeat_penalty_power, "repeat_penalty_power", 1,
|
||||||
|
"Penalty power (1=standard frequentist penalty). If 0, skips penalty "
|
||||||
|
"computation.",
|
||||||
|
1);
|
||||||
|
visitor(repeat_penalty_presence, "repeat_penalty_presence", 0.0f,
|
||||||
|
"Penalty for token presence regardless of frequency (additive).",
|
||||||
|
1);
|
||||||
|
visitor(repeat_penalty_decay, "repeat_penalty_decay", 0.0f,
|
||||||
|
"Penalty for token n positions ago is decayed by "
|
||||||
|
"power(repeat_penalty_decay, n).",
|
||||||
|
1);
|
||||||
|
visitor(repeat_penalty_range, "repeat_penalty_range", 8.0f,
|
||||||
|
"Penalty fades out near the end of range (tokens)", 1);
|
||||||
|
|
||||||
visitor(max_generated_tokens, "max_generated_tokens", size_t{2048},
|
// Batch configuration:
|
||||||
"Maximum number of tokens to generate.");
|
visitor(prefill_tbatch_size, "prefill_tbatch_size", size_t{2},
|
||||||
|
"Token batch size for prefill; <= 32", 2);
|
||||||
visitor(prefill_tbatch_size, "prefill_tbatch", size_t{256},
|
visitor(decode_tbatch_size, "decode_tbatch_size", size_t{1},
|
||||||
"Prefill: max tokens per batch.");
|
"Token batch size for decode (only 1 currently supported)", 2);
|
||||||
visitor(decode_qbatch_size, "decode_qbatch", size_t{16},
|
|
||||||
"Decode: max queries per batch.");
|
|
||||||
|
|
||||||
visitor(temperature, "temperature", 1.0f, "Temperature for top-K", 2);
|
|
||||||
visitor(top_k, "top_k", size_t{1}, "Number of top-K tokens to sample from",
|
|
||||||
2);
|
|
||||||
visitor(deterministic, "deterministic", false,
|
|
||||||
"Make top-k sampling deterministic", 2);
|
|
||||||
visitor(multiturn, "multiturn", false,
|
|
||||||
"Multiturn mode\n 0 = clear KV cache after every "
|
|
||||||
"interaction\n 1 = continue KV cache after every interaction\n "
|
|
||||||
" Default : 0 (conversation "
|
|
||||||
"resets every turn)");
|
|
||||||
visitor(image_file, "image_file", Path(), "Image file to load.");
|
|
||||||
|
|
||||||
visitor(
|
visitor(
|
||||||
eot_line, "eot_line", std::string(""),
|
eot_line, "eot_line", std::string(""),
|
||||||
|
|
@ -397,47 +99,123 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
||||||
"When you specify this, the prompt will be all lines "
|
"When you specify this, the prompt will be all lines "
|
||||||
"before the line where only the given string appears.\n Default = "
|
"before the line where only the given string appears.\n Default = "
|
||||||
"When a newline is encountered, that signals the end of the turn.",
|
"When a newline is encountered, that signals the end of the turn.",
|
||||||
2);
|
1);
|
||||||
|
|
||||||
|
// Non-interactive mode prompt
|
||||||
|
visitor(prompt, "prompt", std::string(""),
|
||||||
|
"Prompt to use in non-interactive mode", 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
void CopyTo(RuntimeConfig& runtime_config) const {
|
const char* Validate() const {
|
||||||
runtime_config.max_generated_tokens = max_generated_tokens;
|
if (max_generated_tokens == 0 && max_tokens == 0) {
|
||||||
runtime_config.prefill_tbatch_size = prefill_tbatch_size;
|
return "At least one of max_tokens and max_generated_tokens must be > 0";
|
||||||
runtime_config.decode_qbatch_size = decode_qbatch_size;
|
|
||||||
if (prefill_tbatch_size > MMStorage::kMaxM) {
|
|
||||||
HWY_ABORT(
|
|
||||||
"prefill_tbatch_size %zu > kMaxM %zu: specify a smaller value, "
|
|
||||||
"or increase the constant in MMStorage.\n",
|
|
||||||
prefill_tbatch_size, MMStorage::kMaxM);
|
|
||||||
}
|
}
|
||||||
if (decode_qbatch_size > MMStorage::kMaxM) {
|
if (temperature <= 0.0) {
|
||||||
HWY_ABORT(
|
return "Temperature must be > 0.0";
|
||||||
"decode_qbatch_size %zu > kMaxM %zu: specify a smaller value, "
|
|
||||||
"or increase the constant in MMStorage.\n",
|
|
||||||
decode_qbatch_size, MMStorage::kMaxM);
|
|
||||||
}
|
}
|
||||||
|
if (prefill_tbatch_size > 32) {
|
||||||
runtime_config.temperature = temperature;
|
return "prefill_tbatch_size must be <= 32";
|
||||||
runtime_config.top_k = top_k;
|
}
|
||||||
|
if (decode_tbatch_size != 1) {
|
||||||
|
return "decode_tbatch_size must be 1";
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
static inline void ShowConfig(const ThreadingArgs& threading,
|
// Arguments related to model weights.
|
||||||
const LoaderArgs& loader,
|
struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
||||||
const InferenceArgs& inference) {
|
Path model_path; // Path to directory containing the weights
|
||||||
threading.Print();
|
Path tokenizer; // Optional: can be derived from model_path
|
||||||
loader.Print();
|
bool model_is_gemma2;
|
||||||
inference.Print();
|
Gemma::Config::WeightFormat weight_format;
|
||||||
}
|
|
||||||
static inline void ShowHelp(const ThreadingArgs& threading,
|
template <class Visitor>
|
||||||
const LoaderArgs& loader,
|
void ForEach(Visitor& visitor) {
|
||||||
const InferenceArgs& inference) {
|
// Each line specifies a variable member, its name, default value, and help.
|
||||||
fprintf(stderr, "\nUsage: gemma [flags]\n\nFlags:\n");
|
visitor(model_path, "model", Path{},
|
||||||
threading.Help();
|
"Directory containing weights or config file from `gemma.cpp "
|
||||||
loader.Help();
|
"convert`.",
|
||||||
inference.Help();
|
0);
|
||||||
}
|
visitor(tokenizer, "tokenizer", Path{},
|
||||||
|
"Optional path to tokenizer.model; if empty, looks in model_path.",
|
||||||
|
2);
|
||||||
|
visitor(model_is_gemma2, "model_is_gemma2", false,
|
||||||
|
"Whether the model is a Gemma 2 model", 1);
|
||||||
|
visitor(weight_format, "format", Gemma::Config::kBfloat16,
|
||||||
|
"Model weights format: 0=F32, 1=F16, 2=BF16", 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
const char* Validate() const {
|
||||||
|
if (model_path.path.empty()) {
|
||||||
|
return "Empty model path";
|
||||||
|
}
|
||||||
|
if (weight_format != Gemma::Config::kBfloat16 &&
|
||||||
|
weight_format != Gemma::Config::kFloat16 &&
|
||||||
|
weight_format != Gemma::Config::kFloat32) {
|
||||||
|
return "Invalid weight format";
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Threading-related arguments.
|
||||||
|
struct ThreadingArgs : public ArgsBase<ThreadingArgs> {
|
||||||
|
size_t num_threads;
|
||||||
|
Tristate pin_threads;
|
||||||
|
Tristate use_spinning;
|
||||||
|
int verbosity;
|
||||||
|
|
||||||
|
template <class Visitor>
|
||||||
|
void ForEach(Visitor& visitor) {
|
||||||
|
visitor(num_threads, "threads", size_t{0},
|
||||||
|
"Number of threads (0=auto, half of logical cores)", 1);
|
||||||
|
visitor(pin_threads, "pin_threads", Tristate::kDefault,
|
||||||
|
"Set to true/false to force enable/disable thread pinning.", 2);
|
||||||
|
visitor(use_spinning, "use_spinning", Tristate::kDefault,
|
||||||
|
"Set to true/false to enable/disable thread spinning (typically "
|
||||||
|
"improves "
|
||||||
|
"performance but increases power usage)",
|
||||||
|
2);
|
||||||
|
visitor(verbosity, "verbosity", 1,
|
||||||
|
"Controls printing of progress messages to stderr", 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns nullptr if OK, otherwise error message.
|
||||||
|
const char* Validate() const { return nullptr; }
|
||||||
|
|
||||||
|
// Returns num_threads to use.
|
||||||
|
size_t NumThreadsToUse() const {
|
||||||
|
return num_threads == 0 ? (size_t{hwy::NumberOfProcessors()} + 1) / 2
|
||||||
|
: num_threads;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Command-line arguments for PeftGemma and Gemma.
|
||||||
|
struct GemmaArgs : public ArgsBase<GemmaArgs> {
|
||||||
|
InferenceArgs inference;
|
||||||
|
LoaderArgs loader;
|
||||||
|
ThreadingArgs threading;
|
||||||
|
// For collect_stats.cc:
|
||||||
|
Path output;
|
||||||
|
|
||||||
|
bool trace_outputs; // For -ftrace and dump_csv.cc
|
||||||
|
bool trace_base; // For -ftrace
|
||||||
|
int time_it; // For time_it.cc
|
||||||
|
|
||||||
|
template <class Visitor>
|
||||||
|
void ForEach(Visitor& visitor) {
|
||||||
|
inference.ForEach(visitor);
|
||||||
|
loader.ForEach(visitor);
|
||||||
|
threading.ForEach(visitor);
|
||||||
|
|
||||||
|
visitor(output, "output", Path{}, "Where to write CSV data / stats", 2);
|
||||||
|
visitor(trace_outputs, "trace_outputs", false, "For tracing", 2);
|
||||||
|
visitor(trace_base, "trace_base", false, "For tracing", 2);
|
||||||
|
visitor(time_it, "time_it", 0, "For benchmarks", 2);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
||||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_ARGS_H_
|
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_ARGS_H_
|
||||||
90
gemma/run.cc
90
gemma/run.cc
|
|
@ -78,6 +78,18 @@ std::string GetPrompt(std::istream& input, int verbosity,
|
||||||
return prompt_string;
|
return prompt_string;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// New GetPrompt function that accepts InferenceArgs
|
||||||
|
std::string GetPrompt(const InferenceArgs& inference, int verbosity,
|
||||||
|
size_t turn) {
|
||||||
|
// Check for command-line prompt first
|
||||||
|
if (!inference.prompt.empty()) {
|
||||||
|
return inference.prompt;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use the existing function for interactive mode
|
||||||
|
return GetPrompt(std::cin, verbosity, inference.eot_line);
|
||||||
|
}
|
||||||
|
|
||||||
// The main Read-Eval-Print Loop.
|
// The main Read-Eval-Print Loop.
|
||||||
void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
|
void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
|
||||||
Gemma& model, KVCache& kv_cache) {
|
Gemma& model, KVCache& kv_cache) {
|
||||||
|
|
@ -89,6 +101,9 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
|
||||||
std::mt19937 gen;
|
std::mt19937 gen;
|
||||||
InitGenerator(inference, gen);
|
InitGenerator(inference, gen);
|
||||||
|
|
||||||
|
// Add flag to track non-interactive mode
|
||||||
|
bool non_interactive_mode = !inference.prompt.empty();
|
||||||
|
|
||||||
const bool have_image = !inference.image_file.path.empty();
|
const bool have_image = !inference.image_file.path.empty();
|
||||||
Image image;
|
Image image;
|
||||||
ImageTokens image_tokens;
|
ImageTokens image_tokens;
|
||||||
|
|
@ -151,47 +166,30 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
|
||||||
|
|
||||||
// Read prompt and handle special commands.
|
// Read prompt and handle special commands.
|
||||||
std::string prompt_string =
|
std::string prompt_string =
|
||||||
GetPrompt(std::cin, inference.verbosity, inference.eot_line);
|
GetPrompt(inference, inference.verbosity, abs_pos);
|
||||||
if (!std::cin) return;
|
|
||||||
|
if (!std::cin && !non_interactive_mode) return;
|
||||||
|
|
||||||
// If !eot_line.empty(), we append \n, so only look at the first 2 chars.
|
// If !eot_line.empty(), we append \n, so only look at the first 2 chars.
|
||||||
if (prompt_string.size() >= 2 && prompt_string[0] == '%') {
|
if (!non_interactive_mode && prompt_string.size() >= 2 &&
|
||||||
|
prompt_string[0] == '%') {
|
||||||
if (prompt_string[1] == 'q' || prompt_string[1] == 'Q') return;
|
if (prompt_string[1] == 'q' || prompt_string[1] == 'Q') return;
|
||||||
if (prompt_string[1] == 'c' || prompt_string[1] == 'C') {
|
if (prompt_string[1] == 'c' || prompt_string[1] == 'C') {
|
||||||
abs_pos = 0;
|
abs_pos = 0;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (prompt_string.empty()) {
|
|
||||||
|
if (!non_interactive_mode && prompt_string.empty()) {
|
||||||
std::cout << "Use '%q' to quit.\n";
|
std::cout << "Use '%q' to quit.\n";
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Wrap, tokenize and maybe log prompt tokens.
|
|
||||||
std::vector<int> prompt = WrapAndTokenize(model.Tokenizer(), model.Info(),
|
|
||||||
abs_pos, prompt_string);
|
|
||||||
prompt_size = prompt.size();
|
|
||||||
if constexpr (kVerboseLogTokens) {
|
|
||||||
for (int i = 0; i < prompt_size; ++i) {
|
|
||||||
fprintf(stderr, "DDD TOKEN %3d: %6d\n", i, prompt[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set up runtime config.
|
|
||||||
TimingInfo timing_info = {.verbosity = inference.verbosity};
|
|
||||||
RuntimeConfig runtime_config = {.gen = &gen,
|
|
||||||
.verbosity = inference.verbosity,
|
|
||||||
.stream_token = stream_token,
|
|
||||||
.use_spinning = threading.spin};
|
|
||||||
inference.CopyTo(runtime_config);
|
|
||||||
size_t prefix_end = 0;
|
|
||||||
|
|
||||||
std::vector<int> prompt;
|
std::vector<int> prompt;
|
||||||
if (have_image) {
|
if (have_image) {
|
||||||
prompt =
|
prompt =
|
||||||
WrapAndTokenize(model.Tokenizer(), model.ChatTemplate(), model.Info(),
|
WrapAndTokenize(model.Tokenizer(), model.ChatTemplate(), model.Info(),
|
||||||
abs_pos, prompt_string, image_tokens.BatchSize());
|
abs_pos, prompt_string, image_tokens.BatchSize());
|
||||||
runtime_config.image_tokens = &image_tokens;
|
|
||||||
prompt_size = prompt.size();
|
|
||||||
// The end of the prefix for prefix-LM style attention in Paligemma.
|
// The end of the prefix for prefix-LM style attention in Paligemma.
|
||||||
// See Figure 2 of https://arxiv.org/abs/2407.07726.
|
// See Figure 2 of https://arxiv.org/abs/2407.07726.
|
||||||
prefix_end = prompt_size;
|
prefix_end = prompt_size;
|
||||||
|
|
@ -209,6 +207,24 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Set up runtime config.
|
||||||
|
TimingInfo timing_info = {.verbosity = inference.verbosity};
|
||||||
|
RuntimeConfig runtime_config = {.gen = &gen,
|
||||||
|
.verbosity = inference.verbosity,
|
||||||
|
.stream_token = stream_token,
|
||||||
|
.use_spinning = threading.spin};
|
||||||
|
inference.CopyTo(runtime_config);
|
||||||
|
size_t prefix_end = 0;
|
||||||
|
|
||||||
|
if (have_image) {
|
||||||
|
runtime_config.image_tokens = &image_tokens;
|
||||||
|
prompt_size = prompt.size();
|
||||||
|
// The end of the prefix for prefix-LM style attention in Paligemma.
|
||||||
|
prefix_end = prompt_size;
|
||||||
|
// We need to look at all the tokens for the prefix.
|
||||||
|
runtime_config.prefill_tbatch_size = prompt_size;
|
||||||
|
}
|
||||||
|
|
||||||
// Generate until EOS or max_generated_tokens.
|
// Generate until EOS or max_generated_tokens.
|
||||||
if (inference.verbosity >= 1) {
|
if (inference.verbosity >= 1) {
|
||||||
std::cerr << "\n[ Reading prompt ] " << std::flush;
|
std::cerr << "\n[ Reading prompt ] " << std::flush;
|
||||||
|
|
@ -217,6 +233,11 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
|
||||||
timing_info);
|
timing_info);
|
||||||
std::cout << "\n\n";
|
std::cout << "\n\n";
|
||||||
|
|
||||||
|
// Break the loop if in non-interactive mode
|
||||||
|
if (non_interactive_mode) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
// Prepare for the next turn. Works only for PaliGemma.
|
// Prepare for the next turn. Works only for PaliGemma.
|
||||||
if (!inference.multiturn ||
|
if (!inference.multiturn ||
|
||||||
model.Info().wrapping == PromptWrapping::PALIGEMMA) {
|
model.Info().wrapping == PromptWrapping::PALIGEMMA) {
|
||||||
|
|
@ -249,22 +270,6 @@ void Run(ThreadingArgs& threading, LoaderArgs& loader,
|
||||||
KVCache kv_cache =
|
KVCache kv_cache =
|
||||||
KVCache::Create(model.GetModelConfig(), inference.prefill_tbatch_size);
|
KVCache::Create(model.GetModelConfig(), inference.prefill_tbatch_size);
|
||||||
|
|
||||||
if (!threading.prompt.empty()) {
|
|
||||||
std::vector<int> prompt =
|
|
||||||
WrapAndTokenize(model.Tokenizer(), model.ChatTemplate(), model.Info(),
|
|
||||||
0, threading.prompt);
|
|
||||||
|
|
||||||
TimingInfo timing_info = {.verbosity = inference.verbosity};
|
|
||||||
RuntimeConfig runtime_config = {.gen = nullptr, // Use default generator
|
|
||||||
.verbosity = inference.verbosity,
|
|
||||||
.use_spinning = threading.spin};
|
|
||||||
inference.CopyTo(runtime_config);
|
|
||||||
|
|
||||||
model.Generate(runtime_config, prompt, 0, 0, kv_cache, timing_info);
|
|
||||||
std::cout << "\n";
|
|
||||||
return; // Exit after generating response
|
|
||||||
}
|
|
||||||
|
|
||||||
if (inference.verbosity >= 1) {
|
if (inference.verbosity >= 1) {
|
||||||
std::string instructions =
|
std::string instructions =
|
||||||
"*Usage*\n"
|
"*Usage*\n"
|
||||||
|
|
@ -286,11 +291,14 @@ void Run(ThreadingArgs& threading, LoaderArgs& loader,
|
||||||
instructions += multiturn;
|
instructions += multiturn;
|
||||||
instructions += examples;
|
instructions += examples;
|
||||||
|
|
||||||
|
// Skip the banner and instructions in non-interactive mode
|
||||||
|
if (inference.prompt.empty()) {
|
||||||
std::cout << "\033[2J\033[1;1H" // clear screen
|
std::cout << "\033[2J\033[1;1H" // clear screen
|
||||||
<< kAsciiArtBanner << "\n\n";
|
<< kAsciiArtBanner << "\n\n";
|
||||||
ShowConfig(threading, loader, inference);
|
ShowConfig(threading, loader, inference);
|
||||||
std::cout << "\n" << instructions << "\n";
|
std::cout << "\n" << instructions << "\n";
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
ReplGemma(threading, inference, model, kv_cache);
|
ReplGemma(threading, inference, model, kv_cache);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue