mirror of https://github.com/tiangolo/fastapi.git
♻️ Refactor deciding if `embed` body fields, do not overwrite fields, compute once per router, refactor internals in preparation for Pydantic models in `Form`, `Query` and others (#12117)
This commit is contained in:
parent
7213d421f5
commit
aa21814a89
|
|
@ -279,6 +279,12 @@ if PYDANTIC_V2:
|
||||||
BodyModel: Type[BaseModel] = create_model(model_name, **field_params) # type: ignore[call-overload]
|
BodyModel: Type[BaseModel] = create_model(model_name, **field_params) # type: ignore[call-overload]
|
||||||
return BodyModel
|
return BodyModel
|
||||||
|
|
||||||
|
def get_model_fields(model: Type[BaseModel]) -> List[ModelField]:
|
||||||
|
return [
|
||||||
|
ModelField(field_info=field_info, name=name)
|
||||||
|
for name, field_info in model.model_fields.items()
|
||||||
|
]
|
||||||
|
|
||||||
else:
|
else:
|
||||||
from fastapi.openapi.constants import REF_PREFIX as REF_PREFIX
|
from fastapi.openapi.constants import REF_PREFIX as REF_PREFIX
|
||||||
from pydantic import AnyUrl as Url # noqa: F401
|
from pydantic import AnyUrl as Url # noqa: F401
|
||||||
|
|
@ -513,6 +519,9 @@ else:
|
||||||
BodyModel.__fields__[f.name] = f # type: ignore[index]
|
BodyModel.__fields__[f.name] = f # type: ignore[index]
|
||||||
return BodyModel
|
return BodyModel
|
||||||
|
|
||||||
|
def get_model_fields(model: Type[BaseModel]) -> List[ModelField]:
|
||||||
|
return list(model.__fields__.values()) # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
|
||||||
def _regenerate_error_with_loc(
|
def _regenerate_error_with_loc(
|
||||||
*, errors: Sequence[Any], loc_prefix: Tuple[Union[str, int], ...]
|
*, errors: Sequence[Any], loc_prefix: Tuple[Union[str, int], ...]
|
||||||
|
|
@ -532,6 +541,12 @@ def _annotation_is_sequence(annotation: Union[Type[Any], None]) -> bool:
|
||||||
|
|
||||||
|
|
||||||
def field_annotation_is_sequence(annotation: Union[Type[Any], None]) -> bool:
|
def field_annotation_is_sequence(annotation: Union[Type[Any], None]) -> bool:
|
||||||
|
origin = get_origin(annotation)
|
||||||
|
if origin is Union or origin is UnionType:
|
||||||
|
for arg in get_args(annotation):
|
||||||
|
if field_annotation_is_sequence(arg):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
return _annotation_is_sequence(annotation) or _annotation_is_sequence(
|
return _annotation_is_sequence(annotation) or _annotation_is_sequence(
|
||||||
get_origin(annotation)
|
get_origin(annotation)
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -59,7 +59,13 @@ from fastapi.utils import create_model_field, get_path_param_names
|
||||||
from pydantic.fields import FieldInfo
|
from pydantic.fields import FieldInfo
|
||||||
from starlette.background import BackgroundTasks as StarletteBackgroundTasks
|
from starlette.background import BackgroundTasks as StarletteBackgroundTasks
|
||||||
from starlette.concurrency import run_in_threadpool
|
from starlette.concurrency import run_in_threadpool
|
||||||
from starlette.datastructures import FormData, Headers, QueryParams, UploadFile
|
from starlette.datastructures import (
|
||||||
|
FormData,
|
||||||
|
Headers,
|
||||||
|
ImmutableMultiDict,
|
||||||
|
QueryParams,
|
||||||
|
UploadFile,
|
||||||
|
)
|
||||||
from starlette.requests import HTTPConnection, Request
|
from starlette.requests import HTTPConnection, Request
|
||||||
from starlette.responses import Response
|
from starlette.responses import Response
|
||||||
from starlette.websockets import WebSocket
|
from starlette.websockets import WebSocket
|
||||||
|
|
@ -282,7 +288,7 @@ def get_dependant(
|
||||||
), f"Cannot specify multiple FastAPI annotations for {param_name!r}"
|
), f"Cannot specify multiple FastAPI annotations for {param_name!r}"
|
||||||
continue
|
continue
|
||||||
assert param_details.field is not None
|
assert param_details.field is not None
|
||||||
if is_body_param(param_field=param_details.field, is_path_param=is_path_param):
|
if isinstance(param_details.field.field_info, params.Body):
|
||||||
dependant.body_params.append(param_details.field)
|
dependant.body_params.append(param_details.field)
|
||||||
else:
|
else:
|
||||||
add_param_to_fields(field=param_details.field, dependant=dependant)
|
add_param_to_fields(field=param_details.field, dependant=dependant)
|
||||||
|
|
@ -466,29 +472,16 @@ def analyze_param(
|
||||||
required=field_info.default in (Required, Undefined),
|
required=field_info.default in (Required, Undefined),
|
||||||
field_info=field_info,
|
field_info=field_info,
|
||||||
)
|
)
|
||||||
|
if is_path_param:
|
||||||
|
assert is_scalar_field(
|
||||||
|
field=field
|
||||||
|
), "Path params must be of one of the supported types"
|
||||||
|
elif isinstance(field_info, params.Query):
|
||||||
|
assert is_scalar_field(field) or is_scalar_sequence_field(field)
|
||||||
|
|
||||||
return ParamDetails(type_annotation=type_annotation, depends=depends, field=field)
|
return ParamDetails(type_annotation=type_annotation, depends=depends, field=field)
|
||||||
|
|
||||||
|
|
||||||
def is_body_param(*, param_field: ModelField, is_path_param: bool) -> bool:
|
|
||||||
if is_path_param:
|
|
||||||
assert is_scalar_field(
|
|
||||||
field=param_field
|
|
||||||
), "Path params must be of one of the supported types"
|
|
||||||
return False
|
|
||||||
elif is_scalar_field(field=param_field):
|
|
||||||
return False
|
|
||||||
elif isinstance(
|
|
||||||
param_field.field_info, (params.Query, params.Header)
|
|
||||||
) and is_scalar_sequence_field(param_field):
|
|
||||||
return False
|
|
||||||
else:
|
|
||||||
assert isinstance(
|
|
||||||
param_field.field_info, params.Body
|
|
||||||
), f"Param: {param_field.name} can only be a request body, using Body()"
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
def add_param_to_fields(*, field: ModelField, dependant: Dependant) -> None:
|
def add_param_to_fields(*, field: ModelField, dependant: Dependant) -> None:
|
||||||
field_info = field.field_info
|
field_info = field.field_info
|
||||||
field_info_in = getattr(field_info, "in_", None)
|
field_info_in = getattr(field_info, "in_", None)
|
||||||
|
|
@ -557,6 +550,7 @@ async def solve_dependencies(
|
||||||
dependency_overrides_provider: Optional[Any] = None,
|
dependency_overrides_provider: Optional[Any] = None,
|
||||||
dependency_cache: Optional[Dict[Tuple[Callable[..., Any], Tuple[str]], Any]] = None,
|
dependency_cache: Optional[Dict[Tuple[Callable[..., Any], Tuple[str]], Any]] = None,
|
||||||
async_exit_stack: AsyncExitStack,
|
async_exit_stack: AsyncExitStack,
|
||||||
|
embed_body_fields: bool,
|
||||||
) -> SolvedDependency:
|
) -> SolvedDependency:
|
||||||
values: Dict[str, Any] = {}
|
values: Dict[str, Any] = {}
|
||||||
errors: List[Any] = []
|
errors: List[Any] = []
|
||||||
|
|
@ -598,6 +592,7 @@ async def solve_dependencies(
|
||||||
dependency_overrides_provider=dependency_overrides_provider,
|
dependency_overrides_provider=dependency_overrides_provider,
|
||||||
dependency_cache=dependency_cache,
|
dependency_cache=dependency_cache,
|
||||||
async_exit_stack=async_exit_stack,
|
async_exit_stack=async_exit_stack,
|
||||||
|
embed_body_fields=embed_body_fields,
|
||||||
)
|
)
|
||||||
background_tasks = solved_result.background_tasks
|
background_tasks = solved_result.background_tasks
|
||||||
dependency_cache.update(solved_result.dependency_cache)
|
dependency_cache.update(solved_result.dependency_cache)
|
||||||
|
|
@ -640,7 +635,9 @@ async def solve_dependencies(
|
||||||
body_values,
|
body_values,
|
||||||
body_errors,
|
body_errors,
|
||||||
) = await request_body_to_args( # body_params checked above
|
) = await request_body_to_args( # body_params checked above
|
||||||
required_params=dependant.body_params, received_body=body
|
body_fields=dependant.body_params,
|
||||||
|
received_body=body,
|
||||||
|
embed_body_fields=embed_body_fields,
|
||||||
)
|
)
|
||||||
values.update(body_values)
|
values.update(body_values)
|
||||||
errors.extend(body_errors)
|
errors.extend(body_errors)
|
||||||
|
|
@ -669,138 +666,185 @@ async def solve_dependencies(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_value_with_model_field(
|
||||||
|
*, field: ModelField, value: Any, values: Dict[str, Any], loc: Tuple[str, ...]
|
||||||
|
) -> Tuple[Any, List[Any]]:
|
||||||
|
if value is None:
|
||||||
|
if field.required:
|
||||||
|
return None, [get_missing_field_error(loc=loc)]
|
||||||
|
else:
|
||||||
|
return deepcopy(field.default), []
|
||||||
|
v_, errors_ = field.validate(value, values, loc=loc)
|
||||||
|
if isinstance(errors_, ErrorWrapper):
|
||||||
|
return None, [errors_]
|
||||||
|
elif isinstance(errors_, list):
|
||||||
|
new_errors = _regenerate_error_with_loc(errors=errors_, loc_prefix=())
|
||||||
|
return None, new_errors
|
||||||
|
else:
|
||||||
|
return v_, []
|
||||||
|
|
||||||
|
|
||||||
|
def _get_multidict_value(field: ModelField, values: Mapping[str, Any]) -> Any:
|
||||||
|
if is_sequence_field(field) and isinstance(values, (ImmutableMultiDict, Headers)):
|
||||||
|
value = values.getlist(field.alias)
|
||||||
|
else:
|
||||||
|
value = values.get(field.alias, None)
|
||||||
|
if (
|
||||||
|
value is None
|
||||||
|
or (
|
||||||
|
isinstance(field.field_info, params.Form)
|
||||||
|
and isinstance(value, str) # For type checks
|
||||||
|
and value == ""
|
||||||
|
)
|
||||||
|
or (is_sequence_field(field) and len(value) == 0)
|
||||||
|
):
|
||||||
|
if field.required:
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
return deepcopy(field.default)
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
def request_params_to_args(
|
def request_params_to_args(
|
||||||
required_params: Sequence[ModelField],
|
fields: Sequence[ModelField],
|
||||||
received_params: Union[Mapping[str, Any], QueryParams, Headers],
|
received_params: Union[Mapping[str, Any], QueryParams, Headers],
|
||||||
) -> Tuple[Dict[str, Any], List[Any]]:
|
) -> Tuple[Dict[str, Any], List[Any]]:
|
||||||
values = {}
|
values: Dict[str, Any] = {}
|
||||||
errors = []
|
errors = []
|
||||||
for field in required_params:
|
for field in fields:
|
||||||
if is_scalar_sequence_field(field) and isinstance(
|
value = _get_multidict_value(field, received_params)
|
||||||
received_params, (QueryParams, Headers)
|
|
||||||
):
|
|
||||||
value = received_params.getlist(field.alias) or field.default
|
|
||||||
else:
|
|
||||||
value = received_params.get(field.alias)
|
|
||||||
field_info = field.field_info
|
field_info = field.field_info
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
field_info, params.Param
|
field_info, params.Param
|
||||||
), "Params must be subclasses of Param"
|
), "Params must be subclasses of Param"
|
||||||
loc = (field_info.in_.value, field.alias)
|
loc = (field_info.in_.value, field.alias)
|
||||||
if value is None:
|
v_, errors_ = _validate_value_with_model_field(
|
||||||
if field.required:
|
field=field, value=value, values=values, loc=loc
|
||||||
errors.append(get_missing_field_error(loc=loc))
|
)
|
||||||
else:
|
if errors_:
|
||||||
values[field.name] = deepcopy(field.default)
|
errors.extend(errors_)
|
||||||
continue
|
|
||||||
v_, errors_ = field.validate(value, values, loc=loc)
|
|
||||||
if isinstance(errors_, ErrorWrapper):
|
|
||||||
errors.append(errors_)
|
|
||||||
elif isinstance(errors_, list):
|
|
||||||
new_errors = _regenerate_error_with_loc(errors=errors_, loc_prefix=())
|
|
||||||
errors.extend(new_errors)
|
|
||||||
else:
|
else:
|
||||||
values[field.name] = v_
|
values[field.name] = v_
|
||||||
return values, errors
|
return values, errors
|
||||||
|
|
||||||
|
|
||||||
async def request_body_to_args(
|
def _should_embed_body_fields(fields: List[ModelField]) -> bool:
|
||||||
required_params: List[ModelField],
|
if not fields:
|
||||||
received_body: Optional[Union[Dict[str, Any], FormData]],
|
return False
|
||||||
) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]:
|
# More than one dependency could have the same field, it would show up as multiple
|
||||||
|
# fields but it's the same one, so count them by name
|
||||||
|
body_param_names_set = {field.name for field in fields}
|
||||||
|
# A top level field has to be a single field, not multiple
|
||||||
|
if len(body_param_names_set) > 1:
|
||||||
|
return True
|
||||||
|
first_field = fields[0]
|
||||||
|
# If it explicitly specifies it is embedded, it has to be embedded
|
||||||
|
if getattr(first_field.field_info, "embed", None):
|
||||||
|
return True
|
||||||
|
# If it's a Form (or File) field, it has to be a BaseModel to be top level
|
||||||
|
# otherwise it has to be embedded, so that the key value pair can be extracted
|
||||||
|
if isinstance(first_field.field_info, params.Form):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
async def _extract_form_body(
|
||||||
|
body_fields: List[ModelField],
|
||||||
|
received_body: FormData,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
values = {}
|
values = {}
|
||||||
|
first_field = body_fields[0]
|
||||||
|
first_field_info = first_field.field_info
|
||||||
|
|
||||||
|
for field in body_fields:
|
||||||
|
value = _get_multidict_value(field, received_body)
|
||||||
|
if (
|
||||||
|
isinstance(first_field_info, params.File)
|
||||||
|
and is_bytes_field(field)
|
||||||
|
and isinstance(value, UploadFile)
|
||||||
|
):
|
||||||
|
value = await value.read()
|
||||||
|
elif (
|
||||||
|
is_bytes_sequence_field(field)
|
||||||
|
and isinstance(first_field_info, params.File)
|
||||||
|
and value_is_sequence(value)
|
||||||
|
):
|
||||||
|
# For types
|
||||||
|
assert isinstance(value, sequence_types) # type: ignore[arg-type]
|
||||||
|
results: List[Union[bytes, str]] = []
|
||||||
|
|
||||||
|
async def process_fn(
|
||||||
|
fn: Callable[[], Coroutine[Any, Any, Any]],
|
||||||
|
) -> None:
|
||||||
|
result = await fn()
|
||||||
|
results.append(result) # noqa: B023
|
||||||
|
|
||||||
|
async with anyio.create_task_group() as tg:
|
||||||
|
for sub_value in value:
|
||||||
|
tg.start_soon(process_fn, sub_value.read)
|
||||||
|
value = serialize_sequence_value(field=field, value=results)
|
||||||
|
values[field.name] = value
|
||||||
|
return values
|
||||||
|
|
||||||
|
|
||||||
|
async def request_body_to_args(
|
||||||
|
body_fields: List[ModelField],
|
||||||
|
received_body: Optional[Union[Dict[str, Any], FormData]],
|
||||||
|
embed_body_fields: bool,
|
||||||
|
) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]:
|
||||||
|
values: Dict[str, Any] = {}
|
||||||
errors: List[Dict[str, Any]] = []
|
errors: List[Dict[str, Any]] = []
|
||||||
if required_params:
|
assert body_fields, "request_body_to_args() should be called with fields"
|
||||||
field = required_params[0]
|
single_not_embedded_field = len(body_fields) == 1 and not embed_body_fields
|
||||||
field_info = field.field_info
|
first_field = body_fields[0]
|
||||||
embed = getattr(field_info, "embed", None)
|
body_to_process = received_body
|
||||||
field_alias_omitted = len(required_params) == 1 and not embed
|
if isinstance(received_body, FormData):
|
||||||
if field_alias_omitted:
|
body_to_process = await _extract_form_body(body_fields, received_body)
|
||||||
received_body = {field.alias: received_body}
|
|
||||||
|
|
||||||
for field in required_params:
|
if single_not_embedded_field:
|
||||||
loc: Tuple[str, ...]
|
loc: Tuple[str, ...] = ("body",)
|
||||||
if field_alias_omitted:
|
v_, errors_ = _validate_value_with_model_field(
|
||||||
loc = ("body",)
|
field=first_field, value=body_to_process, values=values, loc=loc
|
||||||
else:
|
)
|
||||||
loc = ("body", field.alias)
|
return {first_field.name: v_}, errors_
|
||||||
|
for field in body_fields:
|
||||||
value: Optional[Any] = None
|
loc = ("body", field.alias)
|
||||||
if received_body is not None:
|
value: Optional[Any] = None
|
||||||
if (is_sequence_field(field)) and isinstance(received_body, FormData):
|
if body_to_process is not None:
|
||||||
value = received_body.getlist(field.alias)
|
try:
|
||||||
else:
|
value = body_to_process.get(field.alias)
|
||||||
try:
|
# If the received body is a list, not a dict
|
||||||
value = received_body.get(field.alias)
|
except AttributeError:
|
||||||
except AttributeError:
|
errors.append(get_missing_field_error(loc))
|
||||||
errors.append(get_missing_field_error(loc))
|
|
||||||
continue
|
|
||||||
if (
|
|
||||||
value is None
|
|
||||||
or (isinstance(field_info, params.Form) and value == "")
|
|
||||||
or (
|
|
||||||
isinstance(field_info, params.Form)
|
|
||||||
and is_sequence_field(field)
|
|
||||||
and len(value) == 0
|
|
||||||
)
|
|
||||||
):
|
|
||||||
if field.required:
|
|
||||||
errors.append(get_missing_field_error(loc))
|
|
||||||
else:
|
|
||||||
values[field.name] = deepcopy(field.default)
|
|
||||||
continue
|
continue
|
||||||
if (
|
v_, errors_ = _validate_value_with_model_field(
|
||||||
isinstance(field_info, params.File)
|
field=field, value=value, values=values, loc=loc
|
||||||
and is_bytes_field(field)
|
)
|
||||||
and isinstance(value, UploadFile)
|
if errors_:
|
||||||
):
|
errors.extend(errors_)
|
||||||
value = await value.read()
|
else:
|
||||||
elif (
|
values[field.name] = v_
|
||||||
is_bytes_sequence_field(field)
|
|
||||||
and isinstance(field_info, params.File)
|
|
||||||
and value_is_sequence(value)
|
|
||||||
):
|
|
||||||
# For types
|
|
||||||
assert isinstance(value, sequence_types) # type: ignore[arg-type]
|
|
||||||
results: List[Union[bytes, str]] = []
|
|
||||||
|
|
||||||
async def process_fn(
|
|
||||||
fn: Callable[[], Coroutine[Any, Any, Any]],
|
|
||||||
) -> None:
|
|
||||||
result = await fn()
|
|
||||||
results.append(result) # noqa: B023
|
|
||||||
|
|
||||||
async with anyio.create_task_group() as tg:
|
|
||||||
for sub_value in value:
|
|
||||||
tg.start_soon(process_fn, sub_value.read)
|
|
||||||
value = serialize_sequence_value(field=field, value=results)
|
|
||||||
|
|
||||||
v_, errors_ = field.validate(value, values, loc=loc)
|
|
||||||
|
|
||||||
if isinstance(errors_, list):
|
|
||||||
errors.extend(errors_)
|
|
||||||
elif errors_:
|
|
||||||
errors.append(errors_)
|
|
||||||
else:
|
|
||||||
values[field.name] = v_
|
|
||||||
return values, errors
|
return values, errors
|
||||||
|
|
||||||
|
|
||||||
def get_body_field(*, dependant: Dependant, name: str) -> Optional[ModelField]:
|
def get_body_field(
|
||||||
flat_dependant = get_flat_dependant(dependant)
|
*, flat_dependant: Dependant, name: str, embed_body_fields: bool
|
||||||
|
) -> Optional[ModelField]:
|
||||||
|
"""
|
||||||
|
Get a ModelField representing the request body for a path operation, combining
|
||||||
|
all body parameters into a single field if necessary.
|
||||||
|
|
||||||
|
Used to check if it's form data (with `isinstance(body_field, params.Form)`)
|
||||||
|
or JSON and to generate the JSON Schema for a request body.
|
||||||
|
|
||||||
|
This is **not** used to validate/parse the request body, that's done with each
|
||||||
|
individual body parameter.
|
||||||
|
"""
|
||||||
if not flat_dependant.body_params:
|
if not flat_dependant.body_params:
|
||||||
return None
|
return None
|
||||||
first_param = flat_dependant.body_params[0]
|
first_param = flat_dependant.body_params[0]
|
||||||
field_info = first_param.field_info
|
if not embed_body_fields:
|
||||||
embed = getattr(field_info, "embed", None)
|
|
||||||
body_param_names_set = {param.name for param in flat_dependant.body_params}
|
|
||||||
if len(body_param_names_set) == 1 and not embed:
|
|
||||||
return first_param
|
return first_param
|
||||||
# If one field requires to embed, all have to be embedded
|
|
||||||
# in case a sub-dependency is evaluated with a single unique body field
|
|
||||||
# That is combined (embedded) with other body fields
|
|
||||||
for param in flat_dependant.body_params:
|
|
||||||
setattr(param.field_info, "embed", True) # noqa: B010
|
|
||||||
model_name = "Body_" + name
|
model_name = "Body_" + name
|
||||||
BodyModel = create_body_model(
|
BodyModel = create_body_model(
|
||||||
fields=flat_dependant.body_params, model_name=model_name
|
fields=flat_dependant.body_params, model_name=model_name
|
||||||
|
|
|
||||||
|
|
@ -1282,7 +1282,7 @@ def Body( # noqa: N802
|
||||||
),
|
),
|
||||||
] = _Unset,
|
] = _Unset,
|
||||||
embed: Annotated[
|
embed: Annotated[
|
||||||
bool,
|
Union[bool, None],
|
||||||
Doc(
|
Doc(
|
||||||
"""
|
"""
|
||||||
When `embed` is `True`, the parameter will be expected in a JSON body as a
|
When `embed` is `True`, the parameter will be expected in a JSON body as a
|
||||||
|
|
@ -1294,7 +1294,7 @@ def Body( # noqa: N802
|
||||||
[FastAPI docs for Body - Multiple Parameters](https://fastapi.tiangolo.com/tutorial/body-multiple-params/#embed-a-single-body-parameter).
|
[FastAPI docs for Body - Multiple Parameters](https://fastapi.tiangolo.com/tutorial/body-multiple-params/#embed-a-single-body-parameter).
|
||||||
"""
|
"""
|
||||||
),
|
),
|
||||||
] = False,
|
] = None,
|
||||||
media_type: Annotated[
|
media_type: Annotated[
|
||||||
str,
|
str,
|
||||||
Doc(
|
Doc(
|
||||||
|
|
|
||||||
|
|
@ -479,7 +479,7 @@ class Body(FieldInfo):
|
||||||
*,
|
*,
|
||||||
default_factory: Union[Callable[[], Any], None] = _Unset,
|
default_factory: Union[Callable[[], Any], None] = _Unset,
|
||||||
annotation: Optional[Any] = None,
|
annotation: Optional[Any] = None,
|
||||||
embed: bool = False,
|
embed: Union[bool, None] = None,
|
||||||
media_type: str = "application/json",
|
media_type: str = "application/json",
|
||||||
alias: Optional[str] = None,
|
alias: Optional[str] = None,
|
||||||
alias_priority: Union[int, None] = _Unset,
|
alias_priority: Union[int, None] = _Unset,
|
||||||
|
|
@ -642,7 +642,6 @@ class Form(Body):
|
||||||
default=default,
|
default=default,
|
||||||
default_factory=default_factory,
|
default_factory=default_factory,
|
||||||
annotation=annotation,
|
annotation=annotation,
|
||||||
embed=True,
|
|
||||||
media_type=media_type,
|
media_type=media_type,
|
||||||
alias=alias,
|
alias=alias,
|
||||||
alias_priority=alias_priority,
|
alias_priority=alias_priority,
|
||||||
|
|
|
||||||
|
|
@ -33,8 +33,10 @@ from fastapi._compat import (
|
||||||
from fastapi.datastructures import Default, DefaultPlaceholder
|
from fastapi.datastructures import Default, DefaultPlaceholder
|
||||||
from fastapi.dependencies.models import Dependant
|
from fastapi.dependencies.models import Dependant
|
||||||
from fastapi.dependencies.utils import (
|
from fastapi.dependencies.utils import (
|
||||||
|
_should_embed_body_fields,
|
||||||
get_body_field,
|
get_body_field,
|
||||||
get_dependant,
|
get_dependant,
|
||||||
|
get_flat_dependant,
|
||||||
get_parameterless_sub_dependant,
|
get_parameterless_sub_dependant,
|
||||||
get_typed_return_annotation,
|
get_typed_return_annotation,
|
||||||
solve_dependencies,
|
solve_dependencies,
|
||||||
|
|
@ -225,6 +227,7 @@ def get_request_handler(
|
||||||
response_model_exclude_defaults: bool = False,
|
response_model_exclude_defaults: bool = False,
|
||||||
response_model_exclude_none: bool = False,
|
response_model_exclude_none: bool = False,
|
||||||
dependency_overrides_provider: Optional[Any] = None,
|
dependency_overrides_provider: Optional[Any] = None,
|
||||||
|
embed_body_fields: bool = False,
|
||||||
) -> Callable[[Request], Coroutine[Any, Any, Response]]:
|
) -> Callable[[Request], Coroutine[Any, Any, Response]]:
|
||||||
assert dependant.call is not None, "dependant.call must be a function"
|
assert dependant.call is not None, "dependant.call must be a function"
|
||||||
is_coroutine = asyncio.iscoroutinefunction(dependant.call)
|
is_coroutine = asyncio.iscoroutinefunction(dependant.call)
|
||||||
|
|
@ -291,6 +294,7 @@ def get_request_handler(
|
||||||
body=body,
|
body=body,
|
||||||
dependency_overrides_provider=dependency_overrides_provider,
|
dependency_overrides_provider=dependency_overrides_provider,
|
||||||
async_exit_stack=async_exit_stack,
|
async_exit_stack=async_exit_stack,
|
||||||
|
embed_body_fields=embed_body_fields,
|
||||||
)
|
)
|
||||||
errors = solved_result.errors
|
errors = solved_result.errors
|
||||||
if not errors:
|
if not errors:
|
||||||
|
|
@ -354,7 +358,9 @@ def get_request_handler(
|
||||||
|
|
||||||
|
|
||||||
def get_websocket_app(
|
def get_websocket_app(
|
||||||
dependant: Dependant, dependency_overrides_provider: Optional[Any] = None
|
dependant: Dependant,
|
||||||
|
dependency_overrides_provider: Optional[Any] = None,
|
||||||
|
embed_body_fields: bool = False,
|
||||||
) -> Callable[[WebSocket], Coroutine[Any, Any, Any]]:
|
) -> Callable[[WebSocket], Coroutine[Any, Any, Any]]:
|
||||||
async def app(websocket: WebSocket) -> None:
|
async def app(websocket: WebSocket) -> None:
|
||||||
async with AsyncExitStack() as async_exit_stack:
|
async with AsyncExitStack() as async_exit_stack:
|
||||||
|
|
@ -367,6 +373,7 @@ def get_websocket_app(
|
||||||
dependant=dependant,
|
dependant=dependant,
|
||||||
dependency_overrides_provider=dependency_overrides_provider,
|
dependency_overrides_provider=dependency_overrides_provider,
|
||||||
async_exit_stack=async_exit_stack,
|
async_exit_stack=async_exit_stack,
|
||||||
|
embed_body_fields=embed_body_fields,
|
||||||
)
|
)
|
||||||
if solved_result.errors:
|
if solved_result.errors:
|
||||||
raise WebSocketRequestValidationError(
|
raise WebSocketRequestValidationError(
|
||||||
|
|
@ -399,11 +406,15 @@ class APIWebSocketRoute(routing.WebSocketRoute):
|
||||||
0,
|
0,
|
||||||
get_parameterless_sub_dependant(depends=depends, path=self.path_format),
|
get_parameterless_sub_dependant(depends=depends, path=self.path_format),
|
||||||
)
|
)
|
||||||
|
self._flat_dependant = get_flat_dependant(self.dependant)
|
||||||
|
self._embed_body_fields = _should_embed_body_fields(
|
||||||
|
self._flat_dependant.body_params
|
||||||
|
)
|
||||||
self.app = websocket_session(
|
self.app = websocket_session(
|
||||||
get_websocket_app(
|
get_websocket_app(
|
||||||
dependant=self.dependant,
|
dependant=self.dependant,
|
||||||
dependency_overrides_provider=dependency_overrides_provider,
|
dependency_overrides_provider=dependency_overrides_provider,
|
||||||
|
embed_body_fields=self._embed_body_fields,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -544,7 +555,15 @@ class APIRoute(routing.Route):
|
||||||
0,
|
0,
|
||||||
get_parameterless_sub_dependant(depends=depends, path=self.path_format),
|
get_parameterless_sub_dependant(depends=depends, path=self.path_format),
|
||||||
)
|
)
|
||||||
self.body_field = get_body_field(dependant=self.dependant, name=self.unique_id)
|
self._flat_dependant = get_flat_dependant(self.dependant)
|
||||||
|
self._embed_body_fields = _should_embed_body_fields(
|
||||||
|
self._flat_dependant.body_params
|
||||||
|
)
|
||||||
|
self.body_field = get_body_field(
|
||||||
|
flat_dependant=self._flat_dependant,
|
||||||
|
name=self.unique_id,
|
||||||
|
embed_body_fields=self._embed_body_fields,
|
||||||
|
)
|
||||||
self.app = request_response(self.get_route_handler())
|
self.app = request_response(self.get_route_handler())
|
||||||
|
|
||||||
def get_route_handler(self) -> Callable[[Request], Coroutine[Any, Any, Response]]:
|
def get_route_handler(self) -> Callable[[Request], Coroutine[Any, Any, Response]]:
|
||||||
|
|
@ -561,6 +580,7 @@ class APIRoute(routing.Route):
|
||||||
response_model_exclude_defaults=self.response_model_exclude_defaults,
|
response_model_exclude_defaults=self.response_model_exclude_defaults,
|
||||||
response_model_exclude_none=self.response_model_exclude_none,
|
response_model_exclude_none=self.response_model_exclude_none,
|
||||||
dependency_overrides_provider=self.dependency_overrides_provider,
|
dependency_overrides_provider=self.dependency_overrides_provider,
|
||||||
|
embed_body_fields=self._embed_body_fields,
|
||||||
)
|
)
|
||||||
|
|
||||||
def matches(self, scope: Scope) -> Tuple[Match, Scope]:
|
def matches(self, scope: Scope) -> Tuple[Match, Scope]:
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,13 @@
|
||||||
from typing import List, Union
|
from typing import Any, Dict, List, Union
|
||||||
|
|
||||||
from fastapi import FastAPI, UploadFile
|
from fastapi import FastAPI, UploadFile
|
||||||
from fastapi._compat import (
|
from fastapi._compat import (
|
||||||
ModelField,
|
ModelField,
|
||||||
Undefined,
|
Undefined,
|
||||||
_get_model_config,
|
_get_model_config,
|
||||||
|
get_model_fields,
|
||||||
is_bytes_sequence_annotation,
|
is_bytes_sequence_annotation,
|
||||||
|
is_scalar_field,
|
||||||
is_uploadfile_sequence_annotation,
|
is_uploadfile_sequence_annotation,
|
||||||
)
|
)
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
|
|
@ -91,3 +93,12 @@ def test_is_uploadfile_sequence_annotation():
|
||||||
# and other types, but I'm not even sure it's a good idea to support it as a first
|
# and other types, but I'm not even sure it's a good idea to support it as a first
|
||||||
# class "feature"
|
# class "feature"
|
||||||
assert is_uploadfile_sequence_annotation(Union[List[str], List[UploadFile]])
|
assert is_uploadfile_sequence_annotation(Union[List[str], List[UploadFile]])
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_pv1_scalar_field():
|
||||||
|
# For coverage
|
||||||
|
class Model(BaseModel):
|
||||||
|
foo: Union[str, Dict[str, Any]]
|
||||||
|
|
||||||
|
fields = get_model_fields(Model)
|
||||||
|
assert not is_scalar_field(fields[0])
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,99 @@
|
||||||
|
from fastapi import FastAPI, Form
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/form/")
|
||||||
|
def post_form(username: Annotated[str, Form()]):
|
||||||
|
return username
|
||||||
|
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
|
||||||
|
|
||||||
|
def test_single_form_field():
|
||||||
|
response = client.post("/form/", data={"username": "Rick"})
|
||||||
|
assert response.status_code == 200, response.text
|
||||||
|
assert response.json() == "Rick"
|
||||||
|
|
||||||
|
|
||||||
|
def test_openapi_schema():
|
||||||
|
response = client.get("/openapi.json")
|
||||||
|
assert response.status_code == 200, response.text
|
||||||
|
assert response.json() == {
|
||||||
|
"openapi": "3.1.0",
|
||||||
|
"info": {"title": "FastAPI", "version": "0.1.0"},
|
||||||
|
"paths": {
|
||||||
|
"/form/": {
|
||||||
|
"post": {
|
||||||
|
"summary": "Post Form",
|
||||||
|
"operationId": "post_form_form__post",
|
||||||
|
"requestBody": {
|
||||||
|
"content": {
|
||||||
|
"application/x-www-form-urlencoded": {
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/components/schemas/Body_post_form_form__post"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": True,
|
||||||
|
},
|
||||||
|
"responses": {
|
||||||
|
"200": {
|
||||||
|
"description": "Successful Response",
|
||||||
|
"content": {"application/json": {"schema": {}}},
|
||||||
|
},
|
||||||
|
"422": {
|
||||||
|
"description": "Validation Error",
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/components/schemas/HTTPValidationError"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"components": {
|
||||||
|
"schemas": {
|
||||||
|
"Body_post_form_form__post": {
|
||||||
|
"properties": {"username": {"type": "string", "title": "Username"}},
|
||||||
|
"type": "object",
|
||||||
|
"required": ["username"],
|
||||||
|
"title": "Body_post_form_form__post",
|
||||||
|
},
|
||||||
|
"HTTPValidationError": {
|
||||||
|
"properties": {
|
||||||
|
"detail": {
|
||||||
|
"items": {"$ref": "#/components/schemas/ValidationError"},
|
||||||
|
"type": "array",
|
||||||
|
"title": "Detail",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"type": "object",
|
||||||
|
"title": "HTTPValidationError",
|
||||||
|
},
|
||||||
|
"ValidationError": {
|
||||||
|
"properties": {
|
||||||
|
"loc": {
|
||||||
|
"items": {
|
||||||
|
"anyOf": [{"type": "string"}, {"type": "integer"}]
|
||||||
|
},
|
||||||
|
"type": "array",
|
||||||
|
"title": "Location",
|
||||||
|
},
|
||||||
|
"msg": {"type": "string", "title": "Message"},
|
||||||
|
"type": {"type": "string", "title": "Error Type"},
|
||||||
|
},
|
||||||
|
"type": "object",
|
||||||
|
"required": ["loc", "msg", "type"],
|
||||||
|
"title": "ValidationError",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue