"""Transformations for python-trees representing the parameters of a ML model. Important: This module assumes that byte-order is the same on the machine that serializes data and the machine that deserializes data. If, for example, numpy-data gets dumped, respectively loaded, with a dtype-specification of numpy.float32, on-file byte-order will be host byte order. """ import ast import hashlib import itertools import pprint import sys import time from typing import Any, Callable, Iterable, Iterator, Mapping, TypeVar import numpy import pytree_transforms NT = TypeVar('NT') def ml_model_leaf_summary(path, x, sep=', '): """Produces a textual summary of a leaf-node and its path. Args: path: The path-to-root, as a reverse-order recursive pair of path-components, with `()` as root. x: The leaf-object. sep: the separator between description-elements. Default ', ' allows for convenient line-by-line processing (such as via grep, perl -ne, etc.), but using e.g. sep=',\n ' might be more useful for human consumption. Returns: A human-readable string providing information about the node. """ # Using `repr` for path-components to get a faithful presentation. # (...which still however would be somewat painful to correctly # split into components.) path_str = ','.join(map(repr, pytree_transforms.linearize_revtuple_path(path))) tx = type(x) mod = tx.__module__ # Either a module or a string like 'builtins'. modname = mod if isinstance(mod, str) else mod.__name__ type_str = f'{modname}.{tx.__qualname__}' try: # `numpy.ndarray` instances have a `.data` property that gives access # to a buffer via which we can hashlib-fingerprint the numerical # contents. We here simply try to produce a fingerprint and also look # up the .dtype of the object. Technically, there is a somewhat-unsound # assumption here that if these operations succeed, we are indeed looking # at a ndarray or sufficiently similar object for these operations to # make sense. As the output is declared "for human consumption", this # fishiness is not a problem. fp = hashlib.sha256(x.data).hexdigest() start = list(itertools.islice(x.flat, 5)) stats_str = ( f'min={numpy.min(x):.6g}, max={numpy.max(x):.6g}, ' f'mean={numpy.mean(x):.6g}, std={numpy.std(x):.6g}') return (f'{path_str:60s}: <{type_str}{sep}' f'fp=0x{fp}{sep}{stats_str}{sep}shape={x.shape}, ' f'dtype={x.dtype}{sep}start={start}>') except (AttributeError, ValueError, TypeError): # Fallback - trying to include information about the data-content # of a likely-numerical-array failed. return f'{path_str:60s}: {type_str}({repr(x)})' # A specialized node-handler. # Interface follows node-handler expectations defined in pytree_transforms. def _ml_model_tree_node_handler(path: tuple, node : NT) -> ( None | tuple[Callable[[Iterable[tuple[Any, NT]]], NT], Iterator[tuple[Any, NT]]]): """Processes a tree-node as required by pytree-iteration and -mapping. Args: path: revtuple path to the current node. node: a tree-node in a ML-model tree that is recursively built out of `numpy.ndarray` leaf-values and dicts mapping node-name string-keys to other such nodes representing subtrees - and nothing else. Returns: `None` if the tree-node is to be regarded as a leaf, otherwise a pair `(rebuilder, iterator)`, where `iterator` iterates over the data-content of the node, each item represented as a pair of `(lookup_path_item, value_item)`, and `rebuilder` is a function which, when applied to `iterator` or any iterable with the same elements, returns a node that is equivalent to the original. Raises: NotAMLModelTreeNodeError: If the tree contains a node that is neither a `dict` nor a `numpy.ndarray` instance. """ # The astute reader will notice that we are doing something fishy # here - this code could not be translated to Haskell as-is, since # `NT` cannot actually be a proper type-variable in the sense # of parametric polymorphism. del path # Unused. if isinstance(node, dict): return dict, iter(node.items()) if isinstance(node, numpy.ndarray): return None raise pytree_transforms.NotAMLModelTreeNodeError( f'Type of bad node: {type(node)}') def _ml_model_extract_leaf_transform( path: pytree_transforms.RevTuplePath, leaf: Any): """Maps an array-leaf to a pair `(full_path, lambda: array)`. The computation that produces the leaf-value is lazified underneath a `lambda`, since if we e.g. performed a memory-expensive transformation (such as some dtype-changes) directly at this point, then going from an iterator over tree-items for one-by-one consumption to a list of these items would have all the dtype-transformed values around simultaneously. We want to avoid situations where we can do nothing about having multiple variants of the data simultaneously in memory. """ # Hack: If we are encountering a `bfloat16` numpy-array, # we pretend to have the data as a numpy.float32 array, # since that's about all that contemporary CPUs can process # efficiently here. linearized_path = pytree_transforms.linearize_revtuple_path(path) try: # We have to use some trickery to detect `bfloat16`. if leaf.dtype.descr[-1] == ('', ' Any: """Performs perl-style autovivification on a nested-dict tree. Args: keys_and_vals: An iterable of pairs `(key_path, value)`, where `key_path` is a sequence of keys to be used to navigate to the result via iterative dict-lookup, left-to-right. Must not have duplicate keys, and must not more than one key if an empty-sequence key is present. If this iterable is an iterator, it will be fully exhausted on successful execution. Returns: An object representing a nested-dict structure such that for every `key_path` from `keys_and_vals`, recursive-dict-lookup on the elements of that path starting from this object will produce the corresponding value. An empty `keys_and_vals` set will return `{}`. Every dict in the nested return-value that has been populated by autovivification is newly allocated. """ # Code structure is a bit gnarly here due to f(keys_and_vals=[((), x)]) # having to evaluate to x and not a dict. # There may be ways to prettify/simplify this. result = None empty = {} for linear_path, val in keys_and_vals: if linear_path == (): if result is not None: raise ValueError('Root-value seen alongside other values.') result = val else: if result is None: result = {} elif type(result) is not dict: # We already did encounter a root-value. raise ValueError('Root-value seen alongside other values.') cursor = result for n in range(len(linear_path) - 1): cursor = cursor.setdefault(linear_path[n], empty) if cursor is empty: # Regenerate `empty` if we just used it up. empty = {} cursor[linear_path[-1]] = val return {} if result is None else result def model_overview(tree, out=None) -> None: """Prints a human-readable overview to `(out or sys.stdout)`.""" actual_out = out or sys.stdout for line in pytree_transforms.pytree_leaf_iter( tree, ml_model_leaf_summary, _ml_model_tree_node_handler): print(line, file=actual_out) def model_contents(tree) -> Mapping[tuple[str, ...], Any]: """Maps a model to a {pytree_keys: data_array} mapping. Args: tree: The ML-model parameter-tree, built recursively out of dict-instances with numpy.ndarray instances as leaves. Returns: A mapping from linearized pytree-key-sequence tuple to the corresponding leaf-value. """ def leaf_transform(revtuple_path, leaf): return pytree_transforms.linearize_revtuple_path(revtuple_path), leaf return dict( pytree_transforms.pytree_leaf_iter( tree, leaf_transform, _ml_model_tree_node_handler)) def _fn_identity(x): return x def model_save(tree, filepath_stem: str, data_suffix: str = '.data', manifest_suffix: str | None = '.manifest', key: Callable[[tuple[str, ...]], Any] | None = None, array_transform_by_pytree_key: ( Mapping[tuple[str, ...], Callable[[numpy.ndarray], numpy.ndarray]] | None) = None, report: Callable[[str], None] | None = None, byte_align: int = 8) -> tuple[int, float]: """Saves the content of a ML-model parameter-tree to filesystem. After successful execution, the file f"{filepath_stem}.data" will hold the combined numerical model-parameters, and f"{filepath_stem}.manifest" will contain the key for interpreting (and rebuilding) the data. Args: tree: The ML-model parameter-tree, built recursively out of dict-instances with numpy.ndarray instances as leaves. filepath_stem: Filesystem location for data. data_suffix: Suffix to use for the data file. manifest_suffix: Either `None`, in which case no manifest-file will get written, or the suffix for the manifest-file. key: `None` or a key-function that will be applied to the linear model-path and used for sorting the data arrays by increasing value of the key-function. If the key-function returns `None` on an item, then this item is not included. array_transform_by_pytree_key: Optional mapping from pytree-key to an array-to-array transformation function to apply to the array prior to serialization. report: Optional callable for logging progress-reports. byte_align: byte-alignment to use for numerical array data. Numerical arrays whose size in bytes is not a multiple of this will get padded to the next full multiple. Returns: A pair of `(size, time_sec)`, where `size` is the total byte-size of the `.data` file and `time_sec` is the elapsed time for saving the model, in seconds. """ time0 = time.monotonic() if array_transform_by_pytree_key is None: array_transform_by_pytree_key = {} model_lazy_items = ( pytree_transforms.pytree_leaf_iter( tree, _ml_model_extract_leaf_transform, _ml_model_tree_node_handler)) if key is not None: to_write = [ nkv[1:] for nkv in sorted( (nkv for nkv in ((key(path), path, v) for path, v in model_lazy_items) if nkv[0] is not None), key=lambda nkv: nkv[0])] else: to_write = list(model_lazy_items) # def lazy_arr_path_shape_dtype_size(path_and_lazy_arr): path, lazy_arr = path_and_lazy_arr arr = array_transform_by_pytree_key.get(path, _fn_identity)(lazy_arr()) return path, arr.shape, arr.dtype, arr.data.nbytes arrs_path_shape_dtype_nbytes = list( map(lazy_arr_path_shape_dtype_size, to_write)) # We need to know the total size of all the data. bytesizes = [nbytes for *_, nbytes in arrs_path_shape_dtype_nbytes] padded_bytesizes = [-(-bytesize // byte_align * byte_align) for bytesize in bytesizes] offsets = numpy.cumsum([0] + padded_bytesizes) membuf = numpy.memmap(filepath_stem + data_suffix, mode='w+', shape=offsets[-1]) try: for (path, shape, dtype, nbytes), offset, (_, lazy_arr) in zip( arrs_path_shape_dtype_nbytes, offsets, to_write): # Note that if getting the array from the lazy lambda involved some # computation, such as a copying dtype-change, that computation would # end up being done multiple times here - including once above, to compute # byte-sizes, and once more here. transformed_arr = array_transform_by_pytree_key.get( path, _fn_identity)(lazy_arr()) membuf[offset : offset + nbytes] = numpy.frombuffer( transformed_arr.ravel().data, 'u1') if report is not None: samples = ', '.join(map(str, transformed_arr.ravel()[:5])) report(f'# Adding: {path!r}\n bytes: {nbytes:10d}, ' f'shape: {shape!r:30},\n start: [{samples}, ...]') transformed_arr = None # Drop memory references to numerical arrays ASAP. finally: if membuf is not None: membuf.flush() # NumPy wart: the memory-buffer is a resource that conceptually # should be .close()able - since mmap()ing holds on to a # file descriptor. However, it looks as if that clean-up were done # in the "finalizer", despite that having meanwhile been widely # understood as dubious practice. So, the best we can do here is # to explicitly and clearly remove our reference to the instance. del membuf if manifest_suffix is not None: # We still have to serialize the data that allows us to reconstruct # a tree that is equivalent to the original. manifest_data = [ dict(path=path, dtype=dtype.descr[-1][-1], shape=shape, nbytes=nbytes, offset=offset) for (path, shape, dtype, nbytes), offset in zip( arrs_path_shape_dtype_nbytes, offsets)] with open(filepath_stem + '.manifest', 'wt') as h_manifest: pprint.pprint(manifest_data, stream=h_manifest) time_taken = time.monotonic() - time0 return offsets[-1], time_taken def model_load(filepath_stem, mmapped=True): """Loads a model saved by `model_save`. Tries to load the model from f"{filepath_stem}.data" and f"{filepath_stem}.manifest". Args: filepath_stem: The model location on the filesystem. mmapped: Whether data-arrays will be slices of a `numpy.memmap` mapped buffer, to be paged in on demand only, or in-memory copies of the data. Returns: A dict/numpy.ndarray tree representation of the model, equivalent to the original model. """ with open(filepath_stem + '.manifest', 'rt') as h_manifest: manifest = ast.literal_eval(h_manifest.read()) membuf = numpy.memmap(filepath_stem + '.data', mode='r+') paths_and_arrays = [] for item in manifest: path = item['path'] dtype = numpy.dtype(item['dtype']) shape = item['shape'] nbytes = item['nbytes'] offset = item['offset'] data_array = numpy.frombuffer(membuf[offset : offset + nbytes].data, dtype=dtype).reshape(shape) paths_and_arrays.append( (path, data_array if mmapped else data_array.copy())) # At this point, the memory-buffer is no longer needed. Still, if # data-arrays retain references to the underlying data # (i.e. when mmapped=False), this should keep the mapping # - and hence file descriptor - open. We then are in a somewhat # undesirable situation of clean-up of a resource that happens in a # hard-to-predict way releasing a file descriptor. del membuf return revtuple_autovifify_from_linear(paths_and_arrays)