mirror of https://github.com/google/gemma.cpp.git
Add Python code for converting Griffin Orbax weights. Refs #301
PiperOrigin-RevId: 657296255
This commit is contained in:
parent
f27683152c
commit
d9f86f8e4d
|
|
@ -0,0 +1,8 @@
|
||||||
|
|
||||||
|
# General Remarks about the "PyTree" Abstraction
|
||||||
|
|
||||||
|
The pytree wrangling code in this project does not use any of the existing
|
||||||
|
"pytree" modules. The deeper reason here is that our approach is based on an
|
||||||
|
analysis of the notion that emphasizes deeper underlying principles. This is
|
||||||
|
being discussed internally at the time of this writing.
|
||||||
|
|
||||||
|
|
@ -0,0 +1,275 @@
|
||||||
|
"""Ad-hoc glue code for building the griffin model-file for the C++ binary.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
|
||||||
|
python3 -m venv $HOME/clients/griffin-venv
|
||||||
|
|
||||||
|
. $HOME/clients/griffin-venv/bin/activate
|
||||||
|
|
||||||
|
python3 -m pip install -r requirements.txt
|
||||||
|
|
||||||
|
time python3 build_model_file_for_cpp_binary.py \
|
||||||
|
$HOME/GRIFFIN/model_data \
|
||||||
|
cpp_load_log.txt /tmp/G2B.data
|
||||||
|
|
||||||
|
real 3m5.821s
|
||||||
|
user 2m9.205s
|
||||||
|
sys 2m46.720s
|
||||||
|
|
||||||
|
./compress_weights --weights /tmp/G2B.data --model gr2b-it \
|
||||||
|
--compressed_weights /tmp/G2B.compressed
|
||||||
|
./gemma --tokenizer tokenizer.spm --weights /tmp/G2B.compressed \
|
||||||
|
--model gr2b-it
|
||||||
|
|
||||||
|
Weights for the recurrent-gemma model that can be converted with this script
|
||||||
|
can be found at:
|
||||||
|
|
||||||
|
https://www.kaggle.com/models/google/recurrentgemma/flax/2b-it
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pprint
|
||||||
|
import re
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from typing import Any, Mapping
|
||||||
|
|
||||||
|
import numpy
|
||||||
|
|
||||||
|
import orbax.checkpoint
|
||||||
|
|
||||||
|
import ml_model_transforms
|
||||||
|
import pytree_transforms
|
||||||
|
|
||||||
|
|
||||||
|
def _fn_identity(x): return x
|
||||||
|
|
||||||
|
|
||||||
|
def _fn_transpose(x): return x.T
|
||||||
|
|
||||||
|
|
||||||
|
def _fn_transpose_all_heads(x): return x.transpose(0, 2, 1)
|
||||||
|
|
||||||
|
|
||||||
|
def _fn_scaled_softplus(a):
|
||||||
|
return -8 * numpy.logaddexp(a, 0)
|
||||||
|
|
||||||
|
|
||||||
|
def _fn_attention_moveaxis(a):
|
||||||
|
return a.reshape(10, 256, 2560).transpose(0, 2, 1)
|
||||||
|
|
||||||
|
|
||||||
|
def _aspec(pieces=(), transforms=()):
|
||||||
|
"""Short-hand array-save-specification.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pieces: Sequence of key-sequences identifying an array.
|
||||||
|
transforms: Sequence of transformations, indexed in
|
||||||
|
parallel to `pieces`, to apply to data arrays prior to saving.
|
||||||
|
Will be padded with identity-transformations to the length of `pieces`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Specification as for use in _LAYETR_NAME_MAPPING.
|
||||||
|
"""
|
||||||
|
# `zip` trims to shortest sequence, so this amounts to using
|
||||||
|
# default-transforms.
|
||||||
|
# tuple() since we need a Sequence here, not a stateful-iterator zip_object.
|
||||||
|
return tuple(zip(pieces, list(transforms) + [_fn_identity] * len(pieces)))
|
||||||
|
|
||||||
|
|
||||||
|
_LAYER_NAME_MAPPING = pytree_transforms.deep_freeze({
|
||||||
|
# Recurrent Layer
|
||||||
|
'griffin_linear_x_w': _aspec(
|
||||||
|
[('recurrent_block', 'linear_x', 'kernel')],
|
||||||
|
[_fn_transpose]),
|
||||||
|
'griffin_linear_x_biases': _aspec(
|
||||||
|
[('recurrent_block', 'linear_x', 'bias')]),
|
||||||
|
'griffin_linear_y_w': _aspec(
|
||||||
|
[('recurrent_block', 'linear_y', 'kernel')],
|
||||||
|
[_fn_transpose]),
|
||||||
|
'griffin_linear_y_biases': _aspec(
|
||||||
|
[('recurrent_block', 'linear_y', 'bias')]),
|
||||||
|
'griffin_linear_out_w': _aspec(
|
||||||
|
[('recurrent_block', 'linear_out', 'kernel')],
|
||||||
|
[_fn_transpose]),
|
||||||
|
'griffin_linear_out_biases': _aspec(
|
||||||
|
[('recurrent_block' ,'linear_out', 'bias')]),
|
||||||
|
'griffin_conv_w': _aspec(
|
||||||
|
[('recurrent_block', 'conv_1d', 'w')]),
|
||||||
|
'griffin_conv_biases': _aspec(
|
||||||
|
[('recurrent_block', 'conv_1d', 'b')]),
|
||||||
|
'griffin_gate_w': _aspec(
|
||||||
|
[('recurrent_block', 'rg_lru', 'input_gate', 'w'),
|
||||||
|
('recurrent_block', 'rg_lru', 'a_gate', 'w')],
|
||||||
|
[_fn_transpose_all_heads, _fn_transpose_all_heads]),
|
||||||
|
'griffin_gate_biases': _aspec(
|
||||||
|
[('recurrent_block', 'rg_lru', 'input_gate', 'b'),
|
||||||
|
('recurrent_block', 'rg_lru', 'a_gate', 'b')]),
|
||||||
|
'griffin_a': _aspec(
|
||||||
|
[('recurrent_block', 'rg_lru', 'a_param')],
|
||||||
|
[_fn_scaled_softplus]),
|
||||||
|
# Attention Layer
|
||||||
|
'qkv_einsum_w': _aspec(
|
||||||
|
[('attention_block', 'proj_q', 'kernel'),
|
||||||
|
('attention_block', 'proj_k', 'kernel'),
|
||||||
|
('attention_block', 'proj_v', 'kernel'),
|
||||||
|
],
|
||||||
|
[_fn_transpose, _fn_transpose, _fn_transpose]),
|
||||||
|
'attn_vec_einsum_w': _aspec(
|
||||||
|
[('attention_block', 'proj_final', 'kernel')],
|
||||||
|
[_fn_attention_moveaxis]),
|
||||||
|
'attention_output_biases': _aspec(
|
||||||
|
[('attention_block', 'proj_final', 'bias')]),
|
||||||
|
# Common
|
||||||
|
'pre_attention_norm_scale': _aspec(
|
||||||
|
[('temporal_pre_norm', 'scale')]),
|
||||||
|
'pre_ffw_norm_scale': _aspec(
|
||||||
|
[('channel_pre_norm', 'scale')]),
|
||||||
|
'gating_einsum_w': _aspec(
|
||||||
|
[('mlp_block', 'ffw_up', 'w')],
|
||||||
|
[_fn_transpose_all_heads]),
|
||||||
|
'ffw_gating_biases': _aspec(
|
||||||
|
[('mlp_block', 'ffw_up', 'b')]),
|
||||||
|
'linear_w': _aspec(
|
||||||
|
[('mlp_block', 'ffw_down', 'kernel')],
|
||||||
|
[_fn_transpose]),
|
||||||
|
'ffw_output_biases': _aspec(
|
||||||
|
[('mlp_block', 'ffw_down', 'bias')]),
|
||||||
|
# Other
|
||||||
|
'embedder_input_embedding': _aspec(
|
||||||
|
[('embedder', 'input_embedding')]),
|
||||||
|
'final_norm_scale': _aspec(
|
||||||
|
[('final_norm', 'scale')]),
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
def process_param_line(line : str) -> tuple[None | str, int, str]:
|
||||||
|
"""Processes a "loading parameters" log-line from the griffin binary."""
|
||||||
|
# This is slightly more permissive than strictly needed, to also handle
|
||||||
|
# some earlier form of the output.
|
||||||
|
matched = re.match(
|
||||||
|
r'(?a)Loading Parameters:? \('
|
||||||
|
r'(?:layer=(?P<layer>\d+), )?'
|
||||||
|
r'size (?P<size>\d+)\):? '
|
||||||
|
r'(?P<tag>\S+)',
|
||||||
|
line)
|
||||||
|
if not matched:
|
||||||
|
return None
|
||||||
|
layer = matched['layer']
|
||||||
|
wanted_size = int(matched['size'])
|
||||||
|
cpp_tag = matched['tag']
|
||||||
|
return matched['layer'], int(matched['size']), matched['tag']
|
||||||
|
|
||||||
|
|
||||||
|
def collect_pytree_keys(param_lines):
|
||||||
|
"""Collects all the pytree keys and transforms for model-serialization."""
|
||||||
|
pytree_keys = []
|
||||||
|
array_transforms = []
|
||||||
|
unsatisfied = []
|
||||||
|
for maybe_spec in map(process_param_line, param_lines):
|
||||||
|
if not maybe_spec: continue # Skip non-parameter lines.
|
||||||
|
layer, wanted_size, cpp_tag = maybe_spec
|
||||||
|
pytree_key_tails_and_transforms = _LAYER_NAME_MAPPING.get(cpp_tag, ())
|
||||||
|
if not pytree_key_tails_and_transforms:
|
||||||
|
unsatisfied.append((layer, cpp_tag))
|
||||||
|
else:
|
||||||
|
for key_tail, array_transform in pytree_key_tails_and_transforms:
|
||||||
|
pytree_keys.append(
|
||||||
|
key_tail if layer is None
|
||||||
|
else (f'blocks.{layer}',) + key_tail)
|
||||||
|
array_transforms.append(array_transform)
|
||||||
|
return pytree_keys, array_transforms, unsatisfied
|
||||||
|
|
||||||
|
|
||||||
|
class UnsatisfiedArrayLoadsError(ValueError):
|
||||||
|
"""Some array-loads could not be satisfied."""
|
||||||
|
|
||||||
|
|
||||||
|
def flatten_model_for_cpp_binary(tree,
|
||||||
|
cpp_expectations_logfile_path : str,
|
||||||
|
out_path : str,
|
||||||
|
unsatisfied_ok : bool = False
|
||||||
|
):
|
||||||
|
"""Produces a model-parameters file readable by the C++ binary.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tree: The pytree with model-parameters.
|
||||||
|
cpp_expectations_logfile_path:
|
||||||
|
Path to a logfile produced by the C++ binary that shows
|
||||||
|
the expected array-order.
|
||||||
|
out_path: Path to the model-weights file to be written.
|
||||||
|
unsatisfied_ok: If true, we ignore the presence of unsatisfied
|
||||||
|
array-loads and write a model-parameters file that skips these pieces.
|
||||||
|
This will lead to an unusable model-parameters file which however
|
||||||
|
still might be useful for other analysis.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple `(unknown_keys, missing_keys)`, where `unknown_keys`
|
||||||
|
is a sequence of `(layer_or_None, name)` descriptions of the keys
|
||||||
|
in the C++ log that could not be satisfied, and `missing_keys`
|
||||||
|
is a sequence of linearized pytree key-sequences for keys
|
||||||
|
not found in the checkpoint.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
UnsatisfiedArrayLoadsError: If some of the expected arrays
|
||||||
|
could not be included in the output and `unsatisfied_ok`
|
||||||
|
is false.
|
||||||
|
"""
|
||||||
|
with open(cpp_expectations_logfile_path, 'rt') as h_log:
|
||||||
|
pytree_keys, array_transforms, unknown_keys = collect_pytree_keys(
|
||||||
|
list(h_log))
|
||||||
|
rank_by_pytree_key = {k: n for n, k in enumerate(pytree_keys)}
|
||||||
|
array_transform_by_pytree_key = dict(zip(pytree_keys, array_transforms))
|
||||||
|
#
|
||||||
|
model_contents = ml_model_transforms.model_contents(tree)
|
||||||
|
missing_keys = set(pytree_keys) - model_contents.keys()
|
||||||
|
if (unknown_keys or missing_keys) and not unsatisfied_ok:
|
||||||
|
raise ValueError(
|
||||||
|
f'Unsatisfied loads: unknown_keys: {unknown_keys!r}, '
|
||||||
|
f'missing keys: {sorted(missing_keys)!r}')
|
||||||
|
ml_model_transforms.model_save(
|
||||||
|
tree,
|
||||||
|
filepath_stem=out_path,
|
||||||
|
data_suffix='',
|
||||||
|
manifest_suffix=None,
|
||||||
|
array_transform_by_pytree_key=array_transform_by_pytree_key,
|
||||||
|
key=rank_by_pytree_key.get,
|
||||||
|
report=lambda line: print(line, file=sys.stderr),
|
||||||
|
byte_align=1)
|
||||||
|
return tuple(unknown_keys), tuple(sorted(missing_keys))
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
"""Creates the model-file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sys.argv[] parameters from command line sans the leading one.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The pytree with all the de-serialized variables, such as for convenient
|
||||||
|
`python3 -i` inspection.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
model_dir, cpp_load_log, out_path = args
|
||||||
|
except Exception:
|
||||||
|
sys.exit(f'Usage: {__file__} [model_dir] [cpp_load_log] [output_filename]')
|
||||||
|
pattern = ("recurrent", "recurrent", "attention")
|
||||||
|
orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
|
||||||
|
variables = orbax_checkpointer.restore(model_dir)
|
||||||
|
if sorted(variables) == ['params']:
|
||||||
|
print('Warning: Using `variables["params"]` as tree-root.', file=sys.stderr)
|
||||||
|
variables_to_use = variables['params']
|
||||||
|
else:
|
||||||
|
variables_to_use = variables
|
||||||
|
unknown, missing = flatten_model_for_cpp_binary(variables_to_use,
|
||||||
|
cpp_load_log,
|
||||||
|
out_path,
|
||||||
|
unsatisfied_ok=True)
|
||||||
|
print('Model file saved.\n'
|
||||||
|
f'# unknown:\n{pprint.pformat(unknown)}\n'
|
||||||
|
f'# missing:\n{pprint.pformat(missing)}')
|
||||||
|
return variables
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# Return value assignment is for `python3 -i ...` inspection.
|
||||||
|
pytree = main(sys.argv[1:])
|
||||||
|
|
@ -0,0 +1,380 @@
|
||||||
|
Loading Parameters (size 2622750720): embedder_input_embedding
|
||||||
|
Loading Parameters (size 10240): final_norm_scale
|
||||||
|
Loading Parameters: (layer=0, size 26214400) griffin_linear_x_w
|
||||||
|
Loading Parameters: (layer=0, size 10240) griffin_linear_x_biases
|
||||||
|
Loading Parameters: (layer=0, size 26214400) griffin_linear_y_w
|
||||||
|
Loading Parameters: (layer=0, size 10240) griffin_linear_y_biases
|
||||||
|
Loading Parameters: (layer=0, size 26214400) griffin_linear_out_w
|
||||||
|
Loading Parameters: (layer=0, size 10240) griffin_linear_out_biases
|
||||||
|
Loading Parameters: (layer=0, size 40960) griffin_conv_w
|
||||||
|
Loading Parameters: (layer=0, size 10240) griffin_conv_biases
|
||||||
|
Loading Parameters: (layer=0, size 5242880) griffin_gate_w
|
||||||
|
Loading Parameters: (layer=0, size 20480) griffin_gate_biases
|
||||||
|
Loading Parameters: (layer=0, size 10240) griffin_a
|
||||||
|
Loading Parameters: (layer=0, size 157286400) gating_einsum_w
|
||||||
|
Loading Parameters: (layer=0, size 78643200) linear_w
|
||||||
|
Loading Parameters: (layer=0, size 10240) pre_attention_norm_scale
|
||||||
|
Loading Parameters: (layer=0, size 10240) pre_ffw_norm_scale
|
||||||
|
Loading Parameters: (layer=0, size 61440) ffw_gating_biases
|
||||||
|
Loading Parameters: (layer=0, size 10240) ffw_output_biases
|
||||||
|
Loading Parameters: (layer=1, size 26214400) griffin_linear_x_w
|
||||||
|
Loading Parameters: (layer=1, size 10240) griffin_linear_x_biases
|
||||||
|
Loading Parameters: (layer=1, size 26214400) griffin_linear_y_w
|
||||||
|
Loading Parameters: (layer=1, size 10240) griffin_linear_y_biases
|
||||||
|
Loading Parameters: (layer=1, size 26214400) griffin_linear_out_w
|
||||||
|
Loading Parameters: (layer=1, size 10240) griffin_linear_out_biases
|
||||||
|
Loading Parameters: (layer=1, size 40960) griffin_conv_w
|
||||||
|
Loading Parameters: (layer=1, size 10240) griffin_conv_biases
|
||||||
|
Loading Parameters: (layer=1, size 5242880) griffin_gate_w
|
||||||
|
Loading Parameters: (layer=1, size 20480) griffin_gate_biases
|
||||||
|
Loading Parameters: (layer=1, size 10240) griffin_a
|
||||||
|
Loading Parameters: (layer=1, size 157286400) gating_einsum_w
|
||||||
|
Loading Parameters: (layer=1, size 78643200) linear_w
|
||||||
|
Loading Parameters: (layer=1, size 10240) pre_attention_norm_scale
|
||||||
|
Loading Parameters: (layer=1, size 10240) pre_ffw_norm_scale
|
||||||
|
Loading Parameters: (layer=1, size 61440) ffw_gating_biases
|
||||||
|
Loading Parameters: (layer=1, size 10240) ffw_output_biases
|
||||||
|
Loading Parameters: (layer=2, size 26214400) attn_vec_einsum_w
|
||||||
|
Loading Parameters: (layer=2, size 78643200) qkv_einsum_w
|
||||||
|
Loading Parameters: (layer=2, size 157286400) gating_einsum_w
|
||||||
|
Loading Parameters: (layer=2, size 78643200) linear_w
|
||||||
|
Loading Parameters: (layer=2, size 10240) pre_attention_norm_scale
|
||||||
|
Loading Parameters: (layer=2, size 10240) pre_ffw_norm_scale
|
||||||
|
Loading Parameters: (layer=2, size 61440) ffw_gating_biases
|
||||||
|
Loading Parameters: (layer=2, size 10240) ffw_output_biases
|
||||||
|
Loading Parameters: (layer=2, size 10240) attention_output_biases
|
||||||
|
Loading Parameters: (layer=3, size 26214400) griffin_linear_x_w
|
||||||
|
Loading Parameters: (layer=3, size 10240) griffin_linear_x_biases
|
||||||
|
Loading Parameters: (layer=3, size 26214400) griffin_linear_y_w
|
||||||
|
Loading Parameters: (layer=3, size 10240) griffin_linear_y_biases
|
||||||
|
Loading Parameters: (layer=3, size 26214400) griffin_linear_out_w
|
||||||
|
Loading Parameters: (layer=3, size 10240) griffin_linear_out_biases
|
||||||
|
Loading Parameters: (layer=3, size 40960) griffin_conv_w
|
||||||
|
Loading Parameters: (layer=3, size 10240) griffin_conv_biases
|
||||||
|
Loading Parameters: (layer=3, size 5242880) griffin_gate_w
|
||||||
|
Loading Parameters: (layer=3, size 20480) griffin_gate_biases
|
||||||
|
Loading Parameters: (layer=3, size 10240) griffin_a
|
||||||
|
Loading Parameters: (layer=3, size 157286400) gating_einsum_w
|
||||||
|
Loading Parameters: (layer=3, size 78643200) linear_w
|
||||||
|
Loading Parameters: (layer=3, size 10240) pre_attention_norm_scale
|
||||||
|
Loading Parameters: (layer=3, size 10240) pre_ffw_norm_scale
|
||||||
|
Loading Parameters: (layer=3, size 61440) ffw_gating_biases
|
||||||
|
Loading Parameters: (layer=3, size 10240) ffw_output_biases
|
||||||
|
Loading Parameters: (layer=4, size 26214400) griffin_linear_x_w
|
||||||
|
Loading Parameters: (layer=4, size 10240) griffin_linear_x_biases
|
||||||
|
Loading Parameters: (layer=4, size 26214400) griffin_linear_y_w
|
||||||
|
Loading Parameters: (layer=4, size 10240) griffin_linear_y_biases
|
||||||
|
Loading Parameters: (layer=4, size 26214400) griffin_linear_out_w
|
||||||
|
Loading Parameters: (layer=4, size 10240) griffin_linear_out_biases
|
||||||
|
Loading Parameters: (layer=4, size 40960) griffin_conv_w
|
||||||
|
Loading Parameters: (layer=4, size 10240) griffin_conv_biases
|
||||||
|
Loading Parameters: (layer=4, size 5242880) griffin_gate_w
|
||||||
|
Loading Parameters: (layer=4, size 20480) griffin_gate_biases
|
||||||
|
Loading Parameters: (layer=4, size 10240) griffin_a
|
||||||
|
Loading Parameters: (layer=4, size 157286400) gating_einsum_w
|
||||||
|
Loading Parameters: (layer=4, size 78643200) linear_w
|
||||||
|
Loading Parameters: (layer=4, size 10240) pre_attention_norm_scale
|
||||||
|
Loading Parameters: (layer=4, size 10240) pre_ffw_norm_scale
|
||||||
|
Loading Parameters: (layer=4, size 61440) ffw_gating_biases
|
||||||
|
Loading Parameters: (layer=4, size 10240) ffw_output_biases
|
||||||
|
Loading Parameters: (layer=5, size 26214400) attn_vec_einsum_w
|
||||||
|
Loading Parameters: (layer=5, size 78643200) qkv_einsum_w
|
||||||
|
Loading Parameters: (layer=5, size 157286400) gating_einsum_w
|
||||||
|
Loading Parameters: (layer=5, size 78643200) linear_w
|
||||||
|
Loading Parameters: (layer=5, size 10240) pre_attention_norm_scale
|
||||||
|
Loading Parameters: (layer=5, size 10240) pre_ffw_norm_scale
|
||||||
|
Loading Parameters: (layer=5, size 61440) ffw_gating_biases
|
||||||
|
Loading Parameters: (layer=5, size 10240) ffw_output_biases
|
||||||
|
Loading Parameters: (layer=5, size 10240) attention_output_biases
|
||||||
|
Loading Parameters: (layer=6, size 26214400) griffin_linear_x_w
|
||||||
|
Loading Parameters: (layer=6, size 10240) griffin_linear_x_biases
|
||||||
|
Loading Parameters: (layer=6, size 26214400) griffin_linear_y_w
|
||||||
|
Loading Parameters: (layer=6, size 10240) griffin_linear_y_biases
|
||||||
|
Loading Parameters: (layer=6, size 26214400) griffin_linear_out_w
|
||||||
|
Loading Parameters: (layer=6, size 10240) griffin_linear_out_biases
|
||||||
|
Loading Parameters: (layer=6, size 40960) griffin_conv_w
|
||||||
|
Loading Parameters: (layer=6, size 10240) griffin_conv_biases
|
||||||
|
Loading Parameters: (layer=6, size 5242880) griffin_gate_w
|
||||||
|
Loading Parameters: (layer=6, size 20480) griffin_gate_biases
|
||||||
|
Loading Parameters: (layer=6, size 10240) griffin_a
|
||||||
|
Loading Parameters: (layer=6, size 157286400) gating_einsum_w
|
||||||
|
Loading Parameters: (layer=6, size 78643200) linear_w
|
||||||
|
Loading Parameters: (layer=6, size 10240) pre_attention_norm_scale
|
||||||
|
Loading Parameters: (layer=6, size 10240) pre_ffw_norm_scale
|
||||||
|
Loading Parameters: (layer=6, size 61440) ffw_gating_biases
|
||||||
|
Loading Parameters: (layer=6, size 10240) ffw_output_biases
|
||||||
|
Loading Parameters: (layer=7, size 26214400) griffin_linear_x_w
|
||||||
|
Loading Parameters: (layer=7, size 10240) griffin_linear_x_biases
|
||||||
|
Loading Parameters: (layer=7, size 26214400) griffin_linear_y_w
|
||||||
|
Loading Parameters: (layer=7, size 10240) griffin_linear_y_biases
|
||||||
|
Loading Parameters: (layer=7, size 26214400) griffin_linear_out_w
|
||||||
|
Loading Parameters: (layer=7, size 10240) griffin_linear_out_biases
|
||||||
|
Loading Parameters: (layer=7, size 40960) griffin_conv_w
|
||||||
|
Loading Parameters: (layer=7, size 10240) griffin_conv_biases
|
||||||
|
Loading Parameters: (layer=7, size 5242880) griffin_gate_w
|
||||||
|
Loading Parameters: (layer=7, size 20480) griffin_gate_biases
|
||||||
|
Loading Parameters: (layer=7, size 10240) griffin_a
|
||||||
|
Loading Parameters: (layer=7, size 157286400) gating_einsum_w
|
||||||
|
Loading Parameters: (layer=7, size 78643200) linear_w
|
||||||
|
Loading Parameters: (layer=7, size 10240) pre_attention_norm_scale
|
||||||
|
Loading Parameters: (layer=7, size 10240) pre_ffw_norm_scale
|
||||||
|
Loading Parameters: (layer=7, size 61440) ffw_gating_biases
|
||||||
|
Loading Parameters: (layer=7, size 10240) ffw_output_biases
|
||||||
|
Loading Parameters: (layer=8, size 26214400) attn_vec_einsum_w
|
||||||
|
Loading Parameters: (layer=8, size 78643200) qkv_einsum_w
|
||||||
|
Loading Parameters: (layer=8, size 157286400) gating_einsum_w
|
||||||
|
Loading Parameters: (layer=8, size 78643200) linear_w
|
||||||
|
Loading Parameters: (layer=8, size 10240) pre_attention_norm_scale
|
||||||
|
Loading Parameters: (layer=8, size 10240) pre_ffw_norm_scale
|
||||||
|
Loading Parameters: (layer=8, size 61440) ffw_gating_biases
|
||||||
|
Loading Parameters: (layer=8, size 10240) ffw_output_biases
|
||||||
|
Loading Parameters: (layer=8, size 10240) attention_output_biases
|
||||||
|
Loading Parameters: (layer=9, size 26214400) griffin_linear_x_w
|
||||||
|
Loading Parameters: (layer=9, size 10240) griffin_linear_x_biases
|
||||||
|
Loading Parameters: (layer=9, size 26214400) griffin_linear_y_w
|
||||||
|
Loading Parameters: (layer=9, size 10240) griffin_linear_y_biases
|
||||||
|
Loading Parameters: (layer=9, size 26214400) griffin_linear_out_w
|
||||||
|
Loading Parameters: (layer=9, size 10240) griffin_linear_out_biases
|
||||||
|
Loading Parameters: (layer=9, size 40960) griffin_conv_w
|
||||||
|
Loading Parameters: (layer=9, size 10240) griffin_conv_biases
|
||||||
|
Loading Parameters: (layer=9, size 5242880) griffin_gate_w
|
||||||
|
Loading Parameters: (layer=9, size 20480) griffin_gate_biases
|
||||||
|
Loading Parameters: (layer=9, size 10240) griffin_a
|
||||||
|
Loading Parameters: (layer=9, size 157286400) gating_einsum_w
|
||||||
|
Loading Parameters: (layer=9, size 78643200) linear_w
|
||||||
|
Loading Parameters: (layer=9, size 10240) pre_attention_norm_scale
|
||||||
|
Loading Parameters: (layer=9, size 10240) pre_ffw_norm_scale
|
||||||
|
Loading Parameters: (layer=9, size 61440) ffw_gating_biases
|
||||||
|
Loading Parameters: (layer=9, size 10240) ffw_output_biases
|
||||||
|
Loading Parameters: (layer=10, size 26214400) griffin_linear_x_w
|
||||||
|
Loading Parameters: (layer=10, size 10240) griffin_linear_x_biases
|
||||||
|
Loading Parameters: (layer=10, size 26214400) griffin_linear_y_w
|
||||||
|
Loading Parameters: (layer=10, size 10240) griffin_linear_y_biases
|
||||||
|
Loading Parameters: (layer=10, size 26214400) griffin_linear_out_w
|
||||||
|
Loading Parameters: (layer=10, size 10240) griffin_linear_out_biases
|
||||||
|
Loading Parameters: (layer=10, size 40960) griffin_conv_w
|
||||||
|
Loading Parameters: (layer=10, size 10240) griffin_conv_biases
|
||||||
|
Loading Parameters: (layer=10, size 5242880) griffin_gate_w
|
||||||
|
Loading Parameters: (layer=10, size 20480) griffin_gate_biases
|
||||||
|
Loading Parameters: (layer=10, size 10240) griffin_a
|
||||||
|
Loading Parameters: (layer=10, size 157286400) gating_einsum_w
|
||||||
|
Loading Parameters: (layer=10, size 78643200) linear_w
|
||||||
|
Loading Parameters: (layer=10, size 10240) pre_attention_norm_scale
|
||||||
|
Loading Parameters: (layer=10, size 10240) pre_ffw_norm_scale
|
||||||
|
Loading Parameters: (layer=10, size 61440) ffw_gating_biases
|
||||||
|
Loading Parameters: (layer=10, size 10240) ffw_output_biases
|
||||||
|
Loading Parameters: (layer=11, size 26214400) attn_vec_einsum_w
|
||||||
|
Loading Parameters: (layer=11, size 78643200) qkv_einsum_w
|
||||||
|
Loading Parameters: (layer=11, size 157286400) gating_einsum_w
|
||||||
|
Loading Parameters: (layer=11, size 78643200) linear_w
|
||||||
|
Loading Parameters: (layer=11, size 10240) pre_attention_norm_scale
|
||||||
|
Loading Parameters: (layer=11, size 10240) pre_ffw_norm_scale
|
||||||
|
Loading Parameters: (layer=11, size 61440) ffw_gating_biases
|
||||||
|
Loading Parameters: (layer=11, size 10240) ffw_output_biases
|
||||||
|
Loading Parameters: (layer=11, size 10240) attention_output_biases
|
||||||
|
Loading Parameters: (layer=12, size 26214400) griffin_linear_x_w
|
||||||
|
Loading Parameters: (layer=12, size 10240) griffin_linear_x_biases
|
||||||
|
Loading Parameters: (layer=12, size 26214400) griffin_linear_y_w
|
||||||
|
Loading Parameters: (layer=12, size 10240) griffin_linear_y_biases
|
||||||
|
Loading Parameters: (layer=12, size 26214400) griffin_linear_out_w
|
||||||
|
Loading Parameters: (layer=12, size 10240) griffin_linear_out_biases
|
||||||
|
Loading Parameters: (layer=12, size 40960) griffin_conv_w
|
||||||
|
Loading Parameters: (layer=12, size 10240) griffin_conv_biases
|
||||||
|
Loading Parameters: (layer=12, size 5242880) griffin_gate_w
|
||||||
|
Loading Parameters: (layer=12, size 20480) griffin_gate_biases
|
||||||
|
Loading Parameters: (layer=12, size 10240) griffin_a
|
||||||
|
Loading Parameters: (layer=12, size 157286400) gating_einsum_w
|
||||||
|
Loading Parameters: (layer=12, size 78643200) linear_w
|
||||||
|
Loading Parameters: (layer=12, size 10240) pre_attention_norm_scale
|
||||||
|
Loading Parameters: (layer=12, size 10240) pre_ffw_norm_scale
|
||||||
|
Loading Parameters: (layer=12, size 61440) ffw_gating_biases
|
||||||
|
Loading Parameters: (layer=12, size 10240) ffw_output_biases
|
||||||
|
Loading Parameters: (layer=13, size 26214400) griffin_linear_x_w
|
||||||
|
Loading Parameters: (layer=13, size 10240) griffin_linear_x_biases
|
||||||
|
Loading Parameters: (layer=13, size 26214400) griffin_linear_y_w
|
||||||
|
Loading Parameters: (layer=13, size 10240) griffin_linear_y_biases
|
||||||
|
Loading Parameters: (layer=13, size 26214400) griffin_linear_out_w
|
||||||
|
Loading Parameters: (layer=13, size 10240) griffin_linear_out_biases
|
||||||
|
Loading Parameters: (layer=13, size 40960) griffin_conv_w
|
||||||
|
Loading Parameters: (layer=13, size 10240) griffin_conv_biases
|
||||||
|
Loading Parameters: (layer=13, size 5242880) griffin_gate_w
|
||||||
|
Loading Parameters: (layer=13, size 20480) griffin_gate_biases
|
||||||
|
Loading Parameters: (layer=13, size 10240) griffin_a
|
||||||
|
Loading Parameters: (layer=13, size 157286400) gating_einsum_w
|
||||||
|
Loading Parameters: (layer=13, size 78643200) linear_w
|
||||||
|
Loading Parameters: (layer=13, size 10240) pre_attention_norm_scale
|
||||||
|
Loading Parameters: (layer=13, size 10240) pre_ffw_norm_scale
|
||||||
|
Loading Parameters: (layer=13, size 61440) ffw_gating_biases
|
||||||
|
Loading Parameters: (layer=13, size 10240) ffw_output_biases
|
||||||
|
Loading Parameters: (layer=14, size 26214400) attn_vec_einsum_w
|
||||||
|
Loading Parameters: (layer=14, size 78643200) qkv_einsum_w
|
||||||
|
Loading Parameters: (layer=14, size 157286400) gating_einsum_w
|
||||||
|
Loading Parameters: (layer=14, size 78643200) linear_w
|
||||||
|
Loading Parameters: (layer=14, size 10240) pre_attention_norm_scale
|
||||||
|
Loading Parameters: (layer=14, size 10240) pre_ffw_norm_scale
|
||||||
|
Loading Parameters: (layer=14, size 61440) ffw_gating_biases
|
||||||
|
Loading Parameters: (layer=14, size 10240) ffw_output_biases
|
||||||
|
Loading Parameters: (layer=14, size 10240) attention_output_biases
|
||||||
|
Loading Parameters: (layer=15, size 26214400) griffin_linear_x_w
|
||||||
|
Loading Parameters: (layer=15, size 10240) griffin_linear_x_biases
|
||||||
|
Loading Parameters: (layer=15, size 26214400) griffin_linear_y_w
|
||||||
|
Loading Parameters: (layer=15, size 10240) griffin_linear_y_biases
|
||||||
|
Loading Parameters: (layer=15, size 26214400) griffin_linear_out_w
|
||||||
|
Loading Parameters: (layer=15, size 10240) griffin_linear_out_biases
|
||||||
|
Loading Parameters: (layer=15, size 40960) griffin_conv_w
|
||||||
|
Loading Parameters: (layer=15, size 10240) griffin_conv_biases
|
||||||
|
Loading Parameters: (layer=15, size 5242880) griffin_gate_w
|
||||||
|
Loading Parameters: (layer=15, size 20480) griffin_gate_biases
|
||||||
|
Loading Parameters: (layer=15, size 10240) griffin_a
|
||||||
|
Loading Parameters: (layer=15, size 157286400) gating_einsum_w
|
||||||
|
Loading Parameters: (layer=15, size 78643200) linear_w
|
||||||
|
Loading Parameters: (layer=15, size 10240) pre_attention_norm_scale
|
||||||
|
Loading Parameters: (layer=15, size 10240) pre_ffw_norm_scale
|
||||||
|
Loading Parameters: (layer=15, size 61440) ffw_gating_biases
|
||||||
|
Loading Parameters: (layer=15, size 10240) ffw_output_biases
|
||||||
|
Loading Parameters: (layer=16, size 26214400) griffin_linear_x_w
|
||||||
|
Loading Parameters: (layer=16, size 10240) griffin_linear_x_biases
|
||||||
|
Loading Parameters: (layer=16, size 26214400) griffin_linear_y_w
|
||||||
|
Loading Parameters: (layer=16, size 10240) griffin_linear_y_biases
|
||||||
|
Loading Parameters: (layer=16, size 26214400) griffin_linear_out_w
|
||||||
|
Loading Parameters: (layer=16, size 10240) griffin_linear_out_biases
|
||||||
|
Loading Parameters: (layer=16, size 40960) griffin_conv_w
|
||||||
|
Loading Parameters: (layer=16, size 10240) griffin_conv_biases
|
||||||
|
Loading Parameters: (layer=16, size 5242880) griffin_gate_w
|
||||||
|
Loading Parameters: (layer=16, size 20480) griffin_gate_biases
|
||||||
|
Loading Parameters: (layer=16, size 10240) griffin_a
|
||||||
|
Loading Parameters: (layer=16, size 157286400) gating_einsum_w
|
||||||
|
Loading Parameters: (layer=16, size 78643200) linear_w
|
||||||
|
Loading Parameters: (layer=16, size 10240) pre_attention_norm_scale
|
||||||
|
Loading Parameters: (layer=16, size 10240) pre_ffw_norm_scale
|
||||||
|
Loading Parameters: (layer=16, size 61440) ffw_gating_biases
|
||||||
|
Loading Parameters: (layer=16, size 10240) ffw_output_biases
|
||||||
|
Loading Parameters: (layer=17, size 26214400) attn_vec_einsum_w
|
||||||
|
Loading Parameters: (layer=17, size 78643200) qkv_einsum_w
|
||||||
|
Loading Parameters: (layer=17, size 157286400) gating_einsum_w
|
||||||
|
Loading Parameters: (layer=17, size 78643200) linear_w
|
||||||
|
Loading Parameters: (layer=17, size 10240) pre_attention_norm_scale
|
||||||
|
Loading Parameters: (layer=17, size 10240) pre_ffw_norm_scale
|
||||||
|
Loading Parameters: (layer=17, size 61440) ffw_gating_biases
|
||||||
|
Loading Parameters: (layer=17, size 10240) ffw_output_biases
|
||||||
|
Loading Parameters: (layer=17, size 10240) attention_output_biases
|
||||||
|
Loading Parameters: (layer=18, size 26214400) griffin_linear_x_w
|
||||||
|
Loading Parameters: (layer=18, size 10240) griffin_linear_x_biases
|
||||||
|
Loading Parameters: (layer=18, size 26214400) griffin_linear_y_w
|
||||||
|
Loading Parameters: (layer=18, size 10240) griffin_linear_y_biases
|
||||||
|
Loading Parameters: (layer=18, size 26214400) griffin_linear_out_w
|
||||||
|
Loading Parameters: (layer=18, size 10240) griffin_linear_out_biases
|
||||||
|
Loading Parameters: (layer=18, size 40960) griffin_conv_w
|
||||||
|
Loading Parameters: (layer=18, size 10240) griffin_conv_biases
|
||||||
|
Loading Parameters: (layer=18, size 5242880) griffin_gate_w
|
||||||
|
Loading Parameters: (layer=18, size 20480) griffin_gate_biases
|
||||||
|
Loading Parameters: (layer=18, size 10240) griffin_a
|
||||||
|
Loading Parameters: (layer=18, size 157286400) gating_einsum_w
|
||||||
|
Loading Parameters: (layer=18, size 78643200) linear_w
|
||||||
|
Loading Parameters: (layer=18, size 10240) pre_attention_norm_scale
|
||||||
|
Loading Parameters: (layer=18, size 10240) pre_ffw_norm_scale
|
||||||
|
Loading Parameters: (layer=18, size 61440) ffw_gating_biases
|
||||||
|
Loading Parameters: (layer=18, size 10240) ffw_output_biases
|
||||||
|
Loading Parameters: (layer=19, size 26214400) griffin_linear_x_w
|
||||||
|
Loading Parameters: (layer=19, size 10240) griffin_linear_x_biases
|
||||||
|
Loading Parameters: (layer=19, size 26214400) griffin_linear_y_w
|
||||||
|
Loading Parameters: (layer=19, size 10240) griffin_linear_y_biases
|
||||||
|
Loading Parameters: (layer=19, size 26214400) griffin_linear_out_w
|
||||||
|
Loading Parameters: (layer=19, size 10240) griffin_linear_out_biases
|
||||||
|
Loading Parameters: (layer=19, size 40960) griffin_conv_w
|
||||||
|
Loading Parameters: (layer=19, size 10240) griffin_conv_biases
|
||||||
|
Loading Parameters: (layer=19, size 5242880) griffin_gate_w
|
||||||
|
Loading Parameters: (layer=19, size 20480) griffin_gate_biases
|
||||||
|
Loading Parameters: (layer=19, size 10240) griffin_a
|
||||||
|
Loading Parameters: (layer=19, size 157286400) gating_einsum_w
|
||||||
|
Loading Parameters: (layer=19, size 78643200) linear_w
|
||||||
|
Loading Parameters: (layer=19, size 10240) pre_attention_norm_scale
|
||||||
|
Loading Parameters: (layer=19, size 10240) pre_ffw_norm_scale
|
||||||
|
Loading Parameters: (layer=19, size 61440) ffw_gating_biases
|
||||||
|
Loading Parameters: (layer=19, size 10240) ffw_output_biases
|
||||||
|
Loading Parameters: (layer=20, size 26214400) attn_vec_einsum_w
|
||||||
|
Loading Parameters: (layer=20, size 78643200) qkv_einsum_w
|
||||||
|
Loading Parameters: (layer=20, size 157286400) gating_einsum_w
|
||||||
|
Loading Parameters: (layer=20, size 78643200) linear_w
|
||||||
|
Loading Parameters: (layer=20, size 10240) pre_attention_norm_scale
|
||||||
|
Loading Parameters: (layer=20, size 10240) pre_ffw_norm_scale
|
||||||
|
Loading Parameters: (layer=20, size 61440) ffw_gating_biases
|
||||||
|
Loading Parameters: (layer=20, size 10240) ffw_output_biases
|
||||||
|
Loading Parameters: (layer=20, size 10240) attention_output_biases
|
||||||
|
Loading Parameters: (layer=21, size 26214400) griffin_linear_x_w
|
||||||
|
Loading Parameters: (layer=21, size 10240) griffin_linear_x_biases
|
||||||
|
Loading Parameters: (layer=21, size 26214400) griffin_linear_y_w
|
||||||
|
Loading Parameters: (layer=21, size 10240) griffin_linear_y_biases
|
||||||
|
Loading Parameters: (layer=21, size 26214400) griffin_linear_out_w
|
||||||
|
Loading Parameters: (layer=21, size 10240) griffin_linear_out_biases
|
||||||
|
Loading Parameters: (layer=21, size 40960) griffin_conv_w
|
||||||
|
Loading Parameters: (layer=21, size 10240) griffin_conv_biases
|
||||||
|
Loading Parameters: (layer=21, size 5242880) griffin_gate_w
|
||||||
|
Loading Parameters: (layer=21, size 20480) griffin_gate_biases
|
||||||
|
Loading Parameters: (layer=21, size 10240) griffin_a
|
||||||
|
Loading Parameters: (layer=21, size 157286400) gating_einsum_w
|
||||||
|
Loading Parameters: (layer=21, size 78643200) linear_w
|
||||||
|
Loading Parameters: (layer=21, size 10240) pre_attention_norm_scale
|
||||||
|
Loading Parameters: (layer=21, size 10240) pre_ffw_norm_scale
|
||||||
|
Loading Parameters: (layer=21, size 61440) ffw_gating_biases
|
||||||
|
Loading Parameters: (layer=21, size 10240) ffw_output_biases
|
||||||
|
Loading Parameters: (layer=22, size 26214400) griffin_linear_x_w
|
||||||
|
Loading Parameters: (layer=22, size 10240) griffin_linear_x_biases
|
||||||
|
Loading Parameters: (layer=22, size 26214400) griffin_linear_y_w
|
||||||
|
Loading Parameters: (layer=22, size 10240) griffin_linear_y_biases
|
||||||
|
Loading Parameters: (layer=22, size 26214400) griffin_linear_out_w
|
||||||
|
Loading Parameters: (layer=22, size 10240) griffin_linear_out_biases
|
||||||
|
Loading Parameters: (layer=22, size 40960) griffin_conv_w
|
||||||
|
Loading Parameters: (layer=22, size 10240) griffin_conv_biases
|
||||||
|
Loading Parameters: (layer=22, size 5242880) griffin_gate_w
|
||||||
|
Loading Parameters: (layer=22, size 20480) griffin_gate_biases
|
||||||
|
Loading Parameters: (layer=22, size 10240) griffin_a
|
||||||
|
Loading Parameters: (layer=22, size 157286400) gating_einsum_w
|
||||||
|
Loading Parameters: (layer=22, size 78643200) linear_w
|
||||||
|
Loading Parameters: (layer=22, size 10240) pre_attention_norm_scale
|
||||||
|
Loading Parameters: (layer=22, size 10240) pre_ffw_norm_scale
|
||||||
|
Loading Parameters: (layer=22, size 61440) ffw_gating_biases
|
||||||
|
Loading Parameters: (layer=22, size 10240) ffw_output_biases
|
||||||
|
Loading Parameters: (layer=23, size 26214400) attn_vec_einsum_w
|
||||||
|
Loading Parameters: (layer=23, size 78643200) qkv_einsum_w
|
||||||
|
Loading Parameters: (layer=23, size 157286400) gating_einsum_w
|
||||||
|
Loading Parameters: (layer=23, size 78643200) linear_w
|
||||||
|
Loading Parameters: (layer=23, size 10240) pre_attention_norm_scale
|
||||||
|
Loading Parameters: (layer=23, size 10240) pre_ffw_norm_scale
|
||||||
|
Loading Parameters: (layer=23, size 61440) ffw_gating_biases
|
||||||
|
Loading Parameters: (layer=23, size 10240) ffw_output_biases
|
||||||
|
Loading Parameters: (layer=23, size 10240) attention_output_biases
|
||||||
|
Loading Parameters: (layer=24, size 26214400) griffin_linear_x_w
|
||||||
|
Loading Parameters: (layer=24, size 10240) griffin_linear_x_biases
|
||||||
|
Loading Parameters: (layer=24, size 26214400) griffin_linear_y_w
|
||||||
|
Loading Parameters: (layer=24, size 10240) griffin_linear_y_biases
|
||||||
|
Loading Parameters: (layer=24, size 26214400) griffin_linear_out_w
|
||||||
|
Loading Parameters: (layer=24, size 10240) griffin_linear_out_biases
|
||||||
|
Loading Parameters: (layer=24, size 40960) griffin_conv_w
|
||||||
|
Loading Parameters: (layer=24, size 10240) griffin_conv_biases
|
||||||
|
Loading Parameters: (layer=24, size 5242880) griffin_gate_w
|
||||||
|
Loading Parameters: (layer=24, size 20480) griffin_gate_biases
|
||||||
|
Loading Parameters: (layer=24, size 10240) griffin_a
|
||||||
|
Loading Parameters: (layer=24, size 157286400) gating_einsum_w
|
||||||
|
Loading Parameters: (layer=24, size 78643200) linear_w
|
||||||
|
Loading Parameters: (layer=24, size 10240) pre_attention_norm_scale
|
||||||
|
Loading Parameters: (layer=24, size 10240) pre_ffw_norm_scale
|
||||||
|
Loading Parameters: (layer=24, size 61440) ffw_gating_biases
|
||||||
|
Loading Parameters: (layer=24, size 10240) ffw_output_biases
|
||||||
|
Loading Parameters: (layer=25, size 26214400) griffin_linear_x_w
|
||||||
|
Loading Parameters: (layer=25, size 10240) griffin_linear_x_biases
|
||||||
|
Loading Parameters: (layer=25, size 26214400) griffin_linear_y_w
|
||||||
|
Loading Parameters: (layer=25, size 10240) griffin_linear_y_biases
|
||||||
|
Loading Parameters: (layer=25, size 26214400) griffin_linear_out_w
|
||||||
|
Loading Parameters: (layer=25, size 10240) griffin_linear_out_biases
|
||||||
|
Loading Parameters: (layer=25, size 40960) griffin_conv_w
|
||||||
|
Loading Parameters: (layer=25, size 10240) griffin_conv_biases
|
||||||
|
Loading Parameters: (layer=25, size 5242880) griffin_gate_w
|
||||||
|
Loading Parameters: (layer=25, size 20480) griffin_gate_biases
|
||||||
|
Loading Parameters: (layer=25, size 10240) griffin_a
|
||||||
|
Loading Parameters: (layer=25, size 157286400) gating_einsum_w
|
||||||
|
Loading Parameters: (layer=25, size 78643200) linear_w
|
||||||
|
Loading Parameters: (layer=25, size 10240) pre_attention_norm_scale
|
||||||
|
Loading Parameters: (layer=25, size 10240) pre_ffw_norm_scale
|
||||||
|
Loading Parameters: (layer=25, size 61440) ffw_gating_biases
|
||||||
|
Loading Parameters: (layer=25, size 10240) ffw_output_biases
|
||||||
|
|
@ -0,0 +1,371 @@
|
||||||
|
"""Transformations for python-trees representing the parameters of a ML model.
|
||||||
|
|
||||||
|
Important: This module assumes that byte-order is the same on the
|
||||||
|
machine that serializes data and the machine that deserializes
|
||||||
|
data. If, for example, numpy-data gets dumped, respectively loaded,
|
||||||
|
with a dtype-specification of numpy.float32, on-file byte-order
|
||||||
|
will be host byte order.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
import ast
|
||||||
|
import hashlib
|
||||||
|
import itertools
|
||||||
|
import pprint
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from typing import Any, Callable, Iterable, Iterator, Mapping, TypeVar
|
||||||
|
|
||||||
|
import numpy
|
||||||
|
import pytree_transforms
|
||||||
|
|
||||||
|
|
||||||
|
NT = TypeVar('NT')
|
||||||
|
|
||||||
|
|
||||||
|
def ml_model_leaf_summary(path, x, sep=', '):
|
||||||
|
"""Produces a textual summary of a leaf-node and its path.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: The path-to-root, as a reverse-order recursive
|
||||||
|
pair of path-components, with `()` as root.
|
||||||
|
x: The leaf-object.
|
||||||
|
sep: the separator between description-elements.
|
||||||
|
Default ', ' allows for convenient line-by-line processing
|
||||||
|
(such as via grep, perl -ne, etc.), but using e.g. sep=',\n '
|
||||||
|
might be more useful for human consumption.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A human-readable string providing information about the node.
|
||||||
|
"""
|
||||||
|
# Using `repr` for path-components to get a faithful presentation.
|
||||||
|
# (...which still however would be somewat painful to correctly
|
||||||
|
# split into components.)
|
||||||
|
path_str = ','.join(map(repr,
|
||||||
|
pytree_transforms.linearize_revtuple_path(path)))
|
||||||
|
tx = type(x)
|
||||||
|
mod = tx.__module__ # Either a module or a string like 'builtins'.
|
||||||
|
modname = mod if isinstance(mod, str) else mod.__name__
|
||||||
|
type_str = f'{modname}.{tx.__qualname__}'
|
||||||
|
try:
|
||||||
|
# `numpy.ndarray` instances have a `.data` property that gives access
|
||||||
|
# to a buffer via which we can hashlib-fingerprint the numerical
|
||||||
|
# contents. We here simply try to produce a fingerprint and also look
|
||||||
|
# up the .dtype of the object. Technically, there is a somewhat-unsound
|
||||||
|
# assumption here that if these operations succeed, we are indeed looking
|
||||||
|
# at a ndarray or sufficiently similar object for these operations to
|
||||||
|
# make sense. As the output is declared "for human consumption", this
|
||||||
|
# fishiness is not a problem.
|
||||||
|
fp = hashlib.sha256(x.data).hexdigest()
|
||||||
|
start = list(itertools.islice(x.flat, 5))
|
||||||
|
stats_str = (
|
||||||
|
f'min={numpy.min(x):.6g}, max={numpy.max(x):.6g}, '
|
||||||
|
f'mean={numpy.mean(x):.6g}, std={numpy.std(x):.6g}')
|
||||||
|
return (f'{path_str:60s}: <{type_str}{sep}'
|
||||||
|
f'fp=0x{fp}{sep}{stats_str}{sep}shape={x.shape}, '
|
||||||
|
f'dtype={x.dtype}{sep}start={start}>')
|
||||||
|
except (AttributeError, ValueError, TypeError):
|
||||||
|
# Fallback - trying to include information about the data-content
|
||||||
|
# of a likely-numerical-array failed.
|
||||||
|
return f'{path_str:60s}: {type_str}({repr(x)})'
|
||||||
|
|
||||||
|
|
||||||
|
# A specialized node-handler.
|
||||||
|
# Interface follows node-handler expectations defined in pytree_transforms.
|
||||||
|
def _ml_model_tree_node_handler(path: tuple, node : NT) -> (
|
||||||
|
None | tuple[Callable[[Iterable[tuple[Any, NT]]], NT],
|
||||||
|
Iterator[tuple[Any, NT]]]):
|
||||||
|
"""Processes a tree-node as required by pytree-iteration and -mapping.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: revtuple path to the current node.
|
||||||
|
node: a tree-node in a ML-model tree that is recursively
|
||||||
|
built out of `numpy.ndarray` leaf-values and dicts mapping
|
||||||
|
node-name string-keys to other such nodes representing subtrees -
|
||||||
|
and nothing else.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`None` if the tree-node is to be regarded as a leaf, otherwise
|
||||||
|
a pair `(rebuilder, iterator)`, where `iterator` iterates
|
||||||
|
over the data-content of the node, each item represented as a pair
|
||||||
|
of `(lookup_path_item, value_item)`, and `rebuilder` is a function
|
||||||
|
which, when applied to `iterator` or any iterable with the same
|
||||||
|
elements, returns a node that is equivalent to the original.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NotAMLModelTreeNodeError: If the tree contains a node that is neither
|
||||||
|
a `dict` nor a `numpy.ndarray` instance.
|
||||||
|
"""
|
||||||
|
# The astute reader will notice that we are doing something fishy
|
||||||
|
# here - this code could not be translated to Haskell as-is, since
|
||||||
|
# `NT` cannot actually be a proper type-variable in the sense
|
||||||
|
# of parametric polymorphism.
|
||||||
|
del path # Unused.
|
||||||
|
if isinstance(node, dict):
|
||||||
|
return dict, iter(node.items())
|
||||||
|
if isinstance(node, numpy.ndarray):
|
||||||
|
return None
|
||||||
|
raise pytree_transforms.NotAMLModelTreeNodeError(
|
||||||
|
f'Type of bad node: {type(node)}')
|
||||||
|
|
||||||
|
|
||||||
|
def _ml_model_extract_leaf_transform(
|
||||||
|
path: pytree_transforms.RevTuplePath,
|
||||||
|
leaf: Any):
|
||||||
|
"""Maps an array-leaf to a pair `(full_path, lambda: array)`.
|
||||||
|
|
||||||
|
The computation that produces the leaf-value is lazified underneath
|
||||||
|
a `lambda`, since if we e.g. performed a memory-expensive
|
||||||
|
transformation (such as some dtype-changes) directly at this point,
|
||||||
|
then going from an iterator over tree-items for one-by-one
|
||||||
|
consumption to a list of these items would have all the
|
||||||
|
dtype-transformed values around simultaneously. We want to avoid
|
||||||
|
situations where we can do nothing about having multiple variants
|
||||||
|
of the data simultaneously in memory.
|
||||||
|
"""
|
||||||
|
# Hack: If we are encountering a `bfloat16` numpy-array,
|
||||||
|
# we pretend to have the data as a numpy.float32 array,
|
||||||
|
# since that's about all that contemporary CPUs can process
|
||||||
|
# efficiently here.
|
||||||
|
linearized_path = pytree_transforms.linearize_revtuple_path(path)
|
||||||
|
try:
|
||||||
|
# We have to use some trickery to detect `bfloat16`.
|
||||||
|
if leaf.dtype.descr[-1] == ('', '<V2'):
|
||||||
|
return linearized_path, lambda: leaf.astype(numpy.float32)
|
||||||
|
else:
|
||||||
|
return linearized_path, lambda: leaf
|
||||||
|
except Exception:
|
||||||
|
return linearized_path, lambda: leaf
|
||||||
|
|
||||||
|
|
||||||
|
# Here, we cannot properly specify the return-type, since this can
|
||||||
|
# either be a leaf-type or something recursively-defined.
|
||||||
|
def revtuple_autovifify_from_linear(
|
||||||
|
keys_and_vals: Iterable[tuple[Any, Any]]) -> Any:
|
||||||
|
"""Performs perl-style autovivification on a nested-dict tree.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
keys_and_vals: An iterable of pairs `(key_path, value)`, where
|
||||||
|
`key_path` is a sequence of keys to be used to navigate to
|
||||||
|
the result via iterative dict-lookup, left-to-right.
|
||||||
|
Must not have duplicate keys, and must not more than one key if
|
||||||
|
an empty-sequence key is present. If this iterable is an
|
||||||
|
iterator, it will be fully exhausted on successful execution.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An object representing a nested-dict structure such that
|
||||||
|
for every `key_path` from `keys_and_vals`, recursive-dict-lookup
|
||||||
|
on the elements of that path starting from this object will
|
||||||
|
produce the corresponding value. An empty `keys_and_vals`
|
||||||
|
set will return `{}`. Every dict in the nested return-value
|
||||||
|
that has been populated by autovivification is newly allocated.
|
||||||
|
"""
|
||||||
|
# Code structure is a bit gnarly here due to f(keys_and_vals=[((), x)])
|
||||||
|
# having to evaluate to x and not a dict.
|
||||||
|
# There may be ways to prettify/simplify this.
|
||||||
|
result = None
|
||||||
|
empty = {}
|
||||||
|
for linear_path, val in keys_and_vals:
|
||||||
|
if linear_path == ():
|
||||||
|
if result is not None:
|
||||||
|
raise ValueError('Root-value seen alongside other values.')
|
||||||
|
result = val
|
||||||
|
else:
|
||||||
|
if result is None:
|
||||||
|
result = {}
|
||||||
|
elif type(result) is not dict:
|
||||||
|
# We already did encounter a root-value.
|
||||||
|
raise ValueError('Root-value seen alongside other values.')
|
||||||
|
cursor = result
|
||||||
|
for n in range(len(linear_path) - 1):
|
||||||
|
cursor = cursor.setdefault(linear_path[n], empty)
|
||||||
|
if cursor is empty:
|
||||||
|
# Regenerate `empty` if we just used it up.
|
||||||
|
empty = {}
|
||||||
|
cursor[linear_path[-1]] = val
|
||||||
|
return {} if result is None else result
|
||||||
|
|
||||||
|
|
||||||
|
def model_overview(tree, out=None) -> None:
|
||||||
|
"""Prints a human-readable overview to `(out or sys.stdout)`."""
|
||||||
|
actual_out = out or sys.stdout
|
||||||
|
for line in pytree_transforms.pytree_leaf_iter(
|
||||||
|
tree, ml_model_leaf_summary,
|
||||||
|
_ml_model_tree_node_handler):
|
||||||
|
print(line, file=actual_out)
|
||||||
|
|
||||||
|
|
||||||
|
def model_contents(tree) -> Mapping[tuple[str, ...], Any]:
|
||||||
|
"""Maps a model to a {pytree_keys: data_array} mapping.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tree: The ML-model parameter-tree, built recursively out of
|
||||||
|
dict-instances with numpy.ndarray instances as leaves.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A mapping from linearized pytree-key-sequence tuple to the corresponding
|
||||||
|
leaf-value.
|
||||||
|
"""
|
||||||
|
def leaf_transform(revtuple_path, leaf):
|
||||||
|
return pytree_transforms.linearize_revtuple_path(revtuple_path), leaf
|
||||||
|
return dict(
|
||||||
|
pytree_transforms.pytree_leaf_iter(
|
||||||
|
tree, leaf_transform, _ml_model_tree_node_handler))
|
||||||
|
|
||||||
|
|
||||||
|
def _fn_identity(x): return x
|
||||||
|
|
||||||
|
|
||||||
|
def model_save(tree,
|
||||||
|
filepath_stem: str,
|
||||||
|
data_suffix: str = '.data',
|
||||||
|
manifest_suffix: str | None = '.manifest',
|
||||||
|
key: Callable[[tuple[str, ...]], Any] | None = None,
|
||||||
|
array_transform_by_pytree_key: (
|
||||||
|
Mapping[tuple[str, ...],
|
||||||
|
Callable[[numpy.ndarray], numpy.ndarray]] |
|
||||||
|
None) = None,
|
||||||
|
report: Callable[[str], None] | None = None,
|
||||||
|
byte_align: int = 8) -> tuple[int, float]:
|
||||||
|
"""Saves the content of a ML-model parameter-tree to filesystem.
|
||||||
|
|
||||||
|
After successful execution, the file f"{filepath_stem}.data"
|
||||||
|
will hold the combined numerical model-parameters, and
|
||||||
|
f"{filepath_stem}.manifest" will contain the key for interpreting
|
||||||
|
(and rebuilding) the data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tree: The ML-model parameter-tree, built recursively out of
|
||||||
|
dict-instances with numpy.ndarray instances as leaves.
|
||||||
|
filepath_stem: Filesystem location for data.
|
||||||
|
data_suffix: Suffix to use for the data file.
|
||||||
|
manifest_suffix: Either `None`, in which case no manifest-file
|
||||||
|
will get written, or the suffix for the manifest-file.
|
||||||
|
key: `None` or a key-function that will be applied to the linear model-path
|
||||||
|
and used for sorting the data arrays by increasing value of the
|
||||||
|
key-function. If the key-function returns `None` on an item,
|
||||||
|
then this item is not included.
|
||||||
|
array_transform_by_pytree_key: Optional mapping from pytree-key
|
||||||
|
to an array-to-array transformation function to apply to the array
|
||||||
|
prior to serialization.
|
||||||
|
report: Optional callable for logging progress-reports.
|
||||||
|
byte_align: byte-alignment to use for numerical array data.
|
||||||
|
Numerical arrays whose size in bytes is not a multiple of this
|
||||||
|
will get padded to the next full multiple.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A pair of `(size, time_sec)`, where `size` is the total byte-size
|
||||||
|
of the `.data` file and `time_sec` is the elapsed time
|
||||||
|
for saving the model, in seconds.
|
||||||
|
"""
|
||||||
|
time0 = time.monotonic()
|
||||||
|
if array_transform_by_pytree_key is None:
|
||||||
|
array_transform_by_pytree_key = {}
|
||||||
|
model_lazy_items = (
|
||||||
|
pytree_transforms.pytree_leaf_iter(
|
||||||
|
tree, _ml_model_extract_leaf_transform,
|
||||||
|
_ml_model_tree_node_handler))
|
||||||
|
if key is not None:
|
||||||
|
to_write = [
|
||||||
|
nkv[1:] for nkv in sorted(
|
||||||
|
(nkv for nkv in ((key(path), path, v)
|
||||||
|
for path, v in model_lazy_items)
|
||||||
|
if nkv[0] is not None), key=lambda nkv: nkv[0])]
|
||||||
|
else:
|
||||||
|
to_write = list(model_lazy_items)
|
||||||
|
#
|
||||||
|
def lazy_arr_path_shape_dtype_size(path_and_lazy_arr):
|
||||||
|
path, lazy_arr = path_and_lazy_arr
|
||||||
|
arr = array_transform_by_pytree_key.get(path, _fn_identity)(lazy_arr())
|
||||||
|
return path, arr.shape, arr.dtype, arr.data.nbytes
|
||||||
|
arrs_path_shape_dtype_nbytes = list(
|
||||||
|
map(lazy_arr_path_shape_dtype_size, to_write))
|
||||||
|
# We need to know the total size of all the data.
|
||||||
|
bytesizes = [nbytes for *_, nbytes in arrs_path_shape_dtype_nbytes]
|
||||||
|
padded_bytesizes = [-(-bytesize // byte_align * byte_align)
|
||||||
|
for bytesize in bytesizes]
|
||||||
|
offsets = numpy.cumsum([0] + padded_bytesizes)
|
||||||
|
membuf = numpy.memmap(filepath_stem + data_suffix,
|
||||||
|
mode='w+', shape=offsets[-1])
|
||||||
|
try:
|
||||||
|
for (path, shape, dtype, nbytes), offset, (_, lazy_arr) in zip(
|
||||||
|
arrs_path_shape_dtype_nbytes, offsets, to_write):
|
||||||
|
# Note that if getting the array from the lazy lambda involved some
|
||||||
|
# computation, such as a copying dtype-change, that computation would
|
||||||
|
# end up being done multiple times here - including once above, to compute
|
||||||
|
# byte-sizes, and once more here.
|
||||||
|
transformed_arr = array_transform_by_pytree_key.get(
|
||||||
|
path,
|
||||||
|
_fn_identity)(lazy_arr())
|
||||||
|
membuf[offset : offset + nbytes] = numpy.frombuffer(
|
||||||
|
transformed_arr.ravel().data, 'u1')
|
||||||
|
if report is not None:
|
||||||
|
samples = ', '.join(map(str, transformed_arr.ravel()[:5]))
|
||||||
|
report(f'# Adding: {path!r}\n bytes: {nbytes:10d}, '
|
||||||
|
f'shape: {shape!r:30},\n start: [{samples}, ...]')
|
||||||
|
transformed_arr = None # Drop memory references to numerical arrays ASAP.
|
||||||
|
finally:
|
||||||
|
if membuf is not None:
|
||||||
|
membuf.flush()
|
||||||
|
# NumPy wart: the memory-buffer is a resource that conceptually
|
||||||
|
# should be .close()able - since mmap()ing holds on to a
|
||||||
|
# file descriptor. However, it looks as if that clean-up were done
|
||||||
|
# in the "finalizer", despite that having meanwhile been widely
|
||||||
|
# understood as dubious practice. So, the best we can do here is
|
||||||
|
# to explicitly and clearly remove our reference to the instance.
|
||||||
|
del membuf
|
||||||
|
if manifest_suffix is not None:
|
||||||
|
# We still have to serialize the data that allows us to reconstruct
|
||||||
|
# a tree that is equivalent to the original.
|
||||||
|
manifest_data = [
|
||||||
|
dict(path=path,
|
||||||
|
dtype=dtype.descr[-1][-1],
|
||||||
|
shape=shape,
|
||||||
|
nbytes=nbytes,
|
||||||
|
offset=offset)
|
||||||
|
for (path, shape, dtype, nbytes), offset in zip(
|
||||||
|
arrs_path_shape_dtype_nbytes, offsets)]
|
||||||
|
with open(filepath_stem + '.manifest', 'wt') as h_manifest:
|
||||||
|
pprint.pprint(manifest_data, stream=h_manifest)
|
||||||
|
time_taken = time.monotonic() - time0
|
||||||
|
return offsets[-1], time_taken
|
||||||
|
|
||||||
|
|
||||||
|
def model_load(filepath_stem, mmapped=True):
|
||||||
|
"""Loads a model saved by `model_save`.
|
||||||
|
|
||||||
|
Tries to load the model from f"{filepath_stem}.data"
|
||||||
|
and f"{filepath_stem}.manifest".
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filepath_stem: The model location on the filesystem.
|
||||||
|
mmapped: Whether data-arrays will be slices of a
|
||||||
|
`numpy.memmap` mapped buffer, to be paged in
|
||||||
|
on demand only, or in-memory copies of the data.
|
||||||
|
Returns:
|
||||||
|
A dict/numpy.ndarray tree representation of the model,
|
||||||
|
equivalent to the original model.
|
||||||
|
"""
|
||||||
|
with open(filepath_stem + '.manifest', 'rt') as h_manifest:
|
||||||
|
manifest = ast.literal_eval(h_manifest.read())
|
||||||
|
membuf = numpy.memmap(filepath_stem + '.data', mode='r+')
|
||||||
|
paths_and_arrays = []
|
||||||
|
for item in manifest:
|
||||||
|
path = item['path']
|
||||||
|
dtype = numpy.dtype(item['dtype'])
|
||||||
|
shape = item['shape']
|
||||||
|
nbytes = item['nbytes']
|
||||||
|
offset = item['offset']
|
||||||
|
data_array = numpy.frombuffer(membuf[offset : offset + nbytes].data,
|
||||||
|
dtype=dtype).reshape(shape)
|
||||||
|
paths_and_arrays.append(
|
||||||
|
(path,
|
||||||
|
data_array if mmapped else data_array.copy()))
|
||||||
|
# At this point, the memory-buffer is no longer needed. Still, if
|
||||||
|
# data-arrays retain references to the underlying data
|
||||||
|
# (i.e. when mmapped=False), this should keep the mapping
|
||||||
|
# - and hence file descriptor - open. We then are in a somewhat
|
||||||
|
# undesirable situation of clean-up of a resource that happens in a
|
||||||
|
# hard-to-predict way releasing a file descriptor.
|
||||||
|
del membuf
|
||||||
|
return revtuple_autovifify_from_linear(paths_and_arrays)
|
||||||
|
|
@ -0,0 +1,92 @@
|
||||||
|
"""Basic tests for 'algebraic data type based pytree' transformations."""
|
||||||
|
|
||||||
|
|
||||||
|
import io
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import numpy
|
||||||
|
|
||||||
|
import ml_model_transforms
|
||||||
|
|
||||||
|
|
||||||
|
def _get_model(prefix):
|
||||||
|
return {
|
||||||
|
prefix + 'a1': numpy.arange(1000, 1024).reshape(6, 4).astype(numpy.float32),
|
||||||
|
prefix + 'a2': numpy.arange(2000, 2048).reshape(6, 8).astype(numpy.float32),
|
||||||
|
prefix + 'b1': {
|
||||||
|
prefix + 'c1': numpy.arange(100, 127).reshape(3, 3, 3).astype(numpy.int8),
|
||||||
|
prefix + 'c2': numpy.arange(100, 128).reshape(7, 4).astype(numpy.float64)
|
||||||
|
}}
|
||||||
|
|
||||||
|
|
||||||
|
class MLModeltransformsTest(unittest.TestCase):
|
||||||
|
"""Basic correctness validation tests for ML-model transformations."""
|
||||||
|
|
||||||
|
def test_ml_model_leaf_summary(self):
|
||||||
|
"""Tests guarantees given by `ml_model_leaf_summary`."""
|
||||||
|
summary = ml_model_transforms.ml_model_leaf_summary(
|
||||||
|
('a', ()),
|
||||||
|
numpy.arange(1000, 1024).reshape(6, 4).astype(numpy.int16),
|
||||||
|
sep='##')
|
||||||
|
self.assertIn('##', summary) # Separator is respected.
|
||||||
|
self.assertIn('(6, 4)', summary) # Shape is mentioned somewhere.
|
||||||
|
self.assertIn('int16', summary) # dtype is mentioned somewhere.
|
||||||
|
|
||||||
|
def test_revtuple_autovivify_from_linear(self):
|
||||||
|
"""Tests guarantees given by `revtuple_autovifify_from_linear`."""
|
||||||
|
with self.subTest(guarantee='empty'):
|
||||||
|
self.assertEqual(
|
||||||
|
ml_model_transforms.revtuple_autovifify_from_linear([]),
|
||||||
|
{})
|
||||||
|
with self.subTest(guarantee='generic'):
|
||||||
|
keys_vals = [(('a', 'b1', 'c1'), 1001),
|
||||||
|
(('a', 'b2'), 1002),
|
||||||
|
(('a2',), 1003),
|
||||||
|
]
|
||||||
|
self.assertEqual(
|
||||||
|
ml_model_transforms.revtuple_autovifify_from_linear(keys_vals),
|
||||||
|
{'a': {'b1': {'c1': 1001}, 'b2': 1002}, 'a2': 1003})
|
||||||
|
|
||||||
|
def test_model_overview(self):
|
||||||
|
"""Tests guarantees given by `model_overview`."""
|
||||||
|
model = _get_model('xyz')
|
||||||
|
out_io = io.StringIO()
|
||||||
|
ml_model_transforms.model_overview(model, out=out_io)
|
||||||
|
overview = out_io.getvalue()
|
||||||
|
self.assertIn('xyz', overview)
|
||||||
|
|
||||||
|
def test_model_contents(self):
|
||||||
|
"""Tests guarantees given by `model_contents`."""
|
||||||
|
model = _get_model('pq_')
|
||||||
|
contents = ml_model_transforms.model_contents(model)
|
||||||
|
fingerprints = {k: (a.shape, a.ravel()[:3].tolist())
|
||||||
|
for k, a in contents.items()}
|
||||||
|
self.assertEqual(fingerprints,
|
||||||
|
{('pq_a1',): ((6, 4), [1000.0, 1001.0, 1002.0]),
|
||||||
|
('pq_a2',): ((6, 8), [2000.0, 2001.0, 2002.0]),
|
||||||
|
('pq_b1', 'pq_c1'): ((3, 3, 3), [100, 101, 102]),
|
||||||
|
('pq_b1', 'pq_c2'): ((7, 4), [100.0, 101.0, 102.0])})
|
||||||
|
|
||||||
|
def test_model_save_load_basic(self):
|
||||||
|
"""Tests basic guarantees given by `model_save` and `model_load`."""
|
||||||
|
# What we care about here is that the round trip works - so
|
||||||
|
# it makes more sense to test saving and loading as one unit.
|
||||||
|
model_orig = _get_model('model_')
|
||||||
|
with tempfile.TemporaryDirectory() as tempdir:
|
||||||
|
filepath_stem = os.path.join(tempdir, 'the_model')
|
||||||
|
total_size, total_time = ml_model_transforms.model_save(model_orig,
|
||||||
|
filepath_stem)
|
||||||
|
self.assertGreater(total_size, 0)
|
||||||
|
self.assertGreater(total_time, 0)
|
||||||
|
model_reloaded = ml_model_transforms.model_load(filepath_stem)
|
||||||
|
contents_orig = ml_model_transforms.model_contents(model_orig)
|
||||||
|
contents_reloaded = ml_model_transforms.model_contents(model_reloaded)
|
||||||
|
self.assertEqual(
|
||||||
|
{k: v.tolist() for k, v in contents_orig.items()},
|
||||||
|
{k: v.tolist() for k, v in contents_reloaded.items()})
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
||||||
|
|
@ -0,0 +1,508 @@
|
||||||
|
"""Tools for transforming "nested python object" tree data structures.
|
||||||
|
|
||||||
|
# Context
|
||||||
|
|
||||||
|
The motivation for this module came from ML applications that ought to
|
||||||
|
be based on a principled handling of nested Python data structures.
|
||||||
|
Having such principled pytree-transforming code available solves
|
||||||
|
some other problems, such as doing away with a need to abuse
|
||||||
|
tree-mapping for-side-effect-only and having to use a hope-and-pray
|
||||||
|
approach to processing very deeply nested values which with a recursive
|
||||||
|
approach might trigger a RecursionError.
|
||||||
|
|
||||||
|
We specifically want to cover the use case of having ML model
|
||||||
|
parameters that are available in a nested Python data structure for
|
||||||
|
which there "almost" is a unique-up-to-unique-isomorphism mapping from
|
||||||
|
and to this Algebraic Data Type:
|
||||||
|
|
||||||
|
`data ModelParams a = Array a | Node [(String, ModelParams a)]`
|
||||||
|
|
||||||
|
In this correspondence, `a` is some array-type (perhaps
|
||||||
|
`numpy.ndarray`, `jax.numpy.ndarray`, `tf.tensor`, etc.), but the
|
||||||
|
data-processing code is effectively entirely agnostic to this, and a
|
||||||
|
`Node` is "almost" an associative-list of (key, value) pairs
|
||||||
|
representing a Python dict. (Note: The "almost" here is mostly about
|
||||||
|
the conceptual wart that assoc-lists can in principle have key
|
||||||
|
duplicates, but Python dicts can not. This is however not a problem
|
||||||
|
since all we need is the transformation in one direction,
|
||||||
|
i.e. whatever data-processing `f` we want to express on the
|
||||||
|
model-parameters-pytree, we can express by specifying a "faithful"
|
||||||
|
mapping `m` into the above algebraic data type through which every
|
||||||
|
such pytree data transform factorizes, i.e. for every `f` we can find
|
||||||
|
a `g` such that `f(p) = g(m(p))`.)
|
||||||
|
|
||||||
|
## Components
|
||||||
|
|
||||||
|
The main workhorse in this module is the `pytree_iter` function that
|
||||||
|
maps a "PyTree (such as representing `ModelParams`)" to an iterator
|
||||||
|
over values obtained by applying a mapping-function to the "key-path"
|
||||||
|
and leaf-value for every leaf, where the "key-path" contains a
|
||||||
|
linked-list representation of the reversed sequence of keys from the
|
||||||
|
tree-root, with list-nodes being represented by pairs
|
||||||
|
`(latest_dict_key, rest_path)`, and the empty path being represented
|
||||||
|
by `()`.
|
||||||
|
|
||||||
|
For the sake of genericity, `pytree_iter` is built in such a way that
|
||||||
|
it actually can handle any kind of traversal of PyTree-trees that do
|
||||||
|
represent algebraic data types (note however that some some do not) -
|
||||||
|
but for this to make sense, the user must have a way to define how to
|
||||||
|
interpret tree-nodes, in particular identify leaves. This requires
|
||||||
|
providing a function `node_handler` with the same signature and
|
||||||
|
behavior as described below for "node handlers".
|
||||||
|
|
||||||
|
Additionally, this module provides mapping-over-pytrees via
|
||||||
|
`pytree_map`, which is also built in such a way that it makes the
|
||||||
|
correspondence between an algebraic data type and its Python
|
||||||
|
nested-tree representation explicit. Despite being powerful and
|
||||||
|
flexible, this, however, may in general require a bit more effort to
|
||||||
|
wire up, since node-rebuilding can be fairly nontrivial.
|
||||||
|
|
||||||
|
Furthermore, as a prominent application, this module provides a simple
|
||||||
|
deep-freezing function that translates a nested Python data structure
|
||||||
|
to deeply-immutable form.
|
||||||
|
|
||||||
|
## Concepts and Conventions
|
||||||
|
|
||||||
|
"revtuple representation":
|
||||||
|
|
||||||
|
As we iterate over a tree, we will have to keep track of the
|
||||||
|
path-to-tree-root. Naturally, two sibling nodes `n1` and `n2`
|
||||||
|
will share the same parent-path (being siblings), so it makes
|
||||||
|
sense to use a linked-list-with-shared-tail representation.
|
||||||
|
Python does not have a natural notion for that, so we use
|
||||||
|
recursively-constructed tuples `(node_tag, parent_path)`
|
||||||
|
that represent the path-from-root in-reverse-order, i.e.
|
||||||
|
for a non-empty path `p`, `p[0]` is the node-tag at the
|
||||||
|
deepest nesting level. We call this a "revtuple representation"
|
||||||
|
of the path.
|
||||||
|
|
||||||
|
"node handler":
|
||||||
|
|
||||||
|
A node-handler classifies a tree-node as "leaf or other node", and
|
||||||
|
for non-leaf nodes provides information about both its children and
|
||||||
|
how to rebuild it. The behavior of a node-handler function must be
|
||||||
|
in alignment with this docstring:
|
||||||
|
|
||||||
|
'''Processes a tree-node as required by pytree-iteration and -mapping.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
revtuple_path: Revtuple-representation of the path-from-root
|
||||||
|
to the current node.
|
||||||
|
node: a tree-node in a ML-model tree that is recursively
|
||||||
|
built out of leaf-values and other nodes.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`None` if the tree-node is to be regarded as a leaf, otherwise
|
||||||
|
a pair `(rebuilder, iterator)`, where `iterator` iterates
|
||||||
|
over the data-content of the node, each item represented as a pair
|
||||||
|
of `(lookup_path_item, value_item)`, and `rebuilder` is a function
|
||||||
|
which, when applied to an iterable of the aforementioned value-items
|
||||||
|
(or some transformation thereof) returns a node that is equivalent
|
||||||
|
to the original (or up to a transformation of the contents).
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
InvalidTreeNodeError: If the tree contains a node of a kind
|
||||||
|
that is not expected to show up.
|
||||||
|
'''
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
(The behavior of a node-handler is somewhat nontrivial, so covering
|
||||||
|
two very common cases via examples is in order.)
|
||||||
|
|
||||||
|
This node-handler would allow descending into (nested)
|
||||||
|
instances of `list` (but not subclass instances thereof):
|
||||||
|
|
||||||
|
```def list_node_handler(revtuple_path, obj):
|
||||||
|
''' ... '''
|
||||||
|
if type(obj) is list:
|
||||||
|
return list, enumerate(obj)
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
```
|
||||||
|
|
||||||
|
This node-handler would allow descending into (nested) mappings,
|
||||||
|
which upon rebuilding would get turned into `dict` instances:
|
||||||
|
|
||||||
|
```def mapping_node_handler(revtuple_path, obj):
|
||||||
|
''' ... '''
|
||||||
|
if isinstance(obj, collections.abc.Mapping):
|
||||||
|
# For generic mappings, we cannot rely on key- and item-iteration
|
||||||
|
# being guaranteed to use identical iteration-order.
|
||||||
|
items = list(obj.items())
|
||||||
|
keys = [kv[0] for kv in items]
|
||||||
|
return (lambda values: dict(zip(keys, values))), items
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
```
|
||||||
|
|
||||||
|
A dict/mapping node-handler can of course rename keys, add or remove
|
||||||
|
entries, make decisions based on the item-path, or map a dict to
|
||||||
|
an associative list, etc.
|
||||||
|
|
||||||
|
## Further Design Notes
|
||||||
|
|
||||||
|
The `pytree_map` function requests the leaf-transform and node-handler
|
||||||
|
to be side-effect-free functions. This is both required to leave
|
||||||
|
implementation-side flexibility, and also follows the general LISP
|
||||||
|
recommendation to not abuse mapping (which should be a pure
|
||||||
|
data-transformation) for imperative data processing. Overall, if
|
||||||
|
a need for more general "nested datastructures" processing becomes
|
||||||
|
pressing, it is for the better if this leads to a proper articulation
|
||||||
|
of the specific needs, to be addressed with appropriate design, rather
|
||||||
|
than abuse of functional data-transforms becoming "a bad idiom
|
||||||
|
that turned into established practice".
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
import collections.abc
|
||||||
|
import immutabledict
|
||||||
|
|
||||||
|
import numpy
|
||||||
|
|
||||||
|
from typing import Any, Callable, Iterable, Iterator, TypeVar
|
||||||
|
|
||||||
|
|
||||||
|
T = TypeVar('T')
|
||||||
|
U = TypeVar('U')
|
||||||
|
|
||||||
|
KT = TypeVar('KT')
|
||||||
|
NT = TypeVar('NT')
|
||||||
|
|
||||||
|
|
||||||
|
## Type of the reverse-order-keys-to-root path.
|
||||||
|
# (This code actually illustrates why https://xkcd.com/2483/ is very misguided.)
|
||||||
|
RevTuplePath = tuple
|
||||||
|
|
||||||
|
## Type of the `leaf_transform` function-argument used for tree-iteration.
|
||||||
|
#
|
||||||
|
# This would be the correct type we would have to specify here but cannot,
|
||||||
|
# since the design of Python's static typing at the time of this writing
|
||||||
|
# is too broken for that:
|
||||||
|
#
|
||||||
|
# type LeafTransformFunc[L, R] = Callable[[RevTuplePath, L], R]
|
||||||
|
#
|
||||||
|
# Instead, we have to settle for...:
|
||||||
|
LeafTransformFunc = Callable[[RevTuplePath, Any], Any]
|
||||||
|
|
||||||
|
|
||||||
|
## Type of the `tree_node_handler` function-argument used for
|
||||||
|
## tree-iteration and tree-mapping.
|
||||||
|
#
|
||||||
|
# Again, this is the correct type we would have to put here but cannot:
|
||||||
|
#
|
||||||
|
# type NodeHandlerFunc[KT] = (
|
||||||
|
# Callable[[NT],
|
||||||
|
# None | tuple[Callable[[Iterable[tuple[KT, NT]]], NT],
|
||||||
|
# Iterator[tuple[KT, NT]]]])
|
||||||
|
#
|
||||||
|
# ...so, we have to instead settle for:
|
||||||
|
NodeHandlerFunc = (
|
||||||
|
Callable[[RevTuplePath, NT],
|
||||||
|
None | tuple[Callable[[Iterable[tuple[Any, NT]]], NT],
|
||||||
|
Iterator[tuple[Any, NT]]]])
|
||||||
|
|
||||||
|
|
||||||
|
Predicate = Callable[[object], bool]
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidTreeNodeError(ValueError):
|
||||||
|
"""Encountered a tree-node of invalid type."""
|
||||||
|
|
||||||
|
|
||||||
|
def linearize_revtuple_path(
|
||||||
|
revtuple_path: RevTuplePath,
|
||||||
|
present_as: Callable[[Iterator[T]], U] = tuple) -> U:
|
||||||
|
"""Translates a revtuple path to (typically) linear form.
|
||||||
|
|
||||||
|
With default `present_as`, this will map a path of the form
|
||||||
|
`(key_{N}, (key_{N-1}, ..., (root, ())))` into a tuple
|
||||||
|
(root, ..., key_{N-1}, key_{N}).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
revtuple_path: A linked-list-as-recursive-pairs
|
||||||
|
reverse-order tuple-representation of the path.
|
||||||
|
Path-root is `()`, and node-key `x` relative to
|
||||||
|
earlier path `p` is represented as `(x, p)`.
|
||||||
|
present_as: Callable that consumes an iterator over
|
||||||
|
path-pieces - with the deepest-nesting level coming last -
|
||||||
|
turning it into a linearized path. Defaults to `tuple`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Linearized presentation of all the node-keys in the
|
||||||
|
recursive-path in order, deepest-down path component coming last.
|
||||||
|
"""
|
||||||
|
pieces = []
|
||||||
|
todo = revtuple_path
|
||||||
|
while todo:
|
||||||
|
node, todo = todo
|
||||||
|
pieces.append(node)
|
||||||
|
return present_as(reversed(pieces))
|
||||||
|
|
||||||
|
|
||||||
|
# This function itself has type `NodeHandlerFunc`, but Python does not
|
||||||
|
# allow us to here simply type-annotate it like this. We cannot even
|
||||||
|
# introduce an abbreviation for the complicated output-type,
|
||||||
|
# since that would have to be parametric in node-type `NT` (and `KT`).
|
||||||
|
def everything_is_a_leaf_node_handler(
|
||||||
|
revtuple_path: tuple,
|
||||||
|
node : NT) -> (
|
||||||
|
None | tuple[Callable[[Iterable[tuple[Any, NT]]], NT],
|
||||||
|
Iterator[tuple[Any, NT]]]):
|
||||||
|
"""Processes a tree-node as required by pytree-iteration and -mapping.
|
||||||
|
|
||||||
|
Interface and signature are in alignment with the requirements for a
|
||||||
|
"node handler" function explained in the module-docstring.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
revtuple_path: the path-to-root for this node.
|
||||||
|
node: a tree-node.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`None`, i.e. classifying any kind of node as a leaf-node.
|
||||||
|
"""
|
||||||
|
del revtuple_path, node # Unused.
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def leaf_summary(path: RevTuplePath, x: object):
|
||||||
|
"""Produces a human-readable summary-string for a leaf-node.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: revtuple representation of the path-to-root.
|
||||||
|
x: The leaf-value.
|
||||||
|
"""
|
||||||
|
del path # Ignored here.
|
||||||
|
tx = type(x)
|
||||||
|
mod = tx.__module__
|
||||||
|
modname = mod if isinstance(mod, str) else mod.__name__
|
||||||
|
type_str = f'{modname}.{tx.__qualname__}'
|
||||||
|
repr_str = repr(x)
|
||||||
|
repr_abbrev = repr_str if len(repr_str) < 40 else repr_str[:40] + ' ...'
|
||||||
|
# On str, int, float, etc. `{type_str}(repr(x))` would actually still be
|
||||||
|
# a (non-literal) Python-expression that would evaluate to the original value.
|
||||||
|
# However, we make no promises beyond "human-readable".
|
||||||
|
return f'{type_str}({repr_abbrev})'
|
||||||
|
|
||||||
|
|
||||||
|
# With respect to static type annotations, the limitations of Python's
|
||||||
|
# approach to static typing really become prominently visible here.
|
||||||
|
#
|
||||||
|
# Different arguments have type-parameters, but since there is no way
|
||||||
|
# to have parametric abbreviations such as `LeafTransformFunc[L, R]`,
|
||||||
|
# the only way we would have available to express relations between
|
||||||
|
# type-parameters would be to substitute in the not-abbreviated form of
|
||||||
|
# `NodeHandlerFunc` and `LeafTransformFunc`, giving us something monstrous.
|
||||||
|
# We instead here settle for "we cannot express that `tree` must
|
||||||
|
# have the same type as the input-type to `tree_node_handler` and use `Any`,
|
||||||
|
# and likewise for leaf_transform and the output.
|
||||||
|
def pytree_leaf_iter(
|
||||||
|
tree: Any,
|
||||||
|
leaf_transform: LeafTransformFunc,
|
||||||
|
node_handler: NodeHandlerFunc = everything_is_a_leaf_node_handler,
|
||||||
|
) -> Iterator[Any]:
|
||||||
|
# ...actual return type would be `Iterator[{what leaf_transform returns}]`.
|
||||||
|
"""Iterates over the leaves of a tree.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tree: The tree to iterate over.
|
||||||
|
leaf_transform: A callable `f` that will get applied
|
||||||
|
as `f(revtuple_path, leaf)`, where `revtuple_path`
|
||||||
|
is the revtuple representation of the path to the
|
||||||
|
leaf from the root.
|
||||||
|
node_handler: A "node handler" (see module docstring)
|
||||||
|
that processes nodes encountered during iterative traversal.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
Value of `leaf_transform(p, x)`, where `x` is the current leaf
|
||||||
|
and `p` is its revtuple-path to the root.
|
||||||
|
"""
|
||||||
|
# Note: Exit points for the code below are in non-obvious places
|
||||||
|
# and hence marked with " # ***EXIT***".
|
||||||
|
#
|
||||||
|
# Doing iteration properly is slightly nontrivial.
|
||||||
|
# One may be tempted to go for a very simple recursive implementation
|
||||||
|
# (with an extra pre-final `path` argument to `pytree_iter`):
|
||||||
|
#
|
||||||
|
# maybe_substructure = node_handler(path, tree)
|
||||||
|
# if maybe_substructure is None:
|
||||||
|
# # We are looking at a leaf-node.
|
||||||
|
# yield leaf_transform(path, tree)
|
||||||
|
# else:
|
||||||
|
# _, contents_iter = maybe_substructure
|
||||||
|
# for k, v in contents_iter:
|
||||||
|
# yield from pytree_iter(v, leaf_transform, (k, path), node_handler)
|
||||||
|
#
|
||||||
|
# That, however, would be flawed, since there is no a priori reason
|
||||||
|
# why a pytree may not be a very deeply nested structure - such as a
|
||||||
|
# long linked list. That would then risk raising `RecursionError`,
|
||||||
|
# and since Python by design(!) does not perform tail call elimination
|
||||||
|
# or any other kind of advanced CPS transforms, there is no recursive
|
||||||
|
# solution here. So, to do this properly, we have to do this iteratively.
|
||||||
|
#
|
||||||
|
# We are facing an annoying situation here: If `tree` itself is a leaf,
|
||||||
|
# we have two options: (a) wrapping it up in a one-node tree
|
||||||
|
# and processing that, or (b) special-casing "root is a leaf".
|
||||||
|
# Option (b) leads to some mild node-processing code-duplication
|
||||||
|
# for a single node (the root).
|
||||||
|
# Option (a) requires having special cases for node-processing that
|
||||||
|
# get looked at for every tree node. We go with option (b) here.
|
||||||
|
maybe_substructure = node_handler((), tree)
|
||||||
|
if maybe_substructure is None:
|
||||||
|
# The tree itself is a leaf.
|
||||||
|
yield leaf_transform((), tree)
|
||||||
|
return # ***EXIT***
|
||||||
|
# Otherwise, we are looking at a tree.
|
||||||
|
_, contents_iter = maybe_substructure
|
||||||
|
current_revtuple_path = ()
|
||||||
|
work_to_do = [contents_iter]
|
||||||
|
# Otherwise-unreachable sentinel for reliably identifying
|
||||||
|
# iterator-exhaustion without using exceptions:
|
||||||
|
sentinel = object()
|
||||||
|
while True:
|
||||||
|
current_iter = work_to_do[-1]
|
||||||
|
maybe_next_item = next(current_iter, sentinel)
|
||||||
|
if maybe_next_item is sentinel:
|
||||||
|
# We are done at this level.
|
||||||
|
work_to_do.pop()
|
||||||
|
if not work_to_do: return # ***EXIT***
|
||||||
|
current_revtuple_path = current_revtuple_path[1]
|
||||||
|
else:
|
||||||
|
path_piece, subtree = maybe_next_item
|
||||||
|
extended_revtuple_path = (path_piece, current_revtuple_path)
|
||||||
|
maybe_subtree_substructure = node_handler(extended_revtuple_path, subtree)
|
||||||
|
if maybe_subtree_substructure is None: # Case: subtree is a leaf.
|
||||||
|
yield leaf_transform(extended_revtuple_path, subtree)
|
||||||
|
else: # Case: subtree is a tree.
|
||||||
|
current_revtuple_path = (path_piece, current_revtuple_path)
|
||||||
|
_, items_iter = maybe_subtree_substructure
|
||||||
|
work_to_do.append(items_iter)
|
||||||
|
|
||||||
|
|
||||||
|
# The current design approach here would be appropriate for
|
||||||
|
# applying leaf-transforms while retaining the structure of the tree -
|
||||||
|
# which closely corresponds to e.g. a (a -> b) -> (Tree a -> Tree b) functor.
|
||||||
|
#
|
||||||
|
# It is not entirely clear whether this is the abstraction that we should
|
||||||
|
# consider as being appropriately generic to flesh out explicitly - rather
|
||||||
|
# than starting from a more general approach of which this then is a special
|
||||||
|
# case. Some background: https://ncatlab.org/nlab/show/recursion+scheme
|
||||||
|
#
|
||||||
|
# On the other hand, there is a lot of flexibility via whatever
|
||||||
|
# node-rebuilder a node-handler produces - this can do quite some reshaping
|
||||||
|
# of a tree, including dropping or duplicating nodes.
|
||||||
|
def pytree_map(
|
||||||
|
tree: Any,
|
||||||
|
leaf_transform,
|
||||||
|
node_handler: NodeHandlerFunc = everything_is_a_leaf_node_handler,
|
||||||
|
):
|
||||||
|
"""Maps a (potentially nested) Python value to another such value.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tree: The Python-object to be mapped.
|
||||||
|
leaf_transform: A callable `f` that will get applied
|
||||||
|
as `f(revtuple_path, leaf)`, where `revtuple_path`
|
||||||
|
is the revtuple representation of the path to the
|
||||||
|
leaf from the root. Must be side effect free.
|
||||||
|
node_handler: A "node handler" (see module docstring)
|
||||||
|
that processes nodes encountered during iterative traversal.
|
||||||
|
Must be side effect free.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The outcome of translating `tree`.
|
||||||
|
"""
|
||||||
|
# Note: Exit points for the code below are in non-obvious places
|
||||||
|
# and hence marked with " # ***EXIT***".
|
||||||
|
#
|
||||||
|
# Otherwise-inaccessible sentinel object, for reliably identifying
|
||||||
|
# missing-values via identity-check against sentinel lookup-default.
|
||||||
|
sentinel = object()
|
||||||
|
# Code structure mostly follows pytree_leaf_iter.
|
||||||
|
maybe_substructure = node_handler((), tree)
|
||||||
|
if maybe_substructure is None:
|
||||||
|
return leaf_transform((), tree) # ***EXIT***
|
||||||
|
rebuilder, items_iter = maybe_substructure
|
||||||
|
current_revtuple_path = ()
|
||||||
|
# Per-level, we have a triplet of:
|
||||||
|
# (rebuilder, remaining_items_to_iterate_over, processed).
|
||||||
|
parts_for_assembly = [(rebuilder, items_iter, [])]
|
||||||
|
while True:
|
||||||
|
this_rebuilder, this_items_iter, this_done_pieces = parts_for_assembly[-1]
|
||||||
|
maybe_next_item = next(this_items_iter, sentinel)
|
||||||
|
if maybe_next_item is sentinel:
|
||||||
|
# We are done with all the items for this level.
|
||||||
|
parts_for_assembly.pop()
|
||||||
|
built_iter = this_rebuilder(this_done_pieces)
|
||||||
|
if not parts_for_assembly: # No outer structure, so at-top-level.
|
||||||
|
return built_iter # ***EXIT***
|
||||||
|
else: # We have outer structure.
|
||||||
|
parts_for_assembly[-1][-1].append(built_iter)
|
||||||
|
current_revtuple_path = current_revtuple_path[1]
|
||||||
|
continue # ...with next is-the-final-item-complete-check.
|
||||||
|
else:
|
||||||
|
# More constituents of the current item.
|
||||||
|
path_piece, subtree_item = maybe_next_item
|
||||||
|
extended_revtuple_path = (path_piece, current_revtuple_path)
|
||||||
|
maybe_subtree_substructure = node_handler(
|
||||||
|
extended_revtuple_path,
|
||||||
|
subtree_item)
|
||||||
|
if maybe_subtree_substructure is None:
|
||||||
|
this_done_pieces.append(
|
||||||
|
leaf_transform(extended_revtuple_path, subtree_item))
|
||||||
|
else:
|
||||||
|
# We have a subtree.
|
||||||
|
subtree_rebuilder, subtree_items_iter = maybe_subtree_substructure
|
||||||
|
current_revtuple_path = (path_piece,
|
||||||
|
current_revtuple_path)
|
||||||
|
parts_for_assembly.append(
|
||||||
|
(subtree_rebuilder, subtree_items_iter, []))
|
||||||
|
|
||||||
|
|
||||||
|
def deep_freeze(
|
||||||
|
tree,
|
||||||
|
*,
|
||||||
|
is_mapping : Predicate = lambda x: isinstance(x, collections.abc.Mapping),
|
||||||
|
is_set : Predicate = lambda x: isinstance(x, collections.abc.Set),
|
||||||
|
is_sequence : Predicate = lambda x: isinstance(x, (list, tuple)),
|
||||||
|
leaf_fn: Callable[[Any], Any] = lambda x: x,
|
||||||
|
):
|
||||||
|
"""Recursively freezes Set/Mapping/List/Tuple structures.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tree: The potentially deeply-nested object to deep-freeze.
|
||||||
|
is_mapping: Callable that decides whether a sub-object is a mapping.
|
||||||
|
Defaults to an `isinstance()` check for `collections.abc.Mapping`.
|
||||||
|
is_set: Callable that decides whether a sub-object is a set.
|
||||||
|
Defaults to an `isinstance()` check for `collections.abc.Set`.
|
||||||
|
is_sequence: Callable that decides whether a sub-object is a sequence.
|
||||||
|
Defaults to a check for being a `tuple` or `list` instance.
|
||||||
|
leaf_fn: Function to use for translating non-mapping/set/sequence
|
||||||
|
instances.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Translated-to-deeply-immutable form of `tree`.
|
||||||
|
"""
|
||||||
|
idict = immutabledict.immutabledict
|
||||||
|
def freeze_node_handler(path, x):
|
||||||
|
if is_set(x):
|
||||||
|
return frozenset, ((None, y) for y in x)
|
||||||
|
if is_mapping(x):
|
||||||
|
# Mappings already have hashable, so
|
||||||
|
# (should-be-)deeply-immutable keys.
|
||||||
|
# Hence, we only need to deep-freeze the values.
|
||||||
|
#
|
||||||
|
# Note that non-`dict` mappings might not guarantee
|
||||||
|
# to respect iteration-order, so we have to be careful here:
|
||||||
|
items = list(x.items())
|
||||||
|
keys = [kv[0] for kv in items]
|
||||||
|
values = [kv[1] for kv in items]
|
||||||
|
return ((lambda ys: idict(zip(keys, ys))),
|
||||||
|
iter(items))
|
||||||
|
if is_sequence(x):
|
||||||
|
return tuple, enumerate(iter(x))
|
||||||
|
# Otherwise, this should not be traversed.
|
||||||
|
return None
|
||||||
|
def leaf_transform(revtuple_path, value):
|
||||||
|
del revtuple_path # Unused.
|
||||||
|
return leaf_fn(value)
|
||||||
|
return pytree_map(tree, leaf_transform, freeze_node_handler)
|
||||||
|
|
@ -0,0 +1,168 @@
|
||||||
|
"""Basic tests for 'algebraic data type based pytree' transformations."""
|
||||||
|
|
||||||
|
|
||||||
|
import collections.abc
|
||||||
|
import sys
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import pytree_transforms
|
||||||
|
|
||||||
|
|
||||||
|
def _get_deep_pytree(packaging_fn, bottom, depth):
|
||||||
|
current = bottom
|
||||||
|
for n in reversed(range(depth)):
|
||||||
|
current = packaging_fn(n, current)
|
||||||
|
return current
|
||||||
|
|
||||||
|
|
||||||
|
def _dict_node_handler(p, d):
|
||||||
|
del p # Unused.
|
||||||
|
if isinstance(d, dict):
|
||||||
|
keys = d.keys()
|
||||||
|
newdict = lambda vals: dict(zip(keys, vals))
|
||||||
|
return (newdict, iter(d.items()))
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class PyTreeTest(unittest.TestCase):
|
||||||
|
"""Basic correctness validation tests for PyTree transformations."""
|
||||||
|
|
||||||
|
def test_linearize_revtuple_path(self):
|
||||||
|
"""Tests guarantees given by `linearize_revtuple_path`."""
|
||||||
|
linearize_revtuple_path = pytree_transforms.linearize_revtuple_path
|
||||||
|
with self.subTest(guarantee='empty'):
|
||||||
|
self.assertEqual(linearize_revtuple_path(()), ())
|
||||||
|
with self.subTest(guarantee='typical'):
|
||||||
|
self.assertEqual(linearize_revtuple_path((30, (20, (10, ())))),
|
||||||
|
(10, 20, 30))
|
||||||
|
with self.subTest(guarantee='present_as'):
|
||||||
|
self.assertEqual(
|
||||||
|
linearize_revtuple_path(
|
||||||
|
(30, (20, (10, ()))), present_as=list),
|
||||||
|
[10, 20, 30])
|
||||||
|
|
||||||
|
def test_everything_is_a_leaf_node_handler(self):
|
||||||
|
"""Tests guarantees given by `everything_is_a_leaf_node_handler`."""
|
||||||
|
everything_is_a_leaf_node_handler = (
|
||||||
|
pytree_transforms.everything_is_a_leaf_node_handler)
|
||||||
|
self.assertEqual(everything_is_a_leaf_node_handler((), 'abc'),
|
||||||
|
None)
|
||||||
|
self.assertEqual(everything_is_a_leaf_node_handler(('b', ()),
|
||||||
|
dict(a=3)),
|
||||||
|
None)
|
||||||
|
|
||||||
|
def test_leaf_summary(self):
|
||||||
|
"""Tests guarantees given by `leaf_summary`."""
|
||||||
|
# Since the docstring only guarantees "a human-readable presentation",
|
||||||
|
# we can and should only do loose checks.
|
||||||
|
thing = (5678, 9531)
|
||||||
|
summary = pytree_transforms.leaf_summary(('key', ()), thing)
|
||||||
|
self.assertIsInstance(summary, str)
|
||||||
|
self.assertIn(str(thing[0]), summary)
|
||||||
|
self.assertIn(str(thing[1]), summary)
|
||||||
|
|
||||||
|
def test_pytree_leaf_iter(self):
|
||||||
|
"""Tests guarantees given by `pytree_leaf_iter`."""
|
||||||
|
pytree_leaf_iter = pytree_transforms.pytree_leaf_iter
|
||||||
|
def leaf_transform(path, leaf):
|
||||||
|
return repr(leaf) if path and path[0].startswith('R') else leaf
|
||||||
|
with self.subTest(guarantee='returns_iterator'):
|
||||||
|
result = pytree_leaf_iter(7, leaf_transform, _dict_node_handler)
|
||||||
|
self.assertIsInstance(result, collections.abc.Iterator)
|
||||||
|
with self.subTest(guarantee='totally_empty'):
|
||||||
|
result = list(pytree_leaf_iter({}, leaf_transform, _dict_node_handler))
|
||||||
|
self.assertEqual(result, [])
|
||||||
|
with self.subTest(guarantee='no_leaves'):
|
||||||
|
result = list(pytree_leaf_iter(dict(a={}),
|
||||||
|
leaf_transform, _dict_node_handler))
|
||||||
|
self.assertEqual(result, [])
|
||||||
|
with self.subTest(guarantee='is_leaf'):
|
||||||
|
result = list(pytree_leaf_iter(777, leaf_transform, _dict_node_handler))
|
||||||
|
self.assertEqual(result, [777])
|
||||||
|
with self.subTest(guarantee='generic'):
|
||||||
|
result = list(pytree_leaf_iter(
|
||||||
|
dict(n0=dict(n01=dict(n012=1002,
|
||||||
|
n013=1003,
|
||||||
|
Rn014=1004,
|
||||||
|
),
|
||||||
|
n02=1005),
|
||||||
|
n5=1006),
|
||||||
|
leaf_transform, _dict_node_handler))
|
||||||
|
self.assertEqual(result, [1002, 1003, '1004', 1005, 1006])
|
||||||
|
with self.subTest(guarantee='with_keys'):
|
||||||
|
result = list(pytree_leaf_iter(
|
||||||
|
dict(n0=dict(n01=dict(n012=1002,
|
||||||
|
n013=1003)),
|
||||||
|
n1=1004),
|
||||||
|
lambda p, s: (pytree_transforms.linearize_revtuple_path(p), s),
|
||||||
|
_dict_node_handler))
|
||||||
|
self.assertEqual(result,
|
||||||
|
[(('n0', 'n01', 'n012'), 1002),
|
||||||
|
(('n0', 'n01', 'n013'), 1003),
|
||||||
|
(('n1',), 1004)])
|
||||||
|
|
||||||
|
def test_pytree_map(self):
|
||||||
|
"""Tests guarantees given by `pytree_map`."""
|
||||||
|
pytree_map = pytree_transforms.pytree_map
|
||||||
|
leaf_transform = lambda p, s: repr(s)
|
||||||
|
tree1 = dict(t0=dict(t10=1001,
|
||||||
|
t11=dict(t110=1002,
|
||||||
|
t111=1003),
|
||||||
|
t12=dict(t120=1004,
|
||||||
|
t121=1005,
|
||||||
|
t122=1006)),
|
||||||
|
t1=1007)
|
||||||
|
with self.subTest(guarantee='no_leaves'):
|
||||||
|
result = pytree_map(dict(a={}),
|
||||||
|
leaf_transform,
|
||||||
|
_dict_node_handler)
|
||||||
|
self.assertEqual(result, dict(a={}))
|
||||||
|
with self.subTest(guarantee='is_leaf'):
|
||||||
|
result = pytree_map(777, leaf_transform, _dict_node_handler)
|
||||||
|
self.assertEqual(result, '777')
|
||||||
|
with self.subTest(guarantee='generic'):
|
||||||
|
result = pytree_map(tree1, leaf_transform, _dict_node_handler)
|
||||||
|
self.assertEqual(result['t0']['t10'], '1001')
|
||||||
|
|
||||||
|
def test_deeply_nested(self):
|
||||||
|
"""Tests correct behavior on deeply-nested data structures."""
|
||||||
|
pytree_leaf_iter = pytree_transforms.pytree_leaf_iter
|
||||||
|
pytree_map = pytree_transforms.pytree_map
|
||||||
|
#
|
||||||
|
depth = max(10**5, sys.getrecursionlimit() + 100)
|
||||||
|
deep_tree = _get_deep_pytree(lambda n, t: {n: t},
|
||||||
|
'leaf', depth)
|
||||||
|
with self.subTest(function='pytree_leaf_iter'):
|
||||||
|
leaves = list(pytree_leaf_iter(deep_tree,
|
||||||
|
lambda p, s: s.upper(),
|
||||||
|
_dict_node_handler))
|
||||||
|
self.assertEqual(leaves, ['LEAF'])
|
||||||
|
with self.subTest(function='pytree_map'):
|
||||||
|
mapped_deep_tree = pytree_map(deep_tree,
|
||||||
|
lambda p, s: s,
|
||||||
|
_dict_node_handler)
|
||||||
|
self.assertIsInstance(mapped_deep_tree, dict)
|
||||||
|
with self.subTest(function='combined'):
|
||||||
|
leaves = list(
|
||||||
|
pytree_leaf_iter(
|
||||||
|
pytree_map(deep_tree,
|
||||||
|
lambda p, s: s.capitalize(),
|
||||||
|
_dict_node_handler),
|
||||||
|
lambda p, s: s + s,
|
||||||
|
_dict_node_handler))
|
||||||
|
self.assertEqual(leaves, ['LeafLeaf'])
|
||||||
|
|
||||||
|
def test_deep_freeze(self):
|
||||||
|
"""Tests guarantees given by `deep_freeze`."""
|
||||||
|
frozen = pytree_transforms.deep_freeze(
|
||||||
|
dict(a=[1001, 1002, dict(b=(1003, [1004, {1005, 1006}]))]))
|
||||||
|
self.assertIsInstance(frozen, collections.abc.Mapping)
|
||||||
|
self.assertNotIsInstance(frozen, collections.abc.MutableMapping)
|
||||||
|
self.assertIsInstance(frozen['a'], tuple)
|
||||||
|
# `frozen` is hashable, and hashes to an integer.
|
||||||
|
self.assertIsInstance(hash(frozen), int)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
||||||
|
|
@ -0,0 +1,4 @@
|
||||||
|
immutabledict>=4.2.0
|
||||||
|
numpy>=1.26.4
|
||||||
|
orbax-checkpoint>=0.0.0
|
||||||
|
|
||||||
Loading…
Reference in New Issue