mirror of https://github.com/tiangolo/fastapi.git
Merge af05854e91 into 272204c0c7
This commit is contained in:
commit
821e24d96b
|
|
@ -1,6 +1,7 @@
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import inspect
|
import inspect
|
||||||
import sys
|
import sys
|
||||||
|
import types
|
||||||
from contextlib import AsyncExitStack, contextmanager
|
from contextlib import AsyncExitStack, contextmanager
|
||||||
from copy import copy, deepcopy
|
from copy import copy, deepcopy
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
@ -27,6 +28,7 @@ from fastapi._compat import (
|
||||||
ModelField,
|
ModelField,
|
||||||
RequiredParam,
|
RequiredParam,
|
||||||
Undefined,
|
Undefined,
|
||||||
|
UndefinedType,
|
||||||
_is_error_wrapper,
|
_is_error_wrapper,
|
||||||
_is_model_class,
|
_is_model_class,
|
||||||
copy_field_info,
|
copy_field_info,
|
||||||
|
|
@ -595,7 +597,7 @@ async def solve_dependencies(
|
||||||
*,
|
*,
|
||||||
request: Union[Request, WebSocket],
|
request: Union[Request, WebSocket],
|
||||||
dependant: Dependant,
|
dependant: Dependant,
|
||||||
body: Optional[Union[Dict[str, Any], FormData]] = None,
|
body: Optional[Union[Dict[str, Any], FormData, UndefinedType]] = Undefined,
|
||||||
background_tasks: Optional[StarletteBackgroundTasks] = None,
|
background_tasks: Optional[StarletteBackgroundTasks] = None,
|
||||||
response: Optional[Response] = None,
|
response: Optional[Response] = None,
|
||||||
dependency_overrides_provider: Optional[Any] = None,
|
dependency_overrides_provider: Optional[Any] = None,
|
||||||
|
|
@ -731,10 +733,36 @@ async def solve_dependencies(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if PYDANTIC_V2:
|
||||||
|
if sys.hexversion >= 0x30A0000:
|
||||||
|
|
||||||
|
def _allows_none(field: ModelField) -> bool:
|
||||||
|
origin = get_origin(field.type_)
|
||||||
|
return (origin is Union or origin is types.UnionType) and type(
|
||||||
|
None
|
||||||
|
) in get_args(field.type_)
|
||||||
|
else:
|
||||||
|
|
||||||
|
def _allows_none(field: ModelField) -> bool:
|
||||||
|
origin = get_origin(field.type_)
|
||||||
|
return origin is Union and type(None) in get_args(field.type_)
|
||||||
|
else:
|
||||||
|
|
||||||
|
def _allows_none(field: ModelField) -> bool:
|
||||||
|
return field.allow_none # type: ignore
|
||||||
|
|
||||||
|
|
||||||
def _validate_value_with_model_field(
|
def _validate_value_with_model_field(
|
||||||
*, field: ModelField, value: Any, values: Dict[str, Any], loc: Tuple[str, ...]
|
*, field: ModelField, value: Any, values: Dict[str, Any], loc: Tuple[str, ...]
|
||||||
) -> Tuple[Any, List[Any]]:
|
) -> Tuple[Any, List[Any]]:
|
||||||
|
if value is Undefined:
|
||||||
|
if field.required:
|
||||||
|
return None, [get_missing_field_error(loc=loc)]
|
||||||
|
else:
|
||||||
|
return deepcopy(field.default), []
|
||||||
if value is None:
|
if value is None:
|
||||||
|
if _allows_none(field):
|
||||||
|
return value, []
|
||||||
if field.required:
|
if field.required:
|
||||||
return None, [get_missing_field_error(loc=loc)]
|
return None, [get_missing_field_error(loc=loc)]
|
||||||
else:
|
else:
|
||||||
|
|
@ -753,12 +781,13 @@ def _get_multidict_value(
|
||||||
field: ModelField, values: Mapping[str, Any], alias: Union[str, None] = None
|
field: ModelField, values: Mapping[str, Any], alias: Union[str, None] = None
|
||||||
) -> Any:
|
) -> Any:
|
||||||
alias = alias or get_validation_alias(field)
|
alias = alias or get_validation_alias(field)
|
||||||
|
value: Any
|
||||||
if is_sequence_field(field) and isinstance(values, (ImmutableMultiDict, Headers)):
|
if is_sequence_field(field) and isinstance(values, (ImmutableMultiDict, Headers)):
|
||||||
value = values.getlist(alias)
|
value = values.getlist(alias)
|
||||||
else:
|
else:
|
||||||
value = values.get(alias, None)
|
value = values.get(alias, Undefined)
|
||||||
if (
|
if (
|
||||||
value is None
|
value is Undefined
|
||||||
or (
|
or (
|
||||||
isinstance(field.field_info, (params.Form, temp_pydantic_v1_params.Form))
|
isinstance(field.field_info, (params.Form, temp_pydantic_v1_params.Form))
|
||||||
and isinstance(value, str) # For type checks
|
and isinstance(value, str) # For type checks
|
||||||
|
|
@ -767,7 +796,7 @@ def _get_multidict_value(
|
||||||
or (is_sequence_field(field) and len(value) == 0)
|
or (is_sequence_field(field) and len(value) == 0)
|
||||||
):
|
):
|
||||||
if field.required:
|
if field.required:
|
||||||
return
|
return Undefined
|
||||||
else:
|
else:
|
||||||
return deepcopy(field.default)
|
return deepcopy(field.default)
|
||||||
return value
|
return value
|
||||||
|
|
@ -813,7 +842,7 @@ def request_params_to_args(
|
||||||
if alias == field.name:
|
if alias == field.name:
|
||||||
alias = alias.replace("_", "-")
|
alias = alias.replace("_", "-")
|
||||||
value = _get_multidict_value(field, received_params, alias=alias)
|
value = _get_multidict_value(field, received_params, alias=alias)
|
||||||
if value is not None:
|
if value is not Undefined and value is not None:
|
||||||
params_to_process[get_validation_alias(field)] = value
|
params_to_process[get_validation_alias(field)] = value
|
||||||
processed_keys.add(alias or get_validation_alias(field))
|
processed_keys.add(alias or get_validation_alias(field))
|
||||||
|
|
||||||
|
|
@ -933,7 +962,7 @@ async def _extract_form_body(
|
||||||
for sub_value in value:
|
for sub_value in value:
|
||||||
tg.start_soon(process_fn, sub_value.read)
|
tg.start_soon(process_fn, sub_value.read)
|
||||||
value = serialize_sequence_value(field=field, value=results)
|
value = serialize_sequence_value(field=field, value=results)
|
||||||
if value is not None:
|
if value is not Undefined and value is not None:
|
||||||
values[get_validation_alias(field)] = value
|
values[get_validation_alias(field)] = value
|
||||||
field_aliases = {get_validation_alias(field) for field in body_fields}
|
field_aliases = {get_validation_alias(field) for field in body_fields}
|
||||||
for key in received_body.keys():
|
for key in received_body.keys():
|
||||||
|
|
@ -948,7 +977,7 @@ async def _extract_form_body(
|
||||||
|
|
||||||
async def request_body_to_args(
|
async def request_body_to_args(
|
||||||
body_fields: List[ModelField],
|
body_fields: List[ModelField],
|
||||||
received_body: Optional[Union[Dict[str, Any], FormData]],
|
received_body: Optional[Union[Dict[str, Any], FormData, UndefinedType]],
|
||||||
embed_body_fields: bool,
|
embed_body_fields: bool,
|
||||||
) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]:
|
) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]:
|
||||||
values: Dict[str, Any] = {}
|
values: Dict[str, Any] = {}
|
||||||
|
|
@ -978,10 +1007,12 @@ async def request_body_to_args(
|
||||||
return {first_field.name: v_}, errors_
|
return {first_field.name: v_}, errors_
|
||||||
for field in body_fields:
|
for field in body_fields:
|
||||||
loc = ("body", get_validation_alias(field))
|
loc = ("body", get_validation_alias(field))
|
||||||
value: Optional[Any] = None
|
value: Optional[Any] = Undefined
|
||||||
if body_to_process is not None:
|
if body_to_process is not None and not isinstance(
|
||||||
|
body_to_process, UndefinedType
|
||||||
|
):
|
||||||
try:
|
try:
|
||||||
value = body_to_process.get(get_validation_alias(field))
|
value = body_to_process.get(get_validation_alias(field), Undefined)
|
||||||
# If the received body is a list, not a dict
|
# If the received body is a list, not a dict
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
errors.append(get_missing_field_error(loc))
|
errors.append(get_missing_field_error(loc))
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,101 @@
|
||||||
|
import sys
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from dirty_equals import IsDict
|
||||||
|
from fastapi import Body, FastAPI
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
DEFAULT = 1234567890
|
||||||
|
|
||||||
|
endpoints = []
|
||||||
|
|
||||||
|
if sys.hexversion >= 0x30A0000:
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
@app.post("/api1")
|
||||||
|
def api1(
|
||||||
|
integer_or_null: Annotated[int | None, Body(embed=True)] = DEFAULT,
|
||||||
|
) -> dict:
|
||||||
|
return {"received": integer_or_null}
|
||||||
|
|
||||||
|
endpoints.append("/api1")
|
||||||
|
|
||||||
|
|
||||||
|
if sys.hexversion >= 0x3090000:
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
@app.post("/api2")
|
||||||
|
def api2(
|
||||||
|
integer_or_null: Annotated[Optional[int], Body(embed=True)] = DEFAULT,
|
||||||
|
) -> dict:
|
||||||
|
return {"received": integer_or_null}
|
||||||
|
|
||||||
|
endpoints.append("/api2")
|
||||||
|
|
||||||
|
@app.post("/api3")
|
||||||
|
def api3(
|
||||||
|
integer_or_null: Annotated[Union[int, None], Body(embed=True)] = DEFAULT,
|
||||||
|
) -> dict:
|
||||||
|
return {"received": integer_or_null}
|
||||||
|
|
||||||
|
endpoints.append("/api3")
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/api4")
|
||||||
|
def api4(integer_or_null: Optional[int] = Body(embed=True, default=DEFAULT)) -> dict:
|
||||||
|
return {"received": integer_or_null}
|
||||||
|
|
||||||
|
|
||||||
|
endpoints.append("/api4")
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/api5")
|
||||||
|
def api5(integer: int = Body(embed=True)) -> dict:
|
||||||
|
return {"received": integer}
|
||||||
|
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("api", endpoints)
|
||||||
|
def test_apis(api):
|
||||||
|
response = client.post(api, json={"integer_or_null": 100})
|
||||||
|
assert response.status_code == 200, response.text
|
||||||
|
assert response.json() == {"received": 100}
|
||||||
|
|
||||||
|
response = client.post(api, json={"integer_or_null": None})
|
||||||
|
assert response.status_code == 200, response.text
|
||||||
|
assert response.json() == {"received": None}
|
||||||
|
|
||||||
|
|
||||||
|
def test_required_field():
|
||||||
|
response = client.post("/api5", json={"integer": 100})
|
||||||
|
assert response.status_code == 200, response.text
|
||||||
|
assert response.json() == {"received": 100}
|
||||||
|
|
||||||
|
response = client.post("/api5", json={"integer": None})
|
||||||
|
assert response.status_code == 422, response.text
|
||||||
|
assert response.json() == IsDict(
|
||||||
|
{
|
||||||
|
"detail": [
|
||||||
|
{
|
||||||
|
"loc": ["body", "integer"],
|
||||||
|
"msg": "Field required",
|
||||||
|
"type": "missing",
|
||||||
|
"input": None,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
) | IsDict(
|
||||||
|
{
|
||||||
|
"detail": [
|
||||||
|
{
|
||||||
|
"loc": ["body", "integer"],
|
||||||
|
"msg": "field required",
|
||||||
|
"type": "value_error.missing",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
Loading…
Reference in New Issue