diff --git a/fastapi/_compat/__init__.py b/fastapi/_compat/__init__.py index 3561176fe..b7287ea92 100644 --- a/fastapi/_compat/__init__.py +++ b/fastapi/_compat/__init__.py @@ -21,6 +21,7 @@ from .main import get_compat_model_name_map as get_compat_model_name_map from .main import get_definitions as get_definitions from .main import get_missing_field_error as get_missing_field_error from .main import get_schema_from_model_field as get_schema_from_model_field +from .main import ignore_invalid as ignore_invalid from .main import is_bytes_field as is_bytes_field from .main import is_bytes_sequence_field as is_bytes_sequence_field from .main import is_scalar_field as is_scalar_field diff --git a/fastapi/_compat/main.py b/fastapi/_compat/main.py index 5a36d887a..bac992d4e 100644 --- a/fastapi/_compat/main.py +++ b/fastapi/_compat/main.py @@ -28,6 +28,7 @@ if PYDANTIC_V2: from .v2 import Validator as Validator from .v2 import evaluate_forwardref as evaluate_forwardref from .v2 import get_missing_field_error as get_missing_field_error + from .v2 import ignore_invalid as ignore_invalid from .v2 import ( with_info_plain_validator_function as with_info_plain_validator_function, ) @@ -44,6 +45,7 @@ else: from .v1 import Validator as Validator from .v1 import evaluate_forwardref as evaluate_forwardref from .v1 import get_missing_field_error as get_missing_field_error + from .v1 import ignore_invalid as ignore_invalid from .v1 import ( # type: ignore[assignment] with_info_plain_validator_function as with_info_plain_validator_function, ) diff --git a/fastapi/_compat/shared.py b/fastapi/_compat/shared.py index 84b72adf4..30344b630 100644 --- a/fastapi/_compat/shared.py +++ b/fastapi/_compat/shared.py @@ -5,6 +5,7 @@ from collections import deque from dataclasses import is_dataclass from typing import ( Any, + Callable, Deque, FrozenSet, List, @@ -18,7 +19,7 @@ from typing import ( from fastapi._compat import may_v1 from fastapi.types import UnionType -from pydantic import BaseModel +from pydantic import BaseModel, ValidationError from pydantic.version import VERSION as PYDANTIC_VERSION from starlette.datastructures import UploadFile from typing_extensions import Annotated, get_args, get_origin @@ -251,3 +252,30 @@ def annotation_is_pydantic_v1(annotation: Any) -> bool: if annotation_is_pydantic_v1(sub_annotation): return True return False + + +def remove_invalid(v: Any, handler: Callable[[Any], Any]) -> Any: + try: + return handler(v) + except ValidationError as exc: + if not isinstance(v, dict): + raise exc + # remove invalid values from invalid keys and revalidate + errors = may_v1._regenerate_error_with_loc(errors=[exc.errors()], loc_prefix=()) + for err in errors: + loc = err.get("loc", ()) + if len(loc) == 1: + v.pop(loc[0], None) + elif len(loc) == 2 and isinstance(v.get(loc[0]), list): + try: + v[loc[0]][loc[1]] = None + except (ValueError, IndexError): + pass + # remove the None values from lists + for key in list(v.keys()): + if isinstance(v[key], list): + v[key] = [item for item in v[key] if item is not None] + # remove empty lists + if v[key] == []: + v.pop(key) + return handler(v) diff --git a/fastapi/_compat/v1.py b/fastapi/_compat/v1.py index 1d5c83fa7..50c929943 100644 --- a/fastapi/_compat/v1.py +++ b/fastapi/_compat/v1.py @@ -348,3 +348,7 @@ def create_body_model( def get_model_fields(model: Type[BaseModel]) -> List[ModelField]: return list(model.__fields__.values()) # type: ignore[attr-defined] + + +def ignore_invalid(field_info: FieldInfo) -> FieldInfo: + return field_info diff --git a/fastapi/_compat/v2.py b/fastapi/_compat/v2.py index de033083d..11f88a1c1 100644 --- a/fastapi/_compat/v2.py +++ b/fastapi/_compat/v2.py @@ -18,7 +18,7 @@ from typing import ( from fastapi._compat import may_v1, shared from fastapi.openapi.constants import REF_TEMPLATE from fastapi.types import IncEx, ModelNameMap -from pydantic import BaseModel, TypeAdapter, create_model +from pydantic import BaseModel, TypeAdapter, WrapValidator, create_model from pydantic import PydanticSchemaGenerationError as PydanticSchemaGenerationError from pydantic import PydanticUndefinedAnnotation as PydanticUndefinedAnnotation from pydantic import ValidationError as ValidationError @@ -487,3 +487,11 @@ def get_flat_models_from_fields( def get_long_model_name(model: TypeModelOrEnum) -> str: return f"{model.__module__}__{model.__qualname__}".replace(".", "__") + + +def ignore_invalid(field_info: FieldInfo) -> FieldInfo: + new_field_info = copy(field_info) + new_field_info.metadata = getattr(field_info, "metadata", []) + [ + WrapValidator(shared.remove_invalid) + ] + return new_field_info diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index 4e0ecbc52..04282c29c 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -34,6 +34,7 @@ from fastapi._compat import ( get_annotation_from_field_info, get_cached_model_fields, get_missing_field_error, + ignore_invalid, is_bytes_field, is_bytes_sequence_field, is_scalar_field, @@ -63,7 +64,7 @@ from fastapi.security.oauth2 import OAuth2, SecurityScopes from fastapi.security.open_id_connect_url import OpenIdConnect from fastapi.types import DependencyCacheKey from fastapi.utils import create_model_field, get_path_param_names -from pydantic import BaseModel, ValidationError, WrapValidator +from pydantic import BaseModel from pydantic.fields import FieldInfo from starlette.background import BackgroundTasks as StarletteBackgroundTasks from starlette.concurrency import run_in_threadpool @@ -330,33 +331,6 @@ def add_non_field_param_to_dependency( return None -def remove_invalid(v: Any, handler: Callable[[Any], Any]) -> Any: - try: - return handler(v) - except ValidationError as exc: - if not isinstance(v, dict): - raise exc - # remove invalid values from invalid keys and revalidate - errors = may_v1._regenerate_error_with_loc(errors=[exc.errors()], loc_prefix=()) - for err in errors: - loc = err.get("loc", ()) - if len(loc) == 1: - v.pop(loc[0], None) - elif len(loc) == 2 and isinstance(v.get(loc[0]), list): - try: - v[loc[0]][loc[1]] = None - except (ValueError, IndexError): - pass - # remove the None values from lists - for key in list(v.keys()): - if isinstance(v[key], list): - v[key] = [item for item in v[key] if item is not None] - # remove empty lists - if v[key] == []: - v.pop(key) - return handler(v) - - @dataclass class ParamDetails: type_annotation: Any @@ -553,9 +527,6 @@ def analyze_param( if is_scalar_sequence_field(field) or is_scalar_sequence_mapping_field( field ): - field_info.metadata = getattr(field_info, "metadata", []) + [ - WrapValidator(remove_invalid) - ] field = create_model_field( name=param_name, type_=use_annotation_from_field_info, @@ -563,7 +534,7 @@ def analyze_param( alias=alias, required=field_info.default in (RequiredParam, may_v1.RequiredParam, Undefined), - field_info=field_info, + field_info=ignore_invalid(field_info), ) return ParamDetails(type_annotation=type_annotation, depends=depends, field=field)