mirror of https://github.com/google/gemma.cpp.git
Remove Griffin support
Also add IsObsolete helper PiperOrigin-RevId: 803376921
This commit is contained in:
parent
56186193c1
commit
2b4c16e243
|
|
@ -507,14 +507,12 @@ cc_library(
|
|||
srcs = [
|
||||
"gemma/attention.cc",
|
||||
"gemma/gemma.cc",
|
||||
"gemma/griffin.cc",
|
||||
"gemma/vit.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"gemma/activations.h",
|
||||
"gemma/attention.h",
|
||||
"gemma/gemma.h",
|
||||
"gemma/griffin.h",
|
||||
"gemma/vit.h",
|
||||
],
|
||||
exec_properties = {
|
||||
|
|
|
|||
|
|
@ -83,8 +83,6 @@ set(SOURCES
|
|||
gemma/gemma-inl.h
|
||||
gemma/gemma.cc
|
||||
gemma/gemma.h
|
||||
gemma/griffin.cc
|
||||
gemma/griffin.h
|
||||
gemma/kv_cache.cc
|
||||
gemma/kv_cache.h
|
||||
gemma/model_store.cc
|
||||
|
|
|
|||
21
README.md
21
README.md
|
|
@ -53,7 +53,7 @@ Guidelines](https://opensource.google.com/conduct/).
|
|||
|
||||
- LLM
|
||||
|
||||
- CPU-only inference for: Gemma 2-3, Griffin(SSM), PaliGemma 2.
|
||||
- CPU-only inference for: Gemma 2-3, PaliGemma 2.
|
||||
- Sampling with TopK and temperature.
|
||||
- Backward pass (VJP) and Adam optimizer for Gemma research.
|
||||
|
||||
|
|
@ -222,23 +222,6 @@ Example invocation for the following configuration:
|
|||
--tokenizer tokenizer.spm --weights gemma2-2b-it-sfp.sbs
|
||||
```
|
||||
|
||||
### RecurrentGemma
|
||||
|
||||
This repository includes a version of Gemma based on Griffin
|
||||
([paper](https://arxiv.org/abs/2402.19427),
|
||||
[code](https://github.com/google-deepmind/recurrentgemma)). Its architecture
|
||||
includes both recurrent layers and local attention, thus it is more efficient
|
||||
for longer sequences and has a smaller memory footprint than standard Gemma. We
|
||||
here provide a C++ implementation of this model based on the paper.
|
||||
|
||||
To use the recurrent version of Gemma included in this repository, build the
|
||||
gemma binary as noted above in Step 3. Download the compressed weights and
|
||||
tokenizer from the RecurrentGemma
|
||||
[Kaggle](https://www.kaggle.com/models/google/recurrentgemma/gemmaCpp) as in
|
||||
Step 1, and run the binary as follows:
|
||||
|
||||
`./gemma --tokenizer tokenizer.spm --model gr2b-it --weights 2b-it-sfp.sbs`
|
||||
|
||||
### PaliGemma Vision-Language Model
|
||||
|
||||
This repository includes a version of the PaliGemma 2 VLM
|
||||
|
|
@ -535,7 +518,7 @@ gemma.cpp was started in fall 2023 by
|
|||
Griffin support was implemented in April 2024 thanks to contributions by Andrey
|
||||
Mikhaylov, Eugene Kliuchnikov, Jan Wassenberg, Jyrki Alakuijala, Lode
|
||||
Vandevenne, Luca Versari, Martin Bruse, Phil Culliton, Sami Boukortt, Thomas
|
||||
Fischbacher and Zoltan Szabadka.
|
||||
Fischbacher and Zoltan Szabadka. It was removed in 2025-09.
|
||||
|
||||
Gemma-2 support was implemented in June/July 2024 with the help of several
|
||||
people.
|
||||
|
|
|
|||
|
|
@ -91,7 +91,7 @@ class CompressionTest(absltest.TestCase):
|
|||
)
|
||||
|
||||
config = configs.ModelConfig(
|
||||
configs.Model.GEMMA_TINY,
|
||||
configs.Model.GEMMA2_2B,
|
||||
configs.Type.kSFP,
|
||||
configs.PromptWrapping.GEMMA_IT,
|
||||
)
|
||||
|
|
@ -101,7 +101,7 @@ class CompressionTest(absltest.TestCase):
|
|||
print("Ignore next two warnings; test does not enable model deduction.")
|
||||
reader = compression.SbsReader(temp_file.full_path)
|
||||
|
||||
self.assertEqual(reader.config.model, configs.Model.GEMMA_TINY)
|
||||
self.assertEqual(reader.config.model, configs.Model.GEMMA2_2B)
|
||||
self.assertEqual(reader.config.weight, configs.Type.kSFP)
|
||||
|
||||
mat = reader.find_mat("tensor0")
|
||||
|
|
|
|||
|
|
@ -1,8 +0,0 @@
|
|||
|
||||
# General Remarks about the "PyTree" Abstraction
|
||||
|
||||
The pytree wrangling code in this project does not use any of the existing
|
||||
"pytree" modules. The deeper reason here is that our approach is based on an
|
||||
analysis of the notion that emphasizes deeper underlying principles. This is
|
||||
being discussed internally at the time of this writing.
|
||||
|
||||
|
|
@ -1,275 +0,0 @@
|
|||
"""Ad-hoc glue code for building the griffin model-file for the C++ binary.
|
||||
|
||||
Usage:
|
||||
|
||||
python3 -m venv $HOME/clients/griffin-venv
|
||||
|
||||
. $HOME/clients/griffin-venv/bin/activate
|
||||
|
||||
python3 -m pip install -r requirements.txt
|
||||
|
||||
time python3 build_model_file_for_cpp_binary.py \
|
||||
$HOME/GRIFFIN/model_data \
|
||||
cpp_load_log.txt /tmp/G2B.data
|
||||
|
||||
real 3m5.821s
|
||||
user 2m9.205s
|
||||
sys 2m46.720s
|
||||
|
||||
./compress_weights --weights /tmp/G2B.data --model gr2b-it \
|
||||
--compressed_weights /tmp/G2B.compressed
|
||||
./gemma --tokenizer tokenizer.spm --weights /tmp/G2B.compressed \
|
||||
--model gr2b-it
|
||||
|
||||
Weights for the recurrent-gemma model that can be converted with this script
|
||||
can be found at:
|
||||
|
||||
https://www.kaggle.com/models/google/recurrentgemma/flax/2b-it
|
||||
"""
|
||||
|
||||
import pprint
|
||||
import re
|
||||
import sys
|
||||
|
||||
from typing import Any, Mapping
|
||||
|
||||
import numpy
|
||||
|
||||
import orbax.checkpoint
|
||||
|
||||
import ml_model_transforms
|
||||
import pytree_transforms
|
||||
|
||||
|
||||
def _fn_identity(x): return x
|
||||
|
||||
|
||||
def _fn_transpose(x): return x.T
|
||||
|
||||
|
||||
def _fn_transpose_all_heads(x): return x.transpose(0, 2, 1)
|
||||
|
||||
|
||||
def _fn_scaled_softplus(a):
|
||||
return -8 * numpy.logaddexp(a, 0)
|
||||
|
||||
|
||||
def _fn_attention_moveaxis(a):
|
||||
return a.reshape(10, 256, 2560).transpose(0, 2, 1)
|
||||
|
||||
|
||||
def _aspec(pieces=(), transforms=()):
|
||||
"""Short-hand array-save-specification.
|
||||
|
||||
Args:
|
||||
pieces: Sequence of key-sequences identifying an array.
|
||||
transforms: Sequence of transformations, indexed in
|
||||
parallel to `pieces`, to apply to data arrays prior to saving.
|
||||
Will be padded with identity-transformations to the length of `pieces`.
|
||||
|
||||
Returns:
|
||||
Specification as for use in _LAYETR_NAME_MAPPING.
|
||||
"""
|
||||
# `zip` trims to shortest sequence, so this amounts to using
|
||||
# default-transforms.
|
||||
# tuple() since we need a Sequence here, not a stateful-iterator zip_object.
|
||||
return tuple(zip(pieces, list(transforms) + [_fn_identity] * len(pieces)))
|
||||
|
||||
|
||||
_LAYER_NAME_MAPPING = pytree_transforms.deep_freeze({
|
||||
# Recurrent Layer
|
||||
'griffin_linear_x_w': _aspec(
|
||||
[('recurrent_block', 'linear_x', 'kernel')],
|
||||
[_fn_transpose]),
|
||||
'griffin_linear_x_biases': _aspec(
|
||||
[('recurrent_block', 'linear_x', 'bias')]),
|
||||
'griffin_linear_y_w': _aspec(
|
||||
[('recurrent_block', 'linear_y', 'kernel')],
|
||||
[_fn_transpose]),
|
||||
'griffin_linear_y_biases': _aspec(
|
||||
[('recurrent_block', 'linear_y', 'bias')]),
|
||||
'griffin_linear_out_w': _aspec(
|
||||
[('recurrent_block', 'linear_out', 'kernel')],
|
||||
[_fn_transpose]),
|
||||
'griffin_linear_out_biases': _aspec(
|
||||
[('recurrent_block' ,'linear_out', 'bias')]),
|
||||
'griffin_conv_w': _aspec(
|
||||
[('recurrent_block', 'conv_1d', 'w')]),
|
||||
'griffin_conv_biases': _aspec(
|
||||
[('recurrent_block', 'conv_1d', 'b')]),
|
||||
'griffin_gate_w': _aspec(
|
||||
[('recurrent_block', 'rg_lru', 'input_gate', 'w'),
|
||||
('recurrent_block', 'rg_lru', 'a_gate', 'w')],
|
||||
[_fn_transpose_all_heads, _fn_transpose_all_heads]),
|
||||
'griffin_gate_biases': _aspec(
|
||||
[('recurrent_block', 'rg_lru', 'input_gate', 'b'),
|
||||
('recurrent_block', 'rg_lru', 'a_gate', 'b')]),
|
||||
'griffin_a': _aspec(
|
||||
[('recurrent_block', 'rg_lru', 'a_param')],
|
||||
[_fn_scaled_softplus]),
|
||||
# Attention Layer
|
||||
'qkv_einsum_w': _aspec(
|
||||
[('attention_block', 'proj_q', 'kernel'),
|
||||
('attention_block', 'proj_k', 'kernel'),
|
||||
('attention_block', 'proj_v', 'kernel'),
|
||||
],
|
||||
[_fn_transpose, _fn_transpose, _fn_transpose]),
|
||||
'attn_vec_einsum_w': _aspec(
|
||||
[('attention_block', 'proj_final', 'kernel')],
|
||||
[_fn_attention_moveaxis]),
|
||||
'attention_output_biases': _aspec(
|
||||
[('attention_block', 'proj_final', 'bias')]),
|
||||
# Common
|
||||
'pre_attention_norm_scale': _aspec(
|
||||
[('temporal_pre_norm', 'scale')]),
|
||||
'pre_ffw_norm_scale': _aspec(
|
||||
[('channel_pre_norm', 'scale')]),
|
||||
'gating_einsum_w': _aspec(
|
||||
[('mlp_block', 'ffw_up', 'w')],
|
||||
[_fn_transpose_all_heads]),
|
||||
'ffw_gating_biases': _aspec(
|
||||
[('mlp_block', 'ffw_up', 'b')]),
|
||||
'linear_w': _aspec(
|
||||
[('mlp_block', 'ffw_down', 'kernel')],
|
||||
[_fn_transpose]),
|
||||
'ffw_output_biases': _aspec(
|
||||
[('mlp_block', 'ffw_down', 'bias')]),
|
||||
# Other
|
||||
'embedder_input_embedding': _aspec(
|
||||
[('embedder', 'input_embedding')]),
|
||||
'final_norm_scale': _aspec(
|
||||
[('final_norm', 'scale')]),
|
||||
})
|
||||
|
||||
|
||||
def process_param_line(line : str) -> tuple[None | str, int, str]:
|
||||
"""Processes a "loading parameters" log-line from the griffin binary."""
|
||||
# This is slightly more permissive than strictly needed, to also handle
|
||||
# some earlier form of the output.
|
||||
matched = re.match(
|
||||
r'(?a)Loading Parameters:? \('
|
||||
r'(?:layer=(?P<layer>\d+), )?'
|
||||
r'size (?P<size>\d+)\):? '
|
||||
r'(?P<tag>\S+)',
|
||||
line)
|
||||
if not matched:
|
||||
return None
|
||||
layer = matched['layer']
|
||||
wanted_size = int(matched['size'])
|
||||
cpp_tag = matched['tag']
|
||||
return matched['layer'], int(matched['size']), matched['tag']
|
||||
|
||||
|
||||
def collect_pytree_keys(param_lines):
|
||||
"""Collects all the pytree keys and transforms for model-serialization."""
|
||||
pytree_keys = []
|
||||
array_transforms = []
|
||||
unsatisfied = []
|
||||
for maybe_spec in map(process_param_line, param_lines):
|
||||
if not maybe_spec: continue # Skip non-parameter lines.
|
||||
layer, wanted_size, cpp_tag = maybe_spec
|
||||
pytree_key_tails_and_transforms = _LAYER_NAME_MAPPING.get(cpp_tag, ())
|
||||
if not pytree_key_tails_and_transforms:
|
||||
unsatisfied.append((layer, cpp_tag))
|
||||
else:
|
||||
for key_tail, array_transform in pytree_key_tails_and_transforms:
|
||||
pytree_keys.append(
|
||||
key_tail if layer is None
|
||||
else (f'blocks.{layer}',) + key_tail)
|
||||
array_transforms.append(array_transform)
|
||||
return pytree_keys, array_transforms, unsatisfied
|
||||
|
||||
|
||||
class UnsatisfiedArrayLoadsError(ValueError):
|
||||
"""Some array-loads could not be satisfied."""
|
||||
|
||||
|
||||
def flatten_model_for_cpp_binary(tree,
|
||||
cpp_expectations_logfile_path : str,
|
||||
out_path : str,
|
||||
unsatisfied_ok : bool = False
|
||||
):
|
||||
"""Produces a model-parameters file readable by the C++ binary.
|
||||
|
||||
Args:
|
||||
tree: The pytree with model-parameters.
|
||||
cpp_expectations_logfile_path:
|
||||
Path to a logfile produced by the C++ binary that shows
|
||||
the expected array-order.
|
||||
out_path: Path to the model-weights file to be written.
|
||||
unsatisfied_ok: If true, we ignore the presence of unsatisfied
|
||||
array-loads and write a model-parameters file that skips these pieces.
|
||||
This will lead to an unusable model-parameters file which however
|
||||
still might be useful for other analysis.
|
||||
|
||||
Returns:
|
||||
Tuple `(unknown_keys, missing_keys)`, where `unknown_keys`
|
||||
is a sequence of `(layer_or_None, name)` descriptions of the keys
|
||||
in the C++ log that could not be satisfied, and `missing_keys`
|
||||
is a sequence of linearized pytree key-sequences for keys
|
||||
not found in the checkpoint.
|
||||
|
||||
Raises:
|
||||
UnsatisfiedArrayLoadsError: If some of the expected arrays
|
||||
could not be included in the output and `unsatisfied_ok`
|
||||
is false.
|
||||
"""
|
||||
with open(cpp_expectations_logfile_path, 'rt') as h_log:
|
||||
pytree_keys, array_transforms, unknown_keys = collect_pytree_keys(
|
||||
list(h_log))
|
||||
rank_by_pytree_key = {k: n for n, k in enumerate(pytree_keys)}
|
||||
array_transform_by_pytree_key = dict(zip(pytree_keys, array_transforms))
|
||||
#
|
||||
model_contents = ml_model_transforms.model_contents(tree)
|
||||
missing_keys = set(pytree_keys) - model_contents.keys()
|
||||
if (unknown_keys or missing_keys) and not unsatisfied_ok:
|
||||
raise ValueError(
|
||||
f'Unsatisfied loads: unknown_keys: {unknown_keys!r}, '
|
||||
f'missing keys: {sorted(missing_keys)!r}')
|
||||
ml_model_transforms.model_save(
|
||||
tree,
|
||||
filepath_stem=out_path,
|
||||
data_suffix='',
|
||||
manifest_suffix=None,
|
||||
array_transform_by_pytree_key=array_transform_by_pytree_key,
|
||||
key=rank_by_pytree_key.get,
|
||||
report=lambda line: print(line, file=sys.stderr),
|
||||
byte_align=1)
|
||||
return tuple(unknown_keys), tuple(sorted(missing_keys))
|
||||
|
||||
|
||||
def main(args):
|
||||
"""Creates the model-file.
|
||||
|
||||
Args:
|
||||
sys.argv[] parameters from command line sans the leading one.
|
||||
|
||||
Returns:
|
||||
The pytree with all the de-serialized variables, such as for convenient
|
||||
`python3 -i` inspection.
|
||||
"""
|
||||
try:
|
||||
model_dir, cpp_load_log, out_path = args
|
||||
except Exception:
|
||||
sys.exit(f'Usage: {__file__} [model_dir] [cpp_load_log] [output_filename]')
|
||||
pattern = ("recurrent", "recurrent", "attention")
|
||||
orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
|
||||
variables = orbax_checkpointer.restore(model_dir)
|
||||
if sorted(variables) == ['params']:
|
||||
print('Warning: Using `variables["params"]` as tree-root.', file=sys.stderr)
|
||||
variables_to_use = variables['params']
|
||||
else:
|
||||
variables_to_use = variables
|
||||
unknown, missing = flatten_model_for_cpp_binary(variables_to_use,
|
||||
cpp_load_log,
|
||||
out_path,
|
||||
unsatisfied_ok=True)
|
||||
print('Model file saved.\n'
|
||||
f'# unknown:\n{pprint.pformat(unknown)}\n'
|
||||
f'# missing:\n{pprint.pformat(missing)}')
|
||||
return variables
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Return value assignment is for `python3 -i ...` inspection.
|
||||
pytree = main(sys.argv[1:])
|
||||
|
|
@ -1,380 +0,0 @@
|
|||
Loading Parameters (size 2622750720): embedder_input_embedding
|
||||
Loading Parameters (size 10240): final_norm_scale
|
||||
Loading Parameters: (layer=0, size 26214400) griffin_linear_x_w
|
||||
Loading Parameters: (layer=0, size 10240) griffin_linear_x_biases
|
||||
Loading Parameters: (layer=0, size 26214400) griffin_linear_y_w
|
||||
Loading Parameters: (layer=0, size 10240) griffin_linear_y_biases
|
||||
Loading Parameters: (layer=0, size 26214400) griffin_linear_out_w
|
||||
Loading Parameters: (layer=0, size 10240) griffin_linear_out_biases
|
||||
Loading Parameters: (layer=0, size 40960) griffin_conv_w
|
||||
Loading Parameters: (layer=0, size 10240) griffin_conv_biases
|
||||
Loading Parameters: (layer=0, size 5242880) griffin_gate_w
|
||||
Loading Parameters: (layer=0, size 20480) griffin_gate_biases
|
||||
Loading Parameters: (layer=0, size 10240) griffin_a
|
||||
Loading Parameters: (layer=0, size 157286400) gating_einsum_w
|
||||
Loading Parameters: (layer=0, size 78643200) linear_w
|
||||
Loading Parameters: (layer=0, size 10240) pre_attention_norm_scale
|
||||
Loading Parameters: (layer=0, size 10240) pre_ffw_norm_scale
|
||||
Loading Parameters: (layer=0, size 61440) ffw_gating_biases
|
||||
Loading Parameters: (layer=0, size 10240) ffw_output_biases
|
||||
Loading Parameters: (layer=1, size 26214400) griffin_linear_x_w
|
||||
Loading Parameters: (layer=1, size 10240) griffin_linear_x_biases
|
||||
Loading Parameters: (layer=1, size 26214400) griffin_linear_y_w
|
||||
Loading Parameters: (layer=1, size 10240) griffin_linear_y_biases
|
||||
Loading Parameters: (layer=1, size 26214400) griffin_linear_out_w
|
||||
Loading Parameters: (layer=1, size 10240) griffin_linear_out_biases
|
||||
Loading Parameters: (layer=1, size 40960) griffin_conv_w
|
||||
Loading Parameters: (layer=1, size 10240) griffin_conv_biases
|
||||
Loading Parameters: (layer=1, size 5242880) griffin_gate_w
|
||||
Loading Parameters: (layer=1, size 20480) griffin_gate_biases
|
||||
Loading Parameters: (layer=1, size 10240) griffin_a
|
||||
Loading Parameters: (layer=1, size 157286400) gating_einsum_w
|
||||
Loading Parameters: (layer=1, size 78643200) linear_w
|
||||
Loading Parameters: (layer=1, size 10240) pre_attention_norm_scale
|
||||
Loading Parameters: (layer=1, size 10240) pre_ffw_norm_scale
|
||||
Loading Parameters: (layer=1, size 61440) ffw_gating_biases
|
||||
Loading Parameters: (layer=1, size 10240) ffw_output_biases
|
||||
Loading Parameters: (layer=2, size 26214400) attn_vec_einsum_w
|
||||
Loading Parameters: (layer=2, size 78643200) qkv_einsum_w
|
||||
Loading Parameters: (layer=2, size 157286400) gating_einsum_w
|
||||
Loading Parameters: (layer=2, size 78643200) linear_w
|
||||
Loading Parameters: (layer=2, size 10240) pre_attention_norm_scale
|
||||
Loading Parameters: (layer=2, size 10240) pre_ffw_norm_scale
|
||||
Loading Parameters: (layer=2, size 61440) ffw_gating_biases
|
||||
Loading Parameters: (layer=2, size 10240) ffw_output_biases
|
||||
Loading Parameters: (layer=2, size 10240) attention_output_biases
|
||||
Loading Parameters: (layer=3, size 26214400) griffin_linear_x_w
|
||||
Loading Parameters: (layer=3, size 10240) griffin_linear_x_biases
|
||||
Loading Parameters: (layer=3, size 26214400) griffin_linear_y_w
|
||||
Loading Parameters: (layer=3, size 10240) griffin_linear_y_biases
|
||||
Loading Parameters: (layer=3, size 26214400) griffin_linear_out_w
|
||||
Loading Parameters: (layer=3, size 10240) griffin_linear_out_biases
|
||||
Loading Parameters: (layer=3, size 40960) griffin_conv_w
|
||||
Loading Parameters: (layer=3, size 10240) griffin_conv_biases
|
||||
Loading Parameters: (layer=3, size 5242880) griffin_gate_w
|
||||
Loading Parameters: (layer=3, size 20480) griffin_gate_biases
|
||||
Loading Parameters: (layer=3, size 10240) griffin_a
|
||||
Loading Parameters: (layer=3, size 157286400) gating_einsum_w
|
||||
Loading Parameters: (layer=3, size 78643200) linear_w
|
||||
Loading Parameters: (layer=3, size 10240) pre_attention_norm_scale
|
||||
Loading Parameters: (layer=3, size 10240) pre_ffw_norm_scale
|
||||
Loading Parameters: (layer=3, size 61440) ffw_gating_biases
|
||||
Loading Parameters: (layer=3, size 10240) ffw_output_biases
|
||||
Loading Parameters: (layer=4, size 26214400) griffin_linear_x_w
|
||||
Loading Parameters: (layer=4, size 10240) griffin_linear_x_biases
|
||||
Loading Parameters: (layer=4, size 26214400) griffin_linear_y_w
|
||||
Loading Parameters: (layer=4, size 10240) griffin_linear_y_biases
|
||||
Loading Parameters: (layer=4, size 26214400) griffin_linear_out_w
|
||||
Loading Parameters: (layer=4, size 10240) griffin_linear_out_biases
|
||||
Loading Parameters: (layer=4, size 40960) griffin_conv_w
|
||||
Loading Parameters: (layer=4, size 10240) griffin_conv_biases
|
||||
Loading Parameters: (layer=4, size 5242880) griffin_gate_w
|
||||
Loading Parameters: (layer=4, size 20480) griffin_gate_biases
|
||||
Loading Parameters: (layer=4, size 10240) griffin_a
|
||||
Loading Parameters: (layer=4, size 157286400) gating_einsum_w
|
||||
Loading Parameters: (layer=4, size 78643200) linear_w
|
||||
Loading Parameters: (layer=4, size 10240) pre_attention_norm_scale
|
||||
Loading Parameters: (layer=4, size 10240) pre_ffw_norm_scale
|
||||
Loading Parameters: (layer=4, size 61440) ffw_gating_biases
|
||||
Loading Parameters: (layer=4, size 10240) ffw_output_biases
|
||||
Loading Parameters: (layer=5, size 26214400) attn_vec_einsum_w
|
||||
Loading Parameters: (layer=5, size 78643200) qkv_einsum_w
|
||||
Loading Parameters: (layer=5, size 157286400) gating_einsum_w
|
||||
Loading Parameters: (layer=5, size 78643200) linear_w
|
||||
Loading Parameters: (layer=5, size 10240) pre_attention_norm_scale
|
||||
Loading Parameters: (layer=5, size 10240) pre_ffw_norm_scale
|
||||
Loading Parameters: (layer=5, size 61440) ffw_gating_biases
|
||||
Loading Parameters: (layer=5, size 10240) ffw_output_biases
|
||||
Loading Parameters: (layer=5, size 10240) attention_output_biases
|
||||
Loading Parameters: (layer=6, size 26214400) griffin_linear_x_w
|
||||
Loading Parameters: (layer=6, size 10240) griffin_linear_x_biases
|
||||
Loading Parameters: (layer=6, size 26214400) griffin_linear_y_w
|
||||
Loading Parameters: (layer=6, size 10240) griffin_linear_y_biases
|
||||
Loading Parameters: (layer=6, size 26214400) griffin_linear_out_w
|
||||
Loading Parameters: (layer=6, size 10240) griffin_linear_out_biases
|
||||
Loading Parameters: (layer=6, size 40960) griffin_conv_w
|
||||
Loading Parameters: (layer=6, size 10240) griffin_conv_biases
|
||||
Loading Parameters: (layer=6, size 5242880) griffin_gate_w
|
||||
Loading Parameters: (layer=6, size 20480) griffin_gate_biases
|
||||
Loading Parameters: (layer=6, size 10240) griffin_a
|
||||
Loading Parameters: (layer=6, size 157286400) gating_einsum_w
|
||||
Loading Parameters: (layer=6, size 78643200) linear_w
|
||||
Loading Parameters: (layer=6, size 10240) pre_attention_norm_scale
|
||||
Loading Parameters: (layer=6, size 10240) pre_ffw_norm_scale
|
||||
Loading Parameters: (layer=6, size 61440) ffw_gating_biases
|
||||
Loading Parameters: (layer=6, size 10240) ffw_output_biases
|
||||
Loading Parameters: (layer=7, size 26214400) griffin_linear_x_w
|
||||
Loading Parameters: (layer=7, size 10240) griffin_linear_x_biases
|
||||
Loading Parameters: (layer=7, size 26214400) griffin_linear_y_w
|
||||
Loading Parameters: (layer=7, size 10240) griffin_linear_y_biases
|
||||
Loading Parameters: (layer=7, size 26214400) griffin_linear_out_w
|
||||
Loading Parameters: (layer=7, size 10240) griffin_linear_out_biases
|
||||
Loading Parameters: (layer=7, size 40960) griffin_conv_w
|
||||
Loading Parameters: (layer=7, size 10240) griffin_conv_biases
|
||||
Loading Parameters: (layer=7, size 5242880) griffin_gate_w
|
||||
Loading Parameters: (layer=7, size 20480) griffin_gate_biases
|
||||
Loading Parameters: (layer=7, size 10240) griffin_a
|
||||
Loading Parameters: (layer=7, size 157286400) gating_einsum_w
|
||||
Loading Parameters: (layer=7, size 78643200) linear_w
|
||||
Loading Parameters: (layer=7, size 10240) pre_attention_norm_scale
|
||||
Loading Parameters: (layer=7, size 10240) pre_ffw_norm_scale
|
||||
Loading Parameters: (layer=7, size 61440) ffw_gating_biases
|
||||
Loading Parameters: (layer=7, size 10240) ffw_output_biases
|
||||
Loading Parameters: (layer=8, size 26214400) attn_vec_einsum_w
|
||||
Loading Parameters: (layer=8, size 78643200) qkv_einsum_w
|
||||
Loading Parameters: (layer=8, size 157286400) gating_einsum_w
|
||||
Loading Parameters: (layer=8, size 78643200) linear_w
|
||||
Loading Parameters: (layer=8, size 10240) pre_attention_norm_scale
|
||||
Loading Parameters: (layer=8, size 10240) pre_ffw_norm_scale
|
||||
Loading Parameters: (layer=8, size 61440) ffw_gating_biases
|
||||
Loading Parameters: (layer=8, size 10240) ffw_output_biases
|
||||
Loading Parameters: (layer=8, size 10240) attention_output_biases
|
||||
Loading Parameters: (layer=9, size 26214400) griffin_linear_x_w
|
||||
Loading Parameters: (layer=9, size 10240) griffin_linear_x_biases
|
||||
Loading Parameters: (layer=9, size 26214400) griffin_linear_y_w
|
||||
Loading Parameters: (layer=9, size 10240) griffin_linear_y_biases
|
||||
Loading Parameters: (layer=9, size 26214400) griffin_linear_out_w
|
||||
Loading Parameters: (layer=9, size 10240) griffin_linear_out_biases
|
||||
Loading Parameters: (layer=9, size 40960) griffin_conv_w
|
||||
Loading Parameters: (layer=9, size 10240) griffin_conv_biases
|
||||
Loading Parameters: (layer=9, size 5242880) griffin_gate_w
|
||||
Loading Parameters: (layer=9, size 20480) griffin_gate_biases
|
||||
Loading Parameters: (layer=9, size 10240) griffin_a
|
||||
Loading Parameters: (layer=9, size 157286400) gating_einsum_w
|
||||
Loading Parameters: (layer=9, size 78643200) linear_w
|
||||
Loading Parameters: (layer=9, size 10240) pre_attention_norm_scale
|
||||
Loading Parameters: (layer=9, size 10240) pre_ffw_norm_scale
|
||||
Loading Parameters: (layer=9, size 61440) ffw_gating_biases
|
||||
Loading Parameters: (layer=9, size 10240) ffw_output_biases
|
||||
Loading Parameters: (layer=10, size 26214400) griffin_linear_x_w
|
||||
Loading Parameters: (layer=10, size 10240) griffin_linear_x_biases
|
||||
Loading Parameters: (layer=10, size 26214400) griffin_linear_y_w
|
||||
Loading Parameters: (layer=10, size 10240) griffin_linear_y_biases
|
||||
Loading Parameters: (layer=10, size 26214400) griffin_linear_out_w
|
||||
Loading Parameters: (layer=10, size 10240) griffin_linear_out_biases
|
||||
Loading Parameters: (layer=10, size 40960) griffin_conv_w
|
||||
Loading Parameters: (layer=10, size 10240) griffin_conv_biases
|
||||
Loading Parameters: (layer=10, size 5242880) griffin_gate_w
|
||||
Loading Parameters: (layer=10, size 20480) griffin_gate_biases
|
||||
Loading Parameters: (layer=10, size 10240) griffin_a
|
||||
Loading Parameters: (layer=10, size 157286400) gating_einsum_w
|
||||
Loading Parameters: (layer=10, size 78643200) linear_w
|
||||
Loading Parameters: (layer=10, size 10240) pre_attention_norm_scale
|
||||
Loading Parameters: (layer=10, size 10240) pre_ffw_norm_scale
|
||||
Loading Parameters: (layer=10, size 61440) ffw_gating_biases
|
||||
Loading Parameters: (layer=10, size 10240) ffw_output_biases
|
||||
Loading Parameters: (layer=11, size 26214400) attn_vec_einsum_w
|
||||
Loading Parameters: (layer=11, size 78643200) qkv_einsum_w
|
||||
Loading Parameters: (layer=11, size 157286400) gating_einsum_w
|
||||
Loading Parameters: (layer=11, size 78643200) linear_w
|
||||
Loading Parameters: (layer=11, size 10240) pre_attention_norm_scale
|
||||
Loading Parameters: (layer=11, size 10240) pre_ffw_norm_scale
|
||||
Loading Parameters: (layer=11, size 61440) ffw_gating_biases
|
||||
Loading Parameters: (layer=11, size 10240) ffw_output_biases
|
||||
Loading Parameters: (layer=11, size 10240) attention_output_biases
|
||||
Loading Parameters: (layer=12, size 26214400) griffin_linear_x_w
|
||||
Loading Parameters: (layer=12, size 10240) griffin_linear_x_biases
|
||||
Loading Parameters: (layer=12, size 26214400) griffin_linear_y_w
|
||||
Loading Parameters: (layer=12, size 10240) griffin_linear_y_biases
|
||||
Loading Parameters: (layer=12, size 26214400) griffin_linear_out_w
|
||||
Loading Parameters: (layer=12, size 10240) griffin_linear_out_biases
|
||||
Loading Parameters: (layer=12, size 40960) griffin_conv_w
|
||||
Loading Parameters: (layer=12, size 10240) griffin_conv_biases
|
||||
Loading Parameters: (layer=12, size 5242880) griffin_gate_w
|
||||
Loading Parameters: (layer=12, size 20480) griffin_gate_biases
|
||||
Loading Parameters: (layer=12, size 10240) griffin_a
|
||||
Loading Parameters: (layer=12, size 157286400) gating_einsum_w
|
||||
Loading Parameters: (layer=12, size 78643200) linear_w
|
||||
Loading Parameters: (layer=12, size 10240) pre_attention_norm_scale
|
||||
Loading Parameters: (layer=12, size 10240) pre_ffw_norm_scale
|
||||
Loading Parameters: (layer=12, size 61440) ffw_gating_biases
|
||||
Loading Parameters: (layer=12, size 10240) ffw_output_biases
|
||||
Loading Parameters: (layer=13, size 26214400) griffin_linear_x_w
|
||||
Loading Parameters: (layer=13, size 10240) griffin_linear_x_biases
|
||||
Loading Parameters: (layer=13, size 26214400) griffin_linear_y_w
|
||||
Loading Parameters: (layer=13, size 10240) griffin_linear_y_biases
|
||||
Loading Parameters: (layer=13, size 26214400) griffin_linear_out_w
|
||||
Loading Parameters: (layer=13, size 10240) griffin_linear_out_biases
|
||||
Loading Parameters: (layer=13, size 40960) griffin_conv_w
|
||||
Loading Parameters: (layer=13, size 10240) griffin_conv_biases
|
||||
Loading Parameters: (layer=13, size 5242880) griffin_gate_w
|
||||
Loading Parameters: (layer=13, size 20480) griffin_gate_biases
|
||||
Loading Parameters: (layer=13, size 10240) griffin_a
|
||||
Loading Parameters: (layer=13, size 157286400) gating_einsum_w
|
||||
Loading Parameters: (layer=13, size 78643200) linear_w
|
||||
Loading Parameters: (layer=13, size 10240) pre_attention_norm_scale
|
||||
Loading Parameters: (layer=13, size 10240) pre_ffw_norm_scale
|
||||
Loading Parameters: (layer=13, size 61440) ffw_gating_biases
|
||||
Loading Parameters: (layer=13, size 10240) ffw_output_biases
|
||||
Loading Parameters: (layer=14, size 26214400) attn_vec_einsum_w
|
||||
Loading Parameters: (layer=14, size 78643200) qkv_einsum_w
|
||||
Loading Parameters: (layer=14, size 157286400) gating_einsum_w
|
||||
Loading Parameters: (layer=14, size 78643200) linear_w
|
||||
Loading Parameters: (layer=14, size 10240) pre_attention_norm_scale
|
||||
Loading Parameters: (layer=14, size 10240) pre_ffw_norm_scale
|
||||
Loading Parameters: (layer=14, size 61440) ffw_gating_biases
|
||||
Loading Parameters: (layer=14, size 10240) ffw_output_biases
|
||||
Loading Parameters: (layer=14, size 10240) attention_output_biases
|
||||
Loading Parameters: (layer=15, size 26214400) griffin_linear_x_w
|
||||
Loading Parameters: (layer=15, size 10240) griffin_linear_x_biases
|
||||
Loading Parameters: (layer=15, size 26214400) griffin_linear_y_w
|
||||
Loading Parameters: (layer=15, size 10240) griffin_linear_y_biases
|
||||
Loading Parameters: (layer=15, size 26214400) griffin_linear_out_w
|
||||
Loading Parameters: (layer=15, size 10240) griffin_linear_out_biases
|
||||
Loading Parameters: (layer=15, size 40960) griffin_conv_w
|
||||
Loading Parameters: (layer=15, size 10240) griffin_conv_biases
|
||||
Loading Parameters: (layer=15, size 5242880) griffin_gate_w
|
||||
Loading Parameters: (layer=15, size 20480) griffin_gate_biases
|
||||
Loading Parameters: (layer=15, size 10240) griffin_a
|
||||
Loading Parameters: (layer=15, size 157286400) gating_einsum_w
|
||||
Loading Parameters: (layer=15, size 78643200) linear_w
|
||||
Loading Parameters: (layer=15, size 10240) pre_attention_norm_scale
|
||||
Loading Parameters: (layer=15, size 10240) pre_ffw_norm_scale
|
||||
Loading Parameters: (layer=15, size 61440) ffw_gating_biases
|
||||
Loading Parameters: (layer=15, size 10240) ffw_output_biases
|
||||
Loading Parameters: (layer=16, size 26214400) griffin_linear_x_w
|
||||
Loading Parameters: (layer=16, size 10240) griffin_linear_x_biases
|
||||
Loading Parameters: (layer=16, size 26214400) griffin_linear_y_w
|
||||
Loading Parameters: (layer=16, size 10240) griffin_linear_y_biases
|
||||
Loading Parameters: (layer=16, size 26214400) griffin_linear_out_w
|
||||
Loading Parameters: (layer=16, size 10240) griffin_linear_out_biases
|
||||
Loading Parameters: (layer=16, size 40960) griffin_conv_w
|
||||
Loading Parameters: (layer=16, size 10240) griffin_conv_biases
|
||||
Loading Parameters: (layer=16, size 5242880) griffin_gate_w
|
||||
Loading Parameters: (layer=16, size 20480) griffin_gate_biases
|
||||
Loading Parameters: (layer=16, size 10240) griffin_a
|
||||
Loading Parameters: (layer=16, size 157286400) gating_einsum_w
|
||||
Loading Parameters: (layer=16, size 78643200) linear_w
|
||||
Loading Parameters: (layer=16, size 10240) pre_attention_norm_scale
|
||||
Loading Parameters: (layer=16, size 10240) pre_ffw_norm_scale
|
||||
Loading Parameters: (layer=16, size 61440) ffw_gating_biases
|
||||
Loading Parameters: (layer=16, size 10240) ffw_output_biases
|
||||
Loading Parameters: (layer=17, size 26214400) attn_vec_einsum_w
|
||||
Loading Parameters: (layer=17, size 78643200) qkv_einsum_w
|
||||
Loading Parameters: (layer=17, size 157286400) gating_einsum_w
|
||||
Loading Parameters: (layer=17, size 78643200) linear_w
|
||||
Loading Parameters: (layer=17, size 10240) pre_attention_norm_scale
|
||||
Loading Parameters: (layer=17, size 10240) pre_ffw_norm_scale
|
||||
Loading Parameters: (layer=17, size 61440) ffw_gating_biases
|
||||
Loading Parameters: (layer=17, size 10240) ffw_output_biases
|
||||
Loading Parameters: (layer=17, size 10240) attention_output_biases
|
||||
Loading Parameters: (layer=18, size 26214400) griffin_linear_x_w
|
||||
Loading Parameters: (layer=18, size 10240) griffin_linear_x_biases
|
||||
Loading Parameters: (layer=18, size 26214400) griffin_linear_y_w
|
||||
Loading Parameters: (layer=18, size 10240) griffin_linear_y_biases
|
||||
Loading Parameters: (layer=18, size 26214400) griffin_linear_out_w
|
||||
Loading Parameters: (layer=18, size 10240) griffin_linear_out_biases
|
||||
Loading Parameters: (layer=18, size 40960) griffin_conv_w
|
||||
Loading Parameters: (layer=18, size 10240) griffin_conv_biases
|
||||
Loading Parameters: (layer=18, size 5242880) griffin_gate_w
|
||||
Loading Parameters: (layer=18, size 20480) griffin_gate_biases
|
||||
Loading Parameters: (layer=18, size 10240) griffin_a
|
||||
Loading Parameters: (layer=18, size 157286400) gating_einsum_w
|
||||
Loading Parameters: (layer=18, size 78643200) linear_w
|
||||
Loading Parameters: (layer=18, size 10240) pre_attention_norm_scale
|
||||
Loading Parameters: (layer=18, size 10240) pre_ffw_norm_scale
|
||||
Loading Parameters: (layer=18, size 61440) ffw_gating_biases
|
||||
Loading Parameters: (layer=18, size 10240) ffw_output_biases
|
||||
Loading Parameters: (layer=19, size 26214400) griffin_linear_x_w
|
||||
Loading Parameters: (layer=19, size 10240) griffin_linear_x_biases
|
||||
Loading Parameters: (layer=19, size 26214400) griffin_linear_y_w
|
||||
Loading Parameters: (layer=19, size 10240) griffin_linear_y_biases
|
||||
Loading Parameters: (layer=19, size 26214400) griffin_linear_out_w
|
||||
Loading Parameters: (layer=19, size 10240) griffin_linear_out_biases
|
||||
Loading Parameters: (layer=19, size 40960) griffin_conv_w
|
||||
Loading Parameters: (layer=19, size 10240) griffin_conv_biases
|
||||
Loading Parameters: (layer=19, size 5242880) griffin_gate_w
|
||||
Loading Parameters: (layer=19, size 20480) griffin_gate_biases
|
||||
Loading Parameters: (layer=19, size 10240) griffin_a
|
||||
Loading Parameters: (layer=19, size 157286400) gating_einsum_w
|
||||
Loading Parameters: (layer=19, size 78643200) linear_w
|
||||
Loading Parameters: (layer=19, size 10240) pre_attention_norm_scale
|
||||
Loading Parameters: (layer=19, size 10240) pre_ffw_norm_scale
|
||||
Loading Parameters: (layer=19, size 61440) ffw_gating_biases
|
||||
Loading Parameters: (layer=19, size 10240) ffw_output_biases
|
||||
Loading Parameters: (layer=20, size 26214400) attn_vec_einsum_w
|
||||
Loading Parameters: (layer=20, size 78643200) qkv_einsum_w
|
||||
Loading Parameters: (layer=20, size 157286400) gating_einsum_w
|
||||
Loading Parameters: (layer=20, size 78643200) linear_w
|
||||
Loading Parameters: (layer=20, size 10240) pre_attention_norm_scale
|
||||
Loading Parameters: (layer=20, size 10240) pre_ffw_norm_scale
|
||||
Loading Parameters: (layer=20, size 61440) ffw_gating_biases
|
||||
Loading Parameters: (layer=20, size 10240) ffw_output_biases
|
||||
Loading Parameters: (layer=20, size 10240) attention_output_biases
|
||||
Loading Parameters: (layer=21, size 26214400) griffin_linear_x_w
|
||||
Loading Parameters: (layer=21, size 10240) griffin_linear_x_biases
|
||||
Loading Parameters: (layer=21, size 26214400) griffin_linear_y_w
|
||||
Loading Parameters: (layer=21, size 10240) griffin_linear_y_biases
|
||||
Loading Parameters: (layer=21, size 26214400) griffin_linear_out_w
|
||||
Loading Parameters: (layer=21, size 10240) griffin_linear_out_biases
|
||||
Loading Parameters: (layer=21, size 40960) griffin_conv_w
|
||||
Loading Parameters: (layer=21, size 10240) griffin_conv_biases
|
||||
Loading Parameters: (layer=21, size 5242880) griffin_gate_w
|
||||
Loading Parameters: (layer=21, size 20480) griffin_gate_biases
|
||||
Loading Parameters: (layer=21, size 10240) griffin_a
|
||||
Loading Parameters: (layer=21, size 157286400) gating_einsum_w
|
||||
Loading Parameters: (layer=21, size 78643200) linear_w
|
||||
Loading Parameters: (layer=21, size 10240) pre_attention_norm_scale
|
||||
Loading Parameters: (layer=21, size 10240) pre_ffw_norm_scale
|
||||
Loading Parameters: (layer=21, size 61440) ffw_gating_biases
|
||||
Loading Parameters: (layer=21, size 10240) ffw_output_biases
|
||||
Loading Parameters: (layer=22, size 26214400) griffin_linear_x_w
|
||||
Loading Parameters: (layer=22, size 10240) griffin_linear_x_biases
|
||||
Loading Parameters: (layer=22, size 26214400) griffin_linear_y_w
|
||||
Loading Parameters: (layer=22, size 10240) griffin_linear_y_biases
|
||||
Loading Parameters: (layer=22, size 26214400) griffin_linear_out_w
|
||||
Loading Parameters: (layer=22, size 10240) griffin_linear_out_biases
|
||||
Loading Parameters: (layer=22, size 40960) griffin_conv_w
|
||||
Loading Parameters: (layer=22, size 10240) griffin_conv_biases
|
||||
Loading Parameters: (layer=22, size 5242880) griffin_gate_w
|
||||
Loading Parameters: (layer=22, size 20480) griffin_gate_biases
|
||||
Loading Parameters: (layer=22, size 10240) griffin_a
|
||||
Loading Parameters: (layer=22, size 157286400) gating_einsum_w
|
||||
Loading Parameters: (layer=22, size 78643200) linear_w
|
||||
Loading Parameters: (layer=22, size 10240) pre_attention_norm_scale
|
||||
Loading Parameters: (layer=22, size 10240) pre_ffw_norm_scale
|
||||
Loading Parameters: (layer=22, size 61440) ffw_gating_biases
|
||||
Loading Parameters: (layer=22, size 10240) ffw_output_biases
|
||||
Loading Parameters: (layer=23, size 26214400) attn_vec_einsum_w
|
||||
Loading Parameters: (layer=23, size 78643200) qkv_einsum_w
|
||||
Loading Parameters: (layer=23, size 157286400) gating_einsum_w
|
||||
Loading Parameters: (layer=23, size 78643200) linear_w
|
||||
Loading Parameters: (layer=23, size 10240) pre_attention_norm_scale
|
||||
Loading Parameters: (layer=23, size 10240) pre_ffw_norm_scale
|
||||
Loading Parameters: (layer=23, size 61440) ffw_gating_biases
|
||||
Loading Parameters: (layer=23, size 10240) ffw_output_biases
|
||||
Loading Parameters: (layer=23, size 10240) attention_output_biases
|
||||
Loading Parameters: (layer=24, size 26214400) griffin_linear_x_w
|
||||
Loading Parameters: (layer=24, size 10240) griffin_linear_x_biases
|
||||
Loading Parameters: (layer=24, size 26214400) griffin_linear_y_w
|
||||
Loading Parameters: (layer=24, size 10240) griffin_linear_y_biases
|
||||
Loading Parameters: (layer=24, size 26214400) griffin_linear_out_w
|
||||
Loading Parameters: (layer=24, size 10240) griffin_linear_out_biases
|
||||
Loading Parameters: (layer=24, size 40960) griffin_conv_w
|
||||
Loading Parameters: (layer=24, size 10240) griffin_conv_biases
|
||||
Loading Parameters: (layer=24, size 5242880) griffin_gate_w
|
||||
Loading Parameters: (layer=24, size 20480) griffin_gate_biases
|
||||
Loading Parameters: (layer=24, size 10240) griffin_a
|
||||
Loading Parameters: (layer=24, size 157286400) gating_einsum_w
|
||||
Loading Parameters: (layer=24, size 78643200) linear_w
|
||||
Loading Parameters: (layer=24, size 10240) pre_attention_norm_scale
|
||||
Loading Parameters: (layer=24, size 10240) pre_ffw_norm_scale
|
||||
Loading Parameters: (layer=24, size 61440) ffw_gating_biases
|
||||
Loading Parameters: (layer=24, size 10240) ffw_output_biases
|
||||
Loading Parameters: (layer=25, size 26214400) griffin_linear_x_w
|
||||
Loading Parameters: (layer=25, size 10240) griffin_linear_x_biases
|
||||
Loading Parameters: (layer=25, size 26214400) griffin_linear_y_w
|
||||
Loading Parameters: (layer=25, size 10240) griffin_linear_y_biases
|
||||
Loading Parameters: (layer=25, size 26214400) griffin_linear_out_w
|
||||
Loading Parameters: (layer=25, size 10240) griffin_linear_out_biases
|
||||
Loading Parameters: (layer=25, size 40960) griffin_conv_w
|
||||
Loading Parameters: (layer=25, size 10240) griffin_conv_biases
|
||||
Loading Parameters: (layer=25, size 5242880) griffin_gate_w
|
||||
Loading Parameters: (layer=25, size 20480) griffin_gate_biases
|
||||
Loading Parameters: (layer=25, size 10240) griffin_a
|
||||
Loading Parameters: (layer=25, size 157286400) gating_einsum_w
|
||||
Loading Parameters: (layer=25, size 78643200) linear_w
|
||||
Loading Parameters: (layer=25, size 10240) pre_attention_norm_scale
|
||||
Loading Parameters: (layer=25, size 10240) pre_ffw_norm_scale
|
||||
Loading Parameters: (layer=25, size 61440) ffw_gating_biases
|
||||
Loading Parameters: (layer=25, size 10240) ffw_output_biases
|
||||
|
|
@ -1,371 +0,0 @@
|
|||
"""Transformations for python-trees representing the parameters of a ML model.
|
||||
|
||||
Important: This module assumes that byte-order is the same on the
|
||||
machine that serializes data and the machine that deserializes
|
||||
data. If, for example, numpy-data gets dumped, respectively loaded,
|
||||
with a dtype-specification of numpy.float32, on-file byte-order
|
||||
will be host byte order.
|
||||
|
||||
"""
|
||||
|
||||
import ast
|
||||
import hashlib
|
||||
import itertools
|
||||
import pprint
|
||||
import sys
|
||||
import time
|
||||
from typing import Any, Callable, Iterable, Iterator, Mapping, TypeVar
|
||||
|
||||
import numpy
|
||||
import pytree_transforms
|
||||
|
||||
|
||||
NT = TypeVar('NT')
|
||||
|
||||
|
||||
def ml_model_leaf_summary(path, x, sep=', '):
|
||||
"""Produces a textual summary of a leaf-node and its path.
|
||||
|
||||
Args:
|
||||
path: The path-to-root, as a reverse-order recursive
|
||||
pair of path-components, with `()` as root.
|
||||
x: The leaf-object.
|
||||
sep: the separator between description-elements.
|
||||
Default ', ' allows for convenient line-by-line processing
|
||||
(such as via grep, perl -ne, etc.), but using e.g. sep=',\n '
|
||||
might be more useful for human consumption.
|
||||
|
||||
Returns:
|
||||
A human-readable string providing information about the node.
|
||||
"""
|
||||
# Using `repr` for path-components to get a faithful presentation.
|
||||
# (...which still however would be somewat painful to correctly
|
||||
# split into components.)
|
||||
path_str = ','.join(map(repr,
|
||||
pytree_transforms.linearize_revtuple_path(path)))
|
||||
tx = type(x)
|
||||
mod = tx.__module__ # Either a module or a string like 'builtins'.
|
||||
modname = mod if isinstance(mod, str) else mod.__name__
|
||||
type_str = f'{modname}.{tx.__qualname__}'
|
||||
try:
|
||||
# `numpy.ndarray` instances have a `.data` property that gives access
|
||||
# to a buffer via which we can hashlib-fingerprint the numerical
|
||||
# contents. We here simply try to produce a fingerprint and also look
|
||||
# up the .dtype of the object. Technically, there is a somewhat-unsound
|
||||
# assumption here that if these operations succeed, we are indeed looking
|
||||
# at a ndarray or sufficiently similar object for these operations to
|
||||
# make sense. As the output is declared "for human consumption", this
|
||||
# fishiness is not a problem.
|
||||
fp = hashlib.sha256(x.data).hexdigest()
|
||||
start = list(itertools.islice(x.flat, 5))
|
||||
stats_str = (
|
||||
f'min={numpy.min(x):.6g}, max={numpy.max(x):.6g}, '
|
||||
f'mean={numpy.mean(x):.6g}, std={numpy.std(x):.6g}')
|
||||
return (f'{path_str:60s}: <{type_str}{sep}'
|
||||
f'fp=0x{fp}{sep}{stats_str}{sep}shape={x.shape}, '
|
||||
f'dtype={x.dtype}{sep}start={start}>')
|
||||
except (AttributeError, ValueError, TypeError):
|
||||
# Fallback - trying to include information about the data-content
|
||||
# of a likely-numerical-array failed.
|
||||
return f'{path_str:60s}: {type_str}({repr(x)})'
|
||||
|
||||
|
||||
# A specialized node-handler.
|
||||
# Interface follows node-handler expectations defined in pytree_transforms.
|
||||
def _ml_model_tree_node_handler(path: tuple, node : NT) -> (
|
||||
None | tuple[Callable[[Iterable[tuple[Any, NT]]], NT],
|
||||
Iterator[tuple[Any, NT]]]):
|
||||
"""Processes a tree-node as required by pytree-iteration and -mapping.
|
||||
|
||||
Args:
|
||||
path: revtuple path to the current node.
|
||||
node: a tree-node in a ML-model tree that is recursively
|
||||
built out of `numpy.ndarray` leaf-values and dicts mapping
|
||||
node-name string-keys to other such nodes representing subtrees -
|
||||
and nothing else.
|
||||
|
||||
Returns:
|
||||
`None` if the tree-node is to be regarded as a leaf, otherwise
|
||||
a pair `(rebuilder, iterator)`, where `iterator` iterates
|
||||
over the data-content of the node, each item represented as a pair
|
||||
of `(lookup_path_item, value_item)`, and `rebuilder` is a function
|
||||
which, when applied to `iterator` or any iterable with the same
|
||||
elements, returns a node that is equivalent to the original.
|
||||
|
||||
Raises:
|
||||
NotAMLModelTreeNodeError: If the tree contains a node that is neither
|
||||
a `dict` nor a `numpy.ndarray` instance.
|
||||
"""
|
||||
# The astute reader will notice that we are doing something fishy
|
||||
# here - this code could not be translated to Haskell as-is, since
|
||||
# `NT` cannot actually be a proper type-variable in the sense
|
||||
# of parametric polymorphism.
|
||||
del path # Unused.
|
||||
if isinstance(node, dict):
|
||||
return dict, iter(node.items())
|
||||
if isinstance(node, numpy.ndarray):
|
||||
return None
|
||||
raise pytree_transforms.NotAMLModelTreeNodeError(
|
||||
f'Type of bad node: {type(node)}')
|
||||
|
||||
|
||||
def _ml_model_extract_leaf_transform(
|
||||
path: pytree_transforms.RevTuplePath,
|
||||
leaf: Any):
|
||||
"""Maps an array-leaf to a pair `(full_path, lambda: array)`.
|
||||
|
||||
The computation that produces the leaf-value is lazified underneath
|
||||
a `lambda`, since if we e.g. performed a memory-expensive
|
||||
transformation (such as some dtype-changes) directly at this point,
|
||||
then going from an iterator over tree-items for one-by-one
|
||||
consumption to a list of these items would have all the
|
||||
dtype-transformed values around simultaneously. We want to avoid
|
||||
situations where we can do nothing about having multiple variants
|
||||
of the data simultaneously in memory.
|
||||
"""
|
||||
# Hack: If we are encountering a `bfloat16` numpy-array,
|
||||
# we pretend to have the data as a numpy.float32 array,
|
||||
# since that's about all that contemporary CPUs can process
|
||||
# efficiently here.
|
||||
linearized_path = pytree_transforms.linearize_revtuple_path(path)
|
||||
try:
|
||||
# We have to use some trickery to detect `bfloat16`.
|
||||
if leaf.dtype.descr[-1] == ('', '<V2'):
|
||||
return linearized_path, lambda: leaf.astype(numpy.float32)
|
||||
else:
|
||||
return linearized_path, lambda: leaf
|
||||
except Exception:
|
||||
return linearized_path, lambda: leaf
|
||||
|
||||
|
||||
# Here, we cannot properly specify the return-type, since this can
|
||||
# either be a leaf-type or something recursively-defined.
|
||||
def revtuple_autovifify_from_linear(
|
||||
keys_and_vals: Iterable[tuple[Any, Any]]) -> Any:
|
||||
"""Performs perl-style autovivification on a nested-dict tree.
|
||||
|
||||
Args:
|
||||
keys_and_vals: An iterable of pairs `(key_path, value)`, where
|
||||
`key_path` is a sequence of keys to be used to navigate to
|
||||
the result via iterative dict-lookup, left-to-right.
|
||||
Must not have duplicate keys, and must not more than one key if
|
||||
an empty-sequence key is present. If this iterable is an
|
||||
iterator, it will be fully exhausted on successful execution.
|
||||
|
||||
Returns:
|
||||
An object representing a nested-dict structure such that
|
||||
for every `key_path` from `keys_and_vals`, recursive-dict-lookup
|
||||
on the elements of that path starting from this object will
|
||||
produce the corresponding value. An empty `keys_and_vals`
|
||||
set will return `{}`. Every dict in the nested return-value
|
||||
that has been populated by autovivification is newly allocated.
|
||||
"""
|
||||
# Code structure is a bit gnarly here due to f(keys_and_vals=[((), x)])
|
||||
# having to evaluate to x and not a dict.
|
||||
# There may be ways to prettify/simplify this.
|
||||
result = None
|
||||
empty = {}
|
||||
for linear_path, val in keys_and_vals:
|
||||
if linear_path == ():
|
||||
if result is not None:
|
||||
raise ValueError('Root-value seen alongside other values.')
|
||||
result = val
|
||||
else:
|
||||
if result is None:
|
||||
result = {}
|
||||
elif type(result) is not dict:
|
||||
# We already did encounter a root-value.
|
||||
raise ValueError('Root-value seen alongside other values.')
|
||||
cursor = result
|
||||
for n in range(len(linear_path) - 1):
|
||||
cursor = cursor.setdefault(linear_path[n], empty)
|
||||
if cursor is empty:
|
||||
# Regenerate `empty` if we just used it up.
|
||||
empty = {}
|
||||
cursor[linear_path[-1]] = val
|
||||
return {} if result is None else result
|
||||
|
||||
|
||||
def model_overview(tree, out=None) -> None:
|
||||
"""Prints a human-readable overview to `(out or sys.stdout)`."""
|
||||
actual_out = out or sys.stdout
|
||||
for line in pytree_transforms.pytree_leaf_iter(
|
||||
tree, ml_model_leaf_summary,
|
||||
_ml_model_tree_node_handler):
|
||||
print(line, file=actual_out)
|
||||
|
||||
|
||||
def model_contents(tree) -> Mapping[tuple[str, ...], Any]:
|
||||
"""Maps a model to a {pytree_keys: data_array} mapping.
|
||||
|
||||
Args:
|
||||
tree: The ML-model parameter-tree, built recursively out of
|
||||
dict-instances with numpy.ndarray instances as leaves.
|
||||
|
||||
Returns:
|
||||
A mapping from linearized pytree-key-sequence tuple to the corresponding
|
||||
leaf-value.
|
||||
"""
|
||||
def leaf_transform(revtuple_path, leaf):
|
||||
return pytree_transforms.linearize_revtuple_path(revtuple_path), leaf
|
||||
return dict(
|
||||
pytree_transforms.pytree_leaf_iter(
|
||||
tree, leaf_transform, _ml_model_tree_node_handler))
|
||||
|
||||
|
||||
def _fn_identity(x): return x
|
||||
|
||||
|
||||
def model_save(tree,
|
||||
filepath_stem: str,
|
||||
data_suffix: str = '.data',
|
||||
manifest_suffix: str | None = '.manifest',
|
||||
key: Callable[[tuple[str, ...]], Any] | None = None,
|
||||
array_transform_by_pytree_key: (
|
||||
Mapping[tuple[str, ...],
|
||||
Callable[[numpy.ndarray], numpy.ndarray]] |
|
||||
None) = None,
|
||||
report: Callable[[str], None] | None = None,
|
||||
byte_align: int = 8) -> tuple[int, float]:
|
||||
"""Saves the content of a ML-model parameter-tree to filesystem.
|
||||
|
||||
After successful execution, the file f"{filepath_stem}.data"
|
||||
will hold the combined numerical model-parameters, and
|
||||
f"{filepath_stem}.manifest" will contain the key for interpreting
|
||||
(and rebuilding) the data.
|
||||
|
||||
Args:
|
||||
tree: The ML-model parameter-tree, built recursively out of
|
||||
dict-instances with numpy.ndarray instances as leaves.
|
||||
filepath_stem: Filesystem location for data.
|
||||
data_suffix: Suffix to use for the data file.
|
||||
manifest_suffix: Either `None`, in which case no manifest-file
|
||||
will get written, or the suffix for the manifest-file.
|
||||
key: `None` or a key-function that will be applied to the linear model-path
|
||||
and used for sorting the data arrays by increasing value of the
|
||||
key-function. If the key-function returns `None` on an item,
|
||||
then this item is not included.
|
||||
array_transform_by_pytree_key: Optional mapping from pytree-key
|
||||
to an array-to-array transformation function to apply to the array
|
||||
prior to serialization.
|
||||
report: Optional callable for logging progress-reports.
|
||||
byte_align: byte-alignment to use for numerical array data.
|
||||
Numerical arrays whose size in bytes is not a multiple of this
|
||||
will get padded to the next full multiple.
|
||||
|
||||
Returns:
|
||||
A pair of `(size, time_sec)`, where `size` is the total byte-size
|
||||
of the `.data` file and `time_sec` is the elapsed time
|
||||
for saving the model, in seconds.
|
||||
"""
|
||||
time0 = time.monotonic()
|
||||
if array_transform_by_pytree_key is None:
|
||||
array_transform_by_pytree_key = {}
|
||||
model_lazy_items = (
|
||||
pytree_transforms.pytree_leaf_iter(
|
||||
tree, _ml_model_extract_leaf_transform,
|
||||
_ml_model_tree_node_handler))
|
||||
if key is not None:
|
||||
to_write = [
|
||||
nkv[1:] for nkv in sorted(
|
||||
(nkv for nkv in ((key(path), path, v)
|
||||
for path, v in model_lazy_items)
|
||||
if nkv[0] is not None), key=lambda nkv: nkv[0])]
|
||||
else:
|
||||
to_write = list(model_lazy_items)
|
||||
#
|
||||
def lazy_arr_path_shape_dtype_size(path_and_lazy_arr):
|
||||
path, lazy_arr = path_and_lazy_arr
|
||||
arr = array_transform_by_pytree_key.get(path, _fn_identity)(lazy_arr())
|
||||
return path, arr.shape, arr.dtype, arr.data.nbytes
|
||||
arrs_path_shape_dtype_nbytes = list(
|
||||
map(lazy_arr_path_shape_dtype_size, to_write))
|
||||
# We need to know the total size of all the data.
|
||||
bytesizes = [nbytes for *_, nbytes in arrs_path_shape_dtype_nbytes]
|
||||
padded_bytesizes = [-(-bytesize // byte_align * byte_align)
|
||||
for bytesize in bytesizes]
|
||||
offsets = numpy.cumsum([0] + padded_bytesizes)
|
||||
membuf = numpy.memmap(filepath_stem + data_suffix,
|
||||
mode='w+', shape=offsets[-1])
|
||||
try:
|
||||
for (path, shape, dtype, nbytes), offset, (_, lazy_arr) in zip(
|
||||
arrs_path_shape_dtype_nbytes, offsets, to_write):
|
||||
# Note that if getting the array from the lazy lambda involved some
|
||||
# computation, such as a copying dtype-change, that computation would
|
||||
# end up being done multiple times here - including once above, to compute
|
||||
# byte-sizes, and once more here.
|
||||
transformed_arr = array_transform_by_pytree_key.get(
|
||||
path,
|
||||
_fn_identity)(lazy_arr())
|
||||
membuf[offset : offset + nbytes] = numpy.frombuffer(
|
||||
transformed_arr.ravel().data, 'u1')
|
||||
if report is not None:
|
||||
samples = ', '.join(map(str, transformed_arr.ravel()[:5]))
|
||||
report(f'# Adding: {path!r}\n bytes: {nbytes:10d}, '
|
||||
f'shape: {shape!r:30},\n start: [{samples}, ...]')
|
||||
transformed_arr = None # Drop memory references to numerical arrays ASAP.
|
||||
finally:
|
||||
if membuf is not None:
|
||||
membuf.flush()
|
||||
# NumPy wart: the memory-buffer is a resource that conceptually
|
||||
# should be .close()able - since mmap()ing holds on to a
|
||||
# file descriptor. However, it looks as if that clean-up were done
|
||||
# in the "finalizer", despite that having meanwhile been widely
|
||||
# understood as dubious practice. So, the best we can do here is
|
||||
# to explicitly and clearly remove our reference to the instance.
|
||||
del membuf
|
||||
if manifest_suffix is not None:
|
||||
# We still have to serialize the data that allows us to reconstruct
|
||||
# a tree that is equivalent to the original.
|
||||
manifest_data = [
|
||||
dict(path=path,
|
||||
dtype=dtype.descr[-1][-1],
|
||||
shape=shape,
|
||||
nbytes=nbytes,
|
||||
offset=offset)
|
||||
for (path, shape, dtype, nbytes), offset in zip(
|
||||
arrs_path_shape_dtype_nbytes, offsets)]
|
||||
with open(filepath_stem + '.manifest', 'wt') as h_manifest:
|
||||
pprint.pprint(manifest_data, stream=h_manifest)
|
||||
time_taken = time.monotonic() - time0
|
||||
return offsets[-1], time_taken
|
||||
|
||||
|
||||
def model_load(filepath_stem, mmapped=True):
|
||||
"""Loads a model saved by `model_save`.
|
||||
|
||||
Tries to load the model from f"{filepath_stem}.data"
|
||||
and f"{filepath_stem}.manifest".
|
||||
|
||||
Args:
|
||||
filepath_stem: The model location on the filesystem.
|
||||
mmapped: Whether data-arrays will be slices of a
|
||||
`numpy.memmap` mapped buffer, to be paged in
|
||||
on demand only, or in-memory copies of the data.
|
||||
Returns:
|
||||
A dict/numpy.ndarray tree representation of the model,
|
||||
equivalent to the original model.
|
||||
"""
|
||||
with open(filepath_stem + '.manifest', 'rt') as h_manifest:
|
||||
manifest = ast.literal_eval(h_manifest.read())
|
||||
membuf = numpy.memmap(filepath_stem + '.data', mode='r+')
|
||||
paths_and_arrays = []
|
||||
for item in manifest:
|
||||
path = item['path']
|
||||
dtype = numpy.dtype(item['dtype'])
|
||||
shape = item['shape']
|
||||
nbytes = item['nbytes']
|
||||
offset = item['offset']
|
||||
data_array = numpy.frombuffer(membuf[offset : offset + nbytes].data,
|
||||
dtype=dtype).reshape(shape)
|
||||
paths_and_arrays.append(
|
||||
(path,
|
||||
data_array if mmapped else data_array.copy()))
|
||||
# At this point, the memory-buffer is no longer needed. Still, if
|
||||
# data-arrays retain references to the underlying data
|
||||
# (i.e. when mmapped=False), this should keep the mapping
|
||||
# - and hence file descriptor - open. We then are in a somewhat
|
||||
# undesirable situation of clean-up of a resource that happens in a
|
||||
# hard-to-predict way releasing a file descriptor.
|
||||
del membuf
|
||||
return revtuple_autovifify_from_linear(paths_and_arrays)
|
||||
|
|
@ -1,92 +0,0 @@
|
|||
"""Basic tests for 'algebraic data type based pytree' transformations."""
|
||||
|
||||
|
||||
import io
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy
|
||||
|
||||
import ml_model_transforms
|
||||
|
||||
|
||||
def _get_model(prefix):
|
||||
return {
|
||||
prefix + 'a1': numpy.arange(1000, 1024).reshape(6, 4).astype(numpy.float32),
|
||||
prefix + 'a2': numpy.arange(2000, 2048).reshape(6, 8).astype(numpy.float32),
|
||||
prefix + 'b1': {
|
||||
prefix + 'c1': numpy.arange(100, 127).reshape(3, 3, 3).astype(numpy.int8),
|
||||
prefix + 'c2': numpy.arange(100, 128).reshape(7, 4).astype(numpy.float64)
|
||||
}}
|
||||
|
||||
|
||||
class MLModeltransformsTest(unittest.TestCase):
|
||||
"""Basic correctness validation tests for ML-model transformations."""
|
||||
|
||||
def test_ml_model_leaf_summary(self):
|
||||
"""Tests guarantees given by `ml_model_leaf_summary`."""
|
||||
summary = ml_model_transforms.ml_model_leaf_summary(
|
||||
('a', ()),
|
||||
numpy.arange(1000, 1024).reshape(6, 4).astype(numpy.int16),
|
||||
sep='##')
|
||||
self.assertIn('##', summary) # Separator is respected.
|
||||
self.assertIn('(6, 4)', summary) # Shape is mentioned somewhere.
|
||||
self.assertIn('int16', summary) # dtype is mentioned somewhere.
|
||||
|
||||
def test_revtuple_autovivify_from_linear(self):
|
||||
"""Tests guarantees given by `revtuple_autovifify_from_linear`."""
|
||||
with self.subTest(guarantee='empty'):
|
||||
self.assertEqual(
|
||||
ml_model_transforms.revtuple_autovifify_from_linear([]),
|
||||
{})
|
||||
with self.subTest(guarantee='generic'):
|
||||
keys_vals = [(('a', 'b1', 'c1'), 1001),
|
||||
(('a', 'b2'), 1002),
|
||||
(('a2',), 1003),
|
||||
]
|
||||
self.assertEqual(
|
||||
ml_model_transforms.revtuple_autovifify_from_linear(keys_vals),
|
||||
{'a': {'b1': {'c1': 1001}, 'b2': 1002}, 'a2': 1003})
|
||||
|
||||
def test_model_overview(self):
|
||||
"""Tests guarantees given by `model_overview`."""
|
||||
model = _get_model('xyz')
|
||||
out_io = io.StringIO()
|
||||
ml_model_transforms.model_overview(model, out=out_io)
|
||||
overview = out_io.getvalue()
|
||||
self.assertIn('xyz', overview)
|
||||
|
||||
def test_model_contents(self):
|
||||
"""Tests guarantees given by `model_contents`."""
|
||||
model = _get_model('pq_')
|
||||
contents = ml_model_transforms.model_contents(model)
|
||||
fingerprints = {k: (a.shape, a.ravel()[:3].tolist())
|
||||
for k, a in contents.items()}
|
||||
self.assertEqual(fingerprints,
|
||||
{('pq_a1',): ((6, 4), [1000.0, 1001.0, 1002.0]),
|
||||
('pq_a2',): ((6, 8), [2000.0, 2001.0, 2002.0]),
|
||||
('pq_b1', 'pq_c1'): ((3, 3, 3), [100, 101, 102]),
|
||||
('pq_b1', 'pq_c2'): ((7, 4), [100.0, 101.0, 102.0])})
|
||||
|
||||
def test_model_save_load_basic(self):
|
||||
"""Tests basic guarantees given by `model_save` and `model_load`."""
|
||||
# What we care about here is that the round trip works - so
|
||||
# it makes more sense to test saving and loading as one unit.
|
||||
model_orig = _get_model('model_')
|
||||
with tempfile.TemporaryDirectory() as tempdir:
|
||||
filepath_stem = os.path.join(tempdir, 'the_model')
|
||||
total_size, total_time = ml_model_transforms.model_save(model_orig,
|
||||
filepath_stem)
|
||||
self.assertGreater(total_size, 0)
|
||||
self.assertGreater(total_time, 0)
|
||||
model_reloaded = ml_model_transforms.model_load(filepath_stem)
|
||||
contents_orig = ml_model_transforms.model_contents(model_orig)
|
||||
contents_reloaded = ml_model_transforms.model_contents(model_reloaded)
|
||||
self.assertEqual(
|
||||
{k: v.tolist() for k, v in contents_orig.items()},
|
||||
{k: v.tolist() for k, v in contents_reloaded.items()})
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
@ -1,508 +0,0 @@
|
|||
"""Tools for transforming "nested python object" tree data structures.
|
||||
|
||||
# Context
|
||||
|
||||
The motivation for this module came from ML applications that ought to
|
||||
be based on a principled handling of nested Python data structures.
|
||||
Having such principled pytree-transforming code available solves
|
||||
some other problems, such as doing away with a need to abuse
|
||||
tree-mapping for-side-effect-only and having to use a hope-and-pray
|
||||
approach to processing very deeply nested values which with a recursive
|
||||
approach might trigger a RecursionError.
|
||||
|
||||
We specifically want to cover the use case of having ML model
|
||||
parameters that are available in a nested Python data structure for
|
||||
which there "almost" is a unique-up-to-unique-isomorphism mapping from
|
||||
and to this Algebraic Data Type:
|
||||
|
||||
`data ModelParams a = Array a | Node [(String, ModelParams a)]`
|
||||
|
||||
In this correspondence, `a` is some array-type (perhaps
|
||||
`numpy.ndarray`, `jax.numpy.ndarray`, `tf.tensor`, etc.), but the
|
||||
data-processing code is effectively entirely agnostic to this, and a
|
||||
`Node` is "almost" an associative-list of (key, value) pairs
|
||||
representing a Python dict. (Note: The "almost" here is mostly about
|
||||
the conceptual wart that assoc-lists can in principle have key
|
||||
duplicates, but Python dicts can not. This is however not a problem
|
||||
since all we need is the transformation in one direction,
|
||||
i.e. whatever data-processing `f` we want to express on the
|
||||
model-parameters-pytree, we can express by specifying a "faithful"
|
||||
mapping `m` into the above algebraic data type through which every
|
||||
such pytree data transform factorizes, i.e. for every `f` we can find
|
||||
a `g` such that `f(p) = g(m(p))`.)
|
||||
|
||||
## Components
|
||||
|
||||
The main workhorse in this module is the `pytree_iter` function that
|
||||
maps a "PyTree (such as representing `ModelParams`)" to an iterator
|
||||
over values obtained by applying a mapping-function to the "key-path"
|
||||
and leaf-value for every leaf, where the "key-path" contains a
|
||||
linked-list representation of the reversed sequence of keys from the
|
||||
tree-root, with list-nodes being represented by pairs
|
||||
`(latest_dict_key, rest_path)`, and the empty path being represented
|
||||
by `()`.
|
||||
|
||||
For the sake of genericity, `pytree_iter` is built in such a way that
|
||||
it actually can handle any kind of traversal of PyTree-trees that do
|
||||
represent algebraic data types (note however that some some do not) -
|
||||
but for this to make sense, the user must have a way to define how to
|
||||
interpret tree-nodes, in particular identify leaves. This requires
|
||||
providing a function `node_handler` with the same signature and
|
||||
behavior as described below for "node handlers".
|
||||
|
||||
Additionally, this module provides mapping-over-pytrees via
|
||||
`pytree_map`, which is also built in such a way that it makes the
|
||||
correspondence between an algebraic data type and its Python
|
||||
nested-tree representation explicit. Despite being powerful and
|
||||
flexible, this, however, may in general require a bit more effort to
|
||||
wire up, since node-rebuilding can be fairly nontrivial.
|
||||
|
||||
Furthermore, as a prominent application, this module provides a simple
|
||||
deep-freezing function that translates a nested Python data structure
|
||||
to deeply-immutable form.
|
||||
|
||||
## Concepts and Conventions
|
||||
|
||||
"revtuple representation":
|
||||
|
||||
As we iterate over a tree, we will have to keep track of the
|
||||
path-to-tree-root. Naturally, two sibling nodes `n1` and `n2`
|
||||
will share the same parent-path (being siblings), so it makes
|
||||
sense to use a linked-list-with-shared-tail representation.
|
||||
Python does not have a natural notion for that, so we use
|
||||
recursively-constructed tuples `(node_tag, parent_path)`
|
||||
that represent the path-from-root in-reverse-order, i.e.
|
||||
for a non-empty path `p`, `p[0]` is the node-tag at the
|
||||
deepest nesting level. We call this a "revtuple representation"
|
||||
of the path.
|
||||
|
||||
"node handler":
|
||||
|
||||
A node-handler classifies a tree-node as "leaf or other node", and
|
||||
for non-leaf nodes provides information about both its children and
|
||||
how to rebuild it. The behavior of a node-handler function must be
|
||||
in alignment with this docstring:
|
||||
|
||||
'''Processes a tree-node as required by pytree-iteration and -mapping.
|
||||
|
||||
Args:
|
||||
revtuple_path: Revtuple-representation of the path-from-root
|
||||
to the current node.
|
||||
node: a tree-node in a ML-model tree that is recursively
|
||||
built out of leaf-values and other nodes.
|
||||
|
||||
Returns:
|
||||
`None` if the tree-node is to be regarded as a leaf, otherwise
|
||||
a pair `(rebuilder, iterator)`, where `iterator` iterates
|
||||
over the data-content of the node, each item represented as a pair
|
||||
of `(lookup_path_item, value_item)`, and `rebuilder` is a function
|
||||
which, when applied to an iterable of the aforementioned value-items
|
||||
(or some transformation thereof) returns a node that is equivalent
|
||||
to the original (or up to a transformation of the contents).
|
||||
|
||||
Raises:
|
||||
InvalidTreeNodeError: If the tree contains a node of a kind
|
||||
that is not expected to show up.
|
||||
'''
|
||||
|
||||
Examples:
|
||||
|
||||
(The behavior of a node-handler is somewhat nontrivial, so covering
|
||||
two very common cases via examples is in order.)
|
||||
|
||||
This node-handler would allow descending into (nested)
|
||||
instances of `list` (but not subclass instances thereof):
|
||||
|
||||
```def list_node_handler(revtuple_path, obj):
|
||||
''' ... '''
|
||||
if type(obj) is list:
|
||||
return list, enumerate(obj)
|
||||
else:
|
||||
return None
|
||||
```
|
||||
|
||||
This node-handler would allow descending into (nested) mappings,
|
||||
which upon rebuilding would get turned into `dict` instances:
|
||||
|
||||
```def mapping_node_handler(revtuple_path, obj):
|
||||
''' ... '''
|
||||
if isinstance(obj, collections.abc.Mapping):
|
||||
# For generic mappings, we cannot rely on key- and item-iteration
|
||||
# being guaranteed to use identical iteration-order.
|
||||
items = list(obj.items())
|
||||
keys = [kv[0] for kv in items]
|
||||
return (lambda values: dict(zip(keys, values))), items
|
||||
else:
|
||||
return None
|
||||
```
|
||||
|
||||
A dict/mapping node-handler can of course rename keys, add or remove
|
||||
entries, make decisions based on the item-path, or map a dict to
|
||||
an associative list, etc.
|
||||
|
||||
## Further Design Notes
|
||||
|
||||
The `pytree_map` function requests the leaf-transform and node-handler
|
||||
to be side-effect-free functions. This is both required to leave
|
||||
implementation-side flexibility, and also follows the general LISP
|
||||
recommendation to not abuse mapping (which should be a pure
|
||||
data-transformation) for imperative data processing. Overall, if
|
||||
a need for more general "nested datastructures" processing becomes
|
||||
pressing, it is for the better if this leads to a proper articulation
|
||||
of the specific needs, to be addressed with appropriate design, rather
|
||||
than abuse of functional data-transforms becoming "a bad idiom
|
||||
that turned into established practice".
|
||||
|
||||
"""
|
||||
|
||||
import collections.abc
|
||||
import immutabledict
|
||||
|
||||
import numpy
|
||||
|
||||
from typing import Any, Callable, Iterable, Iterator, TypeVar
|
||||
|
||||
|
||||
T = TypeVar('T')
|
||||
U = TypeVar('U')
|
||||
|
||||
KT = TypeVar('KT')
|
||||
NT = TypeVar('NT')
|
||||
|
||||
|
||||
## Type of the reverse-order-keys-to-root path.
|
||||
# (This code actually illustrates why https://xkcd.com/2483/ is very misguided.)
|
||||
RevTuplePath = tuple
|
||||
|
||||
## Type of the `leaf_transform` function-argument used for tree-iteration.
|
||||
#
|
||||
# This would be the correct type we would have to specify here but cannot,
|
||||
# since the design of Python's static typing at the time of this writing
|
||||
# is too broken for that:
|
||||
#
|
||||
# type LeafTransformFunc[L, R] = Callable[[RevTuplePath, L], R]
|
||||
#
|
||||
# Instead, we have to settle for...:
|
||||
LeafTransformFunc = Callable[[RevTuplePath, Any], Any]
|
||||
|
||||
|
||||
## Type of the `tree_node_handler` function-argument used for
|
||||
## tree-iteration and tree-mapping.
|
||||
#
|
||||
# Again, this is the correct type we would have to put here but cannot:
|
||||
#
|
||||
# type NodeHandlerFunc[KT] = (
|
||||
# Callable[[NT],
|
||||
# None | tuple[Callable[[Iterable[tuple[KT, NT]]], NT],
|
||||
# Iterator[tuple[KT, NT]]]])
|
||||
#
|
||||
# ...so, we have to instead settle for:
|
||||
NodeHandlerFunc = (
|
||||
Callable[[RevTuplePath, NT],
|
||||
None | tuple[Callable[[Iterable[tuple[Any, NT]]], NT],
|
||||
Iterator[tuple[Any, NT]]]])
|
||||
|
||||
|
||||
Predicate = Callable[[object], bool]
|
||||
|
||||
|
||||
class InvalidTreeNodeError(ValueError):
|
||||
"""Encountered a tree-node of invalid type."""
|
||||
|
||||
|
||||
def linearize_revtuple_path(
|
||||
revtuple_path: RevTuplePath,
|
||||
present_as: Callable[[Iterator[T]], U] = tuple) -> U:
|
||||
"""Translates a revtuple path to (typically) linear form.
|
||||
|
||||
With default `present_as`, this will map a path of the form
|
||||
`(key_{N}, (key_{N-1}, ..., (root, ())))` into a tuple
|
||||
(root, ..., key_{N-1}, key_{N}).
|
||||
|
||||
Args:
|
||||
revtuple_path: A linked-list-as-recursive-pairs
|
||||
reverse-order tuple-representation of the path.
|
||||
Path-root is `()`, and node-key `x` relative to
|
||||
earlier path `p` is represented as `(x, p)`.
|
||||
present_as: Callable that consumes an iterator over
|
||||
path-pieces - with the deepest-nesting level coming last -
|
||||
turning it into a linearized path. Defaults to `tuple`.
|
||||
|
||||
Returns:
|
||||
Linearized presentation of all the node-keys in the
|
||||
recursive-path in order, deepest-down path component coming last.
|
||||
"""
|
||||
pieces = []
|
||||
todo = revtuple_path
|
||||
while todo:
|
||||
node, todo = todo
|
||||
pieces.append(node)
|
||||
return present_as(reversed(pieces))
|
||||
|
||||
|
||||
# This function itself has type `NodeHandlerFunc`, but Python does not
|
||||
# allow us to here simply type-annotate it like this. We cannot even
|
||||
# introduce an abbreviation for the complicated output-type,
|
||||
# since that would have to be parametric in node-type `NT` (and `KT`).
|
||||
def everything_is_a_leaf_node_handler(
|
||||
revtuple_path: tuple,
|
||||
node : NT) -> (
|
||||
None | tuple[Callable[[Iterable[tuple[Any, NT]]], NT],
|
||||
Iterator[tuple[Any, NT]]]):
|
||||
"""Processes a tree-node as required by pytree-iteration and -mapping.
|
||||
|
||||
Interface and signature are in alignment with the requirements for a
|
||||
"node handler" function explained in the module-docstring.
|
||||
|
||||
Args:
|
||||
revtuple_path: the path-to-root for this node.
|
||||
node: a tree-node.
|
||||
|
||||
Returns:
|
||||
`None`, i.e. classifying any kind of node as a leaf-node.
|
||||
"""
|
||||
del revtuple_path, node # Unused.
|
||||
return None
|
||||
|
||||
|
||||
def leaf_summary(path: RevTuplePath, x: object):
|
||||
"""Produces a human-readable summary-string for a leaf-node.
|
||||
|
||||
Args:
|
||||
path: revtuple representation of the path-to-root.
|
||||
x: The leaf-value.
|
||||
"""
|
||||
del path # Ignored here.
|
||||
tx = type(x)
|
||||
mod = tx.__module__
|
||||
modname = mod if isinstance(mod, str) else mod.__name__
|
||||
type_str = f'{modname}.{tx.__qualname__}'
|
||||
repr_str = repr(x)
|
||||
repr_abbrev = repr_str if len(repr_str) < 40 else repr_str[:40] + ' ...'
|
||||
# On str, int, float, etc. `{type_str}(repr(x))` would actually still be
|
||||
# a (non-literal) Python-expression that would evaluate to the original value.
|
||||
# However, we make no promises beyond "human-readable".
|
||||
return f'{type_str}({repr_abbrev})'
|
||||
|
||||
|
||||
# With respect to static type annotations, the limitations of Python's
|
||||
# approach to static typing really become prominently visible here.
|
||||
#
|
||||
# Different arguments have type-parameters, but since there is no way
|
||||
# to have parametric abbreviations such as `LeafTransformFunc[L, R]`,
|
||||
# the only way we would have available to express relations between
|
||||
# type-parameters would be to substitute in the not-abbreviated form of
|
||||
# `NodeHandlerFunc` and `LeafTransformFunc`, giving us something monstrous.
|
||||
# We instead here settle for "we cannot express that `tree` must
|
||||
# have the same type as the input-type to `tree_node_handler` and use `Any`,
|
||||
# and likewise for leaf_transform and the output.
|
||||
def pytree_leaf_iter(
|
||||
tree: Any,
|
||||
leaf_transform: LeafTransformFunc,
|
||||
node_handler: NodeHandlerFunc = everything_is_a_leaf_node_handler,
|
||||
) -> Iterator[Any]:
|
||||
# ...actual return type would be `Iterator[{what leaf_transform returns}]`.
|
||||
"""Iterates over the leaves of a tree.
|
||||
|
||||
Args:
|
||||
tree: The tree to iterate over.
|
||||
leaf_transform: A callable `f` that will get applied
|
||||
as `f(revtuple_path, leaf)`, where `revtuple_path`
|
||||
is the revtuple representation of the path to the
|
||||
leaf from the root.
|
||||
node_handler: A "node handler" (see module docstring)
|
||||
that processes nodes encountered during iterative traversal.
|
||||
|
||||
Yields:
|
||||
Value of `leaf_transform(p, x)`, where `x` is the current leaf
|
||||
and `p` is its revtuple-path to the root.
|
||||
"""
|
||||
# Note: Exit points for the code below are in non-obvious places
|
||||
# and hence marked with " # ***EXIT***".
|
||||
#
|
||||
# Doing iteration properly is slightly nontrivial.
|
||||
# One may be tempted to go for a very simple recursive implementation
|
||||
# (with an extra pre-final `path` argument to `pytree_iter`):
|
||||
#
|
||||
# maybe_substructure = node_handler(path, tree)
|
||||
# if maybe_substructure is None:
|
||||
# # We are looking at a leaf-node.
|
||||
# yield leaf_transform(path, tree)
|
||||
# else:
|
||||
# _, contents_iter = maybe_substructure
|
||||
# for k, v in contents_iter:
|
||||
# yield from pytree_iter(v, leaf_transform, (k, path), node_handler)
|
||||
#
|
||||
# That, however, would be flawed, since there is no a priori reason
|
||||
# why a pytree may not be a very deeply nested structure - such as a
|
||||
# long linked list. That would then risk raising `RecursionError`,
|
||||
# and since Python by design(!) does not perform tail call elimination
|
||||
# or any other kind of advanced CPS transforms, there is no recursive
|
||||
# solution here. So, to do this properly, we have to do this iteratively.
|
||||
#
|
||||
# We are facing an annoying situation here: If `tree` itself is a leaf,
|
||||
# we have two options: (a) wrapping it up in a one-node tree
|
||||
# and processing that, or (b) special-casing "root is a leaf".
|
||||
# Option (b) leads to some mild node-processing code-duplication
|
||||
# for a single node (the root).
|
||||
# Option (a) requires having special cases for node-processing that
|
||||
# get looked at for every tree node. We go with option (b) here.
|
||||
maybe_substructure = node_handler((), tree)
|
||||
if maybe_substructure is None:
|
||||
# The tree itself is a leaf.
|
||||
yield leaf_transform((), tree)
|
||||
return # ***EXIT***
|
||||
# Otherwise, we are looking at a tree.
|
||||
_, contents_iter = maybe_substructure
|
||||
current_revtuple_path = ()
|
||||
work_to_do = [contents_iter]
|
||||
# Otherwise-unreachable sentinel for reliably identifying
|
||||
# iterator-exhaustion without using exceptions:
|
||||
sentinel = object()
|
||||
while True:
|
||||
current_iter = work_to_do[-1]
|
||||
maybe_next_item = next(current_iter, sentinel)
|
||||
if maybe_next_item is sentinel:
|
||||
# We are done at this level.
|
||||
work_to_do.pop()
|
||||
if not work_to_do: return # ***EXIT***
|
||||
current_revtuple_path = current_revtuple_path[1]
|
||||
else:
|
||||
path_piece, subtree = maybe_next_item
|
||||
extended_revtuple_path = (path_piece, current_revtuple_path)
|
||||
maybe_subtree_substructure = node_handler(extended_revtuple_path, subtree)
|
||||
if maybe_subtree_substructure is None: # Case: subtree is a leaf.
|
||||
yield leaf_transform(extended_revtuple_path, subtree)
|
||||
else: # Case: subtree is a tree.
|
||||
current_revtuple_path = (path_piece, current_revtuple_path)
|
||||
_, items_iter = maybe_subtree_substructure
|
||||
work_to_do.append(items_iter)
|
||||
|
||||
|
||||
# The current design approach here would be appropriate for
|
||||
# applying leaf-transforms while retaining the structure of the tree -
|
||||
# which closely corresponds to e.g. a (a -> b) -> (Tree a -> Tree b) functor.
|
||||
#
|
||||
# It is not entirely clear whether this is the abstraction that we should
|
||||
# consider as being appropriately generic to flesh out explicitly - rather
|
||||
# than starting from a more general approach of which this then is a special
|
||||
# case. Some background: https://ncatlab.org/nlab/show/recursion+scheme
|
||||
#
|
||||
# On the other hand, there is a lot of flexibility via whatever
|
||||
# node-rebuilder a node-handler produces - this can do quite some reshaping
|
||||
# of a tree, including dropping or duplicating nodes.
|
||||
def pytree_map(
|
||||
tree: Any,
|
||||
leaf_transform,
|
||||
node_handler: NodeHandlerFunc = everything_is_a_leaf_node_handler,
|
||||
):
|
||||
"""Maps a (potentially nested) Python value to another such value.
|
||||
|
||||
Args:
|
||||
tree: The Python-object to be mapped.
|
||||
leaf_transform: A callable `f` that will get applied
|
||||
as `f(revtuple_path, leaf)`, where `revtuple_path`
|
||||
is the revtuple representation of the path to the
|
||||
leaf from the root. Must be side effect free.
|
||||
node_handler: A "node handler" (see module docstring)
|
||||
that processes nodes encountered during iterative traversal.
|
||||
Must be side effect free.
|
||||
|
||||
Returns:
|
||||
The outcome of translating `tree`.
|
||||
"""
|
||||
# Note: Exit points for the code below are in non-obvious places
|
||||
# and hence marked with " # ***EXIT***".
|
||||
#
|
||||
# Otherwise-inaccessible sentinel object, for reliably identifying
|
||||
# missing-values via identity-check against sentinel lookup-default.
|
||||
sentinel = object()
|
||||
# Code structure mostly follows pytree_leaf_iter.
|
||||
maybe_substructure = node_handler((), tree)
|
||||
if maybe_substructure is None:
|
||||
return leaf_transform((), tree) # ***EXIT***
|
||||
rebuilder, items_iter = maybe_substructure
|
||||
current_revtuple_path = ()
|
||||
# Per-level, we have a triplet of:
|
||||
# (rebuilder, remaining_items_to_iterate_over, processed).
|
||||
parts_for_assembly = [(rebuilder, items_iter, [])]
|
||||
while True:
|
||||
this_rebuilder, this_items_iter, this_done_pieces = parts_for_assembly[-1]
|
||||
maybe_next_item = next(this_items_iter, sentinel)
|
||||
if maybe_next_item is sentinel:
|
||||
# We are done with all the items for this level.
|
||||
parts_for_assembly.pop()
|
||||
built_iter = this_rebuilder(this_done_pieces)
|
||||
if not parts_for_assembly: # No outer structure, so at-top-level.
|
||||
return built_iter # ***EXIT***
|
||||
else: # We have outer structure.
|
||||
parts_for_assembly[-1][-1].append(built_iter)
|
||||
current_revtuple_path = current_revtuple_path[1]
|
||||
continue # ...with next is-the-final-item-complete-check.
|
||||
else:
|
||||
# More constituents of the current item.
|
||||
path_piece, subtree_item = maybe_next_item
|
||||
extended_revtuple_path = (path_piece, current_revtuple_path)
|
||||
maybe_subtree_substructure = node_handler(
|
||||
extended_revtuple_path,
|
||||
subtree_item)
|
||||
if maybe_subtree_substructure is None:
|
||||
this_done_pieces.append(
|
||||
leaf_transform(extended_revtuple_path, subtree_item))
|
||||
else:
|
||||
# We have a subtree.
|
||||
subtree_rebuilder, subtree_items_iter = maybe_subtree_substructure
|
||||
current_revtuple_path = (path_piece,
|
||||
current_revtuple_path)
|
||||
parts_for_assembly.append(
|
||||
(subtree_rebuilder, subtree_items_iter, []))
|
||||
|
||||
|
||||
def deep_freeze(
|
||||
tree,
|
||||
*,
|
||||
is_mapping : Predicate = lambda x: isinstance(x, collections.abc.Mapping),
|
||||
is_set : Predicate = lambda x: isinstance(x, collections.abc.Set),
|
||||
is_sequence : Predicate = lambda x: isinstance(x, (list, tuple)),
|
||||
leaf_fn: Callable[[Any], Any] = lambda x: x,
|
||||
):
|
||||
"""Recursively freezes Set/Mapping/List/Tuple structures.
|
||||
|
||||
Args:
|
||||
tree: The potentially deeply-nested object to deep-freeze.
|
||||
is_mapping: Callable that decides whether a sub-object is a mapping.
|
||||
Defaults to an `isinstance()` check for `collections.abc.Mapping`.
|
||||
is_set: Callable that decides whether a sub-object is a set.
|
||||
Defaults to an `isinstance()` check for `collections.abc.Set`.
|
||||
is_sequence: Callable that decides whether a sub-object is a sequence.
|
||||
Defaults to a check for being a `tuple` or `list` instance.
|
||||
leaf_fn: Function to use for translating non-mapping/set/sequence
|
||||
instances.
|
||||
|
||||
Returns:
|
||||
Translated-to-deeply-immutable form of `tree`.
|
||||
"""
|
||||
idict = immutabledict.immutabledict
|
||||
def freeze_node_handler(path, x):
|
||||
if is_set(x):
|
||||
return frozenset, ((None, y) for y in x)
|
||||
if is_mapping(x):
|
||||
# Mappings already have hashable, so
|
||||
# (should-be-)deeply-immutable keys.
|
||||
# Hence, we only need to deep-freeze the values.
|
||||
#
|
||||
# Note that non-`dict` mappings might not guarantee
|
||||
# to respect iteration-order, so we have to be careful here:
|
||||
items = list(x.items())
|
||||
keys = [kv[0] for kv in items]
|
||||
values = [kv[1] for kv in items]
|
||||
return ((lambda ys: idict(zip(keys, ys))),
|
||||
iter(items))
|
||||
if is_sequence(x):
|
||||
return tuple, enumerate(iter(x))
|
||||
# Otherwise, this should not be traversed.
|
||||
return None
|
||||
def leaf_transform(revtuple_path, value):
|
||||
del revtuple_path # Unused.
|
||||
return leaf_fn(value)
|
||||
return pytree_map(tree, leaf_transform, freeze_node_handler)
|
||||
|
|
@ -1,168 +0,0 @@
|
|||
"""Basic tests for 'algebraic data type based pytree' transformations."""
|
||||
|
||||
|
||||
import collections.abc
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
import pytree_transforms
|
||||
|
||||
|
||||
def _get_deep_pytree(packaging_fn, bottom, depth):
|
||||
current = bottom
|
||||
for n in reversed(range(depth)):
|
||||
current = packaging_fn(n, current)
|
||||
return current
|
||||
|
||||
|
||||
def _dict_node_handler(p, d):
|
||||
del p # Unused.
|
||||
if isinstance(d, dict):
|
||||
keys = d.keys()
|
||||
newdict = lambda vals: dict(zip(keys, vals))
|
||||
return (newdict, iter(d.items()))
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
class PyTreeTest(unittest.TestCase):
|
||||
"""Basic correctness validation tests for PyTree transformations."""
|
||||
|
||||
def test_linearize_revtuple_path(self):
|
||||
"""Tests guarantees given by `linearize_revtuple_path`."""
|
||||
linearize_revtuple_path = pytree_transforms.linearize_revtuple_path
|
||||
with self.subTest(guarantee='empty'):
|
||||
self.assertEqual(linearize_revtuple_path(()), ())
|
||||
with self.subTest(guarantee='typical'):
|
||||
self.assertEqual(linearize_revtuple_path((30, (20, (10, ())))),
|
||||
(10, 20, 30))
|
||||
with self.subTest(guarantee='present_as'):
|
||||
self.assertEqual(
|
||||
linearize_revtuple_path(
|
||||
(30, (20, (10, ()))), present_as=list),
|
||||
[10, 20, 30])
|
||||
|
||||
def test_everything_is_a_leaf_node_handler(self):
|
||||
"""Tests guarantees given by `everything_is_a_leaf_node_handler`."""
|
||||
everything_is_a_leaf_node_handler = (
|
||||
pytree_transforms.everything_is_a_leaf_node_handler)
|
||||
self.assertEqual(everything_is_a_leaf_node_handler((), 'abc'),
|
||||
None)
|
||||
self.assertEqual(everything_is_a_leaf_node_handler(('b', ()),
|
||||
dict(a=3)),
|
||||
None)
|
||||
|
||||
def test_leaf_summary(self):
|
||||
"""Tests guarantees given by `leaf_summary`."""
|
||||
# Since the docstring only guarantees "a human-readable presentation",
|
||||
# we can and should only do loose checks.
|
||||
thing = (5678, 9531)
|
||||
summary = pytree_transforms.leaf_summary(('key', ()), thing)
|
||||
self.assertIsInstance(summary, str)
|
||||
self.assertIn(str(thing[0]), summary)
|
||||
self.assertIn(str(thing[1]), summary)
|
||||
|
||||
def test_pytree_leaf_iter(self):
|
||||
"""Tests guarantees given by `pytree_leaf_iter`."""
|
||||
pytree_leaf_iter = pytree_transforms.pytree_leaf_iter
|
||||
def leaf_transform(path, leaf):
|
||||
return repr(leaf) if path and path[0].startswith('R') else leaf
|
||||
with self.subTest(guarantee='returns_iterator'):
|
||||
result = pytree_leaf_iter(7, leaf_transform, _dict_node_handler)
|
||||
self.assertIsInstance(result, collections.abc.Iterator)
|
||||
with self.subTest(guarantee='totally_empty'):
|
||||
result = list(pytree_leaf_iter({}, leaf_transform, _dict_node_handler))
|
||||
self.assertEqual(result, [])
|
||||
with self.subTest(guarantee='no_leaves'):
|
||||
result = list(pytree_leaf_iter(dict(a={}),
|
||||
leaf_transform, _dict_node_handler))
|
||||
self.assertEqual(result, [])
|
||||
with self.subTest(guarantee='is_leaf'):
|
||||
result = list(pytree_leaf_iter(777, leaf_transform, _dict_node_handler))
|
||||
self.assertEqual(result, [777])
|
||||
with self.subTest(guarantee='generic'):
|
||||
result = list(pytree_leaf_iter(
|
||||
dict(n0=dict(n01=dict(n012=1002,
|
||||
n013=1003,
|
||||
Rn014=1004,
|
||||
),
|
||||
n02=1005),
|
||||
n5=1006),
|
||||
leaf_transform, _dict_node_handler))
|
||||
self.assertEqual(result, [1002, 1003, '1004', 1005, 1006])
|
||||
with self.subTest(guarantee='with_keys'):
|
||||
result = list(pytree_leaf_iter(
|
||||
dict(n0=dict(n01=dict(n012=1002,
|
||||
n013=1003)),
|
||||
n1=1004),
|
||||
lambda p, s: (pytree_transforms.linearize_revtuple_path(p), s),
|
||||
_dict_node_handler))
|
||||
self.assertEqual(result,
|
||||
[(('n0', 'n01', 'n012'), 1002),
|
||||
(('n0', 'n01', 'n013'), 1003),
|
||||
(('n1',), 1004)])
|
||||
|
||||
def test_pytree_map(self):
|
||||
"""Tests guarantees given by `pytree_map`."""
|
||||
pytree_map = pytree_transforms.pytree_map
|
||||
leaf_transform = lambda p, s: repr(s)
|
||||
tree1 = dict(t0=dict(t10=1001,
|
||||
t11=dict(t110=1002,
|
||||
t111=1003),
|
||||
t12=dict(t120=1004,
|
||||
t121=1005,
|
||||
t122=1006)),
|
||||
t1=1007)
|
||||
with self.subTest(guarantee='no_leaves'):
|
||||
result = pytree_map(dict(a={}),
|
||||
leaf_transform,
|
||||
_dict_node_handler)
|
||||
self.assertEqual(result, dict(a={}))
|
||||
with self.subTest(guarantee='is_leaf'):
|
||||
result = pytree_map(777, leaf_transform, _dict_node_handler)
|
||||
self.assertEqual(result, '777')
|
||||
with self.subTest(guarantee='generic'):
|
||||
result = pytree_map(tree1, leaf_transform, _dict_node_handler)
|
||||
self.assertEqual(result['t0']['t10'], '1001')
|
||||
|
||||
def test_deeply_nested(self):
|
||||
"""Tests correct behavior on deeply-nested data structures."""
|
||||
pytree_leaf_iter = pytree_transforms.pytree_leaf_iter
|
||||
pytree_map = pytree_transforms.pytree_map
|
||||
#
|
||||
depth = max(10**5, sys.getrecursionlimit() + 100)
|
||||
deep_tree = _get_deep_pytree(lambda n, t: {n: t},
|
||||
'leaf', depth)
|
||||
with self.subTest(function='pytree_leaf_iter'):
|
||||
leaves = list(pytree_leaf_iter(deep_tree,
|
||||
lambda p, s: s.upper(),
|
||||
_dict_node_handler))
|
||||
self.assertEqual(leaves, ['LEAF'])
|
||||
with self.subTest(function='pytree_map'):
|
||||
mapped_deep_tree = pytree_map(deep_tree,
|
||||
lambda p, s: s,
|
||||
_dict_node_handler)
|
||||
self.assertIsInstance(mapped_deep_tree, dict)
|
||||
with self.subTest(function='combined'):
|
||||
leaves = list(
|
||||
pytree_leaf_iter(
|
||||
pytree_map(deep_tree,
|
||||
lambda p, s: s.capitalize(),
|
||||
_dict_node_handler),
|
||||
lambda p, s: s + s,
|
||||
_dict_node_handler))
|
||||
self.assertEqual(leaves, ['LeafLeaf'])
|
||||
|
||||
def test_deep_freeze(self):
|
||||
"""Tests guarantees given by `deep_freeze`."""
|
||||
frozen = pytree_transforms.deep_freeze(
|
||||
dict(a=[1001, 1002, dict(b=(1003, [1004, {1005, 1006}]))]))
|
||||
self.assertIsInstance(frozen, collections.abc.Mapping)
|
||||
self.assertNotIsInstance(frozen, collections.abc.MutableMapping)
|
||||
self.assertIsInstance(frozen['a'], tuple)
|
||||
# `frozen` is hashable, and hashes to an integer.
|
||||
self.assertIsInstance(hash(frozen), int)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
@ -1,4 +0,0 @@
|
|||
immutabledict>=4.2.0
|
||||
numpy>=1.26.4
|
||||
orbax-checkpoint>=0.0.0
|
||||
|
||||
|
|
@ -158,9 +158,6 @@ TEST_F(GemmaTest, CrossEntropySmall) {
|
|||
float entropy = s_env->CrossEntropy(kSmall);
|
||||
fprintf(stderr, "per-token entropy: %f\n", entropy);
|
||||
switch (config.model) {
|
||||
case gcpp::Model::GRIFFIN_2B:
|
||||
EXPECT_NEAR(entropy, 2.61f, 0.02f);
|
||||
break;
|
||||
case gcpp::Model::GEMMA2_2B:
|
||||
EXPECT_NEAR(entropy, 1.14f, 0.02f);
|
||||
break;
|
||||
|
|
|
|||
|
|
@ -31,32 +31,6 @@
|
|||
|
||||
namespace gcpp {
|
||||
|
||||
struct GriffinActivations {
|
||||
GriffinActivations(const ModelConfig& config, size_t batch_size,
|
||||
const Allocator& allocator)
|
||||
: griffin_x(
|
||||
MatFactory("griffin_x", batch_size, config.model_dim, allocator)),
|
||||
griffin_y(
|
||||
MatFactory("griffin_y", batch_size, config.model_dim, allocator)),
|
||||
griffin_gate_x(MatFactory("griffin_gate_x", batch_size,
|
||||
config.model_dim, allocator)),
|
||||
griffin_multiplier(MatFactory("griffin_mul", batch_size,
|
||||
config.model_dim, allocator)) {}
|
||||
|
||||
void SetBatchSize(size_t batch_size) {
|
||||
if (griffin_x.Rows() == 0) return;
|
||||
griffin_x.OverrideRows(batch_size);
|
||||
griffin_y.OverrideRows(batch_size);
|
||||
griffin_gate_x.OverrideRows(batch_size);
|
||||
griffin_multiplier.OverrideRows(batch_size);
|
||||
}
|
||||
|
||||
MatStorageT<float> griffin_x;
|
||||
MatStorageT<float> griffin_y;
|
||||
MatStorageT<float> griffin_gate_x;
|
||||
MatStorageT<float> griffin_multiplier;
|
||||
};
|
||||
|
||||
struct AttentionActivations {
|
||||
// Returns the scale value to use for the query in the attention computation.
|
||||
// Also called by ops_test.
|
||||
|
|
@ -143,7 +117,7 @@ struct AttentionActivations {
|
|||
MatStorageT<float> inv_timescale_global;
|
||||
|
||||
hwy::Divisor div_seq_len;
|
||||
// Unfortunately, some models (Griffin) have non-power-of-two heads.
|
||||
// Unfortunately, some models have had non-power-of-two heads.
|
||||
hwy::Divisor div_heads;
|
||||
float query_scale;
|
||||
};
|
||||
|
|
@ -169,9 +143,7 @@ struct Activations {
|
|||
MatFactory("ffw_out", batch_size, config.model_dim, ctx.allocator)),
|
||||
|
||||
attention(config, layer_config, batch_size, seq_len, ctx.allocator,
|
||||
row_ptrs),
|
||||
griffin(config, config.model == Model::GRIFFIN_2B ? batch_size : 0,
|
||||
ctx.allocator) {
|
||||
row_ptrs) {
|
||||
HWY_ASSERT(batch_size != 0);
|
||||
|
||||
// For MatMul outputs, precompute their row pointers.
|
||||
|
|
@ -199,7 +171,6 @@ struct Activations {
|
|||
ffw_out.OverrideRows(batch_size);
|
||||
|
||||
attention.SetBatchSize(batch_size);
|
||||
griffin.SetBatchSize(batch_size);
|
||||
}
|
||||
|
||||
const LayerConfig& layer_config;
|
||||
|
|
@ -215,7 +186,6 @@ struct Activations {
|
|||
MatStorageT<float> ffw_out;
|
||||
|
||||
AttentionActivations attention;
|
||||
GriffinActivations griffin;
|
||||
};
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
|
|
@ -327,6 +327,7 @@ static HWY_INLINE void SumHeads(const LayerWeightsPtrs& layer,
|
|||
MatMulEnv& env) {
|
||||
PROFILER_ZONE("Gen.Attention.SumHeads");
|
||||
const LayerConfig& layer_config = layer.layer_config;
|
||||
(void)layer_config; // For HWY_DASSERT
|
||||
// att_weights and att_out are concatenated heads, each of length
|
||||
// layer_config.qkv_dim. Thus the [num_interleaved,
|
||||
// layer_config.model_dim] matmul output is the sum over heads. Compare
|
||||
|
|
@ -334,10 +335,7 @@ static HWY_INLINE void SumHeads(const LayerWeightsPtrs& layer,
|
|||
// encoded)
|
||||
HWY_DASSERT(layer_config.model_dim != 0 && layer_config.heads != 0 &&
|
||||
layer_config.qkv_dim != 0);
|
||||
const float* add = layer_config.softmax_attn_output_biases
|
||||
? layer.attention_output_biases.PackedScale1()
|
||||
: nullptr;
|
||||
CallMatMul(activations.att_out, layer.att_weights, add, env,
|
||||
CallMatMul(activations.att_out, layer.att_weights, /*add=*/nullptr, env,
|
||||
activations.att_sums);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -133,78 +133,6 @@ static ModelConfig ConfigGemma2_2B() {
|
|||
return config;
|
||||
}
|
||||
|
||||
static LayerConfig LayerConfigGemmaTiny(size_t model_dim) {
|
||||
LayerConfig config;
|
||||
config.model_dim = model_dim;
|
||||
config.ff_hidden_dim = 256;
|
||||
config.heads = 4;
|
||||
config.kv_heads = 1;
|
||||
config.qkv_dim = 16;
|
||||
return config;
|
||||
}
|
||||
|
||||
static ModelConfig ConfigGemmaTiny() {
|
||||
ModelConfig config = ConfigNoSSM();
|
||||
config.display_name = "GemmaTiny";
|
||||
config.model = Model::GEMMA_TINY;
|
||||
config.wrapping = PromptWrapping::GEMMA_IT;
|
||||
config.model_dim = 32;
|
||||
config.vocab_size = 32; // at least two f32 vectors
|
||||
config.max_seq_len = 32;
|
||||
LayerConfig layer_config = LayerConfigGemmaTiny(config.model_dim);
|
||||
config.num_layers = 2;
|
||||
config.layer_configs = {config.num_layers, layer_config};
|
||||
config.query_scale = QueryScaleType::SqrtKeySize;
|
||||
config.attention_window_sizes = FixedAttentionWindowSizes<2>(32);
|
||||
config.att_cap = 50.0f;
|
||||
config.final_cap = 30.0f;
|
||||
config.eos_id = 11;
|
||||
config.secondary_eos_id = 11;
|
||||
return config;
|
||||
}
|
||||
|
||||
static LayerConfig LayerConfigGriffin2B(size_t model_dim) {
|
||||
LayerConfig config;
|
||||
config.model_dim = model_dim;
|
||||
config.griffin_dim = model_dim;
|
||||
config.ff_hidden_dim = 7680;
|
||||
config.heads = 10;
|
||||
config.kv_heads = 1;
|
||||
config.qkv_dim = 256;
|
||||
config.conv1d_width = 4;
|
||||
HWY_DASSERT(config.conv1d_width <= kMaxConv1DWidth);
|
||||
config.ff_biases = true;
|
||||
config.softmax_attn_output_biases = true;
|
||||
config.optimized_gating = false;
|
||||
config.type = LayerAttentionType::kGriffinRecurrentBlock;
|
||||
config.activation = ActivationType::Gelu;
|
||||
config.post_qk = PostQKType::HalfRope;
|
||||
return config;
|
||||
}
|
||||
|
||||
static ModelConfig ConfigGriffin2B() {
|
||||
ModelConfig config = ConfigNoSSM();
|
||||
config.display_name = "Griffin2B";
|
||||
config.model = Model::GRIFFIN_2B;
|
||||
// Griffin uses local attention, so max_seq_len is actually the local
|
||||
// attention window.
|
||||
config.model_dim = 2560;
|
||||
config.vocab_size = kVocabSize;
|
||||
config.max_seq_len = 2048;
|
||||
LayerConfig layer_config = LayerConfigGriffin2B(config.model_dim);
|
||||
config.num_layers = 26;
|
||||
config.layer_configs = {config.num_layers, layer_config};
|
||||
for (size_t i = 2; i < config.num_layers; i += 3) {
|
||||
config.layer_configs[i].type = LayerAttentionType::kGemma;
|
||||
config.layer_configs[i].griffin_dim = 0;
|
||||
}
|
||||
config.attention_window_sizes =
|
||||
FixedAttentionWindowSizes<26>(config.max_seq_len);
|
||||
config.use_local_attention = true;
|
||||
config.final_cap = 0.0f;
|
||||
return config;
|
||||
}
|
||||
|
||||
static LayerConfig LayerConfigVit(size_t model_dim) {
|
||||
LayerConfig config;
|
||||
config.model_dim = model_dim;
|
||||
|
|
@ -510,10 +438,6 @@ static ModelConfig ConfigFromModel(Model model) {
|
|||
return ConfigGemma2_9B();
|
||||
case Model::GEMMA2_27B:
|
||||
return ConfigGemma2_27B();
|
||||
case Model::GRIFFIN_2B:
|
||||
return ConfigGriffin2B();
|
||||
case Model::GEMMA_TINY:
|
||||
return ConfigGemmaTiny();
|
||||
case Model::PALIGEMMA2_3B_224:
|
||||
return ConfigPaliGemma2_3B_224();
|
||||
case Model::PALIGEMMA2_3B_448:
|
||||
|
|
@ -547,10 +471,6 @@ const char* ModelPrefix(Model model) {
|
|||
return "9b";
|
||||
case Model::GEMMA2_27B:
|
||||
return "27b";
|
||||
case Model::GRIFFIN_2B:
|
||||
return "gr2b";
|
||||
case Model::GEMMA_TINY:
|
||||
return "tiny";
|
||||
case Model::PALIGEMMA2_3B_224:
|
||||
return "paligemma2-3b-224";
|
||||
case Model::PALIGEMMA2_3B_448:
|
||||
|
|
@ -750,13 +670,10 @@ bool ModelConfig::OverwriteWithCanonical() {
|
|||
|
||||
Model DeduceModel(const Path& blob_path, size_t layers, int layer_types) {
|
||||
switch (layers) {
|
||||
case 2:
|
||||
return Model::GEMMA_TINY;
|
||||
case 18:
|
||||
return Model::GEMMA3_270M;
|
||||
|
||||
case 26:
|
||||
if (layer_types & kDeducedGriffin) return Model::GRIFFIN_2B;
|
||||
if (layer_types & kDeducedViT) return Model::GEMMA3_1B;
|
||||
return Model::GEMMA2_2B;
|
||||
case 27:
|
||||
|
|
|
|||
|
|
@ -68,14 +68,11 @@ static inline bool EnumValid(PromptWrapping wrapping) {
|
|||
|
||||
enum class LayerAttentionType {
|
||||
kGemma,
|
||||
kGriffinRecurrentBlock,
|
||||
kVit,
|
||||
};
|
||||
|
||||
static inline bool EnumValid(LayerAttentionType type) {
|
||||
return type == LayerAttentionType::kGemma ||
|
||||
type == LayerAttentionType::kGriffinRecurrentBlock ||
|
||||
type == LayerAttentionType::kVit;
|
||||
return type == LayerAttentionType::kGemma || type == LayerAttentionType::kVit;
|
||||
}
|
||||
|
||||
// Post attention and ffw normalization type.
|
||||
|
|
@ -163,9 +160,8 @@ enum class Model {
|
|||
// 1 and 2 are obsolete.
|
||||
GEMMA2_9B = 3,
|
||||
GEMMA2_27B,
|
||||
GRIFFIN_2B,
|
||||
GEMMA_TINY, // for testing only
|
||||
GEMMA2_2B,
|
||||
// 5 and 6 are obsolete.
|
||||
GEMMA2_2B = 7,
|
||||
// 8 and 9 are obsolete.
|
||||
PALIGEMMA2_3B_224 = 10,
|
||||
PALIGEMMA2_3B_448,
|
||||
|
|
@ -199,13 +195,19 @@ static inline bool IsPaliGemma(Model model) {
|
|||
return false;
|
||||
}
|
||||
|
||||
static inline bool IsObsolete(Model model) {
|
||||
const size_t i = static_cast<size_t>(model);
|
||||
if (i == 5 || i == 6 || i == 8 || i == 9) return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
// Visits every valid model enum, skipping `UNKNOWN` and `kSentinel`.
|
||||
template <class Func>
|
||||
void ForEachModel(const Func& func) {
|
||||
for (size_t i = static_cast<size_t>(Model::GEMMA2_9B);
|
||||
i < static_cast<size_t>(Model::kSentinel); ++i) {
|
||||
if (i == 8 || i == 9) continue;
|
||||
func(static_cast<Model>(i));
|
||||
const Model model = static_cast<Model>(i);
|
||||
if (!IsObsolete(model)) func(model);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -214,7 +216,7 @@ static inline bool EnumValid(Model model) {
|
|||
if (model == Model::UNKNOWN) return true;
|
||||
const size_t i = static_cast<size_t>(model);
|
||||
if (i >= static_cast<size_t>(Model::GEMMA2_9B) &&
|
||||
i < static_cast<size_t>(Model::kSentinel) && i != 8 && i != 9) {
|
||||
i < static_cast<size_t>(Model::kSentinel) && !IsObsolete(model)) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
|
|
@ -235,15 +237,20 @@ struct LayerConfig : public IFields {
|
|||
|
||||
// Source of truth for field ordering.
|
||||
void VisitFields(IFieldsVisitor& visitor) override {
|
||||
// Formerly used for Griffin.
|
||||
uint32_t unused_griffin_dim = 0;
|
||||
uint32_t unused_conv1d_width = 0;
|
||||
bool unused_softmax_attn_output_biases = false;
|
||||
|
||||
visitor(model_dim);
|
||||
visitor(griffin_dim);
|
||||
visitor(unused_griffin_dim);
|
||||
visitor(ff_hidden_dim);
|
||||
visitor(heads);
|
||||
visitor(kv_heads);
|
||||
visitor(qkv_dim);
|
||||
visitor(conv1d_width);
|
||||
visitor(unused_conv1d_width);
|
||||
visitor(ff_biases);
|
||||
visitor(softmax_attn_output_biases);
|
||||
visitor(unused_softmax_attn_output_biases);
|
||||
visitor(optimized_gating);
|
||||
visitor(post_norm);
|
||||
visitor(type);
|
||||
|
|
@ -263,14 +270,11 @@ struct LayerConfig : public IFields {
|
|||
bool IsMHA() const { return heads == kv_heads; }
|
||||
|
||||
uint32_t model_dim = 0;
|
||||
uint32_t griffin_dim = 0;
|
||||
uint32_t ff_hidden_dim = 0;
|
||||
uint32_t heads = 0;
|
||||
uint32_t kv_heads = 0;
|
||||
uint32_t qkv_dim = 0; // length of Q, K, V vectors (contiguous).
|
||||
uint32_t conv1d_width = 0; // Griffin only
|
||||
uint32_t qkv_dim = 0; // length of Q, K, V vectors (contiguous).
|
||||
bool ff_biases = false;
|
||||
bool softmax_attn_output_biases = false; // for Griffin
|
||||
bool optimized_gating = true; // for Gemma3
|
||||
PostNormType post_norm = PostNormType::None;
|
||||
LayerAttentionType type = LayerAttentionType::kGemma;
|
||||
|
|
@ -358,7 +362,8 @@ struct ModelConfig : public IFields {
|
|||
visitor(final_cap);
|
||||
|
||||
visitor(absolute_pe);
|
||||
visitor(use_local_attention);
|
||||
bool unused_use_local_attention = false; // formerly used for Griffin
|
||||
visitor(unused_use_local_attention);
|
||||
visitor(query_scale);
|
||||
visitor(layer_configs);
|
||||
visitor(attention_window_sizes);
|
||||
|
|
@ -454,7 +459,6 @@ struct ModelConfig : public IFields {
|
|||
float final_cap = 0.0f;
|
||||
|
||||
bool absolute_pe = false;
|
||||
bool use_local_attention = false; // Griffin only
|
||||
QueryScaleType query_scale = QueryScaleType::SqrtKeySize;
|
||||
std::vector<LayerConfig> layer_configs;
|
||||
std::vector<uint32_t> attention_window_sizes;
|
||||
|
|
@ -478,7 +482,6 @@ struct ModelConfig : public IFields {
|
|||
ModelConfig GetVitConfig(const ModelConfig& config);
|
||||
|
||||
enum DeducedLayerTypes {
|
||||
kDeducedGriffin = 1,
|
||||
kDeducedViT = 2,
|
||||
kDeduced448 = 4, // For ViT, 448x448 resolution instead of 224x224.
|
||||
};
|
||||
|
|
|
|||
|
|
@ -34,7 +34,6 @@
|
|||
// After highway.h
|
||||
#include "gemma/attention.h" // includes highway.h
|
||||
#include "gemma/gemma-inl.h"
|
||||
#include "gemma/griffin.h" // includes highway.h
|
||||
#include "gemma/vit.h" // includes highway.h
|
||||
|
||||
#ifndef GEMMA_CC_ONCE
|
||||
|
|
@ -77,14 +76,6 @@ void Attention(LayerAttentionType type, const size_t num_tokens,
|
|||
GemmaAttention(num_tokens, layer_idx, layer, activations.attention, qbatch,
|
||||
env,
|
||||
/*flags=*/0);
|
||||
} else {
|
||||
HWY_DASSERT(type == LayerAttentionType::kGriffinRecurrentBlock);
|
||||
// KVCache conv1d_cache and rglru_cache have one row per *Griffin* layer,
|
||||
// so map `layer` to the Griffin layer index.
|
||||
const size_t griffin_layer =
|
||||
activations.attention.config.NumLayersOfTypeBefore(type, layer_idx);
|
||||
GriffinRecurrent(num_tokens, griffin_layer, &layer, activations, qbatch,
|
||||
env);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -484,13 +475,6 @@ static void GenerateT(const ModelConfig& config,
|
|||
const AesCtrEngine& engine, const WeightsPtrs& weights,
|
||||
Activations& activations, QBatch& qbatch, MatMulEnv& env,
|
||||
TimingInfo& timing_info) {
|
||||
// Griffin assumes that the recurrent block cache is zero-initialized.
|
||||
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
|
||||
if (qbatch.MutablePos(qi) == 0) {
|
||||
qbatch.KV(qi).ZeroGriffinCache(); // No-op for non-Griffin models.
|
||||
}
|
||||
}
|
||||
|
||||
size_t max_prompt_size = 0;
|
||||
bool all_prefix_end_are_zero = true;
|
||||
size_t total_prefill_tokens = 0; // only for throughput stats.
|
||||
|
|
|
|||
192
gemma/griffin.cc
192
gemma/griffin.cc
|
|
@ -1,192 +0,0 @@
|
|||
// Copyright 2025 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include "compression/types.h" // GEMMA_DISABLED_TARGETS
|
||||
#ifndef HWY_DISABLED_TARGETS
|
||||
#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS
|
||||
#endif // HWY_DISABLED_TARGETS
|
||||
|
||||
#include "gemma/activations.h"
|
||||
#include "gemma/gemma.h"
|
||||
#include "gemma/gemma_args.h"
|
||||
#include "gemma/weights.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
#include "hwy/profiler.h"
|
||||
|
||||
// Compiles this file for multiple architectures via "foreach_target.h", to
|
||||
// which we pass the filename via macro 'argument'.
|
||||
// clang-format off
|
||||
#undef HWY_TARGET_INCLUDE
|
||||
#define HWY_TARGET_INCLUDE "gemma/griffin.cc" // NOLINT
|
||||
// clang-format on
|
||||
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||
#include "hwy/highway.h"
|
||||
// After highway.h
|
||||
#include "ops/matvec-inl.h"
|
||||
#include "ops/ops-inl.h"
|
||||
|
||||
HWY_BEFORE_NAMESPACE();
|
||||
namespace gcpp {
|
||||
namespace HWY_NAMESPACE {
|
||||
|
||||
void GriffinRecurrent(size_t num_tokens, size_t griffin_layer,
|
||||
const LayerWeightsPtrs* layer_weights,
|
||||
Activations& activations, QBatch& qbatch,
|
||||
MatMulEnv& env) {
|
||||
PROFILER_ZONE("Gen.Griffin");
|
||||
hwy::ThreadPool& pool = env.ctx.pools.Pool(0);
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
using D = hn::ScalableTag<float>;
|
||||
const D df;
|
||||
|
||||
const size_t model_dim = layer_weights->layer_config.model_dim;
|
||||
HWY_DASSERT(model_dim % hn::Lanes(df) == 0);
|
||||
|
||||
const size_t heads = layer_weights->layer_config.heads;
|
||||
const size_t conv_1d_width = layer_weights->layer_config.conv1d_width;
|
||||
HWY_ASSERT_M(conv_1d_width % 2 == 0, "Conv width must be even");
|
||||
const size_t kHeadDim = model_dim / heads;
|
||||
const size_t kMatrixSize = kHeadDim * kHeadDim;
|
||||
|
||||
const size_t num_interleaved = num_tokens * qbatch.Size();
|
||||
const hwy::Divisor div_qbatch(static_cast<uint32_t>(qbatch.Size()));
|
||||
GriffinActivations& griffin = activations.griffin;
|
||||
|
||||
// X / Y linear layers.
|
||||
// TODO: MatMul
|
||||
HWY_DASSERT(griffin.griffin_y.Rows() == griffin.griffin_x.Rows());
|
||||
HWY_DASSERT(num_interleaved == griffin.griffin_y.Rows());
|
||||
CallUpcastedSame(
|
||||
&layer_weights->griffin.linear_x_w, &layer_weights->griffin.linear_y_w,
|
||||
[&](const auto* wx, const auto* wy) {
|
||||
for (size_t r = 0; r < num_interleaved; ++r) {
|
||||
float* HWY_RESTRICT y = griffin.griffin_y.Row(r);
|
||||
float* HWY_RESTRICT x = griffin.griffin_x.Row(r);
|
||||
TwoMatVecAdd(
|
||||
*wx, *wy, 0, model_dim, model_dim,
|
||||
activations.attention.pre_att_rms_out.Row(r),
|
||||
/*add0=*/layer_weights->griffin.linear_x_biases.PackedScale1(),
|
||||
/*add1=*/layer_weights->griffin.linear_y_biases.PackedScale1(),
|
||||
/*out0=*/x, /*out1=*/y, pool);
|
||||
Gelu(y, model_dim);
|
||||
}
|
||||
});
|
||||
|
||||
// Conv1D.
|
||||
for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved;
|
||||
++interleaved_idx) {
|
||||
const size_t qi = div_qbatch.Remainder(interleaved_idx);
|
||||
const size_t batch_idx = div_qbatch.Divide(interleaved_idx);
|
||||
const size_t pos = qbatch.Pos(qi) + batch_idx;
|
||||
float* HWY_RESTRICT x = griffin.griffin_x.Row(qi);
|
||||
|
||||
// cache[i] = input at time t-i.
|
||||
float* HWY_RESTRICT cache[kMaxConv1DWidth];
|
||||
cache[0] = x;
|
||||
for (size_t i = 1; i < conv_1d_width; i++) {
|
||||
cache[i] =
|
||||
qbatch.KV(qi).conv1d_cache.Row(griffin_layer) +
|
||||
((pos + conv_1d_width - 1 - i) % (conv_1d_width - 1)) * model_dim;
|
||||
}
|
||||
for (size_t i = 0; i < model_dim; i += hn::Lanes(df)) {
|
||||
auto xv = hn::Load(df, x + i);
|
||||
auto accum0 =
|
||||
hn::Load(df, layer_weights->griffin.conv_biases.PackedScale1() + i);
|
||||
auto accum1 = hn::Zero(df);
|
||||
for (size_t l = 0; 2 * l < conv_1d_width; l++) {
|
||||
auto wv0 =
|
||||
hn::Load(df, layer_weights->griffin.conv_w.PackedScale1() +
|
||||
(conv_1d_width - 1 - 2 * l) * model_dim + i);
|
||||
auto wv1 =
|
||||
hn::Load(df, layer_weights->griffin.conv_w.PackedScale1() +
|
||||
(conv_1d_width - 2 - 2 * l) * model_dim + i);
|
||||
accum0 = hn::MulAdd(wv0, hn::Load(df, cache[l * 2] + i), accum0);
|
||||
accum1 = hn::MulAdd(wv1, hn::Load(df, cache[l * 2 + 1] + i), accum1);
|
||||
}
|
||||
hn::Store(hn::Add(accum0, accum1), df, x + i);
|
||||
hn::Store(xv, df, cache[HWY_MAX(conv_1d_width, 1) - 1] + i);
|
||||
}
|
||||
}
|
||||
|
||||
// RGLRU
|
||||
for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved;
|
||||
++interleaved_idx) {
|
||||
const size_t qi = div_qbatch.Remainder(interleaved_idx);
|
||||
const size_t batch_idx = div_qbatch.Divide(interleaved_idx);
|
||||
const size_t pos = qbatch.Pos(qi) + batch_idx;
|
||||
|
||||
float* HWY_RESTRICT x = griffin.griffin_x.Row(qi);
|
||||
float* HWY_RESTRICT y = griffin.griffin_y.Row(qi);
|
||||
float* HWY_RESTRICT gate_x = griffin.griffin_gate_x.Row(qi);
|
||||
float* HWY_RESTRICT a = griffin.griffin_multiplier.Row(qi);
|
||||
float* HWY_RESTRICT rnn_state =
|
||||
qbatch.KV(qi).rglru_cache.Row(griffin_layer);
|
||||
|
||||
pool.Run(0, heads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR {
|
||||
size_t head_offset = head * kHeadDim;
|
||||
CallUpcasted(&layer_weights->griffin.gate_w, [&](const auto* gate_w) {
|
||||
TwoOfsMatVecAddLoop(
|
||||
*gate_w, kMatrixSize * head, kMatrixSize * (heads + head), kHeadDim,
|
||||
kHeadDim, x + head_offset,
|
||||
/*add0=*/layer_weights->griffin.gate_biases.PackedScale1() +
|
||||
head_offset,
|
||||
/*add1=*/layer_weights->griffin.gate_biases.PackedScale1() +
|
||||
model_dim + head_offset,
|
||||
/*out0=*/gate_x + head_offset, /*out1=*/a + head_offset);
|
||||
});
|
||||
Sigmoid(gate_x + head_offset, kHeadDim);
|
||||
Sigmoid(a + head_offset, kHeadDim);
|
||||
const auto fn_mul = [](D d, hn::Vec<D> x, hn::Vec<D> gate_x)
|
||||
HWY_ATTR { return hn::Mul(x, gate_x); };
|
||||
hn::Transform1(D(), a + head_offset, kHeadDim,
|
||||
layer_weights->griffin.a.PackedScale1() + head_offset,
|
||||
fn_mul);
|
||||
hn::Transform1(D(), x + head_offset, kHeadDim, gate_x + head_offset,
|
||||
fn_mul);
|
||||
// RNN scan
|
||||
HWY_DASSERT(kHeadDim % hn::Lanes(df) == 0);
|
||||
for (size_t i = 0; i < kHeadDim; i += hn::Lanes(df)) {
|
||||
auto log_a = hn::Load(df, a + head_offset + i);
|
||||
auto gated_x = hn::Load(df, x + head_offset + i);
|
||||
auto rnn = hn::Load(df, rnn_state + head_offset + i);
|
||||
auto a = hn::Exp(df, log_a);
|
||||
auto x_multiplier = hn::Sqrt(hn::NegMulAdd(a, a, hn::Set(df, 1.0f)));
|
||||
if (pos == 0) {
|
||||
x_multiplier = hn::Set(df, 1.0f);
|
||||
}
|
||||
auto new_x = hn::MulAdd(x_multiplier, gated_x, hn::Mul(a, rnn));
|
||||
hn::Store(new_x, df, rnn_state + head_offset + i);
|
||||
|
||||
// Join branches.
|
||||
auto yv = hn::Load(df, y + head_offset + i);
|
||||
auto pre_out = hn::Mul(yv, new_x);
|
||||
hn::Store(pre_out, df, x + head_offset + i);
|
||||
}
|
||||
});
|
||||
} // interleaved_idx
|
||||
|
||||
// Final linear layer.
|
||||
CallMatMul(griffin.griffin_x, layer_weights->griffin.linear_out_w,
|
||||
layer_weights->griffin.linear_out_biases.PackedScale1(), env,
|
||||
activations.attention.att_sums);
|
||||
} // GriffinRecurrent
|
||||
|
||||
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||
} // namespace HWY_NAMESPACE
|
||||
} // namespace gcpp
|
||||
HWY_AFTER_NAMESPACE();
|
||||
|
|
@ -1,47 +0,0 @@
|
|||
// Copyright 2025 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_GRIFFIN_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_GRIFFIN_H_
|
||||
|
||||
// Declares GriffinRecurrent for all SIMD targets.
|
||||
|
||||
#include <stddef.h>
|
||||
|
||||
#include "gemma/gemma.h"
|
||||
#include "hwy/highway.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
// Passed to HWY_VISIT_TARGETS; declares for one target.
|
||||
#define GEMMA_DECL_GRIFFIN(TARGET, NAMESPACE) \
|
||||
namespace NAMESPACE { \
|
||||
void GriffinRecurrent(size_t num_tokens, size_t griffin_layer, \
|
||||
const LayerWeightsPtrs* layer_weights, \
|
||||
Activations& activations, QBatch& qbatch, \
|
||||
MatMulEnv& env); \
|
||||
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
|
||||
} // namespace NAMESPACE
|
||||
|
||||
// Function declarations for each SIMD target. Allows direct call from the
|
||||
// per-target namespace. We may later replace this with dynamic dispatch if
|
||||
// the overhead is acceptable.
|
||||
HWY_VISIT_TARGETS(GEMMA_DECL_GRIFFIN)
|
||||
|
||||
#undef GEMMA_DECL_GRIFFIN
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_GRIFFIN_H_
|
||||
|
|
@ -24,26 +24,6 @@
|
|||
|
||||
namespace gcpp {
|
||||
|
||||
void KVCache::ZeroGriffinCache() {
|
||||
if (conv1d_cache.Rows() == 0) return;
|
||||
ZeroInit(conv1d_cache);
|
||||
ZeroInit(rglru_cache);
|
||||
}
|
||||
|
||||
static size_t GriffinLayers(const ModelConfig& config) {
|
||||
return config.NumLayersOfType(LayerAttentionType::kGriffinRecurrentBlock);
|
||||
}
|
||||
|
||||
static size_t GriffinConv1dCols(const ModelConfig& config) {
|
||||
size_t conv1d_width = 0;
|
||||
for (const auto& layer_config : config.layer_configs) {
|
||||
conv1d_width = HWY_MAX(conv1d_width, layer_config.conv1d_width);
|
||||
}
|
||||
// The row offset, in blocks of model_dim is computed mod (conv1d_width - 1),
|
||||
// hence allocate conv1d_width * model_dim total columns.
|
||||
return conv1d_width * config.model_dim;
|
||||
}
|
||||
|
||||
// Number of rows for KV cache. Note that both rows and cols are u32, and
|
||||
// the total number of elements can exceed 2^32.
|
||||
static size_t CappedSeqLen(const ModelConfig& config,
|
||||
|
|
@ -56,30 +36,18 @@ static size_t CappedSeqLen(const ModelConfig& config,
|
|||
return inference_args.seq_len;
|
||||
}
|
||||
|
||||
KVCache::KVCache(const Extents2D& conv1d_extents,
|
||||
const Extents2D& rglru_extents, const Extents2D& kv_extents,
|
||||
const Allocator& allocator)
|
||||
: conv1d_cache("conv1d_cache", conv1d_extents, allocator, MatPadding::kOdd),
|
||||
rglru_cache("rglru_cache", rglru_extents, allocator, MatPadding::kOdd),
|
||||
kv_cache("kv", kv_extents, allocator, MatPadding::kOdd),
|
||||
KVCache::KVCache(const Extents2D& kv_extents, const Allocator& allocator)
|
||||
: kv_cache("kv", kv_extents, allocator, MatPadding::kOdd),
|
||||
allocator_(allocator) {}
|
||||
|
||||
KVCache::KVCache(const ModelConfig& config, const InferenceArgs& inference_args,
|
||||
const Allocator& allocator)
|
||||
: KVCache(
|
||||
Extents2D(GriffinLayers(config), GriffinConv1dCols(config)),
|
||||
Extents2D(GriffinLayers(config), config.model_dim),
|
||||
Extents2D(CappedSeqLen(config, inference_args), config.KVCacheCols()),
|
||||
allocator) {}
|
||||
|
||||
KVCache KVCache::Copy() {
|
||||
KVCache copy(conv1d_cache.Extents(), rglru_cache.Extents(),
|
||||
kv_cache.Extents(), allocator_);
|
||||
|
||||
if (conv1d_cache.Rows() != 0) {
|
||||
CopyMat(conv1d_cache, copy.conv1d_cache);
|
||||
CopyMat(rglru_cache, copy.rglru_cache);
|
||||
}
|
||||
KVCache copy(kv_cache.Extents(), allocator_);
|
||||
|
||||
CopyMat(kv_cache, copy.kv_cache);
|
||||
|
||||
|
|
|
|||
|
|
@ -35,24 +35,15 @@ struct KVCache {
|
|||
// copy ctor to make the cost explicit.
|
||||
KVCache Copy();
|
||||
|
||||
// Zero-initialize the Griffin recurrent block cache, i.e. the conv1d_cache
|
||||
// and rglru_cache.
|
||||
void ZeroGriffinCache();
|
||||
|
||||
size_t SeqLen() const { return kv_cache.Rows(); }
|
||||
|
||||
// [griffin_layers, griffin_conv1d_cols * model_dim]
|
||||
MatStorageT<float> conv1d_cache;
|
||||
MatStorageT<float> rglru_cache; // [griffin_layers, model_dim]
|
||||
|
||||
MatStorageT<KV_t> kv_cache; // [seq_len, layers * kv_heads * qkv_dim * 2]
|
||||
|
||||
private:
|
||||
const Allocator& allocator_;
|
||||
|
||||
// For use by other ctor and Copy()
|
||||
KVCache(const Extents2D& conv1d_extents, const Extents2D& rglru_extents,
|
||||
const Extents2D& kv_extents, const Allocator& allocator);
|
||||
KVCache(const Extents2D& kv_extents, const Allocator& allocator);
|
||||
};
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
|
|
@ -221,9 +221,6 @@ static int DeduceLayerTypes(const BlobReader& reader) {
|
|||
int layer_types = 0;
|
||||
for (size_t key_idx = 0; key_idx < reader.Keys().size(); ++key_idx) {
|
||||
const std::string& key = reader.Keys()[key_idx];
|
||||
if (key.find("gr_conv_w") != std::string::npos) { // NOLINT
|
||||
return kDeducedGriffin;
|
||||
}
|
||||
if (key.find("qkv_ein_w") != std::string::npos) { // NOLINT
|
||||
layer_types |= kDeducedViT;
|
||||
}
|
||||
|
|
@ -293,7 +290,7 @@ static std::vector<float> ReadScales(BlobReader& reader,
|
|||
const ModelConfig& config) {
|
||||
std::vector<float> scales;
|
||||
// Check first to prevent `CallWithSpan` from printing a warning. This blob is
|
||||
// optional even in pre-2025 format; Griffin was the first to include it.
|
||||
// optional even in pre-2025 format.
|
||||
if (reader.Find(kDecoratedScalesName)) {
|
||||
HWY_ASSERT(reader.CallWithSpan<float>(
|
||||
kDecoratedScalesName,
|
||||
|
|
|
|||
|
|
@ -277,122 +277,6 @@ void TensorInfoRegistry::AddImageLayerTensors(const ModelConfig& config,
|
|||
});
|
||||
}
|
||||
|
||||
void TensorInfoRegistry::AddGriffinLayerTensors(const LayerConfig& layer_config,
|
||||
const size_t layer_idx) {
|
||||
const std::string suffix = LayerSuffix(layer_idx);
|
||||
Add(suffix, {
|
||||
.base_name = "gr_lin_x_w",
|
||||
.source_names = {"recurrent_block/linear_x/kernel"},
|
||||
.axes = {1, 0},
|
||||
.shape = {layer_config.griffin_dim, layer_config.griffin_dim},
|
||||
});
|
||||
Add(suffix, {
|
||||
.base_name = "gr_lin_x_b",
|
||||
.source_names = {"recurrent_block/linear_x/bias"},
|
||||
.axes = {0},
|
||||
.shape = {layer_config.griffin_dim},
|
||||
.min_size = Type::kF32,
|
||||
});
|
||||
Add(suffix, {
|
||||
.base_name = "gr_lin_y_w",
|
||||
.source_names = {"recurrent_block/linear_y/kernel"},
|
||||
.axes = {1, 0},
|
||||
.shape = {layer_config.griffin_dim, layer_config.griffin_dim},
|
||||
});
|
||||
Add(suffix, {
|
||||
.base_name = "gr_lin_y_b",
|
||||
.source_names = {"recurrent_block/linear_y/bias"},
|
||||
.axes = {0},
|
||||
.shape = {layer_config.griffin_dim},
|
||||
.min_size = Type::kF32,
|
||||
});
|
||||
Add(suffix, {
|
||||
.base_name = "gr_lin_out_w",
|
||||
.source_names = {"recurrent_block/linear_out/kernel"},
|
||||
.axes = {1, 0},
|
||||
.shape = {layer_config.griffin_dim, layer_config.griffin_dim},
|
||||
});
|
||||
Add(suffix, {
|
||||
.base_name = "gr_lin_out_b",
|
||||
.source_names = {"recurrent_block/linear_out/bias"},
|
||||
.axes = {0},
|
||||
.shape = {layer_config.griffin_dim},
|
||||
.min_size = Type::kF32,
|
||||
});
|
||||
Add(suffix,
|
||||
{
|
||||
.base_name = "gr_conv_w",
|
||||
.source_names = {"recurrent_block/conv_1d/w"},
|
||||
.axes = {0, 1},
|
||||
.shape = {layer_config.conv1d_width, layer_config.griffin_dim},
|
||||
.min_size = Type::kF32,
|
||||
});
|
||||
Add(suffix, {
|
||||
.base_name = "gr_conv_b",
|
||||
.source_names = {"recurrent_block/conv_1d/b"},
|
||||
.axes = {0},
|
||||
.shape = {layer_config.griffin_dim},
|
||||
.min_size = Type::kF32,
|
||||
});
|
||||
Add(suffix, {
|
||||
.base_name = "gr1_gate_w",
|
||||
.source_names = {"recurrent_block/rg_lru/input_gate/w"},
|
||||
.axes = {0, 2, 1},
|
||||
.shape = {layer_config.heads,
|
||||
layer_config.griffin_dim / layer_config.heads,
|
||||
layer_config.griffin_dim / layer_config.heads},
|
||||
.concat_names = {"gr_gate_w", "gr2_gate_w"},
|
||||
});
|
||||
Add(suffix, {
|
||||
.base_name = "gr2_gate_w",
|
||||
.source_names = {"recurrent_block/rg_lru/a_gate/w"},
|
||||
.axes = {0, 2, 1},
|
||||
.shape = {layer_config.heads,
|
||||
layer_config.griffin_dim / layer_config.heads,
|
||||
layer_config.griffin_dim / layer_config.heads},
|
||||
.concat_names = {""},
|
||||
});
|
||||
Add(suffix, {
|
||||
.base_name = "gr_gate_w",
|
||||
.source_names = {"recurrent_block/rg_lru/gate/w"},
|
||||
.axes = {0, 2, 1},
|
||||
.shape = {2 * layer_config.heads,
|
||||
layer_config.griffin_dim / layer_config.heads,
|
||||
layer_config.griffin_dim / layer_config.heads},
|
||||
});
|
||||
Add(suffix, {
|
||||
.base_name = "gr1_gate_b",
|
||||
.source_names = {"recurrent_block/rg_lru/input_gate/b"},
|
||||
.axes = {0},
|
||||
.shape = {layer_config.griffin_dim},
|
||||
.concat_names = {"gr_gate_b", "gr2_gate_b"},
|
||||
.min_size = Type::kF32,
|
||||
});
|
||||
Add(suffix, {
|
||||
.base_name = "gr2_gate_b",
|
||||
.source_names = {"recurrent_block/rg_lru/a_gate/b"},
|
||||
.axes = {0},
|
||||
.shape = {layer_config.griffin_dim},
|
||||
.concat_names = {""},
|
||||
.min_size = Type::kF32,
|
||||
});
|
||||
Add(suffix, {
|
||||
.base_name = "gr_gate_b",
|
||||
.source_names = {"recurrent_block/rg_lru/input_gate/b"},
|
||||
.axes = {0, 1},
|
||||
.shape = {2 * layer_config.griffin_dim},
|
||||
.min_size = Type::kF32,
|
||||
});
|
||||
Add(suffix, {
|
||||
.base_name = "gr_a",
|
||||
.source_names = {"recurrent_block/rg_lru/a_param"},
|
||||
.axes = {0},
|
||||
.shape = {layer_config.griffin_dim},
|
||||
.min_size = Type::kF32,
|
||||
.scaled_softplus = true,
|
||||
});
|
||||
}
|
||||
|
||||
void TensorInfoRegistry::AddLayerTensors(const ModelConfig& config,
|
||||
const LayerConfig& layer_config,
|
||||
const size_t layer_idx) {
|
||||
|
|
@ -553,10 +437,6 @@ void TensorInfoRegistry::AddLayerTensors(const ModelConfig& config,
|
|||
.shape = {config.model_dim, layer_config.heads, layer_config.qkv_dim},
|
||||
.cols_take_extra_dims = true,
|
||||
});
|
||||
|
||||
if (config.model == Model::GRIFFIN_2B) {
|
||||
AddGriffinLayerTensors(layer_config, layer_idx);
|
||||
}
|
||||
}
|
||||
|
||||
TensorInfoRegistry::TensorInfoRegistry(const ModelConfig& config) {
|
||||
|
|
|
|||
|
|
@ -124,8 +124,6 @@ class TensorInfoRegistry {
|
|||
void AddModelTensors(const ModelConfig& config);
|
||||
void AddLayerTensors(const ModelConfig& config,
|
||||
const LayerConfig& layer_config, size_t layer_idx);
|
||||
void AddGriffinLayerTensors(const LayerConfig& layer_config,
|
||||
size_t layer_idx);
|
||||
|
||||
void AddImageLayerTensors(const ModelConfig& config,
|
||||
const LayerConfig& layer_config,
|
||||
|
|
|
|||
|
|
@ -88,7 +88,7 @@ void LayerWeightsPtrs::InitAttWeights(std::vector<MatOwner>& mat_owners,
|
|||
|
||||
// For FFN. Fast, only updates pointers.
|
||||
void LayerWeightsPtrs::SplitW1() {
|
||||
// Used for Gemma and Griffin layers; FFWVit uses different tensors.
|
||||
// Used for Gemma layers; FFWVit uses different tensors.
|
||||
if (layer_config.type == LayerAttentionType::kVit) return;
|
||||
|
||||
// Files have both or neither of w1 and w2.
|
||||
|
|
|
|||
|
|
@ -57,8 +57,7 @@ struct TensorArgs {
|
|||
// the _w1/_w2 tensors are not always present.
|
||||
kMaybeRead = 1,
|
||||
|
||||
// Avoid padding tensor rows when reading. Used for some Griffin tensors
|
||||
// whose index computations do not use Row() accessors.
|
||||
// Avoid padding tensor rows when reading.
|
||||
kPacked = 2,
|
||||
};
|
||||
const int flags;
|
||||
|
|
@ -102,17 +101,6 @@ struct LayerWeightsPtrs {
|
|||
qkv_einsum_w1(finder_("qkv1_w")),
|
||||
qkv_einsum_w2(finder_("qkv2_w")),
|
||||
attention_output_biases(finder_("attn_ob")),
|
||||
griffin({.linear_x_w = finder_("gr_lin_x_w"),
|
||||
.linear_x_biases = finder_("gr_lin_x_b"),
|
||||
.linear_y_w = finder_("gr_lin_y_w"),
|
||||
.linear_y_biases = finder_("gr_lin_y_b"),
|
||||
.linear_out_w = finder_("gr_lin_out_w"),
|
||||
.linear_out_biases = finder_("gr_lin_out_b"),
|
||||
.conv_w = finder_("gr_conv_w"),
|
||||
.conv_biases = finder_("gr_conv_b"),
|
||||
.gate_w = finder_("gr_gate_w"),
|
||||
.gate_biases = finder_("gr_gate_b"),
|
||||
.a = finder_("gr_a")}),
|
||||
// MultiHeadDotProductAttention.
|
||||
vit({.attn_out_w = finder_("attn_out_w"),
|
||||
.attn_out_b = finder_("attn_out_b"),
|
||||
|
|
@ -156,20 +144,6 @@ struct LayerWeightsPtrs {
|
|||
MatPtr qkv_einsum_w2;
|
||||
MatPtrT<float> attention_output_biases;
|
||||
|
||||
struct {
|
||||
MatPtr linear_x_w;
|
||||
MatPtrT<float> linear_x_biases;
|
||||
MatPtr linear_y_w;
|
||||
MatPtrT<float> linear_y_biases;
|
||||
MatPtr linear_out_w;
|
||||
MatPtrT<float> linear_out_biases;
|
||||
MatPtrT<float> conv_w;
|
||||
MatPtrT<float> conv_biases;
|
||||
MatPtr gate_w;
|
||||
MatPtrT<float> gate_biases;
|
||||
MatPtrT<float> a;
|
||||
} griffin;
|
||||
|
||||
struct {
|
||||
// MultiHeadDotProductAttention.
|
||||
MatPtr attn_out_w; // at least BF16.
|
||||
|
|
@ -244,20 +218,6 @@ struct LayerWeightsPtrs {
|
|||
func(TENSOR_ARGS(qkv_einsum_w, kMaybeRead));
|
||||
func(TENSOR_ARGS(qkv_einsum_w1, kMaybeRead));
|
||||
func(TENSOR_ARGS(qkv_einsum_w2, kMaybeRead));
|
||||
} else {
|
||||
func(TENSOR_ARGS(griffin.linear_x_w, kMustRead));
|
||||
func(TENSOR_ARGS(griffin.linear_x_biases, kMustRead));
|
||||
func(TENSOR_ARGS(griffin.linear_y_w, kMustRead));
|
||||
func(TENSOR_ARGS(griffin.linear_y_biases, kMustRead));
|
||||
func(TENSOR_ARGS(griffin.linear_out_w, kMustRead));
|
||||
func(TENSOR_ARGS(griffin.linear_out_biases, kMustRead));
|
||||
// conv_w and gate_w are not accessed via Row(), hence must not be padded.
|
||||
// Note that *biases are 1D, hence packing/padding does not matter.
|
||||
func(TENSOR_ARGS(griffin.conv_w, kMustRead | TensorArgs::kPacked));
|
||||
func(TENSOR_ARGS(griffin.conv_biases, kMustRead));
|
||||
func(TENSOR_ARGS(griffin.gate_w, kMustRead | TensorArgs::kPacked));
|
||||
func(TENSOR_ARGS(griffin.gate_biases, kMustRead));
|
||||
func(TENSOR_ARGS(griffin.a, kMustRead));
|
||||
}
|
||||
{
|
||||
func(TENSOR_ARGS(gating_einsum_w, kMaybeRead));
|
||||
|
|
@ -281,11 +241,6 @@ struct LayerWeightsPtrs {
|
|||
func(TENSOR_ARGS(ffw_gating_biases, kMustRead));
|
||||
func(TENSOR_ARGS(ffw_output_biases, kMustRead));
|
||||
}
|
||||
|
||||
if (layer_config.softmax_attn_output_biases &&
|
||||
layer_config.type == LayerAttentionType::kGemma) {
|
||||
func(TENSOR_ARGS(attention_output_biases, kMustRead));
|
||||
}
|
||||
} // `ForEachTensor`
|
||||
|
||||
// Zero-initializes all allocated tensors in the layer.
|
||||
|
|
|
|||
|
|
@ -57,8 +57,6 @@ PYBIND11_MODULE(configs, py_module) {
|
|||
|
||||
enum_<LayerAttentionType>(py_module, "LayerAttentionType")
|
||||
.value("kGemma", LayerAttentionType::kGemma)
|
||||
.value("kGriffinRecurrentBlock",
|
||||
LayerAttentionType::kGriffinRecurrentBlock)
|
||||
.value("kVit", LayerAttentionType::kVit);
|
||||
|
||||
enum_<PostNormType>(py_module, "PostNormType")
|
||||
|
|
@ -84,8 +82,6 @@ PYBIND11_MODULE(configs, py_module) {
|
|||
.value("UNKNOWN", Model::UNKNOWN)
|
||||
.value("GEMMA2_9B", Model::GEMMA2_9B)
|
||||
.value("GEMMA2_27B", Model::GEMMA2_27B)
|
||||
.value("GRIFFIN_2B", Model::GRIFFIN_2B)
|
||||
.value("GEMMA_TINY", Model::GEMMA_TINY)
|
||||
.value("GEMMA2_2B", Model::GEMMA2_2B)
|
||||
.value("PALIGEMMA2_3B_224", Model::PALIGEMMA2_3B_224)
|
||||
.value("PALIGEMMA2_10B_224", Model::PALIGEMMA2_10B_224)
|
||||
|
|
@ -121,15 +117,11 @@ PYBIND11_MODULE(configs, py_module) {
|
|||
class_<LayerConfig>(py_module, "LayerConfig")
|
||||
.def(init())
|
||||
.def_readwrite("model_dim", &LayerConfig::model_dim)
|
||||
.def_readwrite("griffin_dim", &LayerConfig::griffin_dim)
|
||||
.def_readwrite("ff_hidden_dim", &LayerConfig::ff_hidden_dim)
|
||||
.def_readwrite("heads", &LayerConfig::heads)
|
||||
.def_readwrite("kv_heads", &LayerConfig::kv_heads)
|
||||
.def_readwrite("qkv_dim", &LayerConfig::qkv_dim)
|
||||
.def_readwrite("conv1d_width", &LayerConfig::conv1d_width)
|
||||
.def_readwrite("ff_biases", &LayerConfig::ff_biases)
|
||||
.def_readwrite("softmax_attn_output_biases",
|
||||
&LayerConfig::softmax_attn_output_biases)
|
||||
.def_readwrite("optimized_gating", &LayerConfig::optimized_gating)
|
||||
.def_readwrite("post_norm", &LayerConfig::post_norm)
|
||||
.def_readwrite("type", &LayerConfig::type)
|
||||
|
|
@ -166,7 +158,6 @@ PYBIND11_MODULE(configs, py_module) {
|
|||
.def_readwrite("att_cap", &ModelConfig::att_cap)
|
||||
.def_readwrite("final_cap", &ModelConfig::final_cap)
|
||||
.def_readwrite("absolute_pe", &ModelConfig::absolute_pe)
|
||||
.def_readwrite("use_local_attention", &ModelConfig::use_local_attention)
|
||||
.def_readwrite("query_scale", &ModelConfig::query_scale)
|
||||
.def_readwrite("layer_configs", &ModelConfig::layer_configs)
|
||||
.def_readwrite("attention_window_sizes",
|
||||
|
|
|
|||
Loading…
Reference in New Issue