mirror of https://github.com/google/gemma.cpp.git
Add --prompt flag for non-interactive mode
This commit is contained in:
parent
716713f0e6
commit
cbf179990f
596
README.md
596
README.md
|
|
@ -1,27 +1,583 @@
|
|||
---
|
||||
library_name: gemma.cpp
|
||||
license: gemma
|
||||
pipeline_tag: text-generation
|
||||
tags: []
|
||||
extra_gated_heading: Access Gemma on Hugging Face
|
||||
extra_gated_prompt: To access Gemma on Hugging Face, you’re required to review and
|
||||
agree to Google’s usage license. To do this, please ensure you’re logged-in to Hugging
|
||||
Face and click below. Requests are processed immediately.
|
||||
extra_gated_button_content: Acknowledge license
|
||||
---
|
||||
# gemma.cpp
|
||||
|
||||
# Gemma Model Card
|
||||
gemma.cpp is a lightweight, standalone C++ inference engine for the Gemma
|
||||
foundation models from Google.
|
||||
|
||||
**Model Page**: [Gemma](https://ai.google.dev/gemma/docs)
|
||||
For additional information about Gemma, see
|
||||
[ai.google.dev/gemma](https://ai.google.dev/gemma). Model weights, including
|
||||
gemma.cpp specific artifacts, are
|
||||
[available on kaggle](https://www.kaggle.com/models/google/gemma).
|
||||
|
||||
This model card corresponds to the 2B base version of the Gemma model for usage with C++ (https://github.com/google/gemma.cpp). This is a compressed version of the weights, which will load, run, and download more quickly. For more information about the model, visit https://huggingface.co/google/gemma-2b.
|
||||
## Who is this project for?
|
||||
|
||||
**Resources and Technical Documentation**:
|
||||
Modern LLM inference engines are sophisticated systems, often with bespoke
|
||||
capabilities extending beyond traditional neural network runtimes. With this
|
||||
comes opportunities for research and innovation through co-design of high level
|
||||
algorithms and low-level computation. However, there is a gap between
|
||||
deployment-oriented C++ inference runtimes, which are not designed for
|
||||
experimentation, and Python-centric ML research frameworks, which abstract away
|
||||
low-level computation through compilation.
|
||||
|
||||
* [Responsible Generative AI Toolkit](https://ai.google.dev/responsible)
|
||||
* [Gemma on Kaggle](https://www.kaggle.com/models/google/gemma)
|
||||
* [Gemma on Vertex Model Garden](https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/335?version=gemma-2b-gg-hf)
|
||||
gemma.cpp provides a minimalist implementation of Gemma-1, Gemma-2, Gemma-3, and
|
||||
PaliGemma models, focusing on simplicity and directness rather than full
|
||||
generality. This is inspired by vertically-integrated model implementations such
|
||||
as [ggml](https://github.com/ggerganov/ggml),
|
||||
[llama.c](https://github.com/karpathy/llama2.c), and
|
||||
[llama.rs](https://github.com/srush/llama2.rs).
|
||||
|
||||
**Terms of Use**: [Terms](https://www.kaggle.com/models/google/gemma/license/consent/verify/huggingface?returnModelRepoId=google/gemma-2b-sfp-cpp)
|
||||
gemma.cpp targets experimentation and research use cases. It is intended to be
|
||||
straightforward to embed in other projects with minimal dependencies and also
|
||||
easily modifiable with a small ~2K LoC core implementation (along with ~4K LoC
|
||||
of supporting utilities). We use the [Google
|
||||
Highway](https://github.com/google/highway) Library to take advantage of
|
||||
portable SIMD for CPU inference.
|
||||
|
||||
**Authors**: Google
|
||||
For production-oriented edge deployments we recommend standard deployment
|
||||
pathways using Python frameworks like JAX, Keras, PyTorch, and Transformers
|
||||
([all model variations here](https://www.kaggle.com/models/google/gemma)).
|
||||
|
||||
## Contributing
|
||||
|
||||
Community contributions large and small are welcome. See
|
||||
[DEVELOPERS.md](https://github.com/google/gemma.cpp/blob/main/DEVELOPERS.md)
|
||||
for additional notes contributing developers and [join the discord by following
|
||||
this invite link](https://discord.gg/H5jCBAWxAe). This project follows
|
||||
[Google's Open Source Community
|
||||
Guidelines](https://opensource.google.com/conduct/).
|
||||
|
||||
*Active development is currently done on the `dev` branch. Please open pull
|
||||
requests targeting `dev` branch instead of `main`, which is intended to be more
|
||||
stable.*
|
||||
|
||||
## Quick Start
|
||||
|
||||
### System requirements
|
||||
|
||||
Before starting, you should have installed:
|
||||
|
||||
- [CMake](https://cmake.org/)
|
||||
- [Clang C++ compiler](https://clang.llvm.org/get_started.html), supporting at
|
||||
least C++17.
|
||||
- `tar` for extracting archives from Kaggle.
|
||||
|
||||
Building natively on Windows requires the Visual Studio 2012 Build Tools with the
|
||||
optional Clang/LLVM C++ frontend (`clang-cl`). This can be installed from the
|
||||
command line with
|
||||
[`winget`](https://learn.microsoft.com/en-us/windows/package-manager/winget/):
|
||||
|
||||
```sh
|
||||
winget install --id Kitware.CMake
|
||||
winget install --id Microsoft.VisualStudio.2022.BuildTools --force --override "--passive --wait --add Microsoft.VisualStudio.Workload.VCTools;installRecommended --add Microsoft.VisualStudio.Component.VC.Llvm.Clang --add Microsoft.VisualStudio.Component.VC.Llvm.ClangToolset"
|
||||
```
|
||||
|
||||
### Step 1: Obtain model weights and tokenizer from Kaggle or Hugging Face Hub
|
||||
|
||||
Visit the
|
||||
[Kaggle page for Gemma-2](https://www.kaggle.com/models/google/gemma-2/gemmaCpp)
|
||||
[or Gemma-1](https://www.kaggle.com/models/google/gemma/frameworks/gemmaCpp),
|
||||
and select `Model Variations |> Gemma C++`.
|
||||
|
||||
On this tab, the `Variation` dropdown includes the options below. Note bfloat16
|
||||
weights are higher fidelity, while 8-bit switched floating point weights enable
|
||||
faster inference. In general, we recommend starting with the `-sfp` checkpoints.
|
||||
|
||||
If you are unsure which model to start with, we recommend starting with the
|
||||
smallest Gemma-2 model, i.e. `2.0-2b-it-sfp`.
|
||||
|
||||
Alternatively, visit the
|
||||
[gemma.cpp](https://huggingface.co/models?other=gemma.cpp) models on the Hugging
|
||||
Face Hub. First go the model repository of the model of interest (see
|
||||
recommendations below). Then, click the `Files and versions` tab and download
|
||||
the model and tokenizer files. For programmatic downloading, if you have
|
||||
`huggingface_hub` installed, you can also download by running:
|
||||
|
||||
```
|
||||
huggingface-cli login # Just the first time
|
||||
huggingface-cli download google/gemma-2b-sfp-cpp --local-dir build/
|
||||
```
|
||||
|
||||
Gemma-1 2B instruction-tuned (`it`) and pre-trained (`pt`) models:
|
||||
|
||||
| Model name | Description |
|
||||
| ----------- | ----------- |
|
||||
| `2b-it` | 2 billion parameter instruction-tuned model, bfloat16 |
|
||||
| `2b-it-sfp` | 2 billion parameter instruction-tuned model, 8-bit switched floating point |
|
||||
| `2b-pt` | 2 billion parameter pre-trained model, bfloat16 |
|
||||
| `2b-pt-sfp` | 2 billion parameter pre-trained model, 8-bit switched floating point |
|
||||
|
||||
Gemma-1 7B instruction-tuned (`it`) and pre-trained (`pt`) models:
|
||||
|
||||
| Model name | Description |
|
||||
| ----------- | ----------- |
|
||||
| `7b-it` | 7 billion parameter instruction-tuned model, bfloat16 |
|
||||
| `7b-it-sfp` | 7 billion parameter instruction-tuned model, 8-bit switched floating point |
|
||||
| `7b-pt` | 7 billion parameter pre-trained model, bfloat16 |
|
||||
| `7b-pt-sfp` | 7 billion parameter pre-trained model, 8-bit switched floating point |
|
||||
|
||||
> [!NOTE]
|
||||
> **Important**: We strongly recommend starting off with the `2b-it-sfp` model to
|
||||
> get up and running.
|
||||
|
||||
Gemma 2 models are named `gemma2-2b-it` for 2B and `9b-it` or `27b-it`. See the
|
||||
`kModelFlags` definition in `common.cc`.
|
||||
|
||||
### Step 2: Extract Files
|
||||
|
||||
If you downloaded the models from Hugging Face, skip to step 3.
|
||||
|
||||
After filling out the consent form, the download should proceed to retrieve a
|
||||
tar archive file `archive.tar.gz`. Extract files from `archive.tar.gz` (this can
|
||||
take a few minutes):
|
||||
|
||||
```
|
||||
tar -xf archive.tar.gz
|
||||
```
|
||||
|
||||
This should produce a file containing model weights such as `2b-it-sfp.sbs` and
|
||||
a tokenizer file (`tokenizer.spm`). You may want to move these files to a
|
||||
convenient directory location (e.g. the `build/` directory in this repo).
|
||||
|
||||
### Step 3: Build
|
||||
|
||||
The build system uses [CMake](https://cmake.org/). To build the gemma inference
|
||||
runtime, create a build directory and generate the build files using `cmake`
|
||||
from the top-level project directory. Note if you previous ran `cmake` and are
|
||||
re-running with a different setting, be sure to delete all files in the `build/`
|
||||
directory with `rm -rf build/*`.
|
||||
|
||||
#### Unix-like Platforms
|
||||
```sh
|
||||
cmake -B build
|
||||
```
|
||||
|
||||
After running `cmake`, you can enter the `build/` directory and run `make` to
|
||||
build the `./gemma` executable:
|
||||
|
||||
```sh
|
||||
# Configure `build` directory
|
||||
cmake --preset make
|
||||
|
||||
# Build project using make
|
||||
cmake --build --preset make -j [number of parallel threads to use]
|
||||
```
|
||||
|
||||
Replace `[number of parallel threads to use]` with a number - the number of
|
||||
cores available on your system is a reasonable heuristic. For example,
|
||||
`make -j4 gemma` will build using 4 threads. If the `nproc` command is
|
||||
available, you can use `make -j$(nproc) gemma` as a reasonable default
|
||||
for the number of threads.
|
||||
|
||||
If you aren't sure of the right value for the `-j` flag, you can simply run
|
||||
`make gemma` instead and it should still build the `./gemma` executable.
|
||||
|
||||
> [!NOTE]
|
||||
> On Windows Subsystem for Linux (WSL) users should set the number of
|
||||
> parallel threads to 1. Using a larger number may result in errors.
|
||||
|
||||
If the build is successful, you should now have a `gemma` executable in the `build/` directory.
|
||||
|
||||
#### Windows
|
||||
|
||||
```sh
|
||||
# Configure `build` directory
|
||||
cmake --preset windows
|
||||
|
||||
# Build project using Visual Studio Build Tools
|
||||
cmake --build --preset windows -j [number of parallel threads to use]
|
||||
```
|
||||
|
||||
If the build is successful, you should now have a `gemma.exe` executable in the `build/` directory.
|
||||
|
||||
#### Bazel
|
||||
|
||||
```sh
|
||||
bazel build -c opt --cxxopt=-std=c++20 :gemma
|
||||
```
|
||||
|
||||
If the build is successful, you should now have a `gemma` executable in the `bazel-bin/` directory.
|
||||
|
||||
#### Make
|
||||
|
||||
If you prefer Makefiles, @jart has made one available here:
|
||||
|
||||
https://github.com/jart/gemma3/blob/main/Makefile
|
||||
|
||||
### Step 4: Run
|
||||
|
||||
You can now run `gemma` from inside the `build/` directory.
|
||||
|
||||
`gemma` has the following required arguments:
|
||||
|
||||
Argument | Description | Example value
|
||||
--------------- | ---------------------------- | -----------------------
|
||||
`--model` | The model type. | `2b-it` ... (see below)
|
||||
`--weights` | The compressed weights file. | `2b-it-sfp.sbs`
|
||||
`--weight_type` | The compressed weight type. | `sfp`
|
||||
`--tokenizer` | The tokenizer file. | `tokenizer.spm`
|
||||
|
||||
`gemma` is invoked as:
|
||||
|
||||
```sh
|
||||
./gemma \
|
||||
--tokenizer [tokenizer file] \
|
||||
--weights [compressed weights file] \
|
||||
--weight_type [f32 or bf16 or sfp (default:sfp)] \
|
||||
--model [2b-it or 2b-pt or 7b-it or 7b-pt or ...]
|
||||
```
|
||||
|
||||
Example invocation for the following configuration:
|
||||
|
||||
- Compressed weights file `2b-it-sfp.sbs` (2B instruction-tuned model, 8-bit
|
||||
switched floating point).
|
||||
- Tokenizer file `tokenizer.spm`.
|
||||
|
||||
```sh
|
||||
./gemma \
|
||||
--tokenizer tokenizer.spm \
|
||||
--weights 2b-it-sfp.sbs --model 2b-it
|
||||
```
|
||||
|
||||
### RecurrentGemma
|
||||
|
||||
This repository includes a version of Gemma based on Griffin
|
||||
([paper](https://arxiv.org/abs/2402.19427),
|
||||
[code](https://github.com/google-deepmind/recurrentgemma)). Its architecture
|
||||
includes both recurrent layers and local attention, thus it is more efficient
|
||||
for longer sequences and has a smaller memory footprint than standard Gemma. We
|
||||
here provide a C++ implementation of this model based on the paper.
|
||||
|
||||
To use the recurrent version of Gemma included in this repository, build the
|
||||
gemma binary as noted above in Step 3. Download the compressed weights and
|
||||
tokenizer from the RecurrentGemma
|
||||
[Kaggle](https://www.kaggle.com/models/google/recurrentgemma/gemmaCpp) as in
|
||||
Step 1, and run the binary as follows:
|
||||
|
||||
`./gemma --tokenizer tokenizer.spm --model gr2b-it --weights 2b-it-sfp.sbs`
|
||||
|
||||
### PaliGemma Vision-Language Model
|
||||
|
||||
This repository includes a version of the PaliGemma VLM
|
||||
([paper](https://arxiv.org/abs/2407.07726),
|
||||
[code](https://github.com/google-research/big_vision/tree/main/big_vision/configs/proj/paligemma))
|
||||
and its successor PaliGemma 2 ([paper](https://arxiv.org/abs/2412.03555)). We
|
||||
provide a C++ implementation of the PaliGemma model family here.
|
||||
|
||||
To use the version of PaliGemma included in this repository, build the gemma
|
||||
binary as noted above in Step 3. Download the compressed weights and tokenizer
|
||||
from
|
||||
[Kaggle](https://www.kaggle.com/models/google/paligemma/gemmaCpp/paligemma-3b-mix-224)
|
||||
and run the binary as follows:
|
||||
|
||||
```sh
|
||||
./gemma \
|
||||
--tokenizer paligemma_tokenizer.model \
|
||||
--model paligemma-224 \
|
||||
--weights paligemma-3b-mix-224-sfp.sbs \
|
||||
--image_file paligemma/testdata/image.ppm
|
||||
```
|
||||
|
||||
Note that the image reading code is very basic to avoid depending on an image
|
||||
processing library for now. We currently only support reading binary PPMs (P6).
|
||||
So use a tool like `convert` to first convert your images into that format, e.g.
|
||||
|
||||
`convert image.jpeg -resize 224x224^ image.ppm`
|
||||
|
||||
(As the image will be resized for processing anyway, we can already resize at
|
||||
this stage for slightly faster loading.)
|
||||
|
||||
The interaction with the image (using the mix-224 checkpoint) may then look
|
||||
something like this:
|
||||
|
||||
```
|
||||
> Describe the image briefly
|
||||
A large building with two towers in the middle of a city.
|
||||
> What type of building is it?
|
||||
church
|
||||
> What color is the church?
|
||||
gray
|
||||
> caption image
|
||||
A large building with two towers stands tall on the water's edge. The building
|
||||
has a brown roof and a window on the side. A tree stands in front of the
|
||||
building, and a flag waves proudly from its top. The water is calm and blue,
|
||||
reflecting the sky above. A bridge crosses the water, and a red and white boat
|
||||
rests on its surface. The building has a window on the side, and a flag on top.
|
||||
A tall tree stands in front of the building, and a window on the building is
|
||||
visible from the water. The water is green, and the sky is blue.
|
||||
```
|
||||
|
||||
### Migrating to single-file format
|
||||
|
||||
There is now a new format for the weights file, which is a single file that
|
||||
allows to contain the tokenizer (and the model type) directly. A tool to migrate
|
||||
from the multi-file format to the single-file format is available.
|
||||
|
||||
```sh
|
||||
compression/migrate_weights \
|
||||
--tokenizer .../tokenizer.spm --weights .../gemma2-2b-it-sfp.sbs \
|
||||
--model gemma2-2b-it --output_weights .../gemma2-2b-it-sfp-single.sbs
|
||||
```
|
||||
|
||||
After migration, you can use the new weights file with gemma.cpp like this:
|
||||
|
||||
```sh
|
||||
./gemma --weights .../gemma2-2b-it-sfp-single.sbs
|
||||
```
|
||||
|
||||
### Troubleshooting and FAQs
|
||||
|
||||
**Running `./gemma` fails with "Failed to read cache gating_ein_0 (error 294) ..."**
|
||||
|
||||
The most common problem is that the `--weight_type` argument does not match that
|
||||
of the model file. Revisit step #3 and check which weights you downloaded.
|
||||
|
||||
Note that we have already moved weight type from a compile-time decision to a
|
||||
runtime argument. In a subsequent step, we plan to bake this information into
|
||||
the weights.
|
||||
|
||||
**Problems building in Windows / Visual Studio**
|
||||
|
||||
Currently if you're using Windows, we recommend building in WSL (Windows
|
||||
Subsystem for Linux). We are exploring options to enable other build
|
||||
configurations, see issues for active discussion.
|
||||
|
||||
**Model does not respond to instructions and produces strange output**
|
||||
|
||||
A common issue is that you are using a pre-trained model, which is not
|
||||
instruction-tuned and thus does not respond to instructions. Make sure you are
|
||||
using an instruction-tuned model (`2b-it-sfp`, `2b-it`, `7b-it-sfp`, `7b-it`)
|
||||
and not a pre-trained model (any model with a `-pt` suffix).
|
||||
|
||||
**What sequence lengths are supported?**
|
||||
|
||||
See `seq_len` in `configs.cc`. For the Gemma 3 models larger than 1B, this is
|
||||
typically 32K but 128K would also work given enough RAM. Note that long
|
||||
sequences will be slow due to the quadratic cost of attention.
|
||||
|
||||
**How do I convert my fine-tune to a `.sbs` compressed model file?**
|
||||
|
||||
For PaliGemma (1 and 2) checkpoints, you can use
|
||||
python/convert_from_safetensors.py to convert from safetensors format (tested
|
||||
with building via bazel). For an adapter model, you will likely need to call
|
||||
merge_and_unload() to convert the adapter model to a single-file format before
|
||||
converting it.
|
||||
|
||||
Here is how to use it using a bazel build of the compression library assuming
|
||||
locally installed (venv) torch, numpy, safetensors, absl-py, etc.:
|
||||
|
||||
```sh
|
||||
bazel build //compression/python:compression
|
||||
BAZEL_OUTPUT_DIR="${PWD}/bazel-bin/compression"
|
||||
python3 -c "import site; print(site.getsitepackages())"
|
||||
# Use your sites-packages file here:
|
||||
ln -s $BAZEL_OUTPUT_DIR [...]/site-packages/compression
|
||||
python3 python/convert_from_safetensors.py --load_path [...].safetensors.index.json
|
||||
```
|
||||
|
||||
See also compression/convert_weights.py for a slightly older option to convert a
|
||||
pytorch checkpoint. (The code may need updates to work with Gemma-2 models.)
|
||||
|
||||
**What are some easy ways to make the model run faster?**
|
||||
|
||||
1. Make sure you are using the 8-bit switched floating point `-sfp` models.
|
||||
These are half the size of bf16 and thus use less memory bandwidth and cache
|
||||
space.
|
||||
2. If you're on a laptop, make sure power mode is set to maximize performance
|
||||
and saving mode is **off**. For most laptops, the power saving modes get
|
||||
activated automatically if the computer is not plugged in.
|
||||
3. Close other unused cpu-intensive applications.
|
||||
4. On macs, anecdotally we observe a "warm-up" ramp-up in speed as performance
|
||||
cores get engaged.
|
||||
5. Experiment with the `--num_threads` argument value. Depending on the device,
|
||||
larger numbers don't always mean better performance.
|
||||
|
||||
We're also working on algorithmic and optimization approaches for faster
|
||||
inference, stay tuned.
|
||||
|
||||
## Usage
|
||||
|
||||
`gemma` has different usage modes, controlled by the verbosity flag.
|
||||
|
||||
All usage modes are currently interactive, triggering text generation upon
|
||||
newline input.
|
||||
|
||||
| Verbosity | Usage mode | Details |
|
||||
| --------------- | ---------- | --------------------------------------------- |
|
||||
| `--verbosity 0` | Minimal | Only prints generation output. Suitable as a CLI tool. |
|
||||
| `--verbosity 1` | Default | Standard user-facing terminal UI. |
|
||||
| `--verbosity 2` | Detailed | Shows additional developer and debug info. |
|
||||
|
||||
### Interactive Terminal App
|
||||
|
||||
By default, verbosity is set to 1, bringing up a terminal-based interactive
|
||||
interface when `gemma` is invoked:
|
||||
|
||||
```console
|
||||
$ ./gemma [...]
|
||||
__ _ ___ _ __ ___ _ __ ___ __ _ ___ _ __ _ __
|
||||
/ _` |/ _ \ '_ ` _ \| '_ ` _ \ / _` | / __| '_ \| '_ \
|
||||
| (_| | __/ | | | | | | | | | | (_| || (__| |_) | |_) |
|
||||
\__, |\___|_| |_| |_|_| |_| |_|\__,_(_)___| .__/| .__/
|
||||
__/ | | | | |
|
||||
|___/ |_| |_|
|
||||
|
||||
tokenizer : tokenizer.spm
|
||||
compressed_weights : 2b-it-sfp.sbs
|
||||
model : 2b-it
|
||||
weights : [no path specified]
|
||||
max_generated_tokens : 2048
|
||||
|
||||
*Usage*
|
||||
Enter an instruction and press enter (%C reset conversation, %Q quits).
|
||||
|
||||
*Examples*
|
||||
- Write an email to grandma thanking her for the cookies.
|
||||
- What are some historical attractions to visit around Massachusetts?
|
||||
- Compute the nth fibonacci number in javascript.
|
||||
- Write a standup comedy bit about WebGPU programming.
|
||||
|
||||
> What are some outdoorsy places to visit around Boston?
|
||||
|
||||
[ Reading prompt ] .....................
|
||||
|
||||
|
||||
**Boston Harbor and Islands:**
|
||||
|
||||
* **Boston Harbor Islands National and State Park:** Explore pristine beaches, wildlife, and maritime history.
|
||||
* **Charles River Esplanade:** Enjoy scenic views of the harbor and city skyline.
|
||||
* **Boston Harbor Cruise Company:** Take a relaxing harbor cruise and admire the city from a different perspective.
|
||||
* **Seaport Village:** Visit a charming waterfront area with shops, restaurants, and a seaport museum.
|
||||
|
||||
**Forest and Nature:**
|
||||
|
||||
* **Forest Park:** Hike through a scenic forest with diverse wildlife.
|
||||
* **Quabbin Reservoir:** Enjoy boating, fishing, and hiking in a scenic setting.
|
||||
* **Mount Forest:** Explore a mountain with breathtaking views of the city and surrounding landscape.
|
||||
|
||||
...
|
||||
```
|
||||
|
||||
### Usage as a Command Line Tool
|
||||
|
||||
For using the `gemma` executable as a command line tool, it may be useful to
|
||||
create an alias for gemma.cpp with arguments fully specified:
|
||||
|
||||
```sh
|
||||
alias gemma2b="~/gemma.cpp/build/gemma -- --tokenizer ~/gemma.cpp/build/tokenizer.spm --weights ~/gemma.cpp/build/gemma2-2b-it-sfp.sbs --model gemma2-2b-it --verbosity 0"
|
||||
```
|
||||
|
||||
Replace the above paths with your own paths to the model and tokenizer paths
|
||||
from the download.
|
||||
|
||||
Here is an example of prompting `gemma` with a truncated input
|
||||
file (using a `gemma2b` alias like defined above):
|
||||
|
||||
```sh
|
||||
cat configs.h | tail -n 35 | tr '\n' ' ' | xargs -0 echo "What does this C++ code do: " | gemma2b
|
||||
```
|
||||
|
||||
> [!NOTE]
|
||||
> CLI usage of gemma.cpp is experimental and should take context length
|
||||
> limitations into account.
|
||||
|
||||
The output of the above command should look like:
|
||||
|
||||
```console
|
||||
[ Reading prompt ] [...]
|
||||
This C++ code snippet defines a set of **constants** used in a large language model (LLM) implementation, likely related to the **attention mechanism**.
|
||||
|
||||
Let's break down the code:
|
||||
[...]
|
||||
```
|
||||
|
||||
### Incorporating gemma.cpp as a Library in your Project
|
||||
|
||||
The easiest way to incorporate gemma.cpp in your own project is to pull in
|
||||
gemma.cpp and dependencies using `FetchContent`. You can add the following to your
|
||||
CMakeLists.txt:
|
||||
|
||||
```
|
||||
include(FetchContent)
|
||||
|
||||
FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 53de76561cfc149d3c01037f0595669ad32a5e7c)
|
||||
FetchContent_MakeAvailable(sentencepiece)
|
||||
|
||||
FetchContent_Declare(gemma GIT_REPOSITORY https://github.com/google/gemma.cpp GIT_TAG origin/main)
|
||||
FetchContent_MakeAvailable(gemma)
|
||||
|
||||
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG da250571a45826b21eebbddc1e50d0c1137dee5f)
|
||||
FetchContent_MakeAvailable(highway)
|
||||
```
|
||||
|
||||
Note for the gemma.cpp `GIT_TAG`, you may replace `origin/main` for a specific
|
||||
commit hash if you would like to pin the library version.
|
||||
|
||||
After your executable is defined (substitute your executable name for
|
||||
`[Executable Name]` below):
|
||||
|
||||
```
|
||||
target_link_libraries([Executable Name] libgemma hwy hwy_contrib sentencepiece)
|
||||
FetchContent_GetProperties(gemma)
|
||||
FetchContent_GetProperties(sentencepiece)
|
||||
target_include_directories([Executable Name] PRIVATE ${gemma_SOURCE_DIR})
|
||||
target_include_directories([Executable Name] PRIVATE ${sentencepiece_SOURCE_DIR})
|
||||
```
|
||||
|
||||
### Building gemma.cpp as a Library
|
||||
|
||||
gemma.cpp can also be used as a library dependency in your own project. The
|
||||
shared library artifact can be built by modifying the make invocation to build
|
||||
the `libgemma` target instead of `gemma`.
|
||||
|
||||
> [!NOTE]
|
||||
> If you are using gemma.cpp in your own project with the `FetchContent` steps
|
||||
> in the previous section, building the library is done automatically by `cmake`
|
||||
> and this section can be skipped.
|
||||
|
||||
First, run `cmake`:
|
||||
|
||||
```sh
|
||||
cmake -B build
|
||||
```
|
||||
|
||||
Then, run `make` with the `libgemma` target:
|
||||
|
||||
```sh
|
||||
cd build
|
||||
make -j [number of parallel threads to use] libgemma
|
||||
```
|
||||
|
||||
If this is successful, you should now have a `libgemma` library file in the
|
||||
`build/` directory. On Unix platforms, the filename is `libgemma.a`.
|
||||
|
||||
## Independent Projects Using gemma.cpp
|
||||
|
||||
Some independent projects using gemma.cpp:
|
||||
|
||||
- [gemma-cpp-python - Python bindings](https://github.com/namtranase/gemma-cpp-python)
|
||||
- [lua-cgemma - Lua bindings](https://github.com/ufownl/lua-cgemma)
|
||||
- [Godot engine demo project](https://github.com/Rliop913/Gemma-godot-demo-project)
|
||||
|
||||
If you would like to have your project included, feel free to get in touch or
|
||||
submit a PR with a `README.md` edit.
|
||||
|
||||
## Acknowledgements and Contacts
|
||||
|
||||
gemma.cpp was started in fall 2023 by [Austin Huang](mailto:austinvhuang@google.com)
|
||||
and [Jan Wassenberg](mailto:janwas@google.com), and subsequently released February 2024
|
||||
thanks to contributions from Phil Culliton, Paul Chang, and Dan Zheng.
|
||||
|
||||
Griffin support was implemented in April 2024 thanks to contributions by Andrey
|
||||
Mikhaylov, Eugene Kliuchnikov, Jan Wassenberg, Jyrki Alakuijala, Lode
|
||||
Vandevenne, Luca Versari, Martin Bruse, Phil Culliton, Sami Boukortt, Thomas
|
||||
Fischbacher and Zoltan Szabadka.
|
||||
|
||||
Gemma-2 support was implemented in June/July 2024 with the help of several
|
||||
people.
|
||||
|
||||
PaliGemma support was implemented in September 2024 with contributions from
|
||||
Daniel Keysers.
|
||||
|
||||
[Jan Wassenberg](mailto:janwas@google.com) has continued to contribute many
|
||||
improvements, including major gains in efficiency, since the initial release.
|
||||
|
||||
This is not an officially supported Google product.
|
||||
|
|
|
|||
|
|
@ -13,383 +13,85 @@
|
|||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Shared between various frontends.
|
||||
// Argument parsing for Gemma.
|
||||
|
||||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_ARGS_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_ARGS_H_
|
||||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_ARGS_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_ARGS_H_
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "compression/io.h" // Path
|
||||
#include "compression/shared.h"
|
||||
#include "gemma/common.h"
|
||||
#include "gemma/gemma.h" // For CreateGemma
|
||||
#include "hwy/base.h" // HWY_IS_ASAN, HWY_ABORT
|
||||
#include "hwy/base.h" // HWY_ABORT
|
||||
#include "ops/matmul.h"
|
||||
#include "util/allocator.h"
|
||||
#include "util/args.h"
|
||||
#include "util/basics.h" // Tristate
|
||||
#include "util/threading.h"
|
||||
#include "util/threading_context.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
static inline const char* CompiledConfig() {
|
||||
if (HWY_IS_ASAN) {
|
||||
return "asan";
|
||||
} else if (HWY_IS_MSAN) {
|
||||
return "msan";
|
||||
} else if (HWY_IS_TSAN) {
|
||||
return "tsan";
|
||||
} else if (HWY_IS_HWASAN) {
|
||||
return "hwasan";
|
||||
} else if (HWY_IS_UBSAN) {
|
||||
return "ubsan";
|
||||
} else if (HWY_IS_DEBUG_BUILD) {
|
||||
return "dbg";
|
||||
} else {
|
||||
return "opt";
|
||||
}
|
||||
}
|
||||
template <typename Derived>
|
||||
struct ArgsBase {
|
||||
void Init() { static_cast<Derived*>(this)->ForEach(SetToDefault()); }
|
||||
|
||||
void InitAndParse(int argc, char* argv[]) {
|
||||
Init();
|
||||
static_cast<Derived*>(this)->ForEach(ParseOption(argc, argv));
|
||||
}
|
||||
|
||||
void Print(int min_verbosity = 1) const {
|
||||
static_cast<const Derived*>(this)->ForEach(PrintOption(min_verbosity));
|
||||
}
|
||||
|
||||
void Help() const { static_cast<const Derived*>(this)->ForEach(PrintHelp()); }
|
||||
|
||||
protected:
|
||||
// Helper struct for printing help messages
|
||||
struct PrintHelp {
|
||||
template <typename T>
|
||||
void operator()(const T& value, const char* name, const T& default_value,
|
||||
const char* description, int verbosity = 1) const {
|
||||
fprintf(stderr, " --%s\n %s\n", name, description);
|
||||
}
|
||||
// Special case for strings to avoid template deduction issues
|
||||
void operator()(const std::string& value, const char* name,
|
||||
const std::string& default_value, const char* description,
|
||||
int verbosity = 1) const {
|
||||
fprintf(stderr, " --%s\n %s\n", name, description);
|
||||
}
|
||||
// Special case for Path type
|
||||
void operator()(const Path& value, const char* name,
|
||||
const Path& default_value, const char* description,
|
||||
int verbosity = 1) const {
|
||||
fprintf(stderr, " --%s\n %s\n", name, description);
|
||||
}
|
||||
};
|
||||
|
||||
// Helper struct for setting default values
|
||||
struct SetToDefault {
|
||||
template <typename T>
|
||||
void operator()(T& value, const char* name, const T& default_value,
|
||||
const char* description, int verbosity = 1) const {
|
||||
value = default_value;
|
||||
}
|
||||
};
|
||||
|
||||
// Helper struct for printing values
|
||||
struct PrintOption {
|
||||
explicit PrintOption(int min_verbosity) : min_verbosity_(min_verbosity) {}
|
||||
|
||||
template <typename T>
|
||||
void operator()(const T& value, const char* name, const T& default_value,
|
||||
const char* description, int verbosity = 1) const {
|
||||
if (verbosity >= min_verbosity_) {
|
||||
fprintf(stderr, "%s: %s\n", name, ToString(value).c_str());
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
int min_verbosity_;
|
||||
|
||||
// Helper function to convert values to string
|
||||
template <typename T>
|
||||
static std::string ToString(const T& value) {
|
||||
return std::to_string(value);
|
||||
}
|
||||
// Specialization for string
|
||||
static std::string ToString(const std::string& value) { return value; }
|
||||
// Specialization for Path
|
||||
static std::string ToString(const Path& value) { return value.path; }
|
||||
};
|
||||
};
|
||||
struct ThreadingArgs : public ArgsBase<ThreadingArgs> {
|
||||
public:
|
||||
ThreadingArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
|
||||
ThreadingArgs() { Init(); };
|
||||
|
||||
int verbosity;
|
||||
|
||||
size_t max_threads; // divided among the detected clusters
|
||||
Tristate pin; // pin threads?
|
||||
Tristate spin; // use spin waits?
|
||||
|
||||
// For BoundedSlice:
|
||||
size_t skip_packages;
|
||||
size_t max_packages;
|
||||
size_t skip_clusters;
|
||||
size_t max_clusters;
|
||||
size_t skip_lps;
|
||||
size_t max_lps;
|
||||
|
||||
std::string eot_line;
|
||||
std::string prompt;
|
||||
template <class Visitor>
|
||||
void ForEach(const Visitor& visitor) {
|
||||
visitor(verbosity, "verbosity", 1,
|
||||
"Show verbose developer information\n 0 = only print generation "
|
||||
"output\n 1 = standard user-facing terminal ui\n 2 = show "
|
||||
"developer/debug info).\n Default = 1.",
|
||||
2);
|
||||
|
||||
// The exact meaning is more subtle: see the comment at NestedPools ctor.
|
||||
visitor(max_threads, "num_threads", size_t{0},
|
||||
"Maximum number of threads to use; default 0 = unlimited.", 2);
|
||||
visitor(pin, "pin", Tristate::kDefault,
|
||||
"Pin threads? -1 = auto, 0 = no, 1 = yes.", 2);
|
||||
visitor(spin, "spin", Tristate::kDefault,
|
||||
"Use spin waits? -1 = auto, 0 = no, 1 = yes.", 2);
|
||||
// These can be used to partition CPU sockets/packages and their
|
||||
// clusters/CCXs across several program instances. The default is to use
|
||||
// all available resources.
|
||||
visitor(skip_packages, "skip_packages", size_t{0},
|
||||
"Index of the first socket to use; default 0 = unlimited.", 2);
|
||||
visitor(max_packages, "max_packages", size_t{0},
|
||||
"Maximum number of sockets to use; default 0 = unlimited.", 2);
|
||||
visitor(skip_clusters, "skip_clusters", size_t{0},
|
||||
"Index of the first CCX to use; default 0 = unlimited.", 2);
|
||||
visitor(max_clusters, "max_clusters", size_t{0},
|
||||
"Maximum number of CCXs to use; default 0 = unlimited.", 2);
|
||||
// These are only used when CPU topology is unknown.
|
||||
visitor(skip_lps, "skip_lps", size_t{0},
|
||||
"Index of the first LP to use; default 0 = unlimited.", 2);
|
||||
visitor(max_lps, "max_lps", size_t{0},
|
||||
"Maximum number of LPs to use; default 0 = unlimited.", 2);
|
||||
|
||||
visitor(
|
||||
eot_line, "eot_line", std::string(""),
|
||||
"End of turn line. "
|
||||
"When you specify this, the prompt will be all lines "
|
||||
"before the line where only the given string appears.\n Default = "
|
||||
"When a newline is encountered, that signals the end of the turn.",
|
||||
2);
|
||||
|
||||
visitor(prompt, "prompt", std::string(""),
|
||||
"Prompt string for non-interactive mode. When provided, the model "
|
||||
"generates a response and exits.",
|
||||
2);
|
||||
}
|
||||
};
|
||||
static inline BoundedTopology CreateTopology(const ThreadingArgs& threading) {
|
||||
return BoundedTopology(
|
||||
BoundedSlice(threading.skip_packages, threading.max_packages),
|
||||
BoundedSlice(threading.skip_clusters, threading.max_clusters),
|
||||
BoundedSlice(threading.skip_lps, threading.max_lps));
|
||||
}
|
||||
|
||||
static inline MatMulEnv MakeMatMulEnv(const ThreadingArgs& threading) {
|
||||
ThreadingContext2::SetArgs(threading);
|
||||
return MatMulEnv(ThreadingContext2::Get());
|
||||
}
|
||||
// Note: These functions may need adjustments depending on your specific class
|
||||
// definitions
|
||||
static inline BoundedTopology CreateTopology(const ThreadingArgs& app) {
|
||||
return BoundedTopology(BoundedSlice(app.skip_packages, app.max_packages),
|
||||
BoundedSlice(app.skip_clusters, app.max_clusters),
|
||||
BoundedSlice(app.skip_lps, app.max_lps));
|
||||
}
|
||||
|
||||
// This function may need to be adjusted based on your NestedPools constructor
|
||||
// signature
|
||||
static inline NestedPools CreatePools(const BoundedTopology& topology,
|
||||
const ThreadingArgs& threading) {
|
||||
// Make sure Allocator::Init() is properly declared/defined
|
||||
const Allocator2& allocator = ThreadingContext2::Get().allocator;
|
||||
// Allocator::Init(topology);
|
||||
|
||||
// Adjust the constructor call based on your actual NestedPools constructor
|
||||
// The error suggests that the constructor doesn't match these arguments
|
||||
return NestedPools(topology, allocator, threading.max_threads, threading.pin);
|
||||
// Alternative: return NestedPools(topology, app.max_threads, app.pin);
|
||||
}
|
||||
|
||||
struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
||||
LoaderArgs(int argc, char* argv[], bool validate = true) {
|
||||
InitAndParse(argc, argv);
|
||||
|
||||
if (validate) {
|
||||
if (const char* error = Validate()) {
|
||||
HWY_ABORT("Invalid args: %s", error);
|
||||
}
|
||||
}
|
||||
}
|
||||
LoaderArgs(const std::string& tokenizer_path, const std::string& weights_path,
|
||||
const std::string& model, bool validate = true) {
|
||||
Init(); // Init sets to defaults, so assignments must come after Init().
|
||||
tokenizer.path = tokenizer_path;
|
||||
weights.path = weights_path;
|
||||
model_type_str = model;
|
||||
|
||||
if (validate) {
|
||||
if (const char* error = Validate()) {
|
||||
HWY_ABORT("Invalid args: %s", error);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Returns error string or nullptr if OK.
|
||||
const char* Validate() {
|
||||
if (weights.path.empty()) {
|
||||
return "Missing --weights flag, a file for the model weights.";
|
||||
}
|
||||
if (!weights.Exists()) {
|
||||
return "Can't open file specified with --weights flag.";
|
||||
}
|
||||
info_.model = Model::UNKNOWN;
|
||||
info_.wrapping = PromptWrapping::GEMMA_PT;
|
||||
info_.weight = Type::kUnknown;
|
||||
if (!model_type_str.empty()) {
|
||||
const char* err = ParseModelTypeAndWrapping(model_type_str, info_.model,
|
||||
info_.wrapping);
|
||||
if (err != nullptr) return err;
|
||||
}
|
||||
if (!weight_type_str.empty()) {
|
||||
const char* err = ParseType(weight_type_str, info_.weight);
|
||||
if (err != nullptr) return err;
|
||||
}
|
||||
if (!tokenizer.path.empty()) {
|
||||
if (!tokenizer.Exists()) {
|
||||
return "Can't open file specified with --tokenizer flag.";
|
||||
}
|
||||
}
|
||||
// model_type and tokenizer must be either both present or both absent.
|
||||
// Further checks happen on weight loading.
|
||||
if (model_type_str.empty() != tokenizer.path.empty()) {
|
||||
return "Missing or extra flags for model_type or tokenizer.";
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Path tokenizer;
|
||||
Path weights; // weights file location
|
||||
Path compressed_weights;
|
||||
std::string model_type_str;
|
||||
std::string weight_type_str;
|
||||
|
||||
template <class Visitor>
|
||||
void ForEach(const Visitor& visitor) {
|
||||
visitor(tokenizer, "tokenizer", Path(),
|
||||
"Path name of tokenizer model file.");
|
||||
visitor(weights, "weights", Path(),
|
||||
"Path name of model weights (.sbs) file.\n Required argument.\n");
|
||||
visitor(compressed_weights, "compressed_weights", Path(),
|
||||
"Deprecated alias for --weights.");
|
||||
visitor(model_type_str, "model", std::string(),
|
||||
"Model type, see common.cc for valid values.\n");
|
||||
visitor(weight_type_str, "weight_type", std::string("sfp"),
|
||||
"Weight type\n f32 = float, bf16 = bfloat16, sfp = 8-bit SFP.");
|
||||
}
|
||||
|
||||
// Uninitialized before Validate, must call after that.
|
||||
const ModelInfo& Info() const { return info_; }
|
||||
|
||||
private:
|
||||
// TODO(rays): remove this. Eventually ModelConfig will be loaded from the
|
||||
// weights file, so we can remove the need for this struct entirely.
|
||||
ModelInfo info_;
|
||||
};
|
||||
|
||||
// `env` must remain valid for the lifetime of the Gemma.
|
||||
static inline Gemma CreateGemma(const LoaderArgs& loader, MatMulEnv& env) {
|
||||
if (Type::kUnknown == loader.Info().weight ||
|
||||
Model::UNKNOWN == loader.Info().model || loader.tokenizer.path.empty()) {
|
||||
// New weights file format doesn't need tokenizer path or model/weightinfo.
|
||||
return Gemma(loader.weights, env);
|
||||
}
|
||||
return Gemma(loader.tokenizer, loader.weights, loader.Info(), env);
|
||||
}
|
||||
|
||||
// `env` must remain valid for the lifetime of the Gemma.
|
||||
static inline std::unique_ptr<Gemma> AllocateGemma(const LoaderArgs& loader,
|
||||
MatMulEnv& env) {
|
||||
if (Type::kUnknown == loader.Info().weight ||
|
||||
Model::UNKNOWN == loader.Info().model || loader.tokenizer.path.empty()) {
|
||||
// New weights file format doesn't need tokenizer path or model/weight info.
|
||||
return std::make_unique<Gemma>(loader.weights, env);
|
||||
}
|
||||
return std::make_unique<Gemma>(loader.tokenizer, loader.weights,
|
||||
loader.Info(), env);
|
||||
}
|
||||
|
||||
// Arguments related to inference: sampling, text etc.
|
||||
struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
||||
InferenceArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
|
||||
InferenceArgs() { Init(); };
|
||||
|
||||
int verbosity;
|
||||
|
||||
// Arguments for getc-like interfaces
|
||||
size_t max_tokens;
|
||||
size_t max_generated_tokens;
|
||||
|
||||
size_t prefill_tbatch_size;
|
||||
size_t decode_qbatch_size;
|
||||
|
||||
float temperature;
|
||||
size_t top_k;
|
||||
bool deterministic;
|
||||
bool multiturn;
|
||||
Path image_file;
|
||||
float top_p;
|
||||
float min_p;
|
||||
int repeat_penalty_power;
|
||||
float repeat_penalty_presence;
|
||||
float repeat_penalty_decay;
|
||||
float repeat_penalty_range;
|
||||
|
||||
// Batch configuration:
|
||||
size_t prefill_tbatch_size;
|
||||
size_t decode_tbatch_size;
|
||||
|
||||
// Non-interactive mode prompt
|
||||
std::string prompt;
|
||||
std::string eot_line;
|
||||
|
||||
// Returns error string or nullptr if OK.
|
||||
const char* Validate() const {
|
||||
if (max_generated_tokens > gcpp::kSeqLen) {
|
||||
return "max_generated_tokens is larger than the maximum sequence length "
|
||||
"(see configs.h).";
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
template <class Visitor>
|
||||
void ForEach(const Visitor& visitor) {
|
||||
visitor(verbosity, "verbosity", 1,
|
||||
"Show verbose developer information\n 0 = only print generation "
|
||||
"output\n 1 = standard user-facing terminal ui\n 2 = show "
|
||||
"developer/debug info).\n Default = 1.",
|
||||
2);
|
||||
void ForEach(Visitor& visitor) {
|
||||
// Each line specifies a variable member, its name, default value, and help.
|
||||
visitor(max_tokens, "max_tokens", size_t{50},
|
||||
"Maximum number of total tokens including prompt (0=no limit).", 1);
|
||||
visitor(max_generated_tokens, "max_generated_tokens", size_t{512},
|
||||
"Maximum number of generated tokens (not including prompt) (0=no "
|
||||
"limit).",
|
||||
1);
|
||||
visitor(temperature, "temperature", 1.0f,
|
||||
"Temperature (randomness) for logits.", 1);
|
||||
visitor(top_k, "top_k", size_t{40},
|
||||
"Number of highest-probability tokens to consider (0=unlimited).",
|
||||
1);
|
||||
visitor(top_p, "top_p", 0.9f, "Top-p probability threshold (0.0=disabled).",
|
||||
1);
|
||||
visitor(min_p, "min_p", 0.0f, "Min-p probability threshold (0.0=disabled).",
|
||||
1);
|
||||
visitor(
|
||||
repeat_penalty_power, "repeat_penalty_power", 1,
|
||||
"Penalty power (1=standard frequentist penalty). If 0, skips penalty "
|
||||
"computation.",
|
||||
1);
|
||||
visitor(repeat_penalty_presence, "repeat_penalty_presence", 0.0f,
|
||||
"Penalty for token presence regardless of frequency (additive).",
|
||||
1);
|
||||
visitor(repeat_penalty_decay, "repeat_penalty_decay", 0.0f,
|
||||
"Penalty for token n positions ago is decayed by "
|
||||
"power(repeat_penalty_decay, n).",
|
||||
1);
|
||||
visitor(repeat_penalty_range, "repeat_penalty_range", 8.0f,
|
||||
"Penalty fades out near the end of range (tokens)", 1);
|
||||
|
||||
visitor(max_generated_tokens, "max_generated_tokens", size_t{2048},
|
||||
"Maximum number of tokens to generate.");
|
||||
|
||||
visitor(prefill_tbatch_size, "prefill_tbatch", size_t{256},
|
||||
"Prefill: max tokens per batch.");
|
||||
visitor(decode_qbatch_size, "decode_qbatch", size_t{16},
|
||||
"Decode: max queries per batch.");
|
||||
|
||||
visitor(temperature, "temperature", 1.0f, "Temperature for top-K", 2);
|
||||
visitor(top_k, "top_k", size_t{1}, "Number of top-K tokens to sample from",
|
||||
2);
|
||||
visitor(deterministic, "deterministic", false,
|
||||
"Make top-k sampling deterministic", 2);
|
||||
visitor(multiturn, "multiturn", false,
|
||||
"Multiturn mode\n 0 = clear KV cache after every "
|
||||
"interaction\n 1 = continue KV cache after every interaction\n "
|
||||
" Default : 0 (conversation "
|
||||
"resets every turn)");
|
||||
visitor(image_file, "image_file", Path(), "Image file to load.");
|
||||
// Batch configuration:
|
||||
visitor(prefill_tbatch_size, "prefill_tbatch_size", size_t{2},
|
||||
"Token batch size for prefill; <= 32", 2);
|
||||
visitor(decode_tbatch_size, "decode_tbatch_size", size_t{1},
|
||||
"Token batch size for decode (only 1 currently supported)", 2);
|
||||
|
||||
visitor(
|
||||
eot_line, "eot_line", std::string(""),
|
||||
|
|
@ -397,47 +99,123 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
|||
"When you specify this, the prompt will be all lines "
|
||||
"before the line where only the given string appears.\n Default = "
|
||||
"When a newline is encountered, that signals the end of the turn.",
|
||||
2);
|
||||
1);
|
||||
|
||||
// Non-interactive mode prompt
|
||||
visitor(prompt, "prompt", std::string(""),
|
||||
"Prompt to use in non-interactive mode", 1);
|
||||
}
|
||||
|
||||
void CopyTo(RuntimeConfig& runtime_config) const {
|
||||
runtime_config.max_generated_tokens = max_generated_tokens;
|
||||
runtime_config.prefill_tbatch_size = prefill_tbatch_size;
|
||||
runtime_config.decode_qbatch_size = decode_qbatch_size;
|
||||
if (prefill_tbatch_size > MMStorage::kMaxM) {
|
||||
HWY_ABORT(
|
||||
"prefill_tbatch_size %zu > kMaxM %zu: specify a smaller value, "
|
||||
"or increase the constant in MMStorage.\n",
|
||||
prefill_tbatch_size, MMStorage::kMaxM);
|
||||
const char* Validate() const {
|
||||
if (max_generated_tokens == 0 && max_tokens == 0) {
|
||||
return "At least one of max_tokens and max_generated_tokens must be > 0";
|
||||
}
|
||||
if (decode_qbatch_size > MMStorage::kMaxM) {
|
||||
HWY_ABORT(
|
||||
"decode_qbatch_size %zu > kMaxM %zu: specify a smaller value, "
|
||||
"or increase the constant in MMStorage.\n",
|
||||
decode_qbatch_size, MMStorage::kMaxM);
|
||||
if (temperature <= 0.0) {
|
||||
return "Temperature must be > 0.0";
|
||||
}
|
||||
|
||||
runtime_config.temperature = temperature;
|
||||
runtime_config.top_k = top_k;
|
||||
if (prefill_tbatch_size > 32) {
|
||||
return "prefill_tbatch_size must be <= 32";
|
||||
}
|
||||
if (decode_tbatch_size != 1) {
|
||||
return "decode_tbatch_size must be 1";
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
};
|
||||
|
||||
static inline void ShowConfig(const ThreadingArgs& threading,
|
||||
const LoaderArgs& loader,
|
||||
const InferenceArgs& inference) {
|
||||
threading.Print();
|
||||
loader.Print();
|
||||
inference.Print();
|
||||
}
|
||||
static inline void ShowHelp(const ThreadingArgs& threading,
|
||||
const LoaderArgs& loader,
|
||||
const InferenceArgs& inference) {
|
||||
fprintf(stderr, "\nUsage: gemma [flags]\n\nFlags:\n");
|
||||
threading.Help();
|
||||
loader.Help();
|
||||
inference.Help();
|
||||
}
|
||||
// Arguments related to model weights.
|
||||
struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
||||
Path model_path; // Path to directory containing the weights
|
||||
Path tokenizer; // Optional: can be derived from model_path
|
||||
bool model_is_gemma2;
|
||||
Gemma::Config::WeightFormat weight_format;
|
||||
|
||||
template <class Visitor>
|
||||
void ForEach(Visitor& visitor) {
|
||||
// Each line specifies a variable member, its name, default value, and help.
|
||||
visitor(model_path, "model", Path{},
|
||||
"Directory containing weights or config file from `gemma.cpp "
|
||||
"convert`.",
|
||||
0);
|
||||
visitor(tokenizer, "tokenizer", Path{},
|
||||
"Optional path to tokenizer.model; if empty, looks in model_path.",
|
||||
2);
|
||||
visitor(model_is_gemma2, "model_is_gemma2", false,
|
||||
"Whether the model is a Gemma 2 model", 1);
|
||||
visitor(weight_format, "format", Gemma::Config::kBfloat16,
|
||||
"Model weights format: 0=F32, 1=F16, 2=BF16", 2);
|
||||
}
|
||||
|
||||
const char* Validate() const {
|
||||
if (model_path.path.empty()) {
|
||||
return "Empty model path";
|
||||
}
|
||||
if (weight_format != Gemma::Config::kBfloat16 &&
|
||||
weight_format != Gemma::Config::kFloat16 &&
|
||||
weight_format != Gemma::Config::kFloat32) {
|
||||
return "Invalid weight format";
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
};
|
||||
|
||||
// Threading-related arguments.
|
||||
struct ThreadingArgs : public ArgsBase<ThreadingArgs> {
|
||||
size_t num_threads;
|
||||
Tristate pin_threads;
|
||||
Tristate use_spinning;
|
||||
int verbosity;
|
||||
|
||||
template <class Visitor>
|
||||
void ForEach(Visitor& visitor) {
|
||||
visitor(num_threads, "threads", size_t{0},
|
||||
"Number of threads (0=auto, half of logical cores)", 1);
|
||||
visitor(pin_threads, "pin_threads", Tristate::kDefault,
|
||||
"Set to true/false to force enable/disable thread pinning.", 2);
|
||||
visitor(use_spinning, "use_spinning", Tristate::kDefault,
|
||||
"Set to true/false to enable/disable thread spinning (typically "
|
||||
"improves "
|
||||
"performance but increases power usage)",
|
||||
2);
|
||||
visitor(verbosity, "verbosity", 1,
|
||||
"Controls printing of progress messages to stderr", 1);
|
||||
}
|
||||
|
||||
// Returns nullptr if OK, otherwise error message.
|
||||
const char* Validate() const { return nullptr; }
|
||||
|
||||
// Returns num_threads to use.
|
||||
size_t NumThreadsToUse() const {
|
||||
return num_threads == 0 ? (size_t{hwy::NumberOfProcessors()} + 1) / 2
|
||||
: num_threads;
|
||||
}
|
||||
};
|
||||
|
||||
// Command-line arguments for PeftGemma and Gemma.
|
||||
struct GemmaArgs : public ArgsBase<GemmaArgs> {
|
||||
InferenceArgs inference;
|
||||
LoaderArgs loader;
|
||||
ThreadingArgs threading;
|
||||
// For collect_stats.cc:
|
||||
Path output;
|
||||
|
||||
bool trace_outputs; // For -ftrace and dump_csv.cc
|
||||
bool trace_base; // For -ftrace
|
||||
int time_it; // For time_it.cc
|
||||
|
||||
template <class Visitor>
|
||||
void ForEach(Visitor& visitor) {
|
||||
inference.ForEach(visitor);
|
||||
loader.ForEach(visitor);
|
||||
threading.ForEach(visitor);
|
||||
|
||||
visitor(output, "output", Path{}, "Where to write CSV data / stats", 2);
|
||||
visitor(trace_outputs, "trace_outputs", false, "For tracing", 2);
|
||||
visitor(trace_base, "trace_base", false, "For tracing", 2);
|
||||
visitor(time_it, "time_it", 0, "For benchmarks", 2);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_ARGS_H_
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_ARGS_H_
|
||||
100
gemma/run.cc
100
gemma/run.cc
|
|
@ -78,6 +78,18 @@ std::string GetPrompt(std::istream& input, int verbosity,
|
|||
return prompt_string;
|
||||
}
|
||||
|
||||
// New GetPrompt function that accepts InferenceArgs
|
||||
std::string GetPrompt(const InferenceArgs& inference, int verbosity,
|
||||
size_t turn) {
|
||||
// Check for command-line prompt first
|
||||
if (!inference.prompt.empty()) {
|
||||
return inference.prompt;
|
||||
}
|
||||
|
||||
// Use the existing function for interactive mode
|
||||
return GetPrompt(std::cin, verbosity, inference.eot_line);
|
||||
}
|
||||
|
||||
// The main Read-Eval-Print Loop.
|
||||
void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
|
||||
Gemma& model, KVCache& kv_cache) {
|
||||
|
|
@ -89,6 +101,9 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
|
|||
std::mt19937 gen;
|
||||
InitGenerator(inference, gen);
|
||||
|
||||
// Add flag to track non-interactive mode
|
||||
bool non_interactive_mode = !inference.prompt.empty();
|
||||
|
||||
const bool have_image = !inference.image_file.path.empty();
|
||||
Image image;
|
||||
ImageTokens image_tokens;
|
||||
|
|
@ -151,47 +166,30 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
|
|||
|
||||
// Read prompt and handle special commands.
|
||||
std::string prompt_string =
|
||||
GetPrompt(std::cin, inference.verbosity, inference.eot_line);
|
||||
if (!std::cin) return;
|
||||
GetPrompt(inference, inference.verbosity, abs_pos);
|
||||
|
||||
if (!std::cin && !non_interactive_mode) return;
|
||||
|
||||
// If !eot_line.empty(), we append \n, so only look at the first 2 chars.
|
||||
if (prompt_string.size() >= 2 && prompt_string[0] == '%') {
|
||||
if (!non_interactive_mode && prompt_string.size() >= 2 &&
|
||||
prompt_string[0] == '%') {
|
||||
if (prompt_string[1] == 'q' || prompt_string[1] == 'Q') return;
|
||||
if (prompt_string[1] == 'c' || prompt_string[1] == 'C') {
|
||||
abs_pos = 0;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
if (prompt_string.empty()) {
|
||||
|
||||
if (!non_interactive_mode && prompt_string.empty()) {
|
||||
std::cout << "Use '%q' to quit.\n";
|
||||
continue;
|
||||
}
|
||||
|
||||
// Wrap, tokenize and maybe log prompt tokens.
|
||||
std::vector<int> prompt = WrapAndTokenize(model.Tokenizer(), model.Info(),
|
||||
abs_pos, prompt_string);
|
||||
prompt_size = prompt.size();
|
||||
if constexpr (kVerboseLogTokens) {
|
||||
for (int i = 0; i < prompt_size; ++i) {
|
||||
fprintf(stderr, "DDD TOKEN %3d: %6d\n", i, prompt[i]);
|
||||
}
|
||||
}
|
||||
|
||||
// Set up runtime config.
|
||||
TimingInfo timing_info = {.verbosity = inference.verbosity};
|
||||
RuntimeConfig runtime_config = {.gen = &gen,
|
||||
.verbosity = inference.verbosity,
|
||||
.stream_token = stream_token,
|
||||
.use_spinning = threading.spin};
|
||||
inference.CopyTo(runtime_config);
|
||||
size_t prefix_end = 0;
|
||||
|
||||
std::vector<int> prompt;
|
||||
if (have_image) {
|
||||
prompt =
|
||||
WrapAndTokenize(model.Tokenizer(), model.ChatTemplate(), model.Info(),
|
||||
abs_pos, prompt_string, image_tokens.BatchSize());
|
||||
runtime_config.image_tokens = &image_tokens;
|
||||
prompt_size = prompt.size();
|
||||
// The end of the prefix for prefix-LM style attention in Paligemma.
|
||||
// See Figure 2 of https://arxiv.org/abs/2407.07726.
|
||||
prefix_end = prompt_size;
|
||||
|
|
@ -209,6 +207,24 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
|
|||
}
|
||||
}
|
||||
|
||||
// Set up runtime config.
|
||||
TimingInfo timing_info = {.verbosity = inference.verbosity};
|
||||
RuntimeConfig runtime_config = {.gen = &gen,
|
||||
.verbosity = inference.verbosity,
|
||||
.stream_token = stream_token,
|
||||
.use_spinning = threading.spin};
|
||||
inference.CopyTo(runtime_config);
|
||||
size_t prefix_end = 0;
|
||||
|
||||
if (have_image) {
|
||||
runtime_config.image_tokens = &image_tokens;
|
||||
prompt_size = prompt.size();
|
||||
// The end of the prefix for prefix-LM style attention in Paligemma.
|
||||
prefix_end = prompt_size;
|
||||
// We need to look at all the tokens for the prefix.
|
||||
runtime_config.prefill_tbatch_size = prompt_size;
|
||||
}
|
||||
|
||||
// Generate until EOS or max_generated_tokens.
|
||||
if (inference.verbosity >= 1) {
|
||||
std::cerr << "\n[ Reading prompt ] " << std::flush;
|
||||
|
|
@ -217,6 +233,11 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
|
|||
timing_info);
|
||||
std::cout << "\n\n";
|
||||
|
||||
// Break the loop if in non-interactive mode
|
||||
if (non_interactive_mode) {
|
||||
break;
|
||||
}
|
||||
|
||||
// Prepare for the next turn. Works only for PaliGemma.
|
||||
if (!inference.multiturn ||
|
||||
model.Info().wrapping == PromptWrapping::PALIGEMMA) {
|
||||
|
|
@ -249,22 +270,6 @@ void Run(ThreadingArgs& threading, LoaderArgs& loader,
|
|||
KVCache kv_cache =
|
||||
KVCache::Create(model.GetModelConfig(), inference.prefill_tbatch_size);
|
||||
|
||||
if (!threading.prompt.empty()) {
|
||||
std::vector<int> prompt =
|
||||
WrapAndTokenize(model.Tokenizer(), model.ChatTemplate(), model.Info(),
|
||||
0, threading.prompt);
|
||||
|
||||
TimingInfo timing_info = {.verbosity = inference.verbosity};
|
||||
RuntimeConfig runtime_config = {.gen = nullptr, // Use default generator
|
||||
.verbosity = inference.verbosity,
|
||||
.use_spinning = threading.spin};
|
||||
inference.CopyTo(runtime_config);
|
||||
|
||||
model.Generate(runtime_config, prompt, 0, 0, kv_cache, timing_info);
|
||||
std::cout << "\n";
|
||||
return; // Exit after generating response
|
||||
}
|
||||
|
||||
if (inference.verbosity >= 1) {
|
||||
std::string instructions =
|
||||
"*Usage*\n"
|
||||
|
|
@ -286,10 +291,13 @@ void Run(ThreadingArgs& threading, LoaderArgs& loader,
|
|||
instructions += multiturn;
|
||||
instructions += examples;
|
||||
|
||||
std::cout << "\033[2J\033[1;1H" // clear screen
|
||||
<< kAsciiArtBanner << "\n\n";
|
||||
ShowConfig(threading, loader, inference);
|
||||
std::cout << "\n" << instructions << "\n";
|
||||
// Skip the banner and instructions in non-interactive mode
|
||||
if (inference.prompt.empty()) {
|
||||
std::cout << "\033[2J\033[1;1H" // clear screen
|
||||
<< kAsciiArtBanner << "\n\n";
|
||||
ShowConfig(threading, loader, inference);
|
||||
std::cout << "\n" << instructions << "\n";
|
||||
}
|
||||
}
|
||||
|
||||
ReplGemma(threading, inference, model, kv_cache);
|
||||
|
|
@ -328,4 +336,4 @@ int main(int argc, char** argv) {
|
|||
}
|
||||
PROFILER_PRINT_RESULTS(); // Must call outside the zone above.
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
Loading…
Reference in New Issue