mirror of https://github.com/tiangolo/fastapi.git
ignore invalid for pydantic v1
This commit is contained in:
parent
4a7f704a07
commit
88d2ace66a
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import Annotated, Dict, List, Union
|
from typing import Annotated, Dict, List
|
||||||
|
|
||||||
from fastapi import FastAPI, Query
|
from fastapi import FastAPI, Query
|
||||||
|
|
||||||
|
|
@ -8,15 +8,13 @@ app = FastAPI()
|
||||||
@app.get("/query/mixed-type-params")
|
@app.get("/query/mixed-type-params")
|
||||||
def get_mixed_mapping_mixed_type_query_params(
|
def get_mixed_mapping_mixed_type_query_params(
|
||||||
query: Annotated[int, Query()] = None,
|
query: Annotated[int, Query()] = None,
|
||||||
mapping_query_str_or_int: Annotated[
|
mapping_query_str: Annotated[Dict[str, str], Query()] = None,
|
||||||
Union[Dict[str, str], Dict[str, int]], Query()
|
|
||||||
] = None,
|
|
||||||
mapping_query_int: Annotated[Dict[str, int], Query()] = None,
|
mapping_query_int: Annotated[Dict[str, int], Query()] = None,
|
||||||
sequence_mapping_int: Annotated[Dict[str, List[int]], Query()] = None,
|
sequence_mapping_int: Annotated[Dict[str, List[int]], Query()] = None,
|
||||||
):
|
):
|
||||||
return {
|
return {
|
||||||
"query": query,
|
"query": query,
|
||||||
"mapping_query_str_or_int": mapping_query_str_or_int,
|
"mapping_query_str": mapping_query_str,
|
||||||
"mapping_query_int": mapping_query_int,
|
"mapping_query_int": mapping_query_int,
|
||||||
"sequence_mapping_int": sequence_mapping_int,
|
"sequence_mapping_int": sequence_mapping_int,
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -148,6 +148,8 @@ def field_annotation_is_scalar_mapping(
|
||||||
annotation: Union[Type[Any], None],
|
annotation: Union[Type[Any], None],
|
||||||
) -> bool:
|
) -> bool:
|
||||||
origin = get_origin(annotation)
|
origin = get_origin(annotation)
|
||||||
|
if origin is Annotated:
|
||||||
|
return field_annotation_is_scalar_mapping(get_args(annotation)[0])
|
||||||
if origin is Union or origin is UnionType:
|
if origin is Union or origin is UnionType:
|
||||||
at_least_one_scalar_mapping = False
|
at_least_one_scalar_mapping = False
|
||||||
for arg in get_args(annotation):
|
for arg in get_args(annotation):
|
||||||
|
|
@ -167,6 +169,8 @@ def field_annotation_is_scalar_sequence_mapping(
|
||||||
annotation: Union[Type[Any], None],
|
annotation: Union[Type[Any], None],
|
||||||
) -> bool:
|
) -> bool:
|
||||||
origin = get_origin(annotation)
|
origin = get_origin(annotation)
|
||||||
|
if origin is Annotated:
|
||||||
|
return field_annotation_is_scalar_sequence_mapping(get_args(annotation)[0])
|
||||||
if origin is Union or origin is UnionType:
|
if origin is Union or origin is UnionType:
|
||||||
at_least_one_scalar_mapping = False
|
at_least_one_scalar_mapping = False
|
||||||
for arg in get_args(annotation):
|
for arg in get_args(annotation):
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
from copy import copy
|
from copy import copy, deepcopy
|
||||||
from dataclasses import dataclass, is_dataclass
|
from dataclasses import dataclass, is_dataclass
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import (
|
from typing import (
|
||||||
|
|
@ -350,6 +350,54 @@ def get_model_fields(model: Type[BaseModel]) -> List[ModelField]:
|
||||||
return list(model.__fields__.values()) # type: ignore[attr-defined]
|
return list(model.__fields__.values()) # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
|
||||||
def omit_by_default(field_info: FieldInfo) -> FieldInfo:
|
def ignore_invalid(cls, v, values, field, **kwargs) -> Any:
|
||||||
|
from .may_v1 import _regenerate_error_with_loc
|
||||||
|
|
||||||
|
field_copy = deepcopy(field)
|
||||||
|
field_copy.pre_validators = [
|
||||||
|
validator
|
||||||
|
for validator in field_copy.pre_validators
|
||||||
|
if getattr(validator, "__name__", "") != "ignore_invalid"
|
||||||
|
]
|
||||||
|
v, errors = field_copy.validate(v, values, loc=field.name)
|
||||||
|
if not errors:
|
||||||
|
return v
|
||||||
|
|
||||||
|
# pop the keys or elements that caused the validation errors and revalidate
|
||||||
|
for error in _regenerate_error_with_loc(errors=errors, loc_prefix=()):
|
||||||
|
loc = error["loc"][1:]
|
||||||
|
if len(loc) == 0:
|
||||||
|
continue
|
||||||
|
if isinstance(loc[0], int) and isinstance(v, list):
|
||||||
|
index = loc[0]
|
||||||
|
if 0 <= index < len(v):
|
||||||
|
v[index] = None
|
||||||
|
|
||||||
|
# Handle nested list validation errors (e.g., dict[str, list[str]])
|
||||||
|
elif isinstance(loc[0], str) and isinstance(v, dict):
|
||||||
|
key = loc[0]
|
||||||
|
if (
|
||||||
|
len(loc) > 1
|
||||||
|
and isinstance(loc[1], int)
|
||||||
|
and key in v
|
||||||
|
and isinstance(v[key], list)
|
||||||
|
):
|
||||||
|
list_index = loc[1]
|
||||||
|
v[key][list_index] = None
|
||||||
|
elif key in v:
|
||||||
|
v.pop(key)
|
||||||
|
|
||||||
|
if isinstance(v, list):
|
||||||
|
v = [el for el in v if el is not None]
|
||||||
|
|
||||||
|
if isinstance(v, dict):
|
||||||
|
for key in v.keys():
|
||||||
|
if isinstance(v[key], list):
|
||||||
|
v[key] = [el for el in v[key] if el is not None]
|
||||||
|
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
|
def omit_by_default(field_info: FieldInfo) -> tuple[FieldInfo, dict]:
|
||||||
"""add a wrap validator to omit invalid values by default."""
|
"""add a wrap validator to omit invalid values by default."""
|
||||||
raise NotImplementedError("This function is a placeholder in Pydantic v1.")
|
return field_info, {"ignore_invalid": Validator(ignore_invalid, pre=True)}
|
||||||
|
|
|
||||||
|
|
@ -507,13 +507,12 @@ if shared.PYDANTIC_VERSION_MINOR_TUPLE >= (2, 6):
|
||||||
else:
|
else:
|
||||||
return OnErrorOmit[annotation]
|
return OnErrorOmit[annotation]
|
||||||
|
|
||||||
def omit_by_default(field_info: FieldInfo) -> FieldInfo:
|
def omit_by_default(field_info: FieldInfo) -> tuple[FieldInfo, dict]:
|
||||||
"""Set omit by default on a FieldInfo's annotation."""
|
|
||||||
new_annotation = _omit_by_default(field_info.annotation)
|
new_annotation = _omit_by_default(field_info.annotation)
|
||||||
new_field_info = copy_field_info(
|
new_field_info = copy_field_info(
|
||||||
field_info=field_info, annotation=new_annotation
|
field_info=field_info, annotation=new_annotation
|
||||||
)
|
)
|
||||||
return new_field_info
|
return new_field_info, {}
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
|
||||||
|
|
@ -555,9 +554,9 @@ else:
|
||||||
|
|
||||||
return handler(v)
|
return handler(v)
|
||||||
|
|
||||||
def omit_by_default(field_info: FieldInfo) -> FieldInfo:
|
def omit_by_default(field_info: FieldInfo) -> tuple[FieldInfo, dict]:
|
||||||
"""add a wrap validator to omit invalid values by default."""
|
"""add a wrap validator to omit invalid values by default."""
|
||||||
field_info.metadata = field_info.metadata or [] + [
|
field_info.metadata = field_info.metadata or [] + [
|
||||||
WrapValidator(ignore_invalid)
|
WrapValidator(ignore_invalid)
|
||||||
]
|
]
|
||||||
return field_info
|
return field_info, {}
|
||||||
|
|
|
||||||
|
|
@ -506,11 +506,12 @@ def analyze_param(
|
||||||
field_info.alias = alias
|
field_info.alias = alias
|
||||||
|
|
||||||
# Omit by default for scalar mapping and scalar sequence mapping query fields
|
# Omit by default for scalar mapping and scalar sequence mapping query fields
|
||||||
if isinstance(field_info, (params.Query)) and (
|
class_validators: dict[str, list[Any]] = {}
|
||||||
field_annotation_is_scalar_sequence_mapping(field_info.annotation)
|
if isinstance(field_info, (params.Query, temp_pydantic_v1_params.Query)) and (
|
||||||
or field_annotation_is_scalar_mapping(field_info.annotation)
|
field_annotation_is_scalar_sequence_mapping(use_annotation_from_field_info)
|
||||||
|
or field_annotation_is_scalar_mapping(use_annotation_from_field_info)
|
||||||
):
|
):
|
||||||
field_info = omit_by_default(field_info)
|
field_info, class_validators = omit_by_default(field_info)
|
||||||
|
|
||||||
field = create_model_field(
|
field = create_model_field(
|
||||||
name=param_name,
|
name=param_name,
|
||||||
|
|
@ -520,6 +521,7 @@ def analyze_param(
|
||||||
required=field_info.default
|
required=field_info.default
|
||||||
in (RequiredParam, may_v1.RequiredParam, Undefined),
|
in (RequiredParam, may_v1.RequiredParam, Undefined),
|
||||||
field_info=field_info,
|
field_info=field_info,
|
||||||
|
class_validators=class_validators,
|
||||||
)
|
)
|
||||||
if is_path_param:
|
if is_path_param:
|
||||||
assert is_scalar_field(field=field), (
|
assert is_scalar_field(field=field), (
|
||||||
|
|
@ -529,8 +531,8 @@ def analyze_param(
|
||||||
assert (
|
assert (
|
||||||
is_scalar_field(field)
|
is_scalar_field(field)
|
||||||
or is_scalar_sequence_field(field)
|
or is_scalar_sequence_field(field)
|
||||||
or is_scalar_sequence_mapping_field(field)
|
|
||||||
or is_scalar_mapping_field(field)
|
or is_scalar_mapping_field(field)
|
||||||
|
or is_scalar_sequence_mapping_field(field)
|
||||||
or (
|
or (
|
||||||
_is_model_class(field.type_)
|
_is_model_class(field.type_)
|
||||||
# For Pydantic v1
|
# For Pydantic v1
|
||||||
|
|
|
||||||
|
|
@ -1219,14 +1219,7 @@ def test_openapi_schema():
|
||||||
"schema": {
|
"schema": {
|
||||||
"additionalProperties": {
|
"additionalProperties": {
|
||||||
"items": {
|
"items": {
|
||||||
"anyOf": [
|
"type": "integer",
|
||||||
{
|
|
||||||
"type": "string",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "integer",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
},
|
},
|
||||||
"type": "array",
|
"type": "array",
|
||||||
},
|
},
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,6 @@
|
||||||
import pytest
|
import pytest
|
||||||
from fastapi._compat import PYDANTIC_V2
|
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
if not PYDANTIC_V2:
|
|
||||||
pytest.skip("This test is only for Pydantic v2", allow_module_level=True)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="client")
|
@pytest.fixture(name="client")
|
||||||
def get_client():
|
def get_client():
|
||||||
|
|
@ -19,7 +15,7 @@ def test_foo_needy_very(client: TestClient):
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.json() == {
|
assert response.json() == {
|
||||||
"query": 2,
|
"query": 2,
|
||||||
"mapping_query_str_or_int": {"foo": "baz"},
|
"mapping_query_str": {"foo": "baz"},
|
||||||
"mapping_query_int": {},
|
"mapping_query_int": {},
|
||||||
"sequence_mapping_int": {"foo": []},
|
"sequence_mapping_int": {"foo": []},
|
||||||
}
|
}
|
||||||
|
|
@ -30,7 +26,7 @@ def test_just_string_not_scalar_mapping(client: TestClient):
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.json() == {
|
assert response.json() == {
|
||||||
"query": 2,
|
"query": 2,
|
||||||
"mapping_query_str_or_int": {"bar": "3", "foo": "baz"},
|
"mapping_query_str": {"bar": "3", "foo": "baz"},
|
||||||
"mapping_query_int": {"bar": 3},
|
"mapping_query_int": {"bar": 3},
|
||||||
"sequence_mapping_int": {"bar": [3], "foo": [1, 2]},
|
"sequence_mapping_int": {"bar": [3], "foo": [1, 2]},
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue