mirror of https://github.com/google/gemma.cpp.git
Add support for PaliGemma Vision-LM (224x224) to gemma.cpp
See https://arxiv.org/abs/2407.07726 for a description of the model. Because PaliGemma operates as a prefix-LM on the image+prompt, add support for that. PiperOrigin-RevId: 677841119
This commit is contained in:
parent
c6c10e0a53
commit
f8835fe4a4
|
|
@ -242,6 +242,9 @@ cc_library(
|
|||
"gemma/instantiations/gemma2_2b_bf16.cc",
|
||||
"gemma/instantiations/gemma2_2b_f32.cc",
|
||||
"gemma/instantiations/gemma2_2b_sfp.cc",
|
||||
"gemma/instantiations/paligemma_224_bf16.cc",
|
||||
"gemma/instantiations/paligemma_224_f32.cc",
|
||||
"gemma/instantiations/paligemma_224_sfp.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"gemma/activations.h",
|
||||
|
|
@ -264,6 +267,7 @@ cc_library(
|
|||
":weights",
|
||||
":threading",
|
||||
"//compression:io",
|
||||
"//paligemma:image",
|
||||
"@hwy//:hwy",
|
||||
"@hwy//:bit_set",
|
||||
"@hwy//:matvec",
|
||||
|
|
@ -361,6 +365,7 @@ cc_binary(
|
|||
":gemma_lib",
|
||||
":threading",
|
||||
# Placeholder for internal dep, do not remove.,
|
||||
"//paligemma:image",
|
||||
"@hwy//:hwy",
|
||||
"@hwy//:profiler",
|
||||
"@hwy//:thread_pool",
|
||||
|
|
|
|||
|
|
@ -93,6 +93,9 @@ set(SOURCES
|
|||
gemma/instantiations/gemma2_2b_bf16.cc
|
||||
gemma/instantiations/gemma2_2b_f32.cc
|
||||
gemma/instantiations/gemma2_2b_sfp.cc
|
||||
gemma/instantiations/paligemma_224_bf16.cc
|
||||
gemma/instantiations/paligemma_224_f32.cc
|
||||
gemma/instantiations/paligemma_224_sfp.cc
|
||||
gemma/kv_cache.cc
|
||||
gemma/kv_cache.h
|
||||
gemma/tokenizer.cc
|
||||
|
|
@ -103,6 +106,8 @@ set(SOURCES
|
|||
ops/matmul-inl.h
|
||||
ops/matvec-inl.h
|
||||
ops/ops-inl.h
|
||||
paligemma/image.cc
|
||||
paligemma/image.h
|
||||
util/allocator.h
|
||||
util/app.h
|
||||
util/args.h
|
||||
|
|
@ -164,6 +169,8 @@ set(GEMMA_TEST_FILES
|
|||
ops/matmul_test.cc
|
||||
ops/gemma_matvec_test.cc
|
||||
evals/gemma_test.cc
|
||||
paligemma/image_test.cc
|
||||
paligemma/paligemma_test.cc
|
||||
)
|
||||
|
||||
foreach (TESTFILE IN LISTS GEMMA_TEST_FILES)
|
||||
|
|
|
|||
105
README.md
105
README.md
|
|
@ -23,7 +23,7 @@ deployment-oriented C++ inference runtimes, which are not designed for
|
|||
experimentation, and Python-centric ML research frameworks, which abstract away
|
||||
low-level computation through compilation.
|
||||
|
||||
gemma.cpp provides a minimalist implementation of Gemma 2B and 7B models,
|
||||
gemma.cpp provides a minimalist implementation of Gemma-1 and Gemma-2 models,
|
||||
focusing on simplicity and directness rather than full generality. This is
|
||||
inspired by vertically-integrated model implementations such as
|
||||
[ggml](https://github.com/ggerganov/ggml),
|
||||
|
|
@ -78,17 +78,20 @@ winget install --id Microsoft.VisualStudio.2022.BuildTools --force --override "-
|
|||
### Step 1: Obtain model weights and tokenizer from Kaggle or Hugging Face Hub
|
||||
|
||||
Visit the
|
||||
[Kaggle page for Gemma](https://www.kaggle.com/models/google/gemma/frameworks/gemmaCpp),
|
||||
or [Gemma-2](https://www.kaggle.com/models/google/gemma-2/gemmaCpp), and select
|
||||
`Model Variations |> Gemma C++`.
|
||||
[Kaggle page for Gemma-2](https://www.kaggle.com/models/google/gemma-2/gemmaCpp)
|
||||
[or Gemma-1](https://www.kaggle.com/models/google/gemma/frameworks/gemmaCpp),
|
||||
and select `Model Variations |> Gemma C++`.
|
||||
|
||||
On this tab, the `Variation` dropdown includes the options below. Note bfloat16
|
||||
weights are higher fidelity, while 8-bit switched floating point weights enable
|
||||
faster inference. In general, we recommend starting with the `-sfp` checkpoints.
|
||||
|
||||
If you are unsure which model to start with, we recommend starting with the
|
||||
smallest Gemma-2 model, i.e. `2.0-2b-it-sfp`.
|
||||
|
||||
Alternatively, visit the
|
||||
[gemma.cpp](https://huggingface.co/models?other=gemma.cpp) models on the Hugging
|
||||
Face Hub. First go the the model repository of the model of interest (see
|
||||
Face Hub. First go the model repository of the model of interest (see
|
||||
recommendations below). Then, click the `Files and versions` tab and download
|
||||
the model and tokenizer files. For programmatic downloading, if you have
|
||||
`huggingface_hub` installed, you can also download by running:
|
||||
|
|
@ -98,7 +101,7 @@ huggingface-cli login # Just the first time
|
|||
huggingface-cli download google/gemma-2b-sfp-cpp --local-dir build/
|
||||
```
|
||||
|
||||
2B instruction-tuned (`it`) and pre-trained (`pt`) models:
|
||||
Gemma-1 2B instruction-tuned (`it`) and pre-trained (`pt`) models:
|
||||
|
||||
| Model name | Description |
|
||||
| ----------- | ----------- |
|
||||
|
|
@ -107,7 +110,7 @@ huggingface-cli download google/gemma-2b-sfp-cpp --local-dir build/
|
|||
| `2b-pt` | 2 billion parameter pre-trained model, bfloat16 |
|
||||
| `2b-pt-sfp` | 2 billion parameter pre-trained model, 8-bit switched floating point |
|
||||
|
||||
7B instruction-tuned (`it`) and pre-trained (`pt`) models:
|
||||
Gemma-1 7B instruction-tuned (`it`) and pre-trained (`pt`) models:
|
||||
|
||||
| Model name | Description |
|
||||
| ----------- | ----------- |
|
||||
|
|
@ -256,6 +259,53 @@ Step 1, and run the binary as follows:
|
|||
|
||||
`./gemma --tokenizer tokenizer.spm --model gr2b-it --weights 2b-it-sfp.sbs`
|
||||
|
||||
### PaliGemma Vision-Language Model
|
||||
|
||||
This repository includes a version of the PaliGemma VLM
|
||||
([paper](https://arxiv.org/abs/2407.07726),
|
||||
[code](https://github.com/google-research/big_vision/tree/main/big_vision/configs/proj/paligemma)).
|
||||
We provide a C++ implementation of this model here.
|
||||
|
||||
To use the version of PaliGemma included in this repository, build the gemma
|
||||
binary as noted above in Step 3. Download the compressed weights and tokenizer
|
||||
from //TODO(keysers) - update location// and run the binary as follows:
|
||||
|
||||
```./gemma \
|
||||
--tokenizer paligemma_tokenizer.model \
|
||||
--model paligemma-224 \
|
||||
--weights paligemma-3b-mix-224-sfp.sbs \
|
||||
--image_file paligemma/testdata/image.ppm
|
||||
```
|
||||
|
||||
Note that the image reading code is very basic to avoid depending on an image
|
||||
processing library for now. We currently only support reading binary PPMs (P6).
|
||||
So use a tool like `convert` to first convert your images into that format, e.g.
|
||||
|
||||
`convert image.jpeg -resize 224x224^ image.ppm`
|
||||
|
||||
(As the image will be resized for processing anyway, we can already resize at
|
||||
this stage for slightly faster loading.)
|
||||
|
||||
The interaction with the image (using the mix-224 checkpoint) may then look
|
||||
something like this:
|
||||
|
||||
```
|
||||
> Describe the image briefly
|
||||
A large building with two towers in the middle of a city.
|
||||
> What type of building is it?
|
||||
church
|
||||
> What color is the church?
|
||||
gray
|
||||
> caption image
|
||||
A large building with two towers stands tall on the water's edge. The building
|
||||
has a brown roof and a window on the side. A tree stands in front of the
|
||||
building, and a flag waves proudly from its top. The water is calm and blue,
|
||||
reflecting the sky above. A bridge crosses the water, and a red and white boat
|
||||
rests on its surface. The building has a window on the side, and a flag on top.
|
||||
A tall tree stands in front of the building, and a window on the building is
|
||||
visible from the water. The water is green, and the sky is blue.
|
||||
```
|
||||
|
||||
### Troubleshooting and FAQs
|
||||
|
||||
**Running `./gemma` fails with "Failed to read cache gating_ein_0 (error 294) ..."**
|
||||
|
|
@ -283,8 +333,8 @@ and not a pre-trained model (any model with a `-pt` suffix).
|
|||
**How do I convert my fine-tune to a `.sbs` compressed model file?**
|
||||
|
||||
We're working on a python script to convert a standard model format to `.sbs`,
|
||||
and hope have it available in the next week or so. Follow [this
|
||||
issue](https://github.com/google/gemma.cpp/issues/11) for updates.
|
||||
and hope have it available soon. Follow
|
||||
[this issue](https://github.com/google/gemma.cpp/issues/11) for updates.
|
||||
|
||||
**What are some easy ways to make the model run faster?**
|
||||
|
||||
|
|
@ -371,7 +421,7 @@ For using the `gemma` executable as a command line tool, it may be useful to
|
|||
create an alias for gemma.cpp with arguments fully specified:
|
||||
|
||||
```sh
|
||||
alias gemma2b="~/gemma.cpp/build/gemma -- --tokenizer ~/gemma.cpp/build/tokenizer.spm --weights ~/gemma.cpp/build/2b-it-sfp.sbs --model 2b-it --verbosity 0"
|
||||
alias gemma2b="~/gemma.cpp/build/gemma -- --tokenizer ~/gemma.cpp/build/tokenizer.spm --weights ~/gemma.cpp/build/gemma2-2b-it-sfp.sbs --model gemma2-2b-it --verbosity 0"
|
||||
```
|
||||
|
||||
Replace the above paths with your own paths to the model and tokenizer paths
|
||||
|
|
@ -381,7 +431,7 @@ Here is an example of prompting `gemma` with a truncated input
|
|||
file (using a `gemma2b` alias like defined above):
|
||||
|
||||
```sh
|
||||
cat configs.h | tail -35 | tr '\n' ' ' | xargs -0 echo "What does this C++ code do: " | gemma2b
|
||||
cat configs.h | tail -n 35 | tr '\n' ' ' | xargs -0 echo "What does this C++ code do: " | gemma2b
|
||||
```
|
||||
|
||||
> [!NOTE]
|
||||
|
|
@ -391,27 +441,11 @@ cat configs.h | tail -35 | tr '\n' ' ' | xargs -0 echo "What does this C++ code
|
|||
The output of the above command should look like:
|
||||
|
||||
```console
|
||||
$ cat configs.h | tail -35 | tr '\n' ' ' | xargs -0 echo "What does this C++ code do: " | gemma2b
|
||||
[ Reading prompt ] ......................................................................................................................................................................................................................................................................................................................................................................................................................................................................................
|
||||
The code defines two C++ structs, `ConfigGemma7B` and `ConfigGemma2B`, which are used for configuring a deep learning model.
|
||||
[ Reading prompt ] [...]
|
||||
This C++ code snippet defines a set of **constants** used in a large language model (LLM) implementation, likely related to the **attention mechanism**.
|
||||
|
||||
**ConfigGemma7B**:
|
||||
|
||||
* `kSeqLen`: Stores the length of the sequence to be processed. It's set to 7168.
|
||||
* `kVocabSize`: Stores the size of the vocabulary, which is 256128.
|
||||
* `kLayers`: Number of layers in the deep learning model. It's set to 28.
|
||||
* `kModelDim`: Dimension of the model's internal representation. It's set to 3072.
|
||||
* `kFFHiddenDim`: Dimension of the feedforward and recurrent layers' hidden representations. It's set to 16 * 3072 / 2.
|
||||
|
||||
**ConfigGemma2B**:
|
||||
|
||||
* `kSeqLen`: Stores the length of the sequence to be processed. It's also set to 7168.
|
||||
* `kVocabSize`: Size of the vocabulary, which is 256128.
|
||||
* `kLayers`: Number of layers in the deep learning model. It's set to 18.
|
||||
* `kModelDim`: Dimension of the model's internal representation. It's set to 2048.
|
||||
* `kFFHiddenDim`: Dimension of the feedforward and recurrent layers' hidden representations. It's set to 16 * 2048 / 2.
|
||||
|
||||
These structs are used to configure a deep learning model with specific parameters for either Gemma7B or Gemma2B architecture.
|
||||
Let's break down the code:
|
||||
[...]
|
||||
```
|
||||
|
||||
### Incorporating gemma.cpp as a Library in your Project
|
||||
|
|
@ -496,4 +530,13 @@ Mikhaylov, Eugene Kliuchnikov, Jan Wassenberg, Jyrki Alakuijala, Lode
|
|||
Vandevenne, Luca Versari, Martin Bruse, Phil Culliton, Sami Boukortt, Thomas
|
||||
Fischbacher and Zoltan Szabadka.
|
||||
|
||||
Gemma-2 support was implemented in June/July 2024 with the help of several
|
||||
people.
|
||||
|
||||
PaliGemma support was implemented in September 2024 with contributions from
|
||||
Daniel Keysers.
|
||||
|
||||
[Jan Wassenberg](mailto:janwas@google.com) has continued to contribute many
|
||||
improvements, including major gains in efficiency, since the initial release.
|
||||
|
||||
This is not an officially supported Google product.
|
||||
|
|
|
|||
|
|
@ -322,6 +322,9 @@ namespace gcpp {
|
|||
void Run(Args& args) {
|
||||
hwy::ThreadPool pool(args.num_threads);
|
||||
const Model model_type = args.ModelType();
|
||||
if (model_type == Model::PALIGEMMA_224) {
|
||||
HWY_ABORT("PaliGemma is not supported in compress_weights.");
|
||||
}
|
||||
const Type weight_type = args.WeightType();
|
||||
GEMMA_EXPORT_AND_DISPATCH(
|
||||
model_type, weight_type, CompressWeights,
|
||||
|
|
|
|||
|
|
@ -97,7 +97,9 @@ struct Activations {
|
|||
|
||||
x = RowVectorBatch<float>(batch_size, kModelDim);
|
||||
q = RowVectorBatch<float>(batch_size, kHeads * QStride<TConfig>());
|
||||
logits = RowVectorBatch<float>(batch_size, kVocabSize);
|
||||
if constexpr (kVocabSize > 0) {
|
||||
logits = RowVectorBatch<float>(batch_size, kVocabSize);
|
||||
}
|
||||
|
||||
pre_att_rms_out = RowVectorBatch<float>(batch_size, kModelDim);
|
||||
att = RowVectorBatch<float>(batch_size, kHeads * kSeqLen);
|
||||
|
|
@ -109,7 +111,7 @@ struct Activations {
|
|||
C2 = RowVectorBatch<float>(batch_size, kFFHiddenDim);
|
||||
ffw_out = RowVectorBatch<float>(batch_size, kModelDim);
|
||||
|
||||
if (kGriffinLayers > 0) {
|
||||
if constexpr (kGriffinLayers > 0) {
|
||||
griffin_x = RowVectorBatch<float>(batch_size, kModelDim);
|
||||
griffin_y = RowVectorBatch<float>(batch_size, kModelDim);
|
||||
griffin_gate_x = RowVectorBatch<float>(batch_size, kModelDim);
|
||||
|
|
|
|||
|
|
@ -36,6 +36,7 @@ constexpr const char* kModelFlags[] = {
|
|||
"gemma2-2b-pt", "gemma2-2b-it", // Gemma2 2B
|
||||
"9b-pt", "9b-it", // Gemma2 9B
|
||||
"27b-pt", "27b-it", // Gemma2 27B
|
||||
"paligemma-224", // PaliGemma 224
|
||||
};
|
||||
constexpr Model kModelTypes[] = {
|
||||
Model::GEMMA_2B, Model::GEMMA_2B, // Gemma 2B
|
||||
|
|
@ -43,8 +44,9 @@ constexpr Model kModelTypes[] = {
|
|||
Model::GRIFFIN_2B, Model::GRIFFIN_2B, // RecurrentGemma
|
||||
Model::GEMMA_TINY, // Gemma Tiny
|
||||
Model::GEMMA2_2B, Model::GEMMA2_2B, // Gemma2 2B
|
||||
Model::GEMMA2_9B, Model::GEMMA2_9B, // Gemma2 9B
|
||||
Model::GEMMA2_27B, Model::GEMMA2_27B, // Gemma2 27B
|
||||
Model::GEMMA2_9B, Model::GEMMA2_9B, // Gemma2 9B
|
||||
Model::GEMMA2_27B, Model::GEMMA2_27B, // Gemma2 27B
|
||||
Model::PALIGEMMA_224, // PaliGemma 224
|
||||
};
|
||||
constexpr ModelTraining kModelTraining[] = {
|
||||
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma 2B
|
||||
|
|
@ -54,6 +56,7 @@ constexpr ModelTraining kModelTraining[] = {
|
|||
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma2 2B
|
||||
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma2 9B
|
||||
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma2 27B
|
||||
ModelTraining::PALIGEMMA, // PaliGemma 224
|
||||
};
|
||||
|
||||
constexpr size_t kNumModelFlags =
|
||||
|
|
|
|||
|
|
@ -37,10 +37,11 @@ enum class Model {
|
|||
GRIFFIN_2B,
|
||||
GEMMA_TINY,
|
||||
GEMMA2_2B,
|
||||
PALIGEMMA_224,
|
||||
};
|
||||
|
||||
// Instruction-tuned models require extra 'turn structure' tokens in prompts.
|
||||
enum class ModelTraining { GEMMA_IT, GEMMA_PT };
|
||||
enum class ModelTraining { GEMMA_IT, GEMMA_PT, PALIGEMMA };
|
||||
|
||||
// Tensor types for loading weights. When adding a new one, also
|
||||
// update GEMMA_FOREACH* and Call* below, and add instantiations/*.cc.
|
||||
|
|
@ -93,7 +94,9 @@ decltype(auto) CallForModel(Model model, TArgs&&... args) {
|
|||
return FuncT<ConfigGriffin2B<TWeight>>()(std::forward<TArgs>(args)...);
|
||||
case Model::GEMMA2_2B:
|
||||
return FuncT<ConfigGemma2_2B<TWeight>>()(std::forward<TArgs>(args)...);
|
||||
|
||||
case Model::PALIGEMMA_224:
|
||||
return FuncT<ConfigPaliGemma_224<TWeight>>()(
|
||||
std::forward<TArgs>(args)...);
|
||||
default:
|
||||
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
|
||||
}
|
||||
|
|
@ -136,8 +139,9 @@ decltype(auto) CallForModelAndWeight(Model model, Type weight,
|
|||
GEMMA_FOREACH_WEIGHT(X, ConfigGemma7B) \
|
||||
GEMMA_FOREACH_WEIGHT(X, ConfigGriffin2B) \
|
||||
GEMMA_FOREACH_WEIGHT(X, ConfigGemma2_2B) \
|
||||
GEMMA_FOREACH_WEIGHT(X, ConfigGemma2_9B) \
|
||||
GEMMA_FOREACH_WEIGHT(X, ConfigGemma2_27B) \
|
||||
GEMMA_FOREACH_WEIGHT(X, ConfigGemma2_9B) \
|
||||
GEMMA_FOREACH_WEIGHT(X, ConfigGemma2_27B) \
|
||||
GEMMA_FOREACH_WEIGHT(X, ConfigPaliGemma_224) \
|
||||
static_assert(true, "Allow trailing ;")
|
||||
|
||||
// Used by GEMMA_EXPORT_AND_DISPATCH. For a given TWEIGHT (e.g. float),
|
||||
|
|
@ -179,6 +183,11 @@ decltype(auto) CallForModelAndWeight(Model model, Type weight,
|
|||
ARGS; \
|
||||
break; \
|
||||
} \
|
||||
case Model::PALIGEMMA_224: { \
|
||||
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC<ConfigPaliGemma_224<TWEIGHT>>)\
|
||||
ARGS; \
|
||||
break; \
|
||||
} \
|
||||
default: \
|
||||
HWY_ABORT("Model type %d unknown.", static_cast<int>(MODEL)); \
|
||||
}
|
||||
|
|
|
|||
|
|
@ -44,6 +44,7 @@ using EmbedderInputT = hwy::bfloat16_t;
|
|||
enum class LayerAttentionType {
|
||||
kGemma,
|
||||
kGriffinRecurrentBlock,
|
||||
kVit,
|
||||
};
|
||||
|
||||
// Post attention and ffw normalization type.
|
||||
|
|
@ -131,7 +132,22 @@ struct CachePosSize {
|
|||
}
|
||||
};
|
||||
|
||||
struct ConfigNoSSM {
|
||||
struct ConfigNoVit {
|
||||
struct VitConfig {
|
||||
static constexpr int kLayers = 0;
|
||||
static constexpr std::array<LayerAttentionType, 0> kLayerConfig =
|
||||
FixedLayerConfig<0>(LayerAttentionType::kVit);
|
||||
static constexpr int kModelDim = 0;
|
||||
static constexpr int kFFHiddenDim = 0;
|
||||
static constexpr int kHeads = 0;
|
||||
static constexpr int kKVHeads = 0;
|
||||
static constexpr int kQKVDim = 0;
|
||||
static constexpr int kSeqLen = 0;
|
||||
static constexpr ResidualType kResidual = ResidualType::Add;
|
||||
};
|
||||
};
|
||||
|
||||
struct ConfigNoSSM : ConfigNoVit {
|
||||
static constexpr int kGriffinLayers = 0;
|
||||
|
||||
static constexpr int kConv1dWidth = 0;
|
||||
|
|
@ -247,6 +263,37 @@ struct ConfigGemma2B : public ConfigBaseGemmaV1 {
|
|||
static constexpr bool kAbsolutePE = false;
|
||||
};
|
||||
|
||||
template <typename TWeight>
|
||||
struct ConfigPaliGemma_224 : public ConfigGemma2B<TWeight> {
|
||||
// On the LM side, the vocab size is one difference to Gemma1-2B in the
|
||||
// architecture. PaliGemma adds 1024 <locNNNN> and 128 <segNNN> tokens.
|
||||
static constexpr int kVocabSize = 256000 + 1024 + 128; // = 257152
|
||||
|
||||
// Sub-config for the Vision-Transformer part.
|
||||
struct VitConfig : public ConfigNoSSM {
|
||||
using Weight = TWeight;
|
||||
// The ViT parts. https://arxiv.org/abs/2305.13035
|
||||
// "SoViT-400m/14 [...] has a width of 1152, depth 27, and MLP dim 4304."
|
||||
static constexpr std::array<LayerAttentionType, 27> kLayerConfig =
|
||||
FixedLayerConfig<27>(LayerAttentionType::kVit);
|
||||
static constexpr int kLayers = kLayerConfig.size();
|
||||
static constexpr int kModelDim = 1152;
|
||||
static constexpr int kFFHiddenDim = 4304;
|
||||
static constexpr int kHeads = 16;
|
||||
static constexpr int kKVHeads = 16; // standard MHA
|
||||
static constexpr int kQKVDim = 72;
|
||||
static constexpr int kSeqLen = 16 * 16; // 256
|
||||
static constexpr bool kFFBiases = true;
|
||||
// The Vit part does not have a vocabulary, the image patches are embedded.
|
||||
static constexpr int kVocabSize = 0;
|
||||
// Dimensions related to image processing.
|
||||
static constexpr int kPatchWidth = 14;
|
||||
static constexpr int kImageSize = 224;
|
||||
// Necessary constant for the layer configuration.
|
||||
static constexpr PostNormType kPostNorm = PostNormType::None;
|
||||
};
|
||||
};
|
||||
|
||||
template <typename TWeight>
|
||||
struct ConfigGemma2_2B : public ConfigBaseGemmaV2 {
|
||||
using Weight = TWeight; // make accessible where we only have a TConfig
|
||||
|
|
@ -297,7 +344,7 @@ struct ConfigGemmaTiny : public ConfigNoSSM {
|
|||
};
|
||||
|
||||
template <typename TWeight>
|
||||
struct ConfigGriffin2B {
|
||||
struct ConfigGriffin2B : ConfigNoVit {
|
||||
using Weight = TWeight; // make accessible where we only have a TConfig
|
||||
|
||||
// Griffin uses local attention, so kSeqLen is actually the local attention
|
||||
|
|
|
|||
|
|
@ -41,6 +41,7 @@
|
|||
#include "ops/matmul-inl.h"
|
||||
#include "ops/matvec-inl.h"
|
||||
#include "ops/ops-inl.h"
|
||||
#include "paligemma/image.h"
|
||||
#include "util/allocator.h"
|
||||
#include "util/threading.h"
|
||||
#include "hwy/aligned_allocator.h"
|
||||
|
|
@ -209,7 +210,7 @@ class GemmaAttention {
|
|||
static constexpr size_t kSeqLen = TConfig::kSeqLen;
|
||||
static constexpr bool kIsMHA = Activations::IsMHA<TConfig>();
|
||||
|
||||
// The attention window usually starts at 0 unless unless `pos` is larger than
|
||||
// The attention window usually starts at 0 unless `pos` is larger than
|
||||
// the attention window size, then it is `pos` - window_size + 1.
|
||||
static HWY_INLINE size_t StartPos(size_t pos, size_t layer) {
|
||||
const size_t att_window_size = TConfig::kAttentionWindowSizes[layer];
|
||||
|
|
@ -318,26 +319,26 @@ class GemmaAttention {
|
|||
}
|
||||
|
||||
// Computes Q.K scores, which are "logits" (or scores) stored to head_att.
|
||||
HWY_INLINE void QDotK(const size_t start_pos, const size_t pos,
|
||||
HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos,
|
||||
const size_t head_offset, const float* HWY_RESTRICT q,
|
||||
const KVCache& kv_cache, float* HWY_RESTRICT head_att) {
|
||||
if (HWY_LIKELY(pos < kSeqLen)) {
|
||||
if (HWY_LIKELY(last_pos < kSeqLen)) {
|
||||
// Slightly faster: no wraparound.
|
||||
for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) {
|
||||
for (size_t pos = start_pos; pos <= last_pos; ++pos) {
|
||||
const size_t kv_offset =
|
||||
pos2 * kCachePosSize + layer_ * kCacheLayerSize + head_offset;
|
||||
pos * kCachePosSize + layer_ * kCacheLayerSize + head_offset;
|
||||
const float* HWY_RESTRICT k = &kv_cache.kv_cache[kv_offset];
|
||||
const float score = Dot(q, k, kQKVDim);
|
||||
head_att[pos2] = score;
|
||||
head_att[pos] = score;
|
||||
}
|
||||
} else {
|
||||
for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) {
|
||||
const size_t cache_pos = div_seq_len_.Remainder(pos2);
|
||||
for (size_t pos = start_pos; pos <= last_pos; ++pos) {
|
||||
const size_t cache_pos = div_seq_len_.Remainder(pos);
|
||||
const size_t kv_offset =
|
||||
cache_pos * kCachePosSize + layer_ * kCacheLayerSize + head_offset;
|
||||
const float* HWY_RESTRICT k = &kv_cache.kv_cache[kv_offset];
|
||||
const float score = Dot(q, k, kQKVDim);
|
||||
head_att[pos2 % kSeqLen] = score;
|
||||
head_att[pos % kSeqLen] = score;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -345,32 +346,30 @@ class GemmaAttention {
|
|||
// Accumulates the sum of v (from `kv_cache`) * probability (`head_att`) into
|
||||
// `att_out`. Equivalent in gemma/modules.py:
|
||||
// encoded = jnp.einsum('BTNS,BSNH->BTNH', probs, value_proj)
|
||||
static HWY_INLINE void WeightedSumV(const size_t start_pos, const size_t pos,
|
||||
const float* HWY_RESTRICT head_att,
|
||||
const size_t layer,
|
||||
const size_t head_offset,
|
||||
const hwy::Divisor& div_seq_len,
|
||||
const KVCache& kv_cache,
|
||||
float* HWY_RESTRICT att_out) {
|
||||
static HWY_INLINE void WeightedSumV(
|
||||
const size_t start_pos, const size_t last_pos,
|
||||
const float* HWY_RESTRICT head_att, const size_t layer,
|
||||
const size_t head_offset, const hwy::Divisor& div_seq_len,
|
||||
const KVCache& kv_cache, float* HWY_RESTRICT att_out) {
|
||||
hwy::ZeroBytes(att_out, kQKVDim * sizeof(*att_out));
|
||||
|
||||
if (HWY_LIKELY(pos < kSeqLen)) {
|
||||
if (HWY_LIKELY(last_pos < kSeqLen)) {
|
||||
// Slightly faster: no wraparound.
|
||||
for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) {
|
||||
for (size_t pos = start_pos; pos <= last_pos; ++pos) {
|
||||
const size_t kv_offset =
|
||||
pos2 * kCachePosSize + layer * kCacheLayerSize + head_offset;
|
||||
pos * kCachePosSize + layer * kCacheLayerSize + head_offset;
|
||||
const float* HWY_RESTRICT v =
|
||||
kv_cache.kv_cache.get() + kv_offset + kQKVDim;
|
||||
MulByConstAndAdd(head_att[pos2], v, att_out, kQKVDim);
|
||||
MulByConstAndAdd(head_att[pos], v, att_out, kQKVDim);
|
||||
}
|
||||
} else {
|
||||
for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) {
|
||||
const size_t cache_pos = div_seq_len.Remainder(pos2);
|
||||
for (size_t pos = start_pos; pos <= last_pos; ++pos) {
|
||||
const size_t cache_pos = div_seq_len.Remainder(pos);
|
||||
const size_t kv_offset =
|
||||
cache_pos * kCachePosSize + layer * kCacheLayerSize + head_offset;
|
||||
const float* HWY_RESTRICT v =
|
||||
kv_cache.kv_cache.get() + kv_offset + kQKVDim;
|
||||
MulByConstAndAdd(head_att[pos2 % kSeqLen], v, att_out, kQKVDim);
|
||||
MulByConstAndAdd(head_att[pos % kSeqLen], v, att_out, kQKVDim);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -402,20 +401,26 @@ class GemmaAttention {
|
|||
PositionalEncodingQK(q, pos, layer_, kQueryScale, q);
|
||||
|
||||
const size_t start_pos = StartPos(pos, layer_);
|
||||
size_t last_pos = pos;
|
||||
const size_t prefix_end = queries_prefix_end_[query_idx];
|
||||
if (prefix_end > 0 && prefix_end - 1 > last_pos) {
|
||||
// last_pos in QDotK and WeightedSumV is inclusive.
|
||||
last_pos = prefix_end - 1;
|
||||
}
|
||||
|
||||
float* HWY_RESTRICT head_att =
|
||||
activations_.att.Batch(interleaved_idx) + head * kSeqLen;
|
||||
QDotK(start_pos, pos, head_offset, q, kv_cache, head_att);
|
||||
QDotK(start_pos, last_pos, head_offset, q, kv_cache, head_att);
|
||||
// SoftMax with optional SoftCap yields "probabilities" in
|
||||
// head_att.
|
||||
const size_t head_att_len = std::min(pos + 1, kSeqLen);
|
||||
const size_t head_att_len = std::min(last_pos + 1, kSeqLen);
|
||||
MaybeLogitsSoftCap(TConfig::kAttCap, head_att, head_att_len);
|
||||
Softmax(head_att, head_att_len);
|
||||
|
||||
float* HWY_RESTRICT att_out =
|
||||
activations_.att_out.Batch(interleaved_idx) +
|
||||
head * kQKVDim;
|
||||
WeightedSumV(start_pos, pos, head_att, layer_, head_offset,
|
||||
WeightedSumV(start_pos, last_pos, head_att, layer_, head_offset,
|
||||
div_seq_len_, kv_cache, att_out);
|
||||
});
|
||||
}
|
||||
|
|
@ -435,16 +440,40 @@ class GemmaAttention {
|
|||
MatMul<kAdd>(
|
||||
num_interleaved, ConstMat(activations_.att_out.All(), kHeads * kQKVDim),
|
||||
ConstMat(layer_weights_.att_weights.data(), kHeads * kQKVDim),
|
||||
layer_weights_.attn_vec_einsum_w.scale(), bias, activations_.env,
|
||||
layer_weights_.att_weights.scale(), bias, activations_.env,
|
||||
MutableMat(activations_.att_sums.All(), kModelDim));
|
||||
}
|
||||
|
||||
public:
|
||||
// Constructor with explicit initialization of queries_prefix_end. This is
|
||||
// needed for the Prefix-LM style attention. For standard causal attention,
|
||||
// the other constructor can be used.
|
||||
GemmaAttention(const QueriesPos& queries_pos,
|
||||
const QueriesPos& queries_prefix_end, size_t num_tokens,
|
||||
size_t layer, Activations& activations,
|
||||
const CompressedLayer<TConfig>* layer_weights,
|
||||
const hwy::Divisor& div_seq_len, const KVCaches& kv_caches)
|
||||
: queries_pos_(queries_pos),
|
||||
queries_prefix_end_(queries_prefix_end),
|
||||
num_queries_(queries_pos.size()),
|
||||
num_tokens_(num_tokens),
|
||||
layer_(layer),
|
||||
activations_(activations),
|
||||
layer_weights_(*layer_weights),
|
||||
div_seq_len_(div_seq_len),
|
||||
kv_caches_(kv_caches),
|
||||
pool_(activations.env.Pool()) {
|
||||
HWY_DASSERT(num_queries_ <= kv_caches_.size());
|
||||
}
|
||||
// Constructor with default initialization to 0 for queries_prefix_end.
|
||||
GemmaAttention(const QueriesPos& queries_pos, size_t num_tokens, size_t layer,
|
||||
Activations& activations,
|
||||
const CompressedLayer<TConfig>* layer_weights,
|
||||
const hwy::Divisor& div_seq_len, const KVCaches& kv_caches)
|
||||
: queries_pos_(queries_pos),
|
||||
queries_prefix_end_vec_(queries_pos.size(), 0),
|
||||
queries_prefix_end_(queries_prefix_end_vec_.data(),
|
||||
queries_prefix_end_vec_.size()),
|
||||
num_queries_(queries_pos.size()),
|
||||
num_tokens_(num_tokens),
|
||||
layer_(layer),
|
||||
|
|
@ -456,6 +485,7 @@ class GemmaAttention {
|
|||
HWY_DASSERT(num_queries_ <= kv_caches_.size());
|
||||
}
|
||||
|
||||
// Full attention computation in three steps.
|
||||
HWY_INLINE void operator()() {
|
||||
const size_t num_interleaved = num_tokens_ * num_queries_;
|
||||
ComputeQKV(num_interleaved);
|
||||
|
|
@ -465,6 +495,8 @@ class GemmaAttention {
|
|||
|
||||
private:
|
||||
const QueriesPos& queries_pos_;
|
||||
const std::vector<size_t> queries_prefix_end_vec_;
|
||||
const QueriesPos queries_prefix_end_;
|
||||
const size_t num_queries_;
|
||||
const size_t num_tokens_;
|
||||
const size_t layer_;
|
||||
|
|
@ -476,15 +508,15 @@ class GemmaAttention {
|
|||
};
|
||||
|
||||
template <class TConfig>
|
||||
HWY_NOINLINE void Attention(LayerAttentionType type,
|
||||
const QueriesPos& queries_pos, size_t num_tokens,
|
||||
size_t layer, Activations& activations,
|
||||
const CompressedLayer<TConfig>* layer_weights,
|
||||
const hwy::Divisor& div_seq_len,
|
||||
const KVCaches& kv_caches) {
|
||||
HWY_NOINLINE void Attention(
|
||||
LayerAttentionType type, const QueriesPos& queries_pos,
|
||||
const QueriesPos& queries_prefix_end, size_t num_tokens, size_t layer,
|
||||
Activations& activations, const CompressedLayer<TConfig>* layer_weights,
|
||||
const hwy::Divisor& div_seq_len, const KVCaches& kv_caches) {
|
||||
if (type == LayerAttentionType::kGemma) {
|
||||
GemmaAttention<TConfig>(queries_pos, num_tokens, layer, activations,
|
||||
layer_weights, div_seq_len, kv_caches)();
|
||||
GemmaAttention<TConfig>(queries_pos, queries_prefix_end, num_tokens, layer,
|
||||
activations, layer_weights, div_seq_len,
|
||||
kv_caches)();
|
||||
} else {
|
||||
// Only reached if the model is Griffin. `if constexpr` prevents generating
|
||||
// this code for non-Griffin models.
|
||||
|
|
@ -496,6 +528,115 @@ HWY_NOINLINE void Attention(LayerAttentionType type,
|
|||
}
|
||||
}
|
||||
|
||||
// Wrapper class; holds arguments in member variables to shorten call sites.
|
||||
// The main differences to GemmaAttention are:
|
||||
// - no KV Cache necessary, attention is always all-to-all and not causal.
|
||||
// - no potential wrap-around, attention always goes from 0 to kSeqLen.
|
||||
// - no need for batching, as we are always computing attention for kSeqLen
|
||||
// tokens.
|
||||
// This results in a much simpler implementation. However, to avoid duplicating
|
||||
// code, we should still consider merging the two classes.
|
||||
// TODO(keysers): Refactor to share code with GemmaAttention.
|
||||
template <class TConfig>
|
||||
class VitAttention {
|
||||
static constexpr size_t kHeads = TConfig::kHeads;
|
||||
static constexpr size_t kKVHeads = TConfig::kKVHeads;
|
||||
static constexpr size_t kModelDim = TConfig::kModelDim;
|
||||
static constexpr size_t kQKVDim = TConfig::kQKVDim;
|
||||
static constexpr size_t kQStride = 3 * kQKVDim;
|
||||
static constexpr size_t kSeqLen = TConfig::kSeqLen;
|
||||
|
||||
// Computes Q, K, V for all heads, stored in activations_.q.
|
||||
HWY_NOINLINE void ComputeQKV() {
|
||||
PROFILER_ZONE("Gen.VitAttention.QKV");
|
||||
const auto y =
|
||||
ConstMat(activations_.pre_att_rms_out.All(), kModelDim);
|
||||
auto& qkv = activations_.q;
|
||||
HWY_ASSERT(qkv.BatchSize() == num_tokens_);
|
||||
HWY_ASSERT(qkv.Len() == kHeads * kQStride);
|
||||
MatMul</*kAdd=*/true>(
|
||||
num_tokens_, y,
|
||||
ConstMat(layer_weights_.vit.qkv_einsum_w.data_scale1(), kModelDim),
|
||||
/*scale=*/1.0f, layer_weights_.vit.qkv_einsum_b.data_scale1(),
|
||||
activations_.env, MutableMat(qkv.All(), qkv.Len()));
|
||||
}
|
||||
|
||||
HWY_NOINLINE void DotSoftmaxWeightedSum() {
|
||||
GEMMA_CONSTEXPR_SQRT float kQueryScale =
|
||||
1.0f / Sqrt(static_cast<float>(TConfig::kQKVDim));
|
||||
PROFILER_ZONE("Gen.VitAttention.DotSoftmax");
|
||||
// A "head group" in the context of GQA refers to a collection of query
|
||||
// heads that share the same key and value heads.
|
||||
static_assert(kHeads == kKVHeads, "Vit expects MHA");
|
||||
|
||||
// Compute Q.K, softmax, and weighted V.
|
||||
pool_.Run(0, kHeads * num_tokens_,
|
||||
[&](uint64_t task, size_t /*thread*/) HWY_ATTR {
|
||||
const size_t head = task % kHeads;
|
||||
const size_t token = task / kHeads;
|
||||
// Compute Q.K scores, which are "logits" stored in head_att.
|
||||
float* HWY_RESTRICT q =
|
||||
activations_.q.Batch(token) + head * kQStride;
|
||||
MulByConst(kQueryScale, q, kQKVDim);
|
||||
float* HWY_RESTRICT head_att =
|
||||
activations_.att.Batch(token) + head * kSeqLen;
|
||||
for (size_t i = 0; i < kSeqLen; ++i) {
|
||||
float* HWY_RESTRICT k =
|
||||
activations_.q.Batch(i) + head * kQStride + kQKVDim;
|
||||
head_att[i] = Dot(q, k, kQKVDim); // score = q.k
|
||||
}
|
||||
// SoftMax yields "probabilities" in head_att.
|
||||
Softmax(head_att, kSeqLen);
|
||||
// Compute weighted sum of v into att_out.
|
||||
float* HWY_RESTRICT att_out =
|
||||
activations_.att_out.Batch(token) + head * kQKVDim;
|
||||
hwy::ZeroBytes(att_out, kQKVDim * sizeof(*att_out));
|
||||
for (size_t i = 0; i < kSeqLen; ++i) {
|
||||
float* HWY_RESTRICT v =
|
||||
activations_.q.Batch(i) + head * kQStride + 2 * kQKVDim;
|
||||
MulByConstAndAdd(head_att[i], v, att_out, kQKVDim);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// Sums encoded (`att_out`) over num_heads (`kHeads`) and head_dim (`kQKVDim`)
|
||||
// into output (`att_sums`).
|
||||
HWY_NOINLINE void SumHeads() {
|
||||
PROFILER_ZONE("Gen.VitAttention.SumHeads");
|
||||
auto* bias = layer_weights_.vit.attn_out_b.data_scale1();
|
||||
auto att_out = ConstMat(activations_.att_out.All(), kHeads * kQKVDim);
|
||||
auto att_weights = ConstMat(layer_weights_.vit.attn_out_w.data_scale1(),
|
||||
kHeads * kQKVDim);
|
||||
auto att_sums = MutableMat(activations_.att_sums.All(), kModelDim);
|
||||
// att_weights and att_out are concatenated heads, each of length kQKVDim.
|
||||
// Thus the [num_tokens_, kModelDim] matmul output is the sum over heads.
|
||||
MatMul</*kAdd=*/true>(num_tokens_, att_out, att_weights, /*scale=*/1.0f,
|
||||
bias, activations_.env, att_sums);
|
||||
}
|
||||
|
||||
public:
|
||||
VitAttention(size_t num_tokens, size_t layer, Activations& activations,
|
||||
const CompressedLayer<TConfig>* layer_weights)
|
||||
: num_tokens_(num_tokens),
|
||||
layer_(layer),
|
||||
activations_(activations),
|
||||
layer_weights_(*layer_weights),
|
||||
pool_(activations.env.Pool()) {}
|
||||
|
||||
HWY_INLINE void operator()() {
|
||||
ComputeQKV();
|
||||
DotSoftmaxWeightedSum();
|
||||
SumHeads();
|
||||
}
|
||||
|
||||
private:
|
||||
const size_t num_tokens_;
|
||||
const size_t layer_;
|
||||
Activations& activations_;
|
||||
const CompressedLayer<TConfig>& layer_weights_;
|
||||
hwy::ThreadPool& pool_;
|
||||
};
|
||||
|
||||
template <class TConfig, typename T>
|
||||
HWY_NOINLINE void Activation(T* HWY_RESTRICT c1, T* HWY_RESTRICT c2,
|
||||
size_t count) {
|
||||
|
|
@ -504,6 +645,11 @@ HWY_NOINLINE void Activation(T* HWY_RESTRICT c1, T* HWY_RESTRICT c2,
|
|||
using DF = hn::ScalableTag<T>;
|
||||
using VF = hn::Vec<DF>;
|
||||
// ActivationType::Gelu
|
||||
if (c2 == nullptr) { // No multiplier, just Gelu.
|
||||
Gelu(c1, count);
|
||||
return;
|
||||
};
|
||||
// Has multiplier, Gelu(c1) * c2.
|
||||
hn::Transform1(DF(), c1, count, c2, [](DF df, VF v, VF mul) HWY_ATTR {
|
||||
return hn::Mul(mul, Gelu(df, v));
|
||||
});
|
||||
|
|
@ -515,41 +661,66 @@ HWY_NOINLINE void FFW(Activations& activations, size_t num_interleaved,
|
|||
PROFILER_ZONE("Gen.FFW");
|
||||
constexpr size_t kModelDim = TConfig::kModelDim;
|
||||
constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim;
|
||||
|
||||
// MatMul expects col-major B, which is what we have: kModelDim consecutive
|
||||
// elements in memory, repeated kFFHiddenDim times.
|
||||
HWY_DASSERT(num_interleaved <= activations.bf_pre_ffw_rms_out.BatchSize());
|
||||
const auto A = ConstMat(activations.bf_pre_ffw_rms_out.All(), kModelDim);
|
||||
const auto B1 = ConstMat(layer_weights->gating_einsum_w.data(), kModelDim);
|
||||
const auto B2 = ConstMat(layer_weights->gating_einsum_w.data(), kModelDim,
|
||||
kModelDim, kModelDim * kFFHiddenDim);
|
||||
const float scale = layer_weights->gating_einsum_w.scale();
|
||||
constexpr bool kAddBias = TConfig::kFFBiases;
|
||||
constexpr bool kIsVit = TConfig::kLayerConfig[0] == LayerAttentionType::kVit;
|
||||
using WeightType =
|
||||
hwy::If<kIsVit,
|
||||
typename CompressedLayer<TConfig>::WeightF32OrBF16,
|
||||
typename CompressedLayer<TConfig>::Weight>;
|
||||
HWY_DASSERT(num_interleaved <= activations.bf_pre_ffw_rms_out.BatchSize());
|
||||
|
||||
// Define slightly more readable names for the weights and activations.
|
||||
const auto x = ConstMat(activations.bf_pre_ffw_rms_out.All(), kModelDim);
|
||||
Mat<const WeightType> w1;
|
||||
const float* bias1 = nullptr;
|
||||
Mat<const WeightType> w2;
|
||||
const float* bias2 = nullptr;
|
||||
float scale = 1.0f;
|
||||
Mat<const WeightType> w_output;
|
||||
const float* output_bias = nullptr;
|
||||
if constexpr (kAddBias) {
|
||||
float output_scale = 1.0f;
|
||||
auto hidden_activations = MutableMat(activations.C1.All(), kFFHiddenDim);
|
||||
auto multiplier = MutableMat(activations.C2.All(), kFFHiddenDim);
|
||||
auto ffw_out = MutableMat(activations.ffw_out.All(), kModelDim);
|
||||
|
||||
// For some of the weights and activations, it depends on the config where to
|
||||
// get them from or whether to use them at all.
|
||||
if constexpr (kAddBias && !kIsVit) {
|
||||
bias1 = layer_weights->ffw_gating_biases.data_scale1();
|
||||
bias2 = bias1 + kFFHiddenDim;
|
||||
output_bias = layer_weights->ffw_output_biases.data_scale1();
|
||||
}
|
||||
auto C1 = MutableMat(activations.C1.All(), kFFHiddenDim);
|
||||
auto C2 = MutableMat(activations.C2.All(), kFFHiddenDim);
|
||||
if constexpr (!kIsVit) {
|
||||
w1 = ConstMat(layer_weights->gating_einsum_w.data(), kModelDim);
|
||||
w2 = ConstMat(layer_weights->gating_einsum_w.data(), kModelDim, kModelDim,
|
||||
kModelDim * kFFHiddenDim);
|
||||
scale = layer_weights->gating_einsum_w.scale();
|
||||
w_output = ConstMat(layer_weights->linear_w.data(), kFFHiddenDim);
|
||||
output_scale = layer_weights->linear_w.scale();
|
||||
} else {
|
||||
w1 = ConstMat(layer_weights->vit.linear_0_w.data_scale1(), kModelDim);
|
||||
bias1 = layer_weights->vit.linear_0_b.data_scale1();
|
||||
multiplier.ptr = nullptr;
|
||||
w_output =
|
||||
ConstMat(layer_weights->vit.linear_1_w.data_scale1(), kFFHiddenDim);
|
||||
output_bias = layer_weights->vit.linear_1_b.data_scale1();
|
||||
}
|
||||
|
||||
// Will go through GELU.
|
||||
MatMul<kAddBias>(num_interleaved, A, B1, scale, bias1, activations.env, C1);
|
||||
// What to multiply by.
|
||||
MatMul<kAddBias>(num_interleaved, A, B2, scale, bias2, activations.env, C2);
|
||||
// Compute the hidden layer activations.
|
||||
MatMul<kAddBias>(num_interleaved, x, w1, scale, bias1, activations.env,
|
||||
hidden_activations);
|
||||
if constexpr (!kIsVit) {
|
||||
MatMul<kAddBias>(num_interleaved, x, w2, scale, bias2, activations.env,
|
||||
multiplier);
|
||||
}
|
||||
|
||||
// Activation (Gelu) and multiply by gate. Store activations in C1.
|
||||
Activation<TConfig>(C1.ptr, C2.ptr, kFFHiddenDim * num_interleaved);
|
||||
// Activation (Gelu) and maybe multiply by gate. Store activations in act.
|
||||
Activation<TConfig>(hidden_activations.ptr, multiplier.ptr,
|
||||
kFFHiddenDim * num_interleaved);
|
||||
|
||||
// Hidden layer -> output layer.
|
||||
MatMul<kAddBias>(num_interleaved, ConstMat(C1),
|
||||
ConstMat(layer_weights->linear_w.data(), kFFHiddenDim),
|
||||
layer_weights->linear_w.scale(), output_bias,
|
||||
activations.env,
|
||||
MutableMat(activations.ffw_out.All(), kModelDim));
|
||||
MatMul<kAddBias>(num_interleaved, ConstMat(hidden_activations), w_output,
|
||||
output_scale, output_bias, activations.env, ffw_out);
|
||||
}
|
||||
|
||||
// `batch_idx` indicates which row of `x` to write to.
|
||||
|
|
@ -557,8 +728,17 @@ HWY_NOINLINE void FFW(Activations& activations, size_t num_interleaved,
|
|||
// called for batches of tokens in prefill, but batches of queries in decode.
|
||||
template <class TConfig>
|
||||
HWY_NOINLINE void EmbedToken(int token, size_t batch_idx, size_t pos,
|
||||
size_t pos_in_prompt,
|
||||
const CompressedWeights<TConfig>& weights,
|
||||
RowVectorBatch<float>& x) {
|
||||
RowVectorBatch<float>& x,
|
||||
const ImageTokens* image_tokens) {
|
||||
// Image tokens just need to be copied.
|
||||
if (image_tokens != nullptr && pos_in_prompt < image_tokens->BatchSize()) {
|
||||
hwy::CopyBytes(image_tokens->Batch(pos_in_prompt), x.Batch(batch_idx),
|
||||
x.Len() * sizeof(x.Const()[0]));
|
||||
return;
|
||||
}
|
||||
|
||||
constexpr size_t kModelDim = TConfig::kModelDim;
|
||||
constexpr size_t kVocabSize = TConfig::kVocabSize;
|
||||
GEMMA_CONSTEXPR_EMBSCALING const float kEmbScaling =
|
||||
|
|
@ -598,7 +778,8 @@ void PostNorm(size_t num_interleaved, const WeightT& weights, InOutT* inout) {
|
|||
|
||||
template <class TConfig>
|
||||
HWY_NOINLINE void TransformerLayer(
|
||||
const QueriesPos& queries_pos, size_t num_tokens, size_t layer,
|
||||
const QueriesPos& queries_pos, const QueriesPos& queries_prefix_end,
|
||||
size_t num_tokens, size_t layer,
|
||||
const CompressedLayer<TConfig>* layer_weights, Activations& activations,
|
||||
const hwy::Divisor& div_seq_len, const KVCaches& kv_caches) {
|
||||
constexpr size_t kModelDim = TConfig::kModelDim;
|
||||
|
|
@ -611,8 +792,9 @@ HWY_NOINLINE void TransformerLayer(
|
|||
layer_weights->pre_attention_norm_scale.data_scale1(),
|
||||
activations.pre_att_rms_out.All(), kModelDim);
|
||||
|
||||
Attention<TConfig>(type, queries_pos, num_tokens, layer_of_type, activations,
|
||||
layer_weights, div_seq_len, kv_caches);
|
||||
Attention<TConfig>(type, queries_pos, queries_prefix_end, num_tokens,
|
||||
layer_of_type, activations, layer_weights, div_seq_len,
|
||||
kv_caches);
|
||||
|
||||
PostNorm<TConfig>(num_interleaved, layer_weights->post_attention_norm_scale,
|
||||
activations.att_sums.All());
|
||||
|
|
@ -635,6 +817,51 @@ HWY_NOINLINE void TransformerLayer(
|
|||
/*is_attention=*/false);
|
||||
}
|
||||
|
||||
// Vit transformer layer. Some comments below refer to the Vit implementation in
|
||||
// the Big Vision codebase. See
|
||||
// github.com/google-research/big_vision/blob/main/big_vision/models/vit.py
|
||||
// TODO(keysers): consider adding a wrapper for both LayerNorm with RMSNorm and
|
||||
// try mergig this with TransformerLayer.
|
||||
template <class TConfig>
|
||||
HWY_NOINLINE void VitTransformerLayer(
|
||||
size_t num_tokens, size_t layer,
|
||||
const CompressedLayer<TConfig>* layer_weights, Activations& activations) {
|
||||
constexpr size_t kModelDim = TConfig::kModelDim;
|
||||
auto type = TConfig::kLayerConfig[layer];
|
||||
HWY_ASSERT(type == LayerAttentionType::kVit);
|
||||
|
||||
auto& x = activations.x;
|
||||
HWY_ASSERT(x.BatchSize() == num_tokens);
|
||||
HWY_ASSERT(x.Len() == kModelDim);
|
||||
|
||||
// y = nn.LayerNorm()(x)
|
||||
// y ~ pre_att_rms_out
|
||||
LayerNormBatched(num_tokens, x.All(),
|
||||
layer_weights->vit.layer_norm_0_scale.data_scale1(),
|
||||
layer_weights->vit.layer_norm_0_bias.data_scale1(),
|
||||
activations.pre_att_rms_out.All(), kModelDim);
|
||||
// y = out["sa"] = nn.MultiHeadDotProductAttention(...)(y, y)
|
||||
// y ~ att_sums
|
||||
VitAttention<TConfig>(num_tokens, layer, activations, layer_weights)();
|
||||
|
||||
// x = out["+sa"] = x + y
|
||||
AddFromBatched(num_tokens, activations.att_sums.All(), x.All(), kModelDim);
|
||||
|
||||
// y = nn.LayerNorm()(x)
|
||||
// y ~ bf_pre_ffw_rms_out
|
||||
LayerNormBatched(num_tokens, x.All(),
|
||||
layer_weights->vit.layer_norm_1_scale.data_scale1(),
|
||||
layer_weights->vit.layer_norm_1_bias.data_scale1(),
|
||||
activations.bf_pre_ffw_rms_out.All(), kModelDim);
|
||||
|
||||
// y = out["mlp"] = MlpBlock(...)(y)
|
||||
// y ~ ffw_out
|
||||
FFW<TConfig>(activations, num_tokens, layer_weights);
|
||||
|
||||
// x = out["+mlp"] = x + y
|
||||
AddFromBatched(num_tokens, activations.ffw_out.All(), x.All(), kModelDim);
|
||||
}
|
||||
|
||||
// Prefill() and Transformer() increment positions in-place.
|
||||
using QueriesMutablePos = hwy::Span<size_t>;
|
||||
|
||||
|
|
@ -642,13 +869,14 @@ using QueriesMutablePos = hwy::Span<size_t>;
|
|||
template <class TConfig>
|
||||
HWY_NOINLINE void Prefill(
|
||||
const QueriesPromptTokens& queries_prompt,
|
||||
const QueriesMutablePos& queries_pos, const size_t query_idx_start,
|
||||
const CompressedWeights<TConfig>& weights, Activations& activations,
|
||||
const RuntimeConfig& runtime_config, const hwy::Divisor& div_seq_len,
|
||||
const KVCaches& kv_caches) {
|
||||
const QueriesMutablePos& queries_pos, const QueriesPos& queries_prefix_end,
|
||||
const size_t query_idx_start, const CompressedWeights<TConfig>& weights,
|
||||
Activations& activations, const RuntimeConfig& runtime_config,
|
||||
const hwy::Divisor& div_seq_len, const KVCaches& kv_caches) {
|
||||
PROFILER_ZONE("Gen.Prefill");
|
||||
const size_t num_queries = queries_prompt.size();
|
||||
HWY_ASSERT(queries_pos.size() == num_queries);
|
||||
HWY_ASSERT(queries_prefix_end.size() == num_queries);
|
||||
HWY_ASSERT(kv_caches.size() == num_queries);
|
||||
|
||||
// Batches are important for amortizing loading weights over multiple tokens.
|
||||
|
|
@ -667,27 +895,50 @@ HWY_NOINLINE void Prefill(
|
|||
// Single query at a time, so pass slices of the spans because
|
||||
// GemmaAttention will only access the first KV cache and position.
|
||||
QueriesPos single_query_pos(&queries_pos[qi], 1);
|
||||
QueriesPos single_query_prefix_end(&queries_prefix_end[qi], 1);
|
||||
KVCaches single_kv_cache(&kv_caches[qi], 1);
|
||||
|
||||
const size_t prefill_per_query = queries_prompt[qi].size() - 1;
|
||||
const size_t prompt_size = queries_prompt[qi].size();
|
||||
// In autoregressive mode, we don't need to prefill the last token, so - 1.
|
||||
size_t prefill_this_query = prompt_size - 1;
|
||||
const size_t prefix_end_this_query = queries_prefix_end[qi];
|
||||
// We can't attend beyond the prompt_size.
|
||||
HWY_ASSERT(prefix_end_this_query <= prompt_size);
|
||||
// Special case: if the prefix includes the last token, we need to prefill
|
||||
// the last token, too. However, we need to rewind this for the generation
|
||||
// of the first token. So we need to keep track of this.
|
||||
// TODO: consider implementing masking instead of this logic?
|
||||
bool attend_to_last_token = (prefill_this_query < prefix_end_this_query);
|
||||
if (attend_to_last_token) {
|
||||
// The difference can be at most 1.
|
||||
prefill_this_query += 1;
|
||||
HWY_ASSERT(prefill_this_query == prefix_end_this_query);
|
||||
}
|
||||
// In prefix-LM mode, we need to look at all the tokens for the prefix in
|
||||
// one iteration through the layers, so we need a large enough batch size.
|
||||
HWY_ASSERT(max_tbatch_size >= prefill_this_query);
|
||||
|
||||
// For each batch of tokens in the query:
|
||||
for (size_t tbatch_start = 0; tbatch_start < prefill_per_query;
|
||||
for (size_t tbatch_start = 0; tbatch_start < prefill_this_query;
|
||||
tbatch_start += max_tbatch_size) {
|
||||
// Fill activations.x (much faster than TransformerLayer).
|
||||
const size_t tbatch_size =
|
||||
HWY_MIN(max_tbatch_size, prefill_per_query - tbatch_start);
|
||||
HWY_MIN(max_tbatch_size, prefill_this_query - tbatch_start);
|
||||
|
||||
// Fill activations.x (much faster than TransformerLayer).
|
||||
for (size_t ti = 0; ti < tbatch_size; ++ti) {
|
||||
const int token = queries_prompt[qi][tbatch_start + ti];
|
||||
const size_t pos = queries_pos[qi] + ti;
|
||||
EmbedToken<TConfig>(token, ti, pos, weights, activations.x);
|
||||
const size_t pos_in_prompt = tbatch_start + ti;
|
||||
const int token = queries_prompt[qi][pos_in_prompt];
|
||||
EmbedToken<TConfig>(token, ti, pos, pos_in_prompt, weights,
|
||||
activations.x, runtime_config.image_tokens);
|
||||
}
|
||||
|
||||
// Transformer with one batch of tokens from a single query.
|
||||
for (size_t layer = 0; layer < TConfig::kLayers; ++layer) {
|
||||
const auto* layer_weights = weights.GetLayer(layer);
|
||||
TransformerLayer<TConfig>(single_query_pos, tbatch_size, layer,
|
||||
layer_weights, activations, div_seq_len,
|
||||
single_kv_cache);
|
||||
TransformerLayer<TConfig>(single_query_pos, single_query_prefix_end,
|
||||
tbatch_size, layer, layer_weights,
|
||||
activations, div_seq_len, single_kv_cache);
|
||||
}
|
||||
|
||||
// NOTE: we unconditionally call StreamToken, even if EOS.
|
||||
|
|
@ -695,19 +946,111 @@ HWY_NOINLINE void Prefill(
|
|||
const size_t pos = queries_pos[qi] + ti;
|
||||
const size_t pos_in_prompt = tbatch_start + ti;
|
||||
const int token = queries_prompt[qi][pos_in_prompt];
|
||||
runtime_config.StreamToken(query_idx_start + qi, pos, token, 0.0f);
|
||||
if (pos_in_prompt < prompt_size - 1) {
|
||||
runtime_config.StreamToken(query_idx_start + qi, pos, token, 0.0f);
|
||||
} else {
|
||||
// The last token will be streamed later and we should only get here
|
||||
// if we need to attend to the last token because it is in the prefix.
|
||||
HWY_ASSERT(attend_to_last_token);
|
||||
}
|
||||
}
|
||||
|
||||
queries_pos[qi] += tbatch_size;
|
||||
} // for tbatch_start
|
||||
if (attend_to_last_token) {
|
||||
// We need to rewind the position for the last token that we only
|
||||
// attended to to make sure the prefix LM sees everything.
|
||||
// This means we duplicate work on the last prompt token in autoregressive
|
||||
// decoding. Alternatives: (1) real masking; (2) always prefill the last
|
||||
// token and only generate the next one from the already prefilled
|
||||
// activations.
|
||||
queries_pos[qi] -= 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Gets the patches of the image and embeds them with the image embedding
|
||||
// kernel. The result is stored in activations.x.
|
||||
template <class TConfig>
|
||||
HWY_NOINLINE void EmbedImagePatches(const Image& image,
|
||||
const CompressedWeights<TConfig>& weights,
|
||||
Activations& activations) {
|
||||
static constexpr size_t kModelDim = TConfig::VitConfig::kModelDim;
|
||||
static constexpr size_t kPatchWidth = TConfig::VitConfig::kPatchWidth;
|
||||
static constexpr size_t kSeqLen = TConfig::VitConfig::kSeqLen;
|
||||
constexpr size_t kPatchSize = kPatchWidth * kPatchWidth * 3;
|
||||
HWY_ASSERT(weights.vit_img_embedding_kernel.NumElements() ==
|
||||
kPatchSize * kModelDim);
|
||||
HWY_ASSERT(activations.x.Len() == kModelDim);
|
||||
std::vector<hwy::AlignedFreeUniquePtr<float[]>> image_patches(kSeqLen);
|
||||
for (size_t i = 0; i < kSeqLen; ++i) {
|
||||
image_patches[i] = hwy::AllocateAligned<float>(kPatchSize);
|
||||
image.GetPatch(i, image_patches[i].get());
|
||||
}
|
||||
// img/embedding/kernel has original shape (14, 14, 3, 1152)
|
||||
// H x W x C x D transposed to D x (H x W x C) so here (1152, 14 * 14 * 3)
|
||||
// image_patches is (256, 14 * 14 * 3)
|
||||
// This could be done as one MatMul like:
|
||||
// RowVectorBatch<float> image_patches(kSeqLen, kPatchSize);
|
||||
// [Get patches]
|
||||
// MatMul</*kAdd=*/true>(
|
||||
// kVitSeqLen, ConstMat(image_patches.All(), kPatchSize),
|
||||
// ConstMat(weights.vit_img_embedding_kernel.data_scale1(), kPatchSize),
|
||||
// /*scale=*/1.0f, weights.vit_img_embedding_bias.data_scale1(),
|
||||
// activations.env, MutableMat(activations.x.All(), kVitModelDim));
|
||||
// However, MatMul currently requires that
|
||||
// A.cols % (2 * hn::Lanes(hn::ScalableTag<MulT>())) == 0
|
||||
// which is not the case here. We should relax that requirement on MatMul and
|
||||
// then use the above. For now, we rely on MatVecAdd instead.
|
||||
for (size_t i = 0; i < kSeqLen; ++i) {
|
||||
MatVecAdd<kModelDim, kPatchSize>(
|
||||
weights.vit_img_embedding_kernel, 0, image_patches[i].get(),
|
||||
weights.vit_img_embedding_bias.data_scale1(), activations.x.Batch(i),
|
||||
activations.env.Pools().Outer());
|
||||
}
|
||||
// Add position embeddings.
|
||||
AddFrom(weights.vit_img_pos_embedding.data_scale1(), activations.x.All(),
|
||||
kSeqLen * kModelDim);
|
||||
}
|
||||
|
||||
// Prefills the image tokens with the ViT encoder.
|
||||
template <class TConfig>
|
||||
HWY_NOINLINE void PrefillVit(const CompressedWeights<TConfig>& weights,
|
||||
const RuntimeConfig& runtime_config,
|
||||
const Image& image, ImageTokens& image_tokens,
|
||||
Activations& activations) {
|
||||
PROFILER_ZONE("Gen.PrefillVit");
|
||||
const size_t num_tokens = TConfig::VitConfig::kSeqLen;
|
||||
const size_t kVitModelDim = TConfig::VitConfig::kModelDim;
|
||||
HWY_ASSERT(num_tokens == activations.x.BatchSize());
|
||||
// Embed the image patches.
|
||||
EmbedImagePatches<TConfig>(image, weights, activations);
|
||||
// Go through all layers.
|
||||
for (size_t layer = 0; layer < TConfig::VitConfig::kLayers; ++layer) {
|
||||
const auto* layer_weights = weights.GetVitLayer(layer);
|
||||
VitTransformerLayer<typename TConfig::VitConfig>(
|
||||
num_tokens, layer, layer_weights, activations);
|
||||
}
|
||||
// Final Layernorm.
|
||||
LayerNormBatched(num_tokens, activations.x.All(),
|
||||
weights.vit_encoder_norm_scale.data_scale1(),
|
||||
weights.vit_encoder_norm_bias.data_scale1(),
|
||||
activations.x.All(), kVitModelDim);
|
||||
|
||||
// Apply head embedding into image_tokens of size of the LLM kModelDim.
|
||||
MatMul</*kAdd=*/true>(
|
||||
num_tokens, ConstMat(activations.x.All(), kVitModelDim),
|
||||
ConstMat(weights.vit_img_head_kernel.data_scale1(), kVitModelDim),
|
||||
/*scale=*/1.0f, weights.vit_img_head_bias.data_scale1(), activations.env,
|
||||
MutableMat(image_tokens.All(), TConfig::kModelDim));
|
||||
}
|
||||
|
||||
// Generates one token for each query. `queries_token` is the previous token
|
||||
// from each query, and `queries_pos` are their position in the sequence.
|
||||
template <class TConfig>
|
||||
HWY_NOINLINE void Transformer(
|
||||
const QueriesToken& queries_token, const QueriesMutablePos& queries_pos,
|
||||
const QueriesPos& queries_prefix_end,
|
||||
const CompressedWeights<TConfig>& weights, Activations& activations,
|
||||
const hwy::Divisor& div_seq_len, const KVCaches& kv_caches,
|
||||
const LayersOutputFunc& layers_output,
|
||||
|
|
@ -715,6 +1058,7 @@ HWY_NOINLINE void Transformer(
|
|||
constexpr size_t kModelDim = TConfig::kModelDim;
|
||||
const size_t num_queries = queries_token.size();
|
||||
HWY_DASSERT(queries_pos.size() == num_queries);
|
||||
HWY_DASSERT(queries_prefix_end.size() == num_queries);
|
||||
|
||||
if (layers_output) {
|
||||
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
|
||||
|
|
@ -726,13 +1070,14 @@ HWY_NOINLINE void Transformer(
|
|||
|
||||
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
|
||||
EmbedToken<TConfig>(queries_token[query_idx], query_idx,
|
||||
queries_pos[query_idx], weights, activations.x);
|
||||
queries_pos[query_idx], /*pos_in_prompt=*/0, weights,
|
||||
activations.x, /*image_tokens=*/nullptr);
|
||||
}
|
||||
|
||||
for (size_t layer = 0; layer < TConfig::kLayers; ++layer) {
|
||||
const CompressedLayer<TConfig>* layer_weights = weights.GetLayer(layer);
|
||||
TransformerLayer<TConfig>(queries_pos, /*num_tokens=*/1, layer,
|
||||
layer_weights, activations, div_seq_len,
|
||||
TransformerLayer<TConfig>(queries_pos, queries_prefix_end, /*num_tokens=*/1,
|
||||
layer, layer_weights, activations, div_seq_len,
|
||||
kv_caches);
|
||||
|
||||
if (activations_observer) {
|
||||
|
|
@ -834,8 +1179,10 @@ template <class TConfig>
|
|||
void GenerateT(const ByteStorageT& weights_u8, Activations& activations,
|
||||
const RuntimeConfig& runtime_config,
|
||||
const QueriesPromptTokens& queries_prompt,
|
||||
const QueriesPos& queries_pos_in, const size_t query_idx_start,
|
||||
const KVCaches& kv_caches, TimingInfo& timing_info) {
|
||||
const QueriesPos& queries_pos_in,
|
||||
const QueriesPos& queries_prefix_end,
|
||||
const size_t query_idx_start, const KVCaches& kv_caches,
|
||||
TimingInfo& timing_info) {
|
||||
constexpr size_t kModelDim = TConfig::kModelDim;
|
||||
constexpr size_t kVocabSize = TConfig::kVocabSize;
|
||||
const CompressedWeights<TConfig>& weights =
|
||||
|
|
@ -888,8 +1235,8 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations,
|
|||
prefill_activations.Allocate<TConfig>(runtime_config.prefill_tbatch_size,
|
||||
activations.env.Pools());
|
||||
}
|
||||
Prefill<TConfig>(queries_prompt, queries_mutable_pos, query_idx_start,
|
||||
weights,
|
||||
Prefill<TConfig>(queries_prompt, queries_mutable_pos, queries_prefix_end,
|
||||
query_idx_start, weights,
|
||||
use_prefill_activations ? prefill_activations : activations,
|
||||
runtime_config, div_seq_len, kv_caches);
|
||||
// Compute the number of tokens that were prefilled and notify timing_info.
|
||||
|
|
@ -918,10 +1265,10 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations,
|
|||
const double gen_start = hwy::platform::Now();
|
||||
for (size_t gen = 0; gen < HWY_MIN(max_tokens, max_generated_tokens); ++gen) {
|
||||
// Decode generates one token per query and increments queries_mutable_pos.
|
||||
Transformer<TConfig>(QueriesToken(gen_tokens.data(), num_queries),
|
||||
queries_mutable_pos, weights, activations, div_seq_len,
|
||||
kv_caches, runtime_config.layers_output,
|
||||
runtime_config.activations_observer);
|
||||
Transformer<TConfig>(
|
||||
QueriesToken(gen_tokens.data(), num_queries), queries_mutable_pos,
|
||||
queries_prefix_end, weights, activations, div_seq_len, kv_caches,
|
||||
runtime_config.layers_output, runtime_config.activations_observer);
|
||||
// queries_pos are incremented by Transformer.
|
||||
|
||||
bool all_queries_eos = true;
|
||||
|
|
@ -954,8 +1301,9 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations,
|
|||
template <class TConfig>
|
||||
void GenerateSingleT(const ByteStorageT& weights_u8,
|
||||
const RuntimeConfig& runtime_config,
|
||||
const PromptTokens& prompt, size_t pos, KVCache& kv_cache,
|
||||
PerClusterPools& pools, TimingInfo& timing_info) {
|
||||
const PromptTokens& prompt, size_t pos, size_t prefix_end,
|
||||
KVCache& kv_cache, PerClusterPools& pools,
|
||||
TimingInfo& timing_info) {
|
||||
constexpr size_t kNumQueries = 1;
|
||||
const size_t qbatch_start = 0;
|
||||
|
||||
|
|
@ -963,20 +1311,24 @@ void GenerateSingleT(const ByteStorageT& weights_u8,
|
|||
Activations activations;
|
||||
activations.Allocate<TConfig>(kNumQueries, pools);
|
||||
|
||||
const QueriesPromptTokens prompt_span(&prompt, kNumQueries);
|
||||
QueriesPos pos_span(&pos, kNumQueries);
|
||||
const QueriesPromptTokens queries_prompt(&prompt, kNumQueries);
|
||||
QueriesPos queries_pos(&pos, kNumQueries);
|
||||
const QueriesPos queries_prefix_end(&prefix_end, kNumQueries);
|
||||
const KVCaches kv_caches{&kv_cache, kNumQueries};
|
||||
|
||||
GenerateT<TConfig>(weights_u8, activations, runtime_config, prompt_span,
|
||||
pos_span, qbatch_start, kv_caches, timing_info);
|
||||
GenerateT<TConfig>(weights_u8, activations, runtime_config, queries_prompt,
|
||||
queries_pos, queries_prefix_end, qbatch_start, kv_caches,
|
||||
timing_info);
|
||||
}
|
||||
|
||||
template <class TConfig>
|
||||
void GenerateBatchT(const ByteStorageT& weights_u8,
|
||||
const RuntimeConfig& runtime_config,
|
||||
const QueriesPromptTokens& queries_prompt,
|
||||
const QueriesPos& queries_pos, const KVCaches& kv_caches,
|
||||
PerClusterPools& pools, TimingInfo& timing_info) {
|
||||
const QueriesPos& queries_pos,
|
||||
const QueriesPos& queries_prefix_end,
|
||||
const KVCaches& kv_caches, PerClusterPools& pools,
|
||||
TimingInfo& timing_info) {
|
||||
const size_t num_queries = queries_prompt.size();
|
||||
HWY_ASSERT(queries_pos.size() == num_queries);
|
||||
HWY_ASSERT(kv_caches.size() == num_queries);
|
||||
|
|
@ -995,9 +1347,33 @@ void GenerateBatchT(const ByteStorageT& weights_u8,
|
|||
const QueriesPromptTokens qbatch_prompts(&queries_prompt[qbatch_start],
|
||||
qbatch_size);
|
||||
QueriesPos qbatch_pos(&queries_pos[qbatch_start], qbatch_size);
|
||||
const QueriesPos qbatch_prefix_end(&queries_prefix_end[qbatch_start],
|
||||
qbatch_size);
|
||||
const KVCaches qbatch_kv(&kv_caches[qbatch_start], qbatch_size);
|
||||
GenerateT<TConfig>(weights_u8, activations, runtime_config, qbatch_prompts,
|
||||
qbatch_pos, qbatch_start, qbatch_kv, timing_info);
|
||||
qbatch_pos, qbatch_prefix_end, qbatch_start, qbatch_kv,
|
||||
timing_info);
|
||||
}
|
||||
}
|
||||
|
||||
template <class TConfig>
|
||||
void GenerateImageTokensT(const ByteStorageT& weights_u8,
|
||||
const RuntimeConfig& runtime_config,
|
||||
const Image& image, ImageTokens& image_tokens,
|
||||
PerClusterPools& pools) {
|
||||
if constexpr (TConfig::VitConfig::kLayers == 0) {
|
||||
return;
|
||||
} else {
|
||||
Activations prefill_activations;
|
||||
RuntimeConfig prefill_runtime_config = runtime_config;
|
||||
prefill_runtime_config.prefill_tbatch_size = TConfig::VitConfig::kSeqLen;
|
||||
prefill_activations.Allocate<typename TConfig::VitConfig>(
|
||||
prefill_runtime_config.prefill_tbatch_size, pools);
|
||||
// Weights are for the full PaliGemma model, not just the ViT part.
|
||||
const CompressedWeights<TConfig>& weights =
|
||||
*reinterpret_cast<const CompressedWeights<TConfig>*>(weights_u8.get());
|
||||
PrefillVit<TConfig>(weights, prefill_runtime_config, image, image_tokens,
|
||||
prefill_activations);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -1010,20 +1386,30 @@ void GenerateBatchT(const ByteStorageT& weights_u8,
|
|||
void GenerateSingle( // NOLINT(misc-definitions-in-headers)
|
||||
GEMMA_CONFIG, const ByteStorageT& weights_u8,
|
||||
const RuntimeConfig& runtime_config, const PromptTokens& prompt, size_t pos,
|
||||
KVCache& kv_cache, PerClusterPools& pools, TimingInfo& timing_info) {
|
||||
size_t prefix_end, KVCache& kv_cache, PerClusterPools& pools,
|
||||
TimingInfo& timing_info) {
|
||||
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateSingleT<GEMMA_CONFIG>)
|
||||
(weights_u8, runtime_config, prompt, pos, kv_cache, pools, timing_info);
|
||||
(weights_u8, runtime_config, prompt, pos, prefix_end, kv_cache, pools,
|
||||
timing_info);
|
||||
}
|
||||
|
||||
void GenerateBatch( // NOLINT(misc-definitions-in-headers)
|
||||
GEMMA_CONFIG, const ByteStorageT& weights_u8,
|
||||
const RuntimeConfig& runtime_config,
|
||||
const QueriesPromptTokens& queries_prompt, const QueriesPos& queries_pos,
|
||||
const KVCaches& kv_caches, PerClusterPools& pools,
|
||||
TimingInfo& timing_info) {
|
||||
const QueriesPos& queries_prefix_end, const KVCaches& kv_caches,
|
||||
PerClusterPools& pools, TimingInfo& timing_info) {
|
||||
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateBatchT<GEMMA_CONFIG>)
|
||||
(weights_u8, runtime_config, queries_prompt, queries_pos, kv_caches, pools,
|
||||
timing_info);
|
||||
(weights_u8, runtime_config, queries_prompt, queries_pos, queries_prefix_end,
|
||||
kv_caches, pools, timing_info);
|
||||
}
|
||||
|
||||
void GenerateImageTokens( // NOLINT(misc-definitions-in-headers)
|
||||
GEMMA_CONFIG, const ByteStorageT& weights_u8,
|
||||
const RuntimeConfig& runtime_config, const Image& image,
|
||||
ImageTokens& image_tokens, PerClusterPools& pools) {
|
||||
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateImageTokensT<GEMMA_CONFIG>)
|
||||
(weights_u8, runtime_config, image, image_tokens, pools);
|
||||
}
|
||||
|
||||
#endif // HWY_ONCE
|
||||
|
|
|
|||
|
|
@ -24,10 +24,12 @@
|
|||
#include <string.h>
|
||||
|
||||
#include <utility> // std::move
|
||||
#include <vector>
|
||||
|
||||
#include "compression/io.h" // Path
|
||||
#include "gemma/common.h"
|
||||
#include "gemma/weights.h"
|
||||
#include "paligemma/image.h"
|
||||
#include "util/threading.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
#include "hwy/highway.h"
|
||||
|
|
@ -62,13 +64,18 @@ Gemma::~Gemma() {
|
|||
extern void GenerateSingle(CONFIGT<TWEIGHT>, const ByteStorageT& weights_u8, \
|
||||
const RuntimeConfig& runtime_config, \
|
||||
const PromptTokens& prompt, size_t pos, \
|
||||
KVCache& kv_cache, PerClusterPools& pools, \
|
||||
TimingInfo& timing_info); \
|
||||
size_t prefix_end, KVCache& kv_cache, \
|
||||
PerClusterPools& pools, TimingInfo& timing_info); \
|
||||
extern void GenerateBatch( \
|
||||
CONFIGT<TWEIGHT>, const ByteStorageT& weights_u8, \
|
||||
const RuntimeConfig& runtime_config, const QueriesPromptTokens& prompts, \
|
||||
const QueriesPos& queries_pos, const KVCaches& kv_caches, \
|
||||
PerClusterPools& pools, TimingInfo& timing_info);
|
||||
const QueriesPos& queries_pos, \
|
||||
const QueriesPos& queries_prefix_end, const KVCaches& kv_caches, \
|
||||
PerClusterPools& pools, TimingInfo& timing_info); \
|
||||
extern void GenerateImageTokens( \
|
||||
CONFIGT<TWEIGHT>, const ByteStorageT& weights_u8, \
|
||||
const RuntimeConfig& runtime_config, const Image& image, \
|
||||
ImageTokens& image_tokens, PerClusterPools& pools);
|
||||
GEMMA_FOREACH_CONFIG_AND_WEIGHT(GEMMA_DECLARE);
|
||||
|
||||
// Adapters to select from the above overloads via CallForModelAndWeight.
|
||||
|
|
@ -76,10 +83,11 @@ template <class TConfig>
|
|||
struct GenerateSingleT {
|
||||
void operator()(const ByteStorageT& weights_u8,
|
||||
const RuntimeConfig& runtime_config,
|
||||
const PromptTokens& prompt, size_t pos, KVCache& kv_cache,
|
||||
PerClusterPools& pools, TimingInfo& timing_info) const {
|
||||
GenerateSingle(TConfig(), weights_u8, runtime_config, prompt, pos, kv_cache,
|
||||
pools, timing_info);
|
||||
const PromptTokens& prompt, size_t pos, size_t prefix_end,
|
||||
KVCache& kv_cache, PerClusterPools& pools,
|
||||
TimingInfo& timing_info) const {
|
||||
GenerateSingle(TConfig(), weights_u8, runtime_config, prompt, pos,
|
||||
prefix_end, kv_cache, pools, timing_info);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -88,21 +96,34 @@ struct GenerateBatchT {
|
|||
void operator()(const ByteStorageT& weights_u8,
|
||||
const RuntimeConfig& runtime_config,
|
||||
const QueriesPromptTokens& queries_prompt,
|
||||
const QueriesPos& queries_pos, const KVCaches& kv_caches,
|
||||
PerClusterPools& pools, TimingInfo& timing_info) const {
|
||||
const QueriesPos& queries_pos,
|
||||
const QueriesPos& queries_prefix_end,
|
||||
const KVCaches& kv_caches, PerClusterPools& pools,
|
||||
TimingInfo& timing_info) const {
|
||||
GenerateBatch(TConfig(), weights_u8, runtime_config, queries_prompt,
|
||||
queries_pos, kv_caches, pools, timing_info);
|
||||
queries_pos, queries_prefix_end, kv_caches, pools,
|
||||
timing_info);
|
||||
}
|
||||
};
|
||||
|
||||
template <class TConfig>
|
||||
struct GenerateImageTokensT {
|
||||
void operator()(const ByteStorageT& weights_u8,
|
||||
const RuntimeConfig& runtime_config, const Image& image,
|
||||
ImageTokens& image_tokens, PerClusterPools& pools) const {
|
||||
GenerateImageTokens(TConfig(), weights_u8, runtime_config, image,
|
||||
image_tokens, pools);
|
||||
}
|
||||
};
|
||||
|
||||
void Gemma::Generate(const RuntimeConfig& runtime_config,
|
||||
const PromptTokens& prompt, size_t pos, KVCache& kv_cache,
|
||||
TimingInfo& timing_info) {
|
||||
const PromptTokens& prompt, size_t pos, size_t prefix_end,
|
||||
KVCache& kv_cache, TimingInfo& timing_info) {
|
||||
if (runtime_config.use_spinning) pools_.StartSpinning();
|
||||
|
||||
CallForModelAndWeight<GenerateSingleT>(info_.model, info_.weight, weights_u8_,
|
||||
runtime_config, prompt, pos, kv_cache,
|
||||
pools_, timing_info);
|
||||
CallForModelAndWeight<GenerateSingleT>(
|
||||
info_.model, info_.weight, weights_u8_, runtime_config, prompt, pos,
|
||||
prefix_end, kv_cache, pools_, timing_info);
|
||||
|
||||
if (runtime_config.use_spinning) pools_.StopSpinning();
|
||||
}
|
||||
|
|
@ -110,12 +131,33 @@ void Gemma::Generate(const RuntimeConfig& runtime_config,
|
|||
void Gemma::GenerateBatch(const RuntimeConfig& runtime_config,
|
||||
const QueriesPromptTokens& queries_prompt,
|
||||
const QueriesPos& queries_pos,
|
||||
const QueriesPos& queries_prefix_end,
|
||||
const KVCaches& kv_caches, TimingInfo& timing_info) {
|
||||
// If we did not get passed prefix ends (size 0), assume 0 and pass that on.
|
||||
QueriesPos mutable_queries_prefix_end = queries_prefix_end;
|
||||
std::vector<size_t> prefix_end_vec;
|
||||
if (queries_prefix_end.size() == 0) {
|
||||
prefix_end_vec.resize(queries_prompt.size(), 0);
|
||||
mutable_queries_prefix_end =
|
||||
QueriesPos(prefix_end_vec.data(), prefix_end_vec.size());
|
||||
}
|
||||
|
||||
if (runtime_config.use_spinning) pools_.StartSpinning();
|
||||
|
||||
CallForModelAndWeight<GenerateBatchT>(
|
||||
info_.model, info_.weight, weights_u8_, runtime_config, queries_prompt,
|
||||
queries_pos, kv_caches, pools_, timing_info);
|
||||
queries_pos, mutable_queries_prefix_end, kv_caches, pools_, timing_info);
|
||||
|
||||
if (runtime_config.use_spinning) pools_.StopSpinning();
|
||||
}
|
||||
|
||||
void Gemma::GenerateImageTokens(const RuntimeConfig& runtime_config,
|
||||
const Image& image, ImageTokens& image_tokens) {
|
||||
if (runtime_config.use_spinning) pools_.StartSpinning();
|
||||
|
||||
CallForModelAndWeight<GenerateImageTokensT>(info_.model, info_.weight,
|
||||
weights_u8_, runtime_config,
|
||||
image, image_tokens, pools_);
|
||||
|
||||
if (runtime_config.use_spinning) pools_.StopSpinning();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@
|
|||
#include "gemma/common.h"
|
||||
#include "gemma/kv_cache.h"
|
||||
#include "gemma/tokenizer.h"
|
||||
#include "paligemma/image.h"
|
||||
#include "util/allocator.h"
|
||||
#include "util/threading.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
|
@ -75,6 +76,10 @@ using LayersOutputFunc = std::function<void(size_t, size_t, const std::string&,
|
|||
using ActivationsObserverFunc =
|
||||
std::function<void(const QueriesPos& queries_pos, int, const Activations&)>;
|
||||
|
||||
// ImageTokens are represented as a RowVectorBatch, where each "batch" index
|
||||
// corresponds to a token for an image patch as computed by the image encoder.
|
||||
using ImageTokens = RowVectorBatch<float>;
|
||||
|
||||
// RuntimeConfig holds configuration for a single generation run.
|
||||
struct RuntimeConfig {
|
||||
// If not empty, batch_stream_token is called for each token in the batch,
|
||||
|
|
@ -110,6 +115,10 @@ struct RuntimeConfig {
|
|||
LayersOutputFunc layers_output; // if not empty, called after each layer.
|
||||
ActivationsObserverFunc activations_observer; // if set, called per-layer.
|
||||
|
||||
// If not empty, these point to the image tokens and are used in the
|
||||
// PaliGemma prefix-LM style attention.
|
||||
const ImageTokens *image_tokens = nullptr;
|
||||
|
||||
// Whether to use thread spinning to reduce barrier synchronization latency.
|
||||
bool use_spinning = true;
|
||||
|
||||
|
|
@ -198,14 +207,34 @@ class Gemma {
|
|||
// `pos` is the position in the KV cache. Users are responsible for
|
||||
// incrementing it in the `*StreamFunc`, or setting to zero for single-turn.
|
||||
void Generate(const RuntimeConfig& runtime_config, const PromptTokens& prompt,
|
||||
size_t pos, KVCache& kv_cache, TimingInfo& timing_info);
|
||||
size_t pos, KVCache& kv_cache, TimingInfo& timing_info) {
|
||||
Generate(runtime_config, prompt, pos, /*prefix_end=*/0, kv_cache,
|
||||
timing_info);
|
||||
}
|
||||
// For prefix-LM style attention, we can pass the end of the prefix.
|
||||
void Generate(const RuntimeConfig& runtime_config, const PromptTokens& prompt,
|
||||
size_t pos, size_t prefix_end, KVCache& kv_cache,
|
||||
TimingInfo& timing_info);
|
||||
|
||||
// `queries_pos` are the positions in the KV cache. Users are responsible for
|
||||
// incrementing them in `BatchStreamFunc`, or setting to zero for single-turn.
|
||||
void GenerateBatch(const RuntimeConfig& runtime_config,
|
||||
const QueriesPromptTokens& queries_prompt,
|
||||
const QueriesPos& queries_pos, const KVCaches& kv_caches,
|
||||
TimingInfo& timing_info);
|
||||
TimingInfo& timing_info) {
|
||||
GenerateBatch(runtime_config, queries_prompt, queries_pos,
|
||||
/*queries_prefix_end=*/{}, kv_caches, timing_info);
|
||||
}
|
||||
// For prefix-LM style attention, we can pass the ends of the prefixes.
|
||||
void GenerateBatch(const RuntimeConfig& runtime_config,
|
||||
const QueriesPromptTokens& queries_prompt,
|
||||
const QueriesPos& queries_pos,
|
||||
const QueriesPos& queries_prefix_end,
|
||||
const KVCaches& kv_caches, TimingInfo& timing_info);
|
||||
|
||||
// Generates the image tokens by running the image encoder ViT.
|
||||
void GenerateImageTokens(const RuntimeConfig& runtime_config,
|
||||
const Image& image, ImageTokens& image_tokens);
|
||||
|
||||
private:
|
||||
PerClusterPools& pools_;
|
||||
|
|
|
|||
|
|
@ -0,0 +1,21 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#undef HWY_TARGET_INCLUDE
|
||||
#define HWY_TARGET_INCLUDE \
|
||||
"gemma/instantiations/paligemma_224_bf16.cc"
|
||||
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||
#define GEMMA_CONFIG ConfigPaliGemma_224<hwy::bfloat16_t>
|
||||
#include "gemma/gemma-inl.h"
|
||||
|
|
@ -0,0 +1,21 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#undef HWY_TARGET_INCLUDE
|
||||
#define HWY_TARGET_INCLUDE \
|
||||
"gemma/instantiations/paligemma_224_f32.cc"
|
||||
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||
#define GEMMA_CONFIG ConfigPaliGemma_224<float>
|
||||
#include "gemma/gemma-inl.h"
|
||||
|
|
@ -0,0 +1,21 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#undef HWY_TARGET_INCLUDE
|
||||
#define HWY_TARGET_INCLUDE \
|
||||
"gemma/instantiations/paligemma_224_sfp.cc"
|
||||
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||
#define GEMMA_CONFIG ConfigPaliGemma_224<SfpStream>
|
||||
#include "gemma/gemma-inl.h"
|
||||
33
gemma/run.cc
33
gemma/run.cc
|
|
@ -25,6 +25,7 @@
|
|||
#include "evals/benchmark_helper.h"
|
||||
#include "gemma/common.h"
|
||||
#include "gemma/gemma.h" // Gemma
|
||||
#include "paligemma/image.h"
|
||||
#include "util/app.h"
|
||||
#include "util/args.h" // HasHelp
|
||||
#include "util/threading.h"
|
||||
|
|
@ -87,6 +88,17 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const InferenceArgs& args,
|
|||
std::mt19937 gen;
|
||||
InitGenerator(args, gen);
|
||||
|
||||
const bool have_image = !args.image_file.path.empty();
|
||||
Image image;
|
||||
ImageTokens image_tokens(256, 2048);
|
||||
if (have_image) {
|
||||
HWY_ASSERT(model.Info().model == Model::PALIGEMMA_224);
|
||||
HWY_ASSERT(image.ReadPPM(args.image_file.path));
|
||||
image.Resize();
|
||||
RuntimeConfig runtime_config = {.verbosity = verbosity, .gen = &gen};
|
||||
model.GenerateImageTokens(runtime_config, image, image_tokens);
|
||||
}
|
||||
|
||||
// callback function invoked for each generated token.
|
||||
auto stream_token = [&](int token, float) {
|
||||
++abs_pos;
|
||||
|
|
@ -132,7 +144,12 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const InferenceArgs& args,
|
|||
}
|
||||
}
|
||||
|
||||
const std::vector<int> prompt = WrapAndTokenize(
|
||||
if (have_image && abs_pos != 0) {
|
||||
// This occurs when we have hit max_generated.
|
||||
abs_pos = 0;
|
||||
}
|
||||
|
||||
std::vector<int> prompt = WrapAndTokenize(
|
||||
model.Tokenizer(), model.Info(), abs_pos, prompt_string);
|
||||
prompt_size = prompt.size();
|
||||
std::cerr << "\n"
|
||||
|
|
@ -151,7 +168,19 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const InferenceArgs& args,
|
|||
.accept_token = accept_token,
|
||||
};
|
||||
args.CopyTo(runtime_config);
|
||||
model.Generate(runtime_config, prompt, abs_pos, kv_cache, timing_info);
|
||||
size_t prefix_end = 0;
|
||||
if (have_image) {
|
||||
runtime_config.image_tokens = &image_tokens;
|
||||
prompt.insert(prompt.begin(), image_tokens.BatchSize(), 0);
|
||||
prompt_size = prompt.size();
|
||||
// The end of the prefix for prefix-LM style attention in Paligemma.
|
||||
// See Figure 2 of https://arxiv.org/abs/2407.07726.
|
||||
prefix_end = prompt_size;
|
||||
// We need to look at all the tokens for the prefix.
|
||||
runtime_config.prefill_tbatch_size = prompt_size;
|
||||
}
|
||||
model.Generate(runtime_config, prompt, abs_pos, prefix_end, kv_cache,
|
||||
timing_info);
|
||||
std::cout << "\n\n";
|
||||
}
|
||||
std::cout
|
||||
|
|
|
|||
|
|
@ -107,6 +107,14 @@ std::vector<int> WrapAndTokenize(const GemmaTokenizer& tokenizer,
|
|||
if (pos == 0) {
|
||||
tokens.insert(tokens.begin(), BOS_ID);
|
||||
}
|
||||
|
||||
// PaliGemma separator. The SEP token "\n" is always tokenized separately.
|
||||
if (info.model == Model::PALIGEMMA_224) {
|
||||
std::vector<int> sep_tokens;
|
||||
HWY_ASSERT(tokenizer.Encode("\n", &sep_tokens));
|
||||
tokens.insert(tokens.end(), sep_tokens.begin(), sep_tokens.end());
|
||||
}
|
||||
|
||||
return tokens;
|
||||
}
|
||||
|
||||
|
|
|
|||
106
gemma/weights.h
106
gemma/weights.h
|
|
@ -48,6 +48,7 @@ struct CompressedLayer {
|
|||
static constexpr size_t kAttVecEinsumWSize = kHeads * kQKVDim * kModelDim;
|
||||
static constexpr size_t kQKVEinsumWSize =
|
||||
(kHeads + 2 * kKVHeads) * kQKVDim * kModelDim;
|
||||
static constexpr size_t kQKVEinsumBSize = (kHeads + 2 * kKVHeads) * kQKVDim;
|
||||
// 2x for (gelu gating vector, gated vector)
|
||||
static constexpr size_t kGatingEinsumWSize = 2 * kFFHiddenDim * kModelDim;
|
||||
static constexpr size_t kConv1dWidth = TConfig::kConv1dWidth;
|
||||
|
|
@ -81,6 +82,24 @@ struct CompressedLayer {
|
|||
ArrayT<float, kGriffinDim * 2> gate_biases;
|
||||
ArrayT<float, kGriffinDim> a;
|
||||
} griffin;
|
||||
|
||||
struct {
|
||||
// MultiHeadDotProductAttention.
|
||||
ArrayT<WeightF32OrBF16, kAttVecEinsumWSize> attn_out_w;
|
||||
ArrayT<float, kModelDim> attn_out_b;
|
||||
ArrayT<WeightF32OrBF16, kQKVEinsumWSize> qkv_einsum_w;
|
||||
ArrayT<float, kQKVEinsumBSize> qkv_einsum_b;
|
||||
// MlpBlock.
|
||||
ArrayT<WeightF32OrBF16, kModelDim * kFFHiddenDim> linear_0_w;
|
||||
ArrayT<float, kFFHiddenDim> linear_0_b;
|
||||
ArrayT<WeightF32OrBF16, kFFHiddenDim * kModelDim> linear_1_w;
|
||||
ArrayT<float, kModelDim> linear_1_b;
|
||||
// LayerNorm.
|
||||
ArrayT<WeightF32OrBF16, kModelDim> layer_norm_0_bias;
|
||||
ArrayT<WeightF32OrBF16, kModelDim> layer_norm_0_scale;
|
||||
ArrayT<WeightF32OrBF16, kModelDim> layer_norm_1_bias;
|
||||
ArrayT<WeightF32OrBF16, kModelDim> layer_norm_1_scale;
|
||||
} vit;
|
||||
};
|
||||
|
||||
ArrayT<Weight, kGatingEinsumWSize> gating_einsum_w;
|
||||
|
|
@ -121,6 +140,7 @@ struct CompressedLayer {
|
|||
out_row + h * kQKVDim, kQKVDim * sizeof(Weight));
|
||||
}
|
||||
}
|
||||
att_weights.set_scale(attn_vec_einsum_w.scale());
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -133,10 +153,21 @@ struct CompressedLayerPointers {
|
|||
pool.Run(0, TConfig::kLayers, [this](uint64_t task, size_t /*thread*/) {
|
||||
this->c_layers[task] = hwy::AllocateAligned<CompressedLayer<TConfig>>(1);
|
||||
});
|
||||
if constexpr (TConfig::VitConfig::kLayers > 0) {
|
||||
pool.Run(0, TConfig::VitConfig::kLayers,
|
||||
[this](uint64_t task, size_t /*thread*/) {
|
||||
this->c_vit_layers[task] = hwy::AllocateAligned<
|
||||
CompressedLayer<typename TConfig::VitConfig>>(1);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
using CLayer = CompressedLayer<TConfig>;
|
||||
std::array<hwy::AlignedFreeUniquePtr<CLayer[]>, TConfig::kLayers> c_layers;
|
||||
using CVitLayer = CompressedLayer<typename TConfig::VitConfig>;
|
||||
std::array<hwy::AlignedFreeUniquePtr<CVitLayer[]>,
|
||||
TConfig::VitConfig::kLayers>
|
||||
c_vit_layers;
|
||||
};
|
||||
|
||||
template <class TConfig, typename = void>
|
||||
|
|
@ -159,6 +190,23 @@ struct CompressedWeights {
|
|||
hwy::If<hwy::IsSame<Weight, float>(), float, hwy::bfloat16_t>;
|
||||
CompressedArray<WeightF32OrBF16, TConfig::kModelDim> final_norm_scale;
|
||||
|
||||
// Vit parts.
|
||||
CompressedArray<WeightF32OrBF16, TConfig::VitConfig::kModelDim>
|
||||
vit_encoder_norm_bias;
|
||||
CompressedArray<WeightF32OrBF16, TConfig::VitConfig::kModelDim>
|
||||
vit_encoder_norm_scale;
|
||||
CompressedArray<float, TConfig::VitConfig::kModelDim> vit_img_embedding_bias;
|
||||
CompressedArray<WeightF32OrBF16, TConfig::VitConfig::kModelDim * 14 * 14 * 3>
|
||||
vit_img_embedding_kernel;
|
||||
CompressedArray<float, 256 * TConfig::VitConfig::kModelDim>
|
||||
vit_img_pos_embedding;
|
||||
// The head maps from VitConfig::kModelDim (Vit final layer) to
|
||||
// kModelDim (LLM input).
|
||||
CompressedArray<float, TConfig::kModelDim> vit_img_head_bias;
|
||||
CompressedArray<WeightF32OrBF16,
|
||||
TConfig::VitConfig::kModelDim * TConfig::kModelDim>
|
||||
vit_img_head_kernel;
|
||||
|
||||
// Must be last so that the other arrays remain aligned.
|
||||
CompressedLayerPointers<TConfig> c_layer_ptrs;
|
||||
|
||||
|
|
@ -174,9 +222,21 @@ struct CompressedWeights {
|
|||
void ZeroInit() {
|
||||
hwy::ZeroBytes(&embedder_input_embedding, sizeof(embedder_input_embedding));
|
||||
hwy::ZeroBytes(&final_norm_scale, sizeof(final_norm_scale));
|
||||
hwy::ZeroBytes(&vit_encoder_norm_bias, sizeof(vit_encoder_norm_bias));
|
||||
hwy::ZeroBytes(&vit_encoder_norm_scale, sizeof(vit_encoder_norm_scale));
|
||||
hwy::ZeroBytes(&vit_img_embedding_bias, sizeof(vit_img_embedding_bias));
|
||||
hwy::ZeroBytes(&vit_img_embedding_kernel, sizeof(vit_img_embedding_kernel));
|
||||
hwy::ZeroBytes(&vit_img_head_bias, sizeof(vit_img_head_bias));
|
||||
hwy::ZeroBytes(&vit_img_head_kernel, sizeof(vit_img_head_kernel));
|
||||
hwy::ZeroBytes(&vit_img_pos_embedding, sizeof(vit_img_pos_embedding));
|
||||
for (int i = 0; i < TConfig::kLayers; ++i) {
|
||||
hwy::ZeroBytes(GetLayer(i), sizeof(*GetLayer(i)));
|
||||
}
|
||||
if constexpr (TConfig::VitConfig::kLayers > 0) {
|
||||
for (int i = 0; i < TConfig::VitConfig::kLayers; ++i) {
|
||||
hwy::ZeroBytes(GetVitLayer(i), sizeof(*GetVitLayer(i)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const CompressedLayer<TConfig>* GetLayer(size_t layer) const {
|
||||
|
|
@ -185,6 +245,13 @@ struct CompressedWeights {
|
|||
CompressedLayer<TConfig>* GetLayer(size_t layer) {
|
||||
return c_layer_ptrs.c_layers[layer].get();
|
||||
}
|
||||
const CompressedLayer<typename TConfig::VitConfig>* GetVitLayer(
|
||||
size_t layer) const {
|
||||
return c_layer_ptrs.c_vit_layers[layer].get();
|
||||
}
|
||||
CompressedLayer<typename TConfig::VitConfig>* GetVitLayer(size_t layer) {
|
||||
return c_layer_ptrs.c_vit_layers[layer].get();
|
||||
}
|
||||
};
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
|
@ -288,6 +355,16 @@ void ForEachTensor(RawWeightsPtr raw_weights,
|
|||
GEMMA_CALL_TOP_FUNC("c_embedding", embedder_input_embedding);
|
||||
GEMMA_CALL_TOP_FUNC("c_final_norm", final_norm_scale);
|
||||
|
||||
if constexpr (TConfig::VitConfig::kLayers > 0 && !kHaveRaw) {
|
||||
GEMMA_CALL_TOP_FUNC("enc_norm_bias", vit_encoder_norm_bias);
|
||||
GEMMA_CALL_TOP_FUNC("enc_norm_scale", vit_encoder_norm_scale);
|
||||
GEMMA_CALL_TOP_FUNC("img_emb_bias", vit_img_embedding_bias);
|
||||
GEMMA_CALL_TOP_FUNC("img_emb_kernel", vit_img_embedding_kernel);
|
||||
GEMMA_CALL_TOP_FUNC("img_head_bias", vit_img_head_bias);
|
||||
GEMMA_CALL_TOP_FUNC("img_head_kernel", vit_img_head_kernel);
|
||||
GEMMA_CALL_TOP_FUNC("img_pos_emb", vit_img_pos_embedding);
|
||||
}
|
||||
|
||||
char name_buf[16];
|
||||
for (int layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) {
|
||||
auto type = TConfig::kLayerConfig[layer_idx];
|
||||
|
|
@ -334,6 +411,35 @@ void ForEachTensor(RawWeightsPtr raw_weights,
|
|||
GEMMA_CALL_FUNC("attn_ob", attention_output_biases);
|
||||
}
|
||||
}
|
||||
|
||||
// Vit layers. Not supported for compress_weights.
|
||||
if constexpr (TConfig::VitConfig::kLayers > 0 && !kHaveRaw) {
|
||||
for (int layer_idx = 0; layer_idx < TConfig::VitConfig::kLayers;
|
||||
++layer_idx) {
|
||||
auto type = TConfig::VitConfig::kLayerConfig[layer_idx];
|
||||
HWY_ASSERT(type == LayerAttentionType::kVit);
|
||||
const size_t idx = static_cast<size_t>(layer_idx);
|
||||
const RawLayer* raw_layer = nullptr;
|
||||
CompressedLayer<typename TConfig::VitConfig>* c_layer =
|
||||
c_weights.GetVitLayer(idx);
|
||||
|
||||
// MHA.
|
||||
GEMMA_CALL_FUNC("attn_out_w", vit.attn_out_w);
|
||||
GEMMA_CALL_FUNC("attn_out_b", vit.attn_out_b);
|
||||
GEMMA_CALL_FUNC("qkv_ein_w", vit.qkv_einsum_w);
|
||||
GEMMA_CALL_FUNC("qkv_ein_b", vit.qkv_einsum_b);
|
||||
// MlpBlock.
|
||||
GEMMA_CALL_FUNC("linear_0_w", vit.linear_0_w);
|
||||
GEMMA_CALL_FUNC("linear_0_b", vit.linear_0_b);
|
||||
GEMMA_CALL_FUNC("linear_1_w", vit.linear_1_w);
|
||||
GEMMA_CALL_FUNC("linear_1_b", vit.linear_1_b);
|
||||
// LayerNorm.
|
||||
GEMMA_CALL_FUNC("ln_0_bias", vit.layer_norm_0_bias);
|
||||
GEMMA_CALL_FUNC("ln_0_scale", vit.layer_norm_0_scale);
|
||||
GEMMA_CALL_FUNC("ln_1_bias", vit.layer_norm_1_bias);
|
||||
GEMMA_CALL_FUNC("ln_1_scale", vit.layer_norm_1_scale);
|
||||
}
|
||||
}
|
||||
#undef GEMMA_CALL_FUNC
|
||||
#undef GEMMA_CALL_TOP_FUNC
|
||||
} // ForEachTensor
|
||||
|
|
|
|||
|
|
@ -0,0 +1,45 @@
|
|||
package(
|
||||
default_applicable_licenses = [
|
||||
"//:license", # Placeholder comment, do not modify
|
||||
],
|
||||
default_visibility = [
|
||||
"//:__subpackages__", # Placeholder, do not modify
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "image",
|
||||
srcs = ["image.cc"],
|
||||
hdrs = ["image.h"],
|
||||
deps = ["@hwy//:hwy"],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "image_test",
|
||||
srcs = ["image_test.cc"],
|
||||
data = ["testdata/image.ppm"],
|
||||
deps = [
|
||||
":image",
|
||||
"@googletest//:gtest_main", # buildcleaner: keep
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "paligemma_test",
|
||||
srcs = ["paligemma_test.cc"],
|
||||
# Requires model files
|
||||
tags = [
|
||||
"local",
|
||||
"manual",
|
||||
"no_tap",
|
||||
],
|
||||
deps = [
|
||||
"@googletest//:gtest_main",
|
||||
"//:benchmark_helper",
|
||||
"//:common",
|
||||
"//:gemma_lib",
|
||||
"//:tokenizer",
|
||||
"@hwy//:hwy",
|
||||
"@hwy//:hwy_test_util",
|
||||
],
|
||||
)
|
||||
|
|
@ -0,0 +1,182 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paligemma/image.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cctype>
|
||||
#include <cmath>
|
||||
#include <cstdio>
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "hwy/base.h"
|
||||
|
||||
namespace gcpp {
|
||||
namespace {
|
||||
// Hardcoded for PaliGemma-224 ViT input.
|
||||
constexpr size_t kPatchSize = 14;
|
||||
constexpr size_t kImageSize = 224;
|
||||
constexpr size_t kNumPatches = kImageSize / kPatchSize; // 16
|
||||
|
||||
// Returns the linearly scaled index in [0, to_size) closest to the
|
||||
// value in [0, from_size).
|
||||
int NearestNeighbor(int value, int from_size, int to_size) {
|
||||
float scale_factor = static_cast<float>(to_size - 1) / (from_size - 1);
|
||||
// Apply nearest neighbor rounding.
|
||||
int nn = static_cast<int>(std::round(value * scale_factor));
|
||||
// Ensure the value is within the new range.
|
||||
nn = std::clamp(nn, 0, to_size - 1);
|
||||
return nn;
|
||||
}
|
||||
|
||||
// Returns value in [0,1] mapped linearly to [-1,1].
|
||||
float StretchToSigned(float value) {
|
||||
// = out_min + (value - in_min) * (out_max - out_min) / (in_max - in_min);
|
||||
return value * 2.0f - 1.0f;
|
||||
}
|
||||
|
||||
bool IsLineBreak(int c) { return c == '\r' || c == '\n'; }
|
||||
|
||||
void SkipWhitespaceAndComments(std::ifstream& file) {
|
||||
int value = file.get();
|
||||
while (std::isspace(value)) value = file.get();
|
||||
while (value == '#') { // Skip comment lines.
|
||||
while (!IsLineBreak(value)) value = file.get();
|
||||
while (std::isspace(value)) value = file.get();
|
||||
}
|
||||
file.unget(); // Rewind last byte.
|
||||
}
|
||||
} // namespace
|
||||
|
||||
bool Image::ReadPPM(const std::string& filename) {
|
||||
std::ifstream file(filename);
|
||||
if (!file.is_open()) {
|
||||
std::cerr << "Failed to open " << filename << "\n";
|
||||
return false;
|
||||
}
|
||||
std::string format;
|
||||
file >> format;
|
||||
if (format != "P6") {
|
||||
std::cerr << "We only support binary PPM (P6) but got: " << format << "\n";
|
||||
return false;
|
||||
}
|
||||
int width, height, max_value;
|
||||
SkipWhitespaceAndComments(file);
|
||||
file >> width;
|
||||
SkipWhitespaceAndComments(file);
|
||||
file >> height;
|
||||
SkipWhitespaceAndComments(file);
|
||||
file >> max_value;
|
||||
if (max_value <= 0 || max_value > 255) {
|
||||
std::cerr << "Unsupported max value " << max_value << "\n";
|
||||
return false;
|
||||
}
|
||||
// P6 requires exactly one whitespace character after the header.
|
||||
int value = file.get();
|
||||
if (!std::isspace(value)) {
|
||||
std::cerr << "Missing whitespace after header\n";
|
||||
return false;
|
||||
}
|
||||
width_ = width;
|
||||
height_ = height;
|
||||
int data_size = width * height * 3;
|
||||
data_.resize(data_size);
|
||||
std::vector<char> data_bytes(data_size);
|
||||
file.read(data_bytes.data(), data_size);
|
||||
if (file.gcount() != data_size) {
|
||||
std::cerr << "Failed to read " << data_size << " bytes\n";
|
||||
return false;
|
||||
}
|
||||
for (int i = 0; i < data_size; ++i) {
|
||||
data_[i] = StretchToSigned(static_cast<float>(data_bytes[i]) / max_value);
|
||||
}
|
||||
if (file.get() != EOF) {
|
||||
std::cerr << "Extra data in file\n";
|
||||
return false;
|
||||
}
|
||||
file.close();
|
||||
return true;
|
||||
}
|
||||
|
||||
void Image::Resize() {
|
||||
int new_width = 224;
|
||||
int new_height = kImageSize;
|
||||
std::vector<float> new_data(new_width * new_height * 3);
|
||||
// TODO: go to bilinear interpolation, or antialias.
|
||||
// E.g. consider WeightsSymmetric3Lowpass and SlowSymmetric3 from
|
||||
// jpegxl/lib/jxl/convolve_slow.cc
|
||||
// For now, just do nearest neighbor.
|
||||
for (int i = 0; i < new_height; ++i) {
|
||||
for (int j = 0; j < new_width; ++j) {
|
||||
int old_i = NearestNeighbor(i, new_height, height_);
|
||||
int old_j = NearestNeighbor(j, new_width, width_);
|
||||
for (int k = 0; k < 3; ++k) {
|
||||
new_data[(i * new_width + j) * 3 + k] =
|
||||
data_[(old_i * width_ + old_j) * 3 + k];
|
||||
}
|
||||
}
|
||||
}
|
||||
data_ = std::move(new_data);
|
||||
height_ = new_height;
|
||||
width_ = new_width;
|
||||
}
|
||||
|
||||
bool Image::WriteBinary(const std::string& filename) const {
|
||||
// Writes the floating point values as float32 in binary format.
|
||||
std::ofstream file(filename);
|
||||
if (!file.is_open()) {
|
||||
std::cerr << "Failed to open " << filename << "\n";
|
||||
return false;
|
||||
}
|
||||
for (int i = 0; i < data_.size(); ++i) {
|
||||
file.write(reinterpret_cast<const char*>(&data_[i]), sizeof(float));
|
||||
}
|
||||
file.close();
|
||||
return true;
|
||||
}
|
||||
|
||||
// Image.data() is kImageSize x kImageSize x 3, H x W x C.
|
||||
// We want the N-th patch (of 256) of size kPatchSize x kPatchSize x 3.
|
||||
// Patches are numbered in usual "pixel-order".
|
||||
void Image::GetPatch(size_t patch_num, float* patch) const {
|
||||
constexpr size_t kDataSize = kImageSize * kImageSize * 3;
|
||||
HWY_ASSERT(size() == kDataSize);
|
||||
constexpr size_t kPatchDataSize = kPatchSize * kPatchSize * 3;
|
||||
int i_offs = patch_num / kNumPatches;
|
||||
int j_offs = patch_num % kNumPatches;
|
||||
HWY_ASSERT(0 <= i_offs && i_offs < kNumPatches);
|
||||
HWY_ASSERT(0 <= j_offs && j_offs < kNumPatches);
|
||||
i_offs *= kPatchSize;
|
||||
j_offs *= kPatchSize;
|
||||
// This can be made faster, but let's first see whether it matters.
|
||||
const float* image_data = data();
|
||||
for (int i = 0; i < kPatchSize; ++i) {
|
||||
for (int j = 0; j < kPatchSize; ++j) {
|
||||
for (int k = 0; k < 3; ++k) {
|
||||
const int patch_index = (i * kPatchSize + j) * 3 + k;
|
||||
HWY_ASSERT(patch_index < kPatchDataSize);
|
||||
const int image_index =
|
||||
((i + i_offs) * kImageSize + (j + j_offs)) * 3 + k;
|
||||
HWY_ASSERT(image_index < kDataSize);
|
||||
patch[patch_index] = image_data[image_index];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace gcpp
|
||||
|
|
@ -0,0 +1,59 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef THIRD_PARTY_GEMMA_CPP_PALIGEMMA_IMAGE_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_PALIGEMMA_IMAGE_H_
|
||||
|
||||
#include <cstddef>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
// Very basic image loading and processing for PaliGemma-224. Does not try to be
|
||||
// generic at the moment, e.g. the size to normalize to is hardcoded.
|
||||
class Image {
|
||||
public:
|
||||
Image() = default;
|
||||
// Reads a file in PPM format (P6, binary), normalizes to [-1, 1].
|
||||
// Returns true on success.
|
||||
bool ReadPPM(const std::string& filename);
|
||||
// Resizes to 224x224 (nearest-neighbor for now, bilinear or antialias would
|
||||
// be better).
|
||||
void Resize();
|
||||
// Writes the file as plain floats in binary. Useful to e.g. load in a colab.
|
||||
bool WriteBinary(const std::string& filename) const;
|
||||
// Stores the patch for the given patch number [0, 256) in `patch`.
|
||||
// As sizes are hardcoded, the patch number is sufficient here.
|
||||
// `patch` should have space for at least 14 * 14 * 3 = 588 floats.
|
||||
// Requires that Normalize() has been called.
|
||||
void GetPatch(size_t patch_num, float* patch) const;
|
||||
|
||||
float *data() { return data_.data(); }
|
||||
const float *data() const { return data_.data(); }
|
||||
int width() const { return width_; }
|
||||
int height() const { return height_; }
|
||||
size_t size() const { return data_.size(); }
|
||||
operator bool() const { return data_.size() > 0; }
|
||||
|
||||
private:
|
||||
int width_ = 0;
|
||||
int height_ = 0;
|
||||
std::vector<float> data_; // r, g, b
|
||||
};
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_PALIGEMMA_IMAGE_H_
|
||||
|
|
@ -0,0 +1,72 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paligemma/image.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
namespace gcpp {
|
||||
namespace {
|
||||
|
||||
float Normalize(int value) { return 2.0f * (value / 255.0f) - 1.0f; }
|
||||
|
||||
TEST(ImageTest, BasicFunctionality) {
|
||||
return; // Need to figure out how to get the external path for the test file.
|
||||
std::string path;
|
||||
Image image;
|
||||
EXPECT_EQ(image.width(), 0);
|
||||
EXPECT_EQ(image.height(), 0);
|
||||
EXPECT_EQ(image.size(), 0);
|
||||
ASSERT_TRUE(image.ReadPPM(path));
|
||||
EXPECT_EQ(image.width(), 256);
|
||||
EXPECT_EQ(image.height(), 341);
|
||||
EXPECT_EQ(image.size(), 256 * 341 * 3);
|
||||
// Spot check a few values.
|
||||
EXPECT_EQ(image.data()[0], Normalize(160));
|
||||
EXPECT_EQ(image.data()[1], Normalize(184));
|
||||
EXPECT_EQ(image.data()[2], Normalize(188));
|
||||
EXPECT_EQ(image.data()[3], Normalize(163));
|
||||
EXPECT_EQ(image.data()[4], Normalize(185));
|
||||
EXPECT_EQ(image.data()[5], Normalize(189));
|
||||
EXPECT_EQ(image.data()[30], Normalize(164));
|
||||
EXPECT_EQ(image.data()[31], Normalize(185));
|
||||
EXPECT_EQ(image.data()[32], Normalize(191));
|
||||
EXPECT_EQ(image.data()[33], Normalize(164));
|
||||
EXPECT_EQ(image.data()[34], Normalize(185));
|
||||
EXPECT_EQ(image.data()[35], Normalize(191));
|
||||
image.Resize();
|
||||
// Check first and last pixel.
|
||||
EXPECT_EQ(image.data()[0], Normalize(160));
|
||||
EXPECT_EQ(image.data()[1], Normalize(184));
|
||||
EXPECT_EQ(image.data()[2], Normalize(188));
|
||||
EXPECT_EQ(image.data()[image.size() - 3], Normalize(90));
|
||||
EXPECT_EQ(image.data()[image.size() - 2], Normalize(132));
|
||||
EXPECT_EQ(image.data()[image.size() - 1], Normalize(122));
|
||||
// Extract two patches.
|
||||
float patch[588];
|
||||
image.GetPatch(0, patch);
|
||||
EXPECT_EQ(patch[0], Normalize(160));
|
||||
EXPECT_EQ(patch[1], Normalize(184));
|
||||
EXPECT_EQ(patch[2], Normalize(188));
|
||||
image.GetPatch(18, patch);
|
||||
for (int i = 0; i < 10; ++i) {
|
||||
EXPECT_EQ(patch[i], image.data()[(14 * 224 + 2 * 14) * 3 + i]);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace gcpp
|
||||
|
|
@ -0,0 +1,132 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "gemma/gemma.h"
|
||||
|
||||
#include <cstdio>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "evals/benchmark_helper.h"
|
||||
#include "gemma/common.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/tests/hwy_gtest.h"
|
||||
|
||||
// This test can be run manually with the downloaded PaliGemma weights.
|
||||
// To run the test, pass the following flags:
|
||||
// --model paligemma-224 --tokenizer <tokenizer_path> --weights <weights_path>
|
||||
// It should pass for the following models:
|
||||
// paligemma-3b-mix-224
|
||||
|
||||
namespace gcpp {
|
||||
namespace {
|
||||
|
||||
// Shared state. Requires argc/argv, so construct in main and use the same raw
|
||||
// pointer approach as in benchmarks.cc. Note that the style guide forbids
|
||||
// non-local static variables with dtors.
|
||||
GemmaEnv* s_env = nullptr;
|
||||
|
||||
class PaliGemmaTest : public ::testing::Test {
|
||||
protected:
|
||||
void InitVit(const std::string& path);
|
||||
std::string GemmaReply(const std::string& prompt_text) const;
|
||||
void TestQuestions(const char* kQA[][2], size_t num_questions);
|
||||
|
||||
std::unique_ptr<ImageTokens> image_tokens_;
|
||||
};
|
||||
|
||||
void PaliGemmaTest::InitVit(const std::string& path) {
|
||||
ASSERT_NE(s_env->GetModel(), nullptr);
|
||||
Gemma& model = *(s_env->GetModel());
|
||||
image_tokens_ = std::make_unique<ImageTokens>(256, 2048);
|
||||
Image image;
|
||||
HWY_ASSERT(model.Info().model == Model::PALIGEMMA_224);
|
||||
HWY_ASSERT(image.ReadPPM(path));
|
||||
image.Resize();
|
||||
RuntimeConfig runtime_config = {.verbosity = 0, .gen = &s_env->MutableGen()};
|
||||
model.GenerateImageTokens(runtime_config, image, *image_tokens_);
|
||||
}
|
||||
|
||||
std::string PaliGemmaTest::GemmaReply(const std::string& prompt_text) const{
|
||||
Gemma& model = *(s_env->GetModel());
|
||||
s_env->MutableGen().seed(0x12345678);
|
||||
RuntimeConfig runtime_config = {.max_tokens = 1024,
|
||||
.max_generated_tokens = 512,
|
||||
.verbosity = 0,
|
||||
.gen = &s_env->MutableGen()};
|
||||
runtime_config.image_tokens = image_tokens_.get();
|
||||
size_t abs_pos = 0;
|
||||
std::string mutable_prompt = prompt_text;
|
||||
std::vector<int> tokens =
|
||||
WrapAndTokenize(model.Tokenizer(), model.Info(), abs_pos, mutable_prompt);
|
||||
std::string response;
|
||||
auto stream_token = [&](int token, float) {
|
||||
std::string token_text;
|
||||
HWY_ASSERT(model.Tokenizer().Decode(std::vector<int>{token}, &token_text));
|
||||
response += token_text;
|
||||
return true;
|
||||
};
|
||||
runtime_config.stream_token = stream_token,
|
||||
tokens.insert(tokens.begin(), image_tokens_->BatchSize(), 0);
|
||||
size_t num_tokens = tokens.size();
|
||||
size_t prefix_end = num_tokens;
|
||||
runtime_config.prefill_tbatch_size = num_tokens;
|
||||
TimingInfo timing_info = {.verbosity = 0};
|
||||
model.Generate(runtime_config, tokens, abs_pos, prefix_end,
|
||||
s_env->MutableKVCache(), timing_info);
|
||||
return response;
|
||||
}
|
||||
|
||||
void PaliGemmaTest::TestQuestions(const char* kQA[][2], size_t num_questions) {
|
||||
ASSERT_NE(s_env->GetModel(), nullptr);
|
||||
return; // Need to figure out how to get the external path for the test file.
|
||||
std::string path;
|
||||
InitVit(path);
|
||||
for (size_t i = 0; i < num_questions; ++i) {
|
||||
fprintf(stderr, "Question %zu\n\n", i + 1);
|
||||
std::string response = GemmaReply(kQA[i][0]);
|
||||
fprintf(stderr, "'%s'\n\n", response.c_str());
|
||||
EXPECT_TRUE(response.find(kQA[i][1]) != std::string::npos); // NOLINT
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(PaliGemmaTest, General) {
|
||||
static const char* kQA[][2] = {
|
||||
{"describe this image",
|
||||
"A large building with two towers stands tall on the water's edge."},
|
||||
{"describe image briefly",
|
||||
"A large building with two towers in the middle of a city."},
|
||||
{"What kind of building is it?", "church"},
|
||||
{"How many towers does the church have?", "2"},
|
||||
{"detect water", "<loc1022> water"},
|
||||
{"segment water", "<seg010> water"},
|
||||
{"Which city is this more likely? Tokio or Zurich?", "zurich"},
|
||||
};
|
||||
static const size_t kNum = sizeof(kQA) / sizeof(kQA[0]);
|
||||
TestQuestions(kQA, kNum);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace gcpp
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
gcpp::GemmaEnv env(argc, argv);
|
||||
gcpp::s_env = &env;
|
||||
|
||||
testing::InitGoogleTest(&argc, argv);
|
||||
|
||||
return RUN_ALL_TESTS();
|
||||
}
|
||||
File diff suppressed because one or more lines are too long
|
|
@ -187,6 +187,7 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
|||
float temperature;
|
||||
bool deterministic;
|
||||
bool multiturn;
|
||||
Path image_file;
|
||||
|
||||
// Returns error string or nullptr if OK.
|
||||
const char* Validate() const {
|
||||
|
|
@ -221,6 +222,7 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
|||
"interaction\n 1 = continue KV cache after every interaction\n "
|
||||
" Default : 0 (conversation "
|
||||
"resets every turn)");
|
||||
visitor(image_file, "image_file", Path(), "Image file to load.");
|
||||
}
|
||||
|
||||
void CopyTo(RuntimeConfig& runtime_config) const {
|
||||
|
|
|
|||
Loading…
Reference in New Issue