For extra-types safety

This commit is contained in:
Pedro Lobato 2026-02-10 10:04:37 -05:00
parent 3c1866474d
commit 74417cec5c
1 changed files with 11 additions and 3 deletions

View File

@ -34,9 +34,15 @@ from pydantic.types import SecretBytes, SecretStr
from pydantic_core import PydanticUndefinedType from pydantic_core import PydanticUndefinedType
try: try:
from pydantic_extra_types.coordinate import Coordinate from pydantic_extra_types import color as et_color
from pydantic_extra_types import coordinate
encoders_by_extra_type: dict[type[Any], Callable[[Any], Any]] = {
coordinate.Coordinate: str,
et_color.Color: str
}
except ImportError: except ImportError:
Coordinate = dict encoders_by_extra_type = {}
from ._compat import ( from ._compat import (
Url, Url,
@ -79,7 +85,6 @@ def decimal_encoder(dec_value: Decimal) -> Union[int, float]:
ENCODERS_BY_TYPE: dict[type[Any], Callable[[Any], Any]] = { ENCODERS_BY_TYPE: dict[type[Any], Callable[[Any], Any]] = {
bytes: lambda o: o.decode(), bytes: lambda o: o.decode(),
Color: str, Color: str,
Coordinate: str,
datetime.date: isoformat, datetime.date: isoformat,
datetime.datetime: isoformat, datetime.datetime: isoformat,
datetime.time: isoformat, datetime.time: isoformat,
@ -119,6 +124,7 @@ def generate_encoders_by_class_tuples(
encoders_by_class_tuples = generate_encoders_by_class_tuples(ENCODERS_BY_TYPE) encoders_by_class_tuples = generate_encoders_by_class_tuples(ENCODERS_BY_TYPE)
encoders_by_class_tuples.update(generate_encoders_by_class_tuples(encoders_by_extra_type))
def jsonable_encoder( def jsonable_encoder(
@ -358,6 +364,8 @@ def jsonable_encoder(
sqlalchemy_safe=sqlalchemy_safe, sqlalchemy_safe=sqlalchemy_safe,
) )
if type(obj) in encoders_by_extra_type:
return encoders_by_extra_type[type(obj)](obj)
if type(obj) in ENCODERS_BY_TYPE: if type(obj) in ENCODERS_BY_TYPE:
return ENCODERS_BY_TYPE[type(obj)](obj) return ENCODERS_BY_TYPE[type(obj)](obj)
for encoder, classes_tuple in encoders_by_class_tuples.items(): for encoder, classes_tuple in encoders_by_class_tuples.items():