ignore invalid for pydantic v1

This commit is contained in:
JONEMI21 2025-11-10 12:23:26 +00:00
parent 4a7f704a07
commit 88d2ace66a
7 changed files with 72 additions and 32 deletions

View File

@ -1,4 +1,4 @@
from typing import Annotated, Dict, List, Union from typing import Annotated, Dict, List
from fastapi import FastAPI, Query from fastapi import FastAPI, Query
@ -8,15 +8,13 @@ app = FastAPI()
@app.get("/query/mixed-type-params") @app.get("/query/mixed-type-params")
def get_mixed_mapping_mixed_type_query_params( def get_mixed_mapping_mixed_type_query_params(
query: Annotated[int, Query()] = None, query: Annotated[int, Query()] = None,
mapping_query_str_or_int: Annotated[ mapping_query_str: Annotated[Dict[str, str], Query()] = None,
Union[Dict[str, str], Dict[str, int]], Query()
] = None,
mapping_query_int: Annotated[Dict[str, int], Query()] = None, mapping_query_int: Annotated[Dict[str, int], Query()] = None,
sequence_mapping_int: Annotated[Dict[str, List[int]], Query()] = None, sequence_mapping_int: Annotated[Dict[str, List[int]], Query()] = None,
): ):
return { return {
"query": query, "query": query,
"mapping_query_str_or_int": mapping_query_str_or_int, "mapping_query_str": mapping_query_str,
"mapping_query_int": mapping_query_int, "mapping_query_int": mapping_query_int,
"sequence_mapping_int": sequence_mapping_int, "sequence_mapping_int": sequence_mapping_int,
} }

View File

@ -148,6 +148,8 @@ def field_annotation_is_scalar_mapping(
annotation: Union[Type[Any], None], annotation: Union[Type[Any], None],
) -> bool: ) -> bool:
origin = get_origin(annotation) origin = get_origin(annotation)
if origin is Annotated:
return field_annotation_is_scalar_mapping(get_args(annotation)[0])
if origin is Union or origin is UnionType: if origin is Union or origin is UnionType:
at_least_one_scalar_mapping = False at_least_one_scalar_mapping = False
for arg in get_args(annotation): for arg in get_args(annotation):
@ -167,6 +169,8 @@ def field_annotation_is_scalar_sequence_mapping(
annotation: Union[Type[Any], None], annotation: Union[Type[Any], None],
) -> bool: ) -> bool:
origin = get_origin(annotation) origin = get_origin(annotation)
if origin is Annotated:
return field_annotation_is_scalar_sequence_mapping(get_args(annotation)[0])
if origin is Union or origin is UnionType: if origin is Union or origin is UnionType:
at_least_one_scalar_mapping = False at_least_one_scalar_mapping = False
for arg in get_args(annotation): for arg in get_args(annotation):

View File

@ -1,4 +1,4 @@
from copy import copy from copy import copy, deepcopy
from dataclasses import dataclass, is_dataclass from dataclasses import dataclass, is_dataclass
from enum import Enum from enum import Enum
from typing import ( from typing import (
@ -350,6 +350,54 @@ def get_model_fields(model: Type[BaseModel]) -> List[ModelField]:
return list(model.__fields__.values()) # type: ignore[attr-defined] return list(model.__fields__.values()) # type: ignore[attr-defined]
def omit_by_default(field_info: FieldInfo) -> FieldInfo: def ignore_invalid(cls, v, values, field, **kwargs) -> Any:
from .may_v1 import _regenerate_error_with_loc
field_copy = deepcopy(field)
field_copy.pre_validators = [
validator
for validator in field_copy.pre_validators
if getattr(validator, "__name__", "") != "ignore_invalid"
]
v, errors = field_copy.validate(v, values, loc=field.name)
if not errors:
return v
# pop the keys or elements that caused the validation errors and revalidate
for error in _regenerate_error_with_loc(errors=errors, loc_prefix=()):
loc = error["loc"][1:]
if len(loc) == 0:
continue
if isinstance(loc[0], int) and isinstance(v, list):
index = loc[0]
if 0 <= index < len(v):
v[index] = None
# Handle nested list validation errors (e.g., dict[str, list[str]])
elif isinstance(loc[0], str) and isinstance(v, dict):
key = loc[0]
if (
len(loc) > 1
and isinstance(loc[1], int)
and key in v
and isinstance(v[key], list)
):
list_index = loc[1]
v[key][list_index] = None
elif key in v:
v.pop(key)
if isinstance(v, list):
v = [el for el in v if el is not None]
if isinstance(v, dict):
for key in v.keys():
if isinstance(v[key], list):
v[key] = [el for el in v[key] if el is not None]
return v
def omit_by_default(field_info: FieldInfo) -> tuple[FieldInfo, dict]:
"""add a wrap validator to omit invalid values by default.""" """add a wrap validator to omit invalid values by default."""
raise NotImplementedError("This function is a placeholder in Pydantic v1.") return field_info, {"ignore_invalid": Validator(ignore_invalid, pre=True)}

View File

@ -507,13 +507,12 @@ if shared.PYDANTIC_VERSION_MINOR_TUPLE >= (2, 6):
else: else:
return OnErrorOmit[annotation] return OnErrorOmit[annotation]
def omit_by_default(field_info: FieldInfo) -> FieldInfo: def omit_by_default(field_info: FieldInfo) -> tuple[FieldInfo, dict]:
"""Set omit by default on a FieldInfo's annotation."""
new_annotation = _omit_by_default(field_info.annotation) new_annotation = _omit_by_default(field_info.annotation)
new_field_info = copy_field_info( new_field_info = copy_field_info(
field_info=field_info, annotation=new_annotation field_info=field_info, annotation=new_annotation
) )
return new_field_info return new_field_info, {}
else: else:
@ -555,9 +554,9 @@ else:
return handler(v) return handler(v)
def omit_by_default(field_info: FieldInfo) -> FieldInfo: def omit_by_default(field_info: FieldInfo) -> tuple[FieldInfo, dict]:
"""add a wrap validator to omit invalid values by default.""" """add a wrap validator to omit invalid values by default."""
field_info.metadata = field_info.metadata or [] + [ field_info.metadata = field_info.metadata or [] + [
WrapValidator(ignore_invalid) WrapValidator(ignore_invalid)
] ]
return field_info return field_info, {}

View File

@ -506,11 +506,12 @@ def analyze_param(
field_info.alias = alias field_info.alias = alias
# Omit by default for scalar mapping and scalar sequence mapping query fields # Omit by default for scalar mapping and scalar sequence mapping query fields
if isinstance(field_info, (params.Query)) and ( class_validators: dict[str, list[Any]] = {}
field_annotation_is_scalar_sequence_mapping(field_info.annotation) if isinstance(field_info, (params.Query, temp_pydantic_v1_params.Query)) and (
or field_annotation_is_scalar_mapping(field_info.annotation) field_annotation_is_scalar_sequence_mapping(use_annotation_from_field_info)
or field_annotation_is_scalar_mapping(use_annotation_from_field_info)
): ):
field_info = omit_by_default(field_info) field_info, class_validators = omit_by_default(field_info)
field = create_model_field( field = create_model_field(
name=param_name, name=param_name,
@ -520,6 +521,7 @@ def analyze_param(
required=field_info.default required=field_info.default
in (RequiredParam, may_v1.RequiredParam, Undefined), in (RequiredParam, may_v1.RequiredParam, Undefined),
field_info=field_info, field_info=field_info,
class_validators=class_validators,
) )
if is_path_param: if is_path_param:
assert is_scalar_field(field=field), ( assert is_scalar_field(field=field), (
@ -529,8 +531,8 @@ def analyze_param(
assert ( assert (
is_scalar_field(field) is_scalar_field(field)
or is_scalar_sequence_field(field) or is_scalar_sequence_field(field)
or is_scalar_sequence_mapping_field(field)
or is_scalar_mapping_field(field) or is_scalar_mapping_field(field)
or is_scalar_sequence_mapping_field(field)
or ( or (
_is_model_class(field.type_) _is_model_class(field.type_)
# For Pydantic v1 # For Pydantic v1

View File

@ -1219,15 +1219,8 @@ def test_openapi_schema():
"schema": { "schema": {
"additionalProperties": { "additionalProperties": {
"items": { "items": {
"anyOf": [
{
"type": "string",
},
{
"type": "integer", "type": "integer",
}, },
],
},
"type": "array", "type": "array",
}, },
"default": {}, "default": {},

View File

@ -1,10 +1,6 @@
import pytest import pytest
from fastapi._compat import PYDANTIC_V2
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
if not PYDANTIC_V2:
pytest.skip("This test is only for Pydantic v2", allow_module_level=True)
@pytest.fixture(name="client") @pytest.fixture(name="client")
def get_client(): def get_client():
@ -19,7 +15,7 @@ def test_foo_needy_very(client: TestClient):
assert response.status_code == 200 assert response.status_code == 200
assert response.json() == { assert response.json() == {
"query": 2, "query": 2,
"mapping_query_str_or_int": {"foo": "baz"}, "mapping_query_str": {"foo": "baz"},
"mapping_query_int": {}, "mapping_query_int": {},
"sequence_mapping_int": {"foo": []}, "sequence_mapping_int": {"foo": []},
} }
@ -30,7 +26,7 @@ def test_just_string_not_scalar_mapping(client: TestClient):
assert response.status_code == 200 assert response.status_code == 200
assert response.json() == { assert response.json() == {
"query": 2, "query": 2,
"mapping_query_str_or_int": {"bar": "3", "foo": "baz"}, "mapping_query_str": {"bar": "3", "foo": "baz"},
"mapping_query_int": {"bar": 3}, "mapping_query_int": {"bar": 3},
"sequence_mapping_int": {"bar": [3], "foo": [1, 2]}, "sequence_mapping_int": {"bar": [3], "foo": [1, 2]},
} }