WrapValidator pydanticv2 only - need different solution for v1

This commit is contained in:
JONEMI21 2025-11-07 14:47:08 +00:00
parent 687dd65c31
commit 1e22422ff9
6 changed files with 48 additions and 34 deletions

View File

@ -21,6 +21,7 @@ from .main import get_compat_model_name_map as get_compat_model_name_map
from .main import get_definitions as get_definitions from .main import get_definitions as get_definitions
from .main import get_missing_field_error as get_missing_field_error from .main import get_missing_field_error as get_missing_field_error
from .main import get_schema_from_model_field as get_schema_from_model_field from .main import get_schema_from_model_field as get_schema_from_model_field
from .main import ignore_invalid as ignore_invalid
from .main import is_bytes_field as is_bytes_field from .main import is_bytes_field as is_bytes_field
from .main import is_bytes_sequence_field as is_bytes_sequence_field from .main import is_bytes_sequence_field as is_bytes_sequence_field
from .main import is_scalar_field as is_scalar_field from .main import is_scalar_field as is_scalar_field

View File

@ -28,6 +28,7 @@ if PYDANTIC_V2:
from .v2 import Validator as Validator from .v2 import Validator as Validator
from .v2 import evaluate_forwardref as evaluate_forwardref from .v2 import evaluate_forwardref as evaluate_forwardref
from .v2 import get_missing_field_error as get_missing_field_error from .v2 import get_missing_field_error as get_missing_field_error
from .v2 import ignore_invalid as ignore_invalid
from .v2 import ( from .v2 import (
with_info_plain_validator_function as with_info_plain_validator_function, with_info_plain_validator_function as with_info_plain_validator_function,
) )
@ -44,6 +45,7 @@ else:
from .v1 import Validator as Validator from .v1 import Validator as Validator
from .v1 import evaluate_forwardref as evaluate_forwardref from .v1 import evaluate_forwardref as evaluate_forwardref
from .v1 import get_missing_field_error as get_missing_field_error from .v1 import get_missing_field_error as get_missing_field_error
from .v1 import ignore_invalid as ignore_invalid
from .v1 import ( # type: ignore[assignment] from .v1 import ( # type: ignore[assignment]
with_info_plain_validator_function as with_info_plain_validator_function, with_info_plain_validator_function as with_info_plain_validator_function,
) )

View File

@ -5,6 +5,7 @@ from collections import deque
from dataclasses import is_dataclass from dataclasses import is_dataclass
from typing import ( from typing import (
Any, Any,
Callable,
Deque, Deque,
FrozenSet, FrozenSet,
List, List,
@ -18,7 +19,7 @@ from typing import (
from fastapi._compat import may_v1 from fastapi._compat import may_v1
from fastapi.types import UnionType from fastapi.types import UnionType
from pydantic import BaseModel from pydantic import BaseModel, ValidationError
from pydantic.version import VERSION as PYDANTIC_VERSION from pydantic.version import VERSION as PYDANTIC_VERSION
from starlette.datastructures import UploadFile from starlette.datastructures import UploadFile
from typing_extensions import Annotated, get_args, get_origin from typing_extensions import Annotated, get_args, get_origin
@ -251,3 +252,30 @@ def annotation_is_pydantic_v1(annotation: Any) -> bool:
if annotation_is_pydantic_v1(sub_annotation): if annotation_is_pydantic_v1(sub_annotation):
return True return True
return False return False
def remove_invalid(v: Any, handler: Callable[[Any], Any]) -> Any:
try:
return handler(v)
except ValidationError as exc:
if not isinstance(v, dict):
raise exc
# remove invalid values from invalid keys and revalidate
errors = may_v1._regenerate_error_with_loc(errors=[exc.errors()], loc_prefix=())
for err in errors:
loc = err.get("loc", ())
if len(loc) == 1:
v.pop(loc[0], None)
elif len(loc) == 2 and isinstance(v.get(loc[0]), list):
try:
v[loc[0]][loc[1]] = None
except (ValueError, IndexError):
pass
# remove the None values from lists
for key in list(v.keys()):
if isinstance(v[key], list):
v[key] = [item for item in v[key] if item is not None]
# remove empty lists
if v[key] == []:
v.pop(key)
return handler(v)

View File

@ -348,3 +348,7 @@ def create_body_model(
def get_model_fields(model: Type[BaseModel]) -> List[ModelField]: def get_model_fields(model: Type[BaseModel]) -> List[ModelField]:
return list(model.__fields__.values()) # type: ignore[attr-defined] return list(model.__fields__.values()) # type: ignore[attr-defined]
def ignore_invalid(field_info: FieldInfo) -> FieldInfo:
return field_info

View File

@ -18,7 +18,7 @@ from typing import (
from fastapi._compat import may_v1, shared from fastapi._compat import may_v1, shared
from fastapi.openapi.constants import REF_TEMPLATE from fastapi.openapi.constants import REF_TEMPLATE
from fastapi.types import IncEx, ModelNameMap from fastapi.types import IncEx, ModelNameMap
from pydantic import BaseModel, TypeAdapter, create_model from pydantic import BaseModel, TypeAdapter, WrapValidator, create_model
from pydantic import PydanticSchemaGenerationError as PydanticSchemaGenerationError from pydantic import PydanticSchemaGenerationError as PydanticSchemaGenerationError
from pydantic import PydanticUndefinedAnnotation as PydanticUndefinedAnnotation from pydantic import PydanticUndefinedAnnotation as PydanticUndefinedAnnotation
from pydantic import ValidationError as ValidationError from pydantic import ValidationError as ValidationError
@ -487,3 +487,11 @@ def get_flat_models_from_fields(
def get_long_model_name(model: TypeModelOrEnum) -> str: def get_long_model_name(model: TypeModelOrEnum) -> str:
return f"{model.__module__}__{model.__qualname__}".replace(".", "__") return f"{model.__module__}__{model.__qualname__}".replace(".", "__")
def ignore_invalid(field_info: FieldInfo) -> FieldInfo:
new_field_info = copy(field_info)
new_field_info.metadata = getattr(field_info, "metadata", []) + [
WrapValidator(shared.remove_invalid)
]
return new_field_info

View File

@ -34,6 +34,7 @@ from fastapi._compat import (
get_annotation_from_field_info, get_annotation_from_field_info,
get_cached_model_fields, get_cached_model_fields,
get_missing_field_error, get_missing_field_error,
ignore_invalid,
is_bytes_field, is_bytes_field,
is_bytes_sequence_field, is_bytes_sequence_field,
is_scalar_field, is_scalar_field,
@ -63,7 +64,7 @@ from fastapi.security.oauth2 import OAuth2, SecurityScopes
from fastapi.security.open_id_connect_url import OpenIdConnect from fastapi.security.open_id_connect_url import OpenIdConnect
from fastapi.types import DependencyCacheKey from fastapi.types import DependencyCacheKey
from fastapi.utils import create_model_field, get_path_param_names from fastapi.utils import create_model_field, get_path_param_names
from pydantic import BaseModel, ValidationError, WrapValidator from pydantic import BaseModel
from pydantic.fields import FieldInfo from pydantic.fields import FieldInfo
from starlette.background import BackgroundTasks as StarletteBackgroundTasks from starlette.background import BackgroundTasks as StarletteBackgroundTasks
from starlette.concurrency import run_in_threadpool from starlette.concurrency import run_in_threadpool
@ -330,33 +331,6 @@ def add_non_field_param_to_dependency(
return None return None
def remove_invalid(v: Any, handler: Callable[[Any], Any]) -> Any:
try:
return handler(v)
except ValidationError as exc:
if not isinstance(v, dict):
raise exc
# remove invalid values from invalid keys and revalidate
errors = may_v1._regenerate_error_with_loc(errors=[exc.errors()], loc_prefix=())
for err in errors:
loc = err.get("loc", ())
if len(loc) == 1:
v.pop(loc[0], None)
elif len(loc) == 2 and isinstance(v.get(loc[0]), list):
try:
v[loc[0]][loc[1]] = None
except (ValueError, IndexError):
pass
# remove the None values from lists
for key in list(v.keys()):
if isinstance(v[key], list):
v[key] = [item for item in v[key] if item is not None]
# remove empty lists
if v[key] == []:
v.pop(key)
return handler(v)
@dataclass @dataclass
class ParamDetails: class ParamDetails:
type_annotation: Any type_annotation: Any
@ -553,9 +527,6 @@ def analyze_param(
if is_scalar_sequence_field(field) or is_scalar_sequence_mapping_field( if is_scalar_sequence_field(field) or is_scalar_sequence_mapping_field(
field field
): ):
field_info.metadata = getattr(field_info, "metadata", []) + [
WrapValidator(remove_invalid)
]
field = create_model_field( field = create_model_field(
name=param_name, name=param_name,
type_=use_annotation_from_field_info, type_=use_annotation_from_field_info,
@ -563,7 +534,7 @@ def analyze_param(
alias=alias, alias=alias,
required=field_info.default required=field_info.default
in (RequiredParam, may_v1.RequiredParam, Undefined), in (RequiredParam, may_v1.RequiredParam, Undefined),
field_info=field_info, field_info=ignore_invalid(field_info),
) )
return ParamDetails(type_annotation=type_annotation, depends=depends, field=field) return ParamDetails(type_annotation=type_annotation, depends=depends, field=field)