diff --git a/compression/python/pytree/PYTREE_README.md b/compression/python/pytree/PYTREE_README.md new file mode 100644 index 0000000..4a04079 --- /dev/null +++ b/compression/python/pytree/PYTREE_README.md @@ -0,0 +1,8 @@ + +# General Remarks about the "PyTree" Abstraction + +The pytree wrangling code in this project does not use any of the existing +"pytree" modules. The deeper reason here is that our approach is based on an +analysis of the notion that emphasizes deeper underlying principles. This is +being discussed internally at the time of this writing. + diff --git a/compression/python/pytree/build_model_file_for_cpp_binary.py b/compression/python/pytree/build_model_file_for_cpp_binary.py new file mode 100644 index 0000000..d039639 --- /dev/null +++ b/compression/python/pytree/build_model_file_for_cpp_binary.py @@ -0,0 +1,275 @@ +"""Ad-hoc glue code for building the griffin model-file for the C++ binary. + +Usage: + +python3 -m venv $HOME/clients/griffin-venv + +. $HOME/clients/griffin-venv/bin/activate + +python3 -m pip install -r requirements.txt + +time python3 build_model_file_for_cpp_binary.py \ + $HOME/GRIFFIN/model_data \ + cpp_load_log.txt /tmp/G2B.data + +real 3m5.821s +user 2m9.205s +sys 2m46.720s + +./compress_weights --weights /tmp/G2B.data --model gr2b-it \ + --compressed_weights /tmp/G2B.compressed +./gemma --tokenizer tokenizer.spm --weights /tmp/G2B.compressed \ + --model gr2b-it + +Weights for the recurrent-gemma model that can be converted with this script +can be found at: + + https://www.kaggle.com/models/google/recurrentgemma/flax/2b-it +""" + +import pprint +import re +import sys + +from typing import Any, Mapping + +import numpy + +import orbax.checkpoint + +import ml_model_transforms +import pytree_transforms + + +def _fn_identity(x): return x + + +def _fn_transpose(x): return x.T + + +def _fn_transpose_all_heads(x): return x.transpose(0, 2, 1) + + +def _fn_scaled_softplus(a): + return -8 * numpy.logaddexp(a, 0) + + +def _fn_attention_moveaxis(a): + return a.reshape(10, 256, 2560).transpose(0, 2, 1) + + +def _aspec(pieces=(), transforms=()): + """Short-hand array-save-specification. + + Args: + pieces: Sequence of key-sequences identifying an array. + transforms: Sequence of transformations, indexed in + parallel to `pieces`, to apply to data arrays prior to saving. + Will be padded with identity-transformations to the length of `pieces`. + + Returns: + Specification as for use in _LAYETR_NAME_MAPPING. + """ + # `zip` trims to shortest sequence, so this amounts to using + # default-transforms. + # tuple() since we need a Sequence here, not a stateful-iterator zip_object. + return tuple(zip(pieces, list(transforms) + [_fn_identity] * len(pieces))) + + +_LAYER_NAME_MAPPING = pytree_transforms.deep_freeze({ + # Recurrent Layer + 'griffin_linear_x_w': _aspec( + [('recurrent_block', 'linear_x', 'kernel')], + [_fn_transpose]), + 'griffin_linear_x_biases': _aspec( + [('recurrent_block', 'linear_x', 'bias')]), + 'griffin_linear_y_w': _aspec( + [('recurrent_block', 'linear_y', 'kernel')], + [_fn_transpose]), + 'griffin_linear_y_biases': _aspec( + [('recurrent_block', 'linear_y', 'bias')]), + 'griffin_linear_out_w': _aspec( + [('recurrent_block', 'linear_out', 'kernel')], + [_fn_transpose]), + 'griffin_linear_out_biases': _aspec( + [('recurrent_block' ,'linear_out', 'bias')]), + 'griffin_conv_w': _aspec( + [('recurrent_block', 'conv_1d', 'w')]), + 'griffin_conv_biases': _aspec( + [('recurrent_block', 'conv_1d', 'b')]), + 'griffin_gate_w': _aspec( + [('recurrent_block', 'rg_lru', 'input_gate', 'w'), + ('recurrent_block', 'rg_lru', 'a_gate', 'w')], + [_fn_transpose_all_heads, _fn_transpose_all_heads]), + 'griffin_gate_biases': _aspec( + [('recurrent_block', 'rg_lru', 'input_gate', 'b'), + ('recurrent_block', 'rg_lru', 'a_gate', 'b')]), + 'griffin_a': _aspec( + [('recurrent_block', 'rg_lru', 'a_param')], + [_fn_scaled_softplus]), + # Attention Layer + 'qkv_einsum_w': _aspec( + [('attention_block', 'proj_q', 'kernel'), + ('attention_block', 'proj_k', 'kernel'), + ('attention_block', 'proj_v', 'kernel'), + ], + [_fn_transpose, _fn_transpose, _fn_transpose]), + 'attn_vec_einsum_w': _aspec( + [('attention_block', 'proj_final', 'kernel')], + [_fn_attention_moveaxis]), + 'attention_output_biases': _aspec( + [('attention_block', 'proj_final', 'bias')]), + # Common + 'pre_attention_norm_scale': _aspec( + [('temporal_pre_norm', 'scale')]), + 'pre_ffw_norm_scale': _aspec( + [('channel_pre_norm', 'scale')]), + 'gating_einsum_w': _aspec( + [('mlp_block', 'ffw_up', 'w')], + [_fn_transpose_all_heads]), + 'ffw_gating_biases': _aspec( + [('mlp_block', 'ffw_up', 'b')]), + 'linear_w': _aspec( + [('mlp_block', 'ffw_down', 'kernel')], + [_fn_transpose]), + 'ffw_output_biases': _aspec( + [('mlp_block', 'ffw_down', 'bias')]), + # Other + 'embedder_input_embedding': _aspec( + [('embedder', 'input_embedding')]), + 'final_norm_scale': _aspec( + [('final_norm', 'scale')]), +}) + + +def process_param_line(line : str) -> tuple[None | str, int, str]: + """Processes a "loading parameters" log-line from the griffin binary.""" + # This is slightly more permissive than strictly needed, to also handle + # some earlier form of the output. + matched = re.match( + r'(?a)Loading Parameters:? \(' + r'(?:layer=(?P\d+), )?' + r'size (?P\d+)\):? ' + r'(?P\S+)', + line) + if not matched: + return None + layer = matched['layer'] + wanted_size = int(matched['size']) + cpp_tag = matched['tag'] + return matched['layer'], int(matched['size']), matched['tag'] + + +def collect_pytree_keys(param_lines): + """Collects all the pytree keys and transforms for model-serialization.""" + pytree_keys = [] + array_transforms = [] + unsatisfied = [] + for maybe_spec in map(process_param_line, param_lines): + if not maybe_spec: continue # Skip non-parameter lines. + layer, wanted_size, cpp_tag = maybe_spec + pytree_key_tails_and_transforms = _LAYER_NAME_MAPPING.get(cpp_tag, ()) + if not pytree_key_tails_and_transforms: + unsatisfied.append((layer, cpp_tag)) + else: + for key_tail, array_transform in pytree_key_tails_and_transforms: + pytree_keys.append( + key_tail if layer is None + else (f'blocks.{layer}',) + key_tail) + array_transforms.append(array_transform) + return pytree_keys, array_transforms, unsatisfied + + +class UnsatisfiedArrayLoadsError(ValueError): + """Some array-loads could not be satisfied.""" + + +def flatten_model_for_cpp_binary(tree, + cpp_expectations_logfile_path : str, + out_path : str, + unsatisfied_ok : bool = False + ): + """Produces a model-parameters file readable by the C++ binary. + + Args: + tree: The pytree with model-parameters. + cpp_expectations_logfile_path: + Path to a logfile produced by the C++ binary that shows + the expected array-order. + out_path: Path to the model-weights file to be written. + unsatisfied_ok: If true, we ignore the presence of unsatisfied + array-loads and write a model-parameters file that skips these pieces. + This will lead to an unusable model-parameters file which however + still might be useful for other analysis. + + Returns: + Tuple `(unknown_keys, missing_keys)`, where `unknown_keys` + is a sequence of `(layer_or_None, name)` descriptions of the keys + in the C++ log that could not be satisfied, and `missing_keys` + is a sequence of linearized pytree key-sequences for keys + not found in the checkpoint. + + Raises: + UnsatisfiedArrayLoadsError: If some of the expected arrays + could not be included in the output and `unsatisfied_ok` + is false. + """ + with open(cpp_expectations_logfile_path, 'rt') as h_log: + pytree_keys, array_transforms, unknown_keys = collect_pytree_keys( + list(h_log)) + rank_by_pytree_key = {k: n for n, k in enumerate(pytree_keys)} + array_transform_by_pytree_key = dict(zip(pytree_keys, array_transforms)) + # + model_contents = ml_model_transforms.model_contents(tree) + missing_keys = set(pytree_keys) - model_contents.keys() + if (unknown_keys or missing_keys) and not unsatisfied_ok: + raise ValueError( + f'Unsatisfied loads: unknown_keys: {unknown_keys!r}, ' + f'missing keys: {sorted(missing_keys)!r}') + ml_model_transforms.model_save( + tree, + filepath_stem=out_path, + data_suffix='', + manifest_suffix=None, + array_transform_by_pytree_key=array_transform_by_pytree_key, + key=rank_by_pytree_key.get, + report=lambda line: print(line, file=sys.stderr), + byte_align=1) + return tuple(unknown_keys), tuple(sorted(missing_keys)) + + +def main(args): + """Creates the model-file. + + Args: + sys.argv[] parameters from command line sans the leading one. + + Returns: + The pytree with all the de-serialized variables, such as for convenient + `python3 -i` inspection. + """ + try: + model_dir, cpp_load_log, out_path = args + except Exception: + sys.exit(f'Usage: {__file__} [model_dir] [cpp_load_log] [output_filename]') + pattern = ("recurrent", "recurrent", "attention") + orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer() + variables = orbax_checkpointer.restore(model_dir) + if sorted(variables) == ['params']: + print('Warning: Using `variables["params"]` as tree-root.', file=sys.stderr) + variables_to_use = variables['params'] + else: + variables_to_use = variables + unknown, missing = flatten_model_for_cpp_binary(variables_to_use, + cpp_load_log, + out_path, + unsatisfied_ok=True) + print('Model file saved.\n' + f'# unknown:\n{pprint.pformat(unknown)}\n' + f'# missing:\n{pprint.pformat(missing)}') + return variables + + +if __name__ == '__main__': + # Return value assignment is for `python3 -i ...` inspection. + pytree = main(sys.argv[1:]) diff --git a/compression/python/pytree/cpp_load_log.txt b/compression/python/pytree/cpp_load_log.txt new file mode 100644 index 0000000..cc33394 --- /dev/null +++ b/compression/python/pytree/cpp_load_log.txt @@ -0,0 +1,380 @@ +Loading Parameters (size 2622750720): embedder_input_embedding +Loading Parameters (size 10240): final_norm_scale +Loading Parameters: (layer=0, size 26214400) griffin_linear_x_w +Loading Parameters: (layer=0, size 10240) griffin_linear_x_biases +Loading Parameters: (layer=0, size 26214400) griffin_linear_y_w +Loading Parameters: (layer=0, size 10240) griffin_linear_y_biases +Loading Parameters: (layer=0, size 26214400) griffin_linear_out_w +Loading Parameters: (layer=0, size 10240) griffin_linear_out_biases +Loading Parameters: (layer=0, size 40960) griffin_conv_w +Loading Parameters: (layer=0, size 10240) griffin_conv_biases +Loading Parameters: (layer=0, size 5242880) griffin_gate_w +Loading Parameters: (layer=0, size 20480) griffin_gate_biases +Loading Parameters: (layer=0, size 10240) griffin_a +Loading Parameters: (layer=0, size 157286400) gating_einsum_w +Loading Parameters: (layer=0, size 78643200) linear_w +Loading Parameters: (layer=0, size 10240) pre_attention_norm_scale +Loading Parameters: (layer=0, size 10240) pre_ffw_norm_scale +Loading Parameters: (layer=0, size 61440) ffw_gating_biases +Loading Parameters: (layer=0, size 10240) ffw_output_biases +Loading Parameters: (layer=1, size 26214400) griffin_linear_x_w +Loading Parameters: (layer=1, size 10240) griffin_linear_x_biases +Loading Parameters: (layer=1, size 26214400) griffin_linear_y_w +Loading Parameters: (layer=1, size 10240) griffin_linear_y_biases +Loading Parameters: (layer=1, size 26214400) griffin_linear_out_w +Loading Parameters: (layer=1, size 10240) griffin_linear_out_biases +Loading Parameters: (layer=1, size 40960) griffin_conv_w +Loading Parameters: (layer=1, size 10240) griffin_conv_biases +Loading Parameters: (layer=1, size 5242880) griffin_gate_w +Loading Parameters: (layer=1, size 20480) griffin_gate_biases +Loading Parameters: (layer=1, size 10240) griffin_a +Loading Parameters: (layer=1, size 157286400) gating_einsum_w +Loading Parameters: (layer=1, size 78643200) linear_w +Loading Parameters: (layer=1, size 10240) pre_attention_norm_scale +Loading Parameters: (layer=1, size 10240) pre_ffw_norm_scale +Loading Parameters: (layer=1, size 61440) ffw_gating_biases +Loading Parameters: (layer=1, size 10240) ffw_output_biases +Loading Parameters: (layer=2, size 26214400) attn_vec_einsum_w +Loading Parameters: (layer=2, size 78643200) qkv_einsum_w +Loading Parameters: (layer=2, size 157286400) gating_einsum_w +Loading Parameters: (layer=2, size 78643200) linear_w +Loading Parameters: (layer=2, size 10240) pre_attention_norm_scale +Loading Parameters: (layer=2, size 10240) pre_ffw_norm_scale +Loading Parameters: (layer=2, size 61440) ffw_gating_biases +Loading Parameters: (layer=2, size 10240) ffw_output_biases +Loading Parameters: (layer=2, size 10240) attention_output_biases +Loading Parameters: (layer=3, size 26214400) griffin_linear_x_w +Loading Parameters: (layer=3, size 10240) griffin_linear_x_biases +Loading Parameters: (layer=3, size 26214400) griffin_linear_y_w +Loading Parameters: (layer=3, size 10240) griffin_linear_y_biases +Loading Parameters: (layer=3, size 26214400) griffin_linear_out_w +Loading Parameters: (layer=3, size 10240) griffin_linear_out_biases +Loading Parameters: (layer=3, size 40960) griffin_conv_w +Loading Parameters: (layer=3, size 10240) griffin_conv_biases +Loading Parameters: (layer=3, size 5242880) griffin_gate_w +Loading Parameters: (layer=3, size 20480) griffin_gate_biases +Loading Parameters: (layer=3, size 10240) griffin_a +Loading Parameters: (layer=3, size 157286400) gating_einsum_w +Loading Parameters: (layer=3, size 78643200) linear_w +Loading Parameters: (layer=3, size 10240) pre_attention_norm_scale +Loading Parameters: (layer=3, size 10240) pre_ffw_norm_scale +Loading Parameters: (layer=3, size 61440) ffw_gating_biases +Loading Parameters: (layer=3, size 10240) ffw_output_biases +Loading Parameters: (layer=4, size 26214400) griffin_linear_x_w +Loading Parameters: (layer=4, size 10240) griffin_linear_x_biases +Loading Parameters: (layer=4, size 26214400) griffin_linear_y_w +Loading Parameters: (layer=4, size 10240) griffin_linear_y_biases +Loading Parameters: (layer=4, size 26214400) griffin_linear_out_w +Loading Parameters: (layer=4, size 10240) griffin_linear_out_biases +Loading Parameters: (layer=4, size 40960) griffin_conv_w +Loading Parameters: (layer=4, size 10240) griffin_conv_biases +Loading Parameters: (layer=4, size 5242880) griffin_gate_w +Loading Parameters: (layer=4, size 20480) griffin_gate_biases +Loading Parameters: (layer=4, size 10240) griffin_a +Loading Parameters: (layer=4, size 157286400) gating_einsum_w +Loading Parameters: (layer=4, size 78643200) linear_w +Loading Parameters: (layer=4, size 10240) pre_attention_norm_scale +Loading Parameters: (layer=4, size 10240) pre_ffw_norm_scale +Loading Parameters: (layer=4, size 61440) ffw_gating_biases +Loading Parameters: (layer=4, size 10240) ffw_output_biases +Loading Parameters: (layer=5, size 26214400) attn_vec_einsum_w +Loading Parameters: (layer=5, size 78643200) qkv_einsum_w +Loading Parameters: (layer=5, size 157286400) gating_einsum_w +Loading Parameters: (layer=5, size 78643200) linear_w +Loading Parameters: (layer=5, size 10240) pre_attention_norm_scale +Loading Parameters: (layer=5, size 10240) pre_ffw_norm_scale +Loading Parameters: (layer=5, size 61440) ffw_gating_biases +Loading Parameters: (layer=5, size 10240) ffw_output_biases +Loading Parameters: (layer=5, size 10240) attention_output_biases +Loading Parameters: (layer=6, size 26214400) griffin_linear_x_w +Loading Parameters: (layer=6, size 10240) griffin_linear_x_biases +Loading Parameters: (layer=6, size 26214400) griffin_linear_y_w +Loading Parameters: (layer=6, size 10240) griffin_linear_y_biases +Loading Parameters: (layer=6, size 26214400) griffin_linear_out_w +Loading Parameters: (layer=6, size 10240) griffin_linear_out_biases +Loading Parameters: (layer=6, size 40960) griffin_conv_w +Loading Parameters: (layer=6, size 10240) griffin_conv_biases +Loading Parameters: (layer=6, size 5242880) griffin_gate_w +Loading Parameters: (layer=6, size 20480) griffin_gate_biases +Loading Parameters: (layer=6, size 10240) griffin_a +Loading Parameters: (layer=6, size 157286400) gating_einsum_w +Loading Parameters: (layer=6, size 78643200) linear_w +Loading Parameters: (layer=6, size 10240) pre_attention_norm_scale +Loading Parameters: (layer=6, size 10240) pre_ffw_norm_scale +Loading Parameters: (layer=6, size 61440) ffw_gating_biases +Loading Parameters: (layer=6, size 10240) ffw_output_biases +Loading Parameters: (layer=7, size 26214400) griffin_linear_x_w +Loading Parameters: (layer=7, size 10240) griffin_linear_x_biases +Loading Parameters: (layer=7, size 26214400) griffin_linear_y_w +Loading Parameters: (layer=7, size 10240) griffin_linear_y_biases +Loading Parameters: (layer=7, size 26214400) griffin_linear_out_w +Loading Parameters: (layer=7, size 10240) griffin_linear_out_biases +Loading Parameters: (layer=7, size 40960) griffin_conv_w +Loading Parameters: (layer=7, size 10240) griffin_conv_biases +Loading Parameters: (layer=7, size 5242880) griffin_gate_w +Loading Parameters: (layer=7, size 20480) griffin_gate_biases +Loading Parameters: (layer=7, size 10240) griffin_a +Loading Parameters: (layer=7, size 157286400) gating_einsum_w +Loading Parameters: (layer=7, size 78643200) linear_w +Loading Parameters: (layer=7, size 10240) pre_attention_norm_scale +Loading Parameters: (layer=7, size 10240) pre_ffw_norm_scale +Loading Parameters: (layer=7, size 61440) ffw_gating_biases +Loading Parameters: (layer=7, size 10240) ffw_output_biases +Loading Parameters: (layer=8, size 26214400) attn_vec_einsum_w +Loading Parameters: (layer=8, size 78643200) qkv_einsum_w +Loading Parameters: (layer=8, size 157286400) gating_einsum_w +Loading Parameters: (layer=8, size 78643200) linear_w +Loading Parameters: (layer=8, size 10240) pre_attention_norm_scale +Loading Parameters: (layer=8, size 10240) pre_ffw_norm_scale +Loading Parameters: (layer=8, size 61440) ffw_gating_biases +Loading Parameters: (layer=8, size 10240) ffw_output_biases +Loading Parameters: (layer=8, size 10240) attention_output_biases +Loading Parameters: (layer=9, size 26214400) griffin_linear_x_w +Loading Parameters: (layer=9, size 10240) griffin_linear_x_biases +Loading Parameters: (layer=9, size 26214400) griffin_linear_y_w +Loading Parameters: (layer=9, size 10240) griffin_linear_y_biases +Loading Parameters: (layer=9, size 26214400) griffin_linear_out_w +Loading Parameters: (layer=9, size 10240) griffin_linear_out_biases +Loading Parameters: (layer=9, size 40960) griffin_conv_w +Loading Parameters: (layer=9, size 10240) griffin_conv_biases +Loading Parameters: (layer=9, size 5242880) griffin_gate_w +Loading Parameters: (layer=9, size 20480) griffin_gate_biases +Loading Parameters: (layer=9, size 10240) griffin_a +Loading Parameters: (layer=9, size 157286400) gating_einsum_w +Loading Parameters: (layer=9, size 78643200) linear_w +Loading Parameters: (layer=9, size 10240) pre_attention_norm_scale +Loading Parameters: (layer=9, size 10240) pre_ffw_norm_scale +Loading Parameters: (layer=9, size 61440) ffw_gating_biases +Loading Parameters: (layer=9, size 10240) ffw_output_biases +Loading Parameters: (layer=10, size 26214400) griffin_linear_x_w +Loading Parameters: (layer=10, size 10240) griffin_linear_x_biases +Loading Parameters: (layer=10, size 26214400) griffin_linear_y_w +Loading Parameters: (layer=10, size 10240) griffin_linear_y_biases +Loading Parameters: (layer=10, size 26214400) griffin_linear_out_w +Loading Parameters: (layer=10, size 10240) griffin_linear_out_biases +Loading Parameters: (layer=10, size 40960) griffin_conv_w +Loading Parameters: (layer=10, size 10240) griffin_conv_biases +Loading Parameters: (layer=10, size 5242880) griffin_gate_w +Loading Parameters: (layer=10, size 20480) griffin_gate_biases +Loading Parameters: (layer=10, size 10240) griffin_a +Loading Parameters: (layer=10, size 157286400) gating_einsum_w +Loading Parameters: (layer=10, size 78643200) linear_w +Loading Parameters: (layer=10, size 10240) pre_attention_norm_scale +Loading Parameters: (layer=10, size 10240) pre_ffw_norm_scale +Loading Parameters: (layer=10, size 61440) ffw_gating_biases +Loading Parameters: (layer=10, size 10240) ffw_output_biases +Loading Parameters: (layer=11, size 26214400) attn_vec_einsum_w +Loading Parameters: (layer=11, size 78643200) qkv_einsum_w +Loading Parameters: (layer=11, size 157286400) gating_einsum_w +Loading Parameters: (layer=11, size 78643200) linear_w +Loading Parameters: (layer=11, size 10240) pre_attention_norm_scale +Loading Parameters: (layer=11, size 10240) pre_ffw_norm_scale +Loading Parameters: (layer=11, size 61440) ffw_gating_biases +Loading Parameters: (layer=11, size 10240) ffw_output_biases +Loading Parameters: (layer=11, size 10240) attention_output_biases +Loading Parameters: (layer=12, size 26214400) griffin_linear_x_w +Loading Parameters: (layer=12, size 10240) griffin_linear_x_biases +Loading Parameters: (layer=12, size 26214400) griffin_linear_y_w +Loading Parameters: (layer=12, size 10240) griffin_linear_y_biases +Loading Parameters: (layer=12, size 26214400) griffin_linear_out_w +Loading Parameters: (layer=12, size 10240) griffin_linear_out_biases +Loading Parameters: (layer=12, size 40960) griffin_conv_w +Loading Parameters: (layer=12, size 10240) griffin_conv_biases +Loading Parameters: (layer=12, size 5242880) griffin_gate_w +Loading Parameters: (layer=12, size 20480) griffin_gate_biases +Loading Parameters: (layer=12, size 10240) griffin_a +Loading Parameters: (layer=12, size 157286400) gating_einsum_w +Loading Parameters: (layer=12, size 78643200) linear_w +Loading Parameters: (layer=12, size 10240) pre_attention_norm_scale +Loading Parameters: (layer=12, size 10240) pre_ffw_norm_scale +Loading Parameters: (layer=12, size 61440) ffw_gating_biases +Loading Parameters: (layer=12, size 10240) ffw_output_biases +Loading Parameters: (layer=13, size 26214400) griffin_linear_x_w +Loading Parameters: (layer=13, size 10240) griffin_linear_x_biases +Loading Parameters: (layer=13, size 26214400) griffin_linear_y_w +Loading Parameters: (layer=13, size 10240) griffin_linear_y_biases +Loading Parameters: (layer=13, size 26214400) griffin_linear_out_w +Loading Parameters: (layer=13, size 10240) griffin_linear_out_biases +Loading Parameters: (layer=13, size 40960) griffin_conv_w +Loading Parameters: (layer=13, size 10240) griffin_conv_biases +Loading Parameters: (layer=13, size 5242880) griffin_gate_w +Loading Parameters: (layer=13, size 20480) griffin_gate_biases +Loading Parameters: (layer=13, size 10240) griffin_a +Loading Parameters: (layer=13, size 157286400) gating_einsum_w +Loading Parameters: (layer=13, size 78643200) linear_w +Loading Parameters: (layer=13, size 10240) pre_attention_norm_scale +Loading Parameters: (layer=13, size 10240) pre_ffw_norm_scale +Loading Parameters: (layer=13, size 61440) ffw_gating_biases +Loading Parameters: (layer=13, size 10240) ffw_output_biases +Loading Parameters: (layer=14, size 26214400) attn_vec_einsum_w +Loading Parameters: (layer=14, size 78643200) qkv_einsum_w +Loading Parameters: (layer=14, size 157286400) gating_einsum_w +Loading Parameters: (layer=14, size 78643200) linear_w +Loading Parameters: (layer=14, size 10240) pre_attention_norm_scale +Loading Parameters: (layer=14, size 10240) pre_ffw_norm_scale +Loading Parameters: (layer=14, size 61440) ffw_gating_biases +Loading Parameters: (layer=14, size 10240) ffw_output_biases +Loading Parameters: (layer=14, size 10240) attention_output_biases +Loading Parameters: (layer=15, size 26214400) griffin_linear_x_w +Loading Parameters: (layer=15, size 10240) griffin_linear_x_biases +Loading Parameters: (layer=15, size 26214400) griffin_linear_y_w +Loading Parameters: (layer=15, size 10240) griffin_linear_y_biases +Loading Parameters: (layer=15, size 26214400) griffin_linear_out_w +Loading Parameters: (layer=15, size 10240) griffin_linear_out_biases +Loading Parameters: (layer=15, size 40960) griffin_conv_w +Loading Parameters: (layer=15, size 10240) griffin_conv_biases +Loading Parameters: (layer=15, size 5242880) griffin_gate_w +Loading Parameters: (layer=15, size 20480) griffin_gate_biases +Loading Parameters: (layer=15, size 10240) griffin_a +Loading Parameters: (layer=15, size 157286400) gating_einsum_w +Loading Parameters: (layer=15, size 78643200) linear_w +Loading Parameters: (layer=15, size 10240) pre_attention_norm_scale +Loading Parameters: (layer=15, size 10240) pre_ffw_norm_scale +Loading Parameters: (layer=15, size 61440) ffw_gating_biases +Loading Parameters: (layer=15, size 10240) ffw_output_biases +Loading Parameters: (layer=16, size 26214400) griffin_linear_x_w +Loading Parameters: (layer=16, size 10240) griffin_linear_x_biases +Loading Parameters: (layer=16, size 26214400) griffin_linear_y_w +Loading Parameters: (layer=16, size 10240) griffin_linear_y_biases +Loading Parameters: (layer=16, size 26214400) griffin_linear_out_w +Loading Parameters: (layer=16, size 10240) griffin_linear_out_biases +Loading Parameters: (layer=16, size 40960) griffin_conv_w +Loading Parameters: (layer=16, size 10240) griffin_conv_biases +Loading Parameters: (layer=16, size 5242880) griffin_gate_w +Loading Parameters: (layer=16, size 20480) griffin_gate_biases +Loading Parameters: (layer=16, size 10240) griffin_a +Loading Parameters: (layer=16, size 157286400) gating_einsum_w +Loading Parameters: (layer=16, size 78643200) linear_w +Loading Parameters: (layer=16, size 10240) pre_attention_norm_scale +Loading Parameters: (layer=16, size 10240) pre_ffw_norm_scale +Loading Parameters: (layer=16, size 61440) ffw_gating_biases +Loading Parameters: (layer=16, size 10240) ffw_output_biases +Loading Parameters: (layer=17, size 26214400) attn_vec_einsum_w +Loading Parameters: (layer=17, size 78643200) qkv_einsum_w +Loading Parameters: (layer=17, size 157286400) gating_einsum_w +Loading Parameters: (layer=17, size 78643200) linear_w +Loading Parameters: (layer=17, size 10240) pre_attention_norm_scale +Loading Parameters: (layer=17, size 10240) pre_ffw_norm_scale +Loading Parameters: (layer=17, size 61440) ffw_gating_biases +Loading Parameters: (layer=17, size 10240) ffw_output_biases +Loading Parameters: (layer=17, size 10240) attention_output_biases +Loading Parameters: (layer=18, size 26214400) griffin_linear_x_w +Loading Parameters: (layer=18, size 10240) griffin_linear_x_biases +Loading Parameters: (layer=18, size 26214400) griffin_linear_y_w +Loading Parameters: (layer=18, size 10240) griffin_linear_y_biases +Loading Parameters: (layer=18, size 26214400) griffin_linear_out_w +Loading Parameters: (layer=18, size 10240) griffin_linear_out_biases +Loading Parameters: (layer=18, size 40960) griffin_conv_w +Loading Parameters: (layer=18, size 10240) griffin_conv_biases +Loading Parameters: (layer=18, size 5242880) griffin_gate_w +Loading Parameters: (layer=18, size 20480) griffin_gate_biases +Loading Parameters: (layer=18, size 10240) griffin_a +Loading Parameters: (layer=18, size 157286400) gating_einsum_w +Loading Parameters: (layer=18, size 78643200) linear_w +Loading Parameters: (layer=18, size 10240) pre_attention_norm_scale +Loading Parameters: (layer=18, size 10240) pre_ffw_norm_scale +Loading Parameters: (layer=18, size 61440) ffw_gating_biases +Loading Parameters: (layer=18, size 10240) ffw_output_biases +Loading Parameters: (layer=19, size 26214400) griffin_linear_x_w +Loading Parameters: (layer=19, size 10240) griffin_linear_x_biases +Loading Parameters: (layer=19, size 26214400) griffin_linear_y_w +Loading Parameters: (layer=19, size 10240) griffin_linear_y_biases +Loading Parameters: (layer=19, size 26214400) griffin_linear_out_w +Loading Parameters: (layer=19, size 10240) griffin_linear_out_biases +Loading Parameters: (layer=19, size 40960) griffin_conv_w +Loading Parameters: (layer=19, size 10240) griffin_conv_biases +Loading Parameters: (layer=19, size 5242880) griffin_gate_w +Loading Parameters: (layer=19, size 20480) griffin_gate_biases +Loading Parameters: (layer=19, size 10240) griffin_a +Loading Parameters: (layer=19, size 157286400) gating_einsum_w +Loading Parameters: (layer=19, size 78643200) linear_w +Loading Parameters: (layer=19, size 10240) pre_attention_norm_scale +Loading Parameters: (layer=19, size 10240) pre_ffw_norm_scale +Loading Parameters: (layer=19, size 61440) ffw_gating_biases +Loading Parameters: (layer=19, size 10240) ffw_output_biases +Loading Parameters: (layer=20, size 26214400) attn_vec_einsum_w +Loading Parameters: (layer=20, size 78643200) qkv_einsum_w +Loading Parameters: (layer=20, size 157286400) gating_einsum_w +Loading Parameters: (layer=20, size 78643200) linear_w +Loading Parameters: (layer=20, size 10240) pre_attention_norm_scale +Loading Parameters: (layer=20, size 10240) pre_ffw_norm_scale +Loading Parameters: (layer=20, size 61440) ffw_gating_biases +Loading Parameters: (layer=20, size 10240) ffw_output_biases +Loading Parameters: (layer=20, size 10240) attention_output_biases +Loading Parameters: (layer=21, size 26214400) griffin_linear_x_w +Loading Parameters: (layer=21, size 10240) griffin_linear_x_biases +Loading Parameters: (layer=21, size 26214400) griffin_linear_y_w +Loading Parameters: (layer=21, size 10240) griffin_linear_y_biases +Loading Parameters: (layer=21, size 26214400) griffin_linear_out_w +Loading Parameters: (layer=21, size 10240) griffin_linear_out_biases +Loading Parameters: (layer=21, size 40960) griffin_conv_w +Loading Parameters: (layer=21, size 10240) griffin_conv_biases +Loading Parameters: (layer=21, size 5242880) griffin_gate_w +Loading Parameters: (layer=21, size 20480) griffin_gate_biases +Loading Parameters: (layer=21, size 10240) griffin_a +Loading Parameters: (layer=21, size 157286400) gating_einsum_w +Loading Parameters: (layer=21, size 78643200) linear_w +Loading Parameters: (layer=21, size 10240) pre_attention_norm_scale +Loading Parameters: (layer=21, size 10240) pre_ffw_norm_scale +Loading Parameters: (layer=21, size 61440) ffw_gating_biases +Loading Parameters: (layer=21, size 10240) ffw_output_biases +Loading Parameters: (layer=22, size 26214400) griffin_linear_x_w +Loading Parameters: (layer=22, size 10240) griffin_linear_x_biases +Loading Parameters: (layer=22, size 26214400) griffin_linear_y_w +Loading Parameters: (layer=22, size 10240) griffin_linear_y_biases +Loading Parameters: (layer=22, size 26214400) griffin_linear_out_w +Loading Parameters: (layer=22, size 10240) griffin_linear_out_biases +Loading Parameters: (layer=22, size 40960) griffin_conv_w +Loading Parameters: (layer=22, size 10240) griffin_conv_biases +Loading Parameters: (layer=22, size 5242880) griffin_gate_w +Loading Parameters: (layer=22, size 20480) griffin_gate_biases +Loading Parameters: (layer=22, size 10240) griffin_a +Loading Parameters: (layer=22, size 157286400) gating_einsum_w +Loading Parameters: (layer=22, size 78643200) linear_w +Loading Parameters: (layer=22, size 10240) pre_attention_norm_scale +Loading Parameters: (layer=22, size 10240) pre_ffw_norm_scale +Loading Parameters: (layer=22, size 61440) ffw_gating_biases +Loading Parameters: (layer=22, size 10240) ffw_output_biases +Loading Parameters: (layer=23, size 26214400) attn_vec_einsum_w +Loading Parameters: (layer=23, size 78643200) qkv_einsum_w +Loading Parameters: (layer=23, size 157286400) gating_einsum_w +Loading Parameters: (layer=23, size 78643200) linear_w +Loading Parameters: (layer=23, size 10240) pre_attention_norm_scale +Loading Parameters: (layer=23, size 10240) pre_ffw_norm_scale +Loading Parameters: (layer=23, size 61440) ffw_gating_biases +Loading Parameters: (layer=23, size 10240) ffw_output_biases +Loading Parameters: (layer=23, size 10240) attention_output_biases +Loading Parameters: (layer=24, size 26214400) griffin_linear_x_w +Loading Parameters: (layer=24, size 10240) griffin_linear_x_biases +Loading Parameters: (layer=24, size 26214400) griffin_linear_y_w +Loading Parameters: (layer=24, size 10240) griffin_linear_y_biases +Loading Parameters: (layer=24, size 26214400) griffin_linear_out_w +Loading Parameters: (layer=24, size 10240) griffin_linear_out_biases +Loading Parameters: (layer=24, size 40960) griffin_conv_w +Loading Parameters: (layer=24, size 10240) griffin_conv_biases +Loading Parameters: (layer=24, size 5242880) griffin_gate_w +Loading Parameters: (layer=24, size 20480) griffin_gate_biases +Loading Parameters: (layer=24, size 10240) griffin_a +Loading Parameters: (layer=24, size 157286400) gating_einsum_w +Loading Parameters: (layer=24, size 78643200) linear_w +Loading Parameters: (layer=24, size 10240) pre_attention_norm_scale +Loading Parameters: (layer=24, size 10240) pre_ffw_norm_scale +Loading Parameters: (layer=24, size 61440) ffw_gating_biases +Loading Parameters: (layer=24, size 10240) ffw_output_biases +Loading Parameters: (layer=25, size 26214400) griffin_linear_x_w +Loading Parameters: (layer=25, size 10240) griffin_linear_x_biases +Loading Parameters: (layer=25, size 26214400) griffin_linear_y_w +Loading Parameters: (layer=25, size 10240) griffin_linear_y_biases +Loading Parameters: (layer=25, size 26214400) griffin_linear_out_w +Loading Parameters: (layer=25, size 10240) griffin_linear_out_biases +Loading Parameters: (layer=25, size 40960) griffin_conv_w +Loading Parameters: (layer=25, size 10240) griffin_conv_biases +Loading Parameters: (layer=25, size 5242880) griffin_gate_w +Loading Parameters: (layer=25, size 20480) griffin_gate_biases +Loading Parameters: (layer=25, size 10240) griffin_a +Loading Parameters: (layer=25, size 157286400) gating_einsum_w +Loading Parameters: (layer=25, size 78643200) linear_w +Loading Parameters: (layer=25, size 10240) pre_attention_norm_scale +Loading Parameters: (layer=25, size 10240) pre_ffw_norm_scale +Loading Parameters: (layer=25, size 61440) ffw_gating_biases +Loading Parameters: (layer=25, size 10240) ffw_output_biases diff --git a/compression/python/pytree/ml_model_transforms.py b/compression/python/pytree/ml_model_transforms.py new file mode 100644 index 0000000..3605c07 --- /dev/null +++ b/compression/python/pytree/ml_model_transforms.py @@ -0,0 +1,371 @@ +"""Transformations for python-trees representing the parameters of a ML model. + +Important: This module assumes that byte-order is the same on the +machine that serializes data and the machine that deserializes +data. If, for example, numpy-data gets dumped, respectively loaded, +with a dtype-specification of numpy.float32, on-file byte-order +will be host byte order. + +""" + +import ast +import hashlib +import itertools +import pprint +import sys +import time +from typing import Any, Callable, Iterable, Iterator, Mapping, TypeVar + +import numpy +import pytree_transforms + + +NT = TypeVar('NT') + + +def ml_model_leaf_summary(path, x, sep=', '): + """Produces a textual summary of a leaf-node and its path. + + Args: + path: The path-to-root, as a reverse-order recursive + pair of path-components, with `()` as root. + x: The leaf-object. + sep: the separator between description-elements. + Default ', ' allows for convenient line-by-line processing + (such as via grep, perl -ne, etc.), but using e.g. sep=',\n ' + might be more useful for human consumption. + + Returns: + A human-readable string providing information about the node. + """ + # Using `repr` for path-components to get a faithful presentation. + # (...which still however would be somewat painful to correctly + # split into components.) + path_str = ','.join(map(repr, + pytree_transforms.linearize_revtuple_path(path))) + tx = type(x) + mod = tx.__module__ # Either a module or a string like 'builtins'. + modname = mod if isinstance(mod, str) else mod.__name__ + type_str = f'{modname}.{tx.__qualname__}' + try: + # `numpy.ndarray` instances have a `.data` property that gives access + # to a buffer via which we can hashlib-fingerprint the numerical + # contents. We here simply try to produce a fingerprint and also look + # up the .dtype of the object. Technically, there is a somewhat-unsound + # assumption here that if these operations succeed, we are indeed looking + # at a ndarray or sufficiently similar object for these operations to + # make sense. As the output is declared "for human consumption", this + # fishiness is not a problem. + fp = hashlib.sha256(x.data).hexdigest() + start = list(itertools.islice(x.flat, 5)) + stats_str = ( + f'min={numpy.min(x):.6g}, max={numpy.max(x):.6g}, ' + f'mean={numpy.mean(x):.6g}, std={numpy.std(x):.6g}') + return (f'{path_str:60s}: <{type_str}{sep}' + f'fp=0x{fp}{sep}{stats_str}{sep}shape={x.shape}, ' + f'dtype={x.dtype}{sep}start={start}>') + except (AttributeError, ValueError, TypeError): + # Fallback - trying to include information about the data-content + # of a likely-numerical-array failed. + return f'{path_str:60s}: {type_str}({repr(x)})' + + +# A specialized node-handler. +# Interface follows node-handler expectations defined in pytree_transforms. +def _ml_model_tree_node_handler(path: tuple, node : NT) -> ( + None | tuple[Callable[[Iterable[tuple[Any, NT]]], NT], + Iterator[tuple[Any, NT]]]): + """Processes a tree-node as required by pytree-iteration and -mapping. + + Args: + path: revtuple path to the current node. + node: a tree-node in a ML-model tree that is recursively + built out of `numpy.ndarray` leaf-values and dicts mapping + node-name string-keys to other such nodes representing subtrees - + and nothing else. + + Returns: + `None` if the tree-node is to be regarded as a leaf, otherwise + a pair `(rebuilder, iterator)`, where `iterator` iterates + over the data-content of the node, each item represented as a pair + of `(lookup_path_item, value_item)`, and `rebuilder` is a function + which, when applied to `iterator` or any iterable with the same + elements, returns a node that is equivalent to the original. + + Raises: + NotAMLModelTreeNodeError: If the tree contains a node that is neither + a `dict` nor a `numpy.ndarray` instance. + """ + # The astute reader will notice that we are doing something fishy + # here - this code could not be translated to Haskell as-is, since + # `NT` cannot actually be a proper type-variable in the sense + # of parametric polymorphism. + del path # Unused. + if isinstance(node, dict): + return dict, iter(node.items()) + if isinstance(node, numpy.ndarray): + return None + raise pytree_transforms.NotAMLModelTreeNodeError( + f'Type of bad node: {type(node)}') + + +def _ml_model_extract_leaf_transform( + path: pytree_transforms.RevTuplePath, + leaf: Any): + """Maps an array-leaf to a pair `(full_path, lambda: array)`. + + The computation that produces the leaf-value is lazified underneath + a `lambda`, since if we e.g. performed a memory-expensive + transformation (such as some dtype-changes) directly at this point, + then going from an iterator over tree-items for one-by-one + consumption to a list of these items would have all the + dtype-transformed values around simultaneously. We want to avoid + situations where we can do nothing about having multiple variants + of the data simultaneously in memory. + """ + # Hack: If we are encountering a `bfloat16` numpy-array, + # we pretend to have the data as a numpy.float32 array, + # since that's about all that contemporary CPUs can process + # efficiently here. + linearized_path = pytree_transforms.linearize_revtuple_path(path) + try: + # We have to use some trickery to detect `bfloat16`. + if leaf.dtype.descr[-1] == ('', ' Any: + """Performs perl-style autovivification on a nested-dict tree. + + Args: + keys_and_vals: An iterable of pairs `(key_path, value)`, where + `key_path` is a sequence of keys to be used to navigate to + the result via iterative dict-lookup, left-to-right. + Must not have duplicate keys, and must not more than one key if + an empty-sequence key is present. If this iterable is an + iterator, it will be fully exhausted on successful execution. + + Returns: + An object representing a nested-dict structure such that + for every `key_path` from `keys_and_vals`, recursive-dict-lookup + on the elements of that path starting from this object will + produce the corresponding value. An empty `keys_and_vals` + set will return `{}`. Every dict in the nested return-value + that has been populated by autovivification is newly allocated. + """ + # Code structure is a bit gnarly here due to f(keys_and_vals=[((), x)]) + # having to evaluate to x and not a dict. + # There may be ways to prettify/simplify this. + result = None + empty = {} + for linear_path, val in keys_and_vals: + if linear_path == (): + if result is not None: + raise ValueError('Root-value seen alongside other values.') + result = val + else: + if result is None: + result = {} + elif type(result) is not dict: + # We already did encounter a root-value. + raise ValueError('Root-value seen alongside other values.') + cursor = result + for n in range(len(linear_path) - 1): + cursor = cursor.setdefault(linear_path[n], empty) + if cursor is empty: + # Regenerate `empty` if we just used it up. + empty = {} + cursor[linear_path[-1]] = val + return {} if result is None else result + + +def model_overview(tree, out=None) -> None: + """Prints a human-readable overview to `(out or sys.stdout)`.""" + actual_out = out or sys.stdout + for line in pytree_transforms.pytree_leaf_iter( + tree, ml_model_leaf_summary, + _ml_model_tree_node_handler): + print(line, file=actual_out) + + +def model_contents(tree) -> Mapping[tuple[str, ...], Any]: + """Maps a model to a {pytree_keys: data_array} mapping. + + Args: + tree: The ML-model parameter-tree, built recursively out of + dict-instances with numpy.ndarray instances as leaves. + + Returns: + A mapping from linearized pytree-key-sequence tuple to the corresponding + leaf-value. + """ + def leaf_transform(revtuple_path, leaf): + return pytree_transforms.linearize_revtuple_path(revtuple_path), leaf + return dict( + pytree_transforms.pytree_leaf_iter( + tree, leaf_transform, _ml_model_tree_node_handler)) + + +def _fn_identity(x): return x + + +def model_save(tree, + filepath_stem: str, + data_suffix: str = '.data', + manifest_suffix: str | None = '.manifest', + key: Callable[[tuple[str, ...]], Any] | None = None, + array_transform_by_pytree_key: ( + Mapping[tuple[str, ...], + Callable[[numpy.ndarray], numpy.ndarray]] | + None) = None, + report: Callable[[str], None] | None = None, + byte_align: int = 8) -> tuple[int, float]: + """Saves the content of a ML-model parameter-tree to filesystem. + + After successful execution, the file f"{filepath_stem}.data" + will hold the combined numerical model-parameters, and + f"{filepath_stem}.manifest" will contain the key for interpreting + (and rebuilding) the data. + + Args: + tree: The ML-model parameter-tree, built recursively out of + dict-instances with numpy.ndarray instances as leaves. + filepath_stem: Filesystem location for data. + data_suffix: Suffix to use for the data file. + manifest_suffix: Either `None`, in which case no manifest-file + will get written, or the suffix for the manifest-file. + key: `None` or a key-function that will be applied to the linear model-path + and used for sorting the data arrays by increasing value of the + key-function. If the key-function returns `None` on an item, + then this item is not included. + array_transform_by_pytree_key: Optional mapping from pytree-key + to an array-to-array transformation function to apply to the array + prior to serialization. + report: Optional callable for logging progress-reports. + byte_align: byte-alignment to use for numerical array data. + Numerical arrays whose size in bytes is not a multiple of this + will get padded to the next full multiple. + + Returns: + A pair of `(size, time_sec)`, where `size` is the total byte-size + of the `.data` file and `time_sec` is the elapsed time + for saving the model, in seconds. + """ + time0 = time.monotonic() + if array_transform_by_pytree_key is None: + array_transform_by_pytree_key = {} + model_lazy_items = ( + pytree_transforms.pytree_leaf_iter( + tree, _ml_model_extract_leaf_transform, + _ml_model_tree_node_handler)) + if key is not None: + to_write = [ + nkv[1:] for nkv in sorted( + (nkv for nkv in ((key(path), path, v) + for path, v in model_lazy_items) + if nkv[0] is not None), key=lambda nkv: nkv[0])] + else: + to_write = list(model_lazy_items) + # + def lazy_arr_path_shape_dtype_size(path_and_lazy_arr): + path, lazy_arr = path_and_lazy_arr + arr = array_transform_by_pytree_key.get(path, _fn_identity)(lazy_arr()) + return path, arr.shape, arr.dtype, arr.data.nbytes + arrs_path_shape_dtype_nbytes = list( + map(lazy_arr_path_shape_dtype_size, to_write)) + # We need to know the total size of all the data. + bytesizes = [nbytes for *_, nbytes in arrs_path_shape_dtype_nbytes] + padded_bytesizes = [-(-bytesize // byte_align * byte_align) + for bytesize in bytesizes] + offsets = numpy.cumsum([0] + padded_bytesizes) + membuf = numpy.memmap(filepath_stem + data_suffix, + mode='w+', shape=offsets[-1]) + try: + for (path, shape, dtype, nbytes), offset, (_, lazy_arr) in zip( + arrs_path_shape_dtype_nbytes, offsets, to_write): + # Note that if getting the array from the lazy lambda involved some + # computation, such as a copying dtype-change, that computation would + # end up being done multiple times here - including once above, to compute + # byte-sizes, and once more here. + transformed_arr = array_transform_by_pytree_key.get( + path, + _fn_identity)(lazy_arr()) + membuf[offset : offset + nbytes] = numpy.frombuffer( + transformed_arr.ravel().data, 'u1') + if report is not None: + samples = ', '.join(map(str, transformed_arr.ravel()[:5])) + report(f'# Adding: {path!r}\n bytes: {nbytes:10d}, ' + f'shape: {shape!r:30},\n start: [{samples}, ...]') + transformed_arr = None # Drop memory references to numerical arrays ASAP. + finally: + if membuf is not None: + membuf.flush() + # NumPy wart: the memory-buffer is a resource that conceptually + # should be .close()able - since mmap()ing holds on to a + # file descriptor. However, it looks as if that clean-up were done + # in the "finalizer", despite that having meanwhile been widely + # understood as dubious practice. So, the best we can do here is + # to explicitly and clearly remove our reference to the instance. + del membuf + if manifest_suffix is not None: + # We still have to serialize the data that allows us to reconstruct + # a tree that is equivalent to the original. + manifest_data = [ + dict(path=path, + dtype=dtype.descr[-1][-1], + shape=shape, + nbytes=nbytes, + offset=offset) + for (path, shape, dtype, nbytes), offset in zip( + arrs_path_shape_dtype_nbytes, offsets)] + with open(filepath_stem + '.manifest', 'wt') as h_manifest: + pprint.pprint(manifest_data, stream=h_manifest) + time_taken = time.monotonic() - time0 + return offsets[-1], time_taken + + +def model_load(filepath_stem, mmapped=True): + """Loads a model saved by `model_save`. + + Tries to load the model from f"{filepath_stem}.data" + and f"{filepath_stem}.manifest". + + Args: + filepath_stem: The model location on the filesystem. + mmapped: Whether data-arrays will be slices of a + `numpy.memmap` mapped buffer, to be paged in + on demand only, or in-memory copies of the data. + Returns: + A dict/numpy.ndarray tree representation of the model, + equivalent to the original model. + """ + with open(filepath_stem + '.manifest', 'rt') as h_manifest: + manifest = ast.literal_eval(h_manifest.read()) + membuf = numpy.memmap(filepath_stem + '.data', mode='r+') + paths_and_arrays = [] + for item in manifest: + path = item['path'] + dtype = numpy.dtype(item['dtype']) + shape = item['shape'] + nbytes = item['nbytes'] + offset = item['offset'] + data_array = numpy.frombuffer(membuf[offset : offset + nbytes].data, + dtype=dtype).reshape(shape) + paths_and_arrays.append( + (path, + data_array if mmapped else data_array.copy())) + # At this point, the memory-buffer is no longer needed. Still, if + # data-arrays retain references to the underlying data + # (i.e. when mmapped=False), this should keep the mapping + # - and hence file descriptor - open. We then are in a somewhat + # undesirable situation of clean-up of a resource that happens in a + # hard-to-predict way releasing a file descriptor. + del membuf + return revtuple_autovifify_from_linear(paths_and_arrays) diff --git a/compression/python/pytree/ml_model_transforms_test.py b/compression/python/pytree/ml_model_transforms_test.py new file mode 100644 index 0000000..9495c87 --- /dev/null +++ b/compression/python/pytree/ml_model_transforms_test.py @@ -0,0 +1,92 @@ +"""Basic tests for 'algebraic data type based pytree' transformations.""" + + +import io +import os +import tempfile +import unittest + +import numpy + +import ml_model_transforms + + +def _get_model(prefix): + return { + prefix + 'a1': numpy.arange(1000, 1024).reshape(6, 4).astype(numpy.float32), + prefix + 'a2': numpy.arange(2000, 2048).reshape(6, 8).astype(numpy.float32), + prefix + 'b1': { + prefix + 'c1': numpy.arange(100, 127).reshape(3, 3, 3).astype(numpy.int8), + prefix + 'c2': numpy.arange(100, 128).reshape(7, 4).astype(numpy.float64) + }} + + +class MLModeltransformsTest(unittest.TestCase): + """Basic correctness validation tests for ML-model transformations.""" + + def test_ml_model_leaf_summary(self): + """Tests guarantees given by `ml_model_leaf_summary`.""" + summary = ml_model_transforms.ml_model_leaf_summary( + ('a', ()), + numpy.arange(1000, 1024).reshape(6, 4).astype(numpy.int16), + sep='##') + self.assertIn('##', summary) # Separator is respected. + self.assertIn('(6, 4)', summary) # Shape is mentioned somewhere. + self.assertIn('int16', summary) # dtype is mentioned somewhere. + + def test_revtuple_autovivify_from_linear(self): + """Tests guarantees given by `revtuple_autovifify_from_linear`.""" + with self.subTest(guarantee='empty'): + self.assertEqual( + ml_model_transforms.revtuple_autovifify_from_linear([]), + {}) + with self.subTest(guarantee='generic'): + keys_vals = [(('a', 'b1', 'c1'), 1001), + (('a', 'b2'), 1002), + (('a2',), 1003), + ] + self.assertEqual( + ml_model_transforms.revtuple_autovifify_from_linear(keys_vals), + {'a': {'b1': {'c1': 1001}, 'b2': 1002}, 'a2': 1003}) + + def test_model_overview(self): + """Tests guarantees given by `model_overview`.""" + model = _get_model('xyz') + out_io = io.StringIO() + ml_model_transforms.model_overview(model, out=out_io) + overview = out_io.getvalue() + self.assertIn('xyz', overview) + + def test_model_contents(self): + """Tests guarantees given by `model_contents`.""" + model = _get_model('pq_') + contents = ml_model_transforms.model_contents(model) + fingerprints = {k: (a.shape, a.ravel()[:3].tolist()) + for k, a in contents.items()} + self.assertEqual(fingerprints, + {('pq_a1',): ((6, 4), [1000.0, 1001.0, 1002.0]), + ('pq_a2',): ((6, 8), [2000.0, 2001.0, 2002.0]), + ('pq_b1', 'pq_c1'): ((3, 3, 3), [100, 101, 102]), + ('pq_b1', 'pq_c2'): ((7, 4), [100.0, 101.0, 102.0])}) + + def test_model_save_load_basic(self): + """Tests basic guarantees given by `model_save` and `model_load`.""" + # What we care about here is that the round trip works - so + # it makes more sense to test saving and loading as one unit. + model_orig = _get_model('model_') + with tempfile.TemporaryDirectory() as tempdir: + filepath_stem = os.path.join(tempdir, 'the_model') + total_size, total_time = ml_model_transforms.model_save(model_orig, + filepath_stem) + self.assertGreater(total_size, 0) + self.assertGreater(total_time, 0) + model_reloaded = ml_model_transforms.model_load(filepath_stem) + contents_orig = ml_model_transforms.model_contents(model_orig) + contents_reloaded = ml_model_transforms.model_contents(model_reloaded) + self.assertEqual( + {k: v.tolist() for k, v in contents_orig.items()}, + {k: v.tolist() for k, v in contents_reloaded.items()}) + + +if __name__ == '__main__': + unittest.main() diff --git a/compression/python/pytree/pytree_transforms.py b/compression/python/pytree/pytree_transforms.py new file mode 100644 index 0000000..7e065af --- /dev/null +++ b/compression/python/pytree/pytree_transforms.py @@ -0,0 +1,508 @@ +"""Tools for transforming "nested python object" tree data structures. + +# Context + +The motivation for this module came from ML applications that ought to +be based on a principled handling of nested Python data structures. +Having such principled pytree-transforming code available solves +some other problems, such as doing away with a need to abuse +tree-mapping for-side-effect-only and having to use a hope-and-pray +approach to processing very deeply nested values which with a recursive +approach might trigger a RecursionError. + +We specifically want to cover the use case of having ML model +parameters that are available in a nested Python data structure for +which there "almost" is a unique-up-to-unique-isomorphism mapping from +and to this Algebraic Data Type: + +`data ModelParams a = Array a | Node [(String, ModelParams a)]` + +In this correspondence, `a` is some array-type (perhaps +`numpy.ndarray`, `jax.numpy.ndarray`, `tf.tensor`, etc.), but the +data-processing code is effectively entirely agnostic to this, and a +`Node` is "almost" an associative-list of (key, value) pairs +representing a Python dict. (Note: The "almost" here is mostly about +the conceptual wart that assoc-lists can in principle have key +duplicates, but Python dicts can not. This is however not a problem +since all we need is the transformation in one direction, +i.e. whatever data-processing `f` we want to express on the +model-parameters-pytree, we can express by specifying a "faithful" +mapping `m` into the above algebraic data type through which every +such pytree data transform factorizes, i.e. for every `f` we can find +a `g` such that `f(p) = g(m(p))`.) + +## Components + +The main workhorse in this module is the `pytree_iter` function that +maps a "PyTree (such as representing `ModelParams`)" to an iterator +over values obtained by applying a mapping-function to the "key-path" +and leaf-value for every leaf, where the "key-path" contains a +linked-list representation of the reversed sequence of keys from the +tree-root, with list-nodes being represented by pairs +`(latest_dict_key, rest_path)`, and the empty path being represented +by `()`. + +For the sake of genericity, `pytree_iter` is built in such a way that +it actually can handle any kind of traversal of PyTree-trees that do +represent algebraic data types (note however that some some do not) - +but for this to make sense, the user must have a way to define how to +interpret tree-nodes, in particular identify leaves. This requires +providing a function `node_handler` with the same signature and +behavior as described below for "node handlers". + +Additionally, this module provides mapping-over-pytrees via +`pytree_map`, which is also built in such a way that it makes the +correspondence between an algebraic data type and its Python +nested-tree representation explicit. Despite being powerful and +flexible, this, however, may in general require a bit more effort to +wire up, since node-rebuilding can be fairly nontrivial. + +Furthermore, as a prominent application, this module provides a simple +deep-freezing function that translates a nested Python data structure +to deeply-immutable form. + +## Concepts and Conventions + +"revtuple representation": + + As we iterate over a tree, we will have to keep track of the + path-to-tree-root. Naturally, two sibling nodes `n1` and `n2` + will share the same parent-path (being siblings), so it makes + sense to use a linked-list-with-shared-tail representation. + Python does not have a natural notion for that, so we use + recursively-constructed tuples `(node_tag, parent_path)` + that represent the path-from-root in-reverse-order, i.e. + for a non-empty path `p`, `p[0]` is the node-tag at the + deepest nesting level. We call this a "revtuple representation" + of the path. + +"node handler": + + A node-handler classifies a tree-node as "leaf or other node", and + for non-leaf nodes provides information about both its children and + how to rebuild it. The behavior of a node-handler function must be + in alignment with this docstring: + + '''Processes a tree-node as required by pytree-iteration and -mapping. + + Args: + revtuple_path: Revtuple-representation of the path-from-root + to the current node. + node: a tree-node in a ML-model tree that is recursively + built out of leaf-values and other nodes. + + Returns: + `None` if the tree-node is to be regarded as a leaf, otherwise + a pair `(rebuilder, iterator)`, where `iterator` iterates + over the data-content of the node, each item represented as a pair + of `(lookup_path_item, value_item)`, and `rebuilder` is a function + which, when applied to an iterable of the aforementioned value-items + (or some transformation thereof) returns a node that is equivalent + to the original (or up to a transformation of the contents). + + Raises: + InvalidTreeNodeError: If the tree contains a node of a kind + that is not expected to show up. + ''' + + Examples: + + (The behavior of a node-handler is somewhat nontrivial, so covering + two very common cases via examples is in order.) + + This node-handler would allow descending into (nested) + instances of `list` (but not subclass instances thereof): + + ```def list_node_handler(revtuple_path, obj): + ''' ... ''' + if type(obj) is list: + return list, enumerate(obj) + else: + return None + ``` + + This node-handler would allow descending into (nested) mappings, + which upon rebuilding would get turned into `dict` instances: + + ```def mapping_node_handler(revtuple_path, obj): + ''' ... ''' + if isinstance(obj, collections.abc.Mapping): + # For generic mappings, we cannot rely on key- and item-iteration + # being guaranteed to use identical iteration-order. + items = list(obj.items()) + keys = [kv[0] for kv in items] + return (lambda values: dict(zip(keys, values))), items + else: + return None + ``` + + A dict/mapping node-handler can of course rename keys, add or remove + entries, make decisions based on the item-path, or map a dict to + an associative list, etc. + +## Further Design Notes + +The `pytree_map` function requests the leaf-transform and node-handler +to be side-effect-free functions. This is both required to leave +implementation-side flexibility, and also follows the general LISP +recommendation to not abuse mapping (which should be a pure +data-transformation) for imperative data processing. Overall, if +a need for more general "nested datastructures" processing becomes +pressing, it is for the better if this leads to a proper articulation +of the specific needs, to be addressed with appropriate design, rather +than abuse of functional data-transforms becoming "a bad idiom +that turned into established practice". + +""" + +import collections.abc +import immutabledict + +import numpy + +from typing import Any, Callable, Iterable, Iterator, TypeVar + + +T = TypeVar('T') +U = TypeVar('U') + +KT = TypeVar('KT') +NT = TypeVar('NT') + + +## Type of the reverse-order-keys-to-root path. +# (This code actually illustrates why https://xkcd.com/2483/ is very misguided.) +RevTuplePath = tuple + +## Type of the `leaf_transform` function-argument used for tree-iteration. +# +# This would be the correct type we would have to specify here but cannot, +# since the design of Python's static typing at the time of this writing +# is too broken for that: +# +# type LeafTransformFunc[L, R] = Callable[[RevTuplePath, L], R] +# +# Instead, we have to settle for...: +LeafTransformFunc = Callable[[RevTuplePath, Any], Any] + + +## Type of the `tree_node_handler` function-argument used for +## tree-iteration and tree-mapping. +# +# Again, this is the correct type we would have to put here but cannot: +# +# type NodeHandlerFunc[KT] = ( +# Callable[[NT], +# None | tuple[Callable[[Iterable[tuple[KT, NT]]], NT], +# Iterator[tuple[KT, NT]]]]) +# +# ...so, we have to instead settle for: +NodeHandlerFunc = ( + Callable[[RevTuplePath, NT], + None | tuple[Callable[[Iterable[tuple[Any, NT]]], NT], + Iterator[tuple[Any, NT]]]]) + + +Predicate = Callable[[object], bool] + + +class InvalidTreeNodeError(ValueError): + """Encountered a tree-node of invalid type.""" + + +def linearize_revtuple_path( + revtuple_path: RevTuplePath, + present_as: Callable[[Iterator[T]], U] = tuple) -> U: + """Translates a revtuple path to (typically) linear form. + + With default `present_as`, this will map a path of the form + `(key_{N}, (key_{N-1}, ..., (root, ())))` into a tuple + (root, ..., key_{N-1}, key_{N}). + + Args: + revtuple_path: A linked-list-as-recursive-pairs + reverse-order tuple-representation of the path. + Path-root is `()`, and node-key `x` relative to + earlier path `p` is represented as `(x, p)`. + present_as: Callable that consumes an iterator over + path-pieces - with the deepest-nesting level coming last - + turning it into a linearized path. Defaults to `tuple`. + + Returns: + Linearized presentation of all the node-keys in the + recursive-path in order, deepest-down path component coming last. + """ + pieces = [] + todo = revtuple_path + while todo: + node, todo = todo + pieces.append(node) + return present_as(reversed(pieces)) + + +# This function itself has type `NodeHandlerFunc`, but Python does not +# allow us to here simply type-annotate it like this. We cannot even +# introduce an abbreviation for the complicated output-type, +# since that would have to be parametric in node-type `NT` (and `KT`). +def everything_is_a_leaf_node_handler( + revtuple_path: tuple, + node : NT) -> ( + None | tuple[Callable[[Iterable[tuple[Any, NT]]], NT], + Iterator[tuple[Any, NT]]]): + """Processes a tree-node as required by pytree-iteration and -mapping. + + Interface and signature are in alignment with the requirements for a + "node handler" function explained in the module-docstring. + + Args: + revtuple_path: the path-to-root for this node. + node: a tree-node. + + Returns: + `None`, i.e. classifying any kind of node as a leaf-node. + """ + del revtuple_path, node # Unused. + return None + + +def leaf_summary(path: RevTuplePath, x: object): + """Produces a human-readable summary-string for a leaf-node. + + Args: + path: revtuple representation of the path-to-root. + x: The leaf-value. + """ + del path # Ignored here. + tx = type(x) + mod = tx.__module__ + modname = mod if isinstance(mod, str) else mod.__name__ + type_str = f'{modname}.{tx.__qualname__}' + repr_str = repr(x) + repr_abbrev = repr_str if len(repr_str) < 40 else repr_str[:40] + ' ...' + # On str, int, float, etc. `{type_str}(repr(x))` would actually still be + # a (non-literal) Python-expression that would evaluate to the original value. + # However, we make no promises beyond "human-readable". + return f'{type_str}({repr_abbrev})' + + +# With respect to static type annotations, the limitations of Python's +# approach to static typing really become prominently visible here. +# +# Different arguments have type-parameters, but since there is no way +# to have parametric abbreviations such as `LeafTransformFunc[L, R]`, +# the only way we would have available to express relations between +# type-parameters would be to substitute in the not-abbreviated form of +# `NodeHandlerFunc` and `LeafTransformFunc`, giving us something monstrous. +# We instead here settle for "we cannot express that `tree` must +# have the same type as the input-type to `tree_node_handler` and use `Any`, +# and likewise for leaf_transform and the output. +def pytree_leaf_iter( + tree: Any, + leaf_transform: LeafTransformFunc, + node_handler: NodeHandlerFunc = everything_is_a_leaf_node_handler, + ) -> Iterator[Any]: + # ...actual return type would be `Iterator[{what leaf_transform returns}]`. + """Iterates over the leaves of a tree. + + Args: + tree: The tree to iterate over. + leaf_transform: A callable `f` that will get applied + as `f(revtuple_path, leaf)`, where `revtuple_path` + is the revtuple representation of the path to the + leaf from the root. + node_handler: A "node handler" (see module docstring) + that processes nodes encountered during iterative traversal. + + Yields: + Value of `leaf_transform(p, x)`, where `x` is the current leaf + and `p` is its revtuple-path to the root. + """ + # Note: Exit points for the code below are in non-obvious places + # and hence marked with " # ***EXIT***". + # + # Doing iteration properly is slightly nontrivial. + # One may be tempted to go for a very simple recursive implementation + # (with an extra pre-final `path` argument to `pytree_iter`): + # + # maybe_substructure = node_handler(path, tree) + # if maybe_substructure is None: + # # We are looking at a leaf-node. + # yield leaf_transform(path, tree) + # else: + # _, contents_iter = maybe_substructure + # for k, v in contents_iter: + # yield from pytree_iter(v, leaf_transform, (k, path), node_handler) + # + # That, however, would be flawed, since there is no a priori reason + # why a pytree may not be a very deeply nested structure - such as a + # long linked list. That would then risk raising `RecursionError`, + # and since Python by design(!) does not perform tail call elimination + # or any other kind of advanced CPS transforms, there is no recursive + # solution here. So, to do this properly, we have to do this iteratively. + # + # We are facing an annoying situation here: If `tree` itself is a leaf, + # we have two options: (a) wrapping it up in a one-node tree + # and processing that, or (b) special-casing "root is a leaf". + # Option (b) leads to some mild node-processing code-duplication + # for a single node (the root). + # Option (a) requires having special cases for node-processing that + # get looked at for every tree node. We go with option (b) here. + maybe_substructure = node_handler((), tree) + if maybe_substructure is None: + # The tree itself is a leaf. + yield leaf_transform((), tree) + return # ***EXIT*** + # Otherwise, we are looking at a tree. + _, contents_iter = maybe_substructure + current_revtuple_path = () + work_to_do = [contents_iter] + # Otherwise-unreachable sentinel for reliably identifying + # iterator-exhaustion without using exceptions: + sentinel = object() + while True: + current_iter = work_to_do[-1] + maybe_next_item = next(current_iter, sentinel) + if maybe_next_item is sentinel: + # We are done at this level. + work_to_do.pop() + if not work_to_do: return # ***EXIT*** + current_revtuple_path = current_revtuple_path[1] + else: + path_piece, subtree = maybe_next_item + extended_revtuple_path = (path_piece, current_revtuple_path) + maybe_subtree_substructure = node_handler(extended_revtuple_path, subtree) + if maybe_subtree_substructure is None: # Case: subtree is a leaf. + yield leaf_transform(extended_revtuple_path, subtree) + else: # Case: subtree is a tree. + current_revtuple_path = (path_piece, current_revtuple_path) + _, items_iter = maybe_subtree_substructure + work_to_do.append(items_iter) + + +# The current design approach here would be appropriate for +# applying leaf-transforms while retaining the structure of the tree - +# which closely corresponds to e.g. a (a -> b) -> (Tree a -> Tree b) functor. +# +# It is not entirely clear whether this is the abstraction that we should +# consider as being appropriately generic to flesh out explicitly - rather +# than starting from a more general approach of which this then is a special +# case. Some background: https://ncatlab.org/nlab/show/recursion+scheme +# +# On the other hand, there is a lot of flexibility via whatever +# node-rebuilder a node-handler produces - this can do quite some reshaping +# of a tree, including dropping or duplicating nodes. +def pytree_map( + tree: Any, + leaf_transform, + node_handler: NodeHandlerFunc = everything_is_a_leaf_node_handler, + ): + """Maps a (potentially nested) Python value to another such value. + + Args: + tree: The Python-object to be mapped. + leaf_transform: A callable `f` that will get applied + as `f(revtuple_path, leaf)`, where `revtuple_path` + is the revtuple representation of the path to the + leaf from the root. Must be side effect free. + node_handler: A "node handler" (see module docstring) + that processes nodes encountered during iterative traversal. + Must be side effect free. + + Returns: + The outcome of translating `tree`. + """ + # Note: Exit points for the code below are in non-obvious places + # and hence marked with " # ***EXIT***". + # + # Otherwise-inaccessible sentinel object, for reliably identifying + # missing-values via identity-check against sentinel lookup-default. + sentinel = object() + # Code structure mostly follows pytree_leaf_iter. + maybe_substructure = node_handler((), tree) + if maybe_substructure is None: + return leaf_transform((), tree) # ***EXIT*** + rebuilder, items_iter = maybe_substructure + current_revtuple_path = () + # Per-level, we have a triplet of: + # (rebuilder, remaining_items_to_iterate_over, processed). + parts_for_assembly = [(rebuilder, items_iter, [])] + while True: + this_rebuilder, this_items_iter, this_done_pieces = parts_for_assembly[-1] + maybe_next_item = next(this_items_iter, sentinel) + if maybe_next_item is sentinel: + # We are done with all the items for this level. + parts_for_assembly.pop() + built_iter = this_rebuilder(this_done_pieces) + if not parts_for_assembly: # No outer structure, so at-top-level. + return built_iter # ***EXIT*** + else: # We have outer structure. + parts_for_assembly[-1][-1].append(built_iter) + current_revtuple_path = current_revtuple_path[1] + continue # ...with next is-the-final-item-complete-check. + else: + # More constituents of the current item. + path_piece, subtree_item = maybe_next_item + extended_revtuple_path = (path_piece, current_revtuple_path) + maybe_subtree_substructure = node_handler( + extended_revtuple_path, + subtree_item) + if maybe_subtree_substructure is None: + this_done_pieces.append( + leaf_transform(extended_revtuple_path, subtree_item)) + else: + # We have a subtree. + subtree_rebuilder, subtree_items_iter = maybe_subtree_substructure + current_revtuple_path = (path_piece, + current_revtuple_path) + parts_for_assembly.append( + (subtree_rebuilder, subtree_items_iter, [])) + + +def deep_freeze( + tree, + *, + is_mapping : Predicate = lambda x: isinstance(x, collections.abc.Mapping), + is_set : Predicate = lambda x: isinstance(x, collections.abc.Set), + is_sequence : Predicate = lambda x: isinstance(x, (list, tuple)), + leaf_fn: Callable[[Any], Any] = lambda x: x, + ): + """Recursively freezes Set/Mapping/List/Tuple structures. + + Args: + tree: The potentially deeply-nested object to deep-freeze. + is_mapping: Callable that decides whether a sub-object is a mapping. + Defaults to an `isinstance()` check for `collections.abc.Mapping`. + is_set: Callable that decides whether a sub-object is a set. + Defaults to an `isinstance()` check for `collections.abc.Set`. + is_sequence: Callable that decides whether a sub-object is a sequence. + Defaults to a check for being a `tuple` or `list` instance. + leaf_fn: Function to use for translating non-mapping/set/sequence + instances. + + Returns: + Translated-to-deeply-immutable form of `tree`. + """ + idict = immutabledict.immutabledict + def freeze_node_handler(path, x): + if is_set(x): + return frozenset, ((None, y) for y in x) + if is_mapping(x): + # Mappings already have hashable, so + # (should-be-)deeply-immutable keys. + # Hence, we only need to deep-freeze the values. + # + # Note that non-`dict` mappings might not guarantee + # to respect iteration-order, so we have to be careful here: + items = list(x.items()) + keys = [kv[0] for kv in items] + values = [kv[1] for kv in items] + return ((lambda ys: idict(zip(keys, ys))), + iter(items)) + if is_sequence(x): + return tuple, enumerate(iter(x)) + # Otherwise, this should not be traversed. + return None + def leaf_transform(revtuple_path, value): + del revtuple_path # Unused. + return leaf_fn(value) + return pytree_map(tree, leaf_transform, freeze_node_handler) diff --git a/compression/python/pytree/pytree_transforms_test.py b/compression/python/pytree/pytree_transforms_test.py new file mode 100644 index 0000000..fdaec71 --- /dev/null +++ b/compression/python/pytree/pytree_transforms_test.py @@ -0,0 +1,168 @@ +"""Basic tests for 'algebraic data type based pytree' transformations.""" + + +import collections.abc +import sys +import unittest + +import pytree_transforms + + +def _get_deep_pytree(packaging_fn, bottom, depth): + current = bottom + for n in reversed(range(depth)): + current = packaging_fn(n, current) + return current + + +def _dict_node_handler(p, d): + del p # Unused. + if isinstance(d, dict): + keys = d.keys() + newdict = lambda vals: dict(zip(keys, vals)) + return (newdict, iter(d.items())) + else: + return None + + +class PyTreeTest(unittest.TestCase): + """Basic correctness validation tests for PyTree transformations.""" + + def test_linearize_revtuple_path(self): + """Tests guarantees given by `linearize_revtuple_path`.""" + linearize_revtuple_path = pytree_transforms.linearize_revtuple_path + with self.subTest(guarantee='empty'): + self.assertEqual(linearize_revtuple_path(()), ()) + with self.subTest(guarantee='typical'): + self.assertEqual(linearize_revtuple_path((30, (20, (10, ())))), + (10, 20, 30)) + with self.subTest(guarantee='present_as'): + self.assertEqual( + linearize_revtuple_path( + (30, (20, (10, ()))), present_as=list), + [10, 20, 30]) + + def test_everything_is_a_leaf_node_handler(self): + """Tests guarantees given by `everything_is_a_leaf_node_handler`.""" + everything_is_a_leaf_node_handler = ( + pytree_transforms.everything_is_a_leaf_node_handler) + self.assertEqual(everything_is_a_leaf_node_handler((), 'abc'), + None) + self.assertEqual(everything_is_a_leaf_node_handler(('b', ()), + dict(a=3)), + None) + + def test_leaf_summary(self): + """Tests guarantees given by `leaf_summary`.""" + # Since the docstring only guarantees "a human-readable presentation", + # we can and should only do loose checks. + thing = (5678, 9531) + summary = pytree_transforms.leaf_summary(('key', ()), thing) + self.assertIsInstance(summary, str) + self.assertIn(str(thing[0]), summary) + self.assertIn(str(thing[1]), summary) + + def test_pytree_leaf_iter(self): + """Tests guarantees given by `pytree_leaf_iter`.""" + pytree_leaf_iter = pytree_transforms.pytree_leaf_iter + def leaf_transform(path, leaf): + return repr(leaf) if path and path[0].startswith('R') else leaf + with self.subTest(guarantee='returns_iterator'): + result = pytree_leaf_iter(7, leaf_transform, _dict_node_handler) + self.assertIsInstance(result, collections.abc.Iterator) + with self.subTest(guarantee='totally_empty'): + result = list(pytree_leaf_iter({}, leaf_transform, _dict_node_handler)) + self.assertEqual(result, []) + with self.subTest(guarantee='no_leaves'): + result = list(pytree_leaf_iter(dict(a={}), + leaf_transform, _dict_node_handler)) + self.assertEqual(result, []) + with self.subTest(guarantee='is_leaf'): + result = list(pytree_leaf_iter(777, leaf_transform, _dict_node_handler)) + self.assertEqual(result, [777]) + with self.subTest(guarantee='generic'): + result = list(pytree_leaf_iter( + dict(n0=dict(n01=dict(n012=1002, + n013=1003, + Rn014=1004, + ), + n02=1005), + n5=1006), + leaf_transform, _dict_node_handler)) + self.assertEqual(result, [1002, 1003, '1004', 1005, 1006]) + with self.subTest(guarantee='with_keys'): + result = list(pytree_leaf_iter( + dict(n0=dict(n01=dict(n012=1002, + n013=1003)), + n1=1004), + lambda p, s: (pytree_transforms.linearize_revtuple_path(p), s), + _dict_node_handler)) + self.assertEqual(result, + [(('n0', 'n01', 'n012'), 1002), + (('n0', 'n01', 'n013'), 1003), + (('n1',), 1004)]) + + def test_pytree_map(self): + """Tests guarantees given by `pytree_map`.""" + pytree_map = pytree_transforms.pytree_map + leaf_transform = lambda p, s: repr(s) + tree1 = dict(t0=dict(t10=1001, + t11=dict(t110=1002, + t111=1003), + t12=dict(t120=1004, + t121=1005, + t122=1006)), + t1=1007) + with self.subTest(guarantee='no_leaves'): + result = pytree_map(dict(a={}), + leaf_transform, + _dict_node_handler) + self.assertEqual(result, dict(a={})) + with self.subTest(guarantee='is_leaf'): + result = pytree_map(777, leaf_transform, _dict_node_handler) + self.assertEqual(result, '777') + with self.subTest(guarantee='generic'): + result = pytree_map(tree1, leaf_transform, _dict_node_handler) + self.assertEqual(result['t0']['t10'], '1001') + + def test_deeply_nested(self): + """Tests correct behavior on deeply-nested data structures.""" + pytree_leaf_iter = pytree_transforms.pytree_leaf_iter + pytree_map = pytree_transforms.pytree_map + # + depth = max(10**5, sys.getrecursionlimit() + 100) + deep_tree = _get_deep_pytree(lambda n, t: {n: t}, + 'leaf', depth) + with self.subTest(function='pytree_leaf_iter'): + leaves = list(pytree_leaf_iter(deep_tree, + lambda p, s: s.upper(), + _dict_node_handler)) + self.assertEqual(leaves, ['LEAF']) + with self.subTest(function='pytree_map'): + mapped_deep_tree = pytree_map(deep_tree, + lambda p, s: s, + _dict_node_handler) + self.assertIsInstance(mapped_deep_tree, dict) + with self.subTest(function='combined'): + leaves = list( + pytree_leaf_iter( + pytree_map(deep_tree, + lambda p, s: s.capitalize(), + _dict_node_handler), + lambda p, s: s + s, + _dict_node_handler)) + self.assertEqual(leaves, ['LeafLeaf']) + + def test_deep_freeze(self): + """Tests guarantees given by `deep_freeze`.""" + frozen = pytree_transforms.deep_freeze( + dict(a=[1001, 1002, dict(b=(1003, [1004, {1005, 1006}]))])) + self.assertIsInstance(frozen, collections.abc.Mapping) + self.assertNotIsInstance(frozen, collections.abc.MutableMapping) + self.assertIsInstance(frozen['a'], tuple) + # `frozen` is hashable, and hashes to an integer. + self.assertIsInstance(hash(frozen), int) + + +if __name__ == '__main__': + unittest.main() diff --git a/compression/python/pytree/requirements.txt b/compression/python/pytree/requirements.txt new file mode 100644 index 0000000..90c3f39 --- /dev/null +++ b/compression/python/pytree/requirements.txt @@ -0,0 +1,4 @@ +immutabledict>=4.2.0 +numpy>=1.26.4 +orbax-checkpoint>=0.0.0 +