ignore invalid for pydantic v1

This commit is contained in:
JONEMI21 2025-11-10 12:23:26 +00:00
parent 4a7f704a07
commit 88d2ace66a
7 changed files with 72 additions and 32 deletions

View File

@ -1,4 +1,4 @@
from typing import Annotated, Dict, List, Union
from typing import Annotated, Dict, List
from fastapi import FastAPI, Query
@ -8,15 +8,13 @@ app = FastAPI()
@app.get("/query/mixed-type-params")
def get_mixed_mapping_mixed_type_query_params(
query: Annotated[int, Query()] = None,
mapping_query_str_or_int: Annotated[
Union[Dict[str, str], Dict[str, int]], Query()
] = None,
mapping_query_str: Annotated[Dict[str, str], Query()] = None,
mapping_query_int: Annotated[Dict[str, int], Query()] = None,
sequence_mapping_int: Annotated[Dict[str, List[int]], Query()] = None,
):
return {
"query": query,
"mapping_query_str_or_int": mapping_query_str_or_int,
"mapping_query_str": mapping_query_str,
"mapping_query_int": mapping_query_int,
"sequence_mapping_int": sequence_mapping_int,
}

View File

@ -148,6 +148,8 @@ def field_annotation_is_scalar_mapping(
annotation: Union[Type[Any], None],
) -> bool:
origin = get_origin(annotation)
if origin is Annotated:
return field_annotation_is_scalar_mapping(get_args(annotation)[0])
if origin is Union or origin is UnionType:
at_least_one_scalar_mapping = False
for arg in get_args(annotation):
@ -167,6 +169,8 @@ def field_annotation_is_scalar_sequence_mapping(
annotation: Union[Type[Any], None],
) -> bool:
origin = get_origin(annotation)
if origin is Annotated:
return field_annotation_is_scalar_sequence_mapping(get_args(annotation)[0])
if origin is Union or origin is UnionType:
at_least_one_scalar_mapping = False
for arg in get_args(annotation):

View File

@ -1,4 +1,4 @@
from copy import copy
from copy import copy, deepcopy
from dataclasses import dataclass, is_dataclass
from enum import Enum
from typing import (
@ -350,6 +350,54 @@ def get_model_fields(model: Type[BaseModel]) -> List[ModelField]:
return list(model.__fields__.values()) # type: ignore[attr-defined]
def omit_by_default(field_info: FieldInfo) -> FieldInfo:
def ignore_invalid(cls, v, values, field, **kwargs) -> Any:
from .may_v1 import _regenerate_error_with_loc
field_copy = deepcopy(field)
field_copy.pre_validators = [
validator
for validator in field_copy.pre_validators
if getattr(validator, "__name__", "") != "ignore_invalid"
]
v, errors = field_copy.validate(v, values, loc=field.name)
if not errors:
return v
# pop the keys or elements that caused the validation errors and revalidate
for error in _regenerate_error_with_loc(errors=errors, loc_prefix=()):
loc = error["loc"][1:]
if len(loc) == 0:
continue
if isinstance(loc[0], int) and isinstance(v, list):
index = loc[0]
if 0 <= index < len(v):
v[index] = None
# Handle nested list validation errors (e.g., dict[str, list[str]])
elif isinstance(loc[0], str) and isinstance(v, dict):
key = loc[0]
if (
len(loc) > 1
and isinstance(loc[1], int)
and key in v
and isinstance(v[key], list)
):
list_index = loc[1]
v[key][list_index] = None
elif key in v:
v.pop(key)
if isinstance(v, list):
v = [el for el in v if el is not None]
if isinstance(v, dict):
for key in v.keys():
if isinstance(v[key], list):
v[key] = [el for el in v[key] if el is not None]
return v
def omit_by_default(field_info: FieldInfo) -> tuple[FieldInfo, dict]:
"""add a wrap validator to omit invalid values by default."""
raise NotImplementedError("This function is a placeholder in Pydantic v1.")
return field_info, {"ignore_invalid": Validator(ignore_invalid, pre=True)}

View File

@ -507,13 +507,12 @@ if shared.PYDANTIC_VERSION_MINOR_TUPLE >= (2, 6):
else:
return OnErrorOmit[annotation]
def omit_by_default(field_info: FieldInfo) -> FieldInfo:
"""Set omit by default on a FieldInfo's annotation."""
def omit_by_default(field_info: FieldInfo) -> tuple[FieldInfo, dict]:
new_annotation = _omit_by_default(field_info.annotation)
new_field_info = copy_field_info(
field_info=field_info, annotation=new_annotation
)
return new_field_info
return new_field_info, {}
else:
@ -555,9 +554,9 @@ else:
return handler(v)
def omit_by_default(field_info: FieldInfo) -> FieldInfo:
def omit_by_default(field_info: FieldInfo) -> tuple[FieldInfo, dict]:
"""add a wrap validator to omit invalid values by default."""
field_info.metadata = field_info.metadata or [] + [
WrapValidator(ignore_invalid)
]
return field_info
return field_info, {}

View File

@ -506,11 +506,12 @@ def analyze_param(
field_info.alias = alias
# Omit by default for scalar mapping and scalar sequence mapping query fields
if isinstance(field_info, (params.Query)) and (
field_annotation_is_scalar_sequence_mapping(field_info.annotation)
or field_annotation_is_scalar_mapping(field_info.annotation)
class_validators: dict[str, list[Any]] = {}
if isinstance(field_info, (params.Query, temp_pydantic_v1_params.Query)) and (
field_annotation_is_scalar_sequence_mapping(use_annotation_from_field_info)
or field_annotation_is_scalar_mapping(use_annotation_from_field_info)
):
field_info = omit_by_default(field_info)
field_info, class_validators = omit_by_default(field_info)
field = create_model_field(
name=param_name,
@ -520,6 +521,7 @@ def analyze_param(
required=field_info.default
in (RequiredParam, may_v1.RequiredParam, Undefined),
field_info=field_info,
class_validators=class_validators,
)
if is_path_param:
assert is_scalar_field(field=field), (
@ -529,8 +531,8 @@ def analyze_param(
assert (
is_scalar_field(field)
or is_scalar_sequence_field(field)
or is_scalar_sequence_mapping_field(field)
or is_scalar_mapping_field(field)
or is_scalar_sequence_mapping_field(field)
or (
_is_model_class(field.type_)
# For Pydantic v1

View File

@ -1219,14 +1219,7 @@ def test_openapi_schema():
"schema": {
"additionalProperties": {
"items": {
"anyOf": [
{
"type": "string",
},
{
"type": "integer",
},
],
"type": "integer",
},
"type": "array",
},

View File

@ -1,10 +1,6 @@
import pytest
from fastapi._compat import PYDANTIC_V2
from fastapi.testclient import TestClient
if not PYDANTIC_V2:
pytest.skip("This test is only for Pydantic v2", allow_module_level=True)
@pytest.fixture(name="client")
def get_client():
@ -19,7 +15,7 @@ def test_foo_needy_very(client: TestClient):
assert response.status_code == 200
assert response.json() == {
"query": 2,
"mapping_query_str_or_int": {"foo": "baz"},
"mapping_query_str": {"foo": "baz"},
"mapping_query_int": {},
"sequence_mapping_int": {"foo": []},
}
@ -30,7 +26,7 @@ def test_just_string_not_scalar_mapping(client: TestClient):
assert response.status_code == 200
assert response.json() == {
"query": 2,
"mapping_query_str_or_int": {"bar": "3", "foo": "baz"},
"mapping_query_str": {"bar": "3", "foo": "baz"},
"mapping_query_int": {"bar": 3},
"sequence_mapping_int": {"bar": [3], "foo": [1, 2]},
}