mirror of https://github.com/tiangolo/fastapi.git
♻️ Refactor deciding if `embed` body fields, do not overwrite fields, compute once per router, refactor internals in preparation for Pydantic models in `Form`, `Query` and others (#12117)
This commit is contained in:
parent
7213d421f5
commit
aa21814a89
|
|
@ -279,6 +279,12 @@ if PYDANTIC_V2:
|
|||
BodyModel: Type[BaseModel] = create_model(model_name, **field_params) # type: ignore[call-overload]
|
||||
return BodyModel
|
||||
|
||||
def get_model_fields(model: Type[BaseModel]) -> List[ModelField]:
|
||||
return [
|
||||
ModelField(field_info=field_info, name=name)
|
||||
for name, field_info in model.model_fields.items()
|
||||
]
|
||||
|
||||
else:
|
||||
from fastapi.openapi.constants import REF_PREFIX as REF_PREFIX
|
||||
from pydantic import AnyUrl as Url # noqa: F401
|
||||
|
|
@ -513,6 +519,9 @@ else:
|
|||
BodyModel.__fields__[f.name] = f # type: ignore[index]
|
||||
return BodyModel
|
||||
|
||||
def get_model_fields(model: Type[BaseModel]) -> List[ModelField]:
|
||||
return list(model.__fields__.values()) # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def _regenerate_error_with_loc(
|
||||
*, errors: Sequence[Any], loc_prefix: Tuple[Union[str, int], ...]
|
||||
|
|
@ -532,6 +541,12 @@ def _annotation_is_sequence(annotation: Union[Type[Any], None]) -> bool:
|
|||
|
||||
|
||||
def field_annotation_is_sequence(annotation: Union[Type[Any], None]) -> bool:
|
||||
origin = get_origin(annotation)
|
||||
if origin is Union or origin is UnionType:
|
||||
for arg in get_args(annotation):
|
||||
if field_annotation_is_sequence(arg):
|
||||
return True
|
||||
return False
|
||||
return _annotation_is_sequence(annotation) or _annotation_is_sequence(
|
||||
get_origin(annotation)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -59,7 +59,13 @@ from fastapi.utils import create_model_field, get_path_param_names
|
|||
from pydantic.fields import FieldInfo
|
||||
from starlette.background import BackgroundTasks as StarletteBackgroundTasks
|
||||
from starlette.concurrency import run_in_threadpool
|
||||
from starlette.datastructures import FormData, Headers, QueryParams, UploadFile
|
||||
from starlette.datastructures import (
|
||||
FormData,
|
||||
Headers,
|
||||
ImmutableMultiDict,
|
||||
QueryParams,
|
||||
UploadFile,
|
||||
)
|
||||
from starlette.requests import HTTPConnection, Request
|
||||
from starlette.responses import Response
|
||||
from starlette.websockets import WebSocket
|
||||
|
|
@ -282,7 +288,7 @@ def get_dependant(
|
|||
), f"Cannot specify multiple FastAPI annotations for {param_name!r}"
|
||||
continue
|
||||
assert param_details.field is not None
|
||||
if is_body_param(param_field=param_details.field, is_path_param=is_path_param):
|
||||
if isinstance(param_details.field.field_info, params.Body):
|
||||
dependant.body_params.append(param_details.field)
|
||||
else:
|
||||
add_param_to_fields(field=param_details.field, dependant=dependant)
|
||||
|
|
@ -466,29 +472,16 @@ def analyze_param(
|
|||
required=field_info.default in (Required, Undefined),
|
||||
field_info=field_info,
|
||||
)
|
||||
if is_path_param:
|
||||
assert is_scalar_field(
|
||||
field=field
|
||||
), "Path params must be of one of the supported types"
|
||||
elif isinstance(field_info, params.Query):
|
||||
assert is_scalar_field(field) or is_scalar_sequence_field(field)
|
||||
|
||||
return ParamDetails(type_annotation=type_annotation, depends=depends, field=field)
|
||||
|
||||
|
||||
def is_body_param(*, param_field: ModelField, is_path_param: bool) -> bool:
|
||||
if is_path_param:
|
||||
assert is_scalar_field(
|
||||
field=param_field
|
||||
), "Path params must be of one of the supported types"
|
||||
return False
|
||||
elif is_scalar_field(field=param_field):
|
||||
return False
|
||||
elif isinstance(
|
||||
param_field.field_info, (params.Query, params.Header)
|
||||
) and is_scalar_sequence_field(param_field):
|
||||
return False
|
||||
else:
|
||||
assert isinstance(
|
||||
param_field.field_info, params.Body
|
||||
), f"Param: {param_field.name} can only be a request body, using Body()"
|
||||
return True
|
||||
|
||||
|
||||
def add_param_to_fields(*, field: ModelField, dependant: Dependant) -> None:
|
||||
field_info = field.field_info
|
||||
field_info_in = getattr(field_info, "in_", None)
|
||||
|
|
@ -557,6 +550,7 @@ async def solve_dependencies(
|
|||
dependency_overrides_provider: Optional[Any] = None,
|
||||
dependency_cache: Optional[Dict[Tuple[Callable[..., Any], Tuple[str]], Any]] = None,
|
||||
async_exit_stack: AsyncExitStack,
|
||||
embed_body_fields: bool,
|
||||
) -> SolvedDependency:
|
||||
values: Dict[str, Any] = {}
|
||||
errors: List[Any] = []
|
||||
|
|
@ -598,6 +592,7 @@ async def solve_dependencies(
|
|||
dependency_overrides_provider=dependency_overrides_provider,
|
||||
dependency_cache=dependency_cache,
|
||||
async_exit_stack=async_exit_stack,
|
||||
embed_body_fields=embed_body_fields,
|
||||
)
|
||||
background_tasks = solved_result.background_tasks
|
||||
dependency_cache.update(solved_result.dependency_cache)
|
||||
|
|
@ -640,7 +635,9 @@ async def solve_dependencies(
|
|||
body_values,
|
||||
body_errors,
|
||||
) = await request_body_to_args( # body_params checked above
|
||||
required_params=dependant.body_params, received_body=body
|
||||
body_fields=dependant.body_params,
|
||||
received_body=body,
|
||||
embed_body_fields=embed_body_fields,
|
||||
)
|
||||
values.update(body_values)
|
||||
errors.extend(body_errors)
|
||||
|
|
@ -669,138 +666,185 @@ async def solve_dependencies(
|
|||
)
|
||||
|
||||
|
||||
def _validate_value_with_model_field(
|
||||
*, field: ModelField, value: Any, values: Dict[str, Any], loc: Tuple[str, ...]
|
||||
) -> Tuple[Any, List[Any]]:
|
||||
if value is None:
|
||||
if field.required:
|
||||
return None, [get_missing_field_error(loc=loc)]
|
||||
else:
|
||||
return deepcopy(field.default), []
|
||||
v_, errors_ = field.validate(value, values, loc=loc)
|
||||
if isinstance(errors_, ErrorWrapper):
|
||||
return None, [errors_]
|
||||
elif isinstance(errors_, list):
|
||||
new_errors = _regenerate_error_with_loc(errors=errors_, loc_prefix=())
|
||||
return None, new_errors
|
||||
else:
|
||||
return v_, []
|
||||
|
||||
|
||||
def _get_multidict_value(field: ModelField, values: Mapping[str, Any]) -> Any:
|
||||
if is_sequence_field(field) and isinstance(values, (ImmutableMultiDict, Headers)):
|
||||
value = values.getlist(field.alias)
|
||||
else:
|
||||
value = values.get(field.alias, None)
|
||||
if (
|
||||
value is None
|
||||
or (
|
||||
isinstance(field.field_info, params.Form)
|
||||
and isinstance(value, str) # For type checks
|
||||
and value == ""
|
||||
)
|
||||
or (is_sequence_field(field) and len(value) == 0)
|
||||
):
|
||||
if field.required:
|
||||
return
|
||||
else:
|
||||
return deepcopy(field.default)
|
||||
return value
|
||||
|
||||
|
||||
def request_params_to_args(
|
||||
required_params: Sequence[ModelField],
|
||||
fields: Sequence[ModelField],
|
||||
received_params: Union[Mapping[str, Any], QueryParams, Headers],
|
||||
) -> Tuple[Dict[str, Any], List[Any]]:
|
||||
values = {}
|
||||
values: Dict[str, Any] = {}
|
||||
errors = []
|
||||
for field in required_params:
|
||||
if is_scalar_sequence_field(field) and isinstance(
|
||||
received_params, (QueryParams, Headers)
|
||||
):
|
||||
value = received_params.getlist(field.alias) or field.default
|
||||
else:
|
||||
value = received_params.get(field.alias)
|
||||
for field in fields:
|
||||
value = _get_multidict_value(field, received_params)
|
||||
field_info = field.field_info
|
||||
assert isinstance(
|
||||
field_info, params.Param
|
||||
), "Params must be subclasses of Param"
|
||||
loc = (field_info.in_.value, field.alias)
|
||||
if value is None:
|
||||
if field.required:
|
||||
errors.append(get_missing_field_error(loc=loc))
|
||||
else:
|
||||
values[field.name] = deepcopy(field.default)
|
||||
continue
|
||||
v_, errors_ = field.validate(value, values, loc=loc)
|
||||
if isinstance(errors_, ErrorWrapper):
|
||||
errors.append(errors_)
|
||||
elif isinstance(errors_, list):
|
||||
new_errors = _regenerate_error_with_loc(errors=errors_, loc_prefix=())
|
||||
errors.extend(new_errors)
|
||||
v_, errors_ = _validate_value_with_model_field(
|
||||
field=field, value=value, values=values, loc=loc
|
||||
)
|
||||
if errors_:
|
||||
errors.extend(errors_)
|
||||
else:
|
||||
values[field.name] = v_
|
||||
return values, errors
|
||||
|
||||
|
||||
async def request_body_to_args(
|
||||
required_params: List[ModelField],
|
||||
received_body: Optional[Union[Dict[str, Any], FormData]],
|
||||
) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]:
|
||||
def _should_embed_body_fields(fields: List[ModelField]) -> bool:
|
||||
if not fields:
|
||||
return False
|
||||
# More than one dependency could have the same field, it would show up as multiple
|
||||
# fields but it's the same one, so count them by name
|
||||
body_param_names_set = {field.name for field in fields}
|
||||
# A top level field has to be a single field, not multiple
|
||||
if len(body_param_names_set) > 1:
|
||||
return True
|
||||
first_field = fields[0]
|
||||
# If it explicitly specifies it is embedded, it has to be embedded
|
||||
if getattr(first_field.field_info, "embed", None):
|
||||
return True
|
||||
# If it's a Form (or File) field, it has to be a BaseModel to be top level
|
||||
# otherwise it has to be embedded, so that the key value pair can be extracted
|
||||
if isinstance(first_field.field_info, params.Form):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
async def _extract_form_body(
|
||||
body_fields: List[ModelField],
|
||||
received_body: FormData,
|
||||
) -> Dict[str, Any]:
|
||||
values = {}
|
||||
first_field = body_fields[0]
|
||||
first_field_info = first_field.field_info
|
||||
|
||||
for field in body_fields:
|
||||
value = _get_multidict_value(field, received_body)
|
||||
if (
|
||||
isinstance(first_field_info, params.File)
|
||||
and is_bytes_field(field)
|
||||
and isinstance(value, UploadFile)
|
||||
):
|
||||
value = await value.read()
|
||||
elif (
|
||||
is_bytes_sequence_field(field)
|
||||
and isinstance(first_field_info, params.File)
|
||||
and value_is_sequence(value)
|
||||
):
|
||||
# For types
|
||||
assert isinstance(value, sequence_types) # type: ignore[arg-type]
|
||||
results: List[Union[bytes, str]] = []
|
||||
|
||||
async def process_fn(
|
||||
fn: Callable[[], Coroutine[Any, Any, Any]],
|
||||
) -> None:
|
||||
result = await fn()
|
||||
results.append(result) # noqa: B023
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
for sub_value in value:
|
||||
tg.start_soon(process_fn, sub_value.read)
|
||||
value = serialize_sequence_value(field=field, value=results)
|
||||
values[field.name] = value
|
||||
return values
|
||||
|
||||
|
||||
async def request_body_to_args(
|
||||
body_fields: List[ModelField],
|
||||
received_body: Optional[Union[Dict[str, Any], FormData]],
|
||||
embed_body_fields: bool,
|
||||
) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]:
|
||||
values: Dict[str, Any] = {}
|
||||
errors: List[Dict[str, Any]] = []
|
||||
if required_params:
|
||||
field = required_params[0]
|
||||
field_info = field.field_info
|
||||
embed = getattr(field_info, "embed", None)
|
||||
field_alias_omitted = len(required_params) == 1 and not embed
|
||||
if field_alias_omitted:
|
||||
received_body = {field.alias: received_body}
|
||||
assert body_fields, "request_body_to_args() should be called with fields"
|
||||
single_not_embedded_field = len(body_fields) == 1 and not embed_body_fields
|
||||
first_field = body_fields[0]
|
||||
body_to_process = received_body
|
||||
if isinstance(received_body, FormData):
|
||||
body_to_process = await _extract_form_body(body_fields, received_body)
|
||||
|
||||
for field in required_params:
|
||||
loc: Tuple[str, ...]
|
||||
if field_alias_omitted:
|
||||
loc = ("body",)
|
||||
else:
|
||||
loc = ("body", field.alias)
|
||||
|
||||
value: Optional[Any] = None
|
||||
if received_body is not None:
|
||||
if (is_sequence_field(field)) and isinstance(received_body, FormData):
|
||||
value = received_body.getlist(field.alias)
|
||||
else:
|
||||
try:
|
||||
value = received_body.get(field.alias)
|
||||
except AttributeError:
|
||||
errors.append(get_missing_field_error(loc))
|
||||
continue
|
||||
if (
|
||||
value is None
|
||||
or (isinstance(field_info, params.Form) and value == "")
|
||||
or (
|
||||
isinstance(field_info, params.Form)
|
||||
and is_sequence_field(field)
|
||||
and len(value) == 0
|
||||
)
|
||||
):
|
||||
if field.required:
|
||||
errors.append(get_missing_field_error(loc))
|
||||
else:
|
||||
values[field.name] = deepcopy(field.default)
|
||||
if single_not_embedded_field:
|
||||
loc: Tuple[str, ...] = ("body",)
|
||||
v_, errors_ = _validate_value_with_model_field(
|
||||
field=first_field, value=body_to_process, values=values, loc=loc
|
||||
)
|
||||
return {first_field.name: v_}, errors_
|
||||
for field in body_fields:
|
||||
loc = ("body", field.alias)
|
||||
value: Optional[Any] = None
|
||||
if body_to_process is not None:
|
||||
try:
|
||||
value = body_to_process.get(field.alias)
|
||||
# If the received body is a list, not a dict
|
||||
except AttributeError:
|
||||
errors.append(get_missing_field_error(loc))
|
||||
continue
|
||||
if (
|
||||
isinstance(field_info, params.File)
|
||||
and is_bytes_field(field)
|
||||
and isinstance(value, UploadFile)
|
||||
):
|
||||
value = await value.read()
|
||||
elif (
|
||||
is_bytes_sequence_field(field)
|
||||
and isinstance(field_info, params.File)
|
||||
and value_is_sequence(value)
|
||||
):
|
||||
# For types
|
||||
assert isinstance(value, sequence_types) # type: ignore[arg-type]
|
||||
results: List[Union[bytes, str]] = []
|
||||
|
||||
async def process_fn(
|
||||
fn: Callable[[], Coroutine[Any, Any, Any]],
|
||||
) -> None:
|
||||
result = await fn()
|
||||
results.append(result) # noqa: B023
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
for sub_value in value:
|
||||
tg.start_soon(process_fn, sub_value.read)
|
||||
value = serialize_sequence_value(field=field, value=results)
|
||||
|
||||
v_, errors_ = field.validate(value, values, loc=loc)
|
||||
|
||||
if isinstance(errors_, list):
|
||||
errors.extend(errors_)
|
||||
elif errors_:
|
||||
errors.append(errors_)
|
||||
else:
|
||||
values[field.name] = v_
|
||||
v_, errors_ = _validate_value_with_model_field(
|
||||
field=field, value=value, values=values, loc=loc
|
||||
)
|
||||
if errors_:
|
||||
errors.extend(errors_)
|
||||
else:
|
||||
values[field.name] = v_
|
||||
return values, errors
|
||||
|
||||
|
||||
def get_body_field(*, dependant: Dependant, name: str) -> Optional[ModelField]:
|
||||
flat_dependant = get_flat_dependant(dependant)
|
||||
def get_body_field(
|
||||
*, flat_dependant: Dependant, name: str, embed_body_fields: bool
|
||||
) -> Optional[ModelField]:
|
||||
"""
|
||||
Get a ModelField representing the request body for a path operation, combining
|
||||
all body parameters into a single field if necessary.
|
||||
|
||||
Used to check if it's form data (with `isinstance(body_field, params.Form)`)
|
||||
or JSON and to generate the JSON Schema for a request body.
|
||||
|
||||
This is **not** used to validate/parse the request body, that's done with each
|
||||
individual body parameter.
|
||||
"""
|
||||
if not flat_dependant.body_params:
|
||||
return None
|
||||
first_param = flat_dependant.body_params[0]
|
||||
field_info = first_param.field_info
|
||||
embed = getattr(field_info, "embed", None)
|
||||
body_param_names_set = {param.name for param in flat_dependant.body_params}
|
||||
if len(body_param_names_set) == 1 and not embed:
|
||||
if not embed_body_fields:
|
||||
return first_param
|
||||
# If one field requires to embed, all have to be embedded
|
||||
# in case a sub-dependency is evaluated with a single unique body field
|
||||
# That is combined (embedded) with other body fields
|
||||
for param in flat_dependant.body_params:
|
||||
setattr(param.field_info, "embed", True) # noqa: B010
|
||||
model_name = "Body_" + name
|
||||
BodyModel = create_body_model(
|
||||
fields=flat_dependant.body_params, model_name=model_name
|
||||
|
|
|
|||
|
|
@ -1282,7 +1282,7 @@ def Body( # noqa: N802
|
|||
),
|
||||
] = _Unset,
|
||||
embed: Annotated[
|
||||
bool,
|
||||
Union[bool, None],
|
||||
Doc(
|
||||
"""
|
||||
When `embed` is `True`, the parameter will be expected in a JSON body as a
|
||||
|
|
@ -1294,7 +1294,7 @@ def Body( # noqa: N802
|
|||
[FastAPI docs for Body - Multiple Parameters](https://fastapi.tiangolo.com/tutorial/body-multiple-params/#embed-a-single-body-parameter).
|
||||
"""
|
||||
),
|
||||
] = False,
|
||||
] = None,
|
||||
media_type: Annotated[
|
||||
str,
|
||||
Doc(
|
||||
|
|
|
|||
|
|
@ -479,7 +479,7 @@ class Body(FieldInfo):
|
|||
*,
|
||||
default_factory: Union[Callable[[], Any], None] = _Unset,
|
||||
annotation: Optional[Any] = None,
|
||||
embed: bool = False,
|
||||
embed: Union[bool, None] = None,
|
||||
media_type: str = "application/json",
|
||||
alias: Optional[str] = None,
|
||||
alias_priority: Union[int, None] = _Unset,
|
||||
|
|
@ -642,7 +642,6 @@ class Form(Body):
|
|||
default=default,
|
||||
default_factory=default_factory,
|
||||
annotation=annotation,
|
||||
embed=True,
|
||||
media_type=media_type,
|
||||
alias=alias,
|
||||
alias_priority=alias_priority,
|
||||
|
|
|
|||
|
|
@ -33,8 +33,10 @@ from fastapi._compat import (
|
|||
from fastapi.datastructures import Default, DefaultPlaceholder
|
||||
from fastapi.dependencies.models import Dependant
|
||||
from fastapi.dependencies.utils import (
|
||||
_should_embed_body_fields,
|
||||
get_body_field,
|
||||
get_dependant,
|
||||
get_flat_dependant,
|
||||
get_parameterless_sub_dependant,
|
||||
get_typed_return_annotation,
|
||||
solve_dependencies,
|
||||
|
|
@ -225,6 +227,7 @@ def get_request_handler(
|
|||
response_model_exclude_defaults: bool = False,
|
||||
response_model_exclude_none: bool = False,
|
||||
dependency_overrides_provider: Optional[Any] = None,
|
||||
embed_body_fields: bool = False,
|
||||
) -> Callable[[Request], Coroutine[Any, Any, Response]]:
|
||||
assert dependant.call is not None, "dependant.call must be a function"
|
||||
is_coroutine = asyncio.iscoroutinefunction(dependant.call)
|
||||
|
|
@ -291,6 +294,7 @@ def get_request_handler(
|
|||
body=body,
|
||||
dependency_overrides_provider=dependency_overrides_provider,
|
||||
async_exit_stack=async_exit_stack,
|
||||
embed_body_fields=embed_body_fields,
|
||||
)
|
||||
errors = solved_result.errors
|
||||
if not errors:
|
||||
|
|
@ -354,7 +358,9 @@ def get_request_handler(
|
|||
|
||||
|
||||
def get_websocket_app(
|
||||
dependant: Dependant, dependency_overrides_provider: Optional[Any] = None
|
||||
dependant: Dependant,
|
||||
dependency_overrides_provider: Optional[Any] = None,
|
||||
embed_body_fields: bool = False,
|
||||
) -> Callable[[WebSocket], Coroutine[Any, Any, Any]]:
|
||||
async def app(websocket: WebSocket) -> None:
|
||||
async with AsyncExitStack() as async_exit_stack:
|
||||
|
|
@ -367,6 +373,7 @@ def get_websocket_app(
|
|||
dependant=dependant,
|
||||
dependency_overrides_provider=dependency_overrides_provider,
|
||||
async_exit_stack=async_exit_stack,
|
||||
embed_body_fields=embed_body_fields,
|
||||
)
|
||||
if solved_result.errors:
|
||||
raise WebSocketRequestValidationError(
|
||||
|
|
@ -399,11 +406,15 @@ class APIWebSocketRoute(routing.WebSocketRoute):
|
|||
0,
|
||||
get_parameterless_sub_dependant(depends=depends, path=self.path_format),
|
||||
)
|
||||
|
||||
self._flat_dependant = get_flat_dependant(self.dependant)
|
||||
self._embed_body_fields = _should_embed_body_fields(
|
||||
self._flat_dependant.body_params
|
||||
)
|
||||
self.app = websocket_session(
|
||||
get_websocket_app(
|
||||
dependant=self.dependant,
|
||||
dependency_overrides_provider=dependency_overrides_provider,
|
||||
embed_body_fields=self._embed_body_fields,
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -544,7 +555,15 @@ class APIRoute(routing.Route):
|
|||
0,
|
||||
get_parameterless_sub_dependant(depends=depends, path=self.path_format),
|
||||
)
|
||||
self.body_field = get_body_field(dependant=self.dependant, name=self.unique_id)
|
||||
self._flat_dependant = get_flat_dependant(self.dependant)
|
||||
self._embed_body_fields = _should_embed_body_fields(
|
||||
self._flat_dependant.body_params
|
||||
)
|
||||
self.body_field = get_body_field(
|
||||
flat_dependant=self._flat_dependant,
|
||||
name=self.unique_id,
|
||||
embed_body_fields=self._embed_body_fields,
|
||||
)
|
||||
self.app = request_response(self.get_route_handler())
|
||||
|
||||
def get_route_handler(self) -> Callable[[Request], Coroutine[Any, Any, Response]]:
|
||||
|
|
@ -561,6 +580,7 @@ class APIRoute(routing.Route):
|
|||
response_model_exclude_defaults=self.response_model_exclude_defaults,
|
||||
response_model_exclude_none=self.response_model_exclude_none,
|
||||
dependency_overrides_provider=self.dependency_overrides_provider,
|
||||
embed_body_fields=self._embed_body_fields,
|
||||
)
|
||||
|
||||
def matches(self, scope: Scope) -> Tuple[Match, Scope]:
|
||||
|
|
|
|||
|
|
@ -1,11 +1,13 @@
|
|||
from typing import List, Union
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
from fastapi import FastAPI, UploadFile
|
||||
from fastapi._compat import (
|
||||
ModelField,
|
||||
Undefined,
|
||||
_get_model_config,
|
||||
get_model_fields,
|
||||
is_bytes_sequence_annotation,
|
||||
is_scalar_field,
|
||||
is_uploadfile_sequence_annotation,
|
||||
)
|
||||
from fastapi.testclient import TestClient
|
||||
|
|
@ -91,3 +93,12 @@ def test_is_uploadfile_sequence_annotation():
|
|||
# and other types, but I'm not even sure it's a good idea to support it as a first
|
||||
# class "feature"
|
||||
assert is_uploadfile_sequence_annotation(Union[List[str], List[UploadFile]])
|
||||
|
||||
|
||||
def test_is_pv1_scalar_field():
|
||||
# For coverage
|
||||
class Model(BaseModel):
|
||||
foo: Union[str, Dict[str, Any]]
|
||||
|
||||
fields = get_model_fields(Model)
|
||||
assert not is_scalar_field(fields[0])
|
||||
|
|
|
|||
|
|
@ -0,0 +1,99 @@
|
|||
from fastapi import FastAPI, Form
|
||||
from fastapi.testclient import TestClient
|
||||
from typing_extensions import Annotated
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
@app.post("/form/")
|
||||
def post_form(username: Annotated[str, Form()]):
|
||||
return username
|
||||
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
def test_single_form_field():
|
||||
response = client.post("/form/", data={"username": "Rick"})
|
||||
assert response.status_code == 200, response.text
|
||||
assert response.json() == "Rick"
|
||||
|
||||
|
||||
def test_openapi_schema():
|
||||
response = client.get("/openapi.json")
|
||||
assert response.status_code == 200, response.text
|
||||
assert response.json() == {
|
||||
"openapi": "3.1.0",
|
||||
"info": {"title": "FastAPI", "version": "0.1.0"},
|
||||
"paths": {
|
||||
"/form/": {
|
||||
"post": {
|
||||
"summary": "Post Form",
|
||||
"operationId": "post_form_form__post",
|
||||
"requestBody": {
|
||||
"content": {
|
||||
"application/x-www-form-urlencoded": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/Body_post_form_form__post"
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": True,
|
||||
},
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": {"application/json": {"schema": {}}},
|
||||
},
|
||||
"422": {
|
||||
"description": "Validation Error",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/HTTPValidationError"
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
},
|
||||
"components": {
|
||||
"schemas": {
|
||||
"Body_post_form_form__post": {
|
||||
"properties": {"username": {"type": "string", "title": "Username"}},
|
||||
"type": "object",
|
||||
"required": ["username"],
|
||||
"title": "Body_post_form_form__post",
|
||||
},
|
||||
"HTTPValidationError": {
|
||||
"properties": {
|
||||
"detail": {
|
||||
"items": {"$ref": "#/components/schemas/ValidationError"},
|
||||
"type": "array",
|
||||
"title": "Detail",
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
"title": "HTTPValidationError",
|
||||
},
|
||||
"ValidationError": {
|
||||
"properties": {
|
||||
"loc": {
|
||||
"items": {
|
||||
"anyOf": [{"type": "string"}, {"type": "integer"}]
|
||||
},
|
||||
"type": "array",
|
||||
"title": "Location",
|
||||
},
|
||||
"msg": {"type": "string", "title": "Message"},
|
||||
"type": {"type": "string", "title": "Error Type"},
|
||||
},
|
||||
"type": "object",
|
||||
"required": ["loc", "msg", "type"],
|
||||
"title": "ValidationError",
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
Loading…
Reference in New Issue