mirror of https://github.com/tiangolo/fastapi.git
WrapValidator pydanticv2 only - need different solution for v1
This commit is contained in:
parent
687dd65c31
commit
1e22422ff9
|
|
@ -21,6 +21,7 @@ from .main import get_compat_model_name_map as get_compat_model_name_map
|
||||||
from .main import get_definitions as get_definitions
|
from .main import get_definitions as get_definitions
|
||||||
from .main import get_missing_field_error as get_missing_field_error
|
from .main import get_missing_field_error as get_missing_field_error
|
||||||
from .main import get_schema_from_model_field as get_schema_from_model_field
|
from .main import get_schema_from_model_field as get_schema_from_model_field
|
||||||
|
from .main import ignore_invalid as ignore_invalid
|
||||||
from .main import is_bytes_field as is_bytes_field
|
from .main import is_bytes_field as is_bytes_field
|
||||||
from .main import is_bytes_sequence_field as is_bytes_sequence_field
|
from .main import is_bytes_sequence_field as is_bytes_sequence_field
|
||||||
from .main import is_scalar_field as is_scalar_field
|
from .main import is_scalar_field as is_scalar_field
|
||||||
|
|
|
||||||
|
|
@ -28,6 +28,7 @@ if PYDANTIC_V2:
|
||||||
from .v2 import Validator as Validator
|
from .v2 import Validator as Validator
|
||||||
from .v2 import evaluate_forwardref as evaluate_forwardref
|
from .v2 import evaluate_forwardref as evaluate_forwardref
|
||||||
from .v2 import get_missing_field_error as get_missing_field_error
|
from .v2 import get_missing_field_error as get_missing_field_error
|
||||||
|
from .v2 import ignore_invalid as ignore_invalid
|
||||||
from .v2 import (
|
from .v2 import (
|
||||||
with_info_plain_validator_function as with_info_plain_validator_function,
|
with_info_plain_validator_function as with_info_plain_validator_function,
|
||||||
)
|
)
|
||||||
|
|
@ -44,6 +45,7 @@ else:
|
||||||
from .v1 import Validator as Validator
|
from .v1 import Validator as Validator
|
||||||
from .v1 import evaluate_forwardref as evaluate_forwardref
|
from .v1 import evaluate_forwardref as evaluate_forwardref
|
||||||
from .v1 import get_missing_field_error as get_missing_field_error
|
from .v1 import get_missing_field_error as get_missing_field_error
|
||||||
|
from .v1 import ignore_invalid as ignore_invalid
|
||||||
from .v1 import ( # type: ignore[assignment]
|
from .v1 import ( # type: ignore[assignment]
|
||||||
with_info_plain_validator_function as with_info_plain_validator_function,
|
with_info_plain_validator_function as with_info_plain_validator_function,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@ from collections import deque
|
||||||
from dataclasses import is_dataclass
|
from dataclasses import is_dataclass
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
|
Callable,
|
||||||
Deque,
|
Deque,
|
||||||
FrozenSet,
|
FrozenSet,
|
||||||
List,
|
List,
|
||||||
|
|
@ -18,7 +19,7 @@ from typing import (
|
||||||
|
|
||||||
from fastapi._compat import may_v1
|
from fastapi._compat import may_v1
|
||||||
from fastapi.types import UnionType
|
from fastapi.types import UnionType
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, ValidationError
|
||||||
from pydantic.version import VERSION as PYDANTIC_VERSION
|
from pydantic.version import VERSION as PYDANTIC_VERSION
|
||||||
from starlette.datastructures import UploadFile
|
from starlette.datastructures import UploadFile
|
||||||
from typing_extensions import Annotated, get_args, get_origin
|
from typing_extensions import Annotated, get_args, get_origin
|
||||||
|
|
@ -251,3 +252,30 @@ def annotation_is_pydantic_v1(annotation: Any) -> bool:
|
||||||
if annotation_is_pydantic_v1(sub_annotation):
|
if annotation_is_pydantic_v1(sub_annotation):
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def remove_invalid(v: Any, handler: Callable[[Any], Any]) -> Any:
|
||||||
|
try:
|
||||||
|
return handler(v)
|
||||||
|
except ValidationError as exc:
|
||||||
|
if not isinstance(v, dict):
|
||||||
|
raise exc
|
||||||
|
# remove invalid values from invalid keys and revalidate
|
||||||
|
errors = may_v1._regenerate_error_with_loc(errors=[exc.errors()], loc_prefix=())
|
||||||
|
for err in errors:
|
||||||
|
loc = err.get("loc", ())
|
||||||
|
if len(loc) == 1:
|
||||||
|
v.pop(loc[0], None)
|
||||||
|
elif len(loc) == 2 and isinstance(v.get(loc[0]), list):
|
||||||
|
try:
|
||||||
|
v[loc[0]][loc[1]] = None
|
||||||
|
except (ValueError, IndexError):
|
||||||
|
pass
|
||||||
|
# remove the None values from lists
|
||||||
|
for key in list(v.keys()):
|
||||||
|
if isinstance(v[key], list):
|
||||||
|
v[key] = [item for item in v[key] if item is not None]
|
||||||
|
# remove empty lists
|
||||||
|
if v[key] == []:
|
||||||
|
v.pop(key)
|
||||||
|
return handler(v)
|
||||||
|
|
|
||||||
|
|
@ -348,3 +348,7 @@ def create_body_model(
|
||||||
|
|
||||||
def get_model_fields(model: Type[BaseModel]) -> List[ModelField]:
|
def get_model_fields(model: Type[BaseModel]) -> List[ModelField]:
|
||||||
return list(model.__fields__.values()) # type: ignore[attr-defined]
|
return list(model.__fields__.values()) # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
|
||||||
|
def ignore_invalid(field_info: FieldInfo) -> FieldInfo:
|
||||||
|
return field_info
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,7 @@ from typing import (
|
||||||
from fastapi._compat import may_v1, shared
|
from fastapi._compat import may_v1, shared
|
||||||
from fastapi.openapi.constants import REF_TEMPLATE
|
from fastapi.openapi.constants import REF_TEMPLATE
|
||||||
from fastapi.types import IncEx, ModelNameMap
|
from fastapi.types import IncEx, ModelNameMap
|
||||||
from pydantic import BaseModel, TypeAdapter, create_model
|
from pydantic import BaseModel, TypeAdapter, WrapValidator, create_model
|
||||||
from pydantic import PydanticSchemaGenerationError as PydanticSchemaGenerationError
|
from pydantic import PydanticSchemaGenerationError as PydanticSchemaGenerationError
|
||||||
from pydantic import PydanticUndefinedAnnotation as PydanticUndefinedAnnotation
|
from pydantic import PydanticUndefinedAnnotation as PydanticUndefinedAnnotation
|
||||||
from pydantic import ValidationError as ValidationError
|
from pydantic import ValidationError as ValidationError
|
||||||
|
|
@ -487,3 +487,11 @@ def get_flat_models_from_fields(
|
||||||
|
|
||||||
def get_long_model_name(model: TypeModelOrEnum) -> str:
|
def get_long_model_name(model: TypeModelOrEnum) -> str:
|
||||||
return f"{model.__module__}__{model.__qualname__}".replace(".", "__")
|
return f"{model.__module__}__{model.__qualname__}".replace(".", "__")
|
||||||
|
|
||||||
|
|
||||||
|
def ignore_invalid(field_info: FieldInfo) -> FieldInfo:
|
||||||
|
new_field_info = copy(field_info)
|
||||||
|
new_field_info.metadata = getattr(field_info, "metadata", []) + [
|
||||||
|
WrapValidator(shared.remove_invalid)
|
||||||
|
]
|
||||||
|
return new_field_info
|
||||||
|
|
|
||||||
|
|
@ -34,6 +34,7 @@ from fastapi._compat import (
|
||||||
get_annotation_from_field_info,
|
get_annotation_from_field_info,
|
||||||
get_cached_model_fields,
|
get_cached_model_fields,
|
||||||
get_missing_field_error,
|
get_missing_field_error,
|
||||||
|
ignore_invalid,
|
||||||
is_bytes_field,
|
is_bytes_field,
|
||||||
is_bytes_sequence_field,
|
is_bytes_sequence_field,
|
||||||
is_scalar_field,
|
is_scalar_field,
|
||||||
|
|
@ -63,7 +64,7 @@ from fastapi.security.oauth2 import OAuth2, SecurityScopes
|
||||||
from fastapi.security.open_id_connect_url import OpenIdConnect
|
from fastapi.security.open_id_connect_url import OpenIdConnect
|
||||||
from fastapi.types import DependencyCacheKey
|
from fastapi.types import DependencyCacheKey
|
||||||
from fastapi.utils import create_model_field, get_path_param_names
|
from fastapi.utils import create_model_field, get_path_param_names
|
||||||
from pydantic import BaseModel, ValidationError, WrapValidator
|
from pydantic import BaseModel
|
||||||
from pydantic.fields import FieldInfo
|
from pydantic.fields import FieldInfo
|
||||||
from starlette.background import BackgroundTasks as StarletteBackgroundTasks
|
from starlette.background import BackgroundTasks as StarletteBackgroundTasks
|
||||||
from starlette.concurrency import run_in_threadpool
|
from starlette.concurrency import run_in_threadpool
|
||||||
|
|
@ -330,33 +331,6 @@ def add_non_field_param_to_dependency(
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def remove_invalid(v: Any, handler: Callable[[Any], Any]) -> Any:
|
|
||||||
try:
|
|
||||||
return handler(v)
|
|
||||||
except ValidationError as exc:
|
|
||||||
if not isinstance(v, dict):
|
|
||||||
raise exc
|
|
||||||
# remove invalid values from invalid keys and revalidate
|
|
||||||
errors = may_v1._regenerate_error_with_loc(errors=[exc.errors()], loc_prefix=())
|
|
||||||
for err in errors:
|
|
||||||
loc = err.get("loc", ())
|
|
||||||
if len(loc) == 1:
|
|
||||||
v.pop(loc[0], None)
|
|
||||||
elif len(loc) == 2 and isinstance(v.get(loc[0]), list):
|
|
||||||
try:
|
|
||||||
v[loc[0]][loc[1]] = None
|
|
||||||
except (ValueError, IndexError):
|
|
||||||
pass
|
|
||||||
# remove the None values from lists
|
|
||||||
for key in list(v.keys()):
|
|
||||||
if isinstance(v[key], list):
|
|
||||||
v[key] = [item for item in v[key] if item is not None]
|
|
||||||
# remove empty lists
|
|
||||||
if v[key] == []:
|
|
||||||
v.pop(key)
|
|
||||||
return handler(v)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ParamDetails:
|
class ParamDetails:
|
||||||
type_annotation: Any
|
type_annotation: Any
|
||||||
|
|
@ -553,9 +527,6 @@ def analyze_param(
|
||||||
if is_scalar_sequence_field(field) or is_scalar_sequence_mapping_field(
|
if is_scalar_sequence_field(field) or is_scalar_sequence_mapping_field(
|
||||||
field
|
field
|
||||||
):
|
):
|
||||||
field_info.metadata = getattr(field_info, "metadata", []) + [
|
|
||||||
WrapValidator(remove_invalid)
|
|
||||||
]
|
|
||||||
field = create_model_field(
|
field = create_model_field(
|
||||||
name=param_name,
|
name=param_name,
|
||||||
type_=use_annotation_from_field_info,
|
type_=use_annotation_from_field_info,
|
||||||
|
|
@ -563,7 +534,7 @@ def analyze_param(
|
||||||
alias=alias,
|
alias=alias,
|
||||||
required=field_info.default
|
required=field_info.default
|
||||||
in (RequiredParam, may_v1.RequiredParam, Undefined),
|
in (RequiredParam, may_v1.RequiredParam, Undefined),
|
||||||
field_info=field_info,
|
field_info=ignore_invalid(field_info),
|
||||||
)
|
)
|
||||||
|
|
||||||
return ParamDetails(type_annotation=type_annotation, depends=depends, field=field)
|
return ParamDetails(type_annotation=type_annotation, depends=depends, field=field)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue