mirror of https://github.com/google/gemma.cpp.git
509 lines
20 KiB
Python
509 lines
20 KiB
Python
"""Tools for transforming "nested python object" tree data structures.
|
|
|
|
# Context
|
|
|
|
The motivation for this module came from ML applications that ought to
|
|
be based on a principled handling of nested Python data structures.
|
|
Having such principled pytree-transforming code available solves
|
|
some other problems, such as doing away with a need to abuse
|
|
tree-mapping for-side-effect-only and having to use a hope-and-pray
|
|
approach to processing very deeply nested values which with a recursive
|
|
approach might trigger a RecursionError.
|
|
|
|
We specifically want to cover the use case of having ML model
|
|
parameters that are available in a nested Python data structure for
|
|
which there "almost" is a unique-up-to-unique-isomorphism mapping from
|
|
and to this Algebraic Data Type:
|
|
|
|
`data ModelParams a = Array a | Node [(String, ModelParams a)]`
|
|
|
|
In this correspondence, `a` is some array-type (perhaps
|
|
`numpy.ndarray`, `jax.numpy.ndarray`, `tf.tensor`, etc.), but the
|
|
data-processing code is effectively entirely agnostic to this, and a
|
|
`Node` is "almost" an associative-list of (key, value) pairs
|
|
representing a Python dict. (Note: The "almost" here is mostly about
|
|
the conceptual wart that assoc-lists can in principle have key
|
|
duplicates, but Python dicts can not. This is however not a problem
|
|
since all we need is the transformation in one direction,
|
|
i.e. whatever data-processing `f` we want to express on the
|
|
model-parameters-pytree, we can express by specifying a "faithful"
|
|
mapping `m` into the above algebraic data type through which every
|
|
such pytree data transform factorizes, i.e. for every `f` we can find
|
|
a `g` such that `f(p) = g(m(p))`.)
|
|
|
|
## Components
|
|
|
|
The main workhorse in this module is the `pytree_iter` function that
|
|
maps a "PyTree (such as representing `ModelParams`)" to an iterator
|
|
over values obtained by applying a mapping-function to the "key-path"
|
|
and leaf-value for every leaf, where the "key-path" contains a
|
|
linked-list representation of the reversed sequence of keys from the
|
|
tree-root, with list-nodes being represented by pairs
|
|
`(latest_dict_key, rest_path)`, and the empty path being represented
|
|
by `()`.
|
|
|
|
For the sake of genericity, `pytree_iter` is built in such a way that
|
|
it actually can handle any kind of traversal of PyTree-trees that do
|
|
represent algebraic data types (note however that some some do not) -
|
|
but for this to make sense, the user must have a way to define how to
|
|
interpret tree-nodes, in particular identify leaves. This requires
|
|
providing a function `node_handler` with the same signature and
|
|
behavior as described below for "node handlers".
|
|
|
|
Additionally, this module provides mapping-over-pytrees via
|
|
`pytree_map`, which is also built in such a way that it makes the
|
|
correspondence between an algebraic data type and its Python
|
|
nested-tree representation explicit. Despite being powerful and
|
|
flexible, this, however, may in general require a bit more effort to
|
|
wire up, since node-rebuilding can be fairly nontrivial.
|
|
|
|
Furthermore, as a prominent application, this module provides a simple
|
|
deep-freezing function that translates a nested Python data structure
|
|
to deeply-immutable form.
|
|
|
|
## Concepts and Conventions
|
|
|
|
"revtuple representation":
|
|
|
|
As we iterate over a tree, we will have to keep track of the
|
|
path-to-tree-root. Naturally, two sibling nodes `n1` and `n2`
|
|
will share the same parent-path (being siblings), so it makes
|
|
sense to use a linked-list-with-shared-tail representation.
|
|
Python does not have a natural notion for that, so we use
|
|
recursively-constructed tuples `(node_tag, parent_path)`
|
|
that represent the path-from-root in-reverse-order, i.e.
|
|
for a non-empty path `p`, `p[0]` is the node-tag at the
|
|
deepest nesting level. We call this a "revtuple representation"
|
|
of the path.
|
|
|
|
"node handler":
|
|
|
|
A node-handler classifies a tree-node as "leaf or other node", and
|
|
for non-leaf nodes provides information about both its children and
|
|
how to rebuild it. The behavior of a node-handler function must be
|
|
in alignment with this docstring:
|
|
|
|
'''Processes a tree-node as required by pytree-iteration and -mapping.
|
|
|
|
Args:
|
|
revtuple_path: Revtuple-representation of the path-from-root
|
|
to the current node.
|
|
node: a tree-node in a ML-model tree that is recursively
|
|
built out of leaf-values and other nodes.
|
|
|
|
Returns:
|
|
`None` if the tree-node is to be regarded as a leaf, otherwise
|
|
a pair `(rebuilder, iterator)`, where `iterator` iterates
|
|
over the data-content of the node, each item represented as a pair
|
|
of `(lookup_path_item, value_item)`, and `rebuilder` is a function
|
|
which, when applied to an iterable of the aforementioned value-items
|
|
(or some transformation thereof) returns a node that is equivalent
|
|
to the original (or up to a transformation of the contents).
|
|
|
|
Raises:
|
|
InvalidTreeNodeError: If the tree contains a node of a kind
|
|
that is not expected to show up.
|
|
'''
|
|
|
|
Examples:
|
|
|
|
(The behavior of a node-handler is somewhat nontrivial, so covering
|
|
two very common cases via examples is in order.)
|
|
|
|
This node-handler would allow descending into (nested)
|
|
instances of `list` (but not subclass instances thereof):
|
|
|
|
```def list_node_handler(revtuple_path, obj):
|
|
''' ... '''
|
|
if type(obj) is list:
|
|
return list, enumerate(obj)
|
|
else:
|
|
return None
|
|
```
|
|
|
|
This node-handler would allow descending into (nested) mappings,
|
|
which upon rebuilding would get turned into `dict` instances:
|
|
|
|
```def mapping_node_handler(revtuple_path, obj):
|
|
''' ... '''
|
|
if isinstance(obj, collections.abc.Mapping):
|
|
# For generic mappings, we cannot rely on key- and item-iteration
|
|
# being guaranteed to use identical iteration-order.
|
|
items = list(obj.items())
|
|
keys = [kv[0] for kv in items]
|
|
return (lambda values: dict(zip(keys, values))), items
|
|
else:
|
|
return None
|
|
```
|
|
|
|
A dict/mapping node-handler can of course rename keys, add or remove
|
|
entries, make decisions based on the item-path, or map a dict to
|
|
an associative list, etc.
|
|
|
|
## Further Design Notes
|
|
|
|
The `pytree_map` function requests the leaf-transform and node-handler
|
|
to be side-effect-free functions. This is both required to leave
|
|
implementation-side flexibility, and also follows the general LISP
|
|
recommendation to not abuse mapping (which should be a pure
|
|
data-transformation) for imperative data processing. Overall, if
|
|
a need for more general "nested datastructures" processing becomes
|
|
pressing, it is for the better if this leads to a proper articulation
|
|
of the specific needs, to be addressed with appropriate design, rather
|
|
than abuse of functional data-transforms becoming "a bad idiom
|
|
that turned into established practice".
|
|
|
|
"""
|
|
|
|
import collections.abc
|
|
import immutabledict
|
|
|
|
import numpy
|
|
|
|
from typing import Any, Callable, Iterable, Iterator, TypeVar
|
|
|
|
|
|
T = TypeVar('T')
|
|
U = TypeVar('U')
|
|
|
|
KT = TypeVar('KT')
|
|
NT = TypeVar('NT')
|
|
|
|
|
|
## Type of the reverse-order-keys-to-root path.
|
|
# (This code actually illustrates why https://xkcd.com/2483/ is very misguided.)
|
|
RevTuplePath = tuple
|
|
|
|
## Type of the `leaf_transform` function-argument used for tree-iteration.
|
|
#
|
|
# This would be the correct type we would have to specify here but cannot,
|
|
# since the design of Python's static typing at the time of this writing
|
|
# is too broken for that:
|
|
#
|
|
# type LeafTransformFunc[L, R] = Callable[[RevTuplePath, L], R]
|
|
#
|
|
# Instead, we have to settle for...:
|
|
LeafTransformFunc = Callable[[RevTuplePath, Any], Any]
|
|
|
|
|
|
## Type of the `tree_node_handler` function-argument used for
|
|
## tree-iteration and tree-mapping.
|
|
#
|
|
# Again, this is the correct type we would have to put here but cannot:
|
|
#
|
|
# type NodeHandlerFunc[KT] = (
|
|
# Callable[[NT],
|
|
# None | tuple[Callable[[Iterable[tuple[KT, NT]]], NT],
|
|
# Iterator[tuple[KT, NT]]]])
|
|
#
|
|
# ...so, we have to instead settle for:
|
|
NodeHandlerFunc = (
|
|
Callable[[RevTuplePath, NT],
|
|
None | tuple[Callable[[Iterable[tuple[Any, NT]]], NT],
|
|
Iterator[tuple[Any, NT]]]])
|
|
|
|
|
|
Predicate = Callable[[object], bool]
|
|
|
|
|
|
class InvalidTreeNodeError(ValueError):
|
|
"""Encountered a tree-node of invalid type."""
|
|
|
|
|
|
def linearize_revtuple_path(
|
|
revtuple_path: RevTuplePath,
|
|
present_as: Callable[[Iterator[T]], U] = tuple) -> U:
|
|
"""Translates a revtuple path to (typically) linear form.
|
|
|
|
With default `present_as`, this will map a path of the form
|
|
`(key_{N}, (key_{N-1}, ..., (root, ())))` into a tuple
|
|
(root, ..., key_{N-1}, key_{N}).
|
|
|
|
Args:
|
|
revtuple_path: A linked-list-as-recursive-pairs
|
|
reverse-order tuple-representation of the path.
|
|
Path-root is `()`, and node-key `x` relative to
|
|
earlier path `p` is represented as `(x, p)`.
|
|
present_as: Callable that consumes an iterator over
|
|
path-pieces - with the deepest-nesting level coming last -
|
|
turning it into a linearized path. Defaults to `tuple`.
|
|
|
|
Returns:
|
|
Linearized presentation of all the node-keys in the
|
|
recursive-path in order, deepest-down path component coming last.
|
|
"""
|
|
pieces = []
|
|
todo = revtuple_path
|
|
while todo:
|
|
node, todo = todo
|
|
pieces.append(node)
|
|
return present_as(reversed(pieces))
|
|
|
|
|
|
# This function itself has type `NodeHandlerFunc`, but Python does not
|
|
# allow us to here simply type-annotate it like this. We cannot even
|
|
# introduce an abbreviation for the complicated output-type,
|
|
# since that would have to be parametric in node-type `NT` (and `KT`).
|
|
def everything_is_a_leaf_node_handler(
|
|
revtuple_path: tuple,
|
|
node : NT) -> (
|
|
None | tuple[Callable[[Iterable[tuple[Any, NT]]], NT],
|
|
Iterator[tuple[Any, NT]]]):
|
|
"""Processes a tree-node as required by pytree-iteration and -mapping.
|
|
|
|
Interface and signature are in alignment with the requirements for a
|
|
"node handler" function explained in the module-docstring.
|
|
|
|
Args:
|
|
revtuple_path: the path-to-root for this node.
|
|
node: a tree-node.
|
|
|
|
Returns:
|
|
`None`, i.e. classifying any kind of node as a leaf-node.
|
|
"""
|
|
del revtuple_path, node # Unused.
|
|
return None
|
|
|
|
|
|
def leaf_summary(path: RevTuplePath, x: object):
|
|
"""Produces a human-readable summary-string for a leaf-node.
|
|
|
|
Args:
|
|
path: revtuple representation of the path-to-root.
|
|
x: The leaf-value.
|
|
"""
|
|
del path # Ignored here.
|
|
tx = type(x)
|
|
mod = tx.__module__
|
|
modname = mod if isinstance(mod, str) else mod.__name__
|
|
type_str = f'{modname}.{tx.__qualname__}'
|
|
repr_str = repr(x)
|
|
repr_abbrev = repr_str if len(repr_str) < 40 else repr_str[:40] + ' ...'
|
|
# On str, int, float, etc. `{type_str}(repr(x))` would actually still be
|
|
# a (non-literal) Python-expression that would evaluate to the original value.
|
|
# However, we make no promises beyond "human-readable".
|
|
return f'{type_str}({repr_abbrev})'
|
|
|
|
|
|
# With respect to static type annotations, the limitations of Python's
|
|
# approach to static typing really become prominently visible here.
|
|
#
|
|
# Different arguments have type-parameters, but since there is no way
|
|
# to have parametric abbreviations such as `LeafTransformFunc[L, R]`,
|
|
# the only way we would have available to express relations between
|
|
# type-parameters would be to substitute in the not-abbreviated form of
|
|
# `NodeHandlerFunc` and `LeafTransformFunc`, giving us something monstrous.
|
|
# We instead here settle for "we cannot express that `tree` must
|
|
# have the same type as the input-type to `tree_node_handler` and use `Any`,
|
|
# and likewise for leaf_transform and the output.
|
|
def pytree_leaf_iter(
|
|
tree: Any,
|
|
leaf_transform: LeafTransformFunc,
|
|
node_handler: NodeHandlerFunc = everything_is_a_leaf_node_handler,
|
|
) -> Iterator[Any]:
|
|
# ...actual return type would be `Iterator[{what leaf_transform returns}]`.
|
|
"""Iterates over the leaves of a tree.
|
|
|
|
Args:
|
|
tree: The tree to iterate over.
|
|
leaf_transform: A callable `f` that will get applied
|
|
as `f(revtuple_path, leaf)`, where `revtuple_path`
|
|
is the revtuple representation of the path to the
|
|
leaf from the root.
|
|
node_handler: A "node handler" (see module docstring)
|
|
that processes nodes encountered during iterative traversal.
|
|
|
|
Yields:
|
|
Value of `leaf_transform(p, x)`, where `x` is the current leaf
|
|
and `p` is its revtuple-path to the root.
|
|
"""
|
|
# Note: Exit points for the code below are in non-obvious places
|
|
# and hence marked with " # ***EXIT***".
|
|
#
|
|
# Doing iteration properly is slightly nontrivial.
|
|
# One may be tempted to go for a very simple recursive implementation
|
|
# (with an extra pre-final `path` argument to `pytree_iter`):
|
|
#
|
|
# maybe_substructure = node_handler(path, tree)
|
|
# if maybe_substructure is None:
|
|
# # We are looking at a leaf-node.
|
|
# yield leaf_transform(path, tree)
|
|
# else:
|
|
# _, contents_iter = maybe_substructure
|
|
# for k, v in contents_iter:
|
|
# yield from pytree_iter(v, leaf_transform, (k, path), node_handler)
|
|
#
|
|
# That, however, would be flawed, since there is no a priori reason
|
|
# why a pytree may not be a very deeply nested structure - such as a
|
|
# long linked list. That would then risk raising `RecursionError`,
|
|
# and since Python by design(!) does not perform tail call elimination
|
|
# or any other kind of advanced CPS transforms, there is no recursive
|
|
# solution here. So, to do this properly, we have to do this iteratively.
|
|
#
|
|
# We are facing an annoying situation here: If `tree` itself is a leaf,
|
|
# we have two options: (a) wrapping it up in a one-node tree
|
|
# and processing that, or (b) special-casing "root is a leaf".
|
|
# Option (b) leads to some mild node-processing code-duplication
|
|
# for a single node (the root).
|
|
# Option (a) requires having special cases for node-processing that
|
|
# get looked at for every tree node. We go with option (b) here.
|
|
maybe_substructure = node_handler((), tree)
|
|
if maybe_substructure is None:
|
|
# The tree itself is a leaf.
|
|
yield leaf_transform((), tree)
|
|
return # ***EXIT***
|
|
# Otherwise, we are looking at a tree.
|
|
_, contents_iter = maybe_substructure
|
|
current_revtuple_path = ()
|
|
work_to_do = [contents_iter]
|
|
# Otherwise-unreachable sentinel for reliably identifying
|
|
# iterator-exhaustion without using exceptions:
|
|
sentinel = object()
|
|
while True:
|
|
current_iter = work_to_do[-1]
|
|
maybe_next_item = next(current_iter, sentinel)
|
|
if maybe_next_item is sentinel:
|
|
# We are done at this level.
|
|
work_to_do.pop()
|
|
if not work_to_do: return # ***EXIT***
|
|
current_revtuple_path = current_revtuple_path[1]
|
|
else:
|
|
path_piece, subtree = maybe_next_item
|
|
extended_revtuple_path = (path_piece, current_revtuple_path)
|
|
maybe_subtree_substructure = node_handler(extended_revtuple_path, subtree)
|
|
if maybe_subtree_substructure is None: # Case: subtree is a leaf.
|
|
yield leaf_transform(extended_revtuple_path, subtree)
|
|
else: # Case: subtree is a tree.
|
|
current_revtuple_path = (path_piece, current_revtuple_path)
|
|
_, items_iter = maybe_subtree_substructure
|
|
work_to_do.append(items_iter)
|
|
|
|
|
|
# The current design approach here would be appropriate for
|
|
# applying leaf-transforms while retaining the structure of the tree -
|
|
# which closely corresponds to e.g. a (a -> b) -> (Tree a -> Tree b) functor.
|
|
#
|
|
# It is not entirely clear whether this is the abstraction that we should
|
|
# consider as being appropriately generic to flesh out explicitly - rather
|
|
# than starting from a more general approach of which this then is a special
|
|
# case. Some background: https://ncatlab.org/nlab/show/recursion+scheme
|
|
#
|
|
# On the other hand, there is a lot of flexibility via whatever
|
|
# node-rebuilder a node-handler produces - this can do quite some reshaping
|
|
# of a tree, including dropping or duplicating nodes.
|
|
def pytree_map(
|
|
tree: Any,
|
|
leaf_transform,
|
|
node_handler: NodeHandlerFunc = everything_is_a_leaf_node_handler,
|
|
):
|
|
"""Maps a (potentially nested) Python value to another such value.
|
|
|
|
Args:
|
|
tree: The Python-object to be mapped.
|
|
leaf_transform: A callable `f` that will get applied
|
|
as `f(revtuple_path, leaf)`, where `revtuple_path`
|
|
is the revtuple representation of the path to the
|
|
leaf from the root. Must be side effect free.
|
|
node_handler: A "node handler" (see module docstring)
|
|
that processes nodes encountered during iterative traversal.
|
|
Must be side effect free.
|
|
|
|
Returns:
|
|
The outcome of translating `tree`.
|
|
"""
|
|
# Note: Exit points for the code below are in non-obvious places
|
|
# and hence marked with " # ***EXIT***".
|
|
#
|
|
# Otherwise-inaccessible sentinel object, for reliably identifying
|
|
# missing-values via identity-check against sentinel lookup-default.
|
|
sentinel = object()
|
|
# Code structure mostly follows pytree_leaf_iter.
|
|
maybe_substructure = node_handler((), tree)
|
|
if maybe_substructure is None:
|
|
return leaf_transform((), tree) # ***EXIT***
|
|
rebuilder, items_iter = maybe_substructure
|
|
current_revtuple_path = ()
|
|
# Per-level, we have a triplet of:
|
|
# (rebuilder, remaining_items_to_iterate_over, processed).
|
|
parts_for_assembly = [(rebuilder, items_iter, [])]
|
|
while True:
|
|
this_rebuilder, this_items_iter, this_done_pieces = parts_for_assembly[-1]
|
|
maybe_next_item = next(this_items_iter, sentinel)
|
|
if maybe_next_item is sentinel:
|
|
# We are done with all the items for this level.
|
|
parts_for_assembly.pop()
|
|
built_iter = this_rebuilder(this_done_pieces)
|
|
if not parts_for_assembly: # No outer structure, so at-top-level.
|
|
return built_iter # ***EXIT***
|
|
else: # We have outer structure.
|
|
parts_for_assembly[-1][-1].append(built_iter)
|
|
current_revtuple_path = current_revtuple_path[1]
|
|
continue # ...with next is-the-final-item-complete-check.
|
|
else:
|
|
# More constituents of the current item.
|
|
path_piece, subtree_item = maybe_next_item
|
|
extended_revtuple_path = (path_piece, current_revtuple_path)
|
|
maybe_subtree_substructure = node_handler(
|
|
extended_revtuple_path,
|
|
subtree_item)
|
|
if maybe_subtree_substructure is None:
|
|
this_done_pieces.append(
|
|
leaf_transform(extended_revtuple_path, subtree_item))
|
|
else:
|
|
# We have a subtree.
|
|
subtree_rebuilder, subtree_items_iter = maybe_subtree_substructure
|
|
current_revtuple_path = (path_piece,
|
|
current_revtuple_path)
|
|
parts_for_assembly.append(
|
|
(subtree_rebuilder, subtree_items_iter, []))
|
|
|
|
|
|
def deep_freeze(
|
|
tree,
|
|
*,
|
|
is_mapping : Predicate = lambda x: isinstance(x, collections.abc.Mapping),
|
|
is_set : Predicate = lambda x: isinstance(x, collections.abc.Set),
|
|
is_sequence : Predicate = lambda x: isinstance(x, (list, tuple)),
|
|
leaf_fn: Callable[[Any], Any] = lambda x: x,
|
|
):
|
|
"""Recursively freezes Set/Mapping/List/Tuple structures.
|
|
|
|
Args:
|
|
tree: The potentially deeply-nested object to deep-freeze.
|
|
is_mapping: Callable that decides whether a sub-object is a mapping.
|
|
Defaults to an `isinstance()` check for `collections.abc.Mapping`.
|
|
is_set: Callable that decides whether a sub-object is a set.
|
|
Defaults to an `isinstance()` check for `collections.abc.Set`.
|
|
is_sequence: Callable that decides whether a sub-object is a sequence.
|
|
Defaults to a check for being a `tuple` or `list` instance.
|
|
leaf_fn: Function to use for translating non-mapping/set/sequence
|
|
instances.
|
|
|
|
Returns:
|
|
Translated-to-deeply-immutable form of `tree`.
|
|
"""
|
|
idict = immutabledict.immutabledict
|
|
def freeze_node_handler(path, x):
|
|
if is_set(x):
|
|
return frozenset, ((None, y) for y in x)
|
|
if is_mapping(x):
|
|
# Mappings already have hashable, so
|
|
# (should-be-)deeply-immutable keys.
|
|
# Hence, we only need to deep-freeze the values.
|
|
#
|
|
# Note that non-`dict` mappings might not guarantee
|
|
# to respect iteration-order, so we have to be careful here:
|
|
items = list(x.items())
|
|
keys = [kv[0] for kv in items]
|
|
values = [kv[1] for kv in items]
|
|
return ((lambda ys: idict(zip(keys, ys))),
|
|
iter(items))
|
|
if is_sequence(x):
|
|
return tuple, enumerate(iter(x))
|
|
# Otherwise, this should not be traversed.
|
|
return None
|
|
def leaf_transform(revtuple_path, value):
|
|
del revtuple_path # Unused.
|
|
return leaf_fn(value)
|
|
return pytree_map(tree, leaf_transform, freeze_node_handler)
|