gguf-py: add type validation to GGUFWriter.add_key_value (#9095)

Currently add_key_value() accepts any Python value and silently passes it
to struct.pack, which produces garbage when types are wrong (e.g. a list
passed where a string is expected).  The failure only surfaces later as
an opaque struct.error during serialization, making bugs hard to diagnose.

This adds a _validate_value_type() classmethod that checks values against
the declared GGUFValueType before storing them, raising a clear TypeError
that includes the key name, expected type, actual type, and value.

Key design decisions:
- Reject bool for integer/float types (Python bool is int subclass)
- Accept int for float types (struct.pack handles this; conversion
  scripts commonly pass integer literals like eps=0)
- Accept numpy scalar types (np.integer, np.floating, np.bool_)
- Allow bytes for ARRAY (used by add_array for UINT8 data like
  precompiled_charsmap)
- Reject None for all types

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
FrozenScorch 2026-03-31 23:00:44 -04:00
parent 7cadbfce10
commit 22f18a4838
2 changed files with 259 additions and 0 deletions

View File

@ -82,6 +82,16 @@ class GGUFWriter:
GGUFValueType.FLOAT64: "d",
GGUFValueType.BOOL: "?",
}
_integer_value_types = frozenset({
GGUFValueType.UINT8, GGUFValueType.INT8,
GGUFValueType.UINT16, GGUFValueType.INT16,
GGUFValueType.UINT32, GGUFValueType.INT32,
GGUFValueType.UINT64, GGUFValueType.INT64,
})
_float_value_types = frozenset({
GGUFValueType.FLOAT32,
GGUFValueType.FLOAT64,
})
def __init__(
self, path: os.PathLike[str] | str | None, arch: str, use_temp_file: bool = False, endianess: GGUFEndian = GGUFEndian.LITTLE,
@ -278,8 +288,40 @@ class GGUFWriter:
if any(key in kv_data for kv_data in self.kv_data):
logger.warning(f'Duplicated key name {key!r}, overwriting it with new value {val!r} of type {vtype.name}')
if not self._validate_value_type(val, vtype, sub_type):
raise TypeError(
f"key {key!r}: expected {vtype.name}, got {type(val).__name__}: {val!r}"
)
self.kv_data[0][key] = GGUFValue(value=val, type=vtype, sub_type=sub_type)
@classmethod
def _validate_value_type(cls, val: Any, vtype: GGUFValueType, sub_type: GGUFValueType | None = None) -> bool:
if val is None:
return False
if vtype == GGUFValueType.STRING:
return isinstance(val, (str, bytes, bytearray))
if vtype == GGUFValueType.ARRAY:
# str/bytearray are Sequence subclasses but are never valid array
# values in GGUF. bytes, however, *is* used by add_array() as a
# UINT8 array (e.g. precompiled_charsmap), so it must be allowed.
if not isinstance(val, Sequence) or isinstance(val, (str, bytearray)):
return False
if sub_type is not None and len(val) > 0:
return all(cls._validate_value_type(item, sub_type) for item in val)
return True
if vtype == GGUFValueType.BOOL:
return isinstance(val, (bool, np.bool_))
if vtype in cls._integer_value_types:
# Reject bool -- Python's bool is a subclass of int, which is a common
# source of silent bugs when passing boolean flags to integer metadata.
return isinstance(val, (int, np.integer)) and not isinstance(val, (bool, np.bool_))
if vtype in cls._float_value_types:
# Accept int for float -- struct.pack handles this correctly and
# conversion scripts commonly pass integer literals (e.g. eps=0).
return isinstance(val, (int, float, np.integer, np.floating)) and not isinstance(val, (bool, np.bool_))
return True
def add_uint8(self, key: str, val: int) -> None:
self.add_key_value(key,val, GGUFValueType.UINT8)

View File

@ -0,0 +1,217 @@
#!/usr/bin/env python3
import unittest
from pathlib import Path
import os
import sys
import tempfile
import numpy as np
# Necessary to load the local gguf package
if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent / "gguf-py").exists():
sys.path.insert(0, str(Path(__file__).parent.parent))
import gguf
class TestGGUFWriterTypeValidation(unittest.TestCase):
def _make_writer(self):
fd, path = tempfile.mkstemp(suffix=".gguf")
os.close(fd)
self.addCleanup(os.unlink, path)
return gguf.GGUFWriter(path, "llama")
# --- STRING validation ---
def test_rejects_list_for_string(self):
"""The original bug: model card license as a list instead of string."""
w = self._make_writer()
with self.assertRaises(TypeError) as ctx:
w.add_key_value("general.license", ["Apache-2.0"], gguf.GGUFValueType.STRING)
self.assertIn("general.license", str(ctx.exception))
self.assertIn("STRING", str(ctx.exception))
def test_rejects_int_for_string(self):
w = self._make_writer()
with self.assertRaises(TypeError) as ctx:
w.add_key_value("general.name", 42, gguf.GGUFValueType.STRING)
self.assertIn("STRING", str(ctx.exception))
def test_accepts_str_for_string(self):
w = self._make_writer()
w.add_key_value("general.name", "test", gguf.GGUFValueType.STRING)
def test_accepts_bytes_for_string(self):
w = self._make_writer()
w.add_key_value("tokenizer.ggml.precompiled_charsmap", b"\x01\x02", gguf.GGUFValueType.STRING)
# --- BOOL vs INT trap ---
def test_rejects_bool_for_uint32(self):
"""Python bool is subclass of int -- must be explicitly rejected."""
w = self._make_writer()
with self.assertRaises(TypeError) as ctx:
w.add_key_value("general.file_type", True, gguf.GGUFValueType.UINT32)
self.assertIn("UINT32", str(ctx.exception))
def test_rejects_bool_for_int32(self):
w = self._make_writer()
with self.assertRaises(TypeError) as ctx:
w.add_key_value("answer", False, gguf.GGUFValueType.INT32)
self.assertIn("INT32", str(ctx.exception))
def test_rejects_numpy_bool_for_uint32(self):
"""np.bool_ must also be rejected for integer types."""
w = self._make_writer()
with self.assertRaises(TypeError):
w.add_key_value("general.file_type", np.bool_(True), gguf.GGUFValueType.UINT32)
def test_accepts_python_bool_for_bool(self):
w = self._make_writer()
w.add_key_value("general.use_parallel_residual", True, gguf.GGUFValueType.BOOL)
def test_accepts_numpy_bool_for_bool(self):
w = self._make_writer()
w.add_key_value("general.use_parallel_residual", np.bool_(True), gguf.GGUFValueType.BOOL)
# --- FLOAT accepts int ---
def test_accepts_int_for_float32(self):
"""int literals should be accepted for float types (struct.pack handles this)."""
w = self._make_writer()
w.add_key_value("attention.layer_norm_rms_eps", 0, gguf.GGUFValueType.FLOAT32)
def test_accepts_int_for_float64(self):
w = self._make_writer()
w.add_key_value("test", 1, gguf.GGUFValueType.FLOAT64)
def test_rejects_bool_for_float32(self):
"""bool should NOT be accepted for float types even though bool is int subclass."""
w = self._make_writer()
with self.assertRaises(TypeError):
w.add_key_value("temperature", True, gguf.GGUFValueType.FLOAT32)
def test_accepts_float_for_float32(self):
w = self._make_writer()
w.add_key_value("temperature", 0.7, gguf.GGUFValueType.FLOAT32)
# --- NumPy scalars ---
def test_accepts_numpy_int32_for_uint32(self):
w = self._make_writer()
w.add_key_value("general.file_type", np.int32(7), gguf.GGUFValueType.UINT32)
def test_accepts_numpy_uint64_for_uint64(self):
w = self._make_writer()
w.add_key_value("test", np.uint64(42), gguf.GGUFValueType.UINT64)
def test_accepts_numpy_float32_for_float32(self):
w = self._make_writer()
w.add_key_value("temperature", np.float32(0.5), gguf.GGUFValueType.FLOAT32)
def test_accepts_numpy_float64_for_float64(self):
w = self._make_writer()
w.add_key_value("test", np.float64(1.5), gguf.GGUFValueType.FLOAT64)
# --- ARRAY validation ---
def test_rejects_non_sequence_for_array(self):
w = self._make_writer()
with self.assertRaises(TypeError):
w.add_key_value("general.tags", 42, gguf.GGUFValueType.ARRAY)
def test_rejects_mixed_type_array_with_string_sub_type(self):
w = self._make_writer()
with self.assertRaises(TypeError):
w.add_key_value("general.tags", ["ok", 1], gguf.GGUFValueType.ARRAY, sub_type=gguf.GGUFValueType.STRING)
def test_accepts_string_array_with_string_sub_type(self):
w = self._make_writer()
w.add_key_value("general.tags", ["conversational", "code"], gguf.GGUFValueType.ARRAY, sub_type=gguf.GGUFValueType.STRING)
def test_accepts_int_array_with_int32_sub_type(self):
w = self._make_writer()
w.add_key_value("test", [1, 2, 3], gguf.GGUFValueType.ARRAY, sub_type=gguf.GGUFValueType.INT32)
def test_rejects_bool_in_int_array(self):
"""bool must be rejected even inside integer arrays."""
w = self._make_writer()
with self.assertRaises(TypeError):
w.add_key_value("test", [1, True], gguf.GGUFValueType.ARRAY, sub_type=gguf.GGUFValueType.INT32)
# --- convenience methods still work ---
def test_add_uint32_still_works(self):
w = self._make_writer()
w.add_uint32("general.file_type", 7)
def test_add_float32_still_works(self):
w = self._make_writer()
w.add_float32("attention.layer_norm_rms_eps", 1e-5)
def test_add_bool_still_works(self):
w = self._make_writer()
w.add_bool("general.use_parallel_residual", False)
def test_add_string_still_works(self):
w = self._make_writer()
w.add_string("general.name", "TestModel")
def test_add_array_still_works(self):
w = self._make_writer()
w.add_array("general.tags", ["conversational", "code"])
# --- additional edge cases from code review ---
def test_rejects_str_for_array(self):
"""str is a Sequence subclass but never a valid GGUF array value."""
w = self._make_writer()
with self.assertRaises(TypeError) as ctx:
w.add_key_value("general.tags", "not-a-list", gguf.GGUFValueType.ARRAY)
self.assertIn("ARRAY", str(ctx.exception))
self.assertIn("str", str(ctx.exception))
def test_accepts_bytes_for_array(self):
"""bytes is a legitimate UINT8 array representation (e.g. precompiled_charsmap)."""
w = self._make_writer()
w.add_key_value("test", b"\x01\x02", gguf.GGUFValueType.ARRAY)
def test_rejects_none_for_string(self):
w = self._make_writer()
with self.assertRaises(TypeError):
w.add_key_value("general.name", None, gguf.GGUFValueType.STRING)
def test_rejects_none_for_uint32(self):
w = self._make_writer()
with self.assertRaises(TypeError):
w.add_key_value("general.file_type", None, gguf.GGUFValueType.UINT32)
def test_rejects_float_for_uint32(self):
"""float must be rejected for integer types."""
w = self._make_writer()
with self.assertRaises(TypeError) as ctx:
w.add_key_value("general.file_type", 7.5, gguf.GGUFValueType.UINT32)
self.assertIn("UINT32", str(ctx.exception))
def test_rejects_float_for_int32(self):
w = self._make_writer()
with self.assertRaises(TypeError):
w.add_key_value("test", 1.0, gguf.GGUFValueType.INT32)
# --- error message quality ---
def test_error_message_includes_key_name_and_value(self):
"""Error message must show the key and actual value for easy debugging."""
w = self._make_writer()
with self.assertRaises(TypeError) as ctx:
w.add_key_value("general.license", ["Apache-2.0"], gguf.GGUFValueType.STRING)
msg = str(ctx.exception)
self.assertIn("general.license", msg)
self.assertIn("list", msg)
self.assertIn("['Apache-2.0']", msg)
if __name__ == "__main__":
unittest.main()