use wrap validator to remove failing values

This commit is contained in:
JONEMI21 2025-11-07 13:56:50 +00:00
parent 103d46921b
commit c9bfba04dd
4 changed files with 59 additions and 33 deletions

View File

@ -1,4 +1,4 @@
from typing import Dict, List
from typing import Annotated, Dict, List, Union
from fastapi import FastAPI, Query
@ -7,14 +7,16 @@ app = FastAPI()
@app.get("/query/mixed-type-params")
def get_mixed_mapping_mixed_type_query_params(
query: int = Query(),
mapping_query_str: Dict[str, str] = Query({}),
mapping_query_int: Dict[str, int] = Query({}),
sequence_mapping_queries: Dict[str, List[int]] = Query({}),
query: Annotated[int, Query()] = None,
mapping_query_str_or_int: Annotated[
Union[Dict[str, str], Dict[str, int]], Query()
] = None,
mapping_query_int: Annotated[Dict[str, int], Query()] = None,
sequence_mapping_int: Annotated[Dict[str, List[int]], Query()] = None,
):
return {
"query": query,
"string_mapping": mapping_query_str,
"mapping_query_str_or_int": mapping_query_str_or_int,
"mapping_query_int": mapping_query_int,
"sequence_mapping_queries": sequence_mapping_queries,
"sequence_mapping_int": sequence_mapping_int,
}

View File

@ -173,7 +173,10 @@ def field_annotation_is_scalar_sequence_mapping(
if field_annotation_is_scalar_sequence_mapping(arg):
at_least_one_scalar_mapping = True
continue
elif not field_annotation_is_scalar(arg):
elif not (
field_annotation_is_scalar_sequence_mapping(arg)
or field_annotation_is_scalar_mapping(arg)
):
return False
return at_least_one_scalar_mapping
return lenient_issubclass(origin, Mapping) and all(

View File

@ -12,7 +12,6 @@ from typing import (
Mapping,
Optional,
Sequence,
Set,
Tuple,
Type,
Union,
@ -64,7 +63,7 @@ from fastapi.security.oauth2 import OAuth2, SecurityScopes
from fastapi.security.open_id_connect_url import OpenIdConnect
from fastapi.types import DependencyCacheKey
from fastapi.utils import create_model_field, get_path_param_names
from pydantic import BaseModel
from pydantic import BaseModel, ValidationError, WrapValidator
from pydantic.fields import FieldInfo
from starlette.background import BackgroundTasks as StarletteBackgroundTasks
from starlette.concurrency import run_in_threadpool
@ -330,6 +329,25 @@ def add_non_field_param_to_dependency(
return True
return None
def remove_invalid(v: Any, handler: Callable[[Any], Any]) -> Any:
try:
return handler(v)
except ValidationError as exc:
if not isinstance(v, dict):
raise exc
# remove invalid values from invalid keys and revalidate
errors = may_v1._regenerate_error_with_loc(errors=[exc.errors()], loc_prefix=())
for err in errors:
loc = err.get("loc", ())
if len(loc) == 1:
v.pop(loc[0], None)
elif len(loc) == 2 and isinstance(v.get(loc[0]), list):
try:
v[loc[0]].pop(int(loc[1]))
except (ValueError, IndexError):
pass
return handler(v)
@dataclass
class ParamDetails:
@ -497,7 +515,16 @@ def analyze_param(
alias = param_name.replace("_", "-")
else:
alias = field_info.alias or param_name
field_info.alias = alias
if is_scalar_sequence_field(field) or is_scalar_sequence_mapping_field(field):
# Wrap the validator to remove invalid values from scalar sequence
# and scalar sequence mapping fields instead of failing the whole validation
field_info.metadata = getattr(field_info, "metadata", []) + [
WrapValidator(remove_invalid)
]
field = create_model_field(
name=param_name,
type_=use_annotation_from_field_info,
@ -702,13 +729,6 @@ async def solve_dependencies(
)
def _extract_error_locs(errors_: Sequence[Any]) -> Set[str]:
if isinstance(errors_, list):
errors_ = may_v1._regenerate_error_with_loc(errors=errors_, loc_prefix=())
return {err["loc"][2] for err in errors_ if len(err["loc"]) >= 3}
def _validate_value_with_model_field(
*, field: ModelField, value: Any, values: Dict[str, Any], loc: Tuple[str, ...]
) -> Tuple[Any, List[Any]]:
@ -718,19 +738,6 @@ def _validate_value_with_model_field(
else:
return deepcopy(field.default), []
v_, errors_ = field.validate(value, values, loc=loc)
if (
errors_
and isinstance(field.field_info, params.Query)
and isinstance(value, Mapping)
and (is_scalar_sequence_mapping_field(field) or is_scalar_mapping_field(field))
):
# Remove failing keys from the dict and try to re-validate
invalid_keys = _extract_error_locs(errors_)
v_, errors_ = field.validate(
{k: v for k, v in value.items() if k not in invalid_keys},
values,
loc=loc,
)
if _is_error_wrapper(errors_): # type: ignore[arg-type]
return None, [errors_]

View File

@ -15,7 +15,21 @@ def test_foo_needy_very(client: TestClient):
assert response.status_code == 200
assert response.json() == {
"query": 2,
"string_mapping": {"foo": "baz"},
"mapping_query_int": {},
"sequence_mapping_queries": {},
"mapping_query_str_or_int": {"foo": "baz"},
"mapping_query_int": None,
"sequence_mapping_int": None,
}
def test_just_string_not_scalar_mapping(client: TestClient):
response = client.get(
"/query/mixed-type-params?&query=2&foo=1&bar=3&foo=2&foo=baz"
)
assert response.status_code == 200
assert response.json() == {
"query": 2,
"mapping_query_str_or_int": {"bar": "3", "foo": "baz"},
"mapping_query_int": {"bar": 3},
"sequence_mapping_int": {"bar": [3], "foo": [1, 2]},
}