mirror of https://github.com/google/gemma.cpp.git
Remove support for Gemma 1 and PaliGemma 1 models, superseded by (Pali)Gemma 2.
PiperOrigin-RevId: 756671308
This commit is contained in:
parent
d834c07042
commit
252a4e955e
164
README.md
164
README.md
|
|
@ -6,7 +6,7 @@ foundation models from Google.
|
|||
For additional information about Gemma, see
|
||||
[ai.google.dev/gemma](https://ai.google.dev/gemma). Model weights, including
|
||||
gemma.cpp specific artifacts, are
|
||||
[available on kaggle](https://www.kaggle.com/models/google/gemma).
|
||||
[available on kaggle](https://www.kaggle.com/models/google/gemma-2).
|
||||
|
||||
## Who is this project for?
|
||||
|
||||
|
|
@ -18,8 +18,8 @@ deployment-oriented C++ inference runtimes, which are not designed for
|
|||
experimentation, and Python-centric ML research frameworks, which abstract away
|
||||
low-level computation through compilation.
|
||||
|
||||
gemma.cpp provides a minimalist implementation of Gemma-1, Gemma-2, Gemma-3, and
|
||||
PaliGemma models, focusing on simplicity and directness rather than full
|
||||
gemma.cpp provides a minimalist implementation of Gemma-2, Gemma-3, and
|
||||
PaliGemma-2 models, focusing on simplicity and directness rather than full
|
||||
generality. This is inspired by vertically-integrated model implementations such
|
||||
as [ggml](https://github.com/ggerganov/ggml),
|
||||
[llama.c](https://github.com/karpathy/llama2.c), and
|
||||
|
|
@ -53,7 +53,7 @@ Guidelines](https://opensource.google.com/conduct/).
|
|||
|
||||
- LLM
|
||||
|
||||
- CPU-only inference for: Gemma 1-3, Griffin(SSM), PaliGemma 1-2.
|
||||
- CPU-only inference for: Gemma 2-3, Griffin(SSM), PaliGemma 2.
|
||||
- Sampling with TopK and temperature.
|
||||
- Backward pass (VJP) and Adam optimizer for Gemma research.
|
||||
|
||||
|
|
@ -106,57 +106,20 @@ winget install --id Microsoft.VisualStudio.2022.BuildTools --force --override "-
|
|||
|
||||
Visit the
|
||||
[Kaggle page for Gemma-2](https://www.kaggle.com/models/google/gemma-2/gemmaCpp)
|
||||
[or Gemma-1](https://www.kaggle.com/models/google/gemma/frameworks/gemmaCpp),
|
||||
and select `Model Variations |> Gemma C++`.
|
||||
|
||||
On this tab, the `Variation` dropdown includes the options below. Note bfloat16
|
||||
weights are higher fidelity, while 8-bit switched floating point weights enable
|
||||
faster inference. In general, we recommend starting with the `-sfp` checkpoints.
|
||||
|
||||
If you are unsure which model to start with, we recommend starting with the
|
||||
smallest Gemma-2 model, i.e. `2.0-2b-it-sfp`.
|
||||
|
||||
Alternatively, visit the
|
||||
[gemma.cpp](https://huggingface.co/models?other=gemma.cpp) models on the Hugging
|
||||
Face Hub. First go the model repository of the model of interest (see
|
||||
recommendations below). Then, click the `Files and versions` tab and download
|
||||
the model and tokenizer files. For programmatic downloading, if you have
|
||||
`huggingface_hub` installed, you can also download by running:
|
||||
|
||||
```
|
||||
huggingface-cli login # Just the first time
|
||||
huggingface-cli download google/gemma-2b-sfp-cpp --local-dir build/
|
||||
```
|
||||
|
||||
Gemma-1 2B instruction-tuned (`it`) and pre-trained (`pt`) models:
|
||||
|
||||
| Model name | Description |
|
||||
| ----------- | ----------- |
|
||||
| `2b-it` | 2 billion parameter instruction-tuned model, bfloat16 |
|
||||
| `2b-it-sfp` | 2 billion parameter instruction-tuned model, 8-bit switched floating point |
|
||||
| `2b-pt` | 2 billion parameter pre-trained model, bfloat16 |
|
||||
| `2b-pt-sfp` | 2 billion parameter pre-trained model, 8-bit switched floating point |
|
||||
|
||||
Gemma-1 7B instruction-tuned (`it`) and pre-trained (`pt`) models:
|
||||
|
||||
| Model name | Description |
|
||||
| ----------- | ----------- |
|
||||
| `7b-it` | 7 billion parameter instruction-tuned model, bfloat16 |
|
||||
| `7b-it-sfp` | 7 billion parameter instruction-tuned model, 8-bit switched floating point |
|
||||
| `7b-pt` | 7 billion parameter pre-trained model, bfloat16 |
|
||||
| `7b-pt-sfp` | 7 billion parameter pre-trained model, 8-bit switched floating point |
|
||||
|
||||
> [!NOTE]
|
||||
> **Important**: We strongly recommend starting off with the `2b-it-sfp` model to
|
||||
> get up and running.
|
||||
> [!NOTE] **Important**: We strongly recommend starting off with the
|
||||
> `gemma2-2b-it-sfp` model to get up and running.
|
||||
|
||||
Gemma 2 models are named `gemma2-2b-it` for 2B and `9b-it` or `27b-it`. See the
|
||||
`kModelFlags` definition in `common.cc`.
|
||||
`ModelPrefix` function in `configs.cc`.
|
||||
|
||||
### Step 2: Extract Files
|
||||
|
||||
If you downloaded the models from Hugging Face, skip to step 3.
|
||||
|
||||
After filling out the consent form, the download should proceed to retrieve a
|
||||
tar archive file `archive.tar.gz`. Extract files from `archive.tar.gz` (this can
|
||||
take a few minutes):
|
||||
|
|
@ -194,10 +157,9 @@ cmake --build --preset make -j [number of parallel threads to use]
|
|||
```
|
||||
|
||||
Replace `[number of parallel threads to use]` with a number - the number of
|
||||
cores available on your system is a reasonable heuristic. For example,
|
||||
`make -j4 gemma` will build using 4 threads. If the `nproc` command is
|
||||
available, you can use `make -j$(nproc) gemma` as a reasonable default
|
||||
for the number of threads.
|
||||
cores available on your system is a reasonable heuristic. For example, `make -j4
|
||||
gemma` will build using 4 threads. If the `nproc` command is available, you can
|
||||
use `make -j$(nproc) gemma` as a reasonable default for the number of threads.
|
||||
|
||||
If you aren't sure of the right value for the `-j` flag, you can simply run
|
||||
`make gemma` instead and it should still build the `./gemma` executable.
|
||||
|
|
@ -206,7 +168,8 @@ If you aren't sure of the right value for the `-j` flag, you can simply run
|
|||
> On Windows Subsystem for Linux (WSL) users should set the number of
|
||||
> parallel threads to 1. Using a larger number may result in errors.
|
||||
|
||||
If the build is successful, you should now have a `gemma` executable in the `build/` directory.
|
||||
If the build is successful, you should now have a `gemma` executable in the
|
||||
`build/` directory.
|
||||
|
||||
#### Windows
|
||||
|
||||
|
|
@ -218,7 +181,8 @@ cmake --preset windows
|
|||
cmake --build --preset windows -j [number of parallel threads to use]
|
||||
```
|
||||
|
||||
If the build is successful, you should now have a `gemma.exe` executable in the `build/` directory.
|
||||
If the build is successful, you should now have a `gemma.exe` executable in the
|
||||
`build/` directory.
|
||||
|
||||
#### Bazel
|
||||
|
||||
|
|
@ -226,7 +190,8 @@ If the build is successful, you should now have a `gemma.exe` executable in the
|
|||
bazel build -c opt --cxxopt=-std=c++20 :gemma
|
||||
```
|
||||
|
||||
If the build is successful, you should now have a `gemma` executable in the `bazel-bin/` directory.
|
||||
If the build is successful, you should now have a `gemma` executable in the
|
||||
`bazel-bin/` directory.
|
||||
|
||||
#### Make
|
||||
|
||||
|
|
@ -240,33 +205,21 @@ You can now run `gemma` from inside the `build/` directory.
|
|||
|
||||
`gemma` has the following required arguments:
|
||||
|
||||
Argument | Description | Example value
|
||||
--------------- | ---------------------------- | -----------------------
|
||||
`--model` | The model type. | `2b-it` ... (see below)
|
||||
`--weights` | The compressed weights file. | `2b-it-sfp.sbs`
|
||||
`--weight_type` | The compressed weight type. | `sfp`
|
||||
`--tokenizer` | The tokenizer file. | `tokenizer.spm`
|
||||
|
||||
`gemma` is invoked as:
|
||||
|
||||
```sh
|
||||
./gemma \
|
||||
--tokenizer [tokenizer file] \
|
||||
--weights [compressed weights file] \
|
||||
--weight_type [f32 or bf16 or sfp (default:sfp)] \
|
||||
--model [2b-it or 2b-pt or 7b-it or 7b-pt or ...]
|
||||
```
|
||||
Argument | Description | Example value
|
||||
------------- | ---------------------------- | ---------------
|
||||
`--weights` | The compressed weights file. | `2b-it-sfp.sbs`
|
||||
`--tokenizer` | The tokenizer file. | `tokenizer.spm`
|
||||
|
||||
Example invocation for the following configuration:
|
||||
|
||||
- Compressed weights file `2b-it-sfp.sbs` (2B instruction-tuned model, 8-bit
|
||||
switched floating point).
|
||||
- Tokenizer file `tokenizer.spm`.
|
||||
- weights file `gemma2-2b-it-sfp.sbs` (Gemma2 2B instruction-tuned model,
|
||||
8-bit switched floating point).
|
||||
- Tokenizer file `tokenizer.spm` (can omit for single-format weights files
|
||||
created after 2025-05-06, or output by migrate_weights.cc).
|
||||
|
||||
```sh
|
||||
./gemma \
|
||||
--tokenizer tokenizer.spm \
|
||||
--weights 2b-it-sfp.sbs --model 2b-it
|
||||
--tokenizer tokenizer.spm --weights gemma2-2b-it-sfp.sbs
|
||||
```
|
||||
|
||||
### RecurrentGemma
|
||||
|
|
@ -288,11 +241,9 @@ Step 1, and run the binary as follows:
|
|||
|
||||
### PaliGemma Vision-Language Model
|
||||
|
||||
This repository includes a version of the PaliGemma VLM
|
||||
([paper](https://arxiv.org/abs/2407.07726),
|
||||
[code](https://github.com/google-research/big_vision/tree/main/big_vision/configs/proj/paligemma))
|
||||
and its successor PaliGemma 2 ([paper](https://arxiv.org/abs/2412.03555)). We
|
||||
provide a C++ implementation of the PaliGemma model family here.
|
||||
This repository includes a version of the PaliGemma 2 VLM
|
||||
([paper](https://arxiv.org/abs/2412.03555)). We provide a C++ implementation of
|
||||
the PaliGemma 2 model here.
|
||||
|
||||
To use the version of PaliGemma included in this repository, build the gemma
|
||||
binary as noted above in Step 3. Download the compressed weights and tokenizer
|
||||
|
|
@ -303,8 +254,7 @@ and run the binary as follows:
|
|||
```sh
|
||||
./gemma \
|
||||
--tokenizer paligemma_tokenizer.model \
|
||||
--model paligemma-224 \
|
||||
--weights paligemma-3b-mix-224-sfp.sbs \
|
||||
--weights paligemma2-3b-mix-224-sfp.sbs \
|
||||
--image_file paligemma/testdata/image.ppm
|
||||
```
|
||||
|
||||
|
|
@ -346,10 +296,10 @@ from the multi-file format to the single-file format is available.
|
|||
```sh
|
||||
io/migrate_weights \
|
||||
--tokenizer .../tokenizer.spm --weights .../gemma2-2b-it-sfp.sbs \
|
||||
--model gemma2-2b-it --output_weights .../gemma2-2b-it-sfp-single.sbs
|
||||
--output_weights .../gemma2-2b-it-sfp-single.sbs
|
||||
```
|
||||
|
||||
After migration, you can use the new weights file with gemma.cpp like this:
|
||||
After migration, you can omit the tokenizer argument like this:
|
||||
|
||||
```sh
|
||||
./gemma --weights .../gemma2-2b-it-sfp-single.sbs
|
||||
|
|
@ -357,15 +307,6 @@ After migration, you can use the new weights file with gemma.cpp like this:
|
|||
|
||||
### Troubleshooting and FAQs
|
||||
|
||||
**Running `./gemma` fails with "Failed to read cache gating_ein_0 (error 294) ..."**
|
||||
|
||||
The most common problem is that the `--weight_type` argument does not match that
|
||||
of the model file. Revisit step #3 and check which weights you downloaded.
|
||||
|
||||
Note that we have already moved weight type from a compile-time decision to a
|
||||
runtime argument. In a subsequent step, we plan to bake this information into
|
||||
the weights.
|
||||
|
||||
**Problems building in Windows / Visual Studio**
|
||||
|
||||
Currently if you're using Windows, we recommend building in WSL (Windows
|
||||
|
|
@ -376,8 +317,8 @@ configurations, see issues for active discussion.
|
|||
|
||||
A common issue is that you are using a pre-trained model, which is not
|
||||
instruction-tuned and thus does not respond to instructions. Make sure you are
|
||||
using an instruction-tuned model (`2b-it-sfp`, `2b-it`, `7b-it-sfp`, `7b-it`)
|
||||
and not a pre-trained model (any model with a `-pt` suffix).
|
||||
using an instruction-tuned model (`gemma2-2b-it-sfp`) and not a pre-trained
|
||||
model (any model with a `-pt` suffix).
|
||||
|
||||
**What sequence lengths are supported?**
|
||||
|
||||
|
|
@ -387,11 +328,10 @@ sequences will be slow due to the quadratic cost of attention.
|
|||
|
||||
**How do I convert my fine-tune to a `.sbs` compressed model file?**
|
||||
|
||||
For PaliGemma (1 and 2) checkpoints, you can use
|
||||
python/convert_from_safetensors.py to convert from safetensors format (tested
|
||||
with building via bazel). For an adapter model, you will likely need to call
|
||||
merge_and_unload() to convert the adapter model to a single-file format before
|
||||
converting it.
|
||||
For PaliGemma 2 checkpoints, you can use python/convert_from_safetensors.py to
|
||||
convert from safetensors format (tested with building via bazel). For an adapter
|
||||
model, you will likely need to call merge_and_unload() to convert the adapter
|
||||
model to a single-file format before converting it.
|
||||
|
||||
Here is how to use it using a bazel build of the compression library assuming
|
||||
locally installed (venv) torch, numpy, safetensors, absl-py, etc.:
|
||||
|
|
@ -405,22 +345,18 @@ ln -s $BAZEL_OUTPUT_DIR [...]/site-packages/compression
|
|||
python3 python/convert_from_safetensors.py --load_path [...].safetensors.index.json
|
||||
```
|
||||
|
||||
See also compression/convert_weights.py for a slightly older option to convert a
|
||||
pytorch checkpoint. (The code may need updates to work with Gemma-2 models.)
|
||||
|
||||
**What are some easy ways to make the model run faster?**
|
||||
|
||||
1. Make sure you are using the 8-bit switched floating point `-sfp` models.
|
||||
These are half the size of bf16 and thus use less memory bandwidth and cache
|
||||
space.
|
||||
2. If you're on a laptop, make sure power mode is set to maximize performance
|
||||
2. Due to auto-tuning, the second and especially third query will be faster.
|
||||
3. If you're on a laptop, make sure power mode is set to maximize performance
|
||||
and saving mode is **off**. For most laptops, the power saving modes get
|
||||
activated automatically if the computer is not plugged in.
|
||||
3. Close other unused cpu-intensive applications.
|
||||
4. On macs, anecdotally we observe a "warm-up" ramp-up in speed as performance
|
||||
4. Close other unused cpu-intensive applications.
|
||||
5. On macs, anecdotally we observe a "warm-up" ramp-up in speed as performance
|
||||
cores get engaged.
|
||||
5. Experiment with the `--num_threads` argument value. Depending on the device,
|
||||
larger numbers don't always mean better performance.
|
||||
|
||||
We're also working on algorithmic and optimization approaches for faster
|
||||
inference, stay tuned.
|
||||
|
|
@ -452,10 +388,7 @@ $ ./gemma [...]
|
|||
__/ | | | | |
|
||||
|___/ |_| |_|
|
||||
|
||||
tokenizer : tokenizer.spm
|
||||
weights : 2b-it-sfp.sbs
|
||||
model : 2b-it
|
||||
max_generated_tokens : 2048
|
||||
...
|
||||
|
||||
*Usage*
|
||||
Enter an instruction and press enter (%C reset conversation, %Q quits).
|
||||
|
|
@ -493,7 +426,7 @@ For using the `gemma` executable as a command line tool, it may be useful to
|
|||
create an alias for gemma.cpp with arguments fully specified:
|
||||
|
||||
```sh
|
||||
alias gemma2b="~/gemma.cpp/build/gemma -- --tokenizer ~/gemma.cpp/build/tokenizer.spm --weights ~/gemma.cpp/build/gemma2-2b-it-sfp.sbs --model gemma2-2b-it --verbosity 0"
|
||||
alias gemma2b="~/gemma.cpp/build/gemma -- --tokenizer ~/gemma.cpp/build/tokenizer.spm --weights ~/gemma.cpp/build/gemma2-2b-it-sfp.sbs --verbosity 0"
|
||||
```
|
||||
|
||||
Replace the above paths with your own paths to the model and tokenizer paths
|
||||
|
|
@ -523,8 +456,8 @@ Let's break down the code:
|
|||
### Incorporating gemma.cpp as a Library in your Project
|
||||
|
||||
The easiest way to incorporate gemma.cpp in your own project is to pull in
|
||||
gemma.cpp and dependencies using `FetchContent`. You can add the following to your
|
||||
CMakeLists.txt:
|
||||
gemma.cpp and dependencies using `FetchContent`. You can add the following to
|
||||
your CMakeLists.txt:
|
||||
|
||||
```
|
||||
include(FetchContent)
|
||||
|
|
@ -593,9 +526,10 @@ submit a PR with a `README.md` edit.
|
|||
|
||||
## Acknowledgements and Contacts
|
||||
|
||||
gemma.cpp was started in fall 2023 by [Austin Huang](mailto:austinvhuang@google.com)
|
||||
and [Jan Wassenberg](mailto:janwas@google.com), and subsequently released February 2024
|
||||
thanks to contributions from Phil Culliton, Paul Chang, and Dan Zheng.
|
||||
gemma.cpp was started in fall 2023 by
|
||||
[Austin Huang](mailto:austinvhuang@google.com) and
|
||||
[Jan Wassenberg](mailto:janwas@google.com), and subsequently released February
|
||||
2024 thanks to contributions from Phil Culliton, Paul Chang, and Dan Zheng.
|
||||
|
||||
Griffin support was implemented in April 2024 thanks to contributions by Andrey
|
||||
Mikhaylov, Eugene Kliuchnikov, Jan Wassenberg, Jyrki Alakuijala, Lode
|
||||
|
|
|
|||
|
|
@ -137,57 +137,6 @@ static ModelConfig ConfigGemma2_2B() {
|
|||
return config;
|
||||
}
|
||||
|
||||
static LayerConfig LayerConfigGemma7B(size_t model_dim) {
|
||||
LayerConfig config;
|
||||
config.model_dim = model_dim;
|
||||
config.ff_hidden_dim = 16 * 3072 / 2; // = 24576
|
||||
config.heads = 16;
|
||||
config.kv_heads = 16;
|
||||
config.qkv_dim = 256;
|
||||
return config;
|
||||
}
|
||||
|
||||
static ModelConfig ConfigGemma7B() {
|
||||
ModelConfig config = ConfigBaseGemmaV1();
|
||||
config.display_name = "Gemma7B";
|
||||
config.model = Model::GEMMA_7B;
|
||||
config.model_dim = 3072;
|
||||
config.vocab_size = kVocabSize;
|
||||
config.seq_len = GEMMA_MAX_SEQLEN;
|
||||
LayerConfig layer_config = LayerConfigGemma7B(config.model_dim);
|
||||
config.num_layers = 28;
|
||||
config.layer_configs = {config.num_layers, layer_config};
|
||||
config.query_scale = QueryScaleType::SqrtKeySize;
|
||||
config.attention_window_sizes =
|
||||
FixedAttentionWindowSizes<28>(GEMMA_MAX_SEQLEN);
|
||||
return config;
|
||||
}
|
||||
|
||||
static LayerConfig LayerConfigGemma2B(size_t model_dim) {
|
||||
LayerConfig config;
|
||||
config.model_dim = model_dim;
|
||||
config.ff_hidden_dim = 16 * 2048 / 2; // = 16384
|
||||
config.heads = 8;
|
||||
config.kv_heads = 1;
|
||||
config.qkv_dim = 256;
|
||||
return config;
|
||||
}
|
||||
|
||||
static ModelConfig ConfigGemma2B() {
|
||||
ModelConfig config = ConfigBaseGemmaV1();
|
||||
config.display_name = "Gemma2B";
|
||||
config.model = Model::GEMMA_2B;
|
||||
config.model_dim = 2048;
|
||||
config.vocab_size = kVocabSize;
|
||||
config.seq_len = GEMMA_MAX_SEQLEN;
|
||||
LayerConfig layer_config = LayerConfigGemma2B(config.model_dim);
|
||||
config.num_layers = 18;
|
||||
config.layer_configs = {config.num_layers, layer_config};
|
||||
config.attention_window_sizes =
|
||||
FixedAttentionWindowSizes<18>(GEMMA_MAX_SEQLEN);
|
||||
return config;
|
||||
}
|
||||
|
||||
static LayerConfig LayerConfigGemmaTiny(size_t model_dim) {
|
||||
LayerConfig config;
|
||||
config.model_dim = model_dim;
|
||||
|
|
@ -204,7 +153,7 @@ static ModelConfig ConfigGemmaTiny() {
|
|||
config.model = Model::GEMMA_TINY;
|
||||
config.wrapping = PromptWrapping::GEMMA_IT;
|
||||
config.model_dim = 32;
|
||||
config.vocab_size = 16;
|
||||
config.vocab_size = 32; // at least two f32 vectors
|
||||
config.seq_len = 32; // optimize_test requires more than 24
|
||||
LayerConfig layer_config = LayerConfigGemmaTiny(config.model_dim);
|
||||
config.num_layers = 2;
|
||||
|
|
@ -290,24 +239,6 @@ static void AddVitConfig(ModelConfig& config, size_t image_size = 224) {
|
|||
config.vit_config.num_scales = 4 * config.vit_config.layer_configs.size();
|
||||
}
|
||||
|
||||
static ModelConfig ConfigPaliGemma_224() {
|
||||
ModelConfig config = ConfigGemma2B();
|
||||
config.display_name = "PaliGemma_224";
|
||||
config.model = Model::PALIGEMMA_224;
|
||||
config.wrapping = PromptWrapping::PALIGEMMA;
|
||||
AddVitConfig(config);
|
||||
return config;
|
||||
}
|
||||
|
||||
static ModelConfig ConfigPaliGemma_448() {
|
||||
ModelConfig config = ConfigGemma2B();
|
||||
config.display_name = "PaliGemma_448";
|
||||
config.model = Model::PALIGEMMA_448;
|
||||
config.wrapping = PromptWrapping::PALIGEMMA;
|
||||
AddVitConfig(config, /*image_size=*/448);
|
||||
return config;
|
||||
}
|
||||
|
||||
ModelConfig GetVitConfig(const ModelConfig& config) {
|
||||
ModelConfig vit_config = ConfigNoSSM();
|
||||
vit_config.model_dim = config.vit_config.model_dim;
|
||||
|
|
@ -547,10 +478,6 @@ static ModelConfig ConfigGemma3_27B() {
|
|||
|
||||
static ModelConfig ConfigFromModel(Model model) {
|
||||
switch (model) {
|
||||
case Model::GEMMA_2B:
|
||||
return ConfigGemma2B();
|
||||
case Model::GEMMA_7B:
|
||||
return ConfigGemma7B();
|
||||
case Model::GEMMA2_2B:
|
||||
return ConfigGemma2_2B();
|
||||
case Model::GEMMA2_9B:
|
||||
|
|
@ -561,10 +488,6 @@ static ModelConfig ConfigFromModel(Model model) {
|
|||
return ConfigGriffin2B();
|
||||
case Model::GEMMA_TINY:
|
||||
return ConfigGemmaTiny();
|
||||
case Model::PALIGEMMA_224:
|
||||
return ConfigPaliGemma_224();
|
||||
case Model::PALIGEMMA_448:
|
||||
return ConfigPaliGemma_448();
|
||||
case Model::PALIGEMMA2_3B_224:
|
||||
return ConfigPaliGemma2_3B_224();
|
||||
case Model::PALIGEMMA2_3B_448:
|
||||
|
|
@ -590,10 +513,6 @@ const char* ModelPrefix(Model model) {
|
|||
switch (model) {
|
||||
case Model::UNKNOWN:
|
||||
return "unknown";
|
||||
case Model::GEMMA_2B:
|
||||
return "2b";
|
||||
case Model::GEMMA_7B:
|
||||
return "7b";
|
||||
case Model::GEMMA2_2B:
|
||||
return "gemma2-2b";
|
||||
case Model::GEMMA2_9B:
|
||||
|
|
@ -604,10 +523,6 @@ const char* ModelPrefix(Model model) {
|
|||
return "gr2b";
|
||||
case Model::GEMMA_TINY:
|
||||
return "tiny";
|
||||
case Model::PALIGEMMA_224:
|
||||
return "paligemma-224";
|
||||
case Model::PALIGEMMA_448:
|
||||
return "paligemma-448";
|
||||
case Model::PALIGEMMA2_3B_224:
|
||||
return "paligemma2-3b-224";
|
||||
case Model::PALIGEMMA2_3B_448:
|
||||
|
|
@ -802,16 +717,12 @@ bool ModelConfig::OverwriteWithCanonical() {
|
|||
|
||||
Model DeduceModel(size_t layers, int layer_types) {
|
||||
switch (layers) {
|
||||
case 3:
|
||||
case 2:
|
||||
return Model::GEMMA_TINY;
|
||||
case 18:
|
||||
return Model::GEMMA_2B;
|
||||
case 26:
|
||||
if (layer_types & kDeducedGriffin) return Model::GRIFFIN_2B;
|
||||
if (layer_types & kDeducedViT) return Model::GEMMA3_1B;
|
||||
return Model::GEMMA2_2B;
|
||||
case 28:
|
||||
return Model::GEMMA_7B;
|
||||
case 34:
|
||||
return Model::GEMMA3_4B;
|
||||
case 42:
|
||||
|
|
|
|||
|
|
@ -157,17 +157,15 @@ std::vector<uint32_t> RepeatedAttentionWindowSizes(
|
|||
|
||||
// Model variants: see configs.cc for details.
|
||||
enum class Model {
|
||||
UNKNOWN,
|
||||
GEMMA_2B,
|
||||
GEMMA_7B,
|
||||
GEMMA2_9B,
|
||||
UNKNOWN = 0,
|
||||
// 1 and 2 are obsolete.
|
||||
GEMMA2_9B = 3,
|
||||
GEMMA2_27B,
|
||||
GRIFFIN_2B,
|
||||
GEMMA_TINY, // for backprop/ only
|
||||
GEMMA2_2B,
|
||||
PALIGEMMA_224,
|
||||
PALIGEMMA_448,
|
||||
PALIGEMMA2_3B_224,
|
||||
// 8 and 9 are obsolete.
|
||||
PALIGEMMA2_3B_224 = 10,
|
||||
PALIGEMMA2_3B_448,
|
||||
PALIGEMMA2_10B_224,
|
||||
PALIGEMMA2_10B_448,
|
||||
|
|
@ -190,8 +188,7 @@ static inline bool IsVLM(Model model) {
|
|||
}
|
||||
|
||||
static inline bool IsPaliGemma(Model model) {
|
||||
if (model == Model::PALIGEMMA_224 || model == Model::PALIGEMMA_448 ||
|
||||
model == Model::PALIGEMMA2_3B_224 || model == Model::PALIGEMMA2_3B_448 ||
|
||||
if (model == Model::PALIGEMMA2_3B_224 || model == Model::PALIGEMMA2_3B_448 ||
|
||||
model == Model::PALIGEMMA2_10B_224 ||
|
||||
model == Model::PALIGEMMA2_10B_448) {
|
||||
return true;
|
||||
|
|
@ -202,15 +199,19 @@ static inline bool IsPaliGemma(Model model) {
|
|||
// Visits every valid model enum, skipping `UNKNOWN` and `kSentinel`.
|
||||
template <class Func>
|
||||
void ForEachModel(const Func& func) {
|
||||
for (size_t i = static_cast<size_t>(Model::UNKNOWN) + 1;
|
||||
for (size_t i = static_cast<size_t>(Model::GEMMA2_9B);
|
||||
i < static_cast<size_t>(Model::kSentinel); ++i) {
|
||||
if (i == 8 || i == 9) continue;
|
||||
func(static_cast<Model>(i));
|
||||
}
|
||||
}
|
||||
|
||||
static inline bool EnumValid(Model model) {
|
||||
// Valid for purposes of serialization, even if unknown.
|
||||
if (model == Model::UNKNOWN) return true;
|
||||
const size_t i = static_cast<size_t>(model);
|
||||
if (i < static_cast<size_t>(Model::kSentinel)) {
|
||||
if (i >= static_cast<size_t>(Model::GEMMA2_9B) &&
|
||||
i < static_cast<size_t>(Model::kSentinel) && i != 8 && i != 9) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
|
|
|
|||
|
|
@ -48,7 +48,7 @@ class PaliGemmaTest : public ::testing::Test {
|
|||
void PaliGemmaTest::InitVit(const std::string& path) {
|
||||
ASSERT_NE(s_env->GetGemma(), nullptr);
|
||||
const Allocator& allocator = s_env->Env().ctx.allocator;
|
||||
Gemma& gemma = *(s_env->GetGemma());
|
||||
const Gemma& gemma = *(s_env->GetGemma());
|
||||
image_tokens_ = ImageTokens(
|
||||
allocator, Extents2D(gemma.GetModelConfig().vit_config.seq_len,
|
||||
gemma.GetModelConfig().model_dim));
|
||||
|
|
@ -62,7 +62,7 @@ void PaliGemmaTest::InitVit(const std::string& path) {
|
|||
}
|
||||
|
||||
std::string PaliGemmaTest::GemmaReply(const std::string& prompt_text) const{
|
||||
Gemma& model = *(s_env->GetGemma());
|
||||
const Gemma& model = *(s_env->GetGemma());
|
||||
s_env->MutableGen().seed(0x12345678);
|
||||
RuntimeConfig runtime_config = {.max_generated_tokens = 512,
|
||||
.gen = &s_env->MutableGen(),
|
||||
|
|
@ -103,17 +103,6 @@ void PaliGemmaTest::TestQuestions(const char* kQA[][2], size_t num_questions) {
|
|||
|
||||
TEST_F(PaliGemmaTest, General) {
|
||||
ASSERT_NE(s_env->GetGemma(), nullptr);
|
||||
static const char* kQA_3B_mix_224[][2] = {
|
||||
{"describe this image",
|
||||
"A large building with two towers stands tall on the water's edge."},
|
||||
{"describe image briefly",
|
||||
"A large building with two towers in the middle of a city."},
|
||||
{"What kind of building is it?", "church"},
|
||||
{"How many towers does the church have?", "2"},
|
||||
{"detect water", "<loc1022> water"},
|
||||
{"segment water", "<seg010> water"},
|
||||
{"Which city is this more likely? Tokio or Zurich?", "zurich"},
|
||||
};
|
||||
static const char* kQA_2_3B_pt_448[][2] = {
|
||||
{"describe this image", "The Grossmünster in Zürich"},
|
||||
{"describe image briefly", "The Grossmünster"},
|
||||
|
|
@ -123,10 +112,6 @@ TEST_F(PaliGemmaTest, General) {
|
|||
const char* (*qa)[2];
|
||||
size_t num;
|
||||
switch (s_env->GetGemma()->GetModelConfig().model) {
|
||||
case Model::PALIGEMMA_224:
|
||||
qa = kQA_3B_mix_224;
|
||||
num = sizeof(kQA_3B_mix_224) / sizeof(kQA_3B_mix_224[0]);
|
||||
break;
|
||||
case Model::PALIGEMMA2_3B_448:
|
||||
qa = kQA_2_3B_pt_448;
|
||||
num = sizeof(kQA_2_3B_pt_448) / sizeof(kQA_2_3B_pt_448[0]);
|
||||
|
|
|
|||
|
|
@ -85,8 +85,6 @@ PYBIND11_MODULE(configs, py_module) {
|
|||
|
||||
enum_<Model>(py_module, "Model")
|
||||
.value("UNKNOWN", Model::UNKNOWN)
|
||||
.value("GEMMA_2B", Model::GEMMA_2B)
|
||||
.value("GEMMA_7B", Model::GEMMA_7B)
|
||||
.value("GEMMA2_9B", Model::GEMMA2_9B)
|
||||
.value("GEMMA2_27B", Model::GEMMA2_27B)
|
||||
.value("GRIFFIN_2B", Model::GRIFFIN_2B)
|
||||
|
|
@ -96,7 +94,6 @@ PYBIND11_MODULE(configs, py_module) {
|
|||
.value("PALIGEMMA2_10B_224", Model::PALIGEMMA2_10B_224)
|
||||
.value("PALIGEMMA2_3B_448", Model::PALIGEMMA2_3B_448)
|
||||
.value("PALIGEMMA2_10B_448", Model::PALIGEMMA2_10B_448)
|
||||
.value("PALIGEMMA_224", Model::PALIGEMMA_224)
|
||||
.value("PALIGEMMA_448", Model::PALIGEMMA_448);
|
||||
|
||||
class_<TensorInfo>(py_module, "TensorInfo")
|
||||
|
|
|
|||
Loading…
Reference in New Issue