Add support for PaliGemma Vision-LM (224x224) to gemma.cpp

See https://arxiv.org/abs/2407.07726 for a description of the model.
Because PaliGemma operates as a prefix-LM on the image+prompt, add support for that.

PiperOrigin-RevId: 677841119
This commit is contained in:
Daniel Keysers 2024-09-23 10:09:10 -07:00 committed by Copybara-Service
parent c6c10e0a53
commit f8835fe4a4
24 changed files with 1630 additions and 164 deletions

View File

@ -242,6 +242,9 @@ cc_library(
"gemma/instantiations/gemma2_2b_bf16.cc",
"gemma/instantiations/gemma2_2b_f32.cc",
"gemma/instantiations/gemma2_2b_sfp.cc",
"gemma/instantiations/paligemma_224_bf16.cc",
"gemma/instantiations/paligemma_224_f32.cc",
"gemma/instantiations/paligemma_224_sfp.cc",
],
hdrs = [
"gemma/activations.h",
@ -264,6 +267,7 @@ cc_library(
":weights",
":threading",
"//compression:io",
"//paligemma:image",
"@hwy//:hwy",
"@hwy//:bit_set",
"@hwy//:matvec",
@ -361,6 +365,7 @@ cc_binary(
":gemma_lib",
":threading",
# Placeholder for internal dep, do not remove.,
"//paligemma:image",
"@hwy//:hwy",
"@hwy//:profiler",
"@hwy//:thread_pool",

View File

@ -93,6 +93,9 @@ set(SOURCES
gemma/instantiations/gemma2_2b_bf16.cc
gemma/instantiations/gemma2_2b_f32.cc
gemma/instantiations/gemma2_2b_sfp.cc
gemma/instantiations/paligemma_224_bf16.cc
gemma/instantiations/paligemma_224_f32.cc
gemma/instantiations/paligemma_224_sfp.cc
gemma/kv_cache.cc
gemma/kv_cache.h
gemma/tokenizer.cc
@ -103,6 +106,8 @@ set(SOURCES
ops/matmul-inl.h
ops/matvec-inl.h
ops/ops-inl.h
paligemma/image.cc
paligemma/image.h
util/allocator.h
util/app.h
util/args.h
@ -164,6 +169,8 @@ set(GEMMA_TEST_FILES
ops/matmul_test.cc
ops/gemma_matvec_test.cc
evals/gemma_test.cc
paligemma/image_test.cc
paligemma/paligemma_test.cc
)
foreach (TESTFILE IN LISTS GEMMA_TEST_FILES)

105
README.md
View File

@ -23,7 +23,7 @@ deployment-oriented C++ inference runtimes, which are not designed for
experimentation, and Python-centric ML research frameworks, which abstract away
low-level computation through compilation.
gemma.cpp provides a minimalist implementation of Gemma 2B and 7B models,
gemma.cpp provides a minimalist implementation of Gemma-1 and Gemma-2 models,
focusing on simplicity and directness rather than full generality. This is
inspired by vertically-integrated model implementations such as
[ggml](https://github.com/ggerganov/ggml),
@ -78,17 +78,20 @@ winget install --id Microsoft.VisualStudio.2022.BuildTools --force --override "-
### Step 1: Obtain model weights and tokenizer from Kaggle or Hugging Face Hub
Visit the
[Kaggle page for Gemma](https://www.kaggle.com/models/google/gemma/frameworks/gemmaCpp),
or [Gemma-2](https://www.kaggle.com/models/google/gemma-2/gemmaCpp), and select
`Model Variations |> Gemma C++`.
[Kaggle page for Gemma-2](https://www.kaggle.com/models/google/gemma-2/gemmaCpp)
[or Gemma-1](https://www.kaggle.com/models/google/gemma/frameworks/gemmaCpp),
and select `Model Variations |> Gemma C++`.
On this tab, the `Variation` dropdown includes the options below. Note bfloat16
weights are higher fidelity, while 8-bit switched floating point weights enable
faster inference. In general, we recommend starting with the `-sfp` checkpoints.
If you are unsure which model to start with, we recommend starting with the
smallest Gemma-2 model, i.e. `2.0-2b-it-sfp`.
Alternatively, visit the
[gemma.cpp](https://huggingface.co/models?other=gemma.cpp) models on the Hugging
Face Hub. First go the the model repository of the model of interest (see
Face Hub. First go the model repository of the model of interest (see
recommendations below). Then, click the `Files and versions` tab and download
the model and tokenizer files. For programmatic downloading, if you have
`huggingface_hub` installed, you can also download by running:
@ -98,7 +101,7 @@ huggingface-cli login # Just the first time
huggingface-cli download google/gemma-2b-sfp-cpp --local-dir build/
```
2B instruction-tuned (`it`) and pre-trained (`pt`) models:
Gemma-1 2B instruction-tuned (`it`) and pre-trained (`pt`) models:
| Model name | Description |
| ----------- | ----------- |
@ -107,7 +110,7 @@ huggingface-cli download google/gemma-2b-sfp-cpp --local-dir build/
| `2b-pt` | 2 billion parameter pre-trained model, bfloat16 |
| `2b-pt-sfp` | 2 billion parameter pre-trained model, 8-bit switched floating point |
7B instruction-tuned (`it`) and pre-trained (`pt`) models:
Gemma-1 7B instruction-tuned (`it`) and pre-trained (`pt`) models:
| Model name | Description |
| ----------- | ----------- |
@ -256,6 +259,53 @@ Step 1, and run the binary as follows:
`./gemma --tokenizer tokenizer.spm --model gr2b-it --weights 2b-it-sfp.sbs`
### PaliGemma Vision-Language Model
This repository includes a version of the PaliGemma VLM
([paper](https://arxiv.org/abs/2407.07726),
[code](https://github.com/google-research/big_vision/tree/main/big_vision/configs/proj/paligemma)).
We provide a C++ implementation of this model here.
To use the version of PaliGemma included in this repository, build the gemma
binary as noted above in Step 3. Download the compressed weights and tokenizer
from //TODO(keysers) - update location// and run the binary as follows:
```./gemma \
--tokenizer paligemma_tokenizer.model \
--model paligemma-224 \
--weights paligemma-3b-mix-224-sfp.sbs \
--image_file paligemma/testdata/image.ppm
```
Note that the image reading code is very basic to avoid depending on an image
processing library for now. We currently only support reading binary PPMs (P6).
So use a tool like `convert` to first convert your images into that format, e.g.
`convert image.jpeg -resize 224x224^ image.ppm`
(As the image will be resized for processing anyway, we can already resize at
this stage for slightly faster loading.)
The interaction with the image (using the mix-224 checkpoint) may then look
something like this:
```
> Describe the image briefly
A large building with two towers in the middle of a city.
> What type of building is it?
church
> What color is the church?
gray
> caption image
A large building with two towers stands tall on the water's edge. The building
has a brown roof and a window on the side. A tree stands in front of the
building, and a flag waves proudly from its top. The water is calm and blue,
reflecting the sky above. A bridge crosses the water, and a red and white boat
rests on its surface. The building has a window on the side, and a flag on top.
A tall tree stands in front of the building, and a window on the building is
visible from the water. The water is green, and the sky is blue.
```
### Troubleshooting and FAQs
**Running `./gemma` fails with "Failed to read cache gating_ein_0 (error 294) ..."**
@ -283,8 +333,8 @@ and not a pre-trained model (any model with a `-pt` suffix).
**How do I convert my fine-tune to a `.sbs` compressed model file?**
We're working on a python script to convert a standard model format to `.sbs`,
and hope have it available in the next week or so. Follow [this
issue](https://github.com/google/gemma.cpp/issues/11) for updates.
and hope have it available soon. Follow
[this issue](https://github.com/google/gemma.cpp/issues/11) for updates.
**What are some easy ways to make the model run faster?**
@ -371,7 +421,7 @@ For using the `gemma` executable as a command line tool, it may be useful to
create an alias for gemma.cpp with arguments fully specified:
```sh
alias gemma2b="~/gemma.cpp/build/gemma -- --tokenizer ~/gemma.cpp/build/tokenizer.spm --weights ~/gemma.cpp/build/2b-it-sfp.sbs --model 2b-it --verbosity 0"
alias gemma2b="~/gemma.cpp/build/gemma -- --tokenizer ~/gemma.cpp/build/tokenizer.spm --weights ~/gemma.cpp/build/gemma2-2b-it-sfp.sbs --model gemma2-2b-it --verbosity 0"
```
Replace the above paths with your own paths to the model and tokenizer paths
@ -381,7 +431,7 @@ Here is an example of prompting `gemma` with a truncated input
file (using a `gemma2b` alias like defined above):
```sh
cat configs.h | tail -35 | tr '\n' ' ' | xargs -0 echo "What does this C++ code do: " | gemma2b
cat configs.h | tail -n 35 | tr '\n' ' ' | xargs -0 echo "What does this C++ code do: " | gemma2b
```
> [!NOTE]
@ -391,27 +441,11 @@ cat configs.h | tail -35 | tr '\n' ' ' | xargs -0 echo "What does this C++ code
The output of the above command should look like:
```console
$ cat configs.h | tail -35 | tr '\n' ' ' | xargs -0 echo "What does this C++ code do: " | gemma2b
[ Reading prompt ] ......................................................................................................................................................................................................................................................................................................................................................................................................................................................................................
The code defines two C++ structs, `ConfigGemma7B` and `ConfigGemma2B`, which are used for configuring a deep learning model.
[ Reading prompt ] [...]
This C++ code snippet defines a set of **constants** used in a large language model (LLM) implementation, likely related to the **attention mechanism**.
**ConfigGemma7B**:
* `kSeqLen`: Stores the length of the sequence to be processed. It's set to 7168.
* `kVocabSize`: Stores the size of the vocabulary, which is 256128.
* `kLayers`: Number of layers in the deep learning model. It's set to 28.
* `kModelDim`: Dimension of the model's internal representation. It's set to 3072.
* `kFFHiddenDim`: Dimension of the feedforward and recurrent layers' hidden representations. It's set to 16 * 3072 / 2.
**ConfigGemma2B**:
* `kSeqLen`: Stores the length of the sequence to be processed. It's also set to 7168.
* `kVocabSize`: Size of the vocabulary, which is 256128.
* `kLayers`: Number of layers in the deep learning model. It's set to 18.
* `kModelDim`: Dimension of the model's internal representation. It's set to 2048.
* `kFFHiddenDim`: Dimension of the feedforward and recurrent layers' hidden representations. It's set to 16 * 2048 / 2.
These structs are used to configure a deep learning model with specific parameters for either Gemma7B or Gemma2B architecture.
Let's break down the code:
[...]
```
### Incorporating gemma.cpp as a Library in your Project
@ -496,4 +530,13 @@ Mikhaylov, Eugene Kliuchnikov, Jan Wassenberg, Jyrki Alakuijala, Lode
Vandevenne, Luca Versari, Martin Bruse, Phil Culliton, Sami Boukortt, Thomas
Fischbacher and Zoltan Szabadka.
Gemma-2 support was implemented in June/July 2024 with the help of several
people.
PaliGemma support was implemented in September 2024 with contributions from
Daniel Keysers.
[Jan Wassenberg](mailto:janwas@google.com) has continued to contribute many
improvements, including major gains in efficiency, since the initial release.
This is not an officially supported Google product.

View File

@ -322,6 +322,9 @@ namespace gcpp {
void Run(Args& args) {
hwy::ThreadPool pool(args.num_threads);
const Model model_type = args.ModelType();
if (model_type == Model::PALIGEMMA_224) {
HWY_ABORT("PaliGemma is not supported in compress_weights.");
}
const Type weight_type = args.WeightType();
GEMMA_EXPORT_AND_DISPATCH(
model_type, weight_type, CompressWeights,

View File

@ -97,7 +97,9 @@ struct Activations {
x = RowVectorBatch<float>(batch_size, kModelDim);
q = RowVectorBatch<float>(batch_size, kHeads * QStride<TConfig>());
logits = RowVectorBatch<float>(batch_size, kVocabSize);
if constexpr (kVocabSize > 0) {
logits = RowVectorBatch<float>(batch_size, kVocabSize);
}
pre_att_rms_out = RowVectorBatch<float>(batch_size, kModelDim);
att = RowVectorBatch<float>(batch_size, kHeads * kSeqLen);
@ -109,7 +111,7 @@ struct Activations {
C2 = RowVectorBatch<float>(batch_size, kFFHiddenDim);
ffw_out = RowVectorBatch<float>(batch_size, kModelDim);
if (kGriffinLayers > 0) {
if constexpr (kGriffinLayers > 0) {
griffin_x = RowVectorBatch<float>(batch_size, kModelDim);
griffin_y = RowVectorBatch<float>(batch_size, kModelDim);
griffin_gate_x = RowVectorBatch<float>(batch_size, kModelDim);

View File

@ -36,6 +36,7 @@ constexpr const char* kModelFlags[] = {
"gemma2-2b-pt", "gemma2-2b-it", // Gemma2 2B
"9b-pt", "9b-it", // Gemma2 9B
"27b-pt", "27b-it", // Gemma2 27B
"paligemma-224", // PaliGemma 224
};
constexpr Model kModelTypes[] = {
Model::GEMMA_2B, Model::GEMMA_2B, // Gemma 2B
@ -43,8 +44,9 @@ constexpr Model kModelTypes[] = {
Model::GRIFFIN_2B, Model::GRIFFIN_2B, // RecurrentGemma
Model::GEMMA_TINY, // Gemma Tiny
Model::GEMMA2_2B, Model::GEMMA2_2B, // Gemma2 2B
Model::GEMMA2_9B, Model::GEMMA2_9B, // Gemma2 9B
Model::GEMMA2_27B, Model::GEMMA2_27B, // Gemma2 27B
Model::GEMMA2_9B, Model::GEMMA2_9B, // Gemma2 9B
Model::GEMMA2_27B, Model::GEMMA2_27B, // Gemma2 27B
Model::PALIGEMMA_224, // PaliGemma 224
};
constexpr ModelTraining kModelTraining[] = {
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma 2B
@ -54,6 +56,7 @@ constexpr ModelTraining kModelTraining[] = {
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma2 2B
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma2 9B
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma2 27B
ModelTraining::PALIGEMMA, // PaliGemma 224
};
constexpr size_t kNumModelFlags =

View File

@ -37,10 +37,11 @@ enum class Model {
GRIFFIN_2B,
GEMMA_TINY,
GEMMA2_2B,
PALIGEMMA_224,
};
// Instruction-tuned models require extra 'turn structure' tokens in prompts.
enum class ModelTraining { GEMMA_IT, GEMMA_PT };
enum class ModelTraining { GEMMA_IT, GEMMA_PT, PALIGEMMA };
// Tensor types for loading weights. When adding a new one, also
// update GEMMA_FOREACH* and Call* below, and add instantiations/*.cc.
@ -93,7 +94,9 @@ decltype(auto) CallForModel(Model model, TArgs&&... args) {
return FuncT<ConfigGriffin2B<TWeight>>()(std::forward<TArgs>(args)...);
case Model::GEMMA2_2B:
return FuncT<ConfigGemma2_2B<TWeight>>()(std::forward<TArgs>(args)...);
case Model::PALIGEMMA_224:
return FuncT<ConfigPaliGemma_224<TWeight>>()(
std::forward<TArgs>(args)...);
default:
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
}
@ -136,8 +139,9 @@ decltype(auto) CallForModelAndWeight(Model model, Type weight,
GEMMA_FOREACH_WEIGHT(X, ConfigGemma7B) \
GEMMA_FOREACH_WEIGHT(X, ConfigGriffin2B) \
GEMMA_FOREACH_WEIGHT(X, ConfigGemma2_2B) \
GEMMA_FOREACH_WEIGHT(X, ConfigGemma2_9B) \
GEMMA_FOREACH_WEIGHT(X, ConfigGemma2_27B) \
GEMMA_FOREACH_WEIGHT(X, ConfigGemma2_9B) \
GEMMA_FOREACH_WEIGHT(X, ConfigGemma2_27B) \
GEMMA_FOREACH_WEIGHT(X, ConfigPaliGemma_224) \
static_assert(true, "Allow trailing ;")
// Used by GEMMA_EXPORT_AND_DISPATCH. For a given TWEIGHT (e.g. float),
@ -179,6 +183,11 @@ decltype(auto) CallForModelAndWeight(Model model, Type weight,
ARGS; \
break; \
} \
case Model::PALIGEMMA_224: { \
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC<ConfigPaliGemma_224<TWEIGHT>>)\
ARGS; \
break; \
} \
default: \
HWY_ABORT("Model type %d unknown.", static_cast<int>(MODEL)); \
}

View File

@ -44,6 +44,7 @@ using EmbedderInputT = hwy::bfloat16_t;
enum class LayerAttentionType {
kGemma,
kGriffinRecurrentBlock,
kVit,
};
// Post attention and ffw normalization type.
@ -131,7 +132,22 @@ struct CachePosSize {
}
};
struct ConfigNoSSM {
struct ConfigNoVit {
struct VitConfig {
static constexpr int kLayers = 0;
static constexpr std::array<LayerAttentionType, 0> kLayerConfig =
FixedLayerConfig<0>(LayerAttentionType::kVit);
static constexpr int kModelDim = 0;
static constexpr int kFFHiddenDim = 0;
static constexpr int kHeads = 0;
static constexpr int kKVHeads = 0;
static constexpr int kQKVDim = 0;
static constexpr int kSeqLen = 0;
static constexpr ResidualType kResidual = ResidualType::Add;
};
};
struct ConfigNoSSM : ConfigNoVit {
static constexpr int kGriffinLayers = 0;
static constexpr int kConv1dWidth = 0;
@ -247,6 +263,37 @@ struct ConfigGemma2B : public ConfigBaseGemmaV1 {
static constexpr bool kAbsolutePE = false;
};
template <typename TWeight>
struct ConfigPaliGemma_224 : public ConfigGemma2B<TWeight> {
// On the LM side, the vocab size is one difference to Gemma1-2B in the
// architecture. PaliGemma adds 1024 <locNNNN> and 128 <segNNN> tokens.
static constexpr int kVocabSize = 256000 + 1024 + 128; // = 257152
// Sub-config for the Vision-Transformer part.
struct VitConfig : public ConfigNoSSM {
using Weight = TWeight;
// The ViT parts. https://arxiv.org/abs/2305.13035
// "SoViT-400m/14 [...] has a width of 1152, depth 27, and MLP dim 4304."
static constexpr std::array<LayerAttentionType, 27> kLayerConfig =
FixedLayerConfig<27>(LayerAttentionType::kVit);
static constexpr int kLayers = kLayerConfig.size();
static constexpr int kModelDim = 1152;
static constexpr int kFFHiddenDim = 4304;
static constexpr int kHeads = 16;
static constexpr int kKVHeads = 16; // standard MHA
static constexpr int kQKVDim = 72;
static constexpr int kSeqLen = 16 * 16; // 256
static constexpr bool kFFBiases = true;
// The Vit part does not have a vocabulary, the image patches are embedded.
static constexpr int kVocabSize = 0;
// Dimensions related to image processing.
static constexpr int kPatchWidth = 14;
static constexpr int kImageSize = 224;
// Necessary constant for the layer configuration.
static constexpr PostNormType kPostNorm = PostNormType::None;
};
};
template <typename TWeight>
struct ConfigGemma2_2B : public ConfigBaseGemmaV2 {
using Weight = TWeight; // make accessible where we only have a TConfig
@ -297,7 +344,7 @@ struct ConfigGemmaTiny : public ConfigNoSSM {
};
template <typename TWeight>
struct ConfigGriffin2B {
struct ConfigGriffin2B : ConfigNoVit {
using Weight = TWeight; // make accessible where we only have a TConfig
// Griffin uses local attention, so kSeqLen is actually the local attention

View File

@ -41,6 +41,7 @@
#include "ops/matmul-inl.h"
#include "ops/matvec-inl.h"
#include "ops/ops-inl.h"
#include "paligemma/image.h"
#include "util/allocator.h"
#include "util/threading.h"
#include "hwy/aligned_allocator.h"
@ -209,7 +210,7 @@ class GemmaAttention {
static constexpr size_t kSeqLen = TConfig::kSeqLen;
static constexpr bool kIsMHA = Activations::IsMHA<TConfig>();
// The attention window usually starts at 0 unless unless `pos` is larger than
// The attention window usually starts at 0 unless `pos` is larger than
// the attention window size, then it is `pos` - window_size + 1.
static HWY_INLINE size_t StartPos(size_t pos, size_t layer) {
const size_t att_window_size = TConfig::kAttentionWindowSizes[layer];
@ -318,26 +319,26 @@ class GemmaAttention {
}
// Computes Q.K scores, which are "logits" (or scores) stored to head_att.
HWY_INLINE void QDotK(const size_t start_pos, const size_t pos,
HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos,
const size_t head_offset, const float* HWY_RESTRICT q,
const KVCache& kv_cache, float* HWY_RESTRICT head_att) {
if (HWY_LIKELY(pos < kSeqLen)) {
if (HWY_LIKELY(last_pos < kSeqLen)) {
// Slightly faster: no wraparound.
for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) {
for (size_t pos = start_pos; pos <= last_pos; ++pos) {
const size_t kv_offset =
pos2 * kCachePosSize + layer_ * kCacheLayerSize + head_offset;
pos * kCachePosSize + layer_ * kCacheLayerSize + head_offset;
const float* HWY_RESTRICT k = &kv_cache.kv_cache[kv_offset];
const float score = Dot(q, k, kQKVDim);
head_att[pos2] = score;
head_att[pos] = score;
}
} else {
for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) {
const size_t cache_pos = div_seq_len_.Remainder(pos2);
for (size_t pos = start_pos; pos <= last_pos; ++pos) {
const size_t cache_pos = div_seq_len_.Remainder(pos);
const size_t kv_offset =
cache_pos * kCachePosSize + layer_ * kCacheLayerSize + head_offset;
const float* HWY_RESTRICT k = &kv_cache.kv_cache[kv_offset];
const float score = Dot(q, k, kQKVDim);
head_att[pos2 % kSeqLen] = score;
head_att[pos % kSeqLen] = score;
}
}
}
@ -345,32 +346,30 @@ class GemmaAttention {
// Accumulates the sum of v (from `kv_cache`) * probability (`head_att`) into
// `att_out`. Equivalent in gemma/modules.py:
// encoded = jnp.einsum('BTNS,BSNH->BTNH', probs, value_proj)
static HWY_INLINE void WeightedSumV(const size_t start_pos, const size_t pos,
const float* HWY_RESTRICT head_att,
const size_t layer,
const size_t head_offset,
const hwy::Divisor& div_seq_len,
const KVCache& kv_cache,
float* HWY_RESTRICT att_out) {
static HWY_INLINE void WeightedSumV(
const size_t start_pos, const size_t last_pos,
const float* HWY_RESTRICT head_att, const size_t layer,
const size_t head_offset, const hwy::Divisor& div_seq_len,
const KVCache& kv_cache, float* HWY_RESTRICT att_out) {
hwy::ZeroBytes(att_out, kQKVDim * sizeof(*att_out));
if (HWY_LIKELY(pos < kSeqLen)) {
if (HWY_LIKELY(last_pos < kSeqLen)) {
// Slightly faster: no wraparound.
for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) {
for (size_t pos = start_pos; pos <= last_pos; ++pos) {
const size_t kv_offset =
pos2 * kCachePosSize + layer * kCacheLayerSize + head_offset;
pos * kCachePosSize + layer * kCacheLayerSize + head_offset;
const float* HWY_RESTRICT v =
kv_cache.kv_cache.get() + kv_offset + kQKVDim;
MulByConstAndAdd(head_att[pos2], v, att_out, kQKVDim);
MulByConstAndAdd(head_att[pos], v, att_out, kQKVDim);
}
} else {
for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) {
const size_t cache_pos = div_seq_len.Remainder(pos2);
for (size_t pos = start_pos; pos <= last_pos; ++pos) {
const size_t cache_pos = div_seq_len.Remainder(pos);
const size_t kv_offset =
cache_pos * kCachePosSize + layer * kCacheLayerSize + head_offset;
const float* HWY_RESTRICT v =
kv_cache.kv_cache.get() + kv_offset + kQKVDim;
MulByConstAndAdd(head_att[pos2 % kSeqLen], v, att_out, kQKVDim);
MulByConstAndAdd(head_att[pos % kSeqLen], v, att_out, kQKVDim);
}
}
}
@ -402,20 +401,26 @@ class GemmaAttention {
PositionalEncodingQK(q, pos, layer_, kQueryScale, q);
const size_t start_pos = StartPos(pos, layer_);
size_t last_pos = pos;
const size_t prefix_end = queries_prefix_end_[query_idx];
if (prefix_end > 0 && prefix_end - 1 > last_pos) {
// last_pos in QDotK and WeightedSumV is inclusive.
last_pos = prefix_end - 1;
}
float* HWY_RESTRICT head_att =
activations_.att.Batch(interleaved_idx) + head * kSeqLen;
QDotK(start_pos, pos, head_offset, q, kv_cache, head_att);
QDotK(start_pos, last_pos, head_offset, q, kv_cache, head_att);
// SoftMax with optional SoftCap yields "probabilities" in
// head_att.
const size_t head_att_len = std::min(pos + 1, kSeqLen);
const size_t head_att_len = std::min(last_pos + 1, kSeqLen);
MaybeLogitsSoftCap(TConfig::kAttCap, head_att, head_att_len);
Softmax(head_att, head_att_len);
float* HWY_RESTRICT att_out =
activations_.att_out.Batch(interleaved_idx) +
head * kQKVDim;
WeightedSumV(start_pos, pos, head_att, layer_, head_offset,
WeightedSumV(start_pos, last_pos, head_att, layer_, head_offset,
div_seq_len_, kv_cache, att_out);
});
}
@ -435,16 +440,40 @@ class GemmaAttention {
MatMul<kAdd>(
num_interleaved, ConstMat(activations_.att_out.All(), kHeads * kQKVDim),
ConstMat(layer_weights_.att_weights.data(), kHeads * kQKVDim),
layer_weights_.attn_vec_einsum_w.scale(), bias, activations_.env,
layer_weights_.att_weights.scale(), bias, activations_.env,
MutableMat(activations_.att_sums.All(), kModelDim));
}
public:
// Constructor with explicit initialization of queries_prefix_end. This is
// needed for the Prefix-LM style attention. For standard causal attention,
// the other constructor can be used.
GemmaAttention(const QueriesPos& queries_pos,
const QueriesPos& queries_prefix_end, size_t num_tokens,
size_t layer, Activations& activations,
const CompressedLayer<TConfig>* layer_weights,
const hwy::Divisor& div_seq_len, const KVCaches& kv_caches)
: queries_pos_(queries_pos),
queries_prefix_end_(queries_prefix_end),
num_queries_(queries_pos.size()),
num_tokens_(num_tokens),
layer_(layer),
activations_(activations),
layer_weights_(*layer_weights),
div_seq_len_(div_seq_len),
kv_caches_(kv_caches),
pool_(activations.env.Pool()) {
HWY_DASSERT(num_queries_ <= kv_caches_.size());
}
// Constructor with default initialization to 0 for queries_prefix_end.
GemmaAttention(const QueriesPos& queries_pos, size_t num_tokens, size_t layer,
Activations& activations,
const CompressedLayer<TConfig>* layer_weights,
const hwy::Divisor& div_seq_len, const KVCaches& kv_caches)
: queries_pos_(queries_pos),
queries_prefix_end_vec_(queries_pos.size(), 0),
queries_prefix_end_(queries_prefix_end_vec_.data(),
queries_prefix_end_vec_.size()),
num_queries_(queries_pos.size()),
num_tokens_(num_tokens),
layer_(layer),
@ -456,6 +485,7 @@ class GemmaAttention {
HWY_DASSERT(num_queries_ <= kv_caches_.size());
}
// Full attention computation in three steps.
HWY_INLINE void operator()() {
const size_t num_interleaved = num_tokens_ * num_queries_;
ComputeQKV(num_interleaved);
@ -465,6 +495,8 @@ class GemmaAttention {
private:
const QueriesPos& queries_pos_;
const std::vector<size_t> queries_prefix_end_vec_;
const QueriesPos queries_prefix_end_;
const size_t num_queries_;
const size_t num_tokens_;
const size_t layer_;
@ -476,15 +508,15 @@ class GemmaAttention {
};
template <class TConfig>
HWY_NOINLINE void Attention(LayerAttentionType type,
const QueriesPos& queries_pos, size_t num_tokens,
size_t layer, Activations& activations,
const CompressedLayer<TConfig>* layer_weights,
const hwy::Divisor& div_seq_len,
const KVCaches& kv_caches) {
HWY_NOINLINE void Attention(
LayerAttentionType type, const QueriesPos& queries_pos,
const QueriesPos& queries_prefix_end, size_t num_tokens, size_t layer,
Activations& activations, const CompressedLayer<TConfig>* layer_weights,
const hwy::Divisor& div_seq_len, const KVCaches& kv_caches) {
if (type == LayerAttentionType::kGemma) {
GemmaAttention<TConfig>(queries_pos, num_tokens, layer, activations,
layer_weights, div_seq_len, kv_caches)();
GemmaAttention<TConfig>(queries_pos, queries_prefix_end, num_tokens, layer,
activations, layer_weights, div_seq_len,
kv_caches)();
} else {
// Only reached if the model is Griffin. `if constexpr` prevents generating
// this code for non-Griffin models.
@ -496,6 +528,115 @@ HWY_NOINLINE void Attention(LayerAttentionType type,
}
}
// Wrapper class; holds arguments in member variables to shorten call sites.
// The main differences to GemmaAttention are:
// - no KV Cache necessary, attention is always all-to-all and not causal.
// - no potential wrap-around, attention always goes from 0 to kSeqLen.
// - no need for batching, as we are always computing attention for kSeqLen
// tokens.
// This results in a much simpler implementation. However, to avoid duplicating
// code, we should still consider merging the two classes.
// TODO(keysers): Refactor to share code with GemmaAttention.
template <class TConfig>
class VitAttention {
static constexpr size_t kHeads = TConfig::kHeads;
static constexpr size_t kKVHeads = TConfig::kKVHeads;
static constexpr size_t kModelDim = TConfig::kModelDim;
static constexpr size_t kQKVDim = TConfig::kQKVDim;
static constexpr size_t kQStride = 3 * kQKVDim;
static constexpr size_t kSeqLen = TConfig::kSeqLen;
// Computes Q, K, V for all heads, stored in activations_.q.
HWY_NOINLINE void ComputeQKV() {
PROFILER_ZONE("Gen.VitAttention.QKV");
const auto y =
ConstMat(activations_.pre_att_rms_out.All(), kModelDim);
auto& qkv = activations_.q;
HWY_ASSERT(qkv.BatchSize() == num_tokens_);
HWY_ASSERT(qkv.Len() == kHeads * kQStride);
MatMul</*kAdd=*/true>(
num_tokens_, y,
ConstMat(layer_weights_.vit.qkv_einsum_w.data_scale1(), kModelDim),
/*scale=*/1.0f, layer_weights_.vit.qkv_einsum_b.data_scale1(),
activations_.env, MutableMat(qkv.All(), qkv.Len()));
}
HWY_NOINLINE void DotSoftmaxWeightedSum() {
GEMMA_CONSTEXPR_SQRT float kQueryScale =
1.0f / Sqrt(static_cast<float>(TConfig::kQKVDim));
PROFILER_ZONE("Gen.VitAttention.DotSoftmax");
// A "head group" in the context of GQA refers to a collection of query
// heads that share the same key and value heads.
static_assert(kHeads == kKVHeads, "Vit expects MHA");
// Compute Q.K, softmax, and weighted V.
pool_.Run(0, kHeads * num_tokens_,
[&](uint64_t task, size_t /*thread*/) HWY_ATTR {
const size_t head = task % kHeads;
const size_t token = task / kHeads;
// Compute Q.K scores, which are "logits" stored in head_att.
float* HWY_RESTRICT q =
activations_.q.Batch(token) + head * kQStride;
MulByConst(kQueryScale, q, kQKVDim);
float* HWY_RESTRICT head_att =
activations_.att.Batch(token) + head * kSeqLen;
for (size_t i = 0; i < kSeqLen; ++i) {
float* HWY_RESTRICT k =
activations_.q.Batch(i) + head * kQStride + kQKVDim;
head_att[i] = Dot(q, k, kQKVDim); // score = q.k
}
// SoftMax yields "probabilities" in head_att.
Softmax(head_att, kSeqLen);
// Compute weighted sum of v into att_out.
float* HWY_RESTRICT att_out =
activations_.att_out.Batch(token) + head * kQKVDim;
hwy::ZeroBytes(att_out, kQKVDim * sizeof(*att_out));
for (size_t i = 0; i < kSeqLen; ++i) {
float* HWY_RESTRICT v =
activations_.q.Batch(i) + head * kQStride + 2 * kQKVDim;
MulByConstAndAdd(head_att[i], v, att_out, kQKVDim);
}
});
}
// Sums encoded (`att_out`) over num_heads (`kHeads`) and head_dim (`kQKVDim`)
// into output (`att_sums`).
HWY_NOINLINE void SumHeads() {
PROFILER_ZONE("Gen.VitAttention.SumHeads");
auto* bias = layer_weights_.vit.attn_out_b.data_scale1();
auto att_out = ConstMat(activations_.att_out.All(), kHeads * kQKVDim);
auto att_weights = ConstMat(layer_weights_.vit.attn_out_w.data_scale1(),
kHeads * kQKVDim);
auto att_sums = MutableMat(activations_.att_sums.All(), kModelDim);
// att_weights and att_out are concatenated heads, each of length kQKVDim.
// Thus the [num_tokens_, kModelDim] matmul output is the sum over heads.
MatMul</*kAdd=*/true>(num_tokens_, att_out, att_weights, /*scale=*/1.0f,
bias, activations_.env, att_sums);
}
public:
VitAttention(size_t num_tokens, size_t layer, Activations& activations,
const CompressedLayer<TConfig>* layer_weights)
: num_tokens_(num_tokens),
layer_(layer),
activations_(activations),
layer_weights_(*layer_weights),
pool_(activations.env.Pool()) {}
HWY_INLINE void operator()() {
ComputeQKV();
DotSoftmaxWeightedSum();
SumHeads();
}
private:
const size_t num_tokens_;
const size_t layer_;
Activations& activations_;
const CompressedLayer<TConfig>& layer_weights_;
hwy::ThreadPool& pool_;
};
template <class TConfig, typename T>
HWY_NOINLINE void Activation(T* HWY_RESTRICT c1, T* HWY_RESTRICT c2,
size_t count) {
@ -504,6 +645,11 @@ HWY_NOINLINE void Activation(T* HWY_RESTRICT c1, T* HWY_RESTRICT c2,
using DF = hn::ScalableTag<T>;
using VF = hn::Vec<DF>;
// ActivationType::Gelu
if (c2 == nullptr) { // No multiplier, just Gelu.
Gelu(c1, count);
return;
};
// Has multiplier, Gelu(c1) * c2.
hn::Transform1(DF(), c1, count, c2, [](DF df, VF v, VF mul) HWY_ATTR {
return hn::Mul(mul, Gelu(df, v));
});
@ -515,41 +661,66 @@ HWY_NOINLINE void FFW(Activations& activations, size_t num_interleaved,
PROFILER_ZONE("Gen.FFW");
constexpr size_t kModelDim = TConfig::kModelDim;
constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim;
// MatMul expects col-major B, which is what we have: kModelDim consecutive
// elements in memory, repeated kFFHiddenDim times.
HWY_DASSERT(num_interleaved <= activations.bf_pre_ffw_rms_out.BatchSize());
const auto A = ConstMat(activations.bf_pre_ffw_rms_out.All(), kModelDim);
const auto B1 = ConstMat(layer_weights->gating_einsum_w.data(), kModelDim);
const auto B2 = ConstMat(layer_weights->gating_einsum_w.data(), kModelDim,
kModelDim, kModelDim * kFFHiddenDim);
const float scale = layer_weights->gating_einsum_w.scale();
constexpr bool kAddBias = TConfig::kFFBiases;
constexpr bool kIsVit = TConfig::kLayerConfig[0] == LayerAttentionType::kVit;
using WeightType =
hwy::If<kIsVit,
typename CompressedLayer<TConfig>::WeightF32OrBF16,
typename CompressedLayer<TConfig>::Weight>;
HWY_DASSERT(num_interleaved <= activations.bf_pre_ffw_rms_out.BatchSize());
// Define slightly more readable names for the weights and activations.
const auto x = ConstMat(activations.bf_pre_ffw_rms_out.All(), kModelDim);
Mat<const WeightType> w1;
const float* bias1 = nullptr;
Mat<const WeightType> w2;
const float* bias2 = nullptr;
float scale = 1.0f;
Mat<const WeightType> w_output;
const float* output_bias = nullptr;
if constexpr (kAddBias) {
float output_scale = 1.0f;
auto hidden_activations = MutableMat(activations.C1.All(), kFFHiddenDim);
auto multiplier = MutableMat(activations.C2.All(), kFFHiddenDim);
auto ffw_out = MutableMat(activations.ffw_out.All(), kModelDim);
// For some of the weights and activations, it depends on the config where to
// get them from or whether to use them at all.
if constexpr (kAddBias && !kIsVit) {
bias1 = layer_weights->ffw_gating_biases.data_scale1();
bias2 = bias1 + kFFHiddenDim;
output_bias = layer_weights->ffw_output_biases.data_scale1();
}
auto C1 = MutableMat(activations.C1.All(), kFFHiddenDim);
auto C2 = MutableMat(activations.C2.All(), kFFHiddenDim);
if constexpr (!kIsVit) {
w1 = ConstMat(layer_weights->gating_einsum_w.data(), kModelDim);
w2 = ConstMat(layer_weights->gating_einsum_w.data(), kModelDim, kModelDim,
kModelDim * kFFHiddenDim);
scale = layer_weights->gating_einsum_w.scale();
w_output = ConstMat(layer_weights->linear_w.data(), kFFHiddenDim);
output_scale = layer_weights->linear_w.scale();
} else {
w1 = ConstMat(layer_weights->vit.linear_0_w.data_scale1(), kModelDim);
bias1 = layer_weights->vit.linear_0_b.data_scale1();
multiplier.ptr = nullptr;
w_output =
ConstMat(layer_weights->vit.linear_1_w.data_scale1(), kFFHiddenDim);
output_bias = layer_weights->vit.linear_1_b.data_scale1();
}
// Will go through GELU.
MatMul<kAddBias>(num_interleaved, A, B1, scale, bias1, activations.env, C1);
// What to multiply by.
MatMul<kAddBias>(num_interleaved, A, B2, scale, bias2, activations.env, C2);
// Compute the hidden layer activations.
MatMul<kAddBias>(num_interleaved, x, w1, scale, bias1, activations.env,
hidden_activations);
if constexpr (!kIsVit) {
MatMul<kAddBias>(num_interleaved, x, w2, scale, bias2, activations.env,
multiplier);
}
// Activation (Gelu) and multiply by gate. Store activations in C1.
Activation<TConfig>(C1.ptr, C2.ptr, kFFHiddenDim * num_interleaved);
// Activation (Gelu) and maybe multiply by gate. Store activations in act.
Activation<TConfig>(hidden_activations.ptr, multiplier.ptr,
kFFHiddenDim * num_interleaved);
// Hidden layer -> output layer.
MatMul<kAddBias>(num_interleaved, ConstMat(C1),
ConstMat(layer_weights->linear_w.data(), kFFHiddenDim),
layer_weights->linear_w.scale(), output_bias,
activations.env,
MutableMat(activations.ffw_out.All(), kModelDim));
MatMul<kAddBias>(num_interleaved, ConstMat(hidden_activations), w_output,
output_scale, output_bias, activations.env, ffw_out);
}
// `batch_idx` indicates which row of `x` to write to.
@ -557,8 +728,17 @@ HWY_NOINLINE void FFW(Activations& activations, size_t num_interleaved,
// called for batches of tokens in prefill, but batches of queries in decode.
template <class TConfig>
HWY_NOINLINE void EmbedToken(int token, size_t batch_idx, size_t pos,
size_t pos_in_prompt,
const CompressedWeights<TConfig>& weights,
RowVectorBatch<float>& x) {
RowVectorBatch<float>& x,
const ImageTokens* image_tokens) {
// Image tokens just need to be copied.
if (image_tokens != nullptr && pos_in_prompt < image_tokens->BatchSize()) {
hwy::CopyBytes(image_tokens->Batch(pos_in_prompt), x.Batch(batch_idx),
x.Len() * sizeof(x.Const()[0]));
return;
}
constexpr size_t kModelDim = TConfig::kModelDim;
constexpr size_t kVocabSize = TConfig::kVocabSize;
GEMMA_CONSTEXPR_EMBSCALING const float kEmbScaling =
@ -598,7 +778,8 @@ void PostNorm(size_t num_interleaved, const WeightT& weights, InOutT* inout) {
template <class TConfig>
HWY_NOINLINE void TransformerLayer(
const QueriesPos& queries_pos, size_t num_tokens, size_t layer,
const QueriesPos& queries_pos, const QueriesPos& queries_prefix_end,
size_t num_tokens, size_t layer,
const CompressedLayer<TConfig>* layer_weights, Activations& activations,
const hwy::Divisor& div_seq_len, const KVCaches& kv_caches) {
constexpr size_t kModelDim = TConfig::kModelDim;
@ -611,8 +792,9 @@ HWY_NOINLINE void TransformerLayer(
layer_weights->pre_attention_norm_scale.data_scale1(),
activations.pre_att_rms_out.All(), kModelDim);
Attention<TConfig>(type, queries_pos, num_tokens, layer_of_type, activations,
layer_weights, div_seq_len, kv_caches);
Attention<TConfig>(type, queries_pos, queries_prefix_end, num_tokens,
layer_of_type, activations, layer_weights, div_seq_len,
kv_caches);
PostNorm<TConfig>(num_interleaved, layer_weights->post_attention_norm_scale,
activations.att_sums.All());
@ -635,6 +817,51 @@ HWY_NOINLINE void TransformerLayer(
/*is_attention=*/false);
}
// Vit transformer layer. Some comments below refer to the Vit implementation in
// the Big Vision codebase. See
// github.com/google-research/big_vision/blob/main/big_vision/models/vit.py
// TODO(keysers): consider adding a wrapper for both LayerNorm with RMSNorm and
// try mergig this with TransformerLayer.
template <class TConfig>
HWY_NOINLINE void VitTransformerLayer(
size_t num_tokens, size_t layer,
const CompressedLayer<TConfig>* layer_weights, Activations& activations) {
constexpr size_t kModelDim = TConfig::kModelDim;
auto type = TConfig::kLayerConfig[layer];
HWY_ASSERT(type == LayerAttentionType::kVit);
auto& x = activations.x;
HWY_ASSERT(x.BatchSize() == num_tokens);
HWY_ASSERT(x.Len() == kModelDim);
// y = nn.LayerNorm()(x)
// y ~ pre_att_rms_out
LayerNormBatched(num_tokens, x.All(),
layer_weights->vit.layer_norm_0_scale.data_scale1(),
layer_weights->vit.layer_norm_0_bias.data_scale1(),
activations.pre_att_rms_out.All(), kModelDim);
// y = out["sa"] = nn.MultiHeadDotProductAttention(...)(y, y)
// y ~ att_sums
VitAttention<TConfig>(num_tokens, layer, activations, layer_weights)();
// x = out["+sa"] = x + y
AddFromBatched(num_tokens, activations.att_sums.All(), x.All(), kModelDim);
// y = nn.LayerNorm()(x)
// y ~ bf_pre_ffw_rms_out
LayerNormBatched(num_tokens, x.All(),
layer_weights->vit.layer_norm_1_scale.data_scale1(),
layer_weights->vit.layer_norm_1_bias.data_scale1(),
activations.bf_pre_ffw_rms_out.All(), kModelDim);
// y = out["mlp"] = MlpBlock(...)(y)
// y ~ ffw_out
FFW<TConfig>(activations, num_tokens, layer_weights);
// x = out["+mlp"] = x + y
AddFromBatched(num_tokens, activations.ffw_out.All(), x.All(), kModelDim);
}
// Prefill() and Transformer() increment positions in-place.
using QueriesMutablePos = hwy::Span<size_t>;
@ -642,13 +869,14 @@ using QueriesMutablePos = hwy::Span<size_t>;
template <class TConfig>
HWY_NOINLINE void Prefill(
const QueriesPromptTokens& queries_prompt,
const QueriesMutablePos& queries_pos, const size_t query_idx_start,
const CompressedWeights<TConfig>& weights, Activations& activations,
const RuntimeConfig& runtime_config, const hwy::Divisor& div_seq_len,
const KVCaches& kv_caches) {
const QueriesMutablePos& queries_pos, const QueriesPos& queries_prefix_end,
const size_t query_idx_start, const CompressedWeights<TConfig>& weights,
Activations& activations, const RuntimeConfig& runtime_config,
const hwy::Divisor& div_seq_len, const KVCaches& kv_caches) {
PROFILER_ZONE("Gen.Prefill");
const size_t num_queries = queries_prompt.size();
HWY_ASSERT(queries_pos.size() == num_queries);
HWY_ASSERT(queries_prefix_end.size() == num_queries);
HWY_ASSERT(kv_caches.size() == num_queries);
// Batches are important for amortizing loading weights over multiple tokens.
@ -667,27 +895,50 @@ HWY_NOINLINE void Prefill(
// Single query at a time, so pass slices of the spans because
// GemmaAttention will only access the first KV cache and position.
QueriesPos single_query_pos(&queries_pos[qi], 1);
QueriesPos single_query_prefix_end(&queries_prefix_end[qi], 1);
KVCaches single_kv_cache(&kv_caches[qi], 1);
const size_t prefill_per_query = queries_prompt[qi].size() - 1;
const size_t prompt_size = queries_prompt[qi].size();
// In autoregressive mode, we don't need to prefill the last token, so - 1.
size_t prefill_this_query = prompt_size - 1;
const size_t prefix_end_this_query = queries_prefix_end[qi];
// We can't attend beyond the prompt_size.
HWY_ASSERT(prefix_end_this_query <= prompt_size);
// Special case: if the prefix includes the last token, we need to prefill
// the last token, too. However, we need to rewind this for the generation
// of the first token. So we need to keep track of this.
// TODO: consider implementing masking instead of this logic?
bool attend_to_last_token = (prefill_this_query < prefix_end_this_query);
if (attend_to_last_token) {
// The difference can be at most 1.
prefill_this_query += 1;
HWY_ASSERT(prefill_this_query == prefix_end_this_query);
}
// In prefix-LM mode, we need to look at all the tokens for the prefix in
// one iteration through the layers, so we need a large enough batch size.
HWY_ASSERT(max_tbatch_size >= prefill_this_query);
// For each batch of tokens in the query:
for (size_t tbatch_start = 0; tbatch_start < prefill_per_query;
for (size_t tbatch_start = 0; tbatch_start < prefill_this_query;
tbatch_start += max_tbatch_size) {
// Fill activations.x (much faster than TransformerLayer).
const size_t tbatch_size =
HWY_MIN(max_tbatch_size, prefill_per_query - tbatch_start);
HWY_MIN(max_tbatch_size, prefill_this_query - tbatch_start);
// Fill activations.x (much faster than TransformerLayer).
for (size_t ti = 0; ti < tbatch_size; ++ti) {
const int token = queries_prompt[qi][tbatch_start + ti];
const size_t pos = queries_pos[qi] + ti;
EmbedToken<TConfig>(token, ti, pos, weights, activations.x);
const size_t pos_in_prompt = tbatch_start + ti;
const int token = queries_prompt[qi][pos_in_prompt];
EmbedToken<TConfig>(token, ti, pos, pos_in_prompt, weights,
activations.x, runtime_config.image_tokens);
}
// Transformer with one batch of tokens from a single query.
for (size_t layer = 0; layer < TConfig::kLayers; ++layer) {
const auto* layer_weights = weights.GetLayer(layer);
TransformerLayer<TConfig>(single_query_pos, tbatch_size, layer,
layer_weights, activations, div_seq_len,
single_kv_cache);
TransformerLayer<TConfig>(single_query_pos, single_query_prefix_end,
tbatch_size, layer, layer_weights,
activations, div_seq_len, single_kv_cache);
}
// NOTE: we unconditionally call StreamToken, even if EOS.
@ -695,19 +946,111 @@ HWY_NOINLINE void Prefill(
const size_t pos = queries_pos[qi] + ti;
const size_t pos_in_prompt = tbatch_start + ti;
const int token = queries_prompt[qi][pos_in_prompt];
runtime_config.StreamToken(query_idx_start + qi, pos, token, 0.0f);
if (pos_in_prompt < prompt_size - 1) {
runtime_config.StreamToken(query_idx_start + qi, pos, token, 0.0f);
} else {
// The last token will be streamed later and we should only get here
// if we need to attend to the last token because it is in the prefix.
HWY_ASSERT(attend_to_last_token);
}
}
queries_pos[qi] += tbatch_size;
} // for tbatch_start
if (attend_to_last_token) {
// We need to rewind the position for the last token that we only
// attended to to make sure the prefix LM sees everything.
// This means we duplicate work on the last prompt token in autoregressive
// decoding. Alternatives: (1) real masking; (2) always prefill the last
// token and only generate the next one from the already prefilled
// activations.
queries_pos[qi] -= 1;
}
}
}
// Gets the patches of the image and embeds them with the image embedding
// kernel. The result is stored in activations.x.
template <class TConfig>
HWY_NOINLINE void EmbedImagePatches(const Image& image,
const CompressedWeights<TConfig>& weights,
Activations& activations) {
static constexpr size_t kModelDim = TConfig::VitConfig::kModelDim;
static constexpr size_t kPatchWidth = TConfig::VitConfig::kPatchWidth;
static constexpr size_t kSeqLen = TConfig::VitConfig::kSeqLen;
constexpr size_t kPatchSize = kPatchWidth * kPatchWidth * 3;
HWY_ASSERT(weights.vit_img_embedding_kernel.NumElements() ==
kPatchSize * kModelDim);
HWY_ASSERT(activations.x.Len() == kModelDim);
std::vector<hwy::AlignedFreeUniquePtr<float[]>> image_patches(kSeqLen);
for (size_t i = 0; i < kSeqLen; ++i) {
image_patches[i] = hwy::AllocateAligned<float>(kPatchSize);
image.GetPatch(i, image_patches[i].get());
}
// img/embedding/kernel has original shape (14, 14, 3, 1152)
// H x W x C x D transposed to D x (H x W x C) so here (1152, 14 * 14 * 3)
// image_patches is (256, 14 * 14 * 3)
// This could be done as one MatMul like:
// RowVectorBatch<float> image_patches(kSeqLen, kPatchSize);
// [Get patches]
// MatMul</*kAdd=*/true>(
// kVitSeqLen, ConstMat(image_patches.All(), kPatchSize),
// ConstMat(weights.vit_img_embedding_kernel.data_scale1(), kPatchSize),
// /*scale=*/1.0f, weights.vit_img_embedding_bias.data_scale1(),
// activations.env, MutableMat(activations.x.All(), kVitModelDim));
// However, MatMul currently requires that
// A.cols % (2 * hn::Lanes(hn::ScalableTag<MulT>())) == 0
// which is not the case here. We should relax that requirement on MatMul and
// then use the above. For now, we rely on MatVecAdd instead.
for (size_t i = 0; i < kSeqLen; ++i) {
MatVecAdd<kModelDim, kPatchSize>(
weights.vit_img_embedding_kernel, 0, image_patches[i].get(),
weights.vit_img_embedding_bias.data_scale1(), activations.x.Batch(i),
activations.env.Pools().Outer());
}
// Add position embeddings.
AddFrom(weights.vit_img_pos_embedding.data_scale1(), activations.x.All(),
kSeqLen * kModelDim);
}
// Prefills the image tokens with the ViT encoder.
template <class TConfig>
HWY_NOINLINE void PrefillVit(const CompressedWeights<TConfig>& weights,
const RuntimeConfig& runtime_config,
const Image& image, ImageTokens& image_tokens,
Activations& activations) {
PROFILER_ZONE("Gen.PrefillVit");
const size_t num_tokens = TConfig::VitConfig::kSeqLen;
const size_t kVitModelDim = TConfig::VitConfig::kModelDim;
HWY_ASSERT(num_tokens == activations.x.BatchSize());
// Embed the image patches.
EmbedImagePatches<TConfig>(image, weights, activations);
// Go through all layers.
for (size_t layer = 0; layer < TConfig::VitConfig::kLayers; ++layer) {
const auto* layer_weights = weights.GetVitLayer(layer);
VitTransformerLayer<typename TConfig::VitConfig>(
num_tokens, layer, layer_weights, activations);
}
// Final Layernorm.
LayerNormBatched(num_tokens, activations.x.All(),
weights.vit_encoder_norm_scale.data_scale1(),
weights.vit_encoder_norm_bias.data_scale1(),
activations.x.All(), kVitModelDim);
// Apply head embedding into image_tokens of size of the LLM kModelDim.
MatMul</*kAdd=*/true>(
num_tokens, ConstMat(activations.x.All(), kVitModelDim),
ConstMat(weights.vit_img_head_kernel.data_scale1(), kVitModelDim),
/*scale=*/1.0f, weights.vit_img_head_bias.data_scale1(), activations.env,
MutableMat(image_tokens.All(), TConfig::kModelDim));
}
// Generates one token for each query. `queries_token` is the previous token
// from each query, and `queries_pos` are their position in the sequence.
template <class TConfig>
HWY_NOINLINE void Transformer(
const QueriesToken& queries_token, const QueriesMutablePos& queries_pos,
const QueriesPos& queries_prefix_end,
const CompressedWeights<TConfig>& weights, Activations& activations,
const hwy::Divisor& div_seq_len, const KVCaches& kv_caches,
const LayersOutputFunc& layers_output,
@ -715,6 +1058,7 @@ HWY_NOINLINE void Transformer(
constexpr size_t kModelDim = TConfig::kModelDim;
const size_t num_queries = queries_token.size();
HWY_DASSERT(queries_pos.size() == num_queries);
HWY_DASSERT(queries_prefix_end.size() == num_queries);
if (layers_output) {
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
@ -726,13 +1070,14 @@ HWY_NOINLINE void Transformer(
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
EmbedToken<TConfig>(queries_token[query_idx], query_idx,
queries_pos[query_idx], weights, activations.x);
queries_pos[query_idx], /*pos_in_prompt=*/0, weights,
activations.x, /*image_tokens=*/nullptr);
}
for (size_t layer = 0; layer < TConfig::kLayers; ++layer) {
const CompressedLayer<TConfig>* layer_weights = weights.GetLayer(layer);
TransformerLayer<TConfig>(queries_pos, /*num_tokens=*/1, layer,
layer_weights, activations, div_seq_len,
TransformerLayer<TConfig>(queries_pos, queries_prefix_end, /*num_tokens=*/1,
layer, layer_weights, activations, div_seq_len,
kv_caches);
if (activations_observer) {
@ -834,8 +1179,10 @@ template <class TConfig>
void GenerateT(const ByteStorageT& weights_u8, Activations& activations,
const RuntimeConfig& runtime_config,
const QueriesPromptTokens& queries_prompt,
const QueriesPos& queries_pos_in, const size_t query_idx_start,
const KVCaches& kv_caches, TimingInfo& timing_info) {
const QueriesPos& queries_pos_in,
const QueriesPos& queries_prefix_end,
const size_t query_idx_start, const KVCaches& kv_caches,
TimingInfo& timing_info) {
constexpr size_t kModelDim = TConfig::kModelDim;
constexpr size_t kVocabSize = TConfig::kVocabSize;
const CompressedWeights<TConfig>& weights =
@ -888,8 +1235,8 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations,
prefill_activations.Allocate<TConfig>(runtime_config.prefill_tbatch_size,
activations.env.Pools());
}
Prefill<TConfig>(queries_prompt, queries_mutable_pos, query_idx_start,
weights,
Prefill<TConfig>(queries_prompt, queries_mutable_pos, queries_prefix_end,
query_idx_start, weights,
use_prefill_activations ? prefill_activations : activations,
runtime_config, div_seq_len, kv_caches);
// Compute the number of tokens that were prefilled and notify timing_info.
@ -918,10 +1265,10 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations,
const double gen_start = hwy::platform::Now();
for (size_t gen = 0; gen < HWY_MIN(max_tokens, max_generated_tokens); ++gen) {
// Decode generates one token per query and increments queries_mutable_pos.
Transformer<TConfig>(QueriesToken(gen_tokens.data(), num_queries),
queries_mutable_pos, weights, activations, div_seq_len,
kv_caches, runtime_config.layers_output,
runtime_config.activations_observer);
Transformer<TConfig>(
QueriesToken(gen_tokens.data(), num_queries), queries_mutable_pos,
queries_prefix_end, weights, activations, div_seq_len, kv_caches,
runtime_config.layers_output, runtime_config.activations_observer);
// queries_pos are incremented by Transformer.
bool all_queries_eos = true;
@ -954,8 +1301,9 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations,
template <class TConfig>
void GenerateSingleT(const ByteStorageT& weights_u8,
const RuntimeConfig& runtime_config,
const PromptTokens& prompt, size_t pos, KVCache& kv_cache,
PerClusterPools& pools, TimingInfo& timing_info) {
const PromptTokens& prompt, size_t pos, size_t prefix_end,
KVCache& kv_cache, PerClusterPools& pools,
TimingInfo& timing_info) {
constexpr size_t kNumQueries = 1;
const size_t qbatch_start = 0;
@ -963,20 +1311,24 @@ void GenerateSingleT(const ByteStorageT& weights_u8,
Activations activations;
activations.Allocate<TConfig>(kNumQueries, pools);
const QueriesPromptTokens prompt_span(&prompt, kNumQueries);
QueriesPos pos_span(&pos, kNumQueries);
const QueriesPromptTokens queries_prompt(&prompt, kNumQueries);
QueriesPos queries_pos(&pos, kNumQueries);
const QueriesPos queries_prefix_end(&prefix_end, kNumQueries);
const KVCaches kv_caches{&kv_cache, kNumQueries};
GenerateT<TConfig>(weights_u8, activations, runtime_config, prompt_span,
pos_span, qbatch_start, kv_caches, timing_info);
GenerateT<TConfig>(weights_u8, activations, runtime_config, queries_prompt,
queries_pos, queries_prefix_end, qbatch_start, kv_caches,
timing_info);
}
template <class TConfig>
void GenerateBatchT(const ByteStorageT& weights_u8,
const RuntimeConfig& runtime_config,
const QueriesPromptTokens& queries_prompt,
const QueriesPos& queries_pos, const KVCaches& kv_caches,
PerClusterPools& pools, TimingInfo& timing_info) {
const QueriesPos& queries_pos,
const QueriesPos& queries_prefix_end,
const KVCaches& kv_caches, PerClusterPools& pools,
TimingInfo& timing_info) {
const size_t num_queries = queries_prompt.size();
HWY_ASSERT(queries_pos.size() == num_queries);
HWY_ASSERT(kv_caches.size() == num_queries);
@ -995,9 +1347,33 @@ void GenerateBatchT(const ByteStorageT& weights_u8,
const QueriesPromptTokens qbatch_prompts(&queries_prompt[qbatch_start],
qbatch_size);
QueriesPos qbatch_pos(&queries_pos[qbatch_start], qbatch_size);
const QueriesPos qbatch_prefix_end(&queries_prefix_end[qbatch_start],
qbatch_size);
const KVCaches qbatch_kv(&kv_caches[qbatch_start], qbatch_size);
GenerateT<TConfig>(weights_u8, activations, runtime_config, qbatch_prompts,
qbatch_pos, qbatch_start, qbatch_kv, timing_info);
qbatch_pos, qbatch_prefix_end, qbatch_start, qbatch_kv,
timing_info);
}
}
template <class TConfig>
void GenerateImageTokensT(const ByteStorageT& weights_u8,
const RuntimeConfig& runtime_config,
const Image& image, ImageTokens& image_tokens,
PerClusterPools& pools) {
if constexpr (TConfig::VitConfig::kLayers == 0) {
return;
} else {
Activations prefill_activations;
RuntimeConfig prefill_runtime_config = runtime_config;
prefill_runtime_config.prefill_tbatch_size = TConfig::VitConfig::kSeqLen;
prefill_activations.Allocate<typename TConfig::VitConfig>(
prefill_runtime_config.prefill_tbatch_size, pools);
// Weights are for the full PaliGemma model, not just the ViT part.
const CompressedWeights<TConfig>& weights =
*reinterpret_cast<const CompressedWeights<TConfig>*>(weights_u8.get());
PrefillVit<TConfig>(weights, prefill_runtime_config, image, image_tokens,
prefill_activations);
}
}
@ -1010,20 +1386,30 @@ void GenerateBatchT(const ByteStorageT& weights_u8,
void GenerateSingle( // NOLINT(misc-definitions-in-headers)
GEMMA_CONFIG, const ByteStorageT& weights_u8,
const RuntimeConfig& runtime_config, const PromptTokens& prompt, size_t pos,
KVCache& kv_cache, PerClusterPools& pools, TimingInfo& timing_info) {
size_t prefix_end, KVCache& kv_cache, PerClusterPools& pools,
TimingInfo& timing_info) {
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateSingleT<GEMMA_CONFIG>)
(weights_u8, runtime_config, prompt, pos, kv_cache, pools, timing_info);
(weights_u8, runtime_config, prompt, pos, prefix_end, kv_cache, pools,
timing_info);
}
void GenerateBatch( // NOLINT(misc-definitions-in-headers)
GEMMA_CONFIG, const ByteStorageT& weights_u8,
const RuntimeConfig& runtime_config,
const QueriesPromptTokens& queries_prompt, const QueriesPos& queries_pos,
const KVCaches& kv_caches, PerClusterPools& pools,
TimingInfo& timing_info) {
const QueriesPos& queries_prefix_end, const KVCaches& kv_caches,
PerClusterPools& pools, TimingInfo& timing_info) {
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateBatchT<GEMMA_CONFIG>)
(weights_u8, runtime_config, queries_prompt, queries_pos, kv_caches, pools,
timing_info);
(weights_u8, runtime_config, queries_prompt, queries_pos, queries_prefix_end,
kv_caches, pools, timing_info);
}
void GenerateImageTokens( // NOLINT(misc-definitions-in-headers)
GEMMA_CONFIG, const ByteStorageT& weights_u8,
const RuntimeConfig& runtime_config, const Image& image,
ImageTokens& image_tokens, PerClusterPools& pools) {
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateImageTokensT<GEMMA_CONFIG>)
(weights_u8, runtime_config, image, image_tokens, pools);
}
#endif // HWY_ONCE

View File

@ -24,10 +24,12 @@
#include <string.h>
#include <utility> // std::move
#include <vector>
#include "compression/io.h" // Path
#include "gemma/common.h"
#include "gemma/weights.h"
#include "paligemma/image.h"
#include "util/threading.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/highway.h"
@ -62,13 +64,18 @@ Gemma::~Gemma() {
extern void GenerateSingle(CONFIGT<TWEIGHT>, const ByteStorageT& weights_u8, \
const RuntimeConfig& runtime_config, \
const PromptTokens& prompt, size_t pos, \
KVCache& kv_cache, PerClusterPools& pools, \
TimingInfo& timing_info); \
size_t prefix_end, KVCache& kv_cache, \
PerClusterPools& pools, TimingInfo& timing_info); \
extern void GenerateBatch( \
CONFIGT<TWEIGHT>, const ByteStorageT& weights_u8, \
const RuntimeConfig& runtime_config, const QueriesPromptTokens& prompts, \
const QueriesPos& queries_pos, const KVCaches& kv_caches, \
PerClusterPools& pools, TimingInfo& timing_info);
const QueriesPos& queries_pos, \
const QueriesPos& queries_prefix_end, const KVCaches& kv_caches, \
PerClusterPools& pools, TimingInfo& timing_info); \
extern void GenerateImageTokens( \
CONFIGT<TWEIGHT>, const ByteStorageT& weights_u8, \
const RuntimeConfig& runtime_config, const Image& image, \
ImageTokens& image_tokens, PerClusterPools& pools);
GEMMA_FOREACH_CONFIG_AND_WEIGHT(GEMMA_DECLARE);
// Adapters to select from the above overloads via CallForModelAndWeight.
@ -76,10 +83,11 @@ template <class TConfig>
struct GenerateSingleT {
void operator()(const ByteStorageT& weights_u8,
const RuntimeConfig& runtime_config,
const PromptTokens& prompt, size_t pos, KVCache& kv_cache,
PerClusterPools& pools, TimingInfo& timing_info) const {
GenerateSingle(TConfig(), weights_u8, runtime_config, prompt, pos, kv_cache,
pools, timing_info);
const PromptTokens& prompt, size_t pos, size_t prefix_end,
KVCache& kv_cache, PerClusterPools& pools,
TimingInfo& timing_info) const {
GenerateSingle(TConfig(), weights_u8, runtime_config, prompt, pos,
prefix_end, kv_cache, pools, timing_info);
}
};
@ -88,21 +96,34 @@ struct GenerateBatchT {
void operator()(const ByteStorageT& weights_u8,
const RuntimeConfig& runtime_config,
const QueriesPromptTokens& queries_prompt,
const QueriesPos& queries_pos, const KVCaches& kv_caches,
PerClusterPools& pools, TimingInfo& timing_info) const {
const QueriesPos& queries_pos,
const QueriesPos& queries_prefix_end,
const KVCaches& kv_caches, PerClusterPools& pools,
TimingInfo& timing_info) const {
GenerateBatch(TConfig(), weights_u8, runtime_config, queries_prompt,
queries_pos, kv_caches, pools, timing_info);
queries_pos, queries_prefix_end, kv_caches, pools,
timing_info);
}
};
template <class TConfig>
struct GenerateImageTokensT {
void operator()(const ByteStorageT& weights_u8,
const RuntimeConfig& runtime_config, const Image& image,
ImageTokens& image_tokens, PerClusterPools& pools) const {
GenerateImageTokens(TConfig(), weights_u8, runtime_config, image,
image_tokens, pools);
}
};
void Gemma::Generate(const RuntimeConfig& runtime_config,
const PromptTokens& prompt, size_t pos, KVCache& kv_cache,
TimingInfo& timing_info) {
const PromptTokens& prompt, size_t pos, size_t prefix_end,
KVCache& kv_cache, TimingInfo& timing_info) {
if (runtime_config.use_spinning) pools_.StartSpinning();
CallForModelAndWeight<GenerateSingleT>(info_.model, info_.weight, weights_u8_,
runtime_config, prompt, pos, kv_cache,
pools_, timing_info);
CallForModelAndWeight<GenerateSingleT>(
info_.model, info_.weight, weights_u8_, runtime_config, prompt, pos,
prefix_end, kv_cache, pools_, timing_info);
if (runtime_config.use_spinning) pools_.StopSpinning();
}
@ -110,12 +131,33 @@ void Gemma::Generate(const RuntimeConfig& runtime_config,
void Gemma::GenerateBatch(const RuntimeConfig& runtime_config,
const QueriesPromptTokens& queries_prompt,
const QueriesPos& queries_pos,
const QueriesPos& queries_prefix_end,
const KVCaches& kv_caches, TimingInfo& timing_info) {
// If we did not get passed prefix ends (size 0), assume 0 and pass that on.
QueriesPos mutable_queries_prefix_end = queries_prefix_end;
std::vector<size_t> prefix_end_vec;
if (queries_prefix_end.size() == 0) {
prefix_end_vec.resize(queries_prompt.size(), 0);
mutable_queries_prefix_end =
QueriesPos(prefix_end_vec.data(), prefix_end_vec.size());
}
if (runtime_config.use_spinning) pools_.StartSpinning();
CallForModelAndWeight<GenerateBatchT>(
info_.model, info_.weight, weights_u8_, runtime_config, queries_prompt,
queries_pos, kv_caches, pools_, timing_info);
queries_pos, mutable_queries_prefix_end, kv_caches, pools_, timing_info);
if (runtime_config.use_spinning) pools_.StopSpinning();
}
void Gemma::GenerateImageTokens(const RuntimeConfig& runtime_config,
const Image& image, ImageTokens& image_tokens) {
if (runtime_config.use_spinning) pools_.StartSpinning();
CallForModelAndWeight<GenerateImageTokensT>(info_.model, info_.weight,
weights_u8_, runtime_config,
image, image_tokens, pools_);
if (runtime_config.use_spinning) pools_.StopSpinning();
}

View File

@ -27,6 +27,7 @@
#include "gemma/common.h"
#include "gemma/kv_cache.h"
#include "gemma/tokenizer.h"
#include "paligemma/image.h"
#include "util/allocator.h"
#include "util/threading.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
@ -75,6 +76,10 @@ using LayersOutputFunc = std::function<void(size_t, size_t, const std::string&,
using ActivationsObserverFunc =
std::function<void(const QueriesPos& queries_pos, int, const Activations&)>;
// ImageTokens are represented as a RowVectorBatch, where each "batch" index
// corresponds to a token for an image patch as computed by the image encoder.
using ImageTokens = RowVectorBatch<float>;
// RuntimeConfig holds configuration for a single generation run.
struct RuntimeConfig {
// If not empty, batch_stream_token is called for each token in the batch,
@ -110,6 +115,10 @@ struct RuntimeConfig {
LayersOutputFunc layers_output; // if not empty, called after each layer.
ActivationsObserverFunc activations_observer; // if set, called per-layer.
// If not empty, these point to the image tokens and are used in the
// PaliGemma prefix-LM style attention.
const ImageTokens *image_tokens = nullptr;
// Whether to use thread spinning to reduce barrier synchronization latency.
bool use_spinning = true;
@ -198,14 +207,34 @@ class Gemma {
// `pos` is the position in the KV cache. Users are responsible for
// incrementing it in the `*StreamFunc`, or setting to zero for single-turn.
void Generate(const RuntimeConfig& runtime_config, const PromptTokens& prompt,
size_t pos, KVCache& kv_cache, TimingInfo& timing_info);
size_t pos, KVCache& kv_cache, TimingInfo& timing_info) {
Generate(runtime_config, prompt, pos, /*prefix_end=*/0, kv_cache,
timing_info);
}
// For prefix-LM style attention, we can pass the end of the prefix.
void Generate(const RuntimeConfig& runtime_config, const PromptTokens& prompt,
size_t pos, size_t prefix_end, KVCache& kv_cache,
TimingInfo& timing_info);
// `queries_pos` are the positions in the KV cache. Users are responsible for
// incrementing them in `BatchStreamFunc`, or setting to zero for single-turn.
void GenerateBatch(const RuntimeConfig& runtime_config,
const QueriesPromptTokens& queries_prompt,
const QueriesPos& queries_pos, const KVCaches& kv_caches,
TimingInfo& timing_info);
TimingInfo& timing_info) {
GenerateBatch(runtime_config, queries_prompt, queries_pos,
/*queries_prefix_end=*/{}, kv_caches, timing_info);
}
// For prefix-LM style attention, we can pass the ends of the prefixes.
void GenerateBatch(const RuntimeConfig& runtime_config,
const QueriesPromptTokens& queries_prompt,
const QueriesPos& queries_pos,
const QueriesPos& queries_prefix_end,
const KVCaches& kv_caches, TimingInfo& timing_info);
// Generates the image tokens by running the image encoder ViT.
void GenerateImageTokens(const RuntimeConfig& runtime_config,
const Image& image, ImageTokens& image_tokens);
private:
PerClusterPools& pools_;

View File

@ -0,0 +1,21 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE \
"gemma/instantiations/paligemma_224_bf16.cc"
#include "hwy/foreach_target.h" // IWYU pragma: keep
#define GEMMA_CONFIG ConfigPaliGemma_224<hwy::bfloat16_t>
#include "gemma/gemma-inl.h"

View File

@ -0,0 +1,21 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE \
"gemma/instantiations/paligemma_224_f32.cc"
#include "hwy/foreach_target.h" // IWYU pragma: keep
#define GEMMA_CONFIG ConfigPaliGemma_224<float>
#include "gemma/gemma-inl.h"

View File

@ -0,0 +1,21 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE \
"gemma/instantiations/paligemma_224_sfp.cc"
#include "hwy/foreach_target.h" // IWYU pragma: keep
#define GEMMA_CONFIG ConfigPaliGemma_224<SfpStream>
#include "gemma/gemma-inl.h"

View File

@ -25,6 +25,7 @@
#include "evals/benchmark_helper.h"
#include "gemma/common.h"
#include "gemma/gemma.h" // Gemma
#include "paligemma/image.h"
#include "util/app.h"
#include "util/args.h" // HasHelp
#include "util/threading.h"
@ -87,6 +88,17 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const InferenceArgs& args,
std::mt19937 gen;
InitGenerator(args, gen);
const bool have_image = !args.image_file.path.empty();
Image image;
ImageTokens image_tokens(256, 2048);
if (have_image) {
HWY_ASSERT(model.Info().model == Model::PALIGEMMA_224);
HWY_ASSERT(image.ReadPPM(args.image_file.path));
image.Resize();
RuntimeConfig runtime_config = {.verbosity = verbosity, .gen = &gen};
model.GenerateImageTokens(runtime_config, image, image_tokens);
}
// callback function invoked for each generated token.
auto stream_token = [&](int token, float) {
++abs_pos;
@ -132,7 +144,12 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const InferenceArgs& args,
}
}
const std::vector<int> prompt = WrapAndTokenize(
if (have_image && abs_pos != 0) {
// This occurs when we have hit max_generated.
abs_pos = 0;
}
std::vector<int> prompt = WrapAndTokenize(
model.Tokenizer(), model.Info(), abs_pos, prompt_string);
prompt_size = prompt.size();
std::cerr << "\n"
@ -151,7 +168,19 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const InferenceArgs& args,
.accept_token = accept_token,
};
args.CopyTo(runtime_config);
model.Generate(runtime_config, prompt, abs_pos, kv_cache, timing_info);
size_t prefix_end = 0;
if (have_image) {
runtime_config.image_tokens = &image_tokens;
prompt.insert(prompt.begin(), image_tokens.BatchSize(), 0);
prompt_size = prompt.size();
// The end of the prefix for prefix-LM style attention in Paligemma.
// See Figure 2 of https://arxiv.org/abs/2407.07726.
prefix_end = prompt_size;
// We need to look at all the tokens for the prefix.
runtime_config.prefill_tbatch_size = prompt_size;
}
model.Generate(runtime_config, prompt, abs_pos, prefix_end, kv_cache,
timing_info);
std::cout << "\n\n";
}
std::cout

View File

@ -107,6 +107,14 @@ std::vector<int> WrapAndTokenize(const GemmaTokenizer& tokenizer,
if (pos == 0) {
tokens.insert(tokens.begin(), BOS_ID);
}
// PaliGemma separator. The SEP token "\n" is always tokenized separately.
if (info.model == Model::PALIGEMMA_224) {
std::vector<int> sep_tokens;
HWY_ASSERT(tokenizer.Encode("\n", &sep_tokens));
tokens.insert(tokens.end(), sep_tokens.begin(), sep_tokens.end());
}
return tokens;
}

View File

@ -48,6 +48,7 @@ struct CompressedLayer {
static constexpr size_t kAttVecEinsumWSize = kHeads * kQKVDim * kModelDim;
static constexpr size_t kQKVEinsumWSize =
(kHeads + 2 * kKVHeads) * kQKVDim * kModelDim;
static constexpr size_t kQKVEinsumBSize = (kHeads + 2 * kKVHeads) * kQKVDim;
// 2x for (gelu gating vector, gated vector)
static constexpr size_t kGatingEinsumWSize = 2 * kFFHiddenDim * kModelDim;
static constexpr size_t kConv1dWidth = TConfig::kConv1dWidth;
@ -81,6 +82,24 @@ struct CompressedLayer {
ArrayT<float, kGriffinDim * 2> gate_biases;
ArrayT<float, kGriffinDim> a;
} griffin;
struct {
// MultiHeadDotProductAttention.
ArrayT<WeightF32OrBF16, kAttVecEinsumWSize> attn_out_w;
ArrayT<float, kModelDim> attn_out_b;
ArrayT<WeightF32OrBF16, kQKVEinsumWSize> qkv_einsum_w;
ArrayT<float, kQKVEinsumBSize> qkv_einsum_b;
// MlpBlock.
ArrayT<WeightF32OrBF16, kModelDim * kFFHiddenDim> linear_0_w;
ArrayT<float, kFFHiddenDim> linear_0_b;
ArrayT<WeightF32OrBF16, kFFHiddenDim * kModelDim> linear_1_w;
ArrayT<float, kModelDim> linear_1_b;
// LayerNorm.
ArrayT<WeightF32OrBF16, kModelDim> layer_norm_0_bias;
ArrayT<WeightF32OrBF16, kModelDim> layer_norm_0_scale;
ArrayT<WeightF32OrBF16, kModelDim> layer_norm_1_bias;
ArrayT<WeightF32OrBF16, kModelDim> layer_norm_1_scale;
} vit;
};
ArrayT<Weight, kGatingEinsumWSize> gating_einsum_w;
@ -121,6 +140,7 @@ struct CompressedLayer {
out_row + h * kQKVDim, kQKVDim * sizeof(Weight));
}
}
att_weights.set_scale(attn_vec_einsum_w.scale());
}
};
@ -133,10 +153,21 @@ struct CompressedLayerPointers {
pool.Run(0, TConfig::kLayers, [this](uint64_t task, size_t /*thread*/) {
this->c_layers[task] = hwy::AllocateAligned<CompressedLayer<TConfig>>(1);
});
if constexpr (TConfig::VitConfig::kLayers > 0) {
pool.Run(0, TConfig::VitConfig::kLayers,
[this](uint64_t task, size_t /*thread*/) {
this->c_vit_layers[task] = hwy::AllocateAligned<
CompressedLayer<typename TConfig::VitConfig>>(1);
});
}
}
using CLayer = CompressedLayer<TConfig>;
std::array<hwy::AlignedFreeUniquePtr<CLayer[]>, TConfig::kLayers> c_layers;
using CVitLayer = CompressedLayer<typename TConfig::VitConfig>;
std::array<hwy::AlignedFreeUniquePtr<CVitLayer[]>,
TConfig::VitConfig::kLayers>
c_vit_layers;
};
template <class TConfig, typename = void>
@ -159,6 +190,23 @@ struct CompressedWeights {
hwy::If<hwy::IsSame<Weight, float>(), float, hwy::bfloat16_t>;
CompressedArray<WeightF32OrBF16, TConfig::kModelDim> final_norm_scale;
// Vit parts.
CompressedArray<WeightF32OrBF16, TConfig::VitConfig::kModelDim>
vit_encoder_norm_bias;
CompressedArray<WeightF32OrBF16, TConfig::VitConfig::kModelDim>
vit_encoder_norm_scale;
CompressedArray<float, TConfig::VitConfig::kModelDim> vit_img_embedding_bias;
CompressedArray<WeightF32OrBF16, TConfig::VitConfig::kModelDim * 14 * 14 * 3>
vit_img_embedding_kernel;
CompressedArray<float, 256 * TConfig::VitConfig::kModelDim>
vit_img_pos_embedding;
// The head maps from VitConfig::kModelDim (Vit final layer) to
// kModelDim (LLM input).
CompressedArray<float, TConfig::kModelDim> vit_img_head_bias;
CompressedArray<WeightF32OrBF16,
TConfig::VitConfig::kModelDim * TConfig::kModelDim>
vit_img_head_kernel;
// Must be last so that the other arrays remain aligned.
CompressedLayerPointers<TConfig> c_layer_ptrs;
@ -174,9 +222,21 @@ struct CompressedWeights {
void ZeroInit() {
hwy::ZeroBytes(&embedder_input_embedding, sizeof(embedder_input_embedding));
hwy::ZeroBytes(&final_norm_scale, sizeof(final_norm_scale));
hwy::ZeroBytes(&vit_encoder_norm_bias, sizeof(vit_encoder_norm_bias));
hwy::ZeroBytes(&vit_encoder_norm_scale, sizeof(vit_encoder_norm_scale));
hwy::ZeroBytes(&vit_img_embedding_bias, sizeof(vit_img_embedding_bias));
hwy::ZeroBytes(&vit_img_embedding_kernel, sizeof(vit_img_embedding_kernel));
hwy::ZeroBytes(&vit_img_head_bias, sizeof(vit_img_head_bias));
hwy::ZeroBytes(&vit_img_head_kernel, sizeof(vit_img_head_kernel));
hwy::ZeroBytes(&vit_img_pos_embedding, sizeof(vit_img_pos_embedding));
for (int i = 0; i < TConfig::kLayers; ++i) {
hwy::ZeroBytes(GetLayer(i), sizeof(*GetLayer(i)));
}
if constexpr (TConfig::VitConfig::kLayers > 0) {
for (int i = 0; i < TConfig::VitConfig::kLayers; ++i) {
hwy::ZeroBytes(GetVitLayer(i), sizeof(*GetVitLayer(i)));
}
}
}
const CompressedLayer<TConfig>* GetLayer(size_t layer) const {
@ -185,6 +245,13 @@ struct CompressedWeights {
CompressedLayer<TConfig>* GetLayer(size_t layer) {
return c_layer_ptrs.c_layers[layer].get();
}
const CompressedLayer<typename TConfig::VitConfig>* GetVitLayer(
size_t layer) const {
return c_layer_ptrs.c_vit_layers[layer].get();
}
CompressedLayer<typename TConfig::VitConfig>* GetVitLayer(size_t layer) {
return c_layer_ptrs.c_vit_layers[layer].get();
}
};
// ----------------------------------------------------------------------------
@ -288,6 +355,16 @@ void ForEachTensor(RawWeightsPtr raw_weights,
GEMMA_CALL_TOP_FUNC("c_embedding", embedder_input_embedding);
GEMMA_CALL_TOP_FUNC("c_final_norm", final_norm_scale);
if constexpr (TConfig::VitConfig::kLayers > 0 && !kHaveRaw) {
GEMMA_CALL_TOP_FUNC("enc_norm_bias", vit_encoder_norm_bias);
GEMMA_CALL_TOP_FUNC("enc_norm_scale", vit_encoder_norm_scale);
GEMMA_CALL_TOP_FUNC("img_emb_bias", vit_img_embedding_bias);
GEMMA_CALL_TOP_FUNC("img_emb_kernel", vit_img_embedding_kernel);
GEMMA_CALL_TOP_FUNC("img_head_bias", vit_img_head_bias);
GEMMA_CALL_TOP_FUNC("img_head_kernel", vit_img_head_kernel);
GEMMA_CALL_TOP_FUNC("img_pos_emb", vit_img_pos_embedding);
}
char name_buf[16];
for (int layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) {
auto type = TConfig::kLayerConfig[layer_idx];
@ -334,6 +411,35 @@ void ForEachTensor(RawWeightsPtr raw_weights,
GEMMA_CALL_FUNC("attn_ob", attention_output_biases);
}
}
// Vit layers. Not supported for compress_weights.
if constexpr (TConfig::VitConfig::kLayers > 0 && !kHaveRaw) {
for (int layer_idx = 0; layer_idx < TConfig::VitConfig::kLayers;
++layer_idx) {
auto type = TConfig::VitConfig::kLayerConfig[layer_idx];
HWY_ASSERT(type == LayerAttentionType::kVit);
const size_t idx = static_cast<size_t>(layer_idx);
const RawLayer* raw_layer = nullptr;
CompressedLayer<typename TConfig::VitConfig>* c_layer =
c_weights.GetVitLayer(idx);
// MHA.
GEMMA_CALL_FUNC("attn_out_w", vit.attn_out_w);
GEMMA_CALL_FUNC("attn_out_b", vit.attn_out_b);
GEMMA_CALL_FUNC("qkv_ein_w", vit.qkv_einsum_w);
GEMMA_CALL_FUNC("qkv_ein_b", vit.qkv_einsum_b);
// MlpBlock.
GEMMA_CALL_FUNC("linear_0_w", vit.linear_0_w);
GEMMA_CALL_FUNC("linear_0_b", vit.linear_0_b);
GEMMA_CALL_FUNC("linear_1_w", vit.linear_1_w);
GEMMA_CALL_FUNC("linear_1_b", vit.linear_1_b);
// LayerNorm.
GEMMA_CALL_FUNC("ln_0_bias", vit.layer_norm_0_bias);
GEMMA_CALL_FUNC("ln_0_scale", vit.layer_norm_0_scale);
GEMMA_CALL_FUNC("ln_1_bias", vit.layer_norm_1_bias);
GEMMA_CALL_FUNC("ln_1_scale", vit.layer_norm_1_scale);
}
}
#undef GEMMA_CALL_FUNC
#undef GEMMA_CALL_TOP_FUNC
} // ForEachTensor

45
paligemma/BUILD Normal file
View File

@ -0,0 +1,45 @@
package(
default_applicable_licenses = [
"//:license", # Placeholder comment, do not modify
],
default_visibility = [
"//:__subpackages__", # Placeholder, do not modify
],
)
cc_library(
name = "image",
srcs = ["image.cc"],
hdrs = ["image.h"],
deps = ["@hwy//:hwy"],
)
cc_test(
name = "image_test",
srcs = ["image_test.cc"],
data = ["testdata/image.ppm"],
deps = [
":image",
"@googletest//:gtest_main", # buildcleaner: keep
],
)
cc_test(
name = "paligemma_test",
srcs = ["paligemma_test.cc"],
# Requires model files
tags = [
"local",
"manual",
"no_tap",
],
deps = [
"@googletest//:gtest_main",
"//:benchmark_helper",
"//:common",
"//:gemma_lib",
"//:tokenizer",
"@hwy//:hwy",
"@hwy//:hwy_test_util",
],
)

182
paligemma/image.cc Normal file
View File

@ -0,0 +1,182 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paligemma/image.h"
#include <algorithm>
#include <cctype>
#include <cmath>
#include <cstdio>
#include <fstream>
#include <iostream>
#include <string>
#include <utility>
#include <vector>
#include "hwy/base.h"
namespace gcpp {
namespace {
// Hardcoded for PaliGemma-224 ViT input.
constexpr size_t kPatchSize = 14;
constexpr size_t kImageSize = 224;
constexpr size_t kNumPatches = kImageSize / kPatchSize; // 16
// Returns the linearly scaled index in [0, to_size) closest to the
// value in [0, from_size).
int NearestNeighbor(int value, int from_size, int to_size) {
float scale_factor = static_cast<float>(to_size - 1) / (from_size - 1);
// Apply nearest neighbor rounding.
int nn = static_cast<int>(std::round(value * scale_factor));
// Ensure the value is within the new range.
nn = std::clamp(nn, 0, to_size - 1);
return nn;
}
// Returns value in [0,1] mapped linearly to [-1,1].
float StretchToSigned(float value) {
// = out_min + (value - in_min) * (out_max - out_min) / (in_max - in_min);
return value * 2.0f - 1.0f;
}
bool IsLineBreak(int c) { return c == '\r' || c == '\n'; }
void SkipWhitespaceAndComments(std::ifstream& file) {
int value = file.get();
while (std::isspace(value)) value = file.get();
while (value == '#') { // Skip comment lines.
while (!IsLineBreak(value)) value = file.get();
while (std::isspace(value)) value = file.get();
}
file.unget(); // Rewind last byte.
}
} // namespace
bool Image::ReadPPM(const std::string& filename) {
std::ifstream file(filename);
if (!file.is_open()) {
std::cerr << "Failed to open " << filename << "\n";
return false;
}
std::string format;
file >> format;
if (format != "P6") {
std::cerr << "We only support binary PPM (P6) but got: " << format << "\n";
return false;
}
int width, height, max_value;
SkipWhitespaceAndComments(file);
file >> width;
SkipWhitespaceAndComments(file);
file >> height;
SkipWhitespaceAndComments(file);
file >> max_value;
if (max_value <= 0 || max_value > 255) {
std::cerr << "Unsupported max value " << max_value << "\n";
return false;
}
// P6 requires exactly one whitespace character after the header.
int value = file.get();
if (!std::isspace(value)) {
std::cerr << "Missing whitespace after header\n";
return false;
}
width_ = width;
height_ = height;
int data_size = width * height * 3;
data_.resize(data_size);
std::vector<char> data_bytes(data_size);
file.read(data_bytes.data(), data_size);
if (file.gcount() != data_size) {
std::cerr << "Failed to read " << data_size << " bytes\n";
return false;
}
for (int i = 0; i < data_size; ++i) {
data_[i] = StretchToSigned(static_cast<float>(data_bytes[i]) / max_value);
}
if (file.get() != EOF) {
std::cerr << "Extra data in file\n";
return false;
}
file.close();
return true;
}
void Image::Resize() {
int new_width = 224;
int new_height = kImageSize;
std::vector<float> new_data(new_width * new_height * 3);
// TODO: go to bilinear interpolation, or antialias.
// E.g. consider WeightsSymmetric3Lowpass and SlowSymmetric3 from
// jpegxl/lib/jxl/convolve_slow.cc
// For now, just do nearest neighbor.
for (int i = 0; i < new_height; ++i) {
for (int j = 0; j < new_width; ++j) {
int old_i = NearestNeighbor(i, new_height, height_);
int old_j = NearestNeighbor(j, new_width, width_);
for (int k = 0; k < 3; ++k) {
new_data[(i * new_width + j) * 3 + k] =
data_[(old_i * width_ + old_j) * 3 + k];
}
}
}
data_ = std::move(new_data);
height_ = new_height;
width_ = new_width;
}
bool Image::WriteBinary(const std::string& filename) const {
// Writes the floating point values as float32 in binary format.
std::ofstream file(filename);
if (!file.is_open()) {
std::cerr << "Failed to open " << filename << "\n";
return false;
}
for (int i = 0; i < data_.size(); ++i) {
file.write(reinterpret_cast<const char*>(&data_[i]), sizeof(float));
}
file.close();
return true;
}
// Image.data() is kImageSize x kImageSize x 3, H x W x C.
// We want the N-th patch (of 256) of size kPatchSize x kPatchSize x 3.
// Patches are numbered in usual "pixel-order".
void Image::GetPatch(size_t patch_num, float* patch) const {
constexpr size_t kDataSize = kImageSize * kImageSize * 3;
HWY_ASSERT(size() == kDataSize);
constexpr size_t kPatchDataSize = kPatchSize * kPatchSize * 3;
int i_offs = patch_num / kNumPatches;
int j_offs = patch_num % kNumPatches;
HWY_ASSERT(0 <= i_offs && i_offs < kNumPatches);
HWY_ASSERT(0 <= j_offs && j_offs < kNumPatches);
i_offs *= kPatchSize;
j_offs *= kPatchSize;
// This can be made faster, but let's first see whether it matters.
const float* image_data = data();
for (int i = 0; i < kPatchSize; ++i) {
for (int j = 0; j < kPatchSize; ++j) {
for (int k = 0; k < 3; ++k) {
const int patch_index = (i * kPatchSize + j) * 3 + k;
HWY_ASSERT(patch_index < kPatchDataSize);
const int image_index =
((i + i_offs) * kImageSize + (j + j_offs)) * 3 + k;
HWY_ASSERT(image_index < kDataSize);
patch[patch_index] = image_data[image_index];
}
}
}
}
} // namespace gcpp

59
paligemma/image.h Normal file
View File

@ -0,0 +1,59 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef THIRD_PARTY_GEMMA_CPP_PALIGEMMA_IMAGE_H_
#define THIRD_PARTY_GEMMA_CPP_PALIGEMMA_IMAGE_H_
#include <cstddef>
#include <string>
#include <vector>
namespace gcpp {
// Very basic image loading and processing for PaliGemma-224. Does not try to be
// generic at the moment, e.g. the size to normalize to is hardcoded.
class Image {
public:
Image() = default;
// Reads a file in PPM format (P6, binary), normalizes to [-1, 1].
// Returns true on success.
bool ReadPPM(const std::string& filename);
// Resizes to 224x224 (nearest-neighbor for now, bilinear or antialias would
// be better).
void Resize();
// Writes the file as plain floats in binary. Useful to e.g. load in a colab.
bool WriteBinary(const std::string& filename) const;
// Stores the patch for the given patch number [0, 256) in `patch`.
// As sizes are hardcoded, the patch number is sufficient here.
// `patch` should have space for at least 14 * 14 * 3 = 588 floats.
// Requires that Normalize() has been called.
void GetPatch(size_t patch_num, float* patch) const;
float *data() { return data_.data(); }
const float *data() const { return data_.data(); }
int width() const { return width_; }
int height() const { return height_; }
size_t size() const { return data_.size(); }
operator bool() const { return data_.size() > 0; }
private:
int width_ = 0;
int height_ = 0;
std::vector<float> data_; // r, g, b
};
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_PALIGEMMA_IMAGE_H_

72
paligemma/image_test.cc Normal file
View File

@ -0,0 +1,72 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paligemma/image.h"
#include <string>
#include "gtest/gtest.h"
namespace gcpp {
namespace {
float Normalize(int value) { return 2.0f * (value / 255.0f) - 1.0f; }
TEST(ImageTest, BasicFunctionality) {
return; // Need to figure out how to get the external path for the test file.
std::string path;
Image image;
EXPECT_EQ(image.width(), 0);
EXPECT_EQ(image.height(), 0);
EXPECT_EQ(image.size(), 0);
ASSERT_TRUE(image.ReadPPM(path));
EXPECT_EQ(image.width(), 256);
EXPECT_EQ(image.height(), 341);
EXPECT_EQ(image.size(), 256 * 341 * 3);
// Spot check a few values.
EXPECT_EQ(image.data()[0], Normalize(160));
EXPECT_EQ(image.data()[1], Normalize(184));
EXPECT_EQ(image.data()[2], Normalize(188));
EXPECT_EQ(image.data()[3], Normalize(163));
EXPECT_EQ(image.data()[4], Normalize(185));
EXPECT_EQ(image.data()[5], Normalize(189));
EXPECT_EQ(image.data()[30], Normalize(164));
EXPECT_EQ(image.data()[31], Normalize(185));
EXPECT_EQ(image.data()[32], Normalize(191));
EXPECT_EQ(image.data()[33], Normalize(164));
EXPECT_EQ(image.data()[34], Normalize(185));
EXPECT_EQ(image.data()[35], Normalize(191));
image.Resize();
// Check first and last pixel.
EXPECT_EQ(image.data()[0], Normalize(160));
EXPECT_EQ(image.data()[1], Normalize(184));
EXPECT_EQ(image.data()[2], Normalize(188));
EXPECT_EQ(image.data()[image.size() - 3], Normalize(90));
EXPECT_EQ(image.data()[image.size() - 2], Normalize(132));
EXPECT_EQ(image.data()[image.size() - 1], Normalize(122));
// Extract two patches.
float patch[588];
image.GetPatch(0, patch);
EXPECT_EQ(patch[0], Normalize(160));
EXPECT_EQ(patch[1], Normalize(184));
EXPECT_EQ(patch[2], Normalize(188));
image.GetPatch(18, patch);
for (int i = 0; i < 10; ++i) {
EXPECT_EQ(patch[i], image.data()[(14 * 224 + 2 * 14) * 3 + i]);
}
}
} // namespace
} // namespace gcpp

132
paligemma/paligemma_test.cc Normal file
View File

@ -0,0 +1,132 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "gemma/gemma.h"
#include <cstdio>
#include <memory>
#include <string>
#include <vector>
#include "evals/benchmark_helper.h"
#include "gemma/common.h"
#include "hwy/base.h"
#include "hwy/tests/hwy_gtest.h"
// This test can be run manually with the downloaded PaliGemma weights.
// To run the test, pass the following flags:
// --model paligemma-224 --tokenizer <tokenizer_path> --weights <weights_path>
// It should pass for the following models:
// paligemma-3b-mix-224
namespace gcpp {
namespace {
// Shared state. Requires argc/argv, so construct in main and use the same raw
// pointer approach as in benchmarks.cc. Note that the style guide forbids
// non-local static variables with dtors.
GemmaEnv* s_env = nullptr;
class PaliGemmaTest : public ::testing::Test {
protected:
void InitVit(const std::string& path);
std::string GemmaReply(const std::string& prompt_text) const;
void TestQuestions(const char* kQA[][2], size_t num_questions);
std::unique_ptr<ImageTokens> image_tokens_;
};
void PaliGemmaTest::InitVit(const std::string& path) {
ASSERT_NE(s_env->GetModel(), nullptr);
Gemma& model = *(s_env->GetModel());
image_tokens_ = std::make_unique<ImageTokens>(256, 2048);
Image image;
HWY_ASSERT(model.Info().model == Model::PALIGEMMA_224);
HWY_ASSERT(image.ReadPPM(path));
image.Resize();
RuntimeConfig runtime_config = {.verbosity = 0, .gen = &s_env->MutableGen()};
model.GenerateImageTokens(runtime_config, image, *image_tokens_);
}
std::string PaliGemmaTest::GemmaReply(const std::string& prompt_text) const{
Gemma& model = *(s_env->GetModel());
s_env->MutableGen().seed(0x12345678);
RuntimeConfig runtime_config = {.max_tokens = 1024,
.max_generated_tokens = 512,
.verbosity = 0,
.gen = &s_env->MutableGen()};
runtime_config.image_tokens = image_tokens_.get();
size_t abs_pos = 0;
std::string mutable_prompt = prompt_text;
std::vector<int> tokens =
WrapAndTokenize(model.Tokenizer(), model.Info(), abs_pos, mutable_prompt);
std::string response;
auto stream_token = [&](int token, float) {
std::string token_text;
HWY_ASSERT(model.Tokenizer().Decode(std::vector<int>{token}, &token_text));
response += token_text;
return true;
};
runtime_config.stream_token = stream_token,
tokens.insert(tokens.begin(), image_tokens_->BatchSize(), 0);
size_t num_tokens = tokens.size();
size_t prefix_end = num_tokens;
runtime_config.prefill_tbatch_size = num_tokens;
TimingInfo timing_info = {.verbosity = 0};
model.Generate(runtime_config, tokens, abs_pos, prefix_end,
s_env->MutableKVCache(), timing_info);
return response;
}
void PaliGemmaTest::TestQuestions(const char* kQA[][2], size_t num_questions) {
ASSERT_NE(s_env->GetModel(), nullptr);
return; // Need to figure out how to get the external path for the test file.
std::string path;
InitVit(path);
for (size_t i = 0; i < num_questions; ++i) {
fprintf(stderr, "Question %zu\n\n", i + 1);
std::string response = GemmaReply(kQA[i][0]);
fprintf(stderr, "'%s'\n\n", response.c_str());
EXPECT_TRUE(response.find(kQA[i][1]) != std::string::npos); // NOLINT
}
}
TEST_F(PaliGemmaTest, General) {
static const char* kQA[][2] = {
{"describe this image",
"A large building with two towers stands tall on the water's edge."},
{"describe image briefly",
"A large building with two towers in the middle of a city."},
{"What kind of building is it?", "church"},
{"How many towers does the church have?", "2"},
{"detect water", "<loc1022> water"},
{"segment water", "<seg010> water"},
{"Which city is this more likely? Tokio or Zurich?", "zurich"},
};
static const size_t kNum = sizeof(kQA) / sizeof(kQA[0]);
TestQuestions(kQA, kNum);
}
} // namespace
} // namespace gcpp
int main(int argc, char** argv) {
gcpp::GemmaEnv env(argc, argv);
gcpp::s_env = &env;
testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}

192
paligemma/testdata/image.ppm vendored Normal file

File diff suppressed because one or more lines are too long

View File

@ -187,6 +187,7 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
float temperature;
bool deterministic;
bool multiturn;
Path image_file;
// Returns error string or nullptr if OK.
const char* Validate() const {
@ -221,6 +222,7 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
"interaction\n 1 = continue KV cache after every interaction\n "
" Default : 0 (conversation "
"resets every turn)");
visitor(image_file, "image_file", Path(), "Image file to load.");
}
void CopyTo(RuntimeConfig& runtime_config) const {