mirror of https://github.com/google/gemma.cpp.git
169 lines
6.6 KiB
Python
169 lines
6.6 KiB
Python
"""Basic tests for 'algebraic data type based pytree' transformations."""
|
|
|
|
|
|
import collections.abc
|
|
import sys
|
|
import unittest
|
|
|
|
import pytree_transforms
|
|
|
|
|
|
def _get_deep_pytree(packaging_fn, bottom, depth):
|
|
current = bottom
|
|
for n in reversed(range(depth)):
|
|
current = packaging_fn(n, current)
|
|
return current
|
|
|
|
|
|
def _dict_node_handler(p, d):
|
|
del p # Unused.
|
|
if isinstance(d, dict):
|
|
keys = d.keys()
|
|
newdict = lambda vals: dict(zip(keys, vals))
|
|
return (newdict, iter(d.items()))
|
|
else:
|
|
return None
|
|
|
|
|
|
class PyTreeTest(unittest.TestCase):
|
|
"""Basic correctness validation tests for PyTree transformations."""
|
|
|
|
def test_linearize_revtuple_path(self):
|
|
"""Tests guarantees given by `linearize_revtuple_path`."""
|
|
linearize_revtuple_path = pytree_transforms.linearize_revtuple_path
|
|
with self.subTest(guarantee='empty'):
|
|
self.assertEqual(linearize_revtuple_path(()), ())
|
|
with self.subTest(guarantee='typical'):
|
|
self.assertEqual(linearize_revtuple_path((30, (20, (10, ())))),
|
|
(10, 20, 30))
|
|
with self.subTest(guarantee='present_as'):
|
|
self.assertEqual(
|
|
linearize_revtuple_path(
|
|
(30, (20, (10, ()))), present_as=list),
|
|
[10, 20, 30])
|
|
|
|
def test_everything_is_a_leaf_node_handler(self):
|
|
"""Tests guarantees given by `everything_is_a_leaf_node_handler`."""
|
|
everything_is_a_leaf_node_handler = (
|
|
pytree_transforms.everything_is_a_leaf_node_handler)
|
|
self.assertEqual(everything_is_a_leaf_node_handler((), 'abc'),
|
|
None)
|
|
self.assertEqual(everything_is_a_leaf_node_handler(('b', ()),
|
|
dict(a=3)),
|
|
None)
|
|
|
|
def test_leaf_summary(self):
|
|
"""Tests guarantees given by `leaf_summary`."""
|
|
# Since the docstring only guarantees "a human-readable presentation",
|
|
# we can and should only do loose checks.
|
|
thing = (5678, 9531)
|
|
summary = pytree_transforms.leaf_summary(('key', ()), thing)
|
|
self.assertIsInstance(summary, str)
|
|
self.assertIn(str(thing[0]), summary)
|
|
self.assertIn(str(thing[1]), summary)
|
|
|
|
def test_pytree_leaf_iter(self):
|
|
"""Tests guarantees given by `pytree_leaf_iter`."""
|
|
pytree_leaf_iter = pytree_transforms.pytree_leaf_iter
|
|
def leaf_transform(path, leaf):
|
|
return repr(leaf) if path and path[0].startswith('R') else leaf
|
|
with self.subTest(guarantee='returns_iterator'):
|
|
result = pytree_leaf_iter(7, leaf_transform, _dict_node_handler)
|
|
self.assertIsInstance(result, collections.abc.Iterator)
|
|
with self.subTest(guarantee='totally_empty'):
|
|
result = list(pytree_leaf_iter({}, leaf_transform, _dict_node_handler))
|
|
self.assertEqual(result, [])
|
|
with self.subTest(guarantee='no_leaves'):
|
|
result = list(pytree_leaf_iter(dict(a={}),
|
|
leaf_transform, _dict_node_handler))
|
|
self.assertEqual(result, [])
|
|
with self.subTest(guarantee='is_leaf'):
|
|
result = list(pytree_leaf_iter(777, leaf_transform, _dict_node_handler))
|
|
self.assertEqual(result, [777])
|
|
with self.subTest(guarantee='generic'):
|
|
result = list(pytree_leaf_iter(
|
|
dict(n0=dict(n01=dict(n012=1002,
|
|
n013=1003,
|
|
Rn014=1004,
|
|
),
|
|
n02=1005),
|
|
n5=1006),
|
|
leaf_transform, _dict_node_handler))
|
|
self.assertEqual(result, [1002, 1003, '1004', 1005, 1006])
|
|
with self.subTest(guarantee='with_keys'):
|
|
result = list(pytree_leaf_iter(
|
|
dict(n0=dict(n01=dict(n012=1002,
|
|
n013=1003)),
|
|
n1=1004),
|
|
lambda p, s: (pytree_transforms.linearize_revtuple_path(p), s),
|
|
_dict_node_handler))
|
|
self.assertEqual(result,
|
|
[(('n0', 'n01', 'n012'), 1002),
|
|
(('n0', 'n01', 'n013'), 1003),
|
|
(('n1',), 1004)])
|
|
|
|
def test_pytree_map(self):
|
|
"""Tests guarantees given by `pytree_map`."""
|
|
pytree_map = pytree_transforms.pytree_map
|
|
leaf_transform = lambda p, s: repr(s)
|
|
tree1 = dict(t0=dict(t10=1001,
|
|
t11=dict(t110=1002,
|
|
t111=1003),
|
|
t12=dict(t120=1004,
|
|
t121=1005,
|
|
t122=1006)),
|
|
t1=1007)
|
|
with self.subTest(guarantee='no_leaves'):
|
|
result = pytree_map(dict(a={}),
|
|
leaf_transform,
|
|
_dict_node_handler)
|
|
self.assertEqual(result, dict(a={}))
|
|
with self.subTest(guarantee='is_leaf'):
|
|
result = pytree_map(777, leaf_transform, _dict_node_handler)
|
|
self.assertEqual(result, '777')
|
|
with self.subTest(guarantee='generic'):
|
|
result = pytree_map(tree1, leaf_transform, _dict_node_handler)
|
|
self.assertEqual(result['t0']['t10'], '1001')
|
|
|
|
def test_deeply_nested(self):
|
|
"""Tests correct behavior on deeply-nested data structures."""
|
|
pytree_leaf_iter = pytree_transforms.pytree_leaf_iter
|
|
pytree_map = pytree_transforms.pytree_map
|
|
#
|
|
depth = max(10**5, sys.getrecursionlimit() + 100)
|
|
deep_tree = _get_deep_pytree(lambda n, t: {n: t},
|
|
'leaf', depth)
|
|
with self.subTest(function='pytree_leaf_iter'):
|
|
leaves = list(pytree_leaf_iter(deep_tree,
|
|
lambda p, s: s.upper(),
|
|
_dict_node_handler))
|
|
self.assertEqual(leaves, ['LEAF'])
|
|
with self.subTest(function='pytree_map'):
|
|
mapped_deep_tree = pytree_map(deep_tree,
|
|
lambda p, s: s,
|
|
_dict_node_handler)
|
|
self.assertIsInstance(mapped_deep_tree, dict)
|
|
with self.subTest(function='combined'):
|
|
leaves = list(
|
|
pytree_leaf_iter(
|
|
pytree_map(deep_tree,
|
|
lambda p, s: s.capitalize(),
|
|
_dict_node_handler),
|
|
lambda p, s: s + s,
|
|
_dict_node_handler))
|
|
self.assertEqual(leaves, ['LeafLeaf'])
|
|
|
|
def test_deep_freeze(self):
|
|
"""Tests guarantees given by `deep_freeze`."""
|
|
frozen = pytree_transforms.deep_freeze(
|
|
dict(a=[1001, 1002, dict(b=(1003, [1004, {1005, 1006}]))]))
|
|
self.assertIsInstance(frozen, collections.abc.Mapping)
|
|
self.assertNotIsInstance(frozen, collections.abc.MutableMapping)
|
|
self.assertIsInstance(frozen['a'], tuple)
|
|
# `frozen` is hashable, and hashes to an integer.
|
|
self.assertIsInstance(hash(frozen), int)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|