diff --git a/fastapi/encoders.py b/fastapi/encoders.py index e20255c11..c6fae1577 100644 --- a/fastapi/encoders.py +++ b/fastapi/encoders.py @@ -1,7 +1,7 @@ import dataclasses import datetime from collections import defaultdict, deque -from collections.abc import Callable +from collections.abc import Callable, Mapping, Sequence from decimal import Decimal from enum import Enum from ipaddress import ( @@ -15,7 +15,10 @@ from ipaddress import ( from pathlib import Path, PurePath from re import Pattern from types import GeneratorType -from typing import Annotated, Any +from typing import ( + Annotated, + Any, +) from uuid import UUID from annotated_doc import Doc @@ -27,11 +30,23 @@ from pydantic.networks import AnyUrl, NameEmail from pydantic.types import SecretBytes, SecretStr from pydantic_core import PydanticUndefinedType +# Dropped support for Pydantic v1 so we can remove the try-except import and the related code +from pydantic_extra_types import color as et_color + from ._compat import ( Url, is_pydantic_v1_model_instance, ) +encoders_by_extra_type: dict[type[Any], Callable[[Any], Any]] = {et_color.Color: str} + +try: + from pydantic_extra_types import coordinate + + encoders_by_extra_type[coordinate.Coordinate] = str +except ImportError: + pass + # Taken from Pydantic v1 as is def isoformat(o: datetime.date | datetime.time) -> str: @@ -106,7 +121,9 @@ def generate_encoders_by_class_tuples( return encoders_by_class_tuples -encoders_by_class_tuples = generate_encoders_by_class_tuples(ENCODERS_BY_TYPE) +encoders_by_class_tuples = generate_encoders_by_class_tuples( + ENCODERS_BY_TYPE | encoders_by_extra_type +) def jsonable_encoder( @@ -198,6 +215,18 @@ def jsonable_encoder( """ ), ] = True, + named_tuple_as_dict: Annotated[ + bool, + Doc( + """ + Whether to encode named tuples as dicts instead of lists. + + This is useful when you want to preserve the field names of named tuples + in the JSON output, which can make it easier to understand and work with + the data on the client side. + """ + ), + ] = False, ) -> Any: """ Convert any object to something that can be encoded in JSON. @@ -239,6 +268,10 @@ def jsonable_encoder( exclude_defaults=exclude_defaults, sqlalchemy_safe=sqlalchemy_safe, ) + # The extra types have their own encoders, so we check for them before checking for dataclasses, + # because some of them are also dataclasses, and we want to use their custom encoders instead of encoding them as dataclasses. + if type(obj) in encoders_by_extra_type: + return encoders_by_extra_type[type(obj)](obj) if dataclasses.is_dataclass(obj): assert not isinstance(obj, type) obj_dict = dataclasses.asdict(obj) @@ -261,7 +294,7 @@ def jsonable_encoder( return obj if isinstance(obj, PydanticUndefinedType): return None - if isinstance(obj, dict): + if isinstance(obj, Mapping): encoded_dict = {} allowed_keys = set(obj.keys()) if include is not None: @@ -296,7 +329,28 @@ def jsonable_encoder( ) encoded_dict[encoded_key] = encoded_value return encoded_dict - if isinstance(obj, (list, set, frozenset, GeneratorType, tuple, deque)): + + # Check if it's a named tuple, and if so, encode it as a dict (instead of a list) if `named_tuple_as_dict` is `True`. + if ( + named_tuple_as_dict + and getattr(obj, "_asdict", None) is not None + and callable(obj._asdict) + ): + return jsonable_encoder( + obj._asdict(), + include=include, + exclude=exclude, + by_alias=by_alias, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + custom_encoder=custom_encoder, + sqlalchemy_safe=sqlalchemy_safe, + ) + + # Note that we check for `Sequence` and not `list` because we want to support any kind of sequence, like `list`, `tuple`, `set`, etc. + # Also, we check that it's not a `bytes` object, because `bytes` is also a `Sequence`, but we want to rely on the TYPE_ENCODERS for `bytes` and avoid code duplication. + if isinstance(obj, (Sequence, GeneratorType)) and not isinstance(obj, bytes): encoded_list = [] for item in obj: encoded_list.append( diff --git a/tests/test_jsonable_encoder.py b/tests/test_jsonable_encoder.py index 4528dff44..cab509aee 100644 --- a/tests/test_jsonable_encoder.py +++ b/tests/test_jsonable_encoder.py @@ -1,12 +1,13 @@ import warnings -from collections import deque +from collections import deque, namedtuple +from collections.abc import Sequence from dataclasses import dataclass from datetime import datetime, timezone from decimal import Decimal from enum import Enum from math import isinf, isnan from pathlib import PurePath, PurePosixPath, PureWindowsPath -from typing import Optional, TypedDict +from typing import NamedTuple, TypedDict import pytest from fastapi._compat import Undefined @@ -57,7 +58,7 @@ class RoleEnum(Enum): class ModelWithConfig(BaseModel): - role: Optional[RoleEnum] = None + role: RoleEnum | None = None model_config = {"use_enum_values": True} @@ -311,3 +312,124 @@ def test_encode_deque_encodes_child_models(): def test_encode_pydantic_undefined(): data = {"value": Undefined} assert jsonable_encoder(data) == {"value": None} + + +def test_encode_sequence(): + class SequenceModel(Sequence[str]): + def __init__(self, items: list[str]): + self._items = items + + def __getitem__(self, index: int | slice) -> str | Sequence[str]: + return self._items[index] + + def __len__(self) -> int: + return len(self._items) + + seq = SequenceModel(["item1", "item2", "item3"]) + assert len(seq) == 3 + assert jsonable_encoder(seq) == ["item1", "item2", "item3"] + + +def test_encode_bytes(): + assert jsonable_encoder(b"hello") == "hello" + + +def test_encode_bytes_in_dict(): + data = {"content": b"hello", "name": "test"} + assert jsonable_encoder(data) == {"content": "hello", "name": "test"} + + +def test_encode_list_of_bytes(): + data = [b"hello", b"world"] + assert jsonable_encoder(data) == ["hello", "world"] + + +def test_encode_generator(): + def gen(): + yield 1 + yield 2 + yield 3 + + assert jsonable_encoder(gen()) == [1, 2, 3] + + +def test_encode_generator_of_bytes(): + def gen(): + yield b"hello" + yield b"world" + + assert jsonable_encoder(gen()) == ["hello", "world"] + + +def test_encode_named_tuple_as_list(): + Point = namedtuple("Point", ["x", "y"]) + p = Point(1, 2) + assert jsonable_encoder(p) == [1, 2] + + +def test_encode_named_tuple_as_dict(): + Point = namedtuple("Point", ["x", "y"]) + p = Point(1, 2) + assert jsonable_encoder(p, named_tuple_as_dict=True) == {"x": 1, "y": 2} + + +def test_encode_typed_named_tuple_as_list(): + class Point(NamedTuple): + x: int + y: int + + p = Point(1, 2) + assert jsonable_encoder(p) == [1, 2] + + +def test_encode_typed_named_tuple_as_dict(): + class Point(NamedTuple): + x: int + y: int + + p = Point(1, 2) + assert jsonable_encoder(p, named_tuple_as_dict=True) == {"x": 1, "y": 2} + + +def test_encode_sqlalchemy_safe_filters_sa_keys(): + data = {"name": "test", "_sa_instance_state": "internal"} + assert jsonable_encoder(data, sqlalchemy_safe=True) == {"name": "test"} + assert jsonable_encoder(data, sqlalchemy_safe=False) == { + "name": "test", + "_sa_instance_state": "internal", + } + + +def test_encode_sqlalchemy_row_as_list(): + sa = pytest.importorskip("sqlalchemy") + engine = sa.create_engine("sqlite:///:memory:") + with engine.connect() as conn: + row = conn.execute(sa.text("SELECT 1 AS x, 2 AS y")).fetchone() + engine.dispose() + assert row is not None + assert jsonable_encoder(row) == [1, 2] + + +def test_encode_sqlalchemy_row_as_dict(): + sa = pytest.importorskip("sqlalchemy") + engine = sa.create_engine("sqlite:///:memory:") + with engine.connect() as conn: + row = conn.execute(sa.text("SELECT 1 AS x, 2 AS y")).fetchone() + engine.dispose() + assert row is not None + assert jsonable_encoder(row, named_tuple_as_dict=True) == {"x": 1, "y": 2} + + +def test_encode_pydantic_extra_types_coordinate(): + coordinate = pytest.importorskip("pydantic_extra_types.coordinate") + coord = coordinate.Coordinate(latitude=1.0, longitude=2.0) + # Dataclass output shouldn't be the result + assert jsonable_encoder(coord) != {"latitude": 1.0, "longitude": 2.0} + # The custom encoder should be the result + assert jsonable_encoder(coord) == str(coord) + + +def test_encode_pydantic_extra_types_color(): + et_color = pytest.importorskip("pydantic_extra_types.color") + color = et_color.Color("red") + assert jsonable_encoder(color) == str(color)