mirror of https://github.com/tiangolo/fastapi.git
Refactor code for improved readability and update tests to handle Pydantic v1 and v2 differences.
This commit is contained in:
parent
6b907a57f8
commit
93e98d5cb7
|
|
@ -85,8 +85,7 @@ def _prepare_response_content(
|
||||||
exclude_none: bool = False,
|
exclude_none: bool = False,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
if isinstance(res, BaseModel):
|
if isinstance(res, BaseModel):
|
||||||
read_with_orm_mode = getattr(
|
read_with_orm_mode = getattr(_get_model_config(res), "read_with_orm_mode", None)
|
||||||
_get_model_config(res), "read_with_orm_mode", None)
|
|
||||||
if read_with_orm_mode:
|
if read_with_orm_mode:
|
||||||
# Let from_orm extract the data from this model instead of converting
|
# Let from_orm extract the data from this model instead of converting
|
||||||
# it now to a dict.
|
# it now to a dict.
|
||||||
|
|
@ -165,8 +164,7 @@ async def serialize_response(
|
||||||
exclude_none=exclude_none,
|
exclude_none=exclude_none,
|
||||||
)
|
)
|
||||||
if is_coroutine:
|
if is_coroutine:
|
||||||
value, errors_ = field.validate(
|
value, errors_ = field.validate(response_content, {}, loc=("response",))
|
||||||
response_content, {}, loc=("response",))
|
|
||||||
else:
|
else:
|
||||||
value, errors_ = await run_in_threadpool(
|
value, errors_ = await run_in_threadpool(
|
||||||
field.validate, response_content, {}, loc=("response",)
|
field.validate, response_content, {}, loc=("response",)
|
||||||
|
|
@ -221,8 +219,7 @@ def get_request_handler(
|
||||||
dependant: Dependant,
|
dependant: Dependant,
|
||||||
body_field: Optional[ModelField] = None,
|
body_field: Optional[ModelField] = None,
|
||||||
status_code: Optional[int] = None,
|
status_code: Optional[int] = None,
|
||||||
response_class: Union[Type[Response],
|
response_class: Union[Type[Response], DefaultPlaceholder] = Default(JSONResponse),
|
||||||
DefaultPlaceholder] = Default(JSONResponse),
|
|
||||||
response_field: Optional[ModelField] = None,
|
response_field: Optional[ModelField] = None,
|
||||||
response_model_include: Optional[IncEx] = None,
|
response_model_include: Optional[IncEx] = None,
|
||||||
response_model_exclude: Optional[IncEx] = None,
|
response_model_exclude: Optional[IncEx] = None,
|
||||||
|
|
@ -235,8 +232,7 @@ def get_request_handler(
|
||||||
) -> Callable[[Request], Coroutine[Any, Any, Response]]:
|
) -> Callable[[Request], Coroutine[Any, Any, Response]]:
|
||||||
assert dependant.call is not None, "dependant.call must be a function"
|
assert dependant.call is not None, "dependant.call must be a function"
|
||||||
is_coroutine = asyncio.iscoroutinefunction(dependant.call)
|
is_coroutine = asyncio.iscoroutinefunction(dependant.call)
|
||||||
is_body_form = body_field and isinstance(
|
is_body_form = body_field and isinstance(body_field.field_info, params.Form)
|
||||||
body_field.field_info, params.Form)
|
|
||||||
if isinstance(response_class, DefaultPlaceholder):
|
if isinstance(response_class, DefaultPlaceholder):
|
||||||
actual_response_class: Type[Response] = response_class.value
|
actual_response_class: Type[Response] = response_class.value
|
||||||
else:
|
else:
|
||||||
|
|
@ -255,8 +251,7 @@ def get_request_handler(
|
||||||
body_bytes = await request.body()
|
body_bytes = await request.body()
|
||||||
if body_bytes:
|
if body_bytes:
|
||||||
json_body: Any = Undefined
|
json_body: Any = Undefined
|
||||||
content_type_value = request.headers.get(
|
content_type_value = request.headers.get("content-type")
|
||||||
"content-type")
|
|
||||||
if not content_type_value:
|
if not content_type_value:
|
||||||
json_body = await request.json()
|
json_body = await request.json()
|
||||||
else:
|
else:
|
||||||
|
|
@ -273,8 +268,7 @@ def get_request_handler(
|
||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
lines_before = e.doc[: e.pos].split("\n")
|
lines_before = e.doc[: e.pos].split("\n")
|
||||||
line_number = len(lines_before)
|
line_number = len(lines_before)
|
||||||
column_number = len(
|
column_number = len(lines_before[-1]) + 1 if lines_before else 1
|
||||||
lines_before[-1]) + 1 if lines_before else 1
|
|
||||||
|
|
||||||
start_pos = max(0, e.pos - 40)
|
start_pos = max(0, e.pos - 40)
|
||||||
end_pos = min(len(e.doc), e.pos + 40)
|
end_pos = min(len(e.doc), e.pos + 40)
|
||||||
|
|
@ -359,12 +353,10 @@ def get_request_handler(
|
||||||
exclude_none=response_model_exclude_none,
|
exclude_none=response_model_exclude_none,
|
||||||
is_coroutine=is_coroutine,
|
is_coroutine=is_coroutine,
|
||||||
)
|
)
|
||||||
response = actual_response_class(
|
response = actual_response_class(content, **response_args)
|
||||||
content, **response_args)
|
|
||||||
if not is_body_allowed_for_status_code(response.status_code):
|
if not is_body_allowed_for_status_code(response.status_code):
|
||||||
response.body = b""
|
response.body = b""
|
||||||
response.headers.raw.extend(
|
response.headers.raw.extend(solved_result.response.headers.raw)
|
||||||
solved_result.response.headers.raw)
|
|
||||||
if errors:
|
if errors:
|
||||||
validation_error = RequestValidationError(
|
validation_error = RequestValidationError(
|
||||||
_normalize_errors(errors), body=body
|
_normalize_errors(errors), body=body
|
||||||
|
|
@ -425,15 +417,12 @@ class APIWebSocketRoute(routing.WebSocketRoute):
|
||||||
self.endpoint = endpoint
|
self.endpoint = endpoint
|
||||||
self.name = get_name(endpoint) if name is None else name
|
self.name = get_name(endpoint) if name is None else name
|
||||||
self.dependencies = list(dependencies or [])
|
self.dependencies = list(dependencies or [])
|
||||||
self.path_regex, self.path_format, self.param_convertors = compile_path(
|
self.path_regex, self.path_format, self.param_convertors = compile_path(path)
|
||||||
path)
|
self.dependant = get_dependant(path=self.path_format, call=self.endpoint)
|
||||||
self.dependant = get_dependant(
|
|
||||||
path=self.path_format, call=self.endpoint)
|
|
||||||
for depends in self.dependencies[::-1]:
|
for depends in self.dependencies[::-1]:
|
||||||
self.dependant.dependencies.insert(
|
self.dependant.dependencies.insert(
|
||||||
0,
|
0,
|
||||||
get_parameterless_sub_dependant(
|
get_parameterless_sub_dependant(depends=depends, path=self.path_format),
|
||||||
depends=depends, path=self.path_format),
|
|
||||||
)
|
)
|
||||||
self._flat_dependant = get_flat_dependant(self.dependant)
|
self._flat_dependant = get_flat_dependant(self.dependant)
|
||||||
self._embed_body_fields = _should_embed_body_fields(
|
self._embed_body_fields = _should_embed_body_fields(
|
||||||
|
|
@ -517,8 +506,7 @@ class APIRoute(routing.Route):
|
||||||
self.tags = tags or []
|
self.tags = tags or []
|
||||||
self.responses = responses or {}
|
self.responses = responses or {}
|
||||||
self.name = get_name(endpoint) if name is None else name
|
self.name = get_name(endpoint) if name is None else name
|
||||||
self.path_regex, self.path_format, self.param_convertors = compile_path(
|
self.path_regex, self.path_format, self.param_convertors = compile_path(path)
|
||||||
path)
|
|
||||||
if methods is None:
|
if methods is None:
|
||||||
methods = ["GET"]
|
methods = ["GET"]
|
||||||
self.methods: Set[str] = {method.upper() for method in methods}
|
self.methods: Set[str] = {method.upper() for method in methods}
|
||||||
|
|
@ -558,15 +546,13 @@ class APIRoute(routing.Route):
|
||||||
self.response_field = None # type: ignore
|
self.response_field = None # type: ignore
|
||||||
self.secure_cloned_response_field = None
|
self.secure_cloned_response_field = None
|
||||||
self.dependencies = list(dependencies or [])
|
self.dependencies = list(dependencies or [])
|
||||||
self.description = description or inspect.cleandoc(
|
self.description = description or inspect.cleandoc(self.endpoint.__doc__ or "")
|
||||||
self.endpoint.__doc__ or "")
|
|
||||||
# if a "form feed" character (page break) is found in the description text,
|
# if a "form feed" character (page break) is found in the description text,
|
||||||
# truncate description text to the content preceding the first "form feed"
|
# truncate description text to the content preceding the first "form feed"
|
||||||
self.description = self.description.split("\f")[0].strip()
|
self.description = self.description.split("\f")[0].strip()
|
||||||
response_fields = {}
|
response_fields = {}
|
||||||
for additional_status_code, response in self.responses.items():
|
for additional_status_code, response in self.responses.items():
|
||||||
assert isinstance(
|
assert isinstance(response, dict), "An additional response must be a dict"
|
||||||
response, dict), "An additional response must be a dict"
|
|
||||||
model = response.get("model")
|
model = response.get("model")
|
||||||
if model:
|
if model:
|
||||||
assert is_body_allowed_for_status_code(additional_status_code), (
|
assert is_body_allowed_for_status_code(additional_status_code), (
|
||||||
|
|
@ -578,19 +564,16 @@ class APIRoute(routing.Route):
|
||||||
)
|
)
|
||||||
response_fields[additional_status_code] = response_field
|
response_fields[additional_status_code] = response_field
|
||||||
if response_fields:
|
if response_fields:
|
||||||
self.response_fields: Dict[Union[int,
|
self.response_fields: Dict[Union[int, str], ModelField] = response_fields
|
||||||
str], ModelField] = response_fields
|
|
||||||
else:
|
else:
|
||||||
self.response_fields = {}
|
self.response_fields = {}
|
||||||
|
|
||||||
assert callable(endpoint), "An endpoint must be a callable"
|
assert callable(endpoint), "An endpoint must be a callable"
|
||||||
self.dependant = get_dependant(
|
self.dependant = get_dependant(path=self.path_format, call=self.endpoint)
|
||||||
path=self.path_format, call=self.endpoint)
|
|
||||||
for depends in self.dependencies[::-1]:
|
for depends in self.dependencies[::-1]:
|
||||||
self.dependant.dependencies.insert(
|
self.dependant.dependencies.insert(
|
||||||
0,
|
0,
|
||||||
get_parameterless_sub_dependant(
|
get_parameterless_sub_dependant(depends=depends, path=self.path_format),
|
||||||
depends=depends, path=self.path_format),
|
|
||||||
)
|
)
|
||||||
self._flat_dependant = get_flat_dependant(self.dependant)
|
self._flat_dependant = get_flat_dependant(self.dependant)
|
||||||
self._embed_body_fields = _should_embed_body_fields(
|
self._embed_body_fields = _should_embed_body_fields(
|
||||||
|
|
@ -657,8 +640,7 @@ class APIRouter(routing.Router):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
prefix: Annotated[str, Doc(
|
prefix: Annotated[str, Doc("An optional path prefix for the router.")] = "",
|
||||||
"An optional path prefix for the router.")] = "",
|
|
||||||
tags: Annotated[
|
tags: Annotated[
|
||||||
Optional[List[Union[str, Enum]]],
|
Optional[List[Union[str, Enum]]],
|
||||||
Doc(
|
Doc(
|
||||||
|
|
@ -1159,8 +1141,7 @@ class APIRouter(routing.Router):
|
||||||
self,
|
self,
|
||||||
router: Annotated["APIRouter", Doc("The `APIRouter` to include.")],
|
router: Annotated["APIRouter", Doc("The `APIRouter` to include.")],
|
||||||
*,
|
*,
|
||||||
prefix: Annotated[str, Doc(
|
prefix: Annotated[str, Doc("An optional path prefix for the router.")] = "",
|
||||||
"An optional path prefix for the router.")] = "",
|
|
||||||
tags: Annotated[
|
tags: Annotated[
|
||||||
Optional[List[Union[str, Enum]]],
|
Optional[List[Union[str, Enum]]],
|
||||||
Doc(
|
Doc(
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
|
from dirty_equals import IsDict
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
@ -80,11 +81,30 @@ def test_json_decode_error_empty_body():
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == 422
|
assert response.status_code == 422
|
||||||
error = response.json()["detail"][0]
|
# Handle both Pydantic v1 and v2 - empty body is handled differently
|
||||||
|
assert response.json() == IsDict(
|
||||||
# Empty body is handled differently, not as a JSON decode error
|
{
|
||||||
assert error["loc"] == ["body"]
|
"detail": [
|
||||||
assert error["type"] == "missing"
|
{
|
||||||
|
"loc": ["body"],
|
||||||
|
"msg": "Field required",
|
||||||
|
"type": "missing",
|
||||||
|
"input": None,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
) | IsDict(
|
||||||
|
# Pydantic v1
|
||||||
|
{
|
||||||
|
"detail": [
|
||||||
|
{
|
||||||
|
"loc": ["body"],
|
||||||
|
"msg": "field required",
|
||||||
|
"type": "value_error.missing",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_json_decode_error_unclosed_brace():
|
def test_json_decode_error_unclosed_brace():
|
||||||
|
|
|
||||||
|
|
@ -60,8 +60,7 @@ def test_post_with_str_float_description(client: TestClient):
|
||||||
def test_post_with_str_float_description_tax(client: TestClient):
|
def test_post_with_str_float_description_tax(client: TestClient):
|
||||||
response = client.post(
|
response = client.post(
|
||||||
"/items/",
|
"/items/",
|
||||||
json={"name": "Foo", "price": "50.5",
|
json={"name": "Foo", "price": "50.5", "description": "Some Foo", "tax": 0.3},
|
||||||
"description": "Some Foo", "tax": 0.3},
|
|
||||||
)
|
)
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.json() == {
|
assert response.json() == {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue