diff --git a/BUILD.bazel b/BUILD.bazel index cbfb342..52c2df3 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -507,14 +507,12 @@ cc_library( srcs = [ "gemma/attention.cc", "gemma/gemma.cc", - "gemma/griffin.cc", "gemma/vit.cc", ], hdrs = [ "gemma/activations.h", "gemma/attention.h", "gemma/gemma.h", - "gemma/griffin.h", "gemma/vit.h", ], exec_properties = { diff --git a/CMakeLists.txt b/CMakeLists.txt index d3a66fd..4bc0e80 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -83,8 +83,6 @@ set(SOURCES gemma/gemma-inl.h gemma/gemma.cc gemma/gemma.h - gemma/griffin.cc - gemma/griffin.h gemma/kv_cache.cc gemma/kv_cache.h gemma/model_store.cc diff --git a/README.md b/README.md index 067051d..2963bf6 100644 --- a/README.md +++ b/README.md @@ -53,7 +53,7 @@ Guidelines](https://opensource.google.com/conduct/). - LLM - - CPU-only inference for: Gemma 2-3, Griffin(SSM), PaliGemma 2. + - CPU-only inference for: Gemma 2-3, PaliGemma 2. - Sampling with TopK and temperature. - Backward pass (VJP) and Adam optimizer for Gemma research. @@ -222,23 +222,6 @@ Example invocation for the following configuration: --tokenizer tokenizer.spm --weights gemma2-2b-it-sfp.sbs ``` -### RecurrentGemma - -This repository includes a version of Gemma based on Griffin -([paper](https://arxiv.org/abs/2402.19427), -[code](https://github.com/google-deepmind/recurrentgemma)). Its architecture -includes both recurrent layers and local attention, thus it is more efficient -for longer sequences and has a smaller memory footprint than standard Gemma. We -here provide a C++ implementation of this model based on the paper. - -To use the recurrent version of Gemma included in this repository, build the -gemma binary as noted above in Step 3. Download the compressed weights and -tokenizer from the RecurrentGemma -[Kaggle](https://www.kaggle.com/models/google/recurrentgemma/gemmaCpp) as in -Step 1, and run the binary as follows: - -`./gemma --tokenizer tokenizer.spm --model gr2b-it --weights 2b-it-sfp.sbs` - ### PaliGemma Vision-Language Model This repository includes a version of the PaliGemma 2 VLM @@ -535,7 +518,7 @@ gemma.cpp was started in fall 2023 by Griffin support was implemented in April 2024 thanks to contributions by Andrey Mikhaylov, Eugene Kliuchnikov, Jan Wassenberg, Jyrki Alakuijala, Lode Vandevenne, Luca Versari, Martin Bruse, Phil Culliton, Sami Boukortt, Thomas -Fischbacher and Zoltan Szabadka. +Fischbacher and Zoltan Szabadka. It was removed in 2025-09. Gemma-2 support was implemented in June/July 2024 with the help of several people. diff --git a/compression/python/compression_test.py b/compression/python/compression_test.py index e8244ed..957f0ec 100644 --- a/compression/python/compression_test.py +++ b/compression/python/compression_test.py @@ -91,7 +91,7 @@ class CompressionTest(absltest.TestCase): ) config = configs.ModelConfig( - configs.Model.GEMMA_TINY, + configs.Model.GEMMA2_2B, configs.Type.kSFP, configs.PromptWrapping.GEMMA_IT, ) @@ -101,7 +101,7 @@ class CompressionTest(absltest.TestCase): print("Ignore next two warnings; test does not enable model deduction.") reader = compression.SbsReader(temp_file.full_path) - self.assertEqual(reader.config.model, configs.Model.GEMMA_TINY) + self.assertEqual(reader.config.model, configs.Model.GEMMA2_2B) self.assertEqual(reader.config.weight, configs.Type.kSFP) mat = reader.find_mat("tensor0") diff --git a/compression/python/pytree/PYTREE_README.md b/compression/python/pytree/PYTREE_README.md deleted file mode 100644 index 4a04079..0000000 --- a/compression/python/pytree/PYTREE_README.md +++ /dev/null @@ -1,8 +0,0 @@ - -# General Remarks about the "PyTree" Abstraction - -The pytree wrangling code in this project does not use any of the existing -"pytree" modules. The deeper reason here is that our approach is based on an -analysis of the notion that emphasizes deeper underlying principles. This is -being discussed internally at the time of this writing. - diff --git a/compression/python/pytree/build_model_file_for_cpp_binary.py b/compression/python/pytree/build_model_file_for_cpp_binary.py deleted file mode 100644 index d039639..0000000 --- a/compression/python/pytree/build_model_file_for_cpp_binary.py +++ /dev/null @@ -1,275 +0,0 @@ -"""Ad-hoc glue code for building the griffin model-file for the C++ binary. - -Usage: - -python3 -m venv $HOME/clients/griffin-venv - -. $HOME/clients/griffin-venv/bin/activate - -python3 -m pip install -r requirements.txt - -time python3 build_model_file_for_cpp_binary.py \ - $HOME/GRIFFIN/model_data \ - cpp_load_log.txt /tmp/G2B.data - -real 3m5.821s -user 2m9.205s -sys 2m46.720s - -./compress_weights --weights /tmp/G2B.data --model gr2b-it \ - --compressed_weights /tmp/G2B.compressed -./gemma --tokenizer tokenizer.spm --weights /tmp/G2B.compressed \ - --model gr2b-it - -Weights for the recurrent-gemma model that can be converted with this script -can be found at: - - https://www.kaggle.com/models/google/recurrentgemma/flax/2b-it -""" - -import pprint -import re -import sys - -from typing import Any, Mapping - -import numpy - -import orbax.checkpoint - -import ml_model_transforms -import pytree_transforms - - -def _fn_identity(x): return x - - -def _fn_transpose(x): return x.T - - -def _fn_transpose_all_heads(x): return x.transpose(0, 2, 1) - - -def _fn_scaled_softplus(a): - return -8 * numpy.logaddexp(a, 0) - - -def _fn_attention_moveaxis(a): - return a.reshape(10, 256, 2560).transpose(0, 2, 1) - - -def _aspec(pieces=(), transforms=()): - """Short-hand array-save-specification. - - Args: - pieces: Sequence of key-sequences identifying an array. - transforms: Sequence of transformations, indexed in - parallel to `pieces`, to apply to data arrays prior to saving. - Will be padded with identity-transformations to the length of `pieces`. - - Returns: - Specification as for use in _LAYETR_NAME_MAPPING. - """ - # `zip` trims to shortest sequence, so this amounts to using - # default-transforms. - # tuple() since we need a Sequence here, not a stateful-iterator zip_object. - return tuple(zip(pieces, list(transforms) + [_fn_identity] * len(pieces))) - - -_LAYER_NAME_MAPPING = pytree_transforms.deep_freeze({ - # Recurrent Layer - 'griffin_linear_x_w': _aspec( - [('recurrent_block', 'linear_x', 'kernel')], - [_fn_transpose]), - 'griffin_linear_x_biases': _aspec( - [('recurrent_block', 'linear_x', 'bias')]), - 'griffin_linear_y_w': _aspec( - [('recurrent_block', 'linear_y', 'kernel')], - [_fn_transpose]), - 'griffin_linear_y_biases': _aspec( - [('recurrent_block', 'linear_y', 'bias')]), - 'griffin_linear_out_w': _aspec( - [('recurrent_block', 'linear_out', 'kernel')], - [_fn_transpose]), - 'griffin_linear_out_biases': _aspec( - [('recurrent_block' ,'linear_out', 'bias')]), - 'griffin_conv_w': _aspec( - [('recurrent_block', 'conv_1d', 'w')]), - 'griffin_conv_biases': _aspec( - [('recurrent_block', 'conv_1d', 'b')]), - 'griffin_gate_w': _aspec( - [('recurrent_block', 'rg_lru', 'input_gate', 'w'), - ('recurrent_block', 'rg_lru', 'a_gate', 'w')], - [_fn_transpose_all_heads, _fn_transpose_all_heads]), - 'griffin_gate_biases': _aspec( - [('recurrent_block', 'rg_lru', 'input_gate', 'b'), - ('recurrent_block', 'rg_lru', 'a_gate', 'b')]), - 'griffin_a': _aspec( - [('recurrent_block', 'rg_lru', 'a_param')], - [_fn_scaled_softplus]), - # Attention Layer - 'qkv_einsum_w': _aspec( - [('attention_block', 'proj_q', 'kernel'), - ('attention_block', 'proj_k', 'kernel'), - ('attention_block', 'proj_v', 'kernel'), - ], - [_fn_transpose, _fn_transpose, _fn_transpose]), - 'attn_vec_einsum_w': _aspec( - [('attention_block', 'proj_final', 'kernel')], - [_fn_attention_moveaxis]), - 'attention_output_biases': _aspec( - [('attention_block', 'proj_final', 'bias')]), - # Common - 'pre_attention_norm_scale': _aspec( - [('temporal_pre_norm', 'scale')]), - 'pre_ffw_norm_scale': _aspec( - [('channel_pre_norm', 'scale')]), - 'gating_einsum_w': _aspec( - [('mlp_block', 'ffw_up', 'w')], - [_fn_transpose_all_heads]), - 'ffw_gating_biases': _aspec( - [('mlp_block', 'ffw_up', 'b')]), - 'linear_w': _aspec( - [('mlp_block', 'ffw_down', 'kernel')], - [_fn_transpose]), - 'ffw_output_biases': _aspec( - [('mlp_block', 'ffw_down', 'bias')]), - # Other - 'embedder_input_embedding': _aspec( - [('embedder', 'input_embedding')]), - 'final_norm_scale': _aspec( - [('final_norm', 'scale')]), -}) - - -def process_param_line(line : str) -> tuple[None | str, int, str]: - """Processes a "loading parameters" log-line from the griffin binary.""" - # This is slightly more permissive than strictly needed, to also handle - # some earlier form of the output. - matched = re.match( - r'(?a)Loading Parameters:? \(' - r'(?:layer=(?P\d+), )?' - r'size (?P\d+)\):? ' - r'(?P\S+)', - line) - if not matched: - return None - layer = matched['layer'] - wanted_size = int(matched['size']) - cpp_tag = matched['tag'] - return matched['layer'], int(matched['size']), matched['tag'] - - -def collect_pytree_keys(param_lines): - """Collects all the pytree keys and transforms for model-serialization.""" - pytree_keys = [] - array_transforms = [] - unsatisfied = [] - for maybe_spec in map(process_param_line, param_lines): - if not maybe_spec: continue # Skip non-parameter lines. - layer, wanted_size, cpp_tag = maybe_spec - pytree_key_tails_and_transforms = _LAYER_NAME_MAPPING.get(cpp_tag, ()) - if not pytree_key_tails_and_transforms: - unsatisfied.append((layer, cpp_tag)) - else: - for key_tail, array_transform in pytree_key_tails_and_transforms: - pytree_keys.append( - key_tail if layer is None - else (f'blocks.{layer}',) + key_tail) - array_transforms.append(array_transform) - return pytree_keys, array_transforms, unsatisfied - - -class UnsatisfiedArrayLoadsError(ValueError): - """Some array-loads could not be satisfied.""" - - -def flatten_model_for_cpp_binary(tree, - cpp_expectations_logfile_path : str, - out_path : str, - unsatisfied_ok : bool = False - ): - """Produces a model-parameters file readable by the C++ binary. - - Args: - tree: The pytree with model-parameters. - cpp_expectations_logfile_path: - Path to a logfile produced by the C++ binary that shows - the expected array-order. - out_path: Path to the model-weights file to be written. - unsatisfied_ok: If true, we ignore the presence of unsatisfied - array-loads and write a model-parameters file that skips these pieces. - This will lead to an unusable model-parameters file which however - still might be useful for other analysis. - - Returns: - Tuple `(unknown_keys, missing_keys)`, where `unknown_keys` - is a sequence of `(layer_or_None, name)` descriptions of the keys - in the C++ log that could not be satisfied, and `missing_keys` - is a sequence of linearized pytree key-sequences for keys - not found in the checkpoint. - - Raises: - UnsatisfiedArrayLoadsError: If some of the expected arrays - could not be included in the output and `unsatisfied_ok` - is false. - """ - with open(cpp_expectations_logfile_path, 'rt') as h_log: - pytree_keys, array_transforms, unknown_keys = collect_pytree_keys( - list(h_log)) - rank_by_pytree_key = {k: n for n, k in enumerate(pytree_keys)} - array_transform_by_pytree_key = dict(zip(pytree_keys, array_transforms)) - # - model_contents = ml_model_transforms.model_contents(tree) - missing_keys = set(pytree_keys) - model_contents.keys() - if (unknown_keys or missing_keys) and not unsatisfied_ok: - raise ValueError( - f'Unsatisfied loads: unknown_keys: {unknown_keys!r}, ' - f'missing keys: {sorted(missing_keys)!r}') - ml_model_transforms.model_save( - tree, - filepath_stem=out_path, - data_suffix='', - manifest_suffix=None, - array_transform_by_pytree_key=array_transform_by_pytree_key, - key=rank_by_pytree_key.get, - report=lambda line: print(line, file=sys.stderr), - byte_align=1) - return tuple(unknown_keys), tuple(sorted(missing_keys)) - - -def main(args): - """Creates the model-file. - - Args: - sys.argv[] parameters from command line sans the leading one. - - Returns: - The pytree with all the de-serialized variables, such as for convenient - `python3 -i` inspection. - """ - try: - model_dir, cpp_load_log, out_path = args - except Exception: - sys.exit(f'Usage: {__file__} [model_dir] [cpp_load_log] [output_filename]') - pattern = ("recurrent", "recurrent", "attention") - orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer() - variables = orbax_checkpointer.restore(model_dir) - if sorted(variables) == ['params']: - print('Warning: Using `variables["params"]` as tree-root.', file=sys.stderr) - variables_to_use = variables['params'] - else: - variables_to_use = variables - unknown, missing = flatten_model_for_cpp_binary(variables_to_use, - cpp_load_log, - out_path, - unsatisfied_ok=True) - print('Model file saved.\n' - f'# unknown:\n{pprint.pformat(unknown)}\n' - f'# missing:\n{pprint.pformat(missing)}') - return variables - - -if __name__ == '__main__': - # Return value assignment is for `python3 -i ...` inspection. - pytree = main(sys.argv[1:]) diff --git a/compression/python/pytree/cpp_load_log.txt b/compression/python/pytree/cpp_load_log.txt deleted file mode 100644 index cc33394..0000000 --- a/compression/python/pytree/cpp_load_log.txt +++ /dev/null @@ -1,380 +0,0 @@ -Loading Parameters (size 2622750720): embedder_input_embedding -Loading Parameters (size 10240): final_norm_scale -Loading Parameters: (layer=0, size 26214400) griffin_linear_x_w -Loading Parameters: (layer=0, size 10240) griffin_linear_x_biases -Loading Parameters: (layer=0, size 26214400) griffin_linear_y_w -Loading Parameters: (layer=0, size 10240) griffin_linear_y_biases -Loading Parameters: (layer=0, size 26214400) griffin_linear_out_w -Loading Parameters: (layer=0, size 10240) griffin_linear_out_biases -Loading Parameters: (layer=0, size 40960) griffin_conv_w -Loading Parameters: (layer=0, size 10240) griffin_conv_biases -Loading Parameters: (layer=0, size 5242880) griffin_gate_w -Loading Parameters: (layer=0, size 20480) griffin_gate_biases -Loading Parameters: (layer=0, size 10240) griffin_a -Loading Parameters: (layer=0, size 157286400) gating_einsum_w -Loading Parameters: (layer=0, size 78643200) linear_w -Loading Parameters: (layer=0, size 10240) pre_attention_norm_scale -Loading Parameters: (layer=0, size 10240) pre_ffw_norm_scale -Loading Parameters: (layer=0, size 61440) ffw_gating_biases -Loading Parameters: (layer=0, size 10240) ffw_output_biases -Loading Parameters: (layer=1, size 26214400) griffin_linear_x_w -Loading Parameters: (layer=1, size 10240) griffin_linear_x_biases -Loading Parameters: (layer=1, size 26214400) griffin_linear_y_w -Loading Parameters: (layer=1, size 10240) griffin_linear_y_biases -Loading Parameters: (layer=1, size 26214400) griffin_linear_out_w -Loading Parameters: (layer=1, size 10240) griffin_linear_out_biases -Loading Parameters: (layer=1, size 40960) griffin_conv_w -Loading Parameters: (layer=1, size 10240) griffin_conv_biases -Loading Parameters: (layer=1, size 5242880) griffin_gate_w -Loading Parameters: (layer=1, size 20480) griffin_gate_biases -Loading Parameters: (layer=1, size 10240) griffin_a -Loading Parameters: (layer=1, size 157286400) gating_einsum_w -Loading Parameters: (layer=1, size 78643200) linear_w -Loading Parameters: (layer=1, size 10240) pre_attention_norm_scale -Loading Parameters: (layer=1, size 10240) pre_ffw_norm_scale -Loading Parameters: (layer=1, size 61440) ffw_gating_biases -Loading Parameters: (layer=1, size 10240) ffw_output_biases -Loading Parameters: (layer=2, size 26214400) attn_vec_einsum_w -Loading Parameters: (layer=2, size 78643200) qkv_einsum_w -Loading Parameters: (layer=2, size 157286400) gating_einsum_w -Loading Parameters: (layer=2, size 78643200) linear_w -Loading Parameters: (layer=2, size 10240) pre_attention_norm_scale -Loading Parameters: (layer=2, size 10240) pre_ffw_norm_scale -Loading Parameters: (layer=2, size 61440) ffw_gating_biases -Loading Parameters: (layer=2, size 10240) ffw_output_biases -Loading Parameters: (layer=2, size 10240) attention_output_biases -Loading Parameters: (layer=3, size 26214400) griffin_linear_x_w -Loading Parameters: (layer=3, size 10240) griffin_linear_x_biases -Loading Parameters: (layer=3, size 26214400) griffin_linear_y_w -Loading Parameters: (layer=3, size 10240) griffin_linear_y_biases -Loading Parameters: (layer=3, size 26214400) griffin_linear_out_w -Loading Parameters: (layer=3, size 10240) griffin_linear_out_biases -Loading Parameters: (layer=3, size 40960) griffin_conv_w -Loading Parameters: (layer=3, size 10240) griffin_conv_biases -Loading Parameters: (layer=3, size 5242880) griffin_gate_w -Loading Parameters: (layer=3, size 20480) griffin_gate_biases -Loading Parameters: (layer=3, size 10240) griffin_a -Loading Parameters: (layer=3, size 157286400) gating_einsum_w -Loading Parameters: (layer=3, size 78643200) linear_w -Loading Parameters: (layer=3, size 10240) pre_attention_norm_scale -Loading Parameters: (layer=3, size 10240) pre_ffw_norm_scale -Loading Parameters: (layer=3, size 61440) ffw_gating_biases -Loading Parameters: (layer=3, size 10240) ffw_output_biases -Loading Parameters: (layer=4, size 26214400) griffin_linear_x_w -Loading Parameters: (layer=4, size 10240) griffin_linear_x_biases -Loading Parameters: (layer=4, size 26214400) griffin_linear_y_w -Loading Parameters: (layer=4, size 10240) griffin_linear_y_biases -Loading Parameters: (layer=4, size 26214400) griffin_linear_out_w -Loading Parameters: (layer=4, size 10240) griffin_linear_out_biases -Loading Parameters: (layer=4, size 40960) griffin_conv_w -Loading Parameters: (layer=4, size 10240) griffin_conv_biases -Loading Parameters: (layer=4, size 5242880) griffin_gate_w -Loading Parameters: (layer=4, size 20480) griffin_gate_biases -Loading Parameters: (layer=4, size 10240) griffin_a -Loading Parameters: (layer=4, size 157286400) gating_einsum_w -Loading Parameters: (layer=4, size 78643200) linear_w -Loading Parameters: (layer=4, size 10240) pre_attention_norm_scale -Loading Parameters: (layer=4, size 10240) pre_ffw_norm_scale -Loading Parameters: (layer=4, size 61440) ffw_gating_biases -Loading Parameters: (layer=4, size 10240) ffw_output_biases -Loading Parameters: (layer=5, size 26214400) attn_vec_einsum_w -Loading Parameters: (layer=5, size 78643200) qkv_einsum_w -Loading Parameters: (layer=5, size 157286400) gating_einsum_w -Loading Parameters: (layer=5, size 78643200) linear_w -Loading Parameters: (layer=5, size 10240) pre_attention_norm_scale -Loading Parameters: (layer=5, size 10240) pre_ffw_norm_scale -Loading Parameters: (layer=5, size 61440) ffw_gating_biases -Loading Parameters: (layer=5, size 10240) ffw_output_biases -Loading Parameters: (layer=5, size 10240) attention_output_biases -Loading Parameters: (layer=6, size 26214400) griffin_linear_x_w -Loading Parameters: (layer=6, size 10240) griffin_linear_x_biases -Loading Parameters: (layer=6, size 26214400) griffin_linear_y_w -Loading Parameters: (layer=6, size 10240) griffin_linear_y_biases -Loading Parameters: (layer=6, size 26214400) griffin_linear_out_w -Loading Parameters: (layer=6, size 10240) griffin_linear_out_biases -Loading Parameters: (layer=6, size 40960) griffin_conv_w -Loading Parameters: (layer=6, size 10240) griffin_conv_biases -Loading Parameters: (layer=6, size 5242880) griffin_gate_w -Loading Parameters: (layer=6, size 20480) griffin_gate_biases -Loading Parameters: (layer=6, size 10240) griffin_a -Loading Parameters: (layer=6, size 157286400) gating_einsum_w -Loading Parameters: (layer=6, size 78643200) linear_w -Loading Parameters: (layer=6, size 10240) pre_attention_norm_scale -Loading Parameters: (layer=6, size 10240) pre_ffw_norm_scale -Loading Parameters: (layer=6, size 61440) ffw_gating_biases -Loading Parameters: (layer=6, size 10240) ffw_output_biases -Loading Parameters: (layer=7, size 26214400) griffin_linear_x_w -Loading Parameters: (layer=7, size 10240) griffin_linear_x_biases -Loading Parameters: (layer=7, size 26214400) griffin_linear_y_w -Loading Parameters: (layer=7, size 10240) griffin_linear_y_biases -Loading Parameters: (layer=7, size 26214400) griffin_linear_out_w -Loading Parameters: (layer=7, size 10240) griffin_linear_out_biases -Loading Parameters: (layer=7, size 40960) griffin_conv_w -Loading Parameters: (layer=7, size 10240) griffin_conv_biases -Loading Parameters: (layer=7, size 5242880) griffin_gate_w -Loading Parameters: (layer=7, size 20480) griffin_gate_biases -Loading Parameters: (layer=7, size 10240) griffin_a -Loading Parameters: (layer=7, size 157286400) gating_einsum_w -Loading Parameters: (layer=7, size 78643200) linear_w -Loading Parameters: (layer=7, size 10240) pre_attention_norm_scale -Loading Parameters: (layer=7, size 10240) pre_ffw_norm_scale -Loading Parameters: (layer=7, size 61440) ffw_gating_biases -Loading Parameters: (layer=7, size 10240) ffw_output_biases -Loading Parameters: (layer=8, size 26214400) attn_vec_einsum_w -Loading Parameters: (layer=8, size 78643200) qkv_einsum_w -Loading Parameters: (layer=8, size 157286400) gating_einsum_w -Loading Parameters: (layer=8, size 78643200) linear_w -Loading Parameters: (layer=8, size 10240) pre_attention_norm_scale -Loading Parameters: (layer=8, size 10240) pre_ffw_norm_scale -Loading Parameters: (layer=8, size 61440) ffw_gating_biases -Loading Parameters: (layer=8, size 10240) ffw_output_biases -Loading Parameters: (layer=8, size 10240) attention_output_biases -Loading Parameters: (layer=9, size 26214400) griffin_linear_x_w -Loading Parameters: (layer=9, size 10240) griffin_linear_x_biases -Loading Parameters: (layer=9, size 26214400) griffin_linear_y_w -Loading Parameters: (layer=9, size 10240) griffin_linear_y_biases -Loading Parameters: (layer=9, size 26214400) griffin_linear_out_w -Loading Parameters: (layer=9, size 10240) griffin_linear_out_biases -Loading Parameters: (layer=9, size 40960) griffin_conv_w -Loading Parameters: (layer=9, size 10240) griffin_conv_biases -Loading Parameters: (layer=9, size 5242880) griffin_gate_w -Loading Parameters: (layer=9, size 20480) griffin_gate_biases -Loading Parameters: (layer=9, size 10240) griffin_a -Loading Parameters: (layer=9, size 157286400) gating_einsum_w -Loading Parameters: (layer=9, size 78643200) linear_w -Loading Parameters: (layer=9, size 10240) pre_attention_norm_scale -Loading Parameters: (layer=9, size 10240) pre_ffw_norm_scale -Loading Parameters: (layer=9, size 61440) ffw_gating_biases -Loading Parameters: (layer=9, size 10240) ffw_output_biases -Loading Parameters: (layer=10, size 26214400) griffin_linear_x_w -Loading Parameters: (layer=10, size 10240) griffin_linear_x_biases -Loading Parameters: (layer=10, size 26214400) griffin_linear_y_w -Loading Parameters: (layer=10, size 10240) griffin_linear_y_biases -Loading Parameters: (layer=10, size 26214400) griffin_linear_out_w -Loading Parameters: (layer=10, size 10240) griffin_linear_out_biases -Loading Parameters: (layer=10, size 40960) griffin_conv_w -Loading Parameters: (layer=10, size 10240) griffin_conv_biases -Loading Parameters: (layer=10, size 5242880) griffin_gate_w -Loading Parameters: (layer=10, size 20480) griffin_gate_biases -Loading Parameters: (layer=10, size 10240) griffin_a -Loading Parameters: (layer=10, size 157286400) gating_einsum_w -Loading Parameters: (layer=10, size 78643200) linear_w -Loading Parameters: (layer=10, size 10240) pre_attention_norm_scale -Loading Parameters: (layer=10, size 10240) pre_ffw_norm_scale -Loading Parameters: (layer=10, size 61440) ffw_gating_biases -Loading Parameters: (layer=10, size 10240) ffw_output_biases -Loading Parameters: (layer=11, size 26214400) attn_vec_einsum_w -Loading Parameters: (layer=11, size 78643200) qkv_einsum_w -Loading Parameters: (layer=11, size 157286400) gating_einsum_w -Loading Parameters: (layer=11, size 78643200) linear_w -Loading Parameters: (layer=11, size 10240) pre_attention_norm_scale -Loading Parameters: (layer=11, size 10240) pre_ffw_norm_scale -Loading Parameters: (layer=11, size 61440) ffw_gating_biases -Loading Parameters: (layer=11, size 10240) ffw_output_biases -Loading Parameters: (layer=11, size 10240) attention_output_biases -Loading Parameters: (layer=12, size 26214400) griffin_linear_x_w -Loading Parameters: (layer=12, size 10240) griffin_linear_x_biases -Loading Parameters: (layer=12, size 26214400) griffin_linear_y_w -Loading Parameters: (layer=12, size 10240) griffin_linear_y_biases -Loading Parameters: (layer=12, size 26214400) griffin_linear_out_w -Loading Parameters: (layer=12, size 10240) griffin_linear_out_biases -Loading Parameters: (layer=12, size 40960) griffin_conv_w -Loading Parameters: (layer=12, size 10240) griffin_conv_biases -Loading Parameters: (layer=12, size 5242880) griffin_gate_w -Loading Parameters: (layer=12, size 20480) griffin_gate_biases -Loading Parameters: (layer=12, size 10240) griffin_a -Loading Parameters: (layer=12, size 157286400) gating_einsum_w -Loading Parameters: (layer=12, size 78643200) linear_w -Loading Parameters: (layer=12, size 10240) pre_attention_norm_scale -Loading Parameters: (layer=12, size 10240) pre_ffw_norm_scale -Loading Parameters: (layer=12, size 61440) ffw_gating_biases -Loading Parameters: (layer=12, size 10240) ffw_output_biases -Loading Parameters: (layer=13, size 26214400) griffin_linear_x_w -Loading Parameters: (layer=13, size 10240) griffin_linear_x_biases -Loading Parameters: (layer=13, size 26214400) griffin_linear_y_w -Loading Parameters: (layer=13, size 10240) griffin_linear_y_biases -Loading Parameters: (layer=13, size 26214400) griffin_linear_out_w -Loading Parameters: (layer=13, size 10240) griffin_linear_out_biases -Loading Parameters: (layer=13, size 40960) griffin_conv_w -Loading Parameters: (layer=13, size 10240) griffin_conv_biases -Loading Parameters: (layer=13, size 5242880) griffin_gate_w -Loading Parameters: (layer=13, size 20480) griffin_gate_biases -Loading Parameters: (layer=13, size 10240) griffin_a -Loading Parameters: (layer=13, size 157286400) gating_einsum_w -Loading Parameters: (layer=13, size 78643200) linear_w -Loading Parameters: (layer=13, size 10240) pre_attention_norm_scale -Loading Parameters: (layer=13, size 10240) pre_ffw_norm_scale -Loading Parameters: (layer=13, size 61440) ffw_gating_biases -Loading Parameters: (layer=13, size 10240) ffw_output_biases -Loading Parameters: (layer=14, size 26214400) attn_vec_einsum_w -Loading Parameters: (layer=14, size 78643200) qkv_einsum_w -Loading Parameters: (layer=14, size 157286400) gating_einsum_w -Loading Parameters: (layer=14, size 78643200) linear_w -Loading Parameters: (layer=14, size 10240) pre_attention_norm_scale -Loading Parameters: (layer=14, size 10240) pre_ffw_norm_scale -Loading Parameters: (layer=14, size 61440) ffw_gating_biases -Loading Parameters: (layer=14, size 10240) ffw_output_biases -Loading Parameters: (layer=14, size 10240) attention_output_biases -Loading Parameters: (layer=15, size 26214400) griffin_linear_x_w -Loading Parameters: (layer=15, size 10240) griffin_linear_x_biases -Loading Parameters: (layer=15, size 26214400) griffin_linear_y_w -Loading Parameters: (layer=15, size 10240) griffin_linear_y_biases -Loading Parameters: (layer=15, size 26214400) griffin_linear_out_w -Loading Parameters: (layer=15, size 10240) griffin_linear_out_biases -Loading Parameters: (layer=15, size 40960) griffin_conv_w -Loading Parameters: (layer=15, size 10240) griffin_conv_biases -Loading Parameters: (layer=15, size 5242880) griffin_gate_w -Loading Parameters: (layer=15, size 20480) griffin_gate_biases -Loading Parameters: (layer=15, size 10240) griffin_a -Loading Parameters: (layer=15, size 157286400) gating_einsum_w -Loading Parameters: (layer=15, size 78643200) linear_w -Loading Parameters: (layer=15, size 10240) pre_attention_norm_scale -Loading Parameters: (layer=15, size 10240) pre_ffw_norm_scale -Loading Parameters: (layer=15, size 61440) ffw_gating_biases -Loading Parameters: (layer=15, size 10240) ffw_output_biases -Loading Parameters: (layer=16, size 26214400) griffin_linear_x_w -Loading Parameters: (layer=16, size 10240) griffin_linear_x_biases -Loading Parameters: (layer=16, size 26214400) griffin_linear_y_w -Loading Parameters: (layer=16, size 10240) griffin_linear_y_biases -Loading Parameters: (layer=16, size 26214400) griffin_linear_out_w -Loading Parameters: (layer=16, size 10240) griffin_linear_out_biases -Loading Parameters: (layer=16, size 40960) griffin_conv_w -Loading Parameters: (layer=16, size 10240) griffin_conv_biases -Loading Parameters: (layer=16, size 5242880) griffin_gate_w -Loading Parameters: (layer=16, size 20480) griffin_gate_biases -Loading Parameters: (layer=16, size 10240) griffin_a -Loading Parameters: (layer=16, size 157286400) gating_einsum_w -Loading Parameters: (layer=16, size 78643200) linear_w -Loading Parameters: (layer=16, size 10240) pre_attention_norm_scale -Loading Parameters: (layer=16, size 10240) pre_ffw_norm_scale -Loading Parameters: (layer=16, size 61440) ffw_gating_biases -Loading Parameters: (layer=16, size 10240) ffw_output_biases -Loading Parameters: (layer=17, size 26214400) attn_vec_einsum_w -Loading Parameters: (layer=17, size 78643200) qkv_einsum_w -Loading Parameters: (layer=17, size 157286400) gating_einsum_w -Loading Parameters: (layer=17, size 78643200) linear_w -Loading Parameters: (layer=17, size 10240) pre_attention_norm_scale -Loading Parameters: (layer=17, size 10240) pre_ffw_norm_scale -Loading Parameters: (layer=17, size 61440) ffw_gating_biases -Loading Parameters: (layer=17, size 10240) ffw_output_biases -Loading Parameters: (layer=17, size 10240) attention_output_biases -Loading Parameters: (layer=18, size 26214400) griffin_linear_x_w -Loading Parameters: (layer=18, size 10240) griffin_linear_x_biases -Loading Parameters: (layer=18, size 26214400) griffin_linear_y_w -Loading Parameters: (layer=18, size 10240) griffin_linear_y_biases -Loading Parameters: (layer=18, size 26214400) griffin_linear_out_w -Loading Parameters: (layer=18, size 10240) griffin_linear_out_biases -Loading Parameters: (layer=18, size 40960) griffin_conv_w -Loading Parameters: (layer=18, size 10240) griffin_conv_biases -Loading Parameters: (layer=18, size 5242880) griffin_gate_w -Loading Parameters: (layer=18, size 20480) griffin_gate_biases -Loading Parameters: (layer=18, size 10240) griffin_a -Loading Parameters: (layer=18, size 157286400) gating_einsum_w -Loading Parameters: (layer=18, size 78643200) linear_w -Loading Parameters: (layer=18, size 10240) pre_attention_norm_scale -Loading Parameters: (layer=18, size 10240) pre_ffw_norm_scale -Loading Parameters: (layer=18, size 61440) ffw_gating_biases -Loading Parameters: (layer=18, size 10240) ffw_output_biases -Loading Parameters: (layer=19, size 26214400) griffin_linear_x_w -Loading Parameters: (layer=19, size 10240) griffin_linear_x_biases -Loading Parameters: (layer=19, size 26214400) griffin_linear_y_w -Loading Parameters: (layer=19, size 10240) griffin_linear_y_biases -Loading Parameters: (layer=19, size 26214400) griffin_linear_out_w -Loading Parameters: (layer=19, size 10240) griffin_linear_out_biases -Loading Parameters: (layer=19, size 40960) griffin_conv_w -Loading Parameters: (layer=19, size 10240) griffin_conv_biases -Loading Parameters: (layer=19, size 5242880) griffin_gate_w -Loading Parameters: (layer=19, size 20480) griffin_gate_biases -Loading Parameters: (layer=19, size 10240) griffin_a -Loading Parameters: (layer=19, size 157286400) gating_einsum_w -Loading Parameters: (layer=19, size 78643200) linear_w -Loading Parameters: (layer=19, size 10240) pre_attention_norm_scale -Loading Parameters: (layer=19, size 10240) pre_ffw_norm_scale -Loading Parameters: (layer=19, size 61440) ffw_gating_biases -Loading Parameters: (layer=19, size 10240) ffw_output_biases -Loading Parameters: (layer=20, size 26214400) attn_vec_einsum_w -Loading Parameters: (layer=20, size 78643200) qkv_einsum_w -Loading Parameters: (layer=20, size 157286400) gating_einsum_w -Loading Parameters: (layer=20, size 78643200) linear_w -Loading Parameters: (layer=20, size 10240) pre_attention_norm_scale -Loading Parameters: (layer=20, size 10240) pre_ffw_norm_scale -Loading Parameters: (layer=20, size 61440) ffw_gating_biases -Loading Parameters: (layer=20, size 10240) ffw_output_biases -Loading Parameters: (layer=20, size 10240) attention_output_biases -Loading Parameters: (layer=21, size 26214400) griffin_linear_x_w -Loading Parameters: (layer=21, size 10240) griffin_linear_x_biases -Loading Parameters: (layer=21, size 26214400) griffin_linear_y_w -Loading Parameters: (layer=21, size 10240) griffin_linear_y_biases -Loading Parameters: (layer=21, size 26214400) griffin_linear_out_w -Loading Parameters: (layer=21, size 10240) griffin_linear_out_biases -Loading Parameters: (layer=21, size 40960) griffin_conv_w -Loading Parameters: (layer=21, size 10240) griffin_conv_biases -Loading Parameters: (layer=21, size 5242880) griffin_gate_w -Loading Parameters: (layer=21, size 20480) griffin_gate_biases -Loading Parameters: (layer=21, size 10240) griffin_a -Loading Parameters: (layer=21, size 157286400) gating_einsum_w -Loading Parameters: (layer=21, size 78643200) linear_w -Loading Parameters: (layer=21, size 10240) pre_attention_norm_scale -Loading Parameters: (layer=21, size 10240) pre_ffw_norm_scale -Loading Parameters: (layer=21, size 61440) ffw_gating_biases -Loading Parameters: (layer=21, size 10240) ffw_output_biases -Loading Parameters: (layer=22, size 26214400) griffin_linear_x_w -Loading Parameters: (layer=22, size 10240) griffin_linear_x_biases -Loading Parameters: (layer=22, size 26214400) griffin_linear_y_w -Loading Parameters: (layer=22, size 10240) griffin_linear_y_biases -Loading Parameters: (layer=22, size 26214400) griffin_linear_out_w -Loading Parameters: (layer=22, size 10240) griffin_linear_out_biases -Loading Parameters: (layer=22, size 40960) griffin_conv_w -Loading Parameters: (layer=22, size 10240) griffin_conv_biases -Loading Parameters: (layer=22, size 5242880) griffin_gate_w -Loading Parameters: (layer=22, size 20480) griffin_gate_biases -Loading Parameters: (layer=22, size 10240) griffin_a -Loading Parameters: (layer=22, size 157286400) gating_einsum_w -Loading Parameters: (layer=22, size 78643200) linear_w -Loading Parameters: (layer=22, size 10240) pre_attention_norm_scale -Loading Parameters: (layer=22, size 10240) pre_ffw_norm_scale -Loading Parameters: (layer=22, size 61440) ffw_gating_biases -Loading Parameters: (layer=22, size 10240) ffw_output_biases -Loading Parameters: (layer=23, size 26214400) attn_vec_einsum_w -Loading Parameters: (layer=23, size 78643200) qkv_einsum_w -Loading Parameters: (layer=23, size 157286400) gating_einsum_w -Loading Parameters: (layer=23, size 78643200) linear_w -Loading Parameters: (layer=23, size 10240) pre_attention_norm_scale -Loading Parameters: (layer=23, size 10240) pre_ffw_norm_scale -Loading Parameters: (layer=23, size 61440) ffw_gating_biases -Loading Parameters: (layer=23, size 10240) ffw_output_biases -Loading Parameters: (layer=23, size 10240) attention_output_biases -Loading Parameters: (layer=24, size 26214400) griffin_linear_x_w -Loading Parameters: (layer=24, size 10240) griffin_linear_x_biases -Loading Parameters: (layer=24, size 26214400) griffin_linear_y_w -Loading Parameters: (layer=24, size 10240) griffin_linear_y_biases -Loading Parameters: (layer=24, size 26214400) griffin_linear_out_w -Loading Parameters: (layer=24, size 10240) griffin_linear_out_biases -Loading Parameters: (layer=24, size 40960) griffin_conv_w -Loading Parameters: (layer=24, size 10240) griffin_conv_biases -Loading Parameters: (layer=24, size 5242880) griffin_gate_w -Loading Parameters: (layer=24, size 20480) griffin_gate_biases -Loading Parameters: (layer=24, size 10240) griffin_a -Loading Parameters: (layer=24, size 157286400) gating_einsum_w -Loading Parameters: (layer=24, size 78643200) linear_w -Loading Parameters: (layer=24, size 10240) pre_attention_norm_scale -Loading Parameters: (layer=24, size 10240) pre_ffw_norm_scale -Loading Parameters: (layer=24, size 61440) ffw_gating_biases -Loading Parameters: (layer=24, size 10240) ffw_output_biases -Loading Parameters: (layer=25, size 26214400) griffin_linear_x_w -Loading Parameters: (layer=25, size 10240) griffin_linear_x_biases -Loading Parameters: (layer=25, size 26214400) griffin_linear_y_w -Loading Parameters: (layer=25, size 10240) griffin_linear_y_biases -Loading Parameters: (layer=25, size 26214400) griffin_linear_out_w -Loading Parameters: (layer=25, size 10240) griffin_linear_out_biases -Loading Parameters: (layer=25, size 40960) griffin_conv_w -Loading Parameters: (layer=25, size 10240) griffin_conv_biases -Loading Parameters: (layer=25, size 5242880) griffin_gate_w -Loading Parameters: (layer=25, size 20480) griffin_gate_biases -Loading Parameters: (layer=25, size 10240) griffin_a -Loading Parameters: (layer=25, size 157286400) gating_einsum_w -Loading Parameters: (layer=25, size 78643200) linear_w -Loading Parameters: (layer=25, size 10240) pre_attention_norm_scale -Loading Parameters: (layer=25, size 10240) pre_ffw_norm_scale -Loading Parameters: (layer=25, size 61440) ffw_gating_biases -Loading Parameters: (layer=25, size 10240) ffw_output_biases diff --git a/compression/python/pytree/ml_model_transforms.py b/compression/python/pytree/ml_model_transforms.py deleted file mode 100644 index 3605c07..0000000 --- a/compression/python/pytree/ml_model_transforms.py +++ /dev/null @@ -1,371 +0,0 @@ -"""Transformations for python-trees representing the parameters of a ML model. - -Important: This module assumes that byte-order is the same on the -machine that serializes data and the machine that deserializes -data. If, for example, numpy-data gets dumped, respectively loaded, -with a dtype-specification of numpy.float32, on-file byte-order -will be host byte order. - -""" - -import ast -import hashlib -import itertools -import pprint -import sys -import time -from typing import Any, Callable, Iterable, Iterator, Mapping, TypeVar - -import numpy -import pytree_transforms - - -NT = TypeVar('NT') - - -def ml_model_leaf_summary(path, x, sep=', '): - """Produces a textual summary of a leaf-node and its path. - - Args: - path: The path-to-root, as a reverse-order recursive - pair of path-components, with `()` as root. - x: The leaf-object. - sep: the separator between description-elements. - Default ', ' allows for convenient line-by-line processing - (such as via grep, perl -ne, etc.), but using e.g. sep=',\n ' - might be more useful for human consumption. - - Returns: - A human-readable string providing information about the node. - """ - # Using `repr` for path-components to get a faithful presentation. - # (...which still however would be somewat painful to correctly - # split into components.) - path_str = ','.join(map(repr, - pytree_transforms.linearize_revtuple_path(path))) - tx = type(x) - mod = tx.__module__ # Either a module or a string like 'builtins'. - modname = mod if isinstance(mod, str) else mod.__name__ - type_str = f'{modname}.{tx.__qualname__}' - try: - # `numpy.ndarray` instances have a `.data` property that gives access - # to a buffer via which we can hashlib-fingerprint the numerical - # contents. We here simply try to produce a fingerprint and also look - # up the .dtype of the object. Technically, there is a somewhat-unsound - # assumption here that if these operations succeed, we are indeed looking - # at a ndarray or sufficiently similar object for these operations to - # make sense. As the output is declared "for human consumption", this - # fishiness is not a problem. - fp = hashlib.sha256(x.data).hexdigest() - start = list(itertools.islice(x.flat, 5)) - stats_str = ( - f'min={numpy.min(x):.6g}, max={numpy.max(x):.6g}, ' - f'mean={numpy.mean(x):.6g}, std={numpy.std(x):.6g}') - return (f'{path_str:60s}: <{type_str}{sep}' - f'fp=0x{fp}{sep}{stats_str}{sep}shape={x.shape}, ' - f'dtype={x.dtype}{sep}start={start}>') - except (AttributeError, ValueError, TypeError): - # Fallback - trying to include information about the data-content - # of a likely-numerical-array failed. - return f'{path_str:60s}: {type_str}({repr(x)})' - - -# A specialized node-handler. -# Interface follows node-handler expectations defined in pytree_transforms. -def _ml_model_tree_node_handler(path: tuple, node : NT) -> ( - None | tuple[Callable[[Iterable[tuple[Any, NT]]], NT], - Iterator[tuple[Any, NT]]]): - """Processes a tree-node as required by pytree-iteration and -mapping. - - Args: - path: revtuple path to the current node. - node: a tree-node in a ML-model tree that is recursively - built out of `numpy.ndarray` leaf-values and dicts mapping - node-name string-keys to other such nodes representing subtrees - - and nothing else. - - Returns: - `None` if the tree-node is to be regarded as a leaf, otherwise - a pair `(rebuilder, iterator)`, where `iterator` iterates - over the data-content of the node, each item represented as a pair - of `(lookup_path_item, value_item)`, and `rebuilder` is a function - which, when applied to `iterator` or any iterable with the same - elements, returns a node that is equivalent to the original. - - Raises: - NotAMLModelTreeNodeError: If the tree contains a node that is neither - a `dict` nor a `numpy.ndarray` instance. - """ - # The astute reader will notice that we are doing something fishy - # here - this code could not be translated to Haskell as-is, since - # `NT` cannot actually be a proper type-variable in the sense - # of parametric polymorphism. - del path # Unused. - if isinstance(node, dict): - return dict, iter(node.items()) - if isinstance(node, numpy.ndarray): - return None - raise pytree_transforms.NotAMLModelTreeNodeError( - f'Type of bad node: {type(node)}') - - -def _ml_model_extract_leaf_transform( - path: pytree_transforms.RevTuplePath, - leaf: Any): - """Maps an array-leaf to a pair `(full_path, lambda: array)`. - - The computation that produces the leaf-value is lazified underneath - a `lambda`, since if we e.g. performed a memory-expensive - transformation (such as some dtype-changes) directly at this point, - then going from an iterator over tree-items for one-by-one - consumption to a list of these items would have all the - dtype-transformed values around simultaneously. We want to avoid - situations where we can do nothing about having multiple variants - of the data simultaneously in memory. - """ - # Hack: If we are encountering a `bfloat16` numpy-array, - # we pretend to have the data as a numpy.float32 array, - # since that's about all that contemporary CPUs can process - # efficiently here. - linearized_path = pytree_transforms.linearize_revtuple_path(path) - try: - # We have to use some trickery to detect `bfloat16`. - if leaf.dtype.descr[-1] == ('', ' Any: - """Performs perl-style autovivification on a nested-dict tree. - - Args: - keys_and_vals: An iterable of pairs `(key_path, value)`, where - `key_path` is a sequence of keys to be used to navigate to - the result via iterative dict-lookup, left-to-right. - Must not have duplicate keys, and must not more than one key if - an empty-sequence key is present. If this iterable is an - iterator, it will be fully exhausted on successful execution. - - Returns: - An object representing a nested-dict structure such that - for every `key_path` from `keys_and_vals`, recursive-dict-lookup - on the elements of that path starting from this object will - produce the corresponding value. An empty `keys_and_vals` - set will return `{}`. Every dict in the nested return-value - that has been populated by autovivification is newly allocated. - """ - # Code structure is a bit gnarly here due to f(keys_and_vals=[((), x)]) - # having to evaluate to x and not a dict. - # There may be ways to prettify/simplify this. - result = None - empty = {} - for linear_path, val in keys_and_vals: - if linear_path == (): - if result is not None: - raise ValueError('Root-value seen alongside other values.') - result = val - else: - if result is None: - result = {} - elif type(result) is not dict: - # We already did encounter a root-value. - raise ValueError('Root-value seen alongside other values.') - cursor = result - for n in range(len(linear_path) - 1): - cursor = cursor.setdefault(linear_path[n], empty) - if cursor is empty: - # Regenerate `empty` if we just used it up. - empty = {} - cursor[linear_path[-1]] = val - return {} if result is None else result - - -def model_overview(tree, out=None) -> None: - """Prints a human-readable overview to `(out or sys.stdout)`.""" - actual_out = out or sys.stdout - for line in pytree_transforms.pytree_leaf_iter( - tree, ml_model_leaf_summary, - _ml_model_tree_node_handler): - print(line, file=actual_out) - - -def model_contents(tree) -> Mapping[tuple[str, ...], Any]: - """Maps a model to a {pytree_keys: data_array} mapping. - - Args: - tree: The ML-model parameter-tree, built recursively out of - dict-instances with numpy.ndarray instances as leaves. - - Returns: - A mapping from linearized pytree-key-sequence tuple to the corresponding - leaf-value. - """ - def leaf_transform(revtuple_path, leaf): - return pytree_transforms.linearize_revtuple_path(revtuple_path), leaf - return dict( - pytree_transforms.pytree_leaf_iter( - tree, leaf_transform, _ml_model_tree_node_handler)) - - -def _fn_identity(x): return x - - -def model_save(tree, - filepath_stem: str, - data_suffix: str = '.data', - manifest_suffix: str | None = '.manifest', - key: Callable[[tuple[str, ...]], Any] | None = None, - array_transform_by_pytree_key: ( - Mapping[tuple[str, ...], - Callable[[numpy.ndarray], numpy.ndarray]] | - None) = None, - report: Callable[[str], None] | None = None, - byte_align: int = 8) -> tuple[int, float]: - """Saves the content of a ML-model parameter-tree to filesystem. - - After successful execution, the file f"{filepath_stem}.data" - will hold the combined numerical model-parameters, and - f"{filepath_stem}.manifest" will contain the key for interpreting - (and rebuilding) the data. - - Args: - tree: The ML-model parameter-tree, built recursively out of - dict-instances with numpy.ndarray instances as leaves. - filepath_stem: Filesystem location for data. - data_suffix: Suffix to use for the data file. - manifest_suffix: Either `None`, in which case no manifest-file - will get written, or the suffix for the manifest-file. - key: `None` or a key-function that will be applied to the linear model-path - and used for sorting the data arrays by increasing value of the - key-function. If the key-function returns `None` on an item, - then this item is not included. - array_transform_by_pytree_key: Optional mapping from pytree-key - to an array-to-array transformation function to apply to the array - prior to serialization. - report: Optional callable for logging progress-reports. - byte_align: byte-alignment to use for numerical array data. - Numerical arrays whose size in bytes is not a multiple of this - will get padded to the next full multiple. - - Returns: - A pair of `(size, time_sec)`, where `size` is the total byte-size - of the `.data` file and `time_sec` is the elapsed time - for saving the model, in seconds. - """ - time0 = time.monotonic() - if array_transform_by_pytree_key is None: - array_transform_by_pytree_key = {} - model_lazy_items = ( - pytree_transforms.pytree_leaf_iter( - tree, _ml_model_extract_leaf_transform, - _ml_model_tree_node_handler)) - if key is not None: - to_write = [ - nkv[1:] for nkv in sorted( - (nkv for nkv in ((key(path), path, v) - for path, v in model_lazy_items) - if nkv[0] is not None), key=lambda nkv: nkv[0])] - else: - to_write = list(model_lazy_items) - # - def lazy_arr_path_shape_dtype_size(path_and_lazy_arr): - path, lazy_arr = path_and_lazy_arr - arr = array_transform_by_pytree_key.get(path, _fn_identity)(lazy_arr()) - return path, arr.shape, arr.dtype, arr.data.nbytes - arrs_path_shape_dtype_nbytes = list( - map(lazy_arr_path_shape_dtype_size, to_write)) - # We need to know the total size of all the data. - bytesizes = [nbytes for *_, nbytes in arrs_path_shape_dtype_nbytes] - padded_bytesizes = [-(-bytesize // byte_align * byte_align) - for bytesize in bytesizes] - offsets = numpy.cumsum([0] + padded_bytesizes) - membuf = numpy.memmap(filepath_stem + data_suffix, - mode='w+', shape=offsets[-1]) - try: - for (path, shape, dtype, nbytes), offset, (_, lazy_arr) in zip( - arrs_path_shape_dtype_nbytes, offsets, to_write): - # Note that if getting the array from the lazy lambda involved some - # computation, such as a copying dtype-change, that computation would - # end up being done multiple times here - including once above, to compute - # byte-sizes, and once more here. - transformed_arr = array_transform_by_pytree_key.get( - path, - _fn_identity)(lazy_arr()) - membuf[offset : offset + nbytes] = numpy.frombuffer( - transformed_arr.ravel().data, 'u1') - if report is not None: - samples = ', '.join(map(str, transformed_arr.ravel()[:5])) - report(f'# Adding: {path!r}\n bytes: {nbytes:10d}, ' - f'shape: {shape!r:30},\n start: [{samples}, ...]') - transformed_arr = None # Drop memory references to numerical arrays ASAP. - finally: - if membuf is not None: - membuf.flush() - # NumPy wart: the memory-buffer is a resource that conceptually - # should be .close()able - since mmap()ing holds on to a - # file descriptor. However, it looks as if that clean-up were done - # in the "finalizer", despite that having meanwhile been widely - # understood as dubious practice. So, the best we can do here is - # to explicitly and clearly remove our reference to the instance. - del membuf - if manifest_suffix is not None: - # We still have to serialize the data that allows us to reconstruct - # a tree that is equivalent to the original. - manifest_data = [ - dict(path=path, - dtype=dtype.descr[-1][-1], - shape=shape, - nbytes=nbytes, - offset=offset) - for (path, shape, dtype, nbytes), offset in zip( - arrs_path_shape_dtype_nbytes, offsets)] - with open(filepath_stem + '.manifest', 'wt') as h_manifest: - pprint.pprint(manifest_data, stream=h_manifest) - time_taken = time.monotonic() - time0 - return offsets[-1], time_taken - - -def model_load(filepath_stem, mmapped=True): - """Loads a model saved by `model_save`. - - Tries to load the model from f"{filepath_stem}.data" - and f"{filepath_stem}.manifest". - - Args: - filepath_stem: The model location on the filesystem. - mmapped: Whether data-arrays will be slices of a - `numpy.memmap` mapped buffer, to be paged in - on demand only, or in-memory copies of the data. - Returns: - A dict/numpy.ndarray tree representation of the model, - equivalent to the original model. - """ - with open(filepath_stem + '.manifest', 'rt') as h_manifest: - manifest = ast.literal_eval(h_manifest.read()) - membuf = numpy.memmap(filepath_stem + '.data', mode='r+') - paths_and_arrays = [] - for item in manifest: - path = item['path'] - dtype = numpy.dtype(item['dtype']) - shape = item['shape'] - nbytes = item['nbytes'] - offset = item['offset'] - data_array = numpy.frombuffer(membuf[offset : offset + nbytes].data, - dtype=dtype).reshape(shape) - paths_and_arrays.append( - (path, - data_array if mmapped else data_array.copy())) - # At this point, the memory-buffer is no longer needed. Still, if - # data-arrays retain references to the underlying data - # (i.e. when mmapped=False), this should keep the mapping - # - and hence file descriptor - open. We then are in a somewhat - # undesirable situation of clean-up of a resource that happens in a - # hard-to-predict way releasing a file descriptor. - del membuf - return revtuple_autovifify_from_linear(paths_and_arrays) diff --git a/compression/python/pytree/ml_model_transforms_test.py b/compression/python/pytree/ml_model_transforms_test.py deleted file mode 100644 index 9495c87..0000000 --- a/compression/python/pytree/ml_model_transforms_test.py +++ /dev/null @@ -1,92 +0,0 @@ -"""Basic tests for 'algebraic data type based pytree' transformations.""" - - -import io -import os -import tempfile -import unittest - -import numpy - -import ml_model_transforms - - -def _get_model(prefix): - return { - prefix + 'a1': numpy.arange(1000, 1024).reshape(6, 4).astype(numpy.float32), - prefix + 'a2': numpy.arange(2000, 2048).reshape(6, 8).astype(numpy.float32), - prefix + 'b1': { - prefix + 'c1': numpy.arange(100, 127).reshape(3, 3, 3).astype(numpy.int8), - prefix + 'c2': numpy.arange(100, 128).reshape(7, 4).astype(numpy.float64) - }} - - -class MLModeltransformsTest(unittest.TestCase): - """Basic correctness validation tests for ML-model transformations.""" - - def test_ml_model_leaf_summary(self): - """Tests guarantees given by `ml_model_leaf_summary`.""" - summary = ml_model_transforms.ml_model_leaf_summary( - ('a', ()), - numpy.arange(1000, 1024).reshape(6, 4).astype(numpy.int16), - sep='##') - self.assertIn('##', summary) # Separator is respected. - self.assertIn('(6, 4)', summary) # Shape is mentioned somewhere. - self.assertIn('int16', summary) # dtype is mentioned somewhere. - - def test_revtuple_autovivify_from_linear(self): - """Tests guarantees given by `revtuple_autovifify_from_linear`.""" - with self.subTest(guarantee='empty'): - self.assertEqual( - ml_model_transforms.revtuple_autovifify_from_linear([]), - {}) - with self.subTest(guarantee='generic'): - keys_vals = [(('a', 'b1', 'c1'), 1001), - (('a', 'b2'), 1002), - (('a2',), 1003), - ] - self.assertEqual( - ml_model_transforms.revtuple_autovifify_from_linear(keys_vals), - {'a': {'b1': {'c1': 1001}, 'b2': 1002}, 'a2': 1003}) - - def test_model_overview(self): - """Tests guarantees given by `model_overview`.""" - model = _get_model('xyz') - out_io = io.StringIO() - ml_model_transforms.model_overview(model, out=out_io) - overview = out_io.getvalue() - self.assertIn('xyz', overview) - - def test_model_contents(self): - """Tests guarantees given by `model_contents`.""" - model = _get_model('pq_') - contents = ml_model_transforms.model_contents(model) - fingerprints = {k: (a.shape, a.ravel()[:3].tolist()) - for k, a in contents.items()} - self.assertEqual(fingerprints, - {('pq_a1',): ((6, 4), [1000.0, 1001.0, 1002.0]), - ('pq_a2',): ((6, 8), [2000.0, 2001.0, 2002.0]), - ('pq_b1', 'pq_c1'): ((3, 3, 3), [100, 101, 102]), - ('pq_b1', 'pq_c2'): ((7, 4), [100.0, 101.0, 102.0])}) - - def test_model_save_load_basic(self): - """Tests basic guarantees given by `model_save` and `model_load`.""" - # What we care about here is that the round trip works - so - # it makes more sense to test saving and loading as one unit. - model_orig = _get_model('model_') - with tempfile.TemporaryDirectory() as tempdir: - filepath_stem = os.path.join(tempdir, 'the_model') - total_size, total_time = ml_model_transforms.model_save(model_orig, - filepath_stem) - self.assertGreater(total_size, 0) - self.assertGreater(total_time, 0) - model_reloaded = ml_model_transforms.model_load(filepath_stem) - contents_orig = ml_model_transforms.model_contents(model_orig) - contents_reloaded = ml_model_transforms.model_contents(model_reloaded) - self.assertEqual( - {k: v.tolist() for k, v in contents_orig.items()}, - {k: v.tolist() for k, v in contents_reloaded.items()}) - - -if __name__ == '__main__': - unittest.main() diff --git a/compression/python/pytree/pytree_transforms.py b/compression/python/pytree/pytree_transforms.py deleted file mode 100644 index 7e065af..0000000 --- a/compression/python/pytree/pytree_transforms.py +++ /dev/null @@ -1,508 +0,0 @@ -"""Tools for transforming "nested python object" tree data structures. - -# Context - -The motivation for this module came from ML applications that ought to -be based on a principled handling of nested Python data structures. -Having such principled pytree-transforming code available solves -some other problems, such as doing away with a need to abuse -tree-mapping for-side-effect-only and having to use a hope-and-pray -approach to processing very deeply nested values which with a recursive -approach might trigger a RecursionError. - -We specifically want to cover the use case of having ML model -parameters that are available in a nested Python data structure for -which there "almost" is a unique-up-to-unique-isomorphism mapping from -and to this Algebraic Data Type: - -`data ModelParams a = Array a | Node [(String, ModelParams a)]` - -In this correspondence, `a` is some array-type (perhaps -`numpy.ndarray`, `jax.numpy.ndarray`, `tf.tensor`, etc.), but the -data-processing code is effectively entirely agnostic to this, and a -`Node` is "almost" an associative-list of (key, value) pairs -representing a Python dict. (Note: The "almost" here is mostly about -the conceptual wart that assoc-lists can in principle have key -duplicates, but Python dicts can not. This is however not a problem -since all we need is the transformation in one direction, -i.e. whatever data-processing `f` we want to express on the -model-parameters-pytree, we can express by specifying a "faithful" -mapping `m` into the above algebraic data type through which every -such pytree data transform factorizes, i.e. for every `f` we can find -a `g` such that `f(p) = g(m(p))`.) - -## Components - -The main workhorse in this module is the `pytree_iter` function that -maps a "PyTree (such as representing `ModelParams`)" to an iterator -over values obtained by applying a mapping-function to the "key-path" -and leaf-value for every leaf, where the "key-path" contains a -linked-list representation of the reversed sequence of keys from the -tree-root, with list-nodes being represented by pairs -`(latest_dict_key, rest_path)`, and the empty path being represented -by `()`. - -For the sake of genericity, `pytree_iter` is built in such a way that -it actually can handle any kind of traversal of PyTree-trees that do -represent algebraic data types (note however that some some do not) - -but for this to make sense, the user must have a way to define how to -interpret tree-nodes, in particular identify leaves. This requires -providing a function `node_handler` with the same signature and -behavior as described below for "node handlers". - -Additionally, this module provides mapping-over-pytrees via -`pytree_map`, which is also built in such a way that it makes the -correspondence between an algebraic data type and its Python -nested-tree representation explicit. Despite being powerful and -flexible, this, however, may in general require a bit more effort to -wire up, since node-rebuilding can be fairly nontrivial. - -Furthermore, as a prominent application, this module provides a simple -deep-freezing function that translates a nested Python data structure -to deeply-immutable form. - -## Concepts and Conventions - -"revtuple representation": - - As we iterate over a tree, we will have to keep track of the - path-to-tree-root. Naturally, two sibling nodes `n1` and `n2` - will share the same parent-path (being siblings), so it makes - sense to use a linked-list-with-shared-tail representation. - Python does not have a natural notion for that, so we use - recursively-constructed tuples `(node_tag, parent_path)` - that represent the path-from-root in-reverse-order, i.e. - for a non-empty path `p`, `p[0]` is the node-tag at the - deepest nesting level. We call this a "revtuple representation" - of the path. - -"node handler": - - A node-handler classifies a tree-node as "leaf or other node", and - for non-leaf nodes provides information about both its children and - how to rebuild it. The behavior of a node-handler function must be - in alignment with this docstring: - - '''Processes a tree-node as required by pytree-iteration and -mapping. - - Args: - revtuple_path: Revtuple-representation of the path-from-root - to the current node. - node: a tree-node in a ML-model tree that is recursively - built out of leaf-values and other nodes. - - Returns: - `None` if the tree-node is to be regarded as a leaf, otherwise - a pair `(rebuilder, iterator)`, where `iterator` iterates - over the data-content of the node, each item represented as a pair - of `(lookup_path_item, value_item)`, and `rebuilder` is a function - which, when applied to an iterable of the aforementioned value-items - (or some transformation thereof) returns a node that is equivalent - to the original (or up to a transformation of the contents). - - Raises: - InvalidTreeNodeError: If the tree contains a node of a kind - that is not expected to show up. - ''' - - Examples: - - (The behavior of a node-handler is somewhat nontrivial, so covering - two very common cases via examples is in order.) - - This node-handler would allow descending into (nested) - instances of `list` (but not subclass instances thereof): - - ```def list_node_handler(revtuple_path, obj): - ''' ... ''' - if type(obj) is list: - return list, enumerate(obj) - else: - return None - ``` - - This node-handler would allow descending into (nested) mappings, - which upon rebuilding would get turned into `dict` instances: - - ```def mapping_node_handler(revtuple_path, obj): - ''' ... ''' - if isinstance(obj, collections.abc.Mapping): - # For generic mappings, we cannot rely on key- and item-iteration - # being guaranteed to use identical iteration-order. - items = list(obj.items()) - keys = [kv[0] for kv in items] - return (lambda values: dict(zip(keys, values))), items - else: - return None - ``` - - A dict/mapping node-handler can of course rename keys, add or remove - entries, make decisions based on the item-path, or map a dict to - an associative list, etc. - -## Further Design Notes - -The `pytree_map` function requests the leaf-transform and node-handler -to be side-effect-free functions. This is both required to leave -implementation-side flexibility, and also follows the general LISP -recommendation to not abuse mapping (which should be a pure -data-transformation) for imperative data processing. Overall, if -a need for more general "nested datastructures" processing becomes -pressing, it is for the better if this leads to a proper articulation -of the specific needs, to be addressed with appropriate design, rather -than abuse of functional data-transforms becoming "a bad idiom -that turned into established practice". - -""" - -import collections.abc -import immutabledict - -import numpy - -from typing import Any, Callable, Iterable, Iterator, TypeVar - - -T = TypeVar('T') -U = TypeVar('U') - -KT = TypeVar('KT') -NT = TypeVar('NT') - - -## Type of the reverse-order-keys-to-root path. -# (This code actually illustrates why https://xkcd.com/2483/ is very misguided.) -RevTuplePath = tuple - -## Type of the `leaf_transform` function-argument used for tree-iteration. -# -# This would be the correct type we would have to specify here but cannot, -# since the design of Python's static typing at the time of this writing -# is too broken for that: -# -# type LeafTransformFunc[L, R] = Callable[[RevTuplePath, L], R] -# -# Instead, we have to settle for...: -LeafTransformFunc = Callable[[RevTuplePath, Any], Any] - - -## Type of the `tree_node_handler` function-argument used for -## tree-iteration and tree-mapping. -# -# Again, this is the correct type we would have to put here but cannot: -# -# type NodeHandlerFunc[KT] = ( -# Callable[[NT], -# None | tuple[Callable[[Iterable[tuple[KT, NT]]], NT], -# Iterator[tuple[KT, NT]]]]) -# -# ...so, we have to instead settle for: -NodeHandlerFunc = ( - Callable[[RevTuplePath, NT], - None | tuple[Callable[[Iterable[tuple[Any, NT]]], NT], - Iterator[tuple[Any, NT]]]]) - - -Predicate = Callable[[object], bool] - - -class InvalidTreeNodeError(ValueError): - """Encountered a tree-node of invalid type.""" - - -def linearize_revtuple_path( - revtuple_path: RevTuplePath, - present_as: Callable[[Iterator[T]], U] = tuple) -> U: - """Translates a revtuple path to (typically) linear form. - - With default `present_as`, this will map a path of the form - `(key_{N}, (key_{N-1}, ..., (root, ())))` into a tuple - (root, ..., key_{N-1}, key_{N}). - - Args: - revtuple_path: A linked-list-as-recursive-pairs - reverse-order tuple-representation of the path. - Path-root is `()`, and node-key `x` relative to - earlier path `p` is represented as `(x, p)`. - present_as: Callable that consumes an iterator over - path-pieces - with the deepest-nesting level coming last - - turning it into a linearized path. Defaults to `tuple`. - - Returns: - Linearized presentation of all the node-keys in the - recursive-path in order, deepest-down path component coming last. - """ - pieces = [] - todo = revtuple_path - while todo: - node, todo = todo - pieces.append(node) - return present_as(reversed(pieces)) - - -# This function itself has type `NodeHandlerFunc`, but Python does not -# allow us to here simply type-annotate it like this. We cannot even -# introduce an abbreviation for the complicated output-type, -# since that would have to be parametric in node-type `NT` (and `KT`). -def everything_is_a_leaf_node_handler( - revtuple_path: tuple, - node : NT) -> ( - None | tuple[Callable[[Iterable[tuple[Any, NT]]], NT], - Iterator[tuple[Any, NT]]]): - """Processes a tree-node as required by pytree-iteration and -mapping. - - Interface and signature are in alignment with the requirements for a - "node handler" function explained in the module-docstring. - - Args: - revtuple_path: the path-to-root for this node. - node: a tree-node. - - Returns: - `None`, i.e. classifying any kind of node as a leaf-node. - """ - del revtuple_path, node # Unused. - return None - - -def leaf_summary(path: RevTuplePath, x: object): - """Produces a human-readable summary-string for a leaf-node. - - Args: - path: revtuple representation of the path-to-root. - x: The leaf-value. - """ - del path # Ignored here. - tx = type(x) - mod = tx.__module__ - modname = mod if isinstance(mod, str) else mod.__name__ - type_str = f'{modname}.{tx.__qualname__}' - repr_str = repr(x) - repr_abbrev = repr_str if len(repr_str) < 40 else repr_str[:40] + ' ...' - # On str, int, float, etc. `{type_str}(repr(x))` would actually still be - # a (non-literal) Python-expression that would evaluate to the original value. - # However, we make no promises beyond "human-readable". - return f'{type_str}({repr_abbrev})' - - -# With respect to static type annotations, the limitations of Python's -# approach to static typing really become prominently visible here. -# -# Different arguments have type-parameters, but since there is no way -# to have parametric abbreviations such as `LeafTransformFunc[L, R]`, -# the only way we would have available to express relations between -# type-parameters would be to substitute in the not-abbreviated form of -# `NodeHandlerFunc` and `LeafTransformFunc`, giving us something monstrous. -# We instead here settle for "we cannot express that `tree` must -# have the same type as the input-type to `tree_node_handler` and use `Any`, -# and likewise for leaf_transform and the output. -def pytree_leaf_iter( - tree: Any, - leaf_transform: LeafTransformFunc, - node_handler: NodeHandlerFunc = everything_is_a_leaf_node_handler, - ) -> Iterator[Any]: - # ...actual return type would be `Iterator[{what leaf_transform returns}]`. - """Iterates over the leaves of a tree. - - Args: - tree: The tree to iterate over. - leaf_transform: A callable `f` that will get applied - as `f(revtuple_path, leaf)`, where `revtuple_path` - is the revtuple representation of the path to the - leaf from the root. - node_handler: A "node handler" (see module docstring) - that processes nodes encountered during iterative traversal. - - Yields: - Value of `leaf_transform(p, x)`, where `x` is the current leaf - and `p` is its revtuple-path to the root. - """ - # Note: Exit points for the code below are in non-obvious places - # and hence marked with " # ***EXIT***". - # - # Doing iteration properly is slightly nontrivial. - # One may be tempted to go for a very simple recursive implementation - # (with an extra pre-final `path` argument to `pytree_iter`): - # - # maybe_substructure = node_handler(path, tree) - # if maybe_substructure is None: - # # We are looking at a leaf-node. - # yield leaf_transform(path, tree) - # else: - # _, contents_iter = maybe_substructure - # for k, v in contents_iter: - # yield from pytree_iter(v, leaf_transform, (k, path), node_handler) - # - # That, however, would be flawed, since there is no a priori reason - # why a pytree may not be a very deeply nested structure - such as a - # long linked list. That would then risk raising `RecursionError`, - # and since Python by design(!) does not perform tail call elimination - # or any other kind of advanced CPS transforms, there is no recursive - # solution here. So, to do this properly, we have to do this iteratively. - # - # We are facing an annoying situation here: If `tree` itself is a leaf, - # we have two options: (a) wrapping it up in a one-node tree - # and processing that, or (b) special-casing "root is a leaf". - # Option (b) leads to some mild node-processing code-duplication - # for a single node (the root). - # Option (a) requires having special cases for node-processing that - # get looked at for every tree node. We go with option (b) here. - maybe_substructure = node_handler((), tree) - if maybe_substructure is None: - # The tree itself is a leaf. - yield leaf_transform((), tree) - return # ***EXIT*** - # Otherwise, we are looking at a tree. - _, contents_iter = maybe_substructure - current_revtuple_path = () - work_to_do = [contents_iter] - # Otherwise-unreachable sentinel for reliably identifying - # iterator-exhaustion without using exceptions: - sentinel = object() - while True: - current_iter = work_to_do[-1] - maybe_next_item = next(current_iter, sentinel) - if maybe_next_item is sentinel: - # We are done at this level. - work_to_do.pop() - if not work_to_do: return # ***EXIT*** - current_revtuple_path = current_revtuple_path[1] - else: - path_piece, subtree = maybe_next_item - extended_revtuple_path = (path_piece, current_revtuple_path) - maybe_subtree_substructure = node_handler(extended_revtuple_path, subtree) - if maybe_subtree_substructure is None: # Case: subtree is a leaf. - yield leaf_transform(extended_revtuple_path, subtree) - else: # Case: subtree is a tree. - current_revtuple_path = (path_piece, current_revtuple_path) - _, items_iter = maybe_subtree_substructure - work_to_do.append(items_iter) - - -# The current design approach here would be appropriate for -# applying leaf-transforms while retaining the structure of the tree - -# which closely corresponds to e.g. a (a -> b) -> (Tree a -> Tree b) functor. -# -# It is not entirely clear whether this is the abstraction that we should -# consider as being appropriately generic to flesh out explicitly - rather -# than starting from a more general approach of which this then is a special -# case. Some background: https://ncatlab.org/nlab/show/recursion+scheme -# -# On the other hand, there is a lot of flexibility via whatever -# node-rebuilder a node-handler produces - this can do quite some reshaping -# of a tree, including dropping or duplicating nodes. -def pytree_map( - tree: Any, - leaf_transform, - node_handler: NodeHandlerFunc = everything_is_a_leaf_node_handler, - ): - """Maps a (potentially nested) Python value to another such value. - - Args: - tree: The Python-object to be mapped. - leaf_transform: A callable `f` that will get applied - as `f(revtuple_path, leaf)`, where `revtuple_path` - is the revtuple representation of the path to the - leaf from the root. Must be side effect free. - node_handler: A "node handler" (see module docstring) - that processes nodes encountered during iterative traversal. - Must be side effect free. - - Returns: - The outcome of translating `tree`. - """ - # Note: Exit points for the code below are in non-obvious places - # and hence marked with " # ***EXIT***". - # - # Otherwise-inaccessible sentinel object, for reliably identifying - # missing-values via identity-check against sentinel lookup-default. - sentinel = object() - # Code structure mostly follows pytree_leaf_iter. - maybe_substructure = node_handler((), tree) - if maybe_substructure is None: - return leaf_transform((), tree) # ***EXIT*** - rebuilder, items_iter = maybe_substructure - current_revtuple_path = () - # Per-level, we have a triplet of: - # (rebuilder, remaining_items_to_iterate_over, processed). - parts_for_assembly = [(rebuilder, items_iter, [])] - while True: - this_rebuilder, this_items_iter, this_done_pieces = parts_for_assembly[-1] - maybe_next_item = next(this_items_iter, sentinel) - if maybe_next_item is sentinel: - # We are done with all the items for this level. - parts_for_assembly.pop() - built_iter = this_rebuilder(this_done_pieces) - if not parts_for_assembly: # No outer structure, so at-top-level. - return built_iter # ***EXIT*** - else: # We have outer structure. - parts_for_assembly[-1][-1].append(built_iter) - current_revtuple_path = current_revtuple_path[1] - continue # ...with next is-the-final-item-complete-check. - else: - # More constituents of the current item. - path_piece, subtree_item = maybe_next_item - extended_revtuple_path = (path_piece, current_revtuple_path) - maybe_subtree_substructure = node_handler( - extended_revtuple_path, - subtree_item) - if maybe_subtree_substructure is None: - this_done_pieces.append( - leaf_transform(extended_revtuple_path, subtree_item)) - else: - # We have a subtree. - subtree_rebuilder, subtree_items_iter = maybe_subtree_substructure - current_revtuple_path = (path_piece, - current_revtuple_path) - parts_for_assembly.append( - (subtree_rebuilder, subtree_items_iter, [])) - - -def deep_freeze( - tree, - *, - is_mapping : Predicate = lambda x: isinstance(x, collections.abc.Mapping), - is_set : Predicate = lambda x: isinstance(x, collections.abc.Set), - is_sequence : Predicate = lambda x: isinstance(x, (list, tuple)), - leaf_fn: Callable[[Any], Any] = lambda x: x, - ): - """Recursively freezes Set/Mapping/List/Tuple structures. - - Args: - tree: The potentially deeply-nested object to deep-freeze. - is_mapping: Callable that decides whether a sub-object is a mapping. - Defaults to an `isinstance()` check for `collections.abc.Mapping`. - is_set: Callable that decides whether a sub-object is a set. - Defaults to an `isinstance()` check for `collections.abc.Set`. - is_sequence: Callable that decides whether a sub-object is a sequence. - Defaults to a check for being a `tuple` or `list` instance. - leaf_fn: Function to use for translating non-mapping/set/sequence - instances. - - Returns: - Translated-to-deeply-immutable form of `tree`. - """ - idict = immutabledict.immutabledict - def freeze_node_handler(path, x): - if is_set(x): - return frozenset, ((None, y) for y in x) - if is_mapping(x): - # Mappings already have hashable, so - # (should-be-)deeply-immutable keys. - # Hence, we only need to deep-freeze the values. - # - # Note that non-`dict` mappings might not guarantee - # to respect iteration-order, so we have to be careful here: - items = list(x.items()) - keys = [kv[0] for kv in items] - values = [kv[1] for kv in items] - return ((lambda ys: idict(zip(keys, ys))), - iter(items)) - if is_sequence(x): - return tuple, enumerate(iter(x)) - # Otherwise, this should not be traversed. - return None - def leaf_transform(revtuple_path, value): - del revtuple_path # Unused. - return leaf_fn(value) - return pytree_map(tree, leaf_transform, freeze_node_handler) diff --git a/compression/python/pytree/pytree_transforms_test.py b/compression/python/pytree/pytree_transforms_test.py deleted file mode 100644 index fdaec71..0000000 --- a/compression/python/pytree/pytree_transforms_test.py +++ /dev/null @@ -1,168 +0,0 @@ -"""Basic tests for 'algebraic data type based pytree' transformations.""" - - -import collections.abc -import sys -import unittest - -import pytree_transforms - - -def _get_deep_pytree(packaging_fn, bottom, depth): - current = bottom - for n in reversed(range(depth)): - current = packaging_fn(n, current) - return current - - -def _dict_node_handler(p, d): - del p # Unused. - if isinstance(d, dict): - keys = d.keys() - newdict = lambda vals: dict(zip(keys, vals)) - return (newdict, iter(d.items())) - else: - return None - - -class PyTreeTest(unittest.TestCase): - """Basic correctness validation tests for PyTree transformations.""" - - def test_linearize_revtuple_path(self): - """Tests guarantees given by `linearize_revtuple_path`.""" - linearize_revtuple_path = pytree_transforms.linearize_revtuple_path - with self.subTest(guarantee='empty'): - self.assertEqual(linearize_revtuple_path(()), ()) - with self.subTest(guarantee='typical'): - self.assertEqual(linearize_revtuple_path((30, (20, (10, ())))), - (10, 20, 30)) - with self.subTest(guarantee='present_as'): - self.assertEqual( - linearize_revtuple_path( - (30, (20, (10, ()))), present_as=list), - [10, 20, 30]) - - def test_everything_is_a_leaf_node_handler(self): - """Tests guarantees given by `everything_is_a_leaf_node_handler`.""" - everything_is_a_leaf_node_handler = ( - pytree_transforms.everything_is_a_leaf_node_handler) - self.assertEqual(everything_is_a_leaf_node_handler((), 'abc'), - None) - self.assertEqual(everything_is_a_leaf_node_handler(('b', ()), - dict(a=3)), - None) - - def test_leaf_summary(self): - """Tests guarantees given by `leaf_summary`.""" - # Since the docstring only guarantees "a human-readable presentation", - # we can and should only do loose checks. - thing = (5678, 9531) - summary = pytree_transforms.leaf_summary(('key', ()), thing) - self.assertIsInstance(summary, str) - self.assertIn(str(thing[0]), summary) - self.assertIn(str(thing[1]), summary) - - def test_pytree_leaf_iter(self): - """Tests guarantees given by `pytree_leaf_iter`.""" - pytree_leaf_iter = pytree_transforms.pytree_leaf_iter - def leaf_transform(path, leaf): - return repr(leaf) if path and path[0].startswith('R') else leaf - with self.subTest(guarantee='returns_iterator'): - result = pytree_leaf_iter(7, leaf_transform, _dict_node_handler) - self.assertIsInstance(result, collections.abc.Iterator) - with self.subTest(guarantee='totally_empty'): - result = list(pytree_leaf_iter({}, leaf_transform, _dict_node_handler)) - self.assertEqual(result, []) - with self.subTest(guarantee='no_leaves'): - result = list(pytree_leaf_iter(dict(a={}), - leaf_transform, _dict_node_handler)) - self.assertEqual(result, []) - with self.subTest(guarantee='is_leaf'): - result = list(pytree_leaf_iter(777, leaf_transform, _dict_node_handler)) - self.assertEqual(result, [777]) - with self.subTest(guarantee='generic'): - result = list(pytree_leaf_iter( - dict(n0=dict(n01=dict(n012=1002, - n013=1003, - Rn014=1004, - ), - n02=1005), - n5=1006), - leaf_transform, _dict_node_handler)) - self.assertEqual(result, [1002, 1003, '1004', 1005, 1006]) - with self.subTest(guarantee='with_keys'): - result = list(pytree_leaf_iter( - dict(n0=dict(n01=dict(n012=1002, - n013=1003)), - n1=1004), - lambda p, s: (pytree_transforms.linearize_revtuple_path(p), s), - _dict_node_handler)) - self.assertEqual(result, - [(('n0', 'n01', 'n012'), 1002), - (('n0', 'n01', 'n013'), 1003), - (('n1',), 1004)]) - - def test_pytree_map(self): - """Tests guarantees given by `pytree_map`.""" - pytree_map = pytree_transforms.pytree_map - leaf_transform = lambda p, s: repr(s) - tree1 = dict(t0=dict(t10=1001, - t11=dict(t110=1002, - t111=1003), - t12=dict(t120=1004, - t121=1005, - t122=1006)), - t1=1007) - with self.subTest(guarantee='no_leaves'): - result = pytree_map(dict(a={}), - leaf_transform, - _dict_node_handler) - self.assertEqual(result, dict(a={})) - with self.subTest(guarantee='is_leaf'): - result = pytree_map(777, leaf_transform, _dict_node_handler) - self.assertEqual(result, '777') - with self.subTest(guarantee='generic'): - result = pytree_map(tree1, leaf_transform, _dict_node_handler) - self.assertEqual(result['t0']['t10'], '1001') - - def test_deeply_nested(self): - """Tests correct behavior on deeply-nested data structures.""" - pytree_leaf_iter = pytree_transforms.pytree_leaf_iter - pytree_map = pytree_transforms.pytree_map - # - depth = max(10**5, sys.getrecursionlimit() + 100) - deep_tree = _get_deep_pytree(lambda n, t: {n: t}, - 'leaf', depth) - with self.subTest(function='pytree_leaf_iter'): - leaves = list(pytree_leaf_iter(deep_tree, - lambda p, s: s.upper(), - _dict_node_handler)) - self.assertEqual(leaves, ['LEAF']) - with self.subTest(function='pytree_map'): - mapped_deep_tree = pytree_map(deep_tree, - lambda p, s: s, - _dict_node_handler) - self.assertIsInstance(mapped_deep_tree, dict) - with self.subTest(function='combined'): - leaves = list( - pytree_leaf_iter( - pytree_map(deep_tree, - lambda p, s: s.capitalize(), - _dict_node_handler), - lambda p, s: s + s, - _dict_node_handler)) - self.assertEqual(leaves, ['LeafLeaf']) - - def test_deep_freeze(self): - """Tests guarantees given by `deep_freeze`.""" - frozen = pytree_transforms.deep_freeze( - dict(a=[1001, 1002, dict(b=(1003, [1004, {1005, 1006}]))])) - self.assertIsInstance(frozen, collections.abc.Mapping) - self.assertNotIsInstance(frozen, collections.abc.MutableMapping) - self.assertIsInstance(frozen['a'], tuple) - # `frozen` is hashable, and hashes to an integer. - self.assertIsInstance(hash(frozen), int) - - -if __name__ == '__main__': - unittest.main() diff --git a/compression/python/pytree/requirements.txt b/compression/python/pytree/requirements.txt deleted file mode 100644 index 90c3f39..0000000 --- a/compression/python/pytree/requirements.txt +++ /dev/null @@ -1,4 +0,0 @@ -immutabledict>=4.2.0 -numpy>=1.26.4 -orbax-checkpoint>=0.0.0 - diff --git a/evals/gemma_test.cc b/evals/gemma_test.cc index 26313c1..04eb20a 100644 --- a/evals/gemma_test.cc +++ b/evals/gemma_test.cc @@ -158,9 +158,6 @@ TEST_F(GemmaTest, CrossEntropySmall) { float entropy = s_env->CrossEntropy(kSmall); fprintf(stderr, "per-token entropy: %f\n", entropy); switch (config.model) { - case gcpp::Model::GRIFFIN_2B: - EXPECT_NEAR(entropy, 2.61f, 0.02f); - break; case gcpp::Model::GEMMA2_2B: EXPECT_NEAR(entropy, 1.14f, 0.02f); break; diff --git a/gemma/activations.h b/gemma/activations.h index 63b3153..67e1eba 100644 --- a/gemma/activations.h +++ b/gemma/activations.h @@ -31,32 +31,6 @@ namespace gcpp { -struct GriffinActivations { - GriffinActivations(const ModelConfig& config, size_t batch_size, - const Allocator& allocator) - : griffin_x( - MatFactory("griffin_x", batch_size, config.model_dim, allocator)), - griffin_y( - MatFactory("griffin_y", batch_size, config.model_dim, allocator)), - griffin_gate_x(MatFactory("griffin_gate_x", batch_size, - config.model_dim, allocator)), - griffin_multiplier(MatFactory("griffin_mul", batch_size, - config.model_dim, allocator)) {} - - void SetBatchSize(size_t batch_size) { - if (griffin_x.Rows() == 0) return; - griffin_x.OverrideRows(batch_size); - griffin_y.OverrideRows(batch_size); - griffin_gate_x.OverrideRows(batch_size); - griffin_multiplier.OverrideRows(batch_size); - } - - MatStorageT griffin_x; - MatStorageT griffin_y; - MatStorageT griffin_gate_x; - MatStorageT griffin_multiplier; -}; - struct AttentionActivations { // Returns the scale value to use for the query in the attention computation. // Also called by ops_test. @@ -143,7 +117,7 @@ struct AttentionActivations { MatStorageT inv_timescale_global; hwy::Divisor div_seq_len; - // Unfortunately, some models (Griffin) have non-power-of-two heads. + // Unfortunately, some models have had non-power-of-two heads. hwy::Divisor div_heads; float query_scale; }; @@ -169,9 +143,7 @@ struct Activations { MatFactory("ffw_out", batch_size, config.model_dim, ctx.allocator)), attention(config, layer_config, batch_size, seq_len, ctx.allocator, - row_ptrs), - griffin(config, config.model == Model::GRIFFIN_2B ? batch_size : 0, - ctx.allocator) { + row_ptrs) { HWY_ASSERT(batch_size != 0); // For MatMul outputs, precompute their row pointers. @@ -199,7 +171,6 @@ struct Activations { ffw_out.OverrideRows(batch_size); attention.SetBatchSize(batch_size); - griffin.SetBatchSize(batch_size); } const LayerConfig& layer_config; @@ -215,7 +186,6 @@ struct Activations { MatStorageT ffw_out; AttentionActivations attention; - GriffinActivations griffin; }; } // namespace gcpp diff --git a/gemma/attention.cc b/gemma/attention.cc index 8afd561..31ed4d1 100644 --- a/gemma/attention.cc +++ b/gemma/attention.cc @@ -327,6 +327,7 @@ static HWY_INLINE void SumHeads(const LayerWeightsPtrs& layer, MatMulEnv& env) { PROFILER_ZONE("Gen.Attention.SumHeads"); const LayerConfig& layer_config = layer.layer_config; + (void)layer_config; // For HWY_DASSERT // att_weights and att_out are concatenated heads, each of length // layer_config.qkv_dim. Thus the [num_interleaved, // layer_config.model_dim] matmul output is the sum over heads. Compare @@ -334,10 +335,7 @@ static HWY_INLINE void SumHeads(const LayerWeightsPtrs& layer, // encoded) HWY_DASSERT(layer_config.model_dim != 0 && layer_config.heads != 0 && layer_config.qkv_dim != 0); - const float* add = layer_config.softmax_attn_output_biases - ? layer.attention_output_biases.PackedScale1() - : nullptr; - CallMatMul(activations.att_out, layer.att_weights, add, env, + CallMatMul(activations.att_out, layer.att_weights, /*add=*/nullptr, env, activations.att_sums); } diff --git a/gemma/configs.cc b/gemma/configs.cc index f19d30d..8856203 100644 --- a/gemma/configs.cc +++ b/gemma/configs.cc @@ -133,78 +133,6 @@ static ModelConfig ConfigGemma2_2B() { return config; } -static LayerConfig LayerConfigGemmaTiny(size_t model_dim) { - LayerConfig config; - config.model_dim = model_dim; - config.ff_hidden_dim = 256; - config.heads = 4; - config.kv_heads = 1; - config.qkv_dim = 16; - return config; -} - -static ModelConfig ConfigGemmaTiny() { - ModelConfig config = ConfigNoSSM(); - config.display_name = "GemmaTiny"; - config.model = Model::GEMMA_TINY; - config.wrapping = PromptWrapping::GEMMA_IT; - config.model_dim = 32; - config.vocab_size = 32; // at least two f32 vectors - config.max_seq_len = 32; - LayerConfig layer_config = LayerConfigGemmaTiny(config.model_dim); - config.num_layers = 2; - config.layer_configs = {config.num_layers, layer_config}; - config.query_scale = QueryScaleType::SqrtKeySize; - config.attention_window_sizes = FixedAttentionWindowSizes<2>(32); - config.att_cap = 50.0f; - config.final_cap = 30.0f; - config.eos_id = 11; - config.secondary_eos_id = 11; - return config; -} - -static LayerConfig LayerConfigGriffin2B(size_t model_dim) { - LayerConfig config; - config.model_dim = model_dim; - config.griffin_dim = model_dim; - config.ff_hidden_dim = 7680; - config.heads = 10; - config.kv_heads = 1; - config.qkv_dim = 256; - config.conv1d_width = 4; - HWY_DASSERT(config.conv1d_width <= kMaxConv1DWidth); - config.ff_biases = true; - config.softmax_attn_output_biases = true; - config.optimized_gating = false; - config.type = LayerAttentionType::kGriffinRecurrentBlock; - config.activation = ActivationType::Gelu; - config.post_qk = PostQKType::HalfRope; - return config; -} - -static ModelConfig ConfigGriffin2B() { - ModelConfig config = ConfigNoSSM(); - config.display_name = "Griffin2B"; - config.model = Model::GRIFFIN_2B; - // Griffin uses local attention, so max_seq_len is actually the local - // attention window. - config.model_dim = 2560; - config.vocab_size = kVocabSize; - config.max_seq_len = 2048; - LayerConfig layer_config = LayerConfigGriffin2B(config.model_dim); - config.num_layers = 26; - config.layer_configs = {config.num_layers, layer_config}; - for (size_t i = 2; i < config.num_layers; i += 3) { - config.layer_configs[i].type = LayerAttentionType::kGemma; - config.layer_configs[i].griffin_dim = 0; - } - config.attention_window_sizes = - FixedAttentionWindowSizes<26>(config.max_seq_len); - config.use_local_attention = true; - config.final_cap = 0.0f; - return config; -} - static LayerConfig LayerConfigVit(size_t model_dim) { LayerConfig config; config.model_dim = model_dim; @@ -510,10 +438,6 @@ static ModelConfig ConfigFromModel(Model model) { return ConfigGemma2_9B(); case Model::GEMMA2_27B: return ConfigGemma2_27B(); - case Model::GRIFFIN_2B: - return ConfigGriffin2B(); - case Model::GEMMA_TINY: - return ConfigGemmaTiny(); case Model::PALIGEMMA2_3B_224: return ConfigPaliGemma2_3B_224(); case Model::PALIGEMMA2_3B_448: @@ -547,10 +471,6 @@ const char* ModelPrefix(Model model) { return "9b"; case Model::GEMMA2_27B: return "27b"; - case Model::GRIFFIN_2B: - return "gr2b"; - case Model::GEMMA_TINY: - return "tiny"; case Model::PALIGEMMA2_3B_224: return "paligemma2-3b-224"; case Model::PALIGEMMA2_3B_448: @@ -750,13 +670,10 @@ bool ModelConfig::OverwriteWithCanonical() { Model DeduceModel(const Path& blob_path, size_t layers, int layer_types) { switch (layers) { - case 2: - return Model::GEMMA_TINY; case 18: return Model::GEMMA3_270M; case 26: - if (layer_types & kDeducedGriffin) return Model::GRIFFIN_2B; if (layer_types & kDeducedViT) return Model::GEMMA3_1B; return Model::GEMMA2_2B; case 27: diff --git a/gemma/configs.h b/gemma/configs.h index 0c93e30..e4a26b8 100644 --- a/gemma/configs.h +++ b/gemma/configs.h @@ -68,14 +68,11 @@ static inline bool EnumValid(PromptWrapping wrapping) { enum class LayerAttentionType { kGemma, - kGriffinRecurrentBlock, kVit, }; static inline bool EnumValid(LayerAttentionType type) { - return type == LayerAttentionType::kGemma || - type == LayerAttentionType::kGriffinRecurrentBlock || - type == LayerAttentionType::kVit; + return type == LayerAttentionType::kGemma || type == LayerAttentionType::kVit; } // Post attention and ffw normalization type. @@ -163,9 +160,8 @@ enum class Model { // 1 and 2 are obsolete. GEMMA2_9B = 3, GEMMA2_27B, - GRIFFIN_2B, - GEMMA_TINY, // for testing only - GEMMA2_2B, + // 5 and 6 are obsolete. + GEMMA2_2B = 7, // 8 and 9 are obsolete. PALIGEMMA2_3B_224 = 10, PALIGEMMA2_3B_448, @@ -199,13 +195,19 @@ static inline bool IsPaliGemma(Model model) { return false; } +static inline bool IsObsolete(Model model) { + const size_t i = static_cast(model); + if (i == 5 || i == 6 || i == 8 || i == 9) return true; + return false; +} + // Visits every valid model enum, skipping `UNKNOWN` and `kSentinel`. template void ForEachModel(const Func& func) { for (size_t i = static_cast(Model::GEMMA2_9B); i < static_cast(Model::kSentinel); ++i) { - if (i == 8 || i == 9) continue; - func(static_cast(i)); + const Model model = static_cast(i); + if (!IsObsolete(model)) func(model); } } @@ -214,7 +216,7 @@ static inline bool EnumValid(Model model) { if (model == Model::UNKNOWN) return true; const size_t i = static_cast(model); if (i >= static_cast(Model::GEMMA2_9B) && - i < static_cast(Model::kSentinel) && i != 8 && i != 9) { + i < static_cast(Model::kSentinel) && !IsObsolete(model)) { return true; } return false; @@ -235,15 +237,20 @@ struct LayerConfig : public IFields { // Source of truth for field ordering. void VisitFields(IFieldsVisitor& visitor) override { + // Formerly used for Griffin. + uint32_t unused_griffin_dim = 0; + uint32_t unused_conv1d_width = 0; + bool unused_softmax_attn_output_biases = false; + visitor(model_dim); - visitor(griffin_dim); + visitor(unused_griffin_dim); visitor(ff_hidden_dim); visitor(heads); visitor(kv_heads); visitor(qkv_dim); - visitor(conv1d_width); + visitor(unused_conv1d_width); visitor(ff_biases); - visitor(softmax_attn_output_biases); + visitor(unused_softmax_attn_output_biases); visitor(optimized_gating); visitor(post_norm); visitor(type); @@ -263,14 +270,11 @@ struct LayerConfig : public IFields { bool IsMHA() const { return heads == kv_heads; } uint32_t model_dim = 0; - uint32_t griffin_dim = 0; uint32_t ff_hidden_dim = 0; uint32_t heads = 0; uint32_t kv_heads = 0; - uint32_t qkv_dim = 0; // length of Q, K, V vectors (contiguous). - uint32_t conv1d_width = 0; // Griffin only + uint32_t qkv_dim = 0; // length of Q, K, V vectors (contiguous). bool ff_biases = false; - bool softmax_attn_output_biases = false; // for Griffin bool optimized_gating = true; // for Gemma3 PostNormType post_norm = PostNormType::None; LayerAttentionType type = LayerAttentionType::kGemma; @@ -358,7 +362,8 @@ struct ModelConfig : public IFields { visitor(final_cap); visitor(absolute_pe); - visitor(use_local_attention); + bool unused_use_local_attention = false; // formerly used for Griffin + visitor(unused_use_local_attention); visitor(query_scale); visitor(layer_configs); visitor(attention_window_sizes); @@ -454,7 +459,6 @@ struct ModelConfig : public IFields { float final_cap = 0.0f; bool absolute_pe = false; - bool use_local_attention = false; // Griffin only QueryScaleType query_scale = QueryScaleType::SqrtKeySize; std::vector layer_configs; std::vector attention_window_sizes; @@ -478,7 +482,6 @@ struct ModelConfig : public IFields { ModelConfig GetVitConfig(const ModelConfig& config); enum DeducedLayerTypes { - kDeducedGriffin = 1, kDeducedViT = 2, kDeduced448 = 4, // For ViT, 448x448 resolution instead of 224x224. }; diff --git a/gemma/gemma.cc b/gemma/gemma.cc index 0177c92..62288ff 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -34,7 +34,6 @@ // After highway.h #include "gemma/attention.h" // includes highway.h #include "gemma/gemma-inl.h" -#include "gemma/griffin.h" // includes highway.h #include "gemma/vit.h" // includes highway.h #ifndef GEMMA_CC_ONCE @@ -77,14 +76,6 @@ void Attention(LayerAttentionType type, const size_t num_tokens, GemmaAttention(num_tokens, layer_idx, layer, activations.attention, qbatch, env, /*flags=*/0); - } else { - HWY_DASSERT(type == LayerAttentionType::kGriffinRecurrentBlock); - // KVCache conv1d_cache and rglru_cache have one row per *Griffin* layer, - // so map `layer` to the Griffin layer index. - const size_t griffin_layer = - activations.attention.config.NumLayersOfTypeBefore(type, layer_idx); - GriffinRecurrent(num_tokens, griffin_layer, &layer, activations, qbatch, - env); } } @@ -484,13 +475,6 @@ static void GenerateT(const ModelConfig& config, const AesCtrEngine& engine, const WeightsPtrs& weights, Activations& activations, QBatch& qbatch, MatMulEnv& env, TimingInfo& timing_info) { - // Griffin assumes that the recurrent block cache is zero-initialized. - for (size_t qi = 0; qi < qbatch.Size(); ++qi) { - if (qbatch.MutablePos(qi) == 0) { - qbatch.KV(qi).ZeroGriffinCache(); // No-op for non-Griffin models. - } - } - size_t max_prompt_size = 0; bool all_prefix_end_are_zero = true; size_t total_prefill_tokens = 0; // only for throughput stats. diff --git a/gemma/griffin.cc b/gemma/griffin.cc deleted file mode 100644 index 35bf29a..0000000 --- a/gemma/griffin.cc +++ /dev/null @@ -1,192 +0,0 @@ -// Copyright 2025 Google LLC -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include - -#include "compression/types.h" // GEMMA_DISABLED_TARGETS -#ifndef HWY_DISABLED_TARGETS -#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS -#endif // HWY_DISABLED_TARGETS - -#include "gemma/activations.h" -#include "gemma/gemma.h" -#include "gemma/gemma_args.h" -#include "gemma/weights.h" -#include "hwy/contrib/thread_pool/thread_pool.h" -#include "hwy/profiler.h" - -// Compiles this file for multiple architectures via "foreach_target.h", to -// which we pass the filename via macro 'argument'. -// clang-format off -#undef HWY_TARGET_INCLUDE -#define HWY_TARGET_INCLUDE "gemma/griffin.cc" // NOLINT -// clang-format on -#include "hwy/foreach_target.h" // IWYU pragma: keep -#include "hwy/highway.h" -// After highway.h -#include "ops/matvec-inl.h" -#include "ops/ops-inl.h" - -HWY_BEFORE_NAMESPACE(); -namespace gcpp { -namespace HWY_NAMESPACE { - -void GriffinRecurrent(size_t num_tokens, size_t griffin_layer, - const LayerWeightsPtrs* layer_weights, - Activations& activations, QBatch& qbatch, - MatMulEnv& env) { - PROFILER_ZONE("Gen.Griffin"); - hwy::ThreadPool& pool = env.ctx.pools.Pool(0); - namespace hn = hwy::HWY_NAMESPACE; - using D = hn::ScalableTag; - const D df; - - const size_t model_dim = layer_weights->layer_config.model_dim; - HWY_DASSERT(model_dim % hn::Lanes(df) == 0); - - const size_t heads = layer_weights->layer_config.heads; - const size_t conv_1d_width = layer_weights->layer_config.conv1d_width; - HWY_ASSERT_M(conv_1d_width % 2 == 0, "Conv width must be even"); - const size_t kHeadDim = model_dim / heads; - const size_t kMatrixSize = kHeadDim * kHeadDim; - - const size_t num_interleaved = num_tokens * qbatch.Size(); - const hwy::Divisor div_qbatch(static_cast(qbatch.Size())); - GriffinActivations& griffin = activations.griffin; - - // X / Y linear layers. - // TODO: MatMul - HWY_DASSERT(griffin.griffin_y.Rows() == griffin.griffin_x.Rows()); - HWY_DASSERT(num_interleaved == griffin.griffin_y.Rows()); - CallUpcastedSame( - &layer_weights->griffin.linear_x_w, &layer_weights->griffin.linear_y_w, - [&](const auto* wx, const auto* wy) { - for (size_t r = 0; r < num_interleaved; ++r) { - float* HWY_RESTRICT y = griffin.griffin_y.Row(r); - float* HWY_RESTRICT x = griffin.griffin_x.Row(r); - TwoMatVecAdd( - *wx, *wy, 0, model_dim, model_dim, - activations.attention.pre_att_rms_out.Row(r), - /*add0=*/layer_weights->griffin.linear_x_biases.PackedScale1(), - /*add1=*/layer_weights->griffin.linear_y_biases.PackedScale1(), - /*out0=*/x, /*out1=*/y, pool); - Gelu(y, model_dim); - } - }); - - // Conv1D. - for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved; - ++interleaved_idx) { - const size_t qi = div_qbatch.Remainder(interleaved_idx); - const size_t batch_idx = div_qbatch.Divide(interleaved_idx); - const size_t pos = qbatch.Pos(qi) + batch_idx; - float* HWY_RESTRICT x = griffin.griffin_x.Row(qi); - - // cache[i] = input at time t-i. - float* HWY_RESTRICT cache[kMaxConv1DWidth]; - cache[0] = x; - for (size_t i = 1; i < conv_1d_width; i++) { - cache[i] = - qbatch.KV(qi).conv1d_cache.Row(griffin_layer) + - ((pos + conv_1d_width - 1 - i) % (conv_1d_width - 1)) * model_dim; - } - for (size_t i = 0; i < model_dim; i += hn::Lanes(df)) { - auto xv = hn::Load(df, x + i); - auto accum0 = - hn::Load(df, layer_weights->griffin.conv_biases.PackedScale1() + i); - auto accum1 = hn::Zero(df); - for (size_t l = 0; 2 * l < conv_1d_width; l++) { - auto wv0 = - hn::Load(df, layer_weights->griffin.conv_w.PackedScale1() + - (conv_1d_width - 1 - 2 * l) * model_dim + i); - auto wv1 = - hn::Load(df, layer_weights->griffin.conv_w.PackedScale1() + - (conv_1d_width - 2 - 2 * l) * model_dim + i); - accum0 = hn::MulAdd(wv0, hn::Load(df, cache[l * 2] + i), accum0); - accum1 = hn::MulAdd(wv1, hn::Load(df, cache[l * 2 + 1] + i), accum1); - } - hn::Store(hn::Add(accum0, accum1), df, x + i); - hn::Store(xv, df, cache[HWY_MAX(conv_1d_width, 1) - 1] + i); - } - } - - // RGLRU - for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved; - ++interleaved_idx) { - const size_t qi = div_qbatch.Remainder(interleaved_idx); - const size_t batch_idx = div_qbatch.Divide(interleaved_idx); - const size_t pos = qbatch.Pos(qi) + batch_idx; - - float* HWY_RESTRICT x = griffin.griffin_x.Row(qi); - float* HWY_RESTRICT y = griffin.griffin_y.Row(qi); - float* HWY_RESTRICT gate_x = griffin.griffin_gate_x.Row(qi); - float* HWY_RESTRICT a = griffin.griffin_multiplier.Row(qi); - float* HWY_RESTRICT rnn_state = - qbatch.KV(qi).rglru_cache.Row(griffin_layer); - - pool.Run(0, heads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR { - size_t head_offset = head * kHeadDim; - CallUpcasted(&layer_weights->griffin.gate_w, [&](const auto* gate_w) { - TwoOfsMatVecAddLoop( - *gate_w, kMatrixSize * head, kMatrixSize * (heads + head), kHeadDim, - kHeadDim, x + head_offset, - /*add0=*/layer_weights->griffin.gate_biases.PackedScale1() + - head_offset, - /*add1=*/layer_weights->griffin.gate_biases.PackedScale1() + - model_dim + head_offset, - /*out0=*/gate_x + head_offset, /*out1=*/a + head_offset); - }); - Sigmoid(gate_x + head_offset, kHeadDim); - Sigmoid(a + head_offset, kHeadDim); - const auto fn_mul = [](D d, hn::Vec x, hn::Vec gate_x) - HWY_ATTR { return hn::Mul(x, gate_x); }; - hn::Transform1(D(), a + head_offset, kHeadDim, - layer_weights->griffin.a.PackedScale1() + head_offset, - fn_mul); - hn::Transform1(D(), x + head_offset, kHeadDim, gate_x + head_offset, - fn_mul); - // RNN scan - HWY_DASSERT(kHeadDim % hn::Lanes(df) == 0); - for (size_t i = 0; i < kHeadDim; i += hn::Lanes(df)) { - auto log_a = hn::Load(df, a + head_offset + i); - auto gated_x = hn::Load(df, x + head_offset + i); - auto rnn = hn::Load(df, rnn_state + head_offset + i); - auto a = hn::Exp(df, log_a); - auto x_multiplier = hn::Sqrt(hn::NegMulAdd(a, a, hn::Set(df, 1.0f))); - if (pos == 0) { - x_multiplier = hn::Set(df, 1.0f); - } - auto new_x = hn::MulAdd(x_multiplier, gated_x, hn::Mul(a, rnn)); - hn::Store(new_x, df, rnn_state + head_offset + i); - - // Join branches. - auto yv = hn::Load(df, y + head_offset + i); - auto pre_out = hn::Mul(yv, new_x); - hn::Store(pre_out, df, x + head_offset + i); - } - }); - } // interleaved_idx - - // Final linear layer. - CallMatMul(griffin.griffin_x, layer_weights->griffin.linear_out_w, - layer_weights->griffin.linear_out_biases.PackedScale1(), env, - activations.attention.att_sums); -} // GriffinRecurrent - -// NOLINTNEXTLINE(google-readability-namespace-comments) -} // namespace HWY_NAMESPACE -} // namespace gcpp -HWY_AFTER_NAMESPACE(); diff --git a/gemma/griffin.h b/gemma/griffin.h deleted file mode 100644 index 0ba6a23..0000000 --- a/gemma/griffin.h +++ /dev/null @@ -1,47 +0,0 @@ -// Copyright 2025 Google LLC -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_GRIFFIN_H_ -#define THIRD_PARTY_GEMMA_CPP_GEMMA_GRIFFIN_H_ - -// Declares GriffinRecurrent for all SIMD targets. - -#include - -#include "gemma/gemma.h" -#include "hwy/highway.h" - -namespace gcpp { - -// Passed to HWY_VISIT_TARGETS; declares for one target. -#define GEMMA_DECL_GRIFFIN(TARGET, NAMESPACE) \ - namespace NAMESPACE { \ - void GriffinRecurrent(size_t num_tokens, size_t griffin_layer, \ - const LayerWeightsPtrs* layer_weights, \ - Activations& activations, QBatch& qbatch, \ - MatMulEnv& env); \ - /* NOLINTNEXTLINE(google-readability-namespace-comments) */ \ - } // namespace NAMESPACE - -// Function declarations for each SIMD target. Allows direct call from the -// per-target namespace. We may later replace this with dynamic dispatch if -// the overhead is acceptable. -HWY_VISIT_TARGETS(GEMMA_DECL_GRIFFIN) - -#undef GEMMA_DECL_GRIFFIN - -} // namespace gcpp - -#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_GRIFFIN_H_ diff --git a/gemma/kv_cache.cc b/gemma/kv_cache.cc index 9d107e8..ca814f4 100644 --- a/gemma/kv_cache.cc +++ b/gemma/kv_cache.cc @@ -24,26 +24,6 @@ namespace gcpp { -void KVCache::ZeroGriffinCache() { - if (conv1d_cache.Rows() == 0) return; - ZeroInit(conv1d_cache); - ZeroInit(rglru_cache); -} - -static size_t GriffinLayers(const ModelConfig& config) { - return config.NumLayersOfType(LayerAttentionType::kGriffinRecurrentBlock); -} - -static size_t GriffinConv1dCols(const ModelConfig& config) { - size_t conv1d_width = 0; - for (const auto& layer_config : config.layer_configs) { - conv1d_width = HWY_MAX(conv1d_width, layer_config.conv1d_width); - } - // The row offset, in blocks of model_dim is computed mod (conv1d_width - 1), - // hence allocate conv1d_width * model_dim total columns. - return conv1d_width * config.model_dim; -} - // Number of rows for KV cache. Note that both rows and cols are u32, and // the total number of elements can exceed 2^32. static size_t CappedSeqLen(const ModelConfig& config, @@ -56,30 +36,18 @@ static size_t CappedSeqLen(const ModelConfig& config, return inference_args.seq_len; } -KVCache::KVCache(const Extents2D& conv1d_extents, - const Extents2D& rglru_extents, const Extents2D& kv_extents, - const Allocator& allocator) - : conv1d_cache("conv1d_cache", conv1d_extents, allocator, MatPadding::kOdd), - rglru_cache("rglru_cache", rglru_extents, allocator, MatPadding::kOdd), - kv_cache("kv", kv_extents, allocator, MatPadding::kOdd), +KVCache::KVCache(const Extents2D& kv_extents, const Allocator& allocator) + : kv_cache("kv", kv_extents, allocator, MatPadding::kOdd), allocator_(allocator) {} KVCache::KVCache(const ModelConfig& config, const InferenceArgs& inference_args, const Allocator& allocator) : KVCache( - Extents2D(GriffinLayers(config), GriffinConv1dCols(config)), - Extents2D(GriffinLayers(config), config.model_dim), Extents2D(CappedSeqLen(config, inference_args), config.KVCacheCols()), allocator) {} KVCache KVCache::Copy() { - KVCache copy(conv1d_cache.Extents(), rglru_cache.Extents(), - kv_cache.Extents(), allocator_); - - if (conv1d_cache.Rows() != 0) { - CopyMat(conv1d_cache, copy.conv1d_cache); - CopyMat(rglru_cache, copy.rglru_cache); - } + KVCache copy(kv_cache.Extents(), allocator_); CopyMat(kv_cache, copy.kv_cache); diff --git a/gemma/kv_cache.h b/gemma/kv_cache.h index 7b5b88d..31e964b 100644 --- a/gemma/kv_cache.h +++ b/gemma/kv_cache.h @@ -35,24 +35,15 @@ struct KVCache { // copy ctor to make the cost explicit. KVCache Copy(); - // Zero-initialize the Griffin recurrent block cache, i.e. the conv1d_cache - // and rglru_cache. - void ZeroGriffinCache(); - size_t SeqLen() const { return kv_cache.Rows(); } - // [griffin_layers, griffin_conv1d_cols * model_dim] - MatStorageT conv1d_cache; - MatStorageT rglru_cache; // [griffin_layers, model_dim] - MatStorageT kv_cache; // [seq_len, layers * kv_heads * qkv_dim * 2] private: const Allocator& allocator_; // For use by other ctor and Copy() - KVCache(const Extents2D& conv1d_extents, const Extents2D& rglru_extents, - const Extents2D& kv_extents, const Allocator& allocator); + KVCache(const Extents2D& kv_extents, const Allocator& allocator); }; } // namespace gcpp diff --git a/gemma/model_store.cc b/gemma/model_store.cc index 2aab1f5..a20caf2 100644 --- a/gemma/model_store.cc +++ b/gemma/model_store.cc @@ -221,9 +221,6 @@ static int DeduceLayerTypes(const BlobReader& reader) { int layer_types = 0; for (size_t key_idx = 0; key_idx < reader.Keys().size(); ++key_idx) { const std::string& key = reader.Keys()[key_idx]; - if (key.find("gr_conv_w") != std::string::npos) { // NOLINT - return kDeducedGriffin; - } if (key.find("qkv_ein_w") != std::string::npos) { // NOLINT layer_types |= kDeducedViT; } @@ -293,7 +290,7 @@ static std::vector ReadScales(BlobReader& reader, const ModelConfig& config) { std::vector scales; // Check first to prevent `CallWithSpan` from printing a warning. This blob is - // optional even in pre-2025 format; Griffin was the first to include it. + // optional even in pre-2025 format. if (reader.Find(kDecoratedScalesName)) { HWY_ASSERT(reader.CallWithSpan( kDecoratedScalesName, diff --git a/gemma/tensor_info.cc b/gemma/tensor_info.cc index de93cf9..05f829b 100644 --- a/gemma/tensor_info.cc +++ b/gemma/tensor_info.cc @@ -277,122 +277,6 @@ void TensorInfoRegistry::AddImageLayerTensors(const ModelConfig& config, }); } -void TensorInfoRegistry::AddGriffinLayerTensors(const LayerConfig& layer_config, - const size_t layer_idx) { - const std::string suffix = LayerSuffix(layer_idx); - Add(suffix, { - .base_name = "gr_lin_x_w", - .source_names = {"recurrent_block/linear_x/kernel"}, - .axes = {1, 0}, - .shape = {layer_config.griffin_dim, layer_config.griffin_dim}, - }); - Add(suffix, { - .base_name = "gr_lin_x_b", - .source_names = {"recurrent_block/linear_x/bias"}, - .axes = {0}, - .shape = {layer_config.griffin_dim}, - .min_size = Type::kF32, - }); - Add(suffix, { - .base_name = "gr_lin_y_w", - .source_names = {"recurrent_block/linear_y/kernel"}, - .axes = {1, 0}, - .shape = {layer_config.griffin_dim, layer_config.griffin_dim}, - }); - Add(suffix, { - .base_name = "gr_lin_y_b", - .source_names = {"recurrent_block/linear_y/bias"}, - .axes = {0}, - .shape = {layer_config.griffin_dim}, - .min_size = Type::kF32, - }); - Add(suffix, { - .base_name = "gr_lin_out_w", - .source_names = {"recurrent_block/linear_out/kernel"}, - .axes = {1, 0}, - .shape = {layer_config.griffin_dim, layer_config.griffin_dim}, - }); - Add(suffix, { - .base_name = "gr_lin_out_b", - .source_names = {"recurrent_block/linear_out/bias"}, - .axes = {0}, - .shape = {layer_config.griffin_dim}, - .min_size = Type::kF32, - }); - Add(suffix, - { - .base_name = "gr_conv_w", - .source_names = {"recurrent_block/conv_1d/w"}, - .axes = {0, 1}, - .shape = {layer_config.conv1d_width, layer_config.griffin_dim}, - .min_size = Type::kF32, - }); - Add(suffix, { - .base_name = "gr_conv_b", - .source_names = {"recurrent_block/conv_1d/b"}, - .axes = {0}, - .shape = {layer_config.griffin_dim}, - .min_size = Type::kF32, - }); - Add(suffix, { - .base_name = "gr1_gate_w", - .source_names = {"recurrent_block/rg_lru/input_gate/w"}, - .axes = {0, 2, 1}, - .shape = {layer_config.heads, - layer_config.griffin_dim / layer_config.heads, - layer_config.griffin_dim / layer_config.heads}, - .concat_names = {"gr_gate_w", "gr2_gate_w"}, - }); - Add(suffix, { - .base_name = "gr2_gate_w", - .source_names = {"recurrent_block/rg_lru/a_gate/w"}, - .axes = {0, 2, 1}, - .shape = {layer_config.heads, - layer_config.griffin_dim / layer_config.heads, - layer_config.griffin_dim / layer_config.heads}, - .concat_names = {""}, - }); - Add(suffix, { - .base_name = "gr_gate_w", - .source_names = {"recurrent_block/rg_lru/gate/w"}, - .axes = {0, 2, 1}, - .shape = {2 * layer_config.heads, - layer_config.griffin_dim / layer_config.heads, - layer_config.griffin_dim / layer_config.heads}, - }); - Add(suffix, { - .base_name = "gr1_gate_b", - .source_names = {"recurrent_block/rg_lru/input_gate/b"}, - .axes = {0}, - .shape = {layer_config.griffin_dim}, - .concat_names = {"gr_gate_b", "gr2_gate_b"}, - .min_size = Type::kF32, - }); - Add(suffix, { - .base_name = "gr2_gate_b", - .source_names = {"recurrent_block/rg_lru/a_gate/b"}, - .axes = {0}, - .shape = {layer_config.griffin_dim}, - .concat_names = {""}, - .min_size = Type::kF32, - }); - Add(suffix, { - .base_name = "gr_gate_b", - .source_names = {"recurrent_block/rg_lru/input_gate/b"}, - .axes = {0, 1}, - .shape = {2 * layer_config.griffin_dim}, - .min_size = Type::kF32, - }); - Add(suffix, { - .base_name = "gr_a", - .source_names = {"recurrent_block/rg_lru/a_param"}, - .axes = {0}, - .shape = {layer_config.griffin_dim}, - .min_size = Type::kF32, - .scaled_softplus = true, - }); -} - void TensorInfoRegistry::AddLayerTensors(const ModelConfig& config, const LayerConfig& layer_config, const size_t layer_idx) { @@ -553,10 +437,6 @@ void TensorInfoRegistry::AddLayerTensors(const ModelConfig& config, .shape = {config.model_dim, layer_config.heads, layer_config.qkv_dim}, .cols_take_extra_dims = true, }); - - if (config.model == Model::GRIFFIN_2B) { - AddGriffinLayerTensors(layer_config, layer_idx); - } } TensorInfoRegistry::TensorInfoRegistry(const ModelConfig& config) { diff --git a/gemma/tensor_info.h b/gemma/tensor_info.h index c8252a4..d2b25d9 100644 --- a/gemma/tensor_info.h +++ b/gemma/tensor_info.h @@ -124,8 +124,6 @@ class TensorInfoRegistry { void AddModelTensors(const ModelConfig& config); void AddLayerTensors(const ModelConfig& config, const LayerConfig& layer_config, size_t layer_idx); - void AddGriffinLayerTensors(const LayerConfig& layer_config, - size_t layer_idx); void AddImageLayerTensors(const ModelConfig& config, const LayerConfig& layer_config, diff --git a/gemma/weights.cc b/gemma/weights.cc index 3d1d43e..8191bd9 100644 --- a/gemma/weights.cc +++ b/gemma/weights.cc @@ -88,7 +88,7 @@ void LayerWeightsPtrs::InitAttWeights(std::vector& mat_owners, // For FFN. Fast, only updates pointers. void LayerWeightsPtrs::SplitW1() { - // Used for Gemma and Griffin layers; FFWVit uses different tensors. + // Used for Gemma layers; FFWVit uses different tensors. if (layer_config.type == LayerAttentionType::kVit) return; // Files have both or neither of w1 and w2. diff --git a/gemma/weights.h b/gemma/weights.h index de3652a..06c0186 100644 --- a/gemma/weights.h +++ b/gemma/weights.h @@ -57,8 +57,7 @@ struct TensorArgs { // the _w1/_w2 tensors are not always present. kMaybeRead = 1, - // Avoid padding tensor rows when reading. Used for some Griffin tensors - // whose index computations do not use Row() accessors. + // Avoid padding tensor rows when reading. kPacked = 2, }; const int flags; @@ -102,17 +101,6 @@ struct LayerWeightsPtrs { qkv_einsum_w1(finder_("qkv1_w")), qkv_einsum_w2(finder_("qkv2_w")), attention_output_biases(finder_("attn_ob")), - griffin({.linear_x_w = finder_("gr_lin_x_w"), - .linear_x_biases = finder_("gr_lin_x_b"), - .linear_y_w = finder_("gr_lin_y_w"), - .linear_y_biases = finder_("gr_lin_y_b"), - .linear_out_w = finder_("gr_lin_out_w"), - .linear_out_biases = finder_("gr_lin_out_b"), - .conv_w = finder_("gr_conv_w"), - .conv_biases = finder_("gr_conv_b"), - .gate_w = finder_("gr_gate_w"), - .gate_biases = finder_("gr_gate_b"), - .a = finder_("gr_a")}), // MultiHeadDotProductAttention. vit({.attn_out_w = finder_("attn_out_w"), .attn_out_b = finder_("attn_out_b"), @@ -156,20 +144,6 @@ struct LayerWeightsPtrs { MatPtr qkv_einsum_w2; MatPtrT attention_output_biases; - struct { - MatPtr linear_x_w; - MatPtrT linear_x_biases; - MatPtr linear_y_w; - MatPtrT linear_y_biases; - MatPtr linear_out_w; - MatPtrT linear_out_biases; - MatPtrT conv_w; - MatPtrT conv_biases; - MatPtr gate_w; - MatPtrT gate_biases; - MatPtrT a; - } griffin; - struct { // MultiHeadDotProductAttention. MatPtr attn_out_w; // at least BF16. @@ -244,20 +218,6 @@ struct LayerWeightsPtrs { func(TENSOR_ARGS(qkv_einsum_w, kMaybeRead)); func(TENSOR_ARGS(qkv_einsum_w1, kMaybeRead)); func(TENSOR_ARGS(qkv_einsum_w2, kMaybeRead)); - } else { - func(TENSOR_ARGS(griffin.linear_x_w, kMustRead)); - func(TENSOR_ARGS(griffin.linear_x_biases, kMustRead)); - func(TENSOR_ARGS(griffin.linear_y_w, kMustRead)); - func(TENSOR_ARGS(griffin.linear_y_biases, kMustRead)); - func(TENSOR_ARGS(griffin.linear_out_w, kMustRead)); - func(TENSOR_ARGS(griffin.linear_out_biases, kMustRead)); - // conv_w and gate_w are not accessed via Row(), hence must not be padded. - // Note that *biases are 1D, hence packing/padding does not matter. - func(TENSOR_ARGS(griffin.conv_w, kMustRead | TensorArgs::kPacked)); - func(TENSOR_ARGS(griffin.conv_biases, kMustRead)); - func(TENSOR_ARGS(griffin.gate_w, kMustRead | TensorArgs::kPacked)); - func(TENSOR_ARGS(griffin.gate_biases, kMustRead)); - func(TENSOR_ARGS(griffin.a, kMustRead)); } { func(TENSOR_ARGS(gating_einsum_w, kMaybeRead)); @@ -281,11 +241,6 @@ struct LayerWeightsPtrs { func(TENSOR_ARGS(ffw_gating_biases, kMustRead)); func(TENSOR_ARGS(ffw_output_biases, kMustRead)); } - - if (layer_config.softmax_attn_output_biases && - layer_config.type == LayerAttentionType::kGemma) { - func(TENSOR_ARGS(attention_output_biases, kMustRead)); - } } // `ForEachTensor` // Zero-initializes all allocated tensors in the layer. diff --git a/python/configs.cc b/python/configs.cc index f8121bf..086c691 100644 --- a/python/configs.cc +++ b/python/configs.cc @@ -57,8 +57,6 @@ PYBIND11_MODULE(configs, py_module) { enum_(py_module, "LayerAttentionType") .value("kGemma", LayerAttentionType::kGemma) - .value("kGriffinRecurrentBlock", - LayerAttentionType::kGriffinRecurrentBlock) .value("kVit", LayerAttentionType::kVit); enum_(py_module, "PostNormType") @@ -84,8 +82,6 @@ PYBIND11_MODULE(configs, py_module) { .value("UNKNOWN", Model::UNKNOWN) .value("GEMMA2_9B", Model::GEMMA2_9B) .value("GEMMA2_27B", Model::GEMMA2_27B) - .value("GRIFFIN_2B", Model::GRIFFIN_2B) - .value("GEMMA_TINY", Model::GEMMA_TINY) .value("GEMMA2_2B", Model::GEMMA2_2B) .value("PALIGEMMA2_3B_224", Model::PALIGEMMA2_3B_224) .value("PALIGEMMA2_10B_224", Model::PALIGEMMA2_10B_224) @@ -121,15 +117,11 @@ PYBIND11_MODULE(configs, py_module) { class_(py_module, "LayerConfig") .def(init()) .def_readwrite("model_dim", &LayerConfig::model_dim) - .def_readwrite("griffin_dim", &LayerConfig::griffin_dim) .def_readwrite("ff_hidden_dim", &LayerConfig::ff_hidden_dim) .def_readwrite("heads", &LayerConfig::heads) .def_readwrite("kv_heads", &LayerConfig::kv_heads) .def_readwrite("qkv_dim", &LayerConfig::qkv_dim) - .def_readwrite("conv1d_width", &LayerConfig::conv1d_width) .def_readwrite("ff_biases", &LayerConfig::ff_biases) - .def_readwrite("softmax_attn_output_biases", - &LayerConfig::softmax_attn_output_biases) .def_readwrite("optimized_gating", &LayerConfig::optimized_gating) .def_readwrite("post_norm", &LayerConfig::post_norm) .def_readwrite("type", &LayerConfig::type) @@ -166,7 +158,6 @@ PYBIND11_MODULE(configs, py_module) { .def_readwrite("att_cap", &ModelConfig::att_cap) .def_readwrite("final_cap", &ModelConfig::final_cap) .def_readwrite("absolute_pe", &ModelConfig::absolute_pe) - .def_readwrite("use_local_attention", &ModelConfig::use_local_attention) .def_readwrite("query_scale", &ModelConfig::query_scale) .def_readwrite("layer_configs", &ModelConfig::layer_configs) .def_readwrite("attention_window_sizes",