mirror of https://github.com/tiangolo/fastapi.git
Don't revalidate the response content if it is of the same type as response_model
This commit is contained in:
parent
3824664620
commit
ec08e235ea
|
|
@ -143,6 +143,7 @@ def _merge_lifespan_context(
|
||||||
async def serialize_response(
|
async def serialize_response(
|
||||||
*,
|
*,
|
||||||
field: Optional[ModelField] = None,
|
field: Optional[ModelField] = None,
|
||||||
|
response_model: Any = Default(None),
|
||||||
response_content: Any,
|
response_content: Any,
|
||||||
include: Optional[IncEx] = None,
|
include: Optional[IncEx] = None,
|
||||||
exclude: Optional[IncEx] = None,
|
exclude: Optional[IncEx] = None,
|
||||||
|
|
@ -152,7 +153,10 @@ async def serialize_response(
|
||||||
exclude_none: bool = False,
|
exclude_none: bool = False,
|
||||||
is_coroutine: bool = True,
|
is_coroutine: bool = True,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
if field:
|
if not field:
|
||||||
|
return jsonable_encoder(response_content)
|
||||||
|
|
||||||
|
if type(response_content) is not response_model:
|
||||||
errors = []
|
errors = []
|
||||||
if not hasattr(field, "serialize"):
|
if not hasattr(field, "serialize"):
|
||||||
# pydantic v1
|
# pydantic v1
|
||||||
|
|
@ -187,18 +191,18 @@ async def serialize_response(
|
||||||
exclude_defaults=exclude_defaults,
|
exclude_defaults=exclude_defaults,
|
||||||
exclude_none=exclude_none,
|
exclude_none=exclude_none,
|
||||||
)
|
)
|
||||||
|
|
||||||
return jsonable_encoder(
|
|
||||||
value,
|
|
||||||
include=include,
|
|
||||||
exclude=exclude,
|
|
||||||
by_alias=by_alias,
|
|
||||||
exclude_unset=exclude_unset,
|
|
||||||
exclude_defaults=exclude_defaults,
|
|
||||||
exclude_none=exclude_none,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
return jsonable_encoder(response_content)
|
value = response_content
|
||||||
|
|
||||||
|
return jsonable_encoder(
|
||||||
|
value,
|
||||||
|
include=include,
|
||||||
|
exclude=exclude,
|
||||||
|
by_alias=by_alias,
|
||||||
|
exclude_unset=exclude_unset,
|
||||||
|
exclude_defaults=exclude_defaults,
|
||||||
|
exclude_none=exclude_none,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def run_endpoint_function(
|
async def run_endpoint_function(
|
||||||
|
|
@ -220,6 +224,7 @@ def get_request_handler(
|
||||||
status_code: Optional[int] = None,
|
status_code: Optional[int] = None,
|
||||||
response_class: Union[Type[Response], DefaultPlaceholder] = Default(JSONResponse),
|
response_class: Union[Type[Response], DefaultPlaceholder] = Default(JSONResponse),
|
||||||
response_field: Optional[ModelField] = None,
|
response_field: Optional[ModelField] = None,
|
||||||
|
response_model: Any = Default(None),
|
||||||
response_model_include: Optional[IncEx] = None,
|
response_model_include: Optional[IncEx] = None,
|
||||||
response_model_exclude: Optional[IncEx] = None,
|
response_model_exclude: Optional[IncEx] = None,
|
||||||
response_model_by_alias: bool = True,
|
response_model_by_alias: bool = True,
|
||||||
|
|
@ -327,6 +332,7 @@ def get_request_handler(
|
||||||
content = await serialize_response(
|
content = await serialize_response(
|
||||||
field=response_field,
|
field=response_field,
|
||||||
response_content=raw_response,
|
response_content=raw_response,
|
||||||
|
response_model=response_model,
|
||||||
include=response_model_include,
|
include=response_model_include,
|
||||||
exclude=response_model_exclude,
|
exclude=response_model_exclude,
|
||||||
by_alias=response_model_by_alias,
|
by_alias=response_model_by_alias,
|
||||||
|
|
@ -575,6 +581,7 @@ class APIRoute(routing.Route):
|
||||||
status_code=self.status_code,
|
status_code=self.status_code,
|
||||||
response_class=self.response_class,
|
response_class=self.response_class,
|
||||||
response_field=self.secure_cloned_response_field,
|
response_field=self.secure_cloned_response_field,
|
||||||
|
response_model=self.response_model,
|
||||||
response_model_include=self.response_model_include,
|
response_model_include=self.response_model_include,
|
||||||
response_model_exclude=self.response_model_exclude,
|
response_model_exclude=self.response_model_exclude,
|
||||||
response_model_by_alias=self.response_model_by_alias,
|
response_model_by_alias=self.response_model_by_alias,
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
|
from fastapi._compat import PYDANTIC_V2
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from starlette.testclient import TestClient
|
from starlette.testclient import TestClient
|
||||||
|
|
||||||
|
|
@ -152,3 +153,23 @@ def test_validdict_exclude_unset():
|
||||||
"k2": {"aliased_name": "bar", "price": 1.0},
|
"k2": {"aliased_name": "bar", "price": 1.0},
|
||||||
"k3": {"aliased_name": "baz", "price": 2.0, "owner_ids": [1, 2, 3]},
|
"k3": {"aliased_name": "baz", "price": 2.0, "owner_ids": [1, 2, 3]},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
if not PYDANTIC_V2:
|
||||||
|
from pydantic import validator
|
||||||
|
|
||||||
|
class AutoIncrement(BaseModel):
|
||||||
|
count: int
|
||||||
|
|
||||||
|
@validator("count")
|
||||||
|
def auto_increment(cls, count: int):
|
||||||
|
return count + 1
|
||||||
|
|
||||||
|
@app.post("/increment", response_model=AutoIncrement)
|
||||||
|
async def increment():
|
||||||
|
return AutoIncrement(count=0)
|
||||||
|
|
||||||
|
def test_response_model_should_not_revalidate_response_content_if_they_had_same_type():
|
||||||
|
response = client.post("/increment")
|
||||||
|
response.raise_for_status()
|
||||||
|
assert response.json() == {"count": 1}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue