gemma.cpp/compression/python/pytree/pytree_transforms.py

509 lines
20 KiB
Python

"""Tools for transforming "nested python object" tree data structures.
# Context
The motivation for this module came from ML applications that ought to
be based on a principled handling of nested Python data structures.
Having such principled pytree-transforming code available solves
some other problems, such as doing away with a need to abuse
tree-mapping for-side-effect-only and having to use a hope-and-pray
approach to processing very deeply nested values which with a recursive
approach might trigger a RecursionError.
We specifically want to cover the use case of having ML model
parameters that are available in a nested Python data structure for
which there "almost" is a unique-up-to-unique-isomorphism mapping from
and to this Algebraic Data Type:
`data ModelParams a = Array a | Node [(String, ModelParams a)]`
In this correspondence, `a` is some array-type (perhaps
`numpy.ndarray`, `jax.numpy.ndarray`, `tf.tensor`, etc.), but the
data-processing code is effectively entirely agnostic to this, and a
`Node` is "almost" an associative-list of (key, value) pairs
representing a Python dict. (Note: The "almost" here is mostly about
the conceptual wart that assoc-lists can in principle have key
duplicates, but Python dicts can not. This is however not a problem
since all we need is the transformation in one direction,
i.e. whatever data-processing `f` we want to express on the
model-parameters-pytree, we can express by specifying a "faithful"
mapping `m` into the above algebraic data type through which every
such pytree data transform factorizes, i.e. for every `f` we can find
a `g` such that `f(p) = g(m(p))`.)
## Components
The main workhorse in this module is the `pytree_iter` function that
maps a "PyTree (such as representing `ModelParams`)" to an iterator
over values obtained by applying a mapping-function to the "key-path"
and leaf-value for every leaf, where the "key-path" contains a
linked-list representation of the reversed sequence of keys from the
tree-root, with list-nodes being represented by pairs
`(latest_dict_key, rest_path)`, and the empty path being represented
by `()`.
For the sake of genericity, `pytree_iter` is built in such a way that
it actually can handle any kind of traversal of PyTree-trees that do
represent algebraic data types (note however that some some do not) -
but for this to make sense, the user must have a way to define how to
interpret tree-nodes, in particular identify leaves. This requires
providing a function `node_handler` with the same signature and
behavior as described below for "node handlers".
Additionally, this module provides mapping-over-pytrees via
`pytree_map`, which is also built in such a way that it makes the
correspondence between an algebraic data type and its Python
nested-tree representation explicit. Despite being powerful and
flexible, this, however, may in general require a bit more effort to
wire up, since node-rebuilding can be fairly nontrivial.
Furthermore, as a prominent application, this module provides a simple
deep-freezing function that translates a nested Python data structure
to deeply-immutable form.
## Concepts and Conventions
"revtuple representation":
As we iterate over a tree, we will have to keep track of the
path-to-tree-root. Naturally, two sibling nodes `n1` and `n2`
will share the same parent-path (being siblings), so it makes
sense to use a linked-list-with-shared-tail representation.
Python does not have a natural notion for that, so we use
recursively-constructed tuples `(node_tag, parent_path)`
that represent the path-from-root in-reverse-order, i.e.
for a non-empty path `p`, `p[0]` is the node-tag at the
deepest nesting level. We call this a "revtuple representation"
of the path.
"node handler":
A node-handler classifies a tree-node as "leaf or other node", and
for non-leaf nodes provides information about both its children and
how to rebuild it. The behavior of a node-handler function must be
in alignment with this docstring:
'''Processes a tree-node as required by pytree-iteration and -mapping.
Args:
revtuple_path: Revtuple-representation of the path-from-root
to the current node.
node: a tree-node in a ML-model tree that is recursively
built out of leaf-values and other nodes.
Returns:
`None` if the tree-node is to be regarded as a leaf, otherwise
a pair `(rebuilder, iterator)`, where `iterator` iterates
over the data-content of the node, each item represented as a pair
of `(lookup_path_item, value_item)`, and `rebuilder` is a function
which, when applied to an iterable of the aforementioned value-items
(or some transformation thereof) returns a node that is equivalent
to the original (or up to a transformation of the contents).
Raises:
InvalidTreeNodeError: If the tree contains a node of a kind
that is not expected to show up.
'''
Examples:
(The behavior of a node-handler is somewhat nontrivial, so covering
two very common cases via examples is in order.)
This node-handler would allow descending into (nested)
instances of `list` (but not subclass instances thereof):
```def list_node_handler(revtuple_path, obj):
''' ... '''
if type(obj) is list:
return list, enumerate(obj)
else:
return None
```
This node-handler would allow descending into (nested) mappings,
which upon rebuilding would get turned into `dict` instances:
```def mapping_node_handler(revtuple_path, obj):
''' ... '''
if isinstance(obj, collections.abc.Mapping):
# For generic mappings, we cannot rely on key- and item-iteration
# being guaranteed to use identical iteration-order.
items = list(obj.items())
keys = [kv[0] for kv in items]
return (lambda values: dict(zip(keys, values))), items
else:
return None
```
A dict/mapping node-handler can of course rename keys, add or remove
entries, make decisions based on the item-path, or map a dict to
an associative list, etc.
## Further Design Notes
The `pytree_map` function requests the leaf-transform and node-handler
to be side-effect-free functions. This is both required to leave
implementation-side flexibility, and also follows the general LISP
recommendation to not abuse mapping (which should be a pure
data-transformation) for imperative data processing. Overall, if
a need for more general "nested datastructures" processing becomes
pressing, it is for the better if this leads to a proper articulation
of the specific needs, to be addressed with appropriate design, rather
than abuse of functional data-transforms becoming "a bad idiom
that turned into established practice".
"""
import collections.abc
import immutabledict
import numpy
from typing import Any, Callable, Iterable, Iterator, TypeVar
T = TypeVar('T')
U = TypeVar('U')
KT = TypeVar('KT')
NT = TypeVar('NT')
## Type of the reverse-order-keys-to-root path.
# (This code actually illustrates why https://xkcd.com/2483/ is very misguided.)
RevTuplePath = tuple
## Type of the `leaf_transform` function-argument used for tree-iteration.
#
# This would be the correct type we would have to specify here but cannot,
# since the design of Python's static typing at the time of this writing
# is too broken for that:
#
# type LeafTransformFunc[L, R] = Callable[[RevTuplePath, L], R]
#
# Instead, we have to settle for...:
LeafTransformFunc = Callable[[RevTuplePath, Any], Any]
## Type of the `tree_node_handler` function-argument used for
## tree-iteration and tree-mapping.
#
# Again, this is the correct type we would have to put here but cannot:
#
# type NodeHandlerFunc[KT] = (
# Callable[[NT],
# None | tuple[Callable[[Iterable[tuple[KT, NT]]], NT],
# Iterator[tuple[KT, NT]]]])
#
# ...so, we have to instead settle for:
NodeHandlerFunc = (
Callable[[RevTuplePath, NT],
None | tuple[Callable[[Iterable[tuple[Any, NT]]], NT],
Iterator[tuple[Any, NT]]]])
Predicate = Callable[[object], bool]
class InvalidTreeNodeError(ValueError):
"""Encountered a tree-node of invalid type."""
def linearize_revtuple_path(
revtuple_path: RevTuplePath,
present_as: Callable[[Iterator[T]], U] = tuple) -> U:
"""Translates a revtuple path to (typically) linear form.
With default `present_as`, this will map a path of the form
`(key_{N}, (key_{N-1}, ..., (root, ())))` into a tuple
(root, ..., key_{N-1}, key_{N}).
Args:
revtuple_path: A linked-list-as-recursive-pairs
reverse-order tuple-representation of the path.
Path-root is `()`, and node-key `x` relative to
earlier path `p` is represented as `(x, p)`.
present_as: Callable that consumes an iterator over
path-pieces - with the deepest-nesting level coming last -
turning it into a linearized path. Defaults to `tuple`.
Returns:
Linearized presentation of all the node-keys in the
recursive-path in order, deepest-down path component coming last.
"""
pieces = []
todo = revtuple_path
while todo:
node, todo = todo
pieces.append(node)
return present_as(reversed(pieces))
# This function itself has type `NodeHandlerFunc`, but Python does not
# allow us to here simply type-annotate it like this. We cannot even
# introduce an abbreviation for the complicated output-type,
# since that would have to be parametric in node-type `NT` (and `KT`).
def everything_is_a_leaf_node_handler(
revtuple_path: tuple,
node : NT) -> (
None | tuple[Callable[[Iterable[tuple[Any, NT]]], NT],
Iterator[tuple[Any, NT]]]):
"""Processes a tree-node as required by pytree-iteration and -mapping.
Interface and signature are in alignment with the requirements for a
"node handler" function explained in the module-docstring.
Args:
revtuple_path: the path-to-root for this node.
node: a tree-node.
Returns:
`None`, i.e. classifying any kind of node as a leaf-node.
"""
del revtuple_path, node # Unused.
return None
def leaf_summary(path: RevTuplePath, x: object):
"""Produces a human-readable summary-string for a leaf-node.
Args:
path: revtuple representation of the path-to-root.
x: The leaf-value.
"""
del path # Ignored here.
tx = type(x)
mod = tx.__module__
modname = mod if isinstance(mod, str) else mod.__name__
type_str = f'{modname}.{tx.__qualname__}'
repr_str = repr(x)
repr_abbrev = repr_str if len(repr_str) < 40 else repr_str[:40] + ' ...'
# On str, int, float, etc. `{type_str}(repr(x))` would actually still be
# a (non-literal) Python-expression that would evaluate to the original value.
# However, we make no promises beyond "human-readable".
return f'{type_str}({repr_abbrev})'
# With respect to static type annotations, the limitations of Python's
# approach to static typing really become prominently visible here.
#
# Different arguments have type-parameters, but since there is no way
# to have parametric abbreviations such as `LeafTransformFunc[L, R]`,
# the only way we would have available to express relations between
# type-parameters would be to substitute in the not-abbreviated form of
# `NodeHandlerFunc` and `LeafTransformFunc`, giving us something monstrous.
# We instead here settle for "we cannot express that `tree` must
# have the same type as the input-type to `tree_node_handler` and use `Any`,
# and likewise for leaf_transform and the output.
def pytree_leaf_iter(
tree: Any,
leaf_transform: LeafTransformFunc,
node_handler: NodeHandlerFunc = everything_is_a_leaf_node_handler,
) -> Iterator[Any]:
# ...actual return type would be `Iterator[{what leaf_transform returns}]`.
"""Iterates over the leaves of a tree.
Args:
tree: The tree to iterate over.
leaf_transform: A callable `f` that will get applied
as `f(revtuple_path, leaf)`, where `revtuple_path`
is the revtuple representation of the path to the
leaf from the root.
node_handler: A "node handler" (see module docstring)
that processes nodes encountered during iterative traversal.
Yields:
Value of `leaf_transform(p, x)`, where `x` is the current leaf
and `p` is its revtuple-path to the root.
"""
# Note: Exit points for the code below are in non-obvious places
# and hence marked with " # ***EXIT***".
#
# Doing iteration properly is slightly nontrivial.
# One may be tempted to go for a very simple recursive implementation
# (with an extra pre-final `path` argument to `pytree_iter`):
#
# maybe_substructure = node_handler(path, tree)
# if maybe_substructure is None:
# # We are looking at a leaf-node.
# yield leaf_transform(path, tree)
# else:
# _, contents_iter = maybe_substructure
# for k, v in contents_iter:
# yield from pytree_iter(v, leaf_transform, (k, path), node_handler)
#
# That, however, would be flawed, since there is no a priori reason
# why a pytree may not be a very deeply nested structure - such as a
# long linked list. That would then risk raising `RecursionError`,
# and since Python by design(!) does not perform tail call elimination
# or any other kind of advanced CPS transforms, there is no recursive
# solution here. So, to do this properly, we have to do this iteratively.
#
# We are facing an annoying situation here: If `tree` itself is a leaf,
# we have two options: (a) wrapping it up in a one-node tree
# and processing that, or (b) special-casing "root is a leaf".
# Option (b) leads to some mild node-processing code-duplication
# for a single node (the root).
# Option (a) requires having special cases for node-processing that
# get looked at for every tree node. We go with option (b) here.
maybe_substructure = node_handler((), tree)
if maybe_substructure is None:
# The tree itself is a leaf.
yield leaf_transform((), tree)
return # ***EXIT***
# Otherwise, we are looking at a tree.
_, contents_iter = maybe_substructure
current_revtuple_path = ()
work_to_do = [contents_iter]
# Otherwise-unreachable sentinel for reliably identifying
# iterator-exhaustion without using exceptions:
sentinel = object()
while True:
current_iter = work_to_do[-1]
maybe_next_item = next(current_iter, sentinel)
if maybe_next_item is sentinel:
# We are done at this level.
work_to_do.pop()
if not work_to_do: return # ***EXIT***
current_revtuple_path = current_revtuple_path[1]
else:
path_piece, subtree = maybe_next_item
extended_revtuple_path = (path_piece, current_revtuple_path)
maybe_subtree_substructure = node_handler(extended_revtuple_path, subtree)
if maybe_subtree_substructure is None: # Case: subtree is a leaf.
yield leaf_transform(extended_revtuple_path, subtree)
else: # Case: subtree is a tree.
current_revtuple_path = (path_piece, current_revtuple_path)
_, items_iter = maybe_subtree_substructure
work_to_do.append(items_iter)
# The current design approach here would be appropriate for
# applying leaf-transforms while retaining the structure of the tree -
# which closely corresponds to e.g. a (a -> b) -> (Tree a -> Tree b) functor.
#
# It is not entirely clear whether this is the abstraction that we should
# consider as being appropriately generic to flesh out explicitly - rather
# than starting from a more general approach of which this then is a special
# case. Some background: https://ncatlab.org/nlab/show/recursion+scheme
#
# On the other hand, there is a lot of flexibility via whatever
# node-rebuilder a node-handler produces - this can do quite some reshaping
# of a tree, including dropping or duplicating nodes.
def pytree_map(
tree: Any,
leaf_transform,
node_handler: NodeHandlerFunc = everything_is_a_leaf_node_handler,
):
"""Maps a (potentially nested) Python value to another such value.
Args:
tree: The Python-object to be mapped.
leaf_transform: A callable `f` that will get applied
as `f(revtuple_path, leaf)`, where `revtuple_path`
is the revtuple representation of the path to the
leaf from the root. Must be side effect free.
node_handler: A "node handler" (see module docstring)
that processes nodes encountered during iterative traversal.
Must be side effect free.
Returns:
The outcome of translating `tree`.
"""
# Note: Exit points for the code below are in non-obvious places
# and hence marked with " # ***EXIT***".
#
# Otherwise-inaccessible sentinel object, for reliably identifying
# missing-values via identity-check against sentinel lookup-default.
sentinel = object()
# Code structure mostly follows pytree_leaf_iter.
maybe_substructure = node_handler((), tree)
if maybe_substructure is None:
return leaf_transform((), tree) # ***EXIT***
rebuilder, items_iter = maybe_substructure
current_revtuple_path = ()
# Per-level, we have a triplet of:
# (rebuilder, remaining_items_to_iterate_over, processed).
parts_for_assembly = [(rebuilder, items_iter, [])]
while True:
this_rebuilder, this_items_iter, this_done_pieces = parts_for_assembly[-1]
maybe_next_item = next(this_items_iter, sentinel)
if maybe_next_item is sentinel:
# We are done with all the items for this level.
parts_for_assembly.pop()
built_iter = this_rebuilder(this_done_pieces)
if not parts_for_assembly: # No outer structure, so at-top-level.
return built_iter # ***EXIT***
else: # We have outer structure.
parts_for_assembly[-1][-1].append(built_iter)
current_revtuple_path = current_revtuple_path[1]
continue # ...with next is-the-final-item-complete-check.
else:
# More constituents of the current item.
path_piece, subtree_item = maybe_next_item
extended_revtuple_path = (path_piece, current_revtuple_path)
maybe_subtree_substructure = node_handler(
extended_revtuple_path,
subtree_item)
if maybe_subtree_substructure is None:
this_done_pieces.append(
leaf_transform(extended_revtuple_path, subtree_item))
else:
# We have a subtree.
subtree_rebuilder, subtree_items_iter = maybe_subtree_substructure
current_revtuple_path = (path_piece,
current_revtuple_path)
parts_for_assembly.append(
(subtree_rebuilder, subtree_items_iter, []))
def deep_freeze(
tree,
*,
is_mapping : Predicate = lambda x: isinstance(x, collections.abc.Mapping),
is_set : Predicate = lambda x: isinstance(x, collections.abc.Set),
is_sequence : Predicate = lambda x: isinstance(x, (list, tuple)),
leaf_fn: Callable[[Any], Any] = lambda x: x,
):
"""Recursively freezes Set/Mapping/List/Tuple structures.
Args:
tree: The potentially deeply-nested object to deep-freeze.
is_mapping: Callable that decides whether a sub-object is a mapping.
Defaults to an `isinstance()` check for `collections.abc.Mapping`.
is_set: Callable that decides whether a sub-object is a set.
Defaults to an `isinstance()` check for `collections.abc.Set`.
is_sequence: Callable that decides whether a sub-object is a sequence.
Defaults to a check for being a `tuple` or `list` instance.
leaf_fn: Function to use for translating non-mapping/set/sequence
instances.
Returns:
Translated-to-deeply-immutable form of `tree`.
"""
idict = immutabledict.immutabledict
def freeze_node_handler(path, x):
if is_set(x):
return frozenset, ((None, y) for y in x)
if is_mapping(x):
# Mappings already have hashable, so
# (should-be-)deeply-immutable keys.
# Hence, we only need to deep-freeze the values.
#
# Note that non-`dict` mappings might not guarantee
# to respect iteration-order, so we have to be careful here:
items = list(x.items())
keys = [kv[0] for kv in items]
values = [kv[1] for kv in items]
return ((lambda ys: idict(zip(keys, ys))),
iter(items))
if is_sequence(x):
return tuple, enumerate(iter(x))
# Otherwise, this should not be traversed.
return None
def leaf_transform(revtuple_path, value):
del revtuple_path # Unused.
return leaf_fn(value)
return pytree_map(tree, leaf_transform, freeze_node_handler)