diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 9ee3ac9e8f..55dda69fff 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -82,6 +82,20 @@ class GGUFWriter: GGUFValueType.FLOAT64: "d", GGUFValueType.BOOL: "?", } + _integer_value_types = { + GGUFValueType.UINT8, + GGUFValueType.INT8, + GGUFValueType.UINT16, + GGUFValueType.INT16, + GGUFValueType.UINT32, + GGUFValueType.INT32, + GGUFValueType.UINT64, + GGUFValueType.INT64, + } + _float_value_types = { + GGUFValueType.FLOAT32, + GGUFValueType.FLOAT64, + } def __init__( self, path: os.PathLike[str] | str | None, arch: str, use_temp_file: bool = False, endianess: GGUFEndian = GGUFEndian.LITTLE, @@ -275,8 +289,41 @@ class GGUFWriter: if any(key in kv_data for kv_data in self.kv_data): logger.warning(f'Duplicated key name {key!r}, overwriting it with new value {val!r} of type {vtype.name}') + if not self._is_valid_metadata_value(val, vtype, sub_type=sub_type): + raise ValueError( + f"Invalid GGUF metadata value for key {key!r}: declared {vtype.name}, got {type(val).__name__}" + ) + self.kv_data[0][key] = GGUFValue(value=val, type=vtype, sub_type=sub_type) + @classmethod + def _is_valid_metadata_value( + cls, val: Any, vtype: GGUFValueType, sub_type: GGUFValueType | None = None + ) -> bool: + if vtype == GGUFValueType.STRING: + return isinstance(val, (str, bytes, bytearray)) + if vtype == GGUFValueType.ARRAY: + if not isinstance(val, Sequence): + return False + if len(val) == 0: + return False + if sub_type is not None: + return all(cls._is_valid_metadata_value(item, sub_type) for item in val) + if isinstance(val, bytes): + return True + try: + item_type = GGUFValueType.get_type(val[0]) + return all(GGUFValueType.get_type(item) is item_type for item in val[1:]) + except ValueError: + return False + if vtype == GGUFValueType.BOOL: + return isinstance(val, (bool, np.bool_)) + if vtype in cls._integer_value_types: + return isinstance(val, (int, np.integer)) and not isinstance(val, (bool, np.bool_)) + if vtype in cls._float_value_types: + return isinstance(val, (float, np.floating)) + return True + def add_uint8(self, key: str, val: int) -> None: self.add_key_value(key,val, GGUFValueType.UINT8) diff --git a/gguf-py/tests/test_gguf_writer.py b/gguf-py/tests/test_gguf_writer.py new file mode 100644 index 0000000000..5911da29ba --- /dev/null +++ b/gguf-py/tests/test_gguf_writer.py @@ -0,0 +1,104 @@ +#!/usr/bin/env python3 + +import unittest +from pathlib import Path +import os +import sys + +import numpy as np + +# Necessary to load the local gguf package +if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent / "gguf-py").exists(): + sys.path.insert(0, str(Path(__file__).parent.parent)) + +import gguf + + +class TestGGUFWriter(unittest.TestCase): + + def test_add_key_value_rejects_string_declared_type_with_list_value(self): + writer = gguf.GGUFWriter("/tmp/test.gguf", "llama") + + with self.assertRaisesRegex(ValueError, "declared STRING"): + writer.add_key_value("general.license", ["apache-2.0"], gguf.GGUFValueType.STRING) + + def test_add_key_value_rejects_array_with_invalid_declared_sub_type_item(self): + writer = gguf.GGUFWriter("/tmp/test.gguf", "llama") + + with self.assertRaisesRegex(ValueError, "declared ARRAY"): + writer.add_key_value("general.tags", [1], gguf.GGUFValueType.ARRAY, sub_type=gguf.GGUFValueType.STRING) + + def test_add_key_value_rejects_mixed_type_array_value(self): + writer = gguf.GGUFWriter("/tmp/test.gguf", "llama") + + with self.assertRaisesRegex(ValueError, "declared ARRAY"): + writer.add_key_value("general.tags", [1, "apache-2.0"], gguf.GGUFValueType.ARRAY) + + def test_add_key_value_rejects_empty_array_value(self): + writer = gguf.GGUFWriter("/tmp/test.gguf", "llama") + + with self.assertRaisesRegex(ValueError, "declared ARRAY"): + writer.add_key_value("general.tags", [], gguf.GGUFValueType.ARRAY) + + def test_add_key_value_accepts_string_array_with_declared_string_sub_type(self): + writer = gguf.GGUFWriter("/tmp/test.gguf", "llama") + + writer.add_key_value("general.tags", ["apache-2.0"], gguf.GGUFValueType.ARRAY, sub_type=gguf.GGUFValueType.STRING) + + def test_add_key_value_accepts_bytes_array_value(self): + writer = gguf.GGUFWriter("/tmp/test.gguf", "llama") + + writer.add_key_value("tokenizer.ggml.precompiled_charsmap", b"\x01\x02", gguf.GGUFValueType.ARRAY) + + def test_add_key_value_accepts_homogeneous_integer_array_value(self): + writer = gguf.GGUFWriter("/tmp/test.gguf", "llama") + + writer.add_key_value("test.array", [1, 2], gguf.GGUFValueType.ARRAY) + + def test_add_uint32_still_accepts_python_int(self): + writer = gguf.GGUFWriter("/tmp/test.gguf", "llama") + + writer.add_uint32("answer", 42) + + def test_add_key_value_accepts_numpy_uint32_for_declared_uint32(self): + writer = gguf.GGUFWriter("/tmp/test.gguf", "llama") + + writer.add_key_value("answer", np.uint32(42), gguf.GGUFValueType.UINT32) + + def test_add_key_value_rejects_bool_for_declared_uint32(self): + writer = gguf.GGUFWriter("/tmp/test.gguf", "llama") + + with self.assertRaisesRegex(ValueError, "declared UINT32"): + writer.add_key_value("answer", True, gguf.GGUFValueType.UINT32) + + def test_add_key_value_rejects_numpy_bool_for_declared_uint32(self): + writer = gguf.GGUFWriter("/tmp/test.gguf", "llama") + + with self.assertRaisesRegex(ValueError, "declared UINT32"): + writer.add_key_value("answer", np.bool_(True), gguf.GGUFValueType.UINT32) + + def test_add_key_value_accepts_numpy_float32_for_declared_float32(self): + writer = gguf.GGUFWriter("/tmp/test.gguf", "llama") + + writer.add_key_value("temperature", np.float32(1.5), gguf.GGUFValueType.FLOAT32) + + def test_add_key_value_rejects_python_int_for_declared_float32(self): + writer = gguf.GGUFWriter("/tmp/test.gguf", "llama") + + with self.assertRaisesRegex(ValueError, "declared FLOAT32"): + writer.add_key_value("temperature", 1, gguf.GGUFValueType.FLOAT32) + + def test_add_key_value_rejects_python_int_for_declared_float64(self): + writer = gguf.GGUFWriter("/tmp/test.gguf", "llama") + + with self.assertRaisesRegex(ValueError, "declared FLOAT64"): + writer.add_key_value("temperature", 1, gguf.GGUFValueType.FLOAT64) + + def test_add_key_value_accepts_numpy_bool_for_declared_bool(self): + writer = gguf.GGUFWriter("/tmp/test.gguf", "llama") + + writer.add_key_value("flag", np.bool_(True), gguf.GGUFValueType.BOOL) + + +if __name__ == "__main__": + unittest.main()