gguf-py: validate metadata values against declared types

This commit is contained in:
Eyüp Can Akman 2026-03-08 19:18:26 +03:00
parent d088d5b74f
commit 205cda5ca6
2 changed files with 151 additions and 0 deletions

View File

@ -82,6 +82,20 @@ class GGUFWriter:
GGUFValueType.FLOAT64: "d",
GGUFValueType.BOOL: "?",
}
_integer_value_types = {
GGUFValueType.UINT8,
GGUFValueType.INT8,
GGUFValueType.UINT16,
GGUFValueType.INT16,
GGUFValueType.UINT32,
GGUFValueType.INT32,
GGUFValueType.UINT64,
GGUFValueType.INT64,
}
_float_value_types = {
GGUFValueType.FLOAT32,
GGUFValueType.FLOAT64,
}
def __init__(
self, path: os.PathLike[str] | str | None, arch: str, use_temp_file: bool = False, endianess: GGUFEndian = GGUFEndian.LITTLE,
@ -275,8 +289,41 @@ class GGUFWriter:
if any(key in kv_data for kv_data in self.kv_data):
logger.warning(f'Duplicated key name {key!r}, overwriting it with new value {val!r} of type {vtype.name}')
if not self._is_valid_metadata_value(val, vtype, sub_type=sub_type):
raise ValueError(
f"Invalid GGUF metadata value for key {key!r}: declared {vtype.name}, got {type(val).__name__}"
)
self.kv_data[0][key] = GGUFValue(value=val, type=vtype, sub_type=sub_type)
@classmethod
def _is_valid_metadata_value(
cls, val: Any, vtype: GGUFValueType, sub_type: GGUFValueType | None = None
) -> bool:
if vtype == GGUFValueType.STRING:
return isinstance(val, (str, bytes, bytearray))
if vtype == GGUFValueType.ARRAY:
if not isinstance(val, Sequence):
return False
if len(val) == 0:
return False
if sub_type is not None:
return all(cls._is_valid_metadata_value(item, sub_type) for item in val)
if isinstance(val, bytes):
return True
try:
item_type = GGUFValueType.get_type(val[0])
return all(GGUFValueType.get_type(item) is item_type for item in val[1:])
except ValueError:
return False
if vtype == GGUFValueType.BOOL:
return isinstance(val, (bool, np.bool_))
if vtype in cls._integer_value_types:
return isinstance(val, (int, np.integer)) and not isinstance(val, (bool, np.bool_))
if vtype in cls._float_value_types:
return isinstance(val, (float, np.floating))
return True
def add_uint8(self, key: str, val: int) -> None:
self.add_key_value(key,val, GGUFValueType.UINT8)

View File

@ -0,0 +1,104 @@
#!/usr/bin/env python3
import unittest
from pathlib import Path
import os
import sys
import numpy as np
# Necessary to load the local gguf package
if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent / "gguf-py").exists():
sys.path.insert(0, str(Path(__file__).parent.parent))
import gguf
class TestGGUFWriter(unittest.TestCase):
def test_add_key_value_rejects_string_declared_type_with_list_value(self):
writer = gguf.GGUFWriter("/tmp/test.gguf", "llama")
with self.assertRaisesRegex(ValueError, "declared STRING"):
writer.add_key_value("general.license", ["apache-2.0"], gguf.GGUFValueType.STRING)
def test_add_key_value_rejects_array_with_invalid_declared_sub_type_item(self):
writer = gguf.GGUFWriter("/tmp/test.gguf", "llama")
with self.assertRaisesRegex(ValueError, "declared ARRAY"):
writer.add_key_value("general.tags", [1], gguf.GGUFValueType.ARRAY, sub_type=gguf.GGUFValueType.STRING)
def test_add_key_value_rejects_mixed_type_array_value(self):
writer = gguf.GGUFWriter("/tmp/test.gguf", "llama")
with self.assertRaisesRegex(ValueError, "declared ARRAY"):
writer.add_key_value("general.tags", [1, "apache-2.0"], gguf.GGUFValueType.ARRAY)
def test_add_key_value_rejects_empty_array_value(self):
writer = gguf.GGUFWriter("/tmp/test.gguf", "llama")
with self.assertRaisesRegex(ValueError, "declared ARRAY"):
writer.add_key_value("general.tags", [], gguf.GGUFValueType.ARRAY)
def test_add_key_value_accepts_string_array_with_declared_string_sub_type(self):
writer = gguf.GGUFWriter("/tmp/test.gguf", "llama")
writer.add_key_value("general.tags", ["apache-2.0"], gguf.GGUFValueType.ARRAY, sub_type=gguf.GGUFValueType.STRING)
def test_add_key_value_accepts_bytes_array_value(self):
writer = gguf.GGUFWriter("/tmp/test.gguf", "llama")
writer.add_key_value("tokenizer.ggml.precompiled_charsmap", b"\x01\x02", gguf.GGUFValueType.ARRAY)
def test_add_key_value_accepts_homogeneous_integer_array_value(self):
writer = gguf.GGUFWriter("/tmp/test.gguf", "llama")
writer.add_key_value("test.array", [1, 2], gguf.GGUFValueType.ARRAY)
def test_add_uint32_still_accepts_python_int(self):
writer = gguf.GGUFWriter("/tmp/test.gguf", "llama")
writer.add_uint32("answer", 42)
def test_add_key_value_accepts_numpy_uint32_for_declared_uint32(self):
writer = gguf.GGUFWriter("/tmp/test.gguf", "llama")
writer.add_key_value("answer", np.uint32(42), gguf.GGUFValueType.UINT32)
def test_add_key_value_rejects_bool_for_declared_uint32(self):
writer = gguf.GGUFWriter("/tmp/test.gguf", "llama")
with self.assertRaisesRegex(ValueError, "declared UINT32"):
writer.add_key_value("answer", True, gguf.GGUFValueType.UINT32)
def test_add_key_value_rejects_numpy_bool_for_declared_uint32(self):
writer = gguf.GGUFWriter("/tmp/test.gguf", "llama")
with self.assertRaisesRegex(ValueError, "declared UINT32"):
writer.add_key_value("answer", np.bool_(True), gguf.GGUFValueType.UINT32)
def test_add_key_value_accepts_numpy_float32_for_declared_float32(self):
writer = gguf.GGUFWriter("/tmp/test.gguf", "llama")
writer.add_key_value("temperature", np.float32(1.5), gguf.GGUFValueType.FLOAT32)
def test_add_key_value_rejects_python_int_for_declared_float32(self):
writer = gguf.GGUFWriter("/tmp/test.gguf", "llama")
with self.assertRaisesRegex(ValueError, "declared FLOAT32"):
writer.add_key_value("temperature", 1, gguf.GGUFValueType.FLOAT32)
def test_add_key_value_rejects_python_int_for_declared_float64(self):
writer = gguf.GGUFWriter("/tmp/test.gguf", "llama")
with self.assertRaisesRegex(ValueError, "declared FLOAT64"):
writer.add_key_value("temperature", 1, gguf.GGUFValueType.FLOAT64)
def test_add_key_value_accepts_numpy_bool_for_declared_bool(self):
writer = gguf.GGUFWriter("/tmp/test.gguf", "llama")
writer.add_key_value("flag", np.bool_(True), gguf.GGUFValueType.BOOL)
if __name__ == "__main__":
unittest.main()