Merge 205cda5ca6 into 7cadbfce10
This commit is contained in:
commit
8c8b3f23ba
|
|
@ -82,6 +82,20 @@ class GGUFWriter:
|
|||
GGUFValueType.FLOAT64: "d",
|
||||
GGUFValueType.BOOL: "?",
|
||||
}
|
||||
_integer_value_types = {
|
||||
GGUFValueType.UINT8,
|
||||
GGUFValueType.INT8,
|
||||
GGUFValueType.UINT16,
|
||||
GGUFValueType.INT16,
|
||||
GGUFValueType.UINT32,
|
||||
GGUFValueType.INT32,
|
||||
GGUFValueType.UINT64,
|
||||
GGUFValueType.INT64,
|
||||
}
|
||||
_float_value_types = {
|
||||
GGUFValueType.FLOAT32,
|
||||
GGUFValueType.FLOAT64,
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self, path: os.PathLike[str] | str | None, arch: str, use_temp_file: bool = False, endianess: GGUFEndian = GGUFEndian.LITTLE,
|
||||
|
|
@ -278,8 +292,41 @@ class GGUFWriter:
|
|||
if any(key in kv_data for kv_data in self.kv_data):
|
||||
logger.warning(f'Duplicated key name {key!r}, overwriting it with new value {val!r} of type {vtype.name}')
|
||||
|
||||
if not self._is_valid_metadata_value(val, vtype, sub_type=sub_type):
|
||||
raise ValueError(
|
||||
f"Invalid GGUF metadata value for key {key!r}: declared {vtype.name}, got {type(val).__name__}"
|
||||
)
|
||||
|
||||
self.kv_data[0][key] = GGUFValue(value=val, type=vtype, sub_type=sub_type)
|
||||
|
||||
@classmethod
|
||||
def _is_valid_metadata_value(
|
||||
cls, val: Any, vtype: GGUFValueType, sub_type: GGUFValueType | None = None
|
||||
) -> bool:
|
||||
if vtype == GGUFValueType.STRING:
|
||||
return isinstance(val, (str, bytes, bytearray))
|
||||
if vtype == GGUFValueType.ARRAY:
|
||||
if not isinstance(val, Sequence):
|
||||
return False
|
||||
if len(val) == 0:
|
||||
return False
|
||||
if sub_type is not None:
|
||||
return all(cls._is_valid_metadata_value(item, sub_type) for item in val)
|
||||
if isinstance(val, bytes):
|
||||
return True
|
||||
try:
|
||||
item_type = GGUFValueType.get_type(val[0])
|
||||
return all(GGUFValueType.get_type(item) is item_type for item in val[1:])
|
||||
except ValueError:
|
||||
return False
|
||||
if vtype == GGUFValueType.BOOL:
|
||||
return isinstance(val, (bool, np.bool_))
|
||||
if vtype in cls._integer_value_types:
|
||||
return isinstance(val, (int, np.integer)) and not isinstance(val, (bool, np.bool_))
|
||||
if vtype in cls._float_value_types:
|
||||
return isinstance(val, (float, np.floating))
|
||||
return True
|
||||
|
||||
def add_uint8(self, key: str, val: int) -> None:
|
||||
self.add_key_value(key,val, GGUFValueType.UINT8)
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,104 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
import os
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
|
||||
# Necessary to load the local gguf package
|
||||
if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent / "gguf-py").exists():
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
import gguf
|
||||
|
||||
|
||||
class TestGGUFWriter(unittest.TestCase):
|
||||
|
||||
def test_add_key_value_rejects_string_declared_type_with_list_value(self):
|
||||
writer = gguf.GGUFWriter("/tmp/test.gguf", "llama")
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "declared STRING"):
|
||||
writer.add_key_value("general.license", ["apache-2.0"], gguf.GGUFValueType.STRING)
|
||||
|
||||
def test_add_key_value_rejects_array_with_invalid_declared_sub_type_item(self):
|
||||
writer = gguf.GGUFWriter("/tmp/test.gguf", "llama")
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "declared ARRAY"):
|
||||
writer.add_key_value("general.tags", [1], gguf.GGUFValueType.ARRAY, sub_type=gguf.GGUFValueType.STRING)
|
||||
|
||||
def test_add_key_value_rejects_mixed_type_array_value(self):
|
||||
writer = gguf.GGUFWriter("/tmp/test.gguf", "llama")
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "declared ARRAY"):
|
||||
writer.add_key_value("general.tags", [1, "apache-2.0"], gguf.GGUFValueType.ARRAY)
|
||||
|
||||
def test_add_key_value_rejects_empty_array_value(self):
|
||||
writer = gguf.GGUFWriter("/tmp/test.gguf", "llama")
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "declared ARRAY"):
|
||||
writer.add_key_value("general.tags", [], gguf.GGUFValueType.ARRAY)
|
||||
|
||||
def test_add_key_value_accepts_string_array_with_declared_string_sub_type(self):
|
||||
writer = gguf.GGUFWriter("/tmp/test.gguf", "llama")
|
||||
|
||||
writer.add_key_value("general.tags", ["apache-2.0"], gguf.GGUFValueType.ARRAY, sub_type=gguf.GGUFValueType.STRING)
|
||||
|
||||
def test_add_key_value_accepts_bytes_array_value(self):
|
||||
writer = gguf.GGUFWriter("/tmp/test.gguf", "llama")
|
||||
|
||||
writer.add_key_value("tokenizer.ggml.precompiled_charsmap", b"\x01\x02", gguf.GGUFValueType.ARRAY)
|
||||
|
||||
def test_add_key_value_accepts_homogeneous_integer_array_value(self):
|
||||
writer = gguf.GGUFWriter("/tmp/test.gguf", "llama")
|
||||
|
||||
writer.add_key_value("test.array", [1, 2], gguf.GGUFValueType.ARRAY)
|
||||
|
||||
def test_add_uint32_still_accepts_python_int(self):
|
||||
writer = gguf.GGUFWriter("/tmp/test.gguf", "llama")
|
||||
|
||||
writer.add_uint32("answer", 42)
|
||||
|
||||
def test_add_key_value_accepts_numpy_uint32_for_declared_uint32(self):
|
||||
writer = gguf.GGUFWriter("/tmp/test.gguf", "llama")
|
||||
|
||||
writer.add_key_value("answer", np.uint32(42), gguf.GGUFValueType.UINT32)
|
||||
|
||||
def test_add_key_value_rejects_bool_for_declared_uint32(self):
|
||||
writer = gguf.GGUFWriter("/tmp/test.gguf", "llama")
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "declared UINT32"):
|
||||
writer.add_key_value("answer", True, gguf.GGUFValueType.UINT32)
|
||||
|
||||
def test_add_key_value_rejects_numpy_bool_for_declared_uint32(self):
|
||||
writer = gguf.GGUFWriter("/tmp/test.gguf", "llama")
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "declared UINT32"):
|
||||
writer.add_key_value("answer", np.bool_(True), gguf.GGUFValueType.UINT32)
|
||||
|
||||
def test_add_key_value_accepts_numpy_float32_for_declared_float32(self):
|
||||
writer = gguf.GGUFWriter("/tmp/test.gguf", "llama")
|
||||
|
||||
writer.add_key_value("temperature", np.float32(1.5), gguf.GGUFValueType.FLOAT32)
|
||||
|
||||
def test_add_key_value_rejects_python_int_for_declared_float32(self):
|
||||
writer = gguf.GGUFWriter("/tmp/test.gguf", "llama")
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "declared FLOAT32"):
|
||||
writer.add_key_value("temperature", 1, gguf.GGUFValueType.FLOAT32)
|
||||
|
||||
def test_add_key_value_rejects_python_int_for_declared_float64(self):
|
||||
writer = gguf.GGUFWriter("/tmp/test.gguf", "llama")
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "declared FLOAT64"):
|
||||
writer.add_key_value("temperature", 1, gguf.GGUFValueType.FLOAT64)
|
||||
|
||||
def test_add_key_value_accepts_numpy_bool_for_declared_bool(self):
|
||||
writer = gguf.GGUFWriter("/tmp/test.gguf", "llama")
|
||||
|
||||
writer.add_key_value("flag", np.bool_(True), gguf.GGUFValueType.BOOL)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Loading…
Reference in New Issue