mirror of https://github.com/google/gemma.cpp.git
93 lines
3.6 KiB
Python
93 lines
3.6 KiB
Python
"""Basic tests for 'algebraic data type based pytree' transformations."""
|
|
|
|
|
|
import io
|
|
import os
|
|
import tempfile
|
|
import unittest
|
|
|
|
import numpy
|
|
|
|
import ml_model_transforms
|
|
|
|
|
|
def _get_model(prefix):
|
|
return {
|
|
prefix + 'a1': numpy.arange(1000, 1024).reshape(6, 4).astype(numpy.float32),
|
|
prefix + 'a2': numpy.arange(2000, 2048).reshape(6, 8).astype(numpy.float32),
|
|
prefix + 'b1': {
|
|
prefix + 'c1': numpy.arange(100, 127).reshape(3, 3, 3).astype(numpy.int8),
|
|
prefix + 'c2': numpy.arange(100, 128).reshape(7, 4).astype(numpy.float64)
|
|
}}
|
|
|
|
|
|
class MLModeltransformsTest(unittest.TestCase):
|
|
"""Basic correctness validation tests for ML-model transformations."""
|
|
|
|
def test_ml_model_leaf_summary(self):
|
|
"""Tests guarantees given by `ml_model_leaf_summary`."""
|
|
summary = ml_model_transforms.ml_model_leaf_summary(
|
|
('a', ()),
|
|
numpy.arange(1000, 1024).reshape(6, 4).astype(numpy.int16),
|
|
sep='##')
|
|
self.assertIn('##', summary) # Separator is respected.
|
|
self.assertIn('(6, 4)', summary) # Shape is mentioned somewhere.
|
|
self.assertIn('int16', summary) # dtype is mentioned somewhere.
|
|
|
|
def test_revtuple_autovivify_from_linear(self):
|
|
"""Tests guarantees given by `revtuple_autovifify_from_linear`."""
|
|
with self.subTest(guarantee='empty'):
|
|
self.assertEqual(
|
|
ml_model_transforms.revtuple_autovifify_from_linear([]),
|
|
{})
|
|
with self.subTest(guarantee='generic'):
|
|
keys_vals = [(('a', 'b1', 'c1'), 1001),
|
|
(('a', 'b2'), 1002),
|
|
(('a2',), 1003),
|
|
]
|
|
self.assertEqual(
|
|
ml_model_transforms.revtuple_autovifify_from_linear(keys_vals),
|
|
{'a': {'b1': {'c1': 1001}, 'b2': 1002}, 'a2': 1003})
|
|
|
|
def test_model_overview(self):
|
|
"""Tests guarantees given by `model_overview`."""
|
|
model = _get_model('xyz')
|
|
out_io = io.StringIO()
|
|
ml_model_transforms.model_overview(model, out=out_io)
|
|
overview = out_io.getvalue()
|
|
self.assertIn('xyz', overview)
|
|
|
|
def test_model_contents(self):
|
|
"""Tests guarantees given by `model_contents`."""
|
|
model = _get_model('pq_')
|
|
contents = ml_model_transforms.model_contents(model)
|
|
fingerprints = {k: (a.shape, a.ravel()[:3].tolist())
|
|
for k, a in contents.items()}
|
|
self.assertEqual(fingerprints,
|
|
{('pq_a1',): ((6, 4), [1000.0, 1001.0, 1002.0]),
|
|
('pq_a2',): ((6, 8), [2000.0, 2001.0, 2002.0]),
|
|
('pq_b1', 'pq_c1'): ((3, 3, 3), [100, 101, 102]),
|
|
('pq_b1', 'pq_c2'): ((7, 4), [100.0, 101.0, 102.0])})
|
|
|
|
def test_model_save_load_basic(self):
|
|
"""Tests basic guarantees given by `model_save` and `model_load`."""
|
|
# What we care about here is that the round trip works - so
|
|
# it makes more sense to test saving and loading as one unit.
|
|
model_orig = _get_model('model_')
|
|
with tempfile.TemporaryDirectory() as tempdir:
|
|
filepath_stem = os.path.join(tempdir, 'the_model')
|
|
total_size, total_time = ml_model_transforms.model_save(model_orig,
|
|
filepath_stem)
|
|
self.assertGreater(total_size, 0)
|
|
self.assertGreater(total_time, 0)
|
|
model_reloaded = ml_model_transforms.model_load(filepath_stem)
|
|
contents_orig = ml_model_transforms.model_contents(model_orig)
|
|
contents_reloaded = ml_model_transforms.model_contents(model_reloaded)
|
|
self.assertEqual(
|
|
{k: v.tolist() for k, v in contents_orig.items()},
|
|
{k: v.tolist() for k, v in contents_reloaded.items()})
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|