Remove Griffin support

Also add IsObsolete helper

PiperOrigin-RevId: 803376921
This commit is contained in:
Jan Wassenberg 2025-09-05 02:34:54 -07:00 committed by Copybara-Service
parent 56186193c1
commit 2b4c16e243
28 changed files with 38 additions and 2455 deletions

View File

@ -507,14 +507,12 @@ cc_library(
srcs = [
"gemma/attention.cc",
"gemma/gemma.cc",
"gemma/griffin.cc",
"gemma/vit.cc",
],
hdrs = [
"gemma/activations.h",
"gemma/attention.h",
"gemma/gemma.h",
"gemma/griffin.h",
"gemma/vit.h",
],
exec_properties = {

View File

@ -83,8 +83,6 @@ set(SOURCES
gemma/gemma-inl.h
gemma/gemma.cc
gemma/gemma.h
gemma/griffin.cc
gemma/griffin.h
gemma/kv_cache.cc
gemma/kv_cache.h
gemma/model_store.cc

View File

@ -53,7 +53,7 @@ Guidelines](https://opensource.google.com/conduct/).
- LLM
- CPU-only inference for: Gemma 2-3, Griffin(SSM), PaliGemma 2.
- CPU-only inference for: Gemma 2-3, PaliGemma 2.
- Sampling with TopK and temperature.
- Backward pass (VJP) and Adam optimizer for Gemma research.
@ -222,23 +222,6 @@ Example invocation for the following configuration:
--tokenizer tokenizer.spm --weights gemma2-2b-it-sfp.sbs
```
### RecurrentGemma
This repository includes a version of Gemma based on Griffin
([paper](https://arxiv.org/abs/2402.19427),
[code](https://github.com/google-deepmind/recurrentgemma)). Its architecture
includes both recurrent layers and local attention, thus it is more efficient
for longer sequences and has a smaller memory footprint than standard Gemma. We
here provide a C++ implementation of this model based on the paper.
To use the recurrent version of Gemma included in this repository, build the
gemma binary as noted above in Step 3. Download the compressed weights and
tokenizer from the RecurrentGemma
[Kaggle](https://www.kaggle.com/models/google/recurrentgemma/gemmaCpp) as in
Step 1, and run the binary as follows:
`./gemma --tokenizer tokenizer.spm --model gr2b-it --weights 2b-it-sfp.sbs`
### PaliGemma Vision-Language Model
This repository includes a version of the PaliGemma 2 VLM
@ -535,7 +518,7 @@ gemma.cpp was started in fall 2023 by
Griffin support was implemented in April 2024 thanks to contributions by Andrey
Mikhaylov, Eugene Kliuchnikov, Jan Wassenberg, Jyrki Alakuijala, Lode
Vandevenne, Luca Versari, Martin Bruse, Phil Culliton, Sami Boukortt, Thomas
Fischbacher and Zoltan Szabadka.
Fischbacher and Zoltan Szabadka. It was removed in 2025-09.
Gemma-2 support was implemented in June/July 2024 with the help of several
people.

View File

@ -91,7 +91,7 @@ class CompressionTest(absltest.TestCase):
)
config = configs.ModelConfig(
configs.Model.GEMMA_TINY,
configs.Model.GEMMA2_2B,
configs.Type.kSFP,
configs.PromptWrapping.GEMMA_IT,
)
@ -101,7 +101,7 @@ class CompressionTest(absltest.TestCase):
print("Ignore next two warnings; test does not enable model deduction.")
reader = compression.SbsReader(temp_file.full_path)
self.assertEqual(reader.config.model, configs.Model.GEMMA_TINY)
self.assertEqual(reader.config.model, configs.Model.GEMMA2_2B)
self.assertEqual(reader.config.weight, configs.Type.kSFP)
mat = reader.find_mat("tensor0")

View File

@ -1,8 +0,0 @@
# General Remarks about the "PyTree" Abstraction
The pytree wrangling code in this project does not use any of the existing
"pytree" modules. The deeper reason here is that our approach is based on an
analysis of the notion that emphasizes deeper underlying principles. This is
being discussed internally at the time of this writing.

View File

@ -1,275 +0,0 @@
"""Ad-hoc glue code for building the griffin model-file for the C++ binary.
Usage:
python3 -m venv $HOME/clients/griffin-venv
. $HOME/clients/griffin-venv/bin/activate
python3 -m pip install -r requirements.txt
time python3 build_model_file_for_cpp_binary.py \
$HOME/GRIFFIN/model_data \
cpp_load_log.txt /tmp/G2B.data
real 3m5.821s
user 2m9.205s
sys 2m46.720s
./compress_weights --weights /tmp/G2B.data --model gr2b-it \
--compressed_weights /tmp/G2B.compressed
./gemma --tokenizer tokenizer.spm --weights /tmp/G2B.compressed \
--model gr2b-it
Weights for the recurrent-gemma model that can be converted with this script
can be found at:
https://www.kaggle.com/models/google/recurrentgemma/flax/2b-it
"""
import pprint
import re
import sys
from typing import Any, Mapping
import numpy
import orbax.checkpoint
import ml_model_transforms
import pytree_transforms
def _fn_identity(x): return x
def _fn_transpose(x): return x.T
def _fn_transpose_all_heads(x): return x.transpose(0, 2, 1)
def _fn_scaled_softplus(a):
return -8 * numpy.logaddexp(a, 0)
def _fn_attention_moveaxis(a):
return a.reshape(10, 256, 2560).transpose(0, 2, 1)
def _aspec(pieces=(), transforms=()):
"""Short-hand array-save-specification.
Args:
pieces: Sequence of key-sequences identifying an array.
transforms: Sequence of transformations, indexed in
parallel to `pieces`, to apply to data arrays prior to saving.
Will be padded with identity-transformations to the length of `pieces`.
Returns:
Specification as for use in _LAYETR_NAME_MAPPING.
"""
# `zip` trims to shortest sequence, so this amounts to using
# default-transforms.
# tuple() since we need a Sequence here, not a stateful-iterator zip_object.
return tuple(zip(pieces, list(transforms) + [_fn_identity] * len(pieces)))
_LAYER_NAME_MAPPING = pytree_transforms.deep_freeze({
# Recurrent Layer
'griffin_linear_x_w': _aspec(
[('recurrent_block', 'linear_x', 'kernel')],
[_fn_transpose]),
'griffin_linear_x_biases': _aspec(
[('recurrent_block', 'linear_x', 'bias')]),
'griffin_linear_y_w': _aspec(
[('recurrent_block', 'linear_y', 'kernel')],
[_fn_transpose]),
'griffin_linear_y_biases': _aspec(
[('recurrent_block', 'linear_y', 'bias')]),
'griffin_linear_out_w': _aspec(
[('recurrent_block', 'linear_out', 'kernel')],
[_fn_transpose]),
'griffin_linear_out_biases': _aspec(
[('recurrent_block' ,'linear_out', 'bias')]),
'griffin_conv_w': _aspec(
[('recurrent_block', 'conv_1d', 'w')]),
'griffin_conv_biases': _aspec(
[('recurrent_block', 'conv_1d', 'b')]),
'griffin_gate_w': _aspec(
[('recurrent_block', 'rg_lru', 'input_gate', 'w'),
('recurrent_block', 'rg_lru', 'a_gate', 'w')],
[_fn_transpose_all_heads, _fn_transpose_all_heads]),
'griffin_gate_biases': _aspec(
[('recurrent_block', 'rg_lru', 'input_gate', 'b'),
('recurrent_block', 'rg_lru', 'a_gate', 'b')]),
'griffin_a': _aspec(
[('recurrent_block', 'rg_lru', 'a_param')],
[_fn_scaled_softplus]),
# Attention Layer
'qkv_einsum_w': _aspec(
[('attention_block', 'proj_q', 'kernel'),
('attention_block', 'proj_k', 'kernel'),
('attention_block', 'proj_v', 'kernel'),
],
[_fn_transpose, _fn_transpose, _fn_transpose]),
'attn_vec_einsum_w': _aspec(
[('attention_block', 'proj_final', 'kernel')],
[_fn_attention_moveaxis]),
'attention_output_biases': _aspec(
[('attention_block', 'proj_final', 'bias')]),
# Common
'pre_attention_norm_scale': _aspec(
[('temporal_pre_norm', 'scale')]),
'pre_ffw_norm_scale': _aspec(
[('channel_pre_norm', 'scale')]),
'gating_einsum_w': _aspec(
[('mlp_block', 'ffw_up', 'w')],
[_fn_transpose_all_heads]),
'ffw_gating_biases': _aspec(
[('mlp_block', 'ffw_up', 'b')]),
'linear_w': _aspec(
[('mlp_block', 'ffw_down', 'kernel')],
[_fn_transpose]),
'ffw_output_biases': _aspec(
[('mlp_block', 'ffw_down', 'bias')]),
# Other
'embedder_input_embedding': _aspec(
[('embedder', 'input_embedding')]),
'final_norm_scale': _aspec(
[('final_norm', 'scale')]),
})
def process_param_line(line : str) -> tuple[None | str, int, str]:
"""Processes a "loading parameters" log-line from the griffin binary."""
# This is slightly more permissive than strictly needed, to also handle
# some earlier form of the output.
matched = re.match(
r'(?a)Loading Parameters:? \('
r'(?:layer=(?P<layer>\d+), )?'
r'size (?P<size>\d+)\):? '
r'(?P<tag>\S+)',
line)
if not matched:
return None
layer = matched['layer']
wanted_size = int(matched['size'])
cpp_tag = matched['tag']
return matched['layer'], int(matched['size']), matched['tag']
def collect_pytree_keys(param_lines):
"""Collects all the pytree keys and transforms for model-serialization."""
pytree_keys = []
array_transforms = []
unsatisfied = []
for maybe_spec in map(process_param_line, param_lines):
if not maybe_spec: continue # Skip non-parameter lines.
layer, wanted_size, cpp_tag = maybe_spec
pytree_key_tails_and_transforms = _LAYER_NAME_MAPPING.get(cpp_tag, ())
if not pytree_key_tails_and_transforms:
unsatisfied.append((layer, cpp_tag))
else:
for key_tail, array_transform in pytree_key_tails_and_transforms:
pytree_keys.append(
key_tail if layer is None
else (f'blocks.{layer}',) + key_tail)
array_transforms.append(array_transform)
return pytree_keys, array_transforms, unsatisfied
class UnsatisfiedArrayLoadsError(ValueError):
"""Some array-loads could not be satisfied."""
def flatten_model_for_cpp_binary(tree,
cpp_expectations_logfile_path : str,
out_path : str,
unsatisfied_ok : bool = False
):
"""Produces a model-parameters file readable by the C++ binary.
Args:
tree: The pytree with model-parameters.
cpp_expectations_logfile_path:
Path to a logfile produced by the C++ binary that shows
the expected array-order.
out_path: Path to the model-weights file to be written.
unsatisfied_ok: If true, we ignore the presence of unsatisfied
array-loads and write a model-parameters file that skips these pieces.
This will lead to an unusable model-parameters file which however
still might be useful for other analysis.
Returns:
Tuple `(unknown_keys, missing_keys)`, where `unknown_keys`
is a sequence of `(layer_or_None, name)` descriptions of the keys
in the C++ log that could not be satisfied, and `missing_keys`
is a sequence of linearized pytree key-sequences for keys
not found in the checkpoint.
Raises:
UnsatisfiedArrayLoadsError: If some of the expected arrays
could not be included in the output and `unsatisfied_ok`
is false.
"""
with open(cpp_expectations_logfile_path, 'rt') as h_log:
pytree_keys, array_transforms, unknown_keys = collect_pytree_keys(
list(h_log))
rank_by_pytree_key = {k: n for n, k in enumerate(pytree_keys)}
array_transform_by_pytree_key = dict(zip(pytree_keys, array_transforms))
#
model_contents = ml_model_transforms.model_contents(tree)
missing_keys = set(pytree_keys) - model_contents.keys()
if (unknown_keys or missing_keys) and not unsatisfied_ok:
raise ValueError(
f'Unsatisfied loads: unknown_keys: {unknown_keys!r}, '
f'missing keys: {sorted(missing_keys)!r}')
ml_model_transforms.model_save(
tree,
filepath_stem=out_path,
data_suffix='',
manifest_suffix=None,
array_transform_by_pytree_key=array_transform_by_pytree_key,
key=rank_by_pytree_key.get,
report=lambda line: print(line, file=sys.stderr),
byte_align=1)
return tuple(unknown_keys), tuple(sorted(missing_keys))
def main(args):
"""Creates the model-file.
Args:
sys.argv[] parameters from command line sans the leading one.
Returns:
The pytree with all the de-serialized variables, such as for convenient
`python3 -i` inspection.
"""
try:
model_dir, cpp_load_log, out_path = args
except Exception:
sys.exit(f'Usage: {__file__} [model_dir] [cpp_load_log] [output_filename]')
pattern = ("recurrent", "recurrent", "attention")
orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
variables = orbax_checkpointer.restore(model_dir)
if sorted(variables) == ['params']:
print('Warning: Using `variables["params"]` as tree-root.', file=sys.stderr)
variables_to_use = variables['params']
else:
variables_to_use = variables
unknown, missing = flatten_model_for_cpp_binary(variables_to_use,
cpp_load_log,
out_path,
unsatisfied_ok=True)
print('Model file saved.\n'
f'# unknown:\n{pprint.pformat(unknown)}\n'
f'# missing:\n{pprint.pformat(missing)}')
return variables
if __name__ == '__main__':
# Return value assignment is for `python3 -i ...` inspection.
pytree = main(sys.argv[1:])

View File

@ -1,380 +0,0 @@
Loading Parameters (size 2622750720): embedder_input_embedding
Loading Parameters (size 10240): final_norm_scale
Loading Parameters: (layer=0, size 26214400) griffin_linear_x_w
Loading Parameters: (layer=0, size 10240) griffin_linear_x_biases
Loading Parameters: (layer=0, size 26214400) griffin_linear_y_w
Loading Parameters: (layer=0, size 10240) griffin_linear_y_biases
Loading Parameters: (layer=0, size 26214400) griffin_linear_out_w
Loading Parameters: (layer=0, size 10240) griffin_linear_out_biases
Loading Parameters: (layer=0, size 40960) griffin_conv_w
Loading Parameters: (layer=0, size 10240) griffin_conv_biases
Loading Parameters: (layer=0, size 5242880) griffin_gate_w
Loading Parameters: (layer=0, size 20480) griffin_gate_biases
Loading Parameters: (layer=0, size 10240) griffin_a
Loading Parameters: (layer=0, size 157286400) gating_einsum_w
Loading Parameters: (layer=0, size 78643200) linear_w
Loading Parameters: (layer=0, size 10240) pre_attention_norm_scale
Loading Parameters: (layer=0, size 10240) pre_ffw_norm_scale
Loading Parameters: (layer=0, size 61440) ffw_gating_biases
Loading Parameters: (layer=0, size 10240) ffw_output_biases
Loading Parameters: (layer=1, size 26214400) griffin_linear_x_w
Loading Parameters: (layer=1, size 10240) griffin_linear_x_biases
Loading Parameters: (layer=1, size 26214400) griffin_linear_y_w
Loading Parameters: (layer=1, size 10240) griffin_linear_y_biases
Loading Parameters: (layer=1, size 26214400) griffin_linear_out_w
Loading Parameters: (layer=1, size 10240) griffin_linear_out_biases
Loading Parameters: (layer=1, size 40960) griffin_conv_w
Loading Parameters: (layer=1, size 10240) griffin_conv_biases
Loading Parameters: (layer=1, size 5242880) griffin_gate_w
Loading Parameters: (layer=1, size 20480) griffin_gate_biases
Loading Parameters: (layer=1, size 10240) griffin_a
Loading Parameters: (layer=1, size 157286400) gating_einsum_w
Loading Parameters: (layer=1, size 78643200) linear_w
Loading Parameters: (layer=1, size 10240) pre_attention_norm_scale
Loading Parameters: (layer=1, size 10240) pre_ffw_norm_scale
Loading Parameters: (layer=1, size 61440) ffw_gating_biases
Loading Parameters: (layer=1, size 10240) ffw_output_biases
Loading Parameters: (layer=2, size 26214400) attn_vec_einsum_w
Loading Parameters: (layer=2, size 78643200) qkv_einsum_w
Loading Parameters: (layer=2, size 157286400) gating_einsum_w
Loading Parameters: (layer=2, size 78643200) linear_w
Loading Parameters: (layer=2, size 10240) pre_attention_norm_scale
Loading Parameters: (layer=2, size 10240) pre_ffw_norm_scale
Loading Parameters: (layer=2, size 61440) ffw_gating_biases
Loading Parameters: (layer=2, size 10240) ffw_output_biases
Loading Parameters: (layer=2, size 10240) attention_output_biases
Loading Parameters: (layer=3, size 26214400) griffin_linear_x_w
Loading Parameters: (layer=3, size 10240) griffin_linear_x_biases
Loading Parameters: (layer=3, size 26214400) griffin_linear_y_w
Loading Parameters: (layer=3, size 10240) griffin_linear_y_biases
Loading Parameters: (layer=3, size 26214400) griffin_linear_out_w
Loading Parameters: (layer=3, size 10240) griffin_linear_out_biases
Loading Parameters: (layer=3, size 40960) griffin_conv_w
Loading Parameters: (layer=3, size 10240) griffin_conv_biases
Loading Parameters: (layer=3, size 5242880) griffin_gate_w
Loading Parameters: (layer=3, size 20480) griffin_gate_biases
Loading Parameters: (layer=3, size 10240) griffin_a
Loading Parameters: (layer=3, size 157286400) gating_einsum_w
Loading Parameters: (layer=3, size 78643200) linear_w
Loading Parameters: (layer=3, size 10240) pre_attention_norm_scale
Loading Parameters: (layer=3, size 10240) pre_ffw_norm_scale
Loading Parameters: (layer=3, size 61440) ffw_gating_biases
Loading Parameters: (layer=3, size 10240) ffw_output_biases
Loading Parameters: (layer=4, size 26214400) griffin_linear_x_w
Loading Parameters: (layer=4, size 10240) griffin_linear_x_biases
Loading Parameters: (layer=4, size 26214400) griffin_linear_y_w
Loading Parameters: (layer=4, size 10240) griffin_linear_y_biases
Loading Parameters: (layer=4, size 26214400) griffin_linear_out_w
Loading Parameters: (layer=4, size 10240) griffin_linear_out_biases
Loading Parameters: (layer=4, size 40960) griffin_conv_w
Loading Parameters: (layer=4, size 10240) griffin_conv_biases
Loading Parameters: (layer=4, size 5242880) griffin_gate_w
Loading Parameters: (layer=4, size 20480) griffin_gate_biases
Loading Parameters: (layer=4, size 10240) griffin_a
Loading Parameters: (layer=4, size 157286400) gating_einsum_w
Loading Parameters: (layer=4, size 78643200) linear_w
Loading Parameters: (layer=4, size 10240) pre_attention_norm_scale
Loading Parameters: (layer=4, size 10240) pre_ffw_norm_scale
Loading Parameters: (layer=4, size 61440) ffw_gating_biases
Loading Parameters: (layer=4, size 10240) ffw_output_biases
Loading Parameters: (layer=5, size 26214400) attn_vec_einsum_w
Loading Parameters: (layer=5, size 78643200) qkv_einsum_w
Loading Parameters: (layer=5, size 157286400) gating_einsum_w
Loading Parameters: (layer=5, size 78643200) linear_w
Loading Parameters: (layer=5, size 10240) pre_attention_norm_scale
Loading Parameters: (layer=5, size 10240) pre_ffw_norm_scale
Loading Parameters: (layer=5, size 61440) ffw_gating_biases
Loading Parameters: (layer=5, size 10240) ffw_output_biases
Loading Parameters: (layer=5, size 10240) attention_output_biases
Loading Parameters: (layer=6, size 26214400) griffin_linear_x_w
Loading Parameters: (layer=6, size 10240) griffin_linear_x_biases
Loading Parameters: (layer=6, size 26214400) griffin_linear_y_w
Loading Parameters: (layer=6, size 10240) griffin_linear_y_biases
Loading Parameters: (layer=6, size 26214400) griffin_linear_out_w
Loading Parameters: (layer=6, size 10240) griffin_linear_out_biases
Loading Parameters: (layer=6, size 40960) griffin_conv_w
Loading Parameters: (layer=6, size 10240) griffin_conv_biases
Loading Parameters: (layer=6, size 5242880) griffin_gate_w
Loading Parameters: (layer=6, size 20480) griffin_gate_biases
Loading Parameters: (layer=6, size 10240) griffin_a
Loading Parameters: (layer=6, size 157286400) gating_einsum_w
Loading Parameters: (layer=6, size 78643200) linear_w
Loading Parameters: (layer=6, size 10240) pre_attention_norm_scale
Loading Parameters: (layer=6, size 10240) pre_ffw_norm_scale
Loading Parameters: (layer=6, size 61440) ffw_gating_biases
Loading Parameters: (layer=6, size 10240) ffw_output_biases
Loading Parameters: (layer=7, size 26214400) griffin_linear_x_w
Loading Parameters: (layer=7, size 10240) griffin_linear_x_biases
Loading Parameters: (layer=7, size 26214400) griffin_linear_y_w
Loading Parameters: (layer=7, size 10240) griffin_linear_y_biases
Loading Parameters: (layer=7, size 26214400) griffin_linear_out_w
Loading Parameters: (layer=7, size 10240) griffin_linear_out_biases
Loading Parameters: (layer=7, size 40960) griffin_conv_w
Loading Parameters: (layer=7, size 10240) griffin_conv_biases
Loading Parameters: (layer=7, size 5242880) griffin_gate_w
Loading Parameters: (layer=7, size 20480) griffin_gate_biases
Loading Parameters: (layer=7, size 10240) griffin_a
Loading Parameters: (layer=7, size 157286400) gating_einsum_w
Loading Parameters: (layer=7, size 78643200) linear_w
Loading Parameters: (layer=7, size 10240) pre_attention_norm_scale
Loading Parameters: (layer=7, size 10240) pre_ffw_norm_scale
Loading Parameters: (layer=7, size 61440) ffw_gating_biases
Loading Parameters: (layer=7, size 10240) ffw_output_biases
Loading Parameters: (layer=8, size 26214400) attn_vec_einsum_w
Loading Parameters: (layer=8, size 78643200) qkv_einsum_w
Loading Parameters: (layer=8, size 157286400) gating_einsum_w
Loading Parameters: (layer=8, size 78643200) linear_w
Loading Parameters: (layer=8, size 10240) pre_attention_norm_scale
Loading Parameters: (layer=8, size 10240) pre_ffw_norm_scale
Loading Parameters: (layer=8, size 61440) ffw_gating_biases
Loading Parameters: (layer=8, size 10240) ffw_output_biases
Loading Parameters: (layer=8, size 10240) attention_output_biases
Loading Parameters: (layer=9, size 26214400) griffin_linear_x_w
Loading Parameters: (layer=9, size 10240) griffin_linear_x_biases
Loading Parameters: (layer=9, size 26214400) griffin_linear_y_w
Loading Parameters: (layer=9, size 10240) griffin_linear_y_biases
Loading Parameters: (layer=9, size 26214400) griffin_linear_out_w
Loading Parameters: (layer=9, size 10240) griffin_linear_out_biases
Loading Parameters: (layer=9, size 40960) griffin_conv_w
Loading Parameters: (layer=9, size 10240) griffin_conv_biases
Loading Parameters: (layer=9, size 5242880) griffin_gate_w
Loading Parameters: (layer=9, size 20480) griffin_gate_biases
Loading Parameters: (layer=9, size 10240) griffin_a
Loading Parameters: (layer=9, size 157286400) gating_einsum_w
Loading Parameters: (layer=9, size 78643200) linear_w
Loading Parameters: (layer=9, size 10240) pre_attention_norm_scale
Loading Parameters: (layer=9, size 10240) pre_ffw_norm_scale
Loading Parameters: (layer=9, size 61440) ffw_gating_biases
Loading Parameters: (layer=9, size 10240) ffw_output_biases
Loading Parameters: (layer=10, size 26214400) griffin_linear_x_w
Loading Parameters: (layer=10, size 10240) griffin_linear_x_biases
Loading Parameters: (layer=10, size 26214400) griffin_linear_y_w
Loading Parameters: (layer=10, size 10240) griffin_linear_y_biases
Loading Parameters: (layer=10, size 26214400) griffin_linear_out_w
Loading Parameters: (layer=10, size 10240) griffin_linear_out_biases
Loading Parameters: (layer=10, size 40960) griffin_conv_w
Loading Parameters: (layer=10, size 10240) griffin_conv_biases
Loading Parameters: (layer=10, size 5242880) griffin_gate_w
Loading Parameters: (layer=10, size 20480) griffin_gate_biases
Loading Parameters: (layer=10, size 10240) griffin_a
Loading Parameters: (layer=10, size 157286400) gating_einsum_w
Loading Parameters: (layer=10, size 78643200) linear_w
Loading Parameters: (layer=10, size 10240) pre_attention_norm_scale
Loading Parameters: (layer=10, size 10240) pre_ffw_norm_scale
Loading Parameters: (layer=10, size 61440) ffw_gating_biases
Loading Parameters: (layer=10, size 10240) ffw_output_biases
Loading Parameters: (layer=11, size 26214400) attn_vec_einsum_w
Loading Parameters: (layer=11, size 78643200) qkv_einsum_w
Loading Parameters: (layer=11, size 157286400) gating_einsum_w
Loading Parameters: (layer=11, size 78643200) linear_w
Loading Parameters: (layer=11, size 10240) pre_attention_norm_scale
Loading Parameters: (layer=11, size 10240) pre_ffw_norm_scale
Loading Parameters: (layer=11, size 61440) ffw_gating_biases
Loading Parameters: (layer=11, size 10240) ffw_output_biases
Loading Parameters: (layer=11, size 10240) attention_output_biases
Loading Parameters: (layer=12, size 26214400) griffin_linear_x_w
Loading Parameters: (layer=12, size 10240) griffin_linear_x_biases
Loading Parameters: (layer=12, size 26214400) griffin_linear_y_w
Loading Parameters: (layer=12, size 10240) griffin_linear_y_biases
Loading Parameters: (layer=12, size 26214400) griffin_linear_out_w
Loading Parameters: (layer=12, size 10240) griffin_linear_out_biases
Loading Parameters: (layer=12, size 40960) griffin_conv_w
Loading Parameters: (layer=12, size 10240) griffin_conv_biases
Loading Parameters: (layer=12, size 5242880) griffin_gate_w
Loading Parameters: (layer=12, size 20480) griffin_gate_biases
Loading Parameters: (layer=12, size 10240) griffin_a
Loading Parameters: (layer=12, size 157286400) gating_einsum_w
Loading Parameters: (layer=12, size 78643200) linear_w
Loading Parameters: (layer=12, size 10240) pre_attention_norm_scale
Loading Parameters: (layer=12, size 10240) pre_ffw_norm_scale
Loading Parameters: (layer=12, size 61440) ffw_gating_biases
Loading Parameters: (layer=12, size 10240) ffw_output_biases
Loading Parameters: (layer=13, size 26214400) griffin_linear_x_w
Loading Parameters: (layer=13, size 10240) griffin_linear_x_biases
Loading Parameters: (layer=13, size 26214400) griffin_linear_y_w
Loading Parameters: (layer=13, size 10240) griffin_linear_y_biases
Loading Parameters: (layer=13, size 26214400) griffin_linear_out_w
Loading Parameters: (layer=13, size 10240) griffin_linear_out_biases
Loading Parameters: (layer=13, size 40960) griffin_conv_w
Loading Parameters: (layer=13, size 10240) griffin_conv_biases
Loading Parameters: (layer=13, size 5242880) griffin_gate_w
Loading Parameters: (layer=13, size 20480) griffin_gate_biases
Loading Parameters: (layer=13, size 10240) griffin_a
Loading Parameters: (layer=13, size 157286400) gating_einsum_w
Loading Parameters: (layer=13, size 78643200) linear_w
Loading Parameters: (layer=13, size 10240) pre_attention_norm_scale
Loading Parameters: (layer=13, size 10240) pre_ffw_norm_scale
Loading Parameters: (layer=13, size 61440) ffw_gating_biases
Loading Parameters: (layer=13, size 10240) ffw_output_biases
Loading Parameters: (layer=14, size 26214400) attn_vec_einsum_w
Loading Parameters: (layer=14, size 78643200) qkv_einsum_w
Loading Parameters: (layer=14, size 157286400) gating_einsum_w
Loading Parameters: (layer=14, size 78643200) linear_w
Loading Parameters: (layer=14, size 10240) pre_attention_norm_scale
Loading Parameters: (layer=14, size 10240) pre_ffw_norm_scale
Loading Parameters: (layer=14, size 61440) ffw_gating_biases
Loading Parameters: (layer=14, size 10240) ffw_output_biases
Loading Parameters: (layer=14, size 10240) attention_output_biases
Loading Parameters: (layer=15, size 26214400) griffin_linear_x_w
Loading Parameters: (layer=15, size 10240) griffin_linear_x_biases
Loading Parameters: (layer=15, size 26214400) griffin_linear_y_w
Loading Parameters: (layer=15, size 10240) griffin_linear_y_biases
Loading Parameters: (layer=15, size 26214400) griffin_linear_out_w
Loading Parameters: (layer=15, size 10240) griffin_linear_out_biases
Loading Parameters: (layer=15, size 40960) griffin_conv_w
Loading Parameters: (layer=15, size 10240) griffin_conv_biases
Loading Parameters: (layer=15, size 5242880) griffin_gate_w
Loading Parameters: (layer=15, size 20480) griffin_gate_biases
Loading Parameters: (layer=15, size 10240) griffin_a
Loading Parameters: (layer=15, size 157286400) gating_einsum_w
Loading Parameters: (layer=15, size 78643200) linear_w
Loading Parameters: (layer=15, size 10240) pre_attention_norm_scale
Loading Parameters: (layer=15, size 10240) pre_ffw_norm_scale
Loading Parameters: (layer=15, size 61440) ffw_gating_biases
Loading Parameters: (layer=15, size 10240) ffw_output_biases
Loading Parameters: (layer=16, size 26214400) griffin_linear_x_w
Loading Parameters: (layer=16, size 10240) griffin_linear_x_biases
Loading Parameters: (layer=16, size 26214400) griffin_linear_y_w
Loading Parameters: (layer=16, size 10240) griffin_linear_y_biases
Loading Parameters: (layer=16, size 26214400) griffin_linear_out_w
Loading Parameters: (layer=16, size 10240) griffin_linear_out_biases
Loading Parameters: (layer=16, size 40960) griffin_conv_w
Loading Parameters: (layer=16, size 10240) griffin_conv_biases
Loading Parameters: (layer=16, size 5242880) griffin_gate_w
Loading Parameters: (layer=16, size 20480) griffin_gate_biases
Loading Parameters: (layer=16, size 10240) griffin_a
Loading Parameters: (layer=16, size 157286400) gating_einsum_w
Loading Parameters: (layer=16, size 78643200) linear_w
Loading Parameters: (layer=16, size 10240) pre_attention_norm_scale
Loading Parameters: (layer=16, size 10240) pre_ffw_norm_scale
Loading Parameters: (layer=16, size 61440) ffw_gating_biases
Loading Parameters: (layer=16, size 10240) ffw_output_biases
Loading Parameters: (layer=17, size 26214400) attn_vec_einsum_w
Loading Parameters: (layer=17, size 78643200) qkv_einsum_w
Loading Parameters: (layer=17, size 157286400) gating_einsum_w
Loading Parameters: (layer=17, size 78643200) linear_w
Loading Parameters: (layer=17, size 10240) pre_attention_norm_scale
Loading Parameters: (layer=17, size 10240) pre_ffw_norm_scale
Loading Parameters: (layer=17, size 61440) ffw_gating_biases
Loading Parameters: (layer=17, size 10240) ffw_output_biases
Loading Parameters: (layer=17, size 10240) attention_output_biases
Loading Parameters: (layer=18, size 26214400) griffin_linear_x_w
Loading Parameters: (layer=18, size 10240) griffin_linear_x_biases
Loading Parameters: (layer=18, size 26214400) griffin_linear_y_w
Loading Parameters: (layer=18, size 10240) griffin_linear_y_biases
Loading Parameters: (layer=18, size 26214400) griffin_linear_out_w
Loading Parameters: (layer=18, size 10240) griffin_linear_out_biases
Loading Parameters: (layer=18, size 40960) griffin_conv_w
Loading Parameters: (layer=18, size 10240) griffin_conv_biases
Loading Parameters: (layer=18, size 5242880) griffin_gate_w
Loading Parameters: (layer=18, size 20480) griffin_gate_biases
Loading Parameters: (layer=18, size 10240) griffin_a
Loading Parameters: (layer=18, size 157286400) gating_einsum_w
Loading Parameters: (layer=18, size 78643200) linear_w
Loading Parameters: (layer=18, size 10240) pre_attention_norm_scale
Loading Parameters: (layer=18, size 10240) pre_ffw_norm_scale
Loading Parameters: (layer=18, size 61440) ffw_gating_biases
Loading Parameters: (layer=18, size 10240) ffw_output_biases
Loading Parameters: (layer=19, size 26214400) griffin_linear_x_w
Loading Parameters: (layer=19, size 10240) griffin_linear_x_biases
Loading Parameters: (layer=19, size 26214400) griffin_linear_y_w
Loading Parameters: (layer=19, size 10240) griffin_linear_y_biases
Loading Parameters: (layer=19, size 26214400) griffin_linear_out_w
Loading Parameters: (layer=19, size 10240) griffin_linear_out_biases
Loading Parameters: (layer=19, size 40960) griffin_conv_w
Loading Parameters: (layer=19, size 10240) griffin_conv_biases
Loading Parameters: (layer=19, size 5242880) griffin_gate_w
Loading Parameters: (layer=19, size 20480) griffin_gate_biases
Loading Parameters: (layer=19, size 10240) griffin_a
Loading Parameters: (layer=19, size 157286400) gating_einsum_w
Loading Parameters: (layer=19, size 78643200) linear_w
Loading Parameters: (layer=19, size 10240) pre_attention_norm_scale
Loading Parameters: (layer=19, size 10240) pre_ffw_norm_scale
Loading Parameters: (layer=19, size 61440) ffw_gating_biases
Loading Parameters: (layer=19, size 10240) ffw_output_biases
Loading Parameters: (layer=20, size 26214400) attn_vec_einsum_w
Loading Parameters: (layer=20, size 78643200) qkv_einsum_w
Loading Parameters: (layer=20, size 157286400) gating_einsum_w
Loading Parameters: (layer=20, size 78643200) linear_w
Loading Parameters: (layer=20, size 10240) pre_attention_norm_scale
Loading Parameters: (layer=20, size 10240) pre_ffw_norm_scale
Loading Parameters: (layer=20, size 61440) ffw_gating_biases
Loading Parameters: (layer=20, size 10240) ffw_output_biases
Loading Parameters: (layer=20, size 10240) attention_output_biases
Loading Parameters: (layer=21, size 26214400) griffin_linear_x_w
Loading Parameters: (layer=21, size 10240) griffin_linear_x_biases
Loading Parameters: (layer=21, size 26214400) griffin_linear_y_w
Loading Parameters: (layer=21, size 10240) griffin_linear_y_biases
Loading Parameters: (layer=21, size 26214400) griffin_linear_out_w
Loading Parameters: (layer=21, size 10240) griffin_linear_out_biases
Loading Parameters: (layer=21, size 40960) griffin_conv_w
Loading Parameters: (layer=21, size 10240) griffin_conv_biases
Loading Parameters: (layer=21, size 5242880) griffin_gate_w
Loading Parameters: (layer=21, size 20480) griffin_gate_biases
Loading Parameters: (layer=21, size 10240) griffin_a
Loading Parameters: (layer=21, size 157286400) gating_einsum_w
Loading Parameters: (layer=21, size 78643200) linear_w
Loading Parameters: (layer=21, size 10240) pre_attention_norm_scale
Loading Parameters: (layer=21, size 10240) pre_ffw_norm_scale
Loading Parameters: (layer=21, size 61440) ffw_gating_biases
Loading Parameters: (layer=21, size 10240) ffw_output_biases
Loading Parameters: (layer=22, size 26214400) griffin_linear_x_w
Loading Parameters: (layer=22, size 10240) griffin_linear_x_biases
Loading Parameters: (layer=22, size 26214400) griffin_linear_y_w
Loading Parameters: (layer=22, size 10240) griffin_linear_y_biases
Loading Parameters: (layer=22, size 26214400) griffin_linear_out_w
Loading Parameters: (layer=22, size 10240) griffin_linear_out_biases
Loading Parameters: (layer=22, size 40960) griffin_conv_w
Loading Parameters: (layer=22, size 10240) griffin_conv_biases
Loading Parameters: (layer=22, size 5242880) griffin_gate_w
Loading Parameters: (layer=22, size 20480) griffin_gate_biases
Loading Parameters: (layer=22, size 10240) griffin_a
Loading Parameters: (layer=22, size 157286400) gating_einsum_w
Loading Parameters: (layer=22, size 78643200) linear_w
Loading Parameters: (layer=22, size 10240) pre_attention_norm_scale
Loading Parameters: (layer=22, size 10240) pre_ffw_norm_scale
Loading Parameters: (layer=22, size 61440) ffw_gating_biases
Loading Parameters: (layer=22, size 10240) ffw_output_biases
Loading Parameters: (layer=23, size 26214400) attn_vec_einsum_w
Loading Parameters: (layer=23, size 78643200) qkv_einsum_w
Loading Parameters: (layer=23, size 157286400) gating_einsum_w
Loading Parameters: (layer=23, size 78643200) linear_w
Loading Parameters: (layer=23, size 10240) pre_attention_norm_scale
Loading Parameters: (layer=23, size 10240) pre_ffw_norm_scale
Loading Parameters: (layer=23, size 61440) ffw_gating_biases
Loading Parameters: (layer=23, size 10240) ffw_output_biases
Loading Parameters: (layer=23, size 10240) attention_output_biases
Loading Parameters: (layer=24, size 26214400) griffin_linear_x_w
Loading Parameters: (layer=24, size 10240) griffin_linear_x_biases
Loading Parameters: (layer=24, size 26214400) griffin_linear_y_w
Loading Parameters: (layer=24, size 10240) griffin_linear_y_biases
Loading Parameters: (layer=24, size 26214400) griffin_linear_out_w
Loading Parameters: (layer=24, size 10240) griffin_linear_out_biases
Loading Parameters: (layer=24, size 40960) griffin_conv_w
Loading Parameters: (layer=24, size 10240) griffin_conv_biases
Loading Parameters: (layer=24, size 5242880) griffin_gate_w
Loading Parameters: (layer=24, size 20480) griffin_gate_biases
Loading Parameters: (layer=24, size 10240) griffin_a
Loading Parameters: (layer=24, size 157286400) gating_einsum_w
Loading Parameters: (layer=24, size 78643200) linear_w
Loading Parameters: (layer=24, size 10240) pre_attention_norm_scale
Loading Parameters: (layer=24, size 10240) pre_ffw_norm_scale
Loading Parameters: (layer=24, size 61440) ffw_gating_biases
Loading Parameters: (layer=24, size 10240) ffw_output_biases
Loading Parameters: (layer=25, size 26214400) griffin_linear_x_w
Loading Parameters: (layer=25, size 10240) griffin_linear_x_biases
Loading Parameters: (layer=25, size 26214400) griffin_linear_y_w
Loading Parameters: (layer=25, size 10240) griffin_linear_y_biases
Loading Parameters: (layer=25, size 26214400) griffin_linear_out_w
Loading Parameters: (layer=25, size 10240) griffin_linear_out_biases
Loading Parameters: (layer=25, size 40960) griffin_conv_w
Loading Parameters: (layer=25, size 10240) griffin_conv_biases
Loading Parameters: (layer=25, size 5242880) griffin_gate_w
Loading Parameters: (layer=25, size 20480) griffin_gate_biases
Loading Parameters: (layer=25, size 10240) griffin_a
Loading Parameters: (layer=25, size 157286400) gating_einsum_w
Loading Parameters: (layer=25, size 78643200) linear_w
Loading Parameters: (layer=25, size 10240) pre_attention_norm_scale
Loading Parameters: (layer=25, size 10240) pre_ffw_norm_scale
Loading Parameters: (layer=25, size 61440) ffw_gating_biases
Loading Parameters: (layer=25, size 10240) ffw_output_biases

View File

@ -1,371 +0,0 @@
"""Transformations for python-trees representing the parameters of a ML model.
Important: This module assumes that byte-order is the same on the
machine that serializes data and the machine that deserializes
data. If, for example, numpy-data gets dumped, respectively loaded,
with a dtype-specification of numpy.float32, on-file byte-order
will be host byte order.
"""
import ast
import hashlib
import itertools
import pprint
import sys
import time
from typing import Any, Callable, Iterable, Iterator, Mapping, TypeVar
import numpy
import pytree_transforms
NT = TypeVar('NT')
def ml_model_leaf_summary(path, x, sep=', '):
"""Produces a textual summary of a leaf-node and its path.
Args:
path: The path-to-root, as a reverse-order recursive
pair of path-components, with `()` as root.
x: The leaf-object.
sep: the separator between description-elements.
Default ', ' allows for convenient line-by-line processing
(such as via grep, perl -ne, etc.), but using e.g. sep=',\n '
might be more useful for human consumption.
Returns:
A human-readable string providing information about the node.
"""
# Using `repr` for path-components to get a faithful presentation.
# (...which still however would be somewat painful to correctly
# split into components.)
path_str = ','.join(map(repr,
pytree_transforms.linearize_revtuple_path(path)))
tx = type(x)
mod = tx.__module__ # Either a module or a string like 'builtins'.
modname = mod if isinstance(mod, str) else mod.__name__
type_str = f'{modname}.{tx.__qualname__}'
try:
# `numpy.ndarray` instances have a `.data` property that gives access
# to a buffer via which we can hashlib-fingerprint the numerical
# contents. We here simply try to produce a fingerprint and also look
# up the .dtype of the object. Technically, there is a somewhat-unsound
# assumption here that if these operations succeed, we are indeed looking
# at a ndarray or sufficiently similar object for these operations to
# make sense. As the output is declared "for human consumption", this
# fishiness is not a problem.
fp = hashlib.sha256(x.data).hexdigest()
start = list(itertools.islice(x.flat, 5))
stats_str = (
f'min={numpy.min(x):.6g}, max={numpy.max(x):.6g}, '
f'mean={numpy.mean(x):.6g}, std={numpy.std(x):.6g}')
return (f'{path_str:60s}: <{type_str}{sep}'
f'fp=0x{fp}{sep}{stats_str}{sep}shape={x.shape}, '
f'dtype={x.dtype}{sep}start={start}>')
except (AttributeError, ValueError, TypeError):
# Fallback - trying to include information about the data-content
# of a likely-numerical-array failed.
return f'{path_str:60s}: {type_str}({repr(x)})'
# A specialized node-handler.
# Interface follows node-handler expectations defined in pytree_transforms.
def _ml_model_tree_node_handler(path: tuple, node : NT) -> (
None | tuple[Callable[[Iterable[tuple[Any, NT]]], NT],
Iterator[tuple[Any, NT]]]):
"""Processes a tree-node as required by pytree-iteration and -mapping.
Args:
path: revtuple path to the current node.
node: a tree-node in a ML-model tree that is recursively
built out of `numpy.ndarray` leaf-values and dicts mapping
node-name string-keys to other such nodes representing subtrees -
and nothing else.
Returns:
`None` if the tree-node is to be regarded as a leaf, otherwise
a pair `(rebuilder, iterator)`, where `iterator` iterates
over the data-content of the node, each item represented as a pair
of `(lookup_path_item, value_item)`, and `rebuilder` is a function
which, when applied to `iterator` or any iterable with the same
elements, returns a node that is equivalent to the original.
Raises:
NotAMLModelTreeNodeError: If the tree contains a node that is neither
a `dict` nor a `numpy.ndarray` instance.
"""
# The astute reader will notice that we are doing something fishy
# here - this code could not be translated to Haskell as-is, since
# `NT` cannot actually be a proper type-variable in the sense
# of parametric polymorphism.
del path # Unused.
if isinstance(node, dict):
return dict, iter(node.items())
if isinstance(node, numpy.ndarray):
return None
raise pytree_transforms.NotAMLModelTreeNodeError(
f'Type of bad node: {type(node)}')
def _ml_model_extract_leaf_transform(
path: pytree_transforms.RevTuplePath,
leaf: Any):
"""Maps an array-leaf to a pair `(full_path, lambda: array)`.
The computation that produces the leaf-value is lazified underneath
a `lambda`, since if we e.g. performed a memory-expensive
transformation (such as some dtype-changes) directly at this point,
then going from an iterator over tree-items for one-by-one
consumption to a list of these items would have all the
dtype-transformed values around simultaneously. We want to avoid
situations where we can do nothing about having multiple variants
of the data simultaneously in memory.
"""
# Hack: If we are encountering a `bfloat16` numpy-array,
# we pretend to have the data as a numpy.float32 array,
# since that's about all that contemporary CPUs can process
# efficiently here.
linearized_path = pytree_transforms.linearize_revtuple_path(path)
try:
# We have to use some trickery to detect `bfloat16`.
if leaf.dtype.descr[-1] == ('', '<V2'):
return linearized_path, lambda: leaf.astype(numpy.float32)
else:
return linearized_path, lambda: leaf
except Exception:
return linearized_path, lambda: leaf
# Here, we cannot properly specify the return-type, since this can
# either be a leaf-type or something recursively-defined.
def revtuple_autovifify_from_linear(
keys_and_vals: Iterable[tuple[Any, Any]]) -> Any:
"""Performs perl-style autovivification on a nested-dict tree.
Args:
keys_and_vals: An iterable of pairs `(key_path, value)`, where
`key_path` is a sequence of keys to be used to navigate to
the result via iterative dict-lookup, left-to-right.
Must not have duplicate keys, and must not more than one key if
an empty-sequence key is present. If this iterable is an
iterator, it will be fully exhausted on successful execution.
Returns:
An object representing a nested-dict structure such that
for every `key_path` from `keys_and_vals`, recursive-dict-lookup
on the elements of that path starting from this object will
produce the corresponding value. An empty `keys_and_vals`
set will return `{}`. Every dict in the nested return-value
that has been populated by autovivification is newly allocated.
"""
# Code structure is a bit gnarly here due to f(keys_and_vals=[((), x)])
# having to evaluate to x and not a dict.
# There may be ways to prettify/simplify this.
result = None
empty = {}
for linear_path, val in keys_and_vals:
if linear_path == ():
if result is not None:
raise ValueError('Root-value seen alongside other values.')
result = val
else:
if result is None:
result = {}
elif type(result) is not dict:
# We already did encounter a root-value.
raise ValueError('Root-value seen alongside other values.')
cursor = result
for n in range(len(linear_path) - 1):
cursor = cursor.setdefault(linear_path[n], empty)
if cursor is empty:
# Regenerate `empty` if we just used it up.
empty = {}
cursor[linear_path[-1]] = val
return {} if result is None else result
def model_overview(tree, out=None) -> None:
"""Prints a human-readable overview to `(out or sys.stdout)`."""
actual_out = out or sys.stdout
for line in pytree_transforms.pytree_leaf_iter(
tree, ml_model_leaf_summary,
_ml_model_tree_node_handler):
print(line, file=actual_out)
def model_contents(tree) -> Mapping[tuple[str, ...], Any]:
"""Maps a model to a {pytree_keys: data_array} mapping.
Args:
tree: The ML-model parameter-tree, built recursively out of
dict-instances with numpy.ndarray instances as leaves.
Returns:
A mapping from linearized pytree-key-sequence tuple to the corresponding
leaf-value.
"""
def leaf_transform(revtuple_path, leaf):
return pytree_transforms.linearize_revtuple_path(revtuple_path), leaf
return dict(
pytree_transforms.pytree_leaf_iter(
tree, leaf_transform, _ml_model_tree_node_handler))
def _fn_identity(x): return x
def model_save(tree,
filepath_stem: str,
data_suffix: str = '.data',
manifest_suffix: str | None = '.manifest',
key: Callable[[tuple[str, ...]], Any] | None = None,
array_transform_by_pytree_key: (
Mapping[tuple[str, ...],
Callable[[numpy.ndarray], numpy.ndarray]] |
None) = None,
report: Callable[[str], None] | None = None,
byte_align: int = 8) -> tuple[int, float]:
"""Saves the content of a ML-model parameter-tree to filesystem.
After successful execution, the file f"{filepath_stem}.data"
will hold the combined numerical model-parameters, and
f"{filepath_stem}.manifest" will contain the key for interpreting
(and rebuilding) the data.
Args:
tree: The ML-model parameter-tree, built recursively out of
dict-instances with numpy.ndarray instances as leaves.
filepath_stem: Filesystem location for data.
data_suffix: Suffix to use for the data file.
manifest_suffix: Either `None`, in which case no manifest-file
will get written, or the suffix for the manifest-file.
key: `None` or a key-function that will be applied to the linear model-path
and used for sorting the data arrays by increasing value of the
key-function. If the key-function returns `None` on an item,
then this item is not included.
array_transform_by_pytree_key: Optional mapping from pytree-key
to an array-to-array transformation function to apply to the array
prior to serialization.
report: Optional callable for logging progress-reports.
byte_align: byte-alignment to use for numerical array data.
Numerical arrays whose size in bytes is not a multiple of this
will get padded to the next full multiple.
Returns:
A pair of `(size, time_sec)`, where `size` is the total byte-size
of the `.data` file and `time_sec` is the elapsed time
for saving the model, in seconds.
"""
time0 = time.monotonic()
if array_transform_by_pytree_key is None:
array_transform_by_pytree_key = {}
model_lazy_items = (
pytree_transforms.pytree_leaf_iter(
tree, _ml_model_extract_leaf_transform,
_ml_model_tree_node_handler))
if key is not None:
to_write = [
nkv[1:] for nkv in sorted(
(nkv for nkv in ((key(path), path, v)
for path, v in model_lazy_items)
if nkv[0] is not None), key=lambda nkv: nkv[0])]
else:
to_write = list(model_lazy_items)
#
def lazy_arr_path_shape_dtype_size(path_and_lazy_arr):
path, lazy_arr = path_and_lazy_arr
arr = array_transform_by_pytree_key.get(path, _fn_identity)(lazy_arr())
return path, arr.shape, arr.dtype, arr.data.nbytes
arrs_path_shape_dtype_nbytes = list(
map(lazy_arr_path_shape_dtype_size, to_write))
# We need to know the total size of all the data.
bytesizes = [nbytes for *_, nbytes in arrs_path_shape_dtype_nbytes]
padded_bytesizes = [-(-bytesize // byte_align * byte_align)
for bytesize in bytesizes]
offsets = numpy.cumsum([0] + padded_bytesizes)
membuf = numpy.memmap(filepath_stem + data_suffix,
mode='w+', shape=offsets[-1])
try:
for (path, shape, dtype, nbytes), offset, (_, lazy_arr) in zip(
arrs_path_shape_dtype_nbytes, offsets, to_write):
# Note that if getting the array from the lazy lambda involved some
# computation, such as a copying dtype-change, that computation would
# end up being done multiple times here - including once above, to compute
# byte-sizes, and once more here.
transformed_arr = array_transform_by_pytree_key.get(
path,
_fn_identity)(lazy_arr())
membuf[offset : offset + nbytes] = numpy.frombuffer(
transformed_arr.ravel().data, 'u1')
if report is not None:
samples = ', '.join(map(str, transformed_arr.ravel()[:5]))
report(f'# Adding: {path!r}\n bytes: {nbytes:10d}, '
f'shape: {shape!r:30},\n start: [{samples}, ...]')
transformed_arr = None # Drop memory references to numerical arrays ASAP.
finally:
if membuf is not None:
membuf.flush()
# NumPy wart: the memory-buffer is a resource that conceptually
# should be .close()able - since mmap()ing holds on to a
# file descriptor. However, it looks as if that clean-up were done
# in the "finalizer", despite that having meanwhile been widely
# understood as dubious practice. So, the best we can do here is
# to explicitly and clearly remove our reference to the instance.
del membuf
if manifest_suffix is not None:
# We still have to serialize the data that allows us to reconstruct
# a tree that is equivalent to the original.
manifest_data = [
dict(path=path,
dtype=dtype.descr[-1][-1],
shape=shape,
nbytes=nbytes,
offset=offset)
for (path, shape, dtype, nbytes), offset in zip(
arrs_path_shape_dtype_nbytes, offsets)]
with open(filepath_stem + '.manifest', 'wt') as h_manifest:
pprint.pprint(manifest_data, stream=h_manifest)
time_taken = time.monotonic() - time0
return offsets[-1], time_taken
def model_load(filepath_stem, mmapped=True):
"""Loads a model saved by `model_save`.
Tries to load the model from f"{filepath_stem}.data"
and f"{filepath_stem}.manifest".
Args:
filepath_stem: The model location on the filesystem.
mmapped: Whether data-arrays will be slices of a
`numpy.memmap` mapped buffer, to be paged in
on demand only, or in-memory copies of the data.
Returns:
A dict/numpy.ndarray tree representation of the model,
equivalent to the original model.
"""
with open(filepath_stem + '.manifest', 'rt') as h_manifest:
manifest = ast.literal_eval(h_manifest.read())
membuf = numpy.memmap(filepath_stem + '.data', mode='r+')
paths_and_arrays = []
for item in manifest:
path = item['path']
dtype = numpy.dtype(item['dtype'])
shape = item['shape']
nbytes = item['nbytes']
offset = item['offset']
data_array = numpy.frombuffer(membuf[offset : offset + nbytes].data,
dtype=dtype).reshape(shape)
paths_and_arrays.append(
(path,
data_array if mmapped else data_array.copy()))
# At this point, the memory-buffer is no longer needed. Still, if
# data-arrays retain references to the underlying data
# (i.e. when mmapped=False), this should keep the mapping
# - and hence file descriptor - open. We then are in a somewhat
# undesirable situation of clean-up of a resource that happens in a
# hard-to-predict way releasing a file descriptor.
del membuf
return revtuple_autovifify_from_linear(paths_and_arrays)

View File

@ -1,92 +0,0 @@
"""Basic tests for 'algebraic data type based pytree' transformations."""
import io
import os
import tempfile
import unittest
import numpy
import ml_model_transforms
def _get_model(prefix):
return {
prefix + 'a1': numpy.arange(1000, 1024).reshape(6, 4).astype(numpy.float32),
prefix + 'a2': numpy.arange(2000, 2048).reshape(6, 8).astype(numpy.float32),
prefix + 'b1': {
prefix + 'c1': numpy.arange(100, 127).reshape(3, 3, 3).astype(numpy.int8),
prefix + 'c2': numpy.arange(100, 128).reshape(7, 4).astype(numpy.float64)
}}
class MLModeltransformsTest(unittest.TestCase):
"""Basic correctness validation tests for ML-model transformations."""
def test_ml_model_leaf_summary(self):
"""Tests guarantees given by `ml_model_leaf_summary`."""
summary = ml_model_transforms.ml_model_leaf_summary(
('a', ()),
numpy.arange(1000, 1024).reshape(6, 4).astype(numpy.int16),
sep='##')
self.assertIn('##', summary) # Separator is respected.
self.assertIn('(6, 4)', summary) # Shape is mentioned somewhere.
self.assertIn('int16', summary) # dtype is mentioned somewhere.
def test_revtuple_autovivify_from_linear(self):
"""Tests guarantees given by `revtuple_autovifify_from_linear`."""
with self.subTest(guarantee='empty'):
self.assertEqual(
ml_model_transforms.revtuple_autovifify_from_linear([]),
{})
with self.subTest(guarantee='generic'):
keys_vals = [(('a', 'b1', 'c1'), 1001),
(('a', 'b2'), 1002),
(('a2',), 1003),
]
self.assertEqual(
ml_model_transforms.revtuple_autovifify_from_linear(keys_vals),
{'a': {'b1': {'c1': 1001}, 'b2': 1002}, 'a2': 1003})
def test_model_overview(self):
"""Tests guarantees given by `model_overview`."""
model = _get_model('xyz')
out_io = io.StringIO()
ml_model_transforms.model_overview(model, out=out_io)
overview = out_io.getvalue()
self.assertIn('xyz', overview)
def test_model_contents(self):
"""Tests guarantees given by `model_contents`."""
model = _get_model('pq_')
contents = ml_model_transforms.model_contents(model)
fingerprints = {k: (a.shape, a.ravel()[:3].tolist())
for k, a in contents.items()}
self.assertEqual(fingerprints,
{('pq_a1',): ((6, 4), [1000.0, 1001.0, 1002.0]),
('pq_a2',): ((6, 8), [2000.0, 2001.0, 2002.0]),
('pq_b1', 'pq_c1'): ((3, 3, 3), [100, 101, 102]),
('pq_b1', 'pq_c2'): ((7, 4), [100.0, 101.0, 102.0])})
def test_model_save_load_basic(self):
"""Tests basic guarantees given by `model_save` and `model_load`."""
# What we care about here is that the round trip works - so
# it makes more sense to test saving and loading as one unit.
model_orig = _get_model('model_')
with tempfile.TemporaryDirectory() as tempdir:
filepath_stem = os.path.join(tempdir, 'the_model')
total_size, total_time = ml_model_transforms.model_save(model_orig,
filepath_stem)
self.assertGreater(total_size, 0)
self.assertGreater(total_time, 0)
model_reloaded = ml_model_transforms.model_load(filepath_stem)
contents_orig = ml_model_transforms.model_contents(model_orig)
contents_reloaded = ml_model_transforms.model_contents(model_reloaded)
self.assertEqual(
{k: v.tolist() for k, v in contents_orig.items()},
{k: v.tolist() for k, v in contents_reloaded.items()})
if __name__ == '__main__':
unittest.main()

View File

@ -1,508 +0,0 @@
"""Tools for transforming "nested python object" tree data structures.
# Context
The motivation for this module came from ML applications that ought to
be based on a principled handling of nested Python data structures.
Having such principled pytree-transforming code available solves
some other problems, such as doing away with a need to abuse
tree-mapping for-side-effect-only and having to use a hope-and-pray
approach to processing very deeply nested values which with a recursive
approach might trigger a RecursionError.
We specifically want to cover the use case of having ML model
parameters that are available in a nested Python data structure for
which there "almost" is a unique-up-to-unique-isomorphism mapping from
and to this Algebraic Data Type:
`data ModelParams a = Array a | Node [(String, ModelParams a)]`
In this correspondence, `a` is some array-type (perhaps
`numpy.ndarray`, `jax.numpy.ndarray`, `tf.tensor`, etc.), but the
data-processing code is effectively entirely agnostic to this, and a
`Node` is "almost" an associative-list of (key, value) pairs
representing a Python dict. (Note: The "almost" here is mostly about
the conceptual wart that assoc-lists can in principle have key
duplicates, but Python dicts can not. This is however not a problem
since all we need is the transformation in one direction,
i.e. whatever data-processing `f` we want to express on the
model-parameters-pytree, we can express by specifying a "faithful"
mapping `m` into the above algebraic data type through which every
such pytree data transform factorizes, i.e. for every `f` we can find
a `g` such that `f(p) = g(m(p))`.)
## Components
The main workhorse in this module is the `pytree_iter` function that
maps a "PyTree (such as representing `ModelParams`)" to an iterator
over values obtained by applying a mapping-function to the "key-path"
and leaf-value for every leaf, where the "key-path" contains a
linked-list representation of the reversed sequence of keys from the
tree-root, with list-nodes being represented by pairs
`(latest_dict_key, rest_path)`, and the empty path being represented
by `()`.
For the sake of genericity, `pytree_iter` is built in such a way that
it actually can handle any kind of traversal of PyTree-trees that do
represent algebraic data types (note however that some some do not) -
but for this to make sense, the user must have a way to define how to
interpret tree-nodes, in particular identify leaves. This requires
providing a function `node_handler` with the same signature and
behavior as described below for "node handlers".
Additionally, this module provides mapping-over-pytrees via
`pytree_map`, which is also built in such a way that it makes the
correspondence between an algebraic data type and its Python
nested-tree representation explicit. Despite being powerful and
flexible, this, however, may in general require a bit more effort to
wire up, since node-rebuilding can be fairly nontrivial.
Furthermore, as a prominent application, this module provides a simple
deep-freezing function that translates a nested Python data structure
to deeply-immutable form.
## Concepts and Conventions
"revtuple representation":
As we iterate over a tree, we will have to keep track of the
path-to-tree-root. Naturally, two sibling nodes `n1` and `n2`
will share the same parent-path (being siblings), so it makes
sense to use a linked-list-with-shared-tail representation.
Python does not have a natural notion for that, so we use
recursively-constructed tuples `(node_tag, parent_path)`
that represent the path-from-root in-reverse-order, i.e.
for a non-empty path `p`, `p[0]` is the node-tag at the
deepest nesting level. We call this a "revtuple representation"
of the path.
"node handler":
A node-handler classifies a tree-node as "leaf or other node", and
for non-leaf nodes provides information about both its children and
how to rebuild it. The behavior of a node-handler function must be
in alignment with this docstring:
'''Processes a tree-node as required by pytree-iteration and -mapping.
Args:
revtuple_path: Revtuple-representation of the path-from-root
to the current node.
node: a tree-node in a ML-model tree that is recursively
built out of leaf-values and other nodes.
Returns:
`None` if the tree-node is to be regarded as a leaf, otherwise
a pair `(rebuilder, iterator)`, where `iterator` iterates
over the data-content of the node, each item represented as a pair
of `(lookup_path_item, value_item)`, and `rebuilder` is a function
which, when applied to an iterable of the aforementioned value-items
(or some transformation thereof) returns a node that is equivalent
to the original (or up to a transformation of the contents).
Raises:
InvalidTreeNodeError: If the tree contains a node of a kind
that is not expected to show up.
'''
Examples:
(The behavior of a node-handler is somewhat nontrivial, so covering
two very common cases via examples is in order.)
This node-handler would allow descending into (nested)
instances of `list` (but not subclass instances thereof):
```def list_node_handler(revtuple_path, obj):
''' ... '''
if type(obj) is list:
return list, enumerate(obj)
else:
return None
```
This node-handler would allow descending into (nested) mappings,
which upon rebuilding would get turned into `dict` instances:
```def mapping_node_handler(revtuple_path, obj):
''' ... '''
if isinstance(obj, collections.abc.Mapping):
# For generic mappings, we cannot rely on key- and item-iteration
# being guaranteed to use identical iteration-order.
items = list(obj.items())
keys = [kv[0] for kv in items]
return (lambda values: dict(zip(keys, values))), items
else:
return None
```
A dict/mapping node-handler can of course rename keys, add or remove
entries, make decisions based on the item-path, or map a dict to
an associative list, etc.
## Further Design Notes
The `pytree_map` function requests the leaf-transform and node-handler
to be side-effect-free functions. This is both required to leave
implementation-side flexibility, and also follows the general LISP
recommendation to not abuse mapping (which should be a pure
data-transformation) for imperative data processing. Overall, if
a need for more general "nested datastructures" processing becomes
pressing, it is for the better if this leads to a proper articulation
of the specific needs, to be addressed with appropriate design, rather
than abuse of functional data-transforms becoming "a bad idiom
that turned into established practice".
"""
import collections.abc
import immutabledict
import numpy
from typing import Any, Callable, Iterable, Iterator, TypeVar
T = TypeVar('T')
U = TypeVar('U')
KT = TypeVar('KT')
NT = TypeVar('NT')
## Type of the reverse-order-keys-to-root path.
# (This code actually illustrates why https://xkcd.com/2483/ is very misguided.)
RevTuplePath = tuple
## Type of the `leaf_transform` function-argument used for tree-iteration.
#
# This would be the correct type we would have to specify here but cannot,
# since the design of Python's static typing at the time of this writing
# is too broken for that:
#
# type LeafTransformFunc[L, R] = Callable[[RevTuplePath, L], R]
#
# Instead, we have to settle for...:
LeafTransformFunc = Callable[[RevTuplePath, Any], Any]
## Type of the `tree_node_handler` function-argument used for
## tree-iteration and tree-mapping.
#
# Again, this is the correct type we would have to put here but cannot:
#
# type NodeHandlerFunc[KT] = (
# Callable[[NT],
# None | tuple[Callable[[Iterable[tuple[KT, NT]]], NT],
# Iterator[tuple[KT, NT]]]])
#
# ...so, we have to instead settle for:
NodeHandlerFunc = (
Callable[[RevTuplePath, NT],
None | tuple[Callable[[Iterable[tuple[Any, NT]]], NT],
Iterator[tuple[Any, NT]]]])
Predicate = Callable[[object], bool]
class InvalidTreeNodeError(ValueError):
"""Encountered a tree-node of invalid type."""
def linearize_revtuple_path(
revtuple_path: RevTuplePath,
present_as: Callable[[Iterator[T]], U] = tuple) -> U:
"""Translates a revtuple path to (typically) linear form.
With default `present_as`, this will map a path of the form
`(key_{N}, (key_{N-1}, ..., (root, ())))` into a tuple
(root, ..., key_{N-1}, key_{N}).
Args:
revtuple_path: A linked-list-as-recursive-pairs
reverse-order tuple-representation of the path.
Path-root is `()`, and node-key `x` relative to
earlier path `p` is represented as `(x, p)`.
present_as: Callable that consumes an iterator over
path-pieces - with the deepest-nesting level coming last -
turning it into a linearized path. Defaults to `tuple`.
Returns:
Linearized presentation of all the node-keys in the
recursive-path in order, deepest-down path component coming last.
"""
pieces = []
todo = revtuple_path
while todo:
node, todo = todo
pieces.append(node)
return present_as(reversed(pieces))
# This function itself has type `NodeHandlerFunc`, but Python does not
# allow us to here simply type-annotate it like this. We cannot even
# introduce an abbreviation for the complicated output-type,
# since that would have to be parametric in node-type `NT` (and `KT`).
def everything_is_a_leaf_node_handler(
revtuple_path: tuple,
node : NT) -> (
None | tuple[Callable[[Iterable[tuple[Any, NT]]], NT],
Iterator[tuple[Any, NT]]]):
"""Processes a tree-node as required by pytree-iteration and -mapping.
Interface and signature are in alignment with the requirements for a
"node handler" function explained in the module-docstring.
Args:
revtuple_path: the path-to-root for this node.
node: a tree-node.
Returns:
`None`, i.e. classifying any kind of node as a leaf-node.
"""
del revtuple_path, node # Unused.
return None
def leaf_summary(path: RevTuplePath, x: object):
"""Produces a human-readable summary-string for a leaf-node.
Args:
path: revtuple representation of the path-to-root.
x: The leaf-value.
"""
del path # Ignored here.
tx = type(x)
mod = tx.__module__
modname = mod if isinstance(mod, str) else mod.__name__
type_str = f'{modname}.{tx.__qualname__}'
repr_str = repr(x)
repr_abbrev = repr_str if len(repr_str) < 40 else repr_str[:40] + ' ...'
# On str, int, float, etc. `{type_str}(repr(x))` would actually still be
# a (non-literal) Python-expression that would evaluate to the original value.
# However, we make no promises beyond "human-readable".
return f'{type_str}({repr_abbrev})'
# With respect to static type annotations, the limitations of Python's
# approach to static typing really become prominently visible here.
#
# Different arguments have type-parameters, but since there is no way
# to have parametric abbreviations such as `LeafTransformFunc[L, R]`,
# the only way we would have available to express relations between
# type-parameters would be to substitute in the not-abbreviated form of
# `NodeHandlerFunc` and `LeafTransformFunc`, giving us something monstrous.
# We instead here settle for "we cannot express that `tree` must
# have the same type as the input-type to `tree_node_handler` and use `Any`,
# and likewise for leaf_transform and the output.
def pytree_leaf_iter(
tree: Any,
leaf_transform: LeafTransformFunc,
node_handler: NodeHandlerFunc = everything_is_a_leaf_node_handler,
) -> Iterator[Any]:
# ...actual return type would be `Iterator[{what leaf_transform returns}]`.
"""Iterates over the leaves of a tree.
Args:
tree: The tree to iterate over.
leaf_transform: A callable `f` that will get applied
as `f(revtuple_path, leaf)`, where `revtuple_path`
is the revtuple representation of the path to the
leaf from the root.
node_handler: A "node handler" (see module docstring)
that processes nodes encountered during iterative traversal.
Yields:
Value of `leaf_transform(p, x)`, where `x` is the current leaf
and `p` is its revtuple-path to the root.
"""
# Note: Exit points for the code below are in non-obvious places
# and hence marked with " # ***EXIT***".
#
# Doing iteration properly is slightly nontrivial.
# One may be tempted to go for a very simple recursive implementation
# (with an extra pre-final `path` argument to `pytree_iter`):
#
# maybe_substructure = node_handler(path, tree)
# if maybe_substructure is None:
# # We are looking at a leaf-node.
# yield leaf_transform(path, tree)
# else:
# _, contents_iter = maybe_substructure
# for k, v in contents_iter:
# yield from pytree_iter(v, leaf_transform, (k, path), node_handler)
#
# That, however, would be flawed, since there is no a priori reason
# why a pytree may not be a very deeply nested structure - such as a
# long linked list. That would then risk raising `RecursionError`,
# and since Python by design(!) does not perform tail call elimination
# or any other kind of advanced CPS transforms, there is no recursive
# solution here. So, to do this properly, we have to do this iteratively.
#
# We are facing an annoying situation here: If `tree` itself is a leaf,
# we have two options: (a) wrapping it up in a one-node tree
# and processing that, or (b) special-casing "root is a leaf".
# Option (b) leads to some mild node-processing code-duplication
# for a single node (the root).
# Option (a) requires having special cases for node-processing that
# get looked at for every tree node. We go with option (b) here.
maybe_substructure = node_handler((), tree)
if maybe_substructure is None:
# The tree itself is a leaf.
yield leaf_transform((), tree)
return # ***EXIT***
# Otherwise, we are looking at a tree.
_, contents_iter = maybe_substructure
current_revtuple_path = ()
work_to_do = [contents_iter]
# Otherwise-unreachable sentinel for reliably identifying
# iterator-exhaustion without using exceptions:
sentinel = object()
while True:
current_iter = work_to_do[-1]
maybe_next_item = next(current_iter, sentinel)
if maybe_next_item is sentinel:
# We are done at this level.
work_to_do.pop()
if not work_to_do: return # ***EXIT***
current_revtuple_path = current_revtuple_path[1]
else:
path_piece, subtree = maybe_next_item
extended_revtuple_path = (path_piece, current_revtuple_path)
maybe_subtree_substructure = node_handler(extended_revtuple_path, subtree)
if maybe_subtree_substructure is None: # Case: subtree is a leaf.
yield leaf_transform(extended_revtuple_path, subtree)
else: # Case: subtree is a tree.
current_revtuple_path = (path_piece, current_revtuple_path)
_, items_iter = maybe_subtree_substructure
work_to_do.append(items_iter)
# The current design approach here would be appropriate for
# applying leaf-transforms while retaining the structure of the tree -
# which closely corresponds to e.g. a (a -> b) -> (Tree a -> Tree b) functor.
#
# It is not entirely clear whether this is the abstraction that we should
# consider as being appropriately generic to flesh out explicitly - rather
# than starting from a more general approach of which this then is a special
# case. Some background: https://ncatlab.org/nlab/show/recursion+scheme
#
# On the other hand, there is a lot of flexibility via whatever
# node-rebuilder a node-handler produces - this can do quite some reshaping
# of a tree, including dropping or duplicating nodes.
def pytree_map(
tree: Any,
leaf_transform,
node_handler: NodeHandlerFunc = everything_is_a_leaf_node_handler,
):
"""Maps a (potentially nested) Python value to another such value.
Args:
tree: The Python-object to be mapped.
leaf_transform: A callable `f` that will get applied
as `f(revtuple_path, leaf)`, where `revtuple_path`
is the revtuple representation of the path to the
leaf from the root. Must be side effect free.
node_handler: A "node handler" (see module docstring)
that processes nodes encountered during iterative traversal.
Must be side effect free.
Returns:
The outcome of translating `tree`.
"""
# Note: Exit points for the code below are in non-obvious places
# and hence marked with " # ***EXIT***".
#
# Otherwise-inaccessible sentinel object, for reliably identifying
# missing-values via identity-check against sentinel lookup-default.
sentinel = object()
# Code structure mostly follows pytree_leaf_iter.
maybe_substructure = node_handler((), tree)
if maybe_substructure is None:
return leaf_transform((), tree) # ***EXIT***
rebuilder, items_iter = maybe_substructure
current_revtuple_path = ()
# Per-level, we have a triplet of:
# (rebuilder, remaining_items_to_iterate_over, processed).
parts_for_assembly = [(rebuilder, items_iter, [])]
while True:
this_rebuilder, this_items_iter, this_done_pieces = parts_for_assembly[-1]
maybe_next_item = next(this_items_iter, sentinel)
if maybe_next_item is sentinel:
# We are done with all the items for this level.
parts_for_assembly.pop()
built_iter = this_rebuilder(this_done_pieces)
if not parts_for_assembly: # No outer structure, so at-top-level.
return built_iter # ***EXIT***
else: # We have outer structure.
parts_for_assembly[-1][-1].append(built_iter)
current_revtuple_path = current_revtuple_path[1]
continue # ...with next is-the-final-item-complete-check.
else:
# More constituents of the current item.
path_piece, subtree_item = maybe_next_item
extended_revtuple_path = (path_piece, current_revtuple_path)
maybe_subtree_substructure = node_handler(
extended_revtuple_path,
subtree_item)
if maybe_subtree_substructure is None:
this_done_pieces.append(
leaf_transform(extended_revtuple_path, subtree_item))
else:
# We have a subtree.
subtree_rebuilder, subtree_items_iter = maybe_subtree_substructure
current_revtuple_path = (path_piece,
current_revtuple_path)
parts_for_assembly.append(
(subtree_rebuilder, subtree_items_iter, []))
def deep_freeze(
tree,
*,
is_mapping : Predicate = lambda x: isinstance(x, collections.abc.Mapping),
is_set : Predicate = lambda x: isinstance(x, collections.abc.Set),
is_sequence : Predicate = lambda x: isinstance(x, (list, tuple)),
leaf_fn: Callable[[Any], Any] = lambda x: x,
):
"""Recursively freezes Set/Mapping/List/Tuple structures.
Args:
tree: The potentially deeply-nested object to deep-freeze.
is_mapping: Callable that decides whether a sub-object is a mapping.
Defaults to an `isinstance()` check for `collections.abc.Mapping`.
is_set: Callable that decides whether a sub-object is a set.
Defaults to an `isinstance()` check for `collections.abc.Set`.
is_sequence: Callable that decides whether a sub-object is a sequence.
Defaults to a check for being a `tuple` or `list` instance.
leaf_fn: Function to use for translating non-mapping/set/sequence
instances.
Returns:
Translated-to-deeply-immutable form of `tree`.
"""
idict = immutabledict.immutabledict
def freeze_node_handler(path, x):
if is_set(x):
return frozenset, ((None, y) for y in x)
if is_mapping(x):
# Mappings already have hashable, so
# (should-be-)deeply-immutable keys.
# Hence, we only need to deep-freeze the values.
#
# Note that non-`dict` mappings might not guarantee
# to respect iteration-order, so we have to be careful here:
items = list(x.items())
keys = [kv[0] for kv in items]
values = [kv[1] for kv in items]
return ((lambda ys: idict(zip(keys, ys))),
iter(items))
if is_sequence(x):
return tuple, enumerate(iter(x))
# Otherwise, this should not be traversed.
return None
def leaf_transform(revtuple_path, value):
del revtuple_path # Unused.
return leaf_fn(value)
return pytree_map(tree, leaf_transform, freeze_node_handler)

View File

@ -1,168 +0,0 @@
"""Basic tests for 'algebraic data type based pytree' transformations."""
import collections.abc
import sys
import unittest
import pytree_transforms
def _get_deep_pytree(packaging_fn, bottom, depth):
current = bottom
for n in reversed(range(depth)):
current = packaging_fn(n, current)
return current
def _dict_node_handler(p, d):
del p # Unused.
if isinstance(d, dict):
keys = d.keys()
newdict = lambda vals: dict(zip(keys, vals))
return (newdict, iter(d.items()))
else:
return None
class PyTreeTest(unittest.TestCase):
"""Basic correctness validation tests for PyTree transformations."""
def test_linearize_revtuple_path(self):
"""Tests guarantees given by `linearize_revtuple_path`."""
linearize_revtuple_path = pytree_transforms.linearize_revtuple_path
with self.subTest(guarantee='empty'):
self.assertEqual(linearize_revtuple_path(()), ())
with self.subTest(guarantee='typical'):
self.assertEqual(linearize_revtuple_path((30, (20, (10, ())))),
(10, 20, 30))
with self.subTest(guarantee='present_as'):
self.assertEqual(
linearize_revtuple_path(
(30, (20, (10, ()))), present_as=list),
[10, 20, 30])
def test_everything_is_a_leaf_node_handler(self):
"""Tests guarantees given by `everything_is_a_leaf_node_handler`."""
everything_is_a_leaf_node_handler = (
pytree_transforms.everything_is_a_leaf_node_handler)
self.assertEqual(everything_is_a_leaf_node_handler((), 'abc'),
None)
self.assertEqual(everything_is_a_leaf_node_handler(('b', ()),
dict(a=3)),
None)
def test_leaf_summary(self):
"""Tests guarantees given by `leaf_summary`."""
# Since the docstring only guarantees "a human-readable presentation",
# we can and should only do loose checks.
thing = (5678, 9531)
summary = pytree_transforms.leaf_summary(('key', ()), thing)
self.assertIsInstance(summary, str)
self.assertIn(str(thing[0]), summary)
self.assertIn(str(thing[1]), summary)
def test_pytree_leaf_iter(self):
"""Tests guarantees given by `pytree_leaf_iter`."""
pytree_leaf_iter = pytree_transforms.pytree_leaf_iter
def leaf_transform(path, leaf):
return repr(leaf) if path and path[0].startswith('R') else leaf
with self.subTest(guarantee='returns_iterator'):
result = pytree_leaf_iter(7, leaf_transform, _dict_node_handler)
self.assertIsInstance(result, collections.abc.Iterator)
with self.subTest(guarantee='totally_empty'):
result = list(pytree_leaf_iter({}, leaf_transform, _dict_node_handler))
self.assertEqual(result, [])
with self.subTest(guarantee='no_leaves'):
result = list(pytree_leaf_iter(dict(a={}),
leaf_transform, _dict_node_handler))
self.assertEqual(result, [])
with self.subTest(guarantee='is_leaf'):
result = list(pytree_leaf_iter(777, leaf_transform, _dict_node_handler))
self.assertEqual(result, [777])
with self.subTest(guarantee='generic'):
result = list(pytree_leaf_iter(
dict(n0=dict(n01=dict(n012=1002,
n013=1003,
Rn014=1004,
),
n02=1005),
n5=1006),
leaf_transform, _dict_node_handler))
self.assertEqual(result, [1002, 1003, '1004', 1005, 1006])
with self.subTest(guarantee='with_keys'):
result = list(pytree_leaf_iter(
dict(n0=dict(n01=dict(n012=1002,
n013=1003)),
n1=1004),
lambda p, s: (pytree_transforms.linearize_revtuple_path(p), s),
_dict_node_handler))
self.assertEqual(result,
[(('n0', 'n01', 'n012'), 1002),
(('n0', 'n01', 'n013'), 1003),
(('n1',), 1004)])
def test_pytree_map(self):
"""Tests guarantees given by `pytree_map`."""
pytree_map = pytree_transforms.pytree_map
leaf_transform = lambda p, s: repr(s)
tree1 = dict(t0=dict(t10=1001,
t11=dict(t110=1002,
t111=1003),
t12=dict(t120=1004,
t121=1005,
t122=1006)),
t1=1007)
with self.subTest(guarantee='no_leaves'):
result = pytree_map(dict(a={}),
leaf_transform,
_dict_node_handler)
self.assertEqual(result, dict(a={}))
with self.subTest(guarantee='is_leaf'):
result = pytree_map(777, leaf_transform, _dict_node_handler)
self.assertEqual(result, '777')
with self.subTest(guarantee='generic'):
result = pytree_map(tree1, leaf_transform, _dict_node_handler)
self.assertEqual(result['t0']['t10'], '1001')
def test_deeply_nested(self):
"""Tests correct behavior on deeply-nested data structures."""
pytree_leaf_iter = pytree_transforms.pytree_leaf_iter
pytree_map = pytree_transforms.pytree_map
#
depth = max(10**5, sys.getrecursionlimit() + 100)
deep_tree = _get_deep_pytree(lambda n, t: {n: t},
'leaf', depth)
with self.subTest(function='pytree_leaf_iter'):
leaves = list(pytree_leaf_iter(deep_tree,
lambda p, s: s.upper(),
_dict_node_handler))
self.assertEqual(leaves, ['LEAF'])
with self.subTest(function='pytree_map'):
mapped_deep_tree = pytree_map(deep_tree,
lambda p, s: s,
_dict_node_handler)
self.assertIsInstance(mapped_deep_tree, dict)
with self.subTest(function='combined'):
leaves = list(
pytree_leaf_iter(
pytree_map(deep_tree,
lambda p, s: s.capitalize(),
_dict_node_handler),
lambda p, s: s + s,
_dict_node_handler))
self.assertEqual(leaves, ['LeafLeaf'])
def test_deep_freeze(self):
"""Tests guarantees given by `deep_freeze`."""
frozen = pytree_transforms.deep_freeze(
dict(a=[1001, 1002, dict(b=(1003, [1004, {1005, 1006}]))]))
self.assertIsInstance(frozen, collections.abc.Mapping)
self.assertNotIsInstance(frozen, collections.abc.MutableMapping)
self.assertIsInstance(frozen['a'], tuple)
# `frozen` is hashable, and hashes to an integer.
self.assertIsInstance(hash(frozen), int)
if __name__ == '__main__':
unittest.main()

View File

@ -1,4 +0,0 @@
immutabledict>=4.2.0
numpy>=1.26.4
orbax-checkpoint>=0.0.0

View File

@ -158,9 +158,6 @@ TEST_F(GemmaTest, CrossEntropySmall) {
float entropy = s_env->CrossEntropy(kSmall);
fprintf(stderr, "per-token entropy: %f\n", entropy);
switch (config.model) {
case gcpp::Model::GRIFFIN_2B:
EXPECT_NEAR(entropy, 2.61f, 0.02f);
break;
case gcpp::Model::GEMMA2_2B:
EXPECT_NEAR(entropy, 1.14f, 0.02f);
break;

View File

@ -31,32 +31,6 @@
namespace gcpp {
struct GriffinActivations {
GriffinActivations(const ModelConfig& config, size_t batch_size,
const Allocator& allocator)
: griffin_x(
MatFactory("griffin_x", batch_size, config.model_dim, allocator)),
griffin_y(
MatFactory("griffin_y", batch_size, config.model_dim, allocator)),
griffin_gate_x(MatFactory("griffin_gate_x", batch_size,
config.model_dim, allocator)),
griffin_multiplier(MatFactory("griffin_mul", batch_size,
config.model_dim, allocator)) {}
void SetBatchSize(size_t batch_size) {
if (griffin_x.Rows() == 0) return;
griffin_x.OverrideRows(batch_size);
griffin_y.OverrideRows(batch_size);
griffin_gate_x.OverrideRows(batch_size);
griffin_multiplier.OverrideRows(batch_size);
}
MatStorageT<float> griffin_x;
MatStorageT<float> griffin_y;
MatStorageT<float> griffin_gate_x;
MatStorageT<float> griffin_multiplier;
};
struct AttentionActivations {
// Returns the scale value to use for the query in the attention computation.
// Also called by ops_test.
@ -143,7 +117,7 @@ struct AttentionActivations {
MatStorageT<float> inv_timescale_global;
hwy::Divisor div_seq_len;
// Unfortunately, some models (Griffin) have non-power-of-two heads.
// Unfortunately, some models have had non-power-of-two heads.
hwy::Divisor div_heads;
float query_scale;
};
@ -169,9 +143,7 @@ struct Activations {
MatFactory("ffw_out", batch_size, config.model_dim, ctx.allocator)),
attention(config, layer_config, batch_size, seq_len, ctx.allocator,
row_ptrs),
griffin(config, config.model == Model::GRIFFIN_2B ? batch_size : 0,
ctx.allocator) {
row_ptrs) {
HWY_ASSERT(batch_size != 0);
// For MatMul outputs, precompute their row pointers.
@ -199,7 +171,6 @@ struct Activations {
ffw_out.OverrideRows(batch_size);
attention.SetBatchSize(batch_size);
griffin.SetBatchSize(batch_size);
}
const LayerConfig& layer_config;
@ -215,7 +186,6 @@ struct Activations {
MatStorageT<float> ffw_out;
AttentionActivations attention;
GriffinActivations griffin;
};
} // namespace gcpp

View File

@ -327,6 +327,7 @@ static HWY_INLINE void SumHeads(const LayerWeightsPtrs& layer,
MatMulEnv& env) {
PROFILER_ZONE("Gen.Attention.SumHeads");
const LayerConfig& layer_config = layer.layer_config;
(void)layer_config; // For HWY_DASSERT
// att_weights and att_out are concatenated heads, each of length
// layer_config.qkv_dim. Thus the [num_interleaved,
// layer_config.model_dim] matmul output is the sum over heads. Compare
@ -334,10 +335,7 @@ static HWY_INLINE void SumHeads(const LayerWeightsPtrs& layer,
// encoded)
HWY_DASSERT(layer_config.model_dim != 0 && layer_config.heads != 0 &&
layer_config.qkv_dim != 0);
const float* add = layer_config.softmax_attn_output_biases
? layer.attention_output_biases.PackedScale1()
: nullptr;
CallMatMul(activations.att_out, layer.att_weights, add, env,
CallMatMul(activations.att_out, layer.att_weights, /*add=*/nullptr, env,
activations.att_sums);
}

View File

@ -133,78 +133,6 @@ static ModelConfig ConfigGemma2_2B() {
return config;
}
static LayerConfig LayerConfigGemmaTiny(size_t model_dim) {
LayerConfig config;
config.model_dim = model_dim;
config.ff_hidden_dim = 256;
config.heads = 4;
config.kv_heads = 1;
config.qkv_dim = 16;
return config;
}
static ModelConfig ConfigGemmaTiny() {
ModelConfig config = ConfigNoSSM();
config.display_name = "GemmaTiny";
config.model = Model::GEMMA_TINY;
config.wrapping = PromptWrapping::GEMMA_IT;
config.model_dim = 32;
config.vocab_size = 32; // at least two f32 vectors
config.max_seq_len = 32;
LayerConfig layer_config = LayerConfigGemmaTiny(config.model_dim);
config.num_layers = 2;
config.layer_configs = {config.num_layers, layer_config};
config.query_scale = QueryScaleType::SqrtKeySize;
config.attention_window_sizes = FixedAttentionWindowSizes<2>(32);
config.att_cap = 50.0f;
config.final_cap = 30.0f;
config.eos_id = 11;
config.secondary_eos_id = 11;
return config;
}
static LayerConfig LayerConfigGriffin2B(size_t model_dim) {
LayerConfig config;
config.model_dim = model_dim;
config.griffin_dim = model_dim;
config.ff_hidden_dim = 7680;
config.heads = 10;
config.kv_heads = 1;
config.qkv_dim = 256;
config.conv1d_width = 4;
HWY_DASSERT(config.conv1d_width <= kMaxConv1DWidth);
config.ff_biases = true;
config.softmax_attn_output_biases = true;
config.optimized_gating = false;
config.type = LayerAttentionType::kGriffinRecurrentBlock;
config.activation = ActivationType::Gelu;
config.post_qk = PostQKType::HalfRope;
return config;
}
static ModelConfig ConfigGriffin2B() {
ModelConfig config = ConfigNoSSM();
config.display_name = "Griffin2B";
config.model = Model::GRIFFIN_2B;
// Griffin uses local attention, so max_seq_len is actually the local
// attention window.
config.model_dim = 2560;
config.vocab_size = kVocabSize;
config.max_seq_len = 2048;
LayerConfig layer_config = LayerConfigGriffin2B(config.model_dim);
config.num_layers = 26;
config.layer_configs = {config.num_layers, layer_config};
for (size_t i = 2; i < config.num_layers; i += 3) {
config.layer_configs[i].type = LayerAttentionType::kGemma;
config.layer_configs[i].griffin_dim = 0;
}
config.attention_window_sizes =
FixedAttentionWindowSizes<26>(config.max_seq_len);
config.use_local_attention = true;
config.final_cap = 0.0f;
return config;
}
static LayerConfig LayerConfigVit(size_t model_dim) {
LayerConfig config;
config.model_dim = model_dim;
@ -510,10 +438,6 @@ static ModelConfig ConfigFromModel(Model model) {
return ConfigGemma2_9B();
case Model::GEMMA2_27B:
return ConfigGemma2_27B();
case Model::GRIFFIN_2B:
return ConfigGriffin2B();
case Model::GEMMA_TINY:
return ConfigGemmaTiny();
case Model::PALIGEMMA2_3B_224:
return ConfigPaliGemma2_3B_224();
case Model::PALIGEMMA2_3B_448:
@ -547,10 +471,6 @@ const char* ModelPrefix(Model model) {
return "9b";
case Model::GEMMA2_27B:
return "27b";
case Model::GRIFFIN_2B:
return "gr2b";
case Model::GEMMA_TINY:
return "tiny";
case Model::PALIGEMMA2_3B_224:
return "paligemma2-3b-224";
case Model::PALIGEMMA2_3B_448:
@ -750,13 +670,10 @@ bool ModelConfig::OverwriteWithCanonical() {
Model DeduceModel(const Path& blob_path, size_t layers, int layer_types) {
switch (layers) {
case 2:
return Model::GEMMA_TINY;
case 18:
return Model::GEMMA3_270M;
case 26:
if (layer_types & kDeducedGriffin) return Model::GRIFFIN_2B;
if (layer_types & kDeducedViT) return Model::GEMMA3_1B;
return Model::GEMMA2_2B;
case 27:

View File

@ -68,14 +68,11 @@ static inline bool EnumValid(PromptWrapping wrapping) {
enum class LayerAttentionType {
kGemma,
kGriffinRecurrentBlock,
kVit,
};
static inline bool EnumValid(LayerAttentionType type) {
return type == LayerAttentionType::kGemma ||
type == LayerAttentionType::kGriffinRecurrentBlock ||
type == LayerAttentionType::kVit;
return type == LayerAttentionType::kGemma || type == LayerAttentionType::kVit;
}
// Post attention and ffw normalization type.
@ -163,9 +160,8 @@ enum class Model {
// 1 and 2 are obsolete.
GEMMA2_9B = 3,
GEMMA2_27B,
GRIFFIN_2B,
GEMMA_TINY, // for testing only
GEMMA2_2B,
// 5 and 6 are obsolete.
GEMMA2_2B = 7,
// 8 and 9 are obsolete.
PALIGEMMA2_3B_224 = 10,
PALIGEMMA2_3B_448,
@ -199,13 +195,19 @@ static inline bool IsPaliGemma(Model model) {
return false;
}
static inline bool IsObsolete(Model model) {
const size_t i = static_cast<size_t>(model);
if (i == 5 || i == 6 || i == 8 || i == 9) return true;
return false;
}
// Visits every valid model enum, skipping `UNKNOWN` and `kSentinel`.
template <class Func>
void ForEachModel(const Func& func) {
for (size_t i = static_cast<size_t>(Model::GEMMA2_9B);
i < static_cast<size_t>(Model::kSentinel); ++i) {
if (i == 8 || i == 9) continue;
func(static_cast<Model>(i));
const Model model = static_cast<Model>(i);
if (!IsObsolete(model)) func(model);
}
}
@ -214,7 +216,7 @@ static inline bool EnumValid(Model model) {
if (model == Model::UNKNOWN) return true;
const size_t i = static_cast<size_t>(model);
if (i >= static_cast<size_t>(Model::GEMMA2_9B) &&
i < static_cast<size_t>(Model::kSentinel) && i != 8 && i != 9) {
i < static_cast<size_t>(Model::kSentinel) && !IsObsolete(model)) {
return true;
}
return false;
@ -235,15 +237,20 @@ struct LayerConfig : public IFields {
// Source of truth for field ordering.
void VisitFields(IFieldsVisitor& visitor) override {
// Formerly used for Griffin.
uint32_t unused_griffin_dim = 0;
uint32_t unused_conv1d_width = 0;
bool unused_softmax_attn_output_biases = false;
visitor(model_dim);
visitor(griffin_dim);
visitor(unused_griffin_dim);
visitor(ff_hidden_dim);
visitor(heads);
visitor(kv_heads);
visitor(qkv_dim);
visitor(conv1d_width);
visitor(unused_conv1d_width);
visitor(ff_biases);
visitor(softmax_attn_output_biases);
visitor(unused_softmax_attn_output_biases);
visitor(optimized_gating);
visitor(post_norm);
visitor(type);
@ -263,14 +270,11 @@ struct LayerConfig : public IFields {
bool IsMHA() const { return heads == kv_heads; }
uint32_t model_dim = 0;
uint32_t griffin_dim = 0;
uint32_t ff_hidden_dim = 0;
uint32_t heads = 0;
uint32_t kv_heads = 0;
uint32_t qkv_dim = 0; // length of Q, K, V vectors (contiguous).
uint32_t conv1d_width = 0; // Griffin only
uint32_t qkv_dim = 0; // length of Q, K, V vectors (contiguous).
bool ff_biases = false;
bool softmax_attn_output_biases = false; // for Griffin
bool optimized_gating = true; // for Gemma3
PostNormType post_norm = PostNormType::None;
LayerAttentionType type = LayerAttentionType::kGemma;
@ -358,7 +362,8 @@ struct ModelConfig : public IFields {
visitor(final_cap);
visitor(absolute_pe);
visitor(use_local_attention);
bool unused_use_local_attention = false; // formerly used for Griffin
visitor(unused_use_local_attention);
visitor(query_scale);
visitor(layer_configs);
visitor(attention_window_sizes);
@ -454,7 +459,6 @@ struct ModelConfig : public IFields {
float final_cap = 0.0f;
bool absolute_pe = false;
bool use_local_attention = false; // Griffin only
QueryScaleType query_scale = QueryScaleType::SqrtKeySize;
std::vector<LayerConfig> layer_configs;
std::vector<uint32_t> attention_window_sizes;
@ -478,7 +482,6 @@ struct ModelConfig : public IFields {
ModelConfig GetVitConfig(const ModelConfig& config);
enum DeducedLayerTypes {
kDeducedGriffin = 1,
kDeducedViT = 2,
kDeduced448 = 4, // For ViT, 448x448 resolution instead of 224x224.
};

View File

@ -34,7 +34,6 @@
// After highway.h
#include "gemma/attention.h" // includes highway.h
#include "gemma/gemma-inl.h"
#include "gemma/griffin.h" // includes highway.h
#include "gemma/vit.h" // includes highway.h
#ifndef GEMMA_CC_ONCE
@ -77,14 +76,6 @@ void Attention(LayerAttentionType type, const size_t num_tokens,
GemmaAttention(num_tokens, layer_idx, layer, activations.attention, qbatch,
env,
/*flags=*/0);
} else {
HWY_DASSERT(type == LayerAttentionType::kGriffinRecurrentBlock);
// KVCache conv1d_cache and rglru_cache have one row per *Griffin* layer,
// so map `layer` to the Griffin layer index.
const size_t griffin_layer =
activations.attention.config.NumLayersOfTypeBefore(type, layer_idx);
GriffinRecurrent(num_tokens, griffin_layer, &layer, activations, qbatch,
env);
}
}
@ -484,13 +475,6 @@ static void GenerateT(const ModelConfig& config,
const AesCtrEngine& engine, const WeightsPtrs& weights,
Activations& activations, QBatch& qbatch, MatMulEnv& env,
TimingInfo& timing_info) {
// Griffin assumes that the recurrent block cache is zero-initialized.
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
if (qbatch.MutablePos(qi) == 0) {
qbatch.KV(qi).ZeroGriffinCache(); // No-op for non-Griffin models.
}
}
size_t max_prompt_size = 0;
bool all_prefix_end_are_zero = true;
size_t total_prefill_tokens = 0; // only for throughput stats.

View File

@ -1,192 +0,0 @@
// Copyright 2025 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <stddef.h>
#include <stdint.h>
#include "compression/types.h" // GEMMA_DISABLED_TARGETS
#ifndef HWY_DISABLED_TARGETS
#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS
#endif // HWY_DISABLED_TARGETS
#include "gemma/activations.h"
#include "gemma/gemma.h"
#include "gemma/gemma_args.h"
#include "gemma/weights.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/profiler.h"
// Compiles this file for multiple architectures via "foreach_target.h", to
// which we pass the filename via macro 'argument'.
// clang-format off
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE "gemma/griffin.cc" // NOLINT
// clang-format on
#include "hwy/foreach_target.h" // IWYU pragma: keep
#include "hwy/highway.h"
// After highway.h
#include "ops/matvec-inl.h"
#include "ops/ops-inl.h"
HWY_BEFORE_NAMESPACE();
namespace gcpp {
namespace HWY_NAMESPACE {
void GriffinRecurrent(size_t num_tokens, size_t griffin_layer,
const LayerWeightsPtrs* layer_weights,
Activations& activations, QBatch& qbatch,
MatMulEnv& env) {
PROFILER_ZONE("Gen.Griffin");
hwy::ThreadPool& pool = env.ctx.pools.Pool(0);
namespace hn = hwy::HWY_NAMESPACE;
using D = hn::ScalableTag<float>;
const D df;
const size_t model_dim = layer_weights->layer_config.model_dim;
HWY_DASSERT(model_dim % hn::Lanes(df) == 0);
const size_t heads = layer_weights->layer_config.heads;
const size_t conv_1d_width = layer_weights->layer_config.conv1d_width;
HWY_ASSERT_M(conv_1d_width % 2 == 0, "Conv width must be even");
const size_t kHeadDim = model_dim / heads;
const size_t kMatrixSize = kHeadDim * kHeadDim;
const size_t num_interleaved = num_tokens * qbatch.Size();
const hwy::Divisor div_qbatch(static_cast<uint32_t>(qbatch.Size()));
GriffinActivations& griffin = activations.griffin;
// X / Y linear layers.
// TODO: MatMul
HWY_DASSERT(griffin.griffin_y.Rows() == griffin.griffin_x.Rows());
HWY_DASSERT(num_interleaved == griffin.griffin_y.Rows());
CallUpcastedSame(
&layer_weights->griffin.linear_x_w, &layer_weights->griffin.linear_y_w,
[&](const auto* wx, const auto* wy) {
for (size_t r = 0; r < num_interleaved; ++r) {
float* HWY_RESTRICT y = griffin.griffin_y.Row(r);
float* HWY_RESTRICT x = griffin.griffin_x.Row(r);
TwoMatVecAdd(
*wx, *wy, 0, model_dim, model_dim,
activations.attention.pre_att_rms_out.Row(r),
/*add0=*/layer_weights->griffin.linear_x_biases.PackedScale1(),
/*add1=*/layer_weights->griffin.linear_y_biases.PackedScale1(),
/*out0=*/x, /*out1=*/y, pool);
Gelu(y, model_dim);
}
});
// Conv1D.
for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved;
++interleaved_idx) {
const size_t qi = div_qbatch.Remainder(interleaved_idx);
const size_t batch_idx = div_qbatch.Divide(interleaved_idx);
const size_t pos = qbatch.Pos(qi) + batch_idx;
float* HWY_RESTRICT x = griffin.griffin_x.Row(qi);
// cache[i] = input at time t-i.
float* HWY_RESTRICT cache[kMaxConv1DWidth];
cache[0] = x;
for (size_t i = 1; i < conv_1d_width; i++) {
cache[i] =
qbatch.KV(qi).conv1d_cache.Row(griffin_layer) +
((pos + conv_1d_width - 1 - i) % (conv_1d_width - 1)) * model_dim;
}
for (size_t i = 0; i < model_dim; i += hn::Lanes(df)) {
auto xv = hn::Load(df, x + i);
auto accum0 =
hn::Load(df, layer_weights->griffin.conv_biases.PackedScale1() + i);
auto accum1 = hn::Zero(df);
for (size_t l = 0; 2 * l < conv_1d_width; l++) {
auto wv0 =
hn::Load(df, layer_weights->griffin.conv_w.PackedScale1() +
(conv_1d_width - 1 - 2 * l) * model_dim + i);
auto wv1 =
hn::Load(df, layer_weights->griffin.conv_w.PackedScale1() +
(conv_1d_width - 2 - 2 * l) * model_dim + i);
accum0 = hn::MulAdd(wv0, hn::Load(df, cache[l * 2] + i), accum0);
accum1 = hn::MulAdd(wv1, hn::Load(df, cache[l * 2 + 1] + i), accum1);
}
hn::Store(hn::Add(accum0, accum1), df, x + i);
hn::Store(xv, df, cache[HWY_MAX(conv_1d_width, 1) - 1] + i);
}
}
// RGLRU
for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved;
++interleaved_idx) {
const size_t qi = div_qbatch.Remainder(interleaved_idx);
const size_t batch_idx = div_qbatch.Divide(interleaved_idx);
const size_t pos = qbatch.Pos(qi) + batch_idx;
float* HWY_RESTRICT x = griffin.griffin_x.Row(qi);
float* HWY_RESTRICT y = griffin.griffin_y.Row(qi);
float* HWY_RESTRICT gate_x = griffin.griffin_gate_x.Row(qi);
float* HWY_RESTRICT a = griffin.griffin_multiplier.Row(qi);
float* HWY_RESTRICT rnn_state =
qbatch.KV(qi).rglru_cache.Row(griffin_layer);
pool.Run(0, heads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR {
size_t head_offset = head * kHeadDim;
CallUpcasted(&layer_weights->griffin.gate_w, [&](const auto* gate_w) {
TwoOfsMatVecAddLoop(
*gate_w, kMatrixSize * head, kMatrixSize * (heads + head), kHeadDim,
kHeadDim, x + head_offset,
/*add0=*/layer_weights->griffin.gate_biases.PackedScale1() +
head_offset,
/*add1=*/layer_weights->griffin.gate_biases.PackedScale1() +
model_dim + head_offset,
/*out0=*/gate_x + head_offset, /*out1=*/a + head_offset);
});
Sigmoid(gate_x + head_offset, kHeadDim);
Sigmoid(a + head_offset, kHeadDim);
const auto fn_mul = [](D d, hn::Vec<D> x, hn::Vec<D> gate_x)
HWY_ATTR { return hn::Mul(x, gate_x); };
hn::Transform1(D(), a + head_offset, kHeadDim,
layer_weights->griffin.a.PackedScale1() + head_offset,
fn_mul);
hn::Transform1(D(), x + head_offset, kHeadDim, gate_x + head_offset,
fn_mul);
// RNN scan
HWY_DASSERT(kHeadDim % hn::Lanes(df) == 0);
for (size_t i = 0; i < kHeadDim; i += hn::Lanes(df)) {
auto log_a = hn::Load(df, a + head_offset + i);
auto gated_x = hn::Load(df, x + head_offset + i);
auto rnn = hn::Load(df, rnn_state + head_offset + i);
auto a = hn::Exp(df, log_a);
auto x_multiplier = hn::Sqrt(hn::NegMulAdd(a, a, hn::Set(df, 1.0f)));
if (pos == 0) {
x_multiplier = hn::Set(df, 1.0f);
}
auto new_x = hn::MulAdd(x_multiplier, gated_x, hn::Mul(a, rnn));
hn::Store(new_x, df, rnn_state + head_offset + i);
// Join branches.
auto yv = hn::Load(df, y + head_offset + i);
auto pre_out = hn::Mul(yv, new_x);
hn::Store(pre_out, df, x + head_offset + i);
}
});
} // interleaved_idx
// Final linear layer.
CallMatMul(griffin.griffin_x, layer_weights->griffin.linear_out_w,
layer_weights->griffin.linear_out_biases.PackedScale1(), env,
activations.attention.att_sums);
} // GriffinRecurrent
// NOLINTNEXTLINE(google-readability-namespace-comments)
} // namespace HWY_NAMESPACE
} // namespace gcpp
HWY_AFTER_NAMESPACE();

View File

@ -1,47 +0,0 @@
// Copyright 2025 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_GRIFFIN_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_GRIFFIN_H_
// Declares GriffinRecurrent for all SIMD targets.
#include <stddef.h>
#include "gemma/gemma.h"
#include "hwy/highway.h"
namespace gcpp {
// Passed to HWY_VISIT_TARGETS; declares for one target.
#define GEMMA_DECL_GRIFFIN(TARGET, NAMESPACE) \
namespace NAMESPACE { \
void GriffinRecurrent(size_t num_tokens, size_t griffin_layer, \
const LayerWeightsPtrs* layer_weights, \
Activations& activations, QBatch& qbatch, \
MatMulEnv& env); \
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
} // namespace NAMESPACE
// Function declarations for each SIMD target. Allows direct call from the
// per-target namespace. We may later replace this with dynamic dispatch if
// the overhead is acceptable.
HWY_VISIT_TARGETS(GEMMA_DECL_GRIFFIN)
#undef GEMMA_DECL_GRIFFIN
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_GRIFFIN_H_

View File

@ -24,26 +24,6 @@
namespace gcpp {
void KVCache::ZeroGriffinCache() {
if (conv1d_cache.Rows() == 0) return;
ZeroInit(conv1d_cache);
ZeroInit(rglru_cache);
}
static size_t GriffinLayers(const ModelConfig& config) {
return config.NumLayersOfType(LayerAttentionType::kGriffinRecurrentBlock);
}
static size_t GriffinConv1dCols(const ModelConfig& config) {
size_t conv1d_width = 0;
for (const auto& layer_config : config.layer_configs) {
conv1d_width = HWY_MAX(conv1d_width, layer_config.conv1d_width);
}
// The row offset, in blocks of model_dim is computed mod (conv1d_width - 1),
// hence allocate conv1d_width * model_dim total columns.
return conv1d_width * config.model_dim;
}
// Number of rows for KV cache. Note that both rows and cols are u32, and
// the total number of elements can exceed 2^32.
static size_t CappedSeqLen(const ModelConfig& config,
@ -56,30 +36,18 @@ static size_t CappedSeqLen(const ModelConfig& config,
return inference_args.seq_len;
}
KVCache::KVCache(const Extents2D& conv1d_extents,
const Extents2D& rglru_extents, const Extents2D& kv_extents,
const Allocator& allocator)
: conv1d_cache("conv1d_cache", conv1d_extents, allocator, MatPadding::kOdd),
rglru_cache("rglru_cache", rglru_extents, allocator, MatPadding::kOdd),
kv_cache("kv", kv_extents, allocator, MatPadding::kOdd),
KVCache::KVCache(const Extents2D& kv_extents, const Allocator& allocator)
: kv_cache("kv", kv_extents, allocator, MatPadding::kOdd),
allocator_(allocator) {}
KVCache::KVCache(const ModelConfig& config, const InferenceArgs& inference_args,
const Allocator& allocator)
: KVCache(
Extents2D(GriffinLayers(config), GriffinConv1dCols(config)),
Extents2D(GriffinLayers(config), config.model_dim),
Extents2D(CappedSeqLen(config, inference_args), config.KVCacheCols()),
allocator) {}
KVCache KVCache::Copy() {
KVCache copy(conv1d_cache.Extents(), rglru_cache.Extents(),
kv_cache.Extents(), allocator_);
if (conv1d_cache.Rows() != 0) {
CopyMat(conv1d_cache, copy.conv1d_cache);
CopyMat(rglru_cache, copy.rglru_cache);
}
KVCache copy(kv_cache.Extents(), allocator_);
CopyMat(kv_cache, copy.kv_cache);

View File

@ -35,24 +35,15 @@ struct KVCache {
// copy ctor to make the cost explicit.
KVCache Copy();
// Zero-initialize the Griffin recurrent block cache, i.e. the conv1d_cache
// and rglru_cache.
void ZeroGriffinCache();
size_t SeqLen() const { return kv_cache.Rows(); }
// [griffin_layers, griffin_conv1d_cols * model_dim]
MatStorageT<float> conv1d_cache;
MatStorageT<float> rglru_cache; // [griffin_layers, model_dim]
MatStorageT<KV_t> kv_cache; // [seq_len, layers * kv_heads * qkv_dim * 2]
private:
const Allocator& allocator_;
// For use by other ctor and Copy()
KVCache(const Extents2D& conv1d_extents, const Extents2D& rglru_extents,
const Extents2D& kv_extents, const Allocator& allocator);
KVCache(const Extents2D& kv_extents, const Allocator& allocator);
};
} // namespace gcpp

View File

@ -221,9 +221,6 @@ static int DeduceLayerTypes(const BlobReader& reader) {
int layer_types = 0;
for (size_t key_idx = 0; key_idx < reader.Keys().size(); ++key_idx) {
const std::string& key = reader.Keys()[key_idx];
if (key.find("gr_conv_w") != std::string::npos) { // NOLINT
return kDeducedGriffin;
}
if (key.find("qkv_ein_w") != std::string::npos) { // NOLINT
layer_types |= kDeducedViT;
}
@ -293,7 +290,7 @@ static std::vector<float> ReadScales(BlobReader& reader,
const ModelConfig& config) {
std::vector<float> scales;
// Check first to prevent `CallWithSpan` from printing a warning. This blob is
// optional even in pre-2025 format; Griffin was the first to include it.
// optional even in pre-2025 format.
if (reader.Find(kDecoratedScalesName)) {
HWY_ASSERT(reader.CallWithSpan<float>(
kDecoratedScalesName,

View File

@ -277,122 +277,6 @@ void TensorInfoRegistry::AddImageLayerTensors(const ModelConfig& config,
});
}
void TensorInfoRegistry::AddGriffinLayerTensors(const LayerConfig& layer_config,
const size_t layer_idx) {
const std::string suffix = LayerSuffix(layer_idx);
Add(suffix, {
.base_name = "gr_lin_x_w",
.source_names = {"recurrent_block/linear_x/kernel"},
.axes = {1, 0},
.shape = {layer_config.griffin_dim, layer_config.griffin_dim},
});
Add(suffix, {
.base_name = "gr_lin_x_b",
.source_names = {"recurrent_block/linear_x/bias"},
.axes = {0},
.shape = {layer_config.griffin_dim},
.min_size = Type::kF32,
});
Add(suffix, {
.base_name = "gr_lin_y_w",
.source_names = {"recurrent_block/linear_y/kernel"},
.axes = {1, 0},
.shape = {layer_config.griffin_dim, layer_config.griffin_dim},
});
Add(suffix, {
.base_name = "gr_lin_y_b",
.source_names = {"recurrent_block/linear_y/bias"},
.axes = {0},
.shape = {layer_config.griffin_dim},
.min_size = Type::kF32,
});
Add(suffix, {
.base_name = "gr_lin_out_w",
.source_names = {"recurrent_block/linear_out/kernel"},
.axes = {1, 0},
.shape = {layer_config.griffin_dim, layer_config.griffin_dim},
});
Add(suffix, {
.base_name = "gr_lin_out_b",
.source_names = {"recurrent_block/linear_out/bias"},
.axes = {0},
.shape = {layer_config.griffin_dim},
.min_size = Type::kF32,
});
Add(suffix,
{
.base_name = "gr_conv_w",
.source_names = {"recurrent_block/conv_1d/w"},
.axes = {0, 1},
.shape = {layer_config.conv1d_width, layer_config.griffin_dim},
.min_size = Type::kF32,
});
Add(suffix, {
.base_name = "gr_conv_b",
.source_names = {"recurrent_block/conv_1d/b"},
.axes = {0},
.shape = {layer_config.griffin_dim},
.min_size = Type::kF32,
});
Add(suffix, {
.base_name = "gr1_gate_w",
.source_names = {"recurrent_block/rg_lru/input_gate/w"},
.axes = {0, 2, 1},
.shape = {layer_config.heads,
layer_config.griffin_dim / layer_config.heads,
layer_config.griffin_dim / layer_config.heads},
.concat_names = {"gr_gate_w", "gr2_gate_w"},
});
Add(suffix, {
.base_name = "gr2_gate_w",
.source_names = {"recurrent_block/rg_lru/a_gate/w"},
.axes = {0, 2, 1},
.shape = {layer_config.heads,
layer_config.griffin_dim / layer_config.heads,
layer_config.griffin_dim / layer_config.heads},
.concat_names = {""},
});
Add(suffix, {
.base_name = "gr_gate_w",
.source_names = {"recurrent_block/rg_lru/gate/w"},
.axes = {0, 2, 1},
.shape = {2 * layer_config.heads,
layer_config.griffin_dim / layer_config.heads,
layer_config.griffin_dim / layer_config.heads},
});
Add(suffix, {
.base_name = "gr1_gate_b",
.source_names = {"recurrent_block/rg_lru/input_gate/b"},
.axes = {0},
.shape = {layer_config.griffin_dim},
.concat_names = {"gr_gate_b", "gr2_gate_b"},
.min_size = Type::kF32,
});
Add(suffix, {
.base_name = "gr2_gate_b",
.source_names = {"recurrent_block/rg_lru/a_gate/b"},
.axes = {0},
.shape = {layer_config.griffin_dim},
.concat_names = {""},
.min_size = Type::kF32,
});
Add(suffix, {
.base_name = "gr_gate_b",
.source_names = {"recurrent_block/rg_lru/input_gate/b"},
.axes = {0, 1},
.shape = {2 * layer_config.griffin_dim},
.min_size = Type::kF32,
});
Add(suffix, {
.base_name = "gr_a",
.source_names = {"recurrent_block/rg_lru/a_param"},
.axes = {0},
.shape = {layer_config.griffin_dim},
.min_size = Type::kF32,
.scaled_softplus = true,
});
}
void TensorInfoRegistry::AddLayerTensors(const ModelConfig& config,
const LayerConfig& layer_config,
const size_t layer_idx) {
@ -553,10 +437,6 @@ void TensorInfoRegistry::AddLayerTensors(const ModelConfig& config,
.shape = {config.model_dim, layer_config.heads, layer_config.qkv_dim},
.cols_take_extra_dims = true,
});
if (config.model == Model::GRIFFIN_2B) {
AddGriffinLayerTensors(layer_config, layer_idx);
}
}
TensorInfoRegistry::TensorInfoRegistry(const ModelConfig& config) {

View File

@ -124,8 +124,6 @@ class TensorInfoRegistry {
void AddModelTensors(const ModelConfig& config);
void AddLayerTensors(const ModelConfig& config,
const LayerConfig& layer_config, size_t layer_idx);
void AddGriffinLayerTensors(const LayerConfig& layer_config,
size_t layer_idx);
void AddImageLayerTensors(const ModelConfig& config,
const LayerConfig& layer_config,

View File

@ -88,7 +88,7 @@ void LayerWeightsPtrs::InitAttWeights(std::vector<MatOwner>& mat_owners,
// For FFN. Fast, only updates pointers.
void LayerWeightsPtrs::SplitW1() {
// Used for Gemma and Griffin layers; FFWVit uses different tensors.
// Used for Gemma layers; FFWVit uses different tensors.
if (layer_config.type == LayerAttentionType::kVit) return;
// Files have both or neither of w1 and w2.

View File

@ -57,8 +57,7 @@ struct TensorArgs {
// the _w1/_w2 tensors are not always present.
kMaybeRead = 1,
// Avoid padding tensor rows when reading. Used for some Griffin tensors
// whose index computations do not use Row() accessors.
// Avoid padding tensor rows when reading.
kPacked = 2,
};
const int flags;
@ -102,17 +101,6 @@ struct LayerWeightsPtrs {
qkv_einsum_w1(finder_("qkv1_w")),
qkv_einsum_w2(finder_("qkv2_w")),
attention_output_biases(finder_("attn_ob")),
griffin({.linear_x_w = finder_("gr_lin_x_w"),
.linear_x_biases = finder_("gr_lin_x_b"),
.linear_y_w = finder_("gr_lin_y_w"),
.linear_y_biases = finder_("gr_lin_y_b"),
.linear_out_w = finder_("gr_lin_out_w"),
.linear_out_biases = finder_("gr_lin_out_b"),
.conv_w = finder_("gr_conv_w"),
.conv_biases = finder_("gr_conv_b"),
.gate_w = finder_("gr_gate_w"),
.gate_biases = finder_("gr_gate_b"),
.a = finder_("gr_a")}),
// MultiHeadDotProductAttention.
vit({.attn_out_w = finder_("attn_out_w"),
.attn_out_b = finder_("attn_out_b"),
@ -156,20 +144,6 @@ struct LayerWeightsPtrs {
MatPtr qkv_einsum_w2;
MatPtrT<float> attention_output_biases;
struct {
MatPtr linear_x_w;
MatPtrT<float> linear_x_biases;
MatPtr linear_y_w;
MatPtrT<float> linear_y_biases;
MatPtr linear_out_w;
MatPtrT<float> linear_out_biases;
MatPtrT<float> conv_w;
MatPtrT<float> conv_biases;
MatPtr gate_w;
MatPtrT<float> gate_biases;
MatPtrT<float> a;
} griffin;
struct {
// MultiHeadDotProductAttention.
MatPtr attn_out_w; // at least BF16.
@ -244,20 +218,6 @@ struct LayerWeightsPtrs {
func(TENSOR_ARGS(qkv_einsum_w, kMaybeRead));
func(TENSOR_ARGS(qkv_einsum_w1, kMaybeRead));
func(TENSOR_ARGS(qkv_einsum_w2, kMaybeRead));
} else {
func(TENSOR_ARGS(griffin.linear_x_w, kMustRead));
func(TENSOR_ARGS(griffin.linear_x_biases, kMustRead));
func(TENSOR_ARGS(griffin.linear_y_w, kMustRead));
func(TENSOR_ARGS(griffin.linear_y_biases, kMustRead));
func(TENSOR_ARGS(griffin.linear_out_w, kMustRead));
func(TENSOR_ARGS(griffin.linear_out_biases, kMustRead));
// conv_w and gate_w are not accessed via Row(), hence must not be padded.
// Note that *biases are 1D, hence packing/padding does not matter.
func(TENSOR_ARGS(griffin.conv_w, kMustRead | TensorArgs::kPacked));
func(TENSOR_ARGS(griffin.conv_biases, kMustRead));
func(TENSOR_ARGS(griffin.gate_w, kMustRead | TensorArgs::kPacked));
func(TENSOR_ARGS(griffin.gate_biases, kMustRead));
func(TENSOR_ARGS(griffin.a, kMustRead));
}
{
func(TENSOR_ARGS(gating_einsum_w, kMaybeRead));
@ -281,11 +241,6 @@ struct LayerWeightsPtrs {
func(TENSOR_ARGS(ffw_gating_biases, kMustRead));
func(TENSOR_ARGS(ffw_output_biases, kMustRead));
}
if (layer_config.softmax_attn_output_biases &&
layer_config.type == LayerAttentionType::kGemma) {
func(TENSOR_ARGS(attention_output_biases, kMustRead));
}
} // `ForEachTensor`
// Zero-initializes all allocated tensors in the layer.

View File

@ -57,8 +57,6 @@ PYBIND11_MODULE(configs, py_module) {
enum_<LayerAttentionType>(py_module, "LayerAttentionType")
.value("kGemma", LayerAttentionType::kGemma)
.value("kGriffinRecurrentBlock",
LayerAttentionType::kGriffinRecurrentBlock)
.value("kVit", LayerAttentionType::kVit);
enum_<PostNormType>(py_module, "PostNormType")
@ -84,8 +82,6 @@ PYBIND11_MODULE(configs, py_module) {
.value("UNKNOWN", Model::UNKNOWN)
.value("GEMMA2_9B", Model::GEMMA2_9B)
.value("GEMMA2_27B", Model::GEMMA2_27B)
.value("GRIFFIN_2B", Model::GRIFFIN_2B)
.value("GEMMA_TINY", Model::GEMMA_TINY)
.value("GEMMA2_2B", Model::GEMMA2_2B)
.value("PALIGEMMA2_3B_224", Model::PALIGEMMA2_3B_224)
.value("PALIGEMMA2_10B_224", Model::PALIGEMMA2_10B_224)
@ -121,15 +117,11 @@ PYBIND11_MODULE(configs, py_module) {
class_<LayerConfig>(py_module, "LayerConfig")
.def(init())
.def_readwrite("model_dim", &LayerConfig::model_dim)
.def_readwrite("griffin_dim", &LayerConfig::griffin_dim)
.def_readwrite("ff_hidden_dim", &LayerConfig::ff_hidden_dim)
.def_readwrite("heads", &LayerConfig::heads)
.def_readwrite("kv_heads", &LayerConfig::kv_heads)
.def_readwrite("qkv_dim", &LayerConfig::qkv_dim)
.def_readwrite("conv1d_width", &LayerConfig::conv1d_width)
.def_readwrite("ff_biases", &LayerConfig::ff_biases)
.def_readwrite("softmax_attn_output_biases",
&LayerConfig::softmax_attn_output_biases)
.def_readwrite("optimized_gating", &LayerConfig::optimized_gating)
.def_readwrite("post_norm", &LayerConfig::post_norm)
.def_readwrite("type", &LayerConfig::type)
@ -166,7 +158,6 @@ PYBIND11_MODULE(configs, py_module) {
.def_readwrite("att_cap", &ModelConfig::att_cap)
.def_readwrite("final_cap", &ModelConfig::final_cap)
.def_readwrite("absolute_pe", &ModelConfig::absolute_pe)
.def_readwrite("use_local_attention", &ModelConfig::use_local_attention)
.def_readwrite("query_scale", &ModelConfig::query_scale)
.def_readwrite("layer_configs", &ModelConfig::layer_configs)
.def_readwrite("attention_window_sizes",