mirror of https://github.com/tiangolo/fastapi.git
✨ Add support for subtypes of main types in jsonable_encoder
This commit is contained in:
parent
08fc2a41ca
commit
b85b2e3942
|
|
@ -1,6 +1,6 @@
|
|||
from enum import Enum
|
||||
from types import GeneratorType
|
||||
from typing import Any, Dict, List, Set, Union
|
||||
from typing import Any, Callable, Dict, List, Set, Tuple, Union
|
||||
|
||||
from fastapi.logger import logger
|
||||
from fastapi.utils import PYDANTIC_1
|
||||
|
|
@ -11,6 +11,21 @@ SetIntStr = Set[Union[int, str]]
|
|||
DictIntStrAny = Dict[Union[int, str], Any]
|
||||
|
||||
|
||||
def generate_encoders_by_class_tuples(
|
||||
type_encoder_map: Dict[Any, Callable]
|
||||
) -> Dict[Callable, Tuple]:
|
||||
encoders_by_classes: Dict[Callable, List] = {}
|
||||
for type_, encoder in type_encoder_map.items():
|
||||
encoders_by_classes.setdefault(encoder, []).append(type_)
|
||||
encoders_by_class_tuples: Dict[Callable, Tuple] = {}
|
||||
for encoder, classes in encoders_by_classes.items():
|
||||
encoders_by_class_tuples[encoder] = tuple(classes)
|
||||
return encoders_by_class_tuples
|
||||
|
||||
|
||||
encoders_by_class_tuples = generate_encoders_by_class_tuples(ENCODERS_BY_TYPE)
|
||||
|
||||
|
||||
def jsonable_encoder(
|
||||
obj: Any,
|
||||
include: Union[SetIntStr, DictIntStrAny] = None,
|
||||
|
|
@ -106,24 +121,31 @@ def jsonable_encoder(
|
|||
)
|
||||
)
|
||||
return encoded_list
|
||||
|
||||
if custom_encoder:
|
||||
if type(obj) in custom_encoder:
|
||||
return custom_encoder[type(obj)](obj)
|
||||
else:
|
||||
for encoder_type, encoder in custom_encoder.items():
|
||||
if isinstance(obj, encoder_type):
|
||||
return encoder(obj)
|
||||
|
||||
if type(obj) in ENCODERS_BY_TYPE:
|
||||
return ENCODERS_BY_TYPE[type(obj)](obj)
|
||||
for encoder, classes_tuple in encoders_by_class_tuples.items():
|
||||
if isinstance(obj, classes_tuple):
|
||||
return encoder(obj)
|
||||
|
||||
errors: List[Exception] = []
|
||||
try:
|
||||
if custom_encoder and type(obj) in custom_encoder:
|
||||
encoder = custom_encoder[type(obj)]
|
||||
else:
|
||||
encoder = ENCODERS_BY_TYPE[type(obj)]
|
||||
return encoder(obj)
|
||||
except KeyError as e:
|
||||
data = dict(obj)
|
||||
except Exception as e:
|
||||
errors.append(e)
|
||||
try:
|
||||
data = dict(obj)
|
||||
data = vars(obj)
|
||||
except Exception as e:
|
||||
errors.append(e)
|
||||
try:
|
||||
data = vars(obj)
|
||||
except Exception as e:
|
||||
errors.append(e)
|
||||
raise ValueError(errors)
|
||||
raise ValueError(errors)
|
||||
return jsonable_encoder(
|
||||
data,
|
||||
by_alias=by_alias,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,73 @@
|
|||
import uuid
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from pydantic import BaseModel
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
class MyUuid:
|
||||
def __init__(self, uuid_string: str):
|
||||
self.uuid = uuid_string
|
||||
|
||||
def __str__(self):
|
||||
return self.uuid
|
||||
|
||||
@property
|
||||
def __class__(self):
|
||||
return uuid.UUID
|
||||
|
||||
@property
|
||||
def __dict__(self):
|
||||
"""Spoof a missing __dict__ by raising TypeError, this is how
|
||||
asyncpg.pgroto.pgproto.UUID behaves"""
|
||||
raise TypeError("vars() argument must have __dict__ attribute")
|
||||
|
||||
|
||||
@app.get("/fast_uuid")
|
||||
def return_fast_uuid():
|
||||
# I don't want to import asyncpg for this test so I made my own UUID
|
||||
# Import asyncpg and uncomment the two lines below for the actual bug
|
||||
|
||||
# from asyncpg.pgproto import pgproto
|
||||
# asyncpg_uuid = pgproto.UUID("a10ff360-3b1e-4984-a26f-d3ab460bdb51")
|
||||
|
||||
asyncpg_uuid = MyUuid("a10ff360-3b1e-4984-a26f-d3ab460bdb51")
|
||||
assert isinstance(asyncpg_uuid, uuid.UUID)
|
||||
assert type(asyncpg_uuid) != uuid.UUID
|
||||
with pytest.raises(TypeError):
|
||||
vars(asyncpg_uuid)
|
||||
return {"fast_uuid": asyncpg_uuid}
|
||||
|
||||
|
||||
class SomeCustomClass(BaseModel):
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
json_encoders = {uuid.UUID: str}
|
||||
|
||||
a_uuid: MyUuid
|
||||
|
||||
|
||||
@app.get("/get_custom_class")
|
||||
def return_some_user():
|
||||
# Test that the fix also works for custom pydantic classes
|
||||
return SomeCustomClass(a_uuid=MyUuid("b8799909-f914-42de-91bc-95c819218d01"))
|
||||
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
def test_dt():
|
||||
with client:
|
||||
response_simple = client.get("/fast_uuid")
|
||||
response_pydantic = client.get("/get_custom_class")
|
||||
|
||||
assert response_simple.json() == {
|
||||
"fast_uuid": "a10ff360-3b1e-4984-a26f-d3ab460bdb51"
|
||||
}
|
||||
|
||||
assert response_pydantic.json() == {
|
||||
"a_uuid": "b8799909-f914-42de-91bc-95c819218d01"
|
||||
}
|
||||
Loading…
Reference in New Issue