diff --git a/docs_src/query_params/tutorial007_py310.py b/docs_src/query_params/tutorial007_py310.py index 4ce0017d2..a88600d5d 100644 --- a/docs_src/query_params/tutorial007_py310.py +++ b/docs_src/query_params/tutorial007_py310.py @@ -1,4 +1,4 @@ -from typing import Dict, List +from typing import Annotated, Dict, List, Union from fastapi import FastAPI, Query @@ -7,14 +7,16 @@ app = FastAPI() @app.get("/query/mixed-type-params") def get_mixed_mapping_mixed_type_query_params( - query: int = Query(), - mapping_query_str: Dict[str, str] = Query({}), - mapping_query_int: Dict[str, int] = Query({}), - sequence_mapping_queries: Dict[str, List[int]] = Query({}), + query: Annotated[int, Query()] = None, + mapping_query_str_or_int: Annotated[ + Union[Dict[str, str], Dict[str, int]], Query() + ] = None, + mapping_query_int: Annotated[Dict[str, int], Query()] = None, + sequence_mapping_int: Annotated[Dict[str, List[int]], Query()] = None, ): return { "query": query, - "string_mapping": mapping_query_str, + "mapping_query_str_or_int": mapping_query_str_or_int, "mapping_query_int": mapping_query_int, - "sequence_mapping_queries": sequence_mapping_queries, + "sequence_mapping_int": sequence_mapping_int, } diff --git a/fastapi/_compat/shared.py b/fastapi/_compat/shared.py index 95730b839..84b72adf4 100644 --- a/fastapi/_compat/shared.py +++ b/fastapi/_compat/shared.py @@ -173,7 +173,10 @@ def field_annotation_is_scalar_sequence_mapping( if field_annotation_is_scalar_sequence_mapping(arg): at_least_one_scalar_mapping = True continue - elif not field_annotation_is_scalar(arg): + elif not ( + field_annotation_is_scalar_sequence_mapping(arg) + or field_annotation_is_scalar_mapping(arg) + ): return False return at_least_one_scalar_mapping return lenient_issubclass(origin, Mapping) and all( diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index 0764f8a93..e8e3466a8 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -12,7 +12,6 @@ from typing import ( Mapping, Optional, Sequence, - Set, Tuple, Type, Union, @@ -64,7 +63,7 @@ from fastapi.security.oauth2 import OAuth2, SecurityScopes from fastapi.security.open_id_connect_url import OpenIdConnect from fastapi.types import DependencyCacheKey from fastapi.utils import create_model_field, get_path_param_names -from pydantic import BaseModel +from pydantic import BaseModel, ValidationError, WrapValidator from pydantic.fields import FieldInfo from starlette.background import BackgroundTasks as StarletteBackgroundTasks from starlette.concurrency import run_in_threadpool @@ -330,6 +329,25 @@ def add_non_field_param_to_dependency( return True return None +def remove_invalid(v: Any, handler: Callable[[Any], Any]) -> Any: + try: + return handler(v) + except ValidationError as exc: + if not isinstance(v, dict): + raise exc + # remove invalid values from invalid keys and revalidate + errors = may_v1._regenerate_error_with_loc(errors=[exc.errors()], loc_prefix=()) + for err in errors: + loc = err.get("loc", ()) + if len(loc) == 1: + v.pop(loc[0], None) + elif len(loc) == 2 and isinstance(v.get(loc[0]), list): + try: + v[loc[0]].pop(int(loc[1])) + except (ValueError, IndexError): + pass + return handler(v) + @dataclass class ParamDetails: @@ -497,7 +515,16 @@ def analyze_param( alias = param_name.replace("_", "-") else: alias = field_info.alias or param_name + field_info.alias = alias + + if is_scalar_sequence_field(field) or is_scalar_sequence_mapping_field(field): + # Wrap the validator to remove invalid values from scalar sequence + # and scalar sequence mapping fields instead of failing the whole validation + field_info.metadata = getattr(field_info, "metadata", []) + [ + WrapValidator(remove_invalid) + ] + field = create_model_field( name=param_name, type_=use_annotation_from_field_info, @@ -702,13 +729,6 @@ async def solve_dependencies( ) -def _extract_error_locs(errors_: Sequence[Any]) -> Set[str]: - if isinstance(errors_, list): - errors_ = may_v1._regenerate_error_with_loc(errors=errors_, loc_prefix=()) - - return {err["loc"][2] for err in errors_ if len(err["loc"]) >= 3} - - def _validate_value_with_model_field( *, field: ModelField, value: Any, values: Dict[str, Any], loc: Tuple[str, ...] ) -> Tuple[Any, List[Any]]: @@ -718,19 +738,6 @@ def _validate_value_with_model_field( else: return deepcopy(field.default), [] v_, errors_ = field.validate(value, values, loc=loc) - if ( - errors_ - and isinstance(field.field_info, params.Query) - and isinstance(value, Mapping) - and (is_scalar_sequence_mapping_field(field) or is_scalar_mapping_field(field)) - ): - # Remove failing keys from the dict and try to re-validate - invalid_keys = _extract_error_locs(errors_) - v_, errors_ = field.validate( - {k: v for k, v in value.items() if k not in invalid_keys}, - values, - loc=loc, - ) if _is_error_wrapper(errors_): # type: ignore[arg-type] return None, [errors_] diff --git a/tests/test_tutorial/test_query_params/test_tutorial007_py310.py b/tests/test_tutorial/test_query_params/test_tutorial007_py310.py index b275dacd8..235992399 100644 --- a/tests/test_tutorial/test_query_params/test_tutorial007_py310.py +++ b/tests/test_tutorial/test_query_params/test_tutorial007_py310.py @@ -15,7 +15,21 @@ def test_foo_needy_very(client: TestClient): assert response.status_code == 200 assert response.json() == { "query": 2, - "string_mapping": {"foo": "baz"}, - "mapping_query_int": {}, - "sequence_mapping_queries": {}, + "mapping_query_str_or_int": {"foo": "baz"}, + "mapping_query_int": None, + "sequence_mapping_int": None, } + + +def test_just_string_not_scalar_mapping(client: TestClient): + response = client.get( + "/query/mixed-type-params?&query=2&foo=1&bar=3&foo=2&foo=baz" + ) + assert response.status_code == 200 + assert response.json() == { + "query": 2, + "mapping_query_str_or_int": {"bar": "3", "foo": "baz"}, + "mapping_query_int": {"bar": 3}, + "sequence_mapping_int": {"bar": [3], "foo": [1, 2]}, + } +