mirror of https://github.com/tiangolo/fastapi.git
♻️ Refactor internals for test coverage and performance (#9691)
* ♻️ Tweak import of Annotated from typing_extensions, they are installed anyway * ♻️ Refactor _compat to define functions for Pydantic v1 or v2 once instead of checking inside * ✅ Add test for UploadFile for Pydantic v2 * ♻️ Refactor types and remove logic for impossible cases * ✅ Add missing tests from test refactor for path params * ✅ Add tests for new decimal encoder * 💡 Add TODO comment for decimals in encoders * 🔥 Remove unneeded dummy function * 🔥 Remove section of code in field_annotation_is_scalar covered by sub-call to field_annotation_is_complex * ♻️ Refactor and tweak variables and types in _compat * ✅ Add tests for corner cases and compat with Pydantic v1 and v2 * ♻️ Refactor type annotations
This commit is contained in:
parent
c58e2b2d1e
commit
cfb00b2119
|
|
@ -113,22 +113,16 @@ if PYDANTIC_V2:
|
|||
value: Any,
|
||||
values: Dict[str, Any] = {}, # noqa: B006
|
||||
*,
|
||||
loc: Union[Tuple[Union[int, str], ...], str] = "",
|
||||
) -> Tuple[Any, Union[List[ValidationError], None]]:
|
||||
loc: Tuple[Union[int, str], ...] = (),
|
||||
) -> Tuple[Any, Union[List[Dict[str, Any]], None]]:
|
||||
try:
|
||||
return (
|
||||
self._type_adapter.validate_python(value, from_attributes=True),
|
||||
None,
|
||||
)
|
||||
except ValidationError as exc:
|
||||
if isinstance(loc, tuple):
|
||||
use_loc = loc
|
||||
elif loc == "":
|
||||
use_loc = ()
|
||||
else:
|
||||
use_loc = (loc,)
|
||||
return None, _regenerate_error_with_loc(
|
||||
errors=exc.errors(), loc_prefix=use_loc
|
||||
errors=exc.errors(), loc_prefix=loc
|
||||
)
|
||||
|
||||
def serialize(
|
||||
|
|
@ -161,13 +155,6 @@ if PYDANTIC_V2:
|
|||
# ModelField to its JSON Schema.
|
||||
return id(self)
|
||||
|
||||
def get_model_definitions(
|
||||
*,
|
||||
flat_models: Set[Union[Type[BaseModel], Type[Enum]]],
|
||||
model_name_map: Dict[Union[Type[BaseModel], Type[Enum]], str],
|
||||
) -> Dict[str, Any]:
|
||||
return {}
|
||||
|
||||
def get_annotation_from_field_info(
|
||||
annotation: Any, field_info: FieldInfo, field_name: str
|
||||
) -> Any:
|
||||
|
|
@ -176,6 +163,91 @@ if PYDANTIC_V2:
|
|||
def _normalize_errors(errors: Sequence[Any]) -> List[Dict[str, Any]]:
|
||||
return errors # type: ignore[return-value]
|
||||
|
||||
def _model_rebuild(model: Type[BaseModel]) -> None:
|
||||
model.model_rebuild()
|
||||
|
||||
def _model_dump(
|
||||
model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any
|
||||
) -> Any:
|
||||
return model.model_dump(mode=mode, **kwargs)
|
||||
|
||||
def _get_model_config(model: BaseModel) -> Any:
|
||||
return model.model_config
|
||||
|
||||
def get_schema_from_model_field(
|
||||
*,
|
||||
field: ModelField,
|
||||
schema_generator: GenerateJsonSchema,
|
||||
model_name_map: ModelNameMap,
|
||||
) -> Dict[str, Any]:
|
||||
# This expects that GenerateJsonSchema was already used to generate the definitions
|
||||
json_schema = schema_generator.generate_inner(field._type_adapter.core_schema)
|
||||
if "$ref" not in json_schema:
|
||||
# TODO remove when deprecating Pydantic v1
|
||||
# Ref: https://github.com/pydantic/pydantic/blob/d61792cc42c80b13b23e3ffa74bc37ec7c77f7d1/pydantic/schema.py#L207
|
||||
json_schema[
|
||||
"title"
|
||||
] = field.field_info.title or field.alias.title().replace("_", " ")
|
||||
return json_schema
|
||||
|
||||
def get_compat_model_name_map(fields: List[ModelField]) -> ModelNameMap:
|
||||
return {}
|
||||
|
||||
def get_definitions(
|
||||
*,
|
||||
fields: List[ModelField],
|
||||
schema_generator: GenerateJsonSchema,
|
||||
model_name_map: ModelNameMap,
|
||||
) -> Dict[str, Dict[str, Any]]:
|
||||
inputs = [
|
||||
(field, "validation", field._type_adapter.core_schema) for field in fields
|
||||
]
|
||||
_, definitions = schema_generator.generate_definitions(inputs=inputs) # type: ignore[arg-type]
|
||||
return definitions # type: ignore[return-value]
|
||||
|
||||
def is_scalar_field(field: ModelField) -> bool:
|
||||
from fastapi import params
|
||||
|
||||
return field_annotation_is_scalar(
|
||||
field.field_info.annotation
|
||||
) and not isinstance(field.field_info, params.Body)
|
||||
|
||||
def is_sequence_field(field: ModelField) -> bool:
|
||||
return field_annotation_is_sequence(field.field_info.annotation)
|
||||
|
||||
def is_scalar_sequence_field(field: ModelField) -> bool:
|
||||
return field_annotation_is_scalar_sequence(field.field_info.annotation)
|
||||
|
||||
def is_bytes_field(field: ModelField) -> bool:
|
||||
return is_bytes_or_nonable_bytes_annotation(field.type_)
|
||||
|
||||
def is_bytes_sequence_field(field: ModelField) -> bool:
|
||||
return is_bytes_sequence_annotation(field.type_)
|
||||
|
||||
def copy_field_info(*, field_info: FieldInfo, annotation: Any) -> FieldInfo:
|
||||
return type(field_info).from_annotation(annotation)
|
||||
|
||||
def serialize_sequence_value(*, field: ModelField, value: Any) -> Sequence[Any]:
|
||||
origin_type = (
|
||||
get_origin(field.field_info.annotation) or field.field_info.annotation
|
||||
)
|
||||
assert issubclass(origin_type, sequence_types) # type: ignore[arg-type]
|
||||
return sequence_annotation_to_type[origin_type](value) # type: ignore[no-any-return]
|
||||
|
||||
def get_missing_field_error(loc: Tuple[str, ...]) -> Dict[str, Any]:
|
||||
error = ValidationError.from_exception_data(
|
||||
"Field required", [{"type": "missing", "loc": loc, "input": {}}]
|
||||
).errors()[0]
|
||||
error["input"] = None
|
||||
return error # type: ignore[return-value]
|
||||
|
||||
def create_body_model(
|
||||
*, fields: Sequence[ModelField], model_name: str
|
||||
) -> Type[BaseModel]:
|
||||
field_params = {f.name: (f.field_info.annotation, f.field_info) for f in fields}
|
||||
BodyModel: Type[BaseModel] = create_model(model_name, **field_params) # type: ignore[call-overload]
|
||||
return BodyModel
|
||||
|
||||
else:
|
||||
from fastapi.openapi.constants import REF_PREFIX as REF_PREFIX
|
||||
from pydantic import AnyUrl as Url # noqa: F401
|
||||
|
|
@ -333,10 +405,79 @@ else:
|
|||
use_errors.append(error)
|
||||
return use_errors
|
||||
|
||||
def _model_rebuild(model: Type[BaseModel]) -> None:
|
||||
model.update_forward_refs()
|
||||
|
||||
def _model_dump(
|
||||
model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any
|
||||
) -> Any:
|
||||
return model.dict(**kwargs)
|
||||
|
||||
def _get_model_config(model: BaseModel) -> Any:
|
||||
return model.__config__ # type: ignore[attr-defined]
|
||||
|
||||
def get_schema_from_model_field(
|
||||
*,
|
||||
field: ModelField,
|
||||
schema_generator: GenerateJsonSchema,
|
||||
model_name_map: ModelNameMap,
|
||||
) -> Dict[str, Any]:
|
||||
# This expects that GenerateJsonSchema was already used to generate the definitions
|
||||
return field_schema( # type: ignore[no-any-return]
|
||||
field, model_name_map=model_name_map, ref_prefix=REF_PREFIX
|
||||
)[0]
|
||||
|
||||
def get_compat_model_name_map(fields: List[ModelField]) -> ModelNameMap:
|
||||
models = get_flat_models_from_fields(fields, known_models=set())
|
||||
return get_model_name_map(models) # type: ignore[no-any-return]
|
||||
|
||||
def get_definitions(
|
||||
*,
|
||||
fields: List[ModelField],
|
||||
schema_generator: GenerateJsonSchema,
|
||||
model_name_map: ModelNameMap,
|
||||
) -> Dict[str, Dict[str, Any]]:
|
||||
models = get_flat_models_from_fields(fields, known_models=set())
|
||||
return get_model_definitions(flat_models=models, model_name_map=model_name_map)
|
||||
|
||||
def is_scalar_field(field: ModelField) -> bool:
|
||||
return is_pv1_scalar_field(field)
|
||||
|
||||
def is_sequence_field(field: ModelField) -> bool:
|
||||
return field.shape in sequence_shapes or _annotation_is_sequence(field.type_) # type: ignore[attr-defined]
|
||||
|
||||
def is_scalar_sequence_field(field: ModelField) -> bool:
|
||||
return is_pv1_scalar_sequence_field(field)
|
||||
|
||||
def is_bytes_field(field: ModelField) -> bool:
|
||||
return lenient_issubclass(field.type_, bytes)
|
||||
|
||||
def is_bytes_sequence_field(field: ModelField) -> bool:
|
||||
return field.shape in sequence_shapes and lenient_issubclass(field.type_, bytes) # type: ignore[attr-defined]
|
||||
|
||||
def copy_field_info(*, field_info: FieldInfo, annotation: Any) -> FieldInfo:
|
||||
return copy(field_info)
|
||||
|
||||
def serialize_sequence_value(*, field: ModelField, value: Any) -> Sequence[Any]:
|
||||
return sequence_shape_to_type[field.shape](value) # type: ignore[no-any-return,attr-defined]
|
||||
|
||||
def get_missing_field_error(loc: Tuple[str, ...]) -> Dict[str, Any]:
|
||||
missing_field_error = ErrorWrapper(MissingError(), loc=loc) # type: ignore[call-arg]
|
||||
new_error = ValidationError([missing_field_error], RequestErrorModel)
|
||||
return new_error.errors()[0] # type: ignore[return-value]
|
||||
|
||||
def create_body_model(
|
||||
*, fields: Sequence[ModelField], model_name: str
|
||||
) -> Type[BaseModel]:
|
||||
BodyModel = create_model(model_name)
|
||||
for f in fields:
|
||||
BodyModel.__fields__[f.name] = f # type: ignore[index]
|
||||
return BodyModel
|
||||
|
||||
|
||||
def _regenerate_error_with_loc(
|
||||
*, errors: Sequence[Any], loc_prefix: Tuple[Union[str, int], ...]
|
||||
) -> List[ValidationError]:
|
||||
) -> List[Dict[str, Any]]:
|
||||
updated_loc_errors: List[Any] = [
|
||||
{**err, "loc": loc_prefix + err.get("loc", ())}
|
||||
for err in _normalize_errors(errors)
|
||||
|
|
@ -345,76 +486,6 @@ def _regenerate_error_with_loc(
|
|||
return updated_loc_errors
|
||||
|
||||
|
||||
def _model_rebuild(model: Type[BaseModel]) -> None:
|
||||
if PYDANTIC_V2:
|
||||
model.model_rebuild()
|
||||
else:
|
||||
model.update_forward_refs()
|
||||
|
||||
|
||||
def _model_dump(
|
||||
model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any
|
||||
) -> Any:
|
||||
if PYDANTIC_V2:
|
||||
return model.model_dump(mode=mode, **kwargs)
|
||||
else:
|
||||
return model.dict(**kwargs)
|
||||
|
||||
|
||||
def _get_model_config(model: BaseModel) -> Any:
|
||||
if PYDANTIC_V2:
|
||||
return model.model_config
|
||||
else:
|
||||
return model.__config__ # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def get_schema_from_model_field(
|
||||
*,
|
||||
field: ModelField,
|
||||
schema_generator: GenerateJsonSchema,
|
||||
model_name_map: ModelNameMap,
|
||||
) -> Dict[str, Any]:
|
||||
# This expects that GenerateJsonSchema was already used to generate the definitions
|
||||
if PYDANTIC_V2:
|
||||
json_schema = schema_generator.generate_inner(field._type_adapter.core_schema)
|
||||
if "$ref" not in json_schema:
|
||||
# TODO remove when deprecating Pydantic v1
|
||||
# Ref: https://github.com/pydantic/pydantic/blob/d61792cc42c80b13b23e3ffa74bc37ec7c77f7d1/pydantic/schema.py#L207
|
||||
json_schema[
|
||||
"title"
|
||||
] = field.field_info.title or field.alias.title().replace("_", " ")
|
||||
return json_schema
|
||||
else:
|
||||
return field_schema( # type: ignore[no-any-return]
|
||||
field, model_name_map=model_name_map, ref_prefix=REF_PREFIX
|
||||
)[0]
|
||||
|
||||
|
||||
def get_compat_model_name_map(fields: List[ModelField]) -> ModelNameMap:
|
||||
if PYDANTIC_V2:
|
||||
return {}
|
||||
else:
|
||||
models = get_flat_models_from_fields(fields, known_models=set())
|
||||
return get_model_name_map(models) # type: ignore[no-any-return]
|
||||
|
||||
|
||||
def get_definitions(
|
||||
*,
|
||||
fields: List[ModelField],
|
||||
schema_generator: GenerateJsonSchema,
|
||||
model_name_map: ModelNameMap,
|
||||
) -> Dict[str, Dict[str, Any]]:
|
||||
if PYDANTIC_V2:
|
||||
inputs = [
|
||||
(field, "validation", field._type_adapter.core_schema) for field in fields
|
||||
]
|
||||
_, definitions = schema_generator.generate_definitions(inputs=inputs) # type: ignore[arg-type]
|
||||
return definitions # type: ignore[return-value]
|
||||
else:
|
||||
models = get_flat_models_from_fields(fields, known_models=set())
|
||||
return get_model_definitions(flat_models=models, model_name_map=model_name_map)
|
||||
|
||||
|
||||
def _annotation_is_sequence(annotation: Union[Type[Any], None]) -> bool:
|
||||
if lenient_issubclass(annotation, (str, bytes)):
|
||||
return False
|
||||
|
|
@ -453,10 +524,6 @@ def field_annotation_is_complex(annotation: Union[Type[Any], None]) -> bool:
|
|||
|
||||
|
||||
def field_annotation_is_scalar(annotation: Any) -> bool:
|
||||
origin = get_origin(annotation)
|
||||
if origin is Union or origin is UnionType:
|
||||
return all(field_annotation_is_scalar(arg) for arg in get_args(annotation))
|
||||
|
||||
# handle Ellipsis here to make tuple[int, ...] work nicely
|
||||
return annotation is Ellipsis or not field_annotation_is_complex(annotation)
|
||||
|
||||
|
|
@ -478,31 +545,6 @@ def field_annotation_is_scalar_sequence(annotation: Union[Type[Any], None]) -> b
|
|||
)
|
||||
|
||||
|
||||
def is_scalar_field(field: ModelField) -> bool:
|
||||
from fastapi import params
|
||||
|
||||
if PYDANTIC_V2:
|
||||
return field_annotation_is_scalar(
|
||||
field.field_info.annotation
|
||||
) and not isinstance(field.field_info, params.Body)
|
||||
else:
|
||||
return is_pv1_scalar_field(field)
|
||||
|
||||
|
||||
def is_sequence_field(field: ModelField) -> bool:
|
||||
if PYDANTIC_V2:
|
||||
return field_annotation_is_sequence(field.field_info.annotation)
|
||||
else:
|
||||
return field.shape in sequence_shapes or _annotation_is_sequence(field.type_) # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def is_scalar_sequence_field(field: ModelField) -> bool:
|
||||
if PYDANTIC_V2:
|
||||
return field_annotation_is_scalar_sequence(field.field_info.annotation)
|
||||
else:
|
||||
return is_pv1_scalar_sequence_field(field)
|
||||
|
||||
|
||||
def is_bytes_or_nonable_bytes_annotation(annotation: Any) -> bool:
|
||||
if lenient_issubclass(annotation, bytes):
|
||||
return True
|
||||
|
|
@ -525,90 +567,31 @@ def is_uploadfile_or_nonable_uploadfile_annotation(annotation: Any) -> bool:
|
|||
return False
|
||||
|
||||
|
||||
def is_bytes_sequence_annotation(annotation: Union[Type[Any], None]) -> bool:
|
||||
def is_bytes_sequence_annotation(annotation: Any) -> bool:
|
||||
origin = get_origin(annotation)
|
||||
if origin is Union or origin is UnionType:
|
||||
at_least_one_bytes_sequence = False
|
||||
at_least_one = False
|
||||
for arg in get_args(annotation):
|
||||
if is_bytes_sequence_annotation(arg):
|
||||
at_least_one_bytes_sequence = True
|
||||
at_least_one = True
|
||||
continue
|
||||
return at_least_one_bytes_sequence
|
||||
return at_least_one
|
||||
return field_annotation_is_sequence(annotation) and all(
|
||||
is_bytes_or_nonable_bytes_annotation(sub_annotation)
|
||||
for sub_annotation in get_args(annotation)
|
||||
)
|
||||
|
||||
|
||||
def is_uploadfile_sequence_annotation(annotation: Union[Type[Any], None]) -> bool:
|
||||
def is_uploadfile_sequence_annotation(annotation: Any) -> bool:
|
||||
origin = get_origin(annotation)
|
||||
if origin is Union or origin is UnionType:
|
||||
at_least_one_bytes_sequence = False
|
||||
at_least_one = False
|
||||
for arg in get_args(annotation):
|
||||
if is_uploadfile_sequence_annotation(arg):
|
||||
at_least_one_bytes_sequence = True
|
||||
at_least_one = True
|
||||
continue
|
||||
return at_least_one_bytes_sequence
|
||||
return at_least_one
|
||||
return field_annotation_is_sequence(annotation) and all(
|
||||
is_uploadfile_or_nonable_uploadfile_annotation(sub_annotation)
|
||||
for sub_annotation in get_args(annotation)
|
||||
)
|
||||
|
||||
|
||||
def is_bytes_field(field: ModelField) -> bool:
|
||||
if PYDANTIC_V2:
|
||||
return is_bytes_or_nonable_bytes_annotation(field.type_)
|
||||
else:
|
||||
return lenient_issubclass(field.type_, bytes)
|
||||
|
||||
|
||||
def is_bytes_sequence_field(field: ModelField) -> bool:
|
||||
if PYDANTIC_V2:
|
||||
return is_bytes_sequence_annotation(field.type_)
|
||||
else:
|
||||
return field.shape in sequence_shapes and lenient_issubclass(field.type_, bytes) # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def copy_field_info(*, field_info: FieldInfo, annotation: Any) -> FieldInfo:
|
||||
if PYDANTIC_V2:
|
||||
return type(field_info).from_annotation(annotation)
|
||||
else:
|
||||
return copy(field_info)
|
||||
|
||||
|
||||
def serialize_sequence_value(*, field: ModelField, value: Any) -> Sequence[Any]:
|
||||
if PYDANTIC_V2:
|
||||
origin_type = (
|
||||
get_origin(field.field_info.annotation) or field.field_info.annotation
|
||||
)
|
||||
assert issubclass(origin_type, sequence_types) # type: ignore[arg-type]
|
||||
return sequence_annotation_to_type[origin_type](value) # type: ignore[no-any-return]
|
||||
else:
|
||||
return sequence_shape_to_type[field.shape](value) # type: ignore[no-any-return,attr-defined]
|
||||
|
||||
|
||||
def get_missing_field_error(loc: Tuple[str, ...]) -> ValidationError:
|
||||
if PYDANTIC_V2:
|
||||
error = ValidationError.from_exception_data(
|
||||
"Field required", [{"type": "missing", "loc": loc, "input": {}}]
|
||||
).errors()[0]
|
||||
error["input"] = None
|
||||
return error # type: ignore[return-value]
|
||||
else:
|
||||
missing_field_error = ErrorWrapper(MissingError(), loc=loc) # type: ignore[call-arg]
|
||||
new_error = ValidationError([missing_field_error], RequestErrorModel)
|
||||
return new_error.errors()[0] # type: ignore[return-value]
|
||||
|
||||
|
||||
def create_body_model(
|
||||
*, fields: Sequence[ModelField], model_name: str
|
||||
) -> Type[BaseModel]:
|
||||
if PYDANTIC_V2:
|
||||
field_params = {f.name: (f.field_info.annotation, f.field_info) for f in fields}
|
||||
BodyModel: Type[BaseModel] = create_model(model_name, **field_params) # type: ignore[call-overload]
|
||||
return BodyModel
|
||||
else:
|
||||
BodyModel = create_model(model_name)
|
||||
for f in fields:
|
||||
BodyModel.__fields__[f.name] = f # type: ignore[index]
|
||||
return BodyModel
|
||||
|
|
|
|||
|
|
@ -25,7 +25,6 @@ from fastapi._compat import (
|
|||
ModelField,
|
||||
Required,
|
||||
Undefined,
|
||||
ValidationError,
|
||||
_regenerate_error_with_loc,
|
||||
copy_field_info,
|
||||
create_body_model,
|
||||
|
|
@ -659,13 +658,7 @@ def request_params_to_args(
|
|||
values[field.name] = deepcopy(field.default)
|
||||
continue
|
||||
v_, errors_ = field.validate(value, values, loc=loc)
|
||||
if isinstance(errors_, ValidationError):
|
||||
new_errors = _regenerate_error_with_loc(
|
||||
errors=errors_.errors(), loc_prefix=loc
|
||||
)
|
||||
new_error = ValidationError(title=errors_.title, errors=new_errors)
|
||||
errors.append(new_error)
|
||||
elif isinstance(errors_, ErrorWrapper):
|
||||
if isinstance(errors_, ErrorWrapper):
|
||||
errors.append(errors_)
|
||||
elif isinstance(errors_, list):
|
||||
new_errors = _regenerate_error_with_loc(errors=errors_, loc_prefix=())
|
||||
|
|
@ -678,9 +671,9 @@ def request_params_to_args(
|
|||
async def request_body_to_args(
|
||||
required_params: List[ModelField],
|
||||
received_body: Optional[Union[Dict[str, Any], FormData]],
|
||||
) -> Tuple[Dict[str, Any], List[ValidationError]]:
|
||||
) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]:
|
||||
values = {}
|
||||
errors: List[ValidationError] = []
|
||||
errors: List[Dict[str, Any]] = []
|
||||
if required_params:
|
||||
field = required_params[0]
|
||||
field_info = field.field_info
|
||||
|
|
|
|||
|
|
@ -32,6 +32,7 @@ def isoformat(o: Union[datetime.date, datetime.time]) -> str:
|
|||
|
||||
|
||||
# Taken from Pydantic v1 as is
|
||||
# TODO: pv2 should this return strings instead?
|
||||
def decimal_encoder(dec_value: Decimal) -> Union[int, float]:
|
||||
"""
|
||||
Encodes a Decimal as int of there's no exponent, otherwise float
|
||||
|
|
|
|||
|
|
@ -1,11 +1,5 @@
|
|||
import sys
|
||||
from typing import Any, Dict, List, Optional, Union, cast
|
||||
|
||||
if sys.version_info < (3, 9):
|
||||
from typing_extensions import Annotated
|
||||
else:
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi.exceptions import HTTPException
|
||||
from fastapi.openapi.models import OAuth2 as OAuth2Model
|
||||
from fastapi.openapi.models import OAuthFlows as OAuthFlowsModel
|
||||
|
|
@ -15,6 +9,9 @@ from fastapi.security.utils import get_authorization_scheme_param
|
|||
from starlette.requests import Request
|
||||
from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN
|
||||
|
||||
# TODO: import from typing when deprecating Python 3.9
|
||||
from typing_extensions import Annotated
|
||||
|
||||
|
||||
class OAuth2PasswordRequestForm:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -0,0 +1,93 @@
|
|||
from typing import List, Union
|
||||
|
||||
from fastapi import FastAPI, UploadFile
|
||||
from fastapi._compat import (
|
||||
ModelField,
|
||||
Undefined,
|
||||
_get_model_config,
|
||||
is_bytes_sequence_annotation,
|
||||
is_uploadfile_sequence_annotation,
|
||||
)
|
||||
from fastapi.testclient import TestClient
|
||||
from pydantic import BaseConfig, BaseModel, ConfigDict
|
||||
from pydantic.fields import FieldInfo
|
||||
|
||||
from .utils import needs_pydanticv1, needs_pydanticv2
|
||||
|
||||
|
||||
@needs_pydanticv2
|
||||
def test_model_field_default_required():
|
||||
# For coverage
|
||||
field_info = FieldInfo(annotation=str)
|
||||
field = ModelField(name="foo", field_info=field_info)
|
||||
assert field.default is Undefined
|
||||
|
||||
|
||||
@needs_pydanticv1
|
||||
def test_upload_file_dummy_general_plain_validator_function():
|
||||
# For coverage
|
||||
assert UploadFile.__get_pydantic_core_schema__(str, lambda x: None) == {}
|
||||
|
||||
|
||||
@needs_pydanticv1
|
||||
def test_union_scalar_list():
|
||||
# For coverage
|
||||
# TODO: there might not be a current valid code path that uses this, it would
|
||||
# potentially enable query parameters defined as both a scalar and a list
|
||||
# but that would require more refactors, also not sure it's really useful
|
||||
from fastapi._compat import is_pv1_scalar_field
|
||||
|
||||
field_info = FieldInfo()
|
||||
field = ModelField(
|
||||
name="foo",
|
||||
field_info=field_info,
|
||||
type_=Union[str, List[int]],
|
||||
class_validators={},
|
||||
model_config=BaseConfig,
|
||||
)
|
||||
assert not is_pv1_scalar_field(field)
|
||||
|
||||
|
||||
@needs_pydanticv2
|
||||
def test_get_model_config():
|
||||
# For coverage in Pydantic v2
|
||||
class Foo(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
foo = Foo()
|
||||
config = _get_model_config(foo)
|
||||
assert config == {"from_attributes": True}
|
||||
|
||||
|
||||
def test_complex():
|
||||
app = FastAPI()
|
||||
|
||||
@app.post("/")
|
||||
def foo(foo: Union[str, List[int]]):
|
||||
return foo
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
response = client.post("/", json="bar")
|
||||
assert response.status_code == 200, response.text
|
||||
assert response.json() == "bar"
|
||||
|
||||
response2 = client.post("/", json=[1, 2])
|
||||
assert response2.status_code == 200, response2.text
|
||||
assert response2.json() == [1, 2]
|
||||
|
||||
|
||||
def test_is_bytes_sequence_annotation_union():
|
||||
# For coverage
|
||||
# TODO: in theory this would allow declaring types that could be lists of bytes
|
||||
# to be read from files and other types, but I'm not even sure it's a good idea
|
||||
# to support it as a first class "feature"
|
||||
assert is_bytes_sequence_annotation(Union[List[str], List[bytes]])
|
||||
|
||||
|
||||
def test_is_uploadfile_sequence_annotation():
|
||||
# For coverage
|
||||
# TODO: in theory this would allow declaring types that could be lists of UploadFile
|
||||
# and other types, but I'm not even sure it's a good idea to support it as a first
|
||||
# class "feature"
|
||||
assert is_uploadfile_sequence_annotation(Union[List[str], List[UploadFile]])
|
||||
|
|
@ -7,11 +7,17 @@ from fastapi.datastructures import Default
|
|||
from fastapi.testclient import TestClient
|
||||
|
||||
|
||||
# TODO: remove when deprecating Pydantic v1
|
||||
def test_upload_file_invalid():
|
||||
with pytest.raises(ValueError):
|
||||
UploadFile.validate("not a Starlette UploadFile")
|
||||
|
||||
|
||||
def test_upload_file_invalid_pydantic_v2():
|
||||
with pytest.raises(ValueError):
|
||||
UploadFile._validate("not a Starlette UploadFile", {})
|
||||
|
||||
|
||||
def test_default_placeholder_equals():
|
||||
placeholder_1 = Default("a")
|
||||
placeholder_2 = Default("a")
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
from decimal import Decimal
|
||||
from enum import Enum
|
||||
from pathlib import PurePath, PurePosixPath, PureWindowsPath
|
||||
from typing import Optional
|
||||
|
|
@ -286,3 +287,15 @@ def test_encode_root():
|
|||
|
||||
model = ModelWithRoot(__root__="Foo")
|
||||
assert jsonable_encoder(model) == "Foo"
|
||||
|
||||
|
||||
@needs_pydanticv2
|
||||
def test_decimal_encoder_float():
|
||||
data = {"value": Decimal(1.23)}
|
||||
assert jsonable_encoder(data) == {"value": 1.23}
|
||||
|
||||
|
||||
@needs_pydanticv2
|
||||
def test_decimal_encoder_int():
|
||||
data = {"value": Decimal(2)}
|
||||
assert jsonable_encoder(data) == {"value": 2}
|
||||
|
|
|
|||
|
|
@ -409,6 +409,73 @@ def test_path_param_maxlength_foobar():
|
|||
)
|
||||
|
||||
|
||||
def test_path_param_min_maxlength_foo():
|
||||
response = client.get("/path/param-min_maxlength/foo")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == "foo"
|
||||
|
||||
|
||||
def test_path_param_min_maxlength_foobar():
|
||||
response = client.get("/path/param-min_maxlength/foobar")
|
||||
assert response.status_code == 422
|
||||
assert response.json() == IsDict(
|
||||
{
|
||||
"detail": [
|
||||
{
|
||||
"type": "string_too_long",
|
||||
"loc": ["path", "item_id"],
|
||||
"msg": "String should have at most 3 characters",
|
||||
"input": "foobar",
|
||||
"ctx": {"max_length": 3},
|
||||
"url": match_pydantic_error_url("string_too_long"),
|
||||
}
|
||||
]
|
||||
}
|
||||
) | IsDict(
|
||||
# TODO: remove when deprecating Pydantic v1
|
||||
{
|
||||
"detail": [
|
||||
{
|
||||
"loc": ["path", "item_id"],
|
||||
"msg": "ensure this value has at most 3 characters",
|
||||
"type": "value_error.any_str.max_length",
|
||||
"ctx": {"limit_value": 3},
|
||||
}
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test_path_param_min_maxlength_f():
|
||||
response = client.get("/path/param-min_maxlength/f")
|
||||
assert response.status_code == 422
|
||||
assert response.json() == IsDict(
|
||||
{
|
||||
"detail": [
|
||||
{
|
||||
"type": "string_too_short",
|
||||
"loc": ["path", "item_id"],
|
||||
"msg": "String should have at least 2 characters",
|
||||
"input": "f",
|
||||
"ctx": {"min_length": 2},
|
||||
"url": match_pydantic_error_url("string_too_short"),
|
||||
}
|
||||
]
|
||||
}
|
||||
) | IsDict(
|
||||
{
|
||||
"detail": [
|
||||
{
|
||||
"loc": ["path", "item_id"],
|
||||
"msg": "ensure this value has at least 2 characters",
|
||||
"type": "value_error.any_str.min_length",
|
||||
"ctx": {"limit_value": 2},
|
||||
}
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test_path_param_gt_42():
|
||||
response = client.get("/path/param-gt/42")
|
||||
assert response.status_code == 200
|
||||
|
|
|
|||
Loading…
Reference in New Issue