mirror of https://github.com/tiangolo/fastapi.git
✨ Add support for subtypes of main types in jsonable_encoder
This commit is contained in:
parent
08fc2a41ca
commit
b85b2e3942
|
|
@ -1,6 +1,6 @@
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from types import GeneratorType
|
from types import GeneratorType
|
||||||
from typing import Any, Dict, List, Set, Union
|
from typing import Any, Callable, Dict, List, Set, Tuple, Union
|
||||||
|
|
||||||
from fastapi.logger import logger
|
from fastapi.logger import logger
|
||||||
from fastapi.utils import PYDANTIC_1
|
from fastapi.utils import PYDANTIC_1
|
||||||
|
|
@ -11,6 +11,21 @@ SetIntStr = Set[Union[int, str]]
|
||||||
DictIntStrAny = Dict[Union[int, str], Any]
|
DictIntStrAny = Dict[Union[int, str], Any]
|
||||||
|
|
||||||
|
|
||||||
|
def generate_encoders_by_class_tuples(
|
||||||
|
type_encoder_map: Dict[Any, Callable]
|
||||||
|
) -> Dict[Callable, Tuple]:
|
||||||
|
encoders_by_classes: Dict[Callable, List] = {}
|
||||||
|
for type_, encoder in type_encoder_map.items():
|
||||||
|
encoders_by_classes.setdefault(encoder, []).append(type_)
|
||||||
|
encoders_by_class_tuples: Dict[Callable, Tuple] = {}
|
||||||
|
for encoder, classes in encoders_by_classes.items():
|
||||||
|
encoders_by_class_tuples[encoder] = tuple(classes)
|
||||||
|
return encoders_by_class_tuples
|
||||||
|
|
||||||
|
|
||||||
|
encoders_by_class_tuples = generate_encoders_by_class_tuples(ENCODERS_BY_TYPE)
|
||||||
|
|
||||||
|
|
||||||
def jsonable_encoder(
|
def jsonable_encoder(
|
||||||
obj: Any,
|
obj: Any,
|
||||||
include: Union[SetIntStr, DictIntStrAny] = None,
|
include: Union[SetIntStr, DictIntStrAny] = None,
|
||||||
|
|
@ -106,15 +121,22 @@ def jsonable_encoder(
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return encoded_list
|
return encoded_list
|
||||||
errors: List[Exception] = []
|
|
||||||
try:
|
if custom_encoder:
|
||||||
if custom_encoder and type(obj) in custom_encoder:
|
if type(obj) in custom_encoder:
|
||||||
encoder = custom_encoder[type(obj)]
|
return custom_encoder[type(obj)](obj)
|
||||||
else:
|
else:
|
||||||
encoder = ENCODERS_BY_TYPE[type(obj)]
|
for encoder_type, encoder in custom_encoder.items():
|
||||||
|
if isinstance(obj, encoder_type):
|
||||||
return encoder(obj)
|
return encoder(obj)
|
||||||
except KeyError as e:
|
|
||||||
errors.append(e)
|
if type(obj) in ENCODERS_BY_TYPE:
|
||||||
|
return ENCODERS_BY_TYPE[type(obj)](obj)
|
||||||
|
for encoder, classes_tuple in encoders_by_class_tuples.items():
|
||||||
|
if isinstance(obj, classes_tuple):
|
||||||
|
return encoder(obj)
|
||||||
|
|
||||||
|
errors: List[Exception] = []
|
||||||
try:
|
try:
|
||||||
data = dict(obj)
|
data = dict(obj)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,73 @@
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from starlette.testclient import TestClient
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
|
||||||
|
class MyUuid:
|
||||||
|
def __init__(self, uuid_string: str):
|
||||||
|
self.uuid = uuid_string
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return self.uuid
|
||||||
|
|
||||||
|
@property
|
||||||
|
def __class__(self):
|
||||||
|
return uuid.UUID
|
||||||
|
|
||||||
|
@property
|
||||||
|
def __dict__(self):
|
||||||
|
"""Spoof a missing __dict__ by raising TypeError, this is how
|
||||||
|
asyncpg.pgroto.pgproto.UUID behaves"""
|
||||||
|
raise TypeError("vars() argument must have __dict__ attribute")
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/fast_uuid")
|
||||||
|
def return_fast_uuid():
|
||||||
|
# I don't want to import asyncpg for this test so I made my own UUID
|
||||||
|
# Import asyncpg and uncomment the two lines below for the actual bug
|
||||||
|
|
||||||
|
# from asyncpg.pgproto import pgproto
|
||||||
|
# asyncpg_uuid = pgproto.UUID("a10ff360-3b1e-4984-a26f-d3ab460bdb51")
|
||||||
|
|
||||||
|
asyncpg_uuid = MyUuid("a10ff360-3b1e-4984-a26f-d3ab460bdb51")
|
||||||
|
assert isinstance(asyncpg_uuid, uuid.UUID)
|
||||||
|
assert type(asyncpg_uuid) != uuid.UUID
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
vars(asyncpg_uuid)
|
||||||
|
return {"fast_uuid": asyncpg_uuid}
|
||||||
|
|
||||||
|
|
||||||
|
class SomeCustomClass(BaseModel):
|
||||||
|
class Config:
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
json_encoders = {uuid.UUID: str}
|
||||||
|
|
||||||
|
a_uuid: MyUuid
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/get_custom_class")
|
||||||
|
def return_some_user():
|
||||||
|
# Test that the fix also works for custom pydantic classes
|
||||||
|
return SomeCustomClass(a_uuid=MyUuid("b8799909-f914-42de-91bc-95c819218d01"))
|
||||||
|
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
|
||||||
|
|
||||||
|
def test_dt():
|
||||||
|
with client:
|
||||||
|
response_simple = client.get("/fast_uuid")
|
||||||
|
response_pydantic = client.get("/get_custom_class")
|
||||||
|
|
||||||
|
assert response_simple.json() == {
|
||||||
|
"fast_uuid": "a10ff360-3b1e-4984-a26f-d3ab460bdb51"
|
||||||
|
}
|
||||||
|
|
||||||
|
assert response_pydantic.json() == {
|
||||||
|
"a_uuid": "b8799909-f914-42de-91bc-95c819218d01"
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue