mirror of https://github.com/tiangolo/fastapi.git
use wrap validator to remove failing values
This commit is contained in:
parent
103d46921b
commit
c9bfba04dd
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Dict, List
|
||||
from typing import Annotated, Dict, List, Union
|
||||
|
||||
from fastapi import FastAPI, Query
|
||||
|
||||
|
|
@ -7,14 +7,16 @@ app = FastAPI()
|
|||
|
||||
@app.get("/query/mixed-type-params")
|
||||
def get_mixed_mapping_mixed_type_query_params(
|
||||
query: int = Query(),
|
||||
mapping_query_str: Dict[str, str] = Query({}),
|
||||
mapping_query_int: Dict[str, int] = Query({}),
|
||||
sequence_mapping_queries: Dict[str, List[int]] = Query({}),
|
||||
query: Annotated[int, Query()] = None,
|
||||
mapping_query_str_or_int: Annotated[
|
||||
Union[Dict[str, str], Dict[str, int]], Query()
|
||||
] = None,
|
||||
mapping_query_int: Annotated[Dict[str, int], Query()] = None,
|
||||
sequence_mapping_int: Annotated[Dict[str, List[int]], Query()] = None,
|
||||
):
|
||||
return {
|
||||
"query": query,
|
||||
"string_mapping": mapping_query_str,
|
||||
"mapping_query_str_or_int": mapping_query_str_or_int,
|
||||
"mapping_query_int": mapping_query_int,
|
||||
"sequence_mapping_queries": sequence_mapping_queries,
|
||||
"sequence_mapping_int": sequence_mapping_int,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -173,7 +173,10 @@ def field_annotation_is_scalar_sequence_mapping(
|
|||
if field_annotation_is_scalar_sequence_mapping(arg):
|
||||
at_least_one_scalar_mapping = True
|
||||
continue
|
||||
elif not field_annotation_is_scalar(arg):
|
||||
elif not (
|
||||
field_annotation_is_scalar_sequence_mapping(arg)
|
||||
or field_annotation_is_scalar_mapping(arg)
|
||||
):
|
||||
return False
|
||||
return at_least_one_scalar_mapping
|
||||
return lenient_issubclass(origin, Mapping) and all(
|
||||
|
|
|
|||
|
|
@ -12,7 +12,6 @@ from typing import (
|
|||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
|
|
@ -64,7 +63,7 @@ from fastapi.security.oauth2 import OAuth2, SecurityScopes
|
|||
from fastapi.security.open_id_connect_url import OpenIdConnect
|
||||
from fastapi.types import DependencyCacheKey
|
||||
from fastapi.utils import create_model_field, get_path_param_names
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, ValidationError, WrapValidator
|
||||
from pydantic.fields import FieldInfo
|
||||
from starlette.background import BackgroundTasks as StarletteBackgroundTasks
|
||||
from starlette.concurrency import run_in_threadpool
|
||||
|
|
@ -330,6 +329,25 @@ def add_non_field_param_to_dependency(
|
|||
return True
|
||||
return None
|
||||
|
||||
def remove_invalid(v: Any, handler: Callable[[Any], Any]) -> Any:
|
||||
try:
|
||||
return handler(v)
|
||||
except ValidationError as exc:
|
||||
if not isinstance(v, dict):
|
||||
raise exc
|
||||
# remove invalid values from invalid keys and revalidate
|
||||
errors = may_v1._regenerate_error_with_loc(errors=[exc.errors()], loc_prefix=())
|
||||
for err in errors:
|
||||
loc = err.get("loc", ())
|
||||
if len(loc) == 1:
|
||||
v.pop(loc[0], None)
|
||||
elif len(loc) == 2 and isinstance(v.get(loc[0]), list):
|
||||
try:
|
||||
v[loc[0]].pop(int(loc[1]))
|
||||
except (ValueError, IndexError):
|
||||
pass
|
||||
return handler(v)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParamDetails:
|
||||
|
|
@ -497,7 +515,16 @@ def analyze_param(
|
|||
alias = param_name.replace("_", "-")
|
||||
else:
|
||||
alias = field_info.alias or param_name
|
||||
|
||||
field_info.alias = alias
|
||||
|
||||
if is_scalar_sequence_field(field) or is_scalar_sequence_mapping_field(field):
|
||||
# Wrap the validator to remove invalid values from scalar sequence
|
||||
# and scalar sequence mapping fields instead of failing the whole validation
|
||||
field_info.metadata = getattr(field_info, "metadata", []) + [
|
||||
WrapValidator(remove_invalid)
|
||||
]
|
||||
|
||||
field = create_model_field(
|
||||
name=param_name,
|
||||
type_=use_annotation_from_field_info,
|
||||
|
|
@ -702,13 +729,6 @@ async def solve_dependencies(
|
|||
)
|
||||
|
||||
|
||||
def _extract_error_locs(errors_: Sequence[Any]) -> Set[str]:
|
||||
if isinstance(errors_, list):
|
||||
errors_ = may_v1._regenerate_error_with_loc(errors=errors_, loc_prefix=())
|
||||
|
||||
return {err["loc"][2] for err in errors_ if len(err["loc"]) >= 3}
|
||||
|
||||
|
||||
def _validate_value_with_model_field(
|
||||
*, field: ModelField, value: Any, values: Dict[str, Any], loc: Tuple[str, ...]
|
||||
) -> Tuple[Any, List[Any]]:
|
||||
|
|
@ -718,19 +738,6 @@ def _validate_value_with_model_field(
|
|||
else:
|
||||
return deepcopy(field.default), []
|
||||
v_, errors_ = field.validate(value, values, loc=loc)
|
||||
if (
|
||||
errors_
|
||||
and isinstance(field.field_info, params.Query)
|
||||
and isinstance(value, Mapping)
|
||||
and (is_scalar_sequence_mapping_field(field) or is_scalar_mapping_field(field))
|
||||
):
|
||||
# Remove failing keys from the dict and try to re-validate
|
||||
invalid_keys = _extract_error_locs(errors_)
|
||||
v_, errors_ = field.validate(
|
||||
{k: v for k, v in value.items() if k not in invalid_keys},
|
||||
values,
|
||||
loc=loc,
|
||||
)
|
||||
|
||||
if _is_error_wrapper(errors_): # type: ignore[arg-type]
|
||||
return None, [errors_]
|
||||
|
|
|
|||
|
|
@ -15,7 +15,21 @@ def test_foo_needy_very(client: TestClient):
|
|||
assert response.status_code == 200
|
||||
assert response.json() == {
|
||||
"query": 2,
|
||||
"string_mapping": {"foo": "baz"},
|
||||
"mapping_query_int": {},
|
||||
"sequence_mapping_queries": {},
|
||||
"mapping_query_str_or_int": {"foo": "baz"},
|
||||
"mapping_query_int": None,
|
||||
"sequence_mapping_int": None,
|
||||
}
|
||||
|
||||
|
||||
def test_just_string_not_scalar_mapping(client: TestClient):
|
||||
response = client.get(
|
||||
"/query/mixed-type-params?&query=2&foo=1&bar=3&foo=2&foo=baz"
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {
|
||||
"query": 2,
|
||||
"mapping_query_str_or_int": {"bar": "3", "foo": "baz"},
|
||||
"mapping_query_int": {"bar": 3},
|
||||
"sequence_mapping_int": {"bar": [3], "foo": [1, 2]},
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue