diff --git a/docs_src/query_params/tutorial007_py310.py b/docs_src/query_params/tutorial007_py310.py index a88600d5d..44939b620 100644 --- a/docs_src/query_params/tutorial007_py310.py +++ b/docs_src/query_params/tutorial007_py310.py @@ -1,4 +1,4 @@ -from typing import Annotated, Dict, List, Union +from typing import Annotated, Dict, List from fastapi import FastAPI, Query @@ -8,15 +8,13 @@ app = FastAPI() @app.get("/query/mixed-type-params") def get_mixed_mapping_mixed_type_query_params( query: Annotated[int, Query()] = None, - mapping_query_str_or_int: Annotated[ - Union[Dict[str, str], Dict[str, int]], Query() - ] = None, + mapping_query_str: Annotated[Dict[str, str], Query()] = None, mapping_query_int: Annotated[Dict[str, int], Query()] = None, sequence_mapping_int: Annotated[Dict[str, List[int]], Query()] = None, ): return { "query": query, - "mapping_query_str_or_int": mapping_query_str_or_int, + "mapping_query_str": mapping_query_str, "mapping_query_int": mapping_query_int, "sequence_mapping_int": sequence_mapping_int, } diff --git a/fastapi/_compat/shared.py b/fastapi/_compat/shared.py index 84b72adf4..7643b09c3 100644 --- a/fastapi/_compat/shared.py +++ b/fastapi/_compat/shared.py @@ -148,6 +148,8 @@ def field_annotation_is_scalar_mapping( annotation: Union[Type[Any], None], ) -> bool: origin = get_origin(annotation) + if origin is Annotated: + return field_annotation_is_scalar_mapping(get_args(annotation)[0]) if origin is Union or origin is UnionType: at_least_one_scalar_mapping = False for arg in get_args(annotation): @@ -167,6 +169,8 @@ def field_annotation_is_scalar_sequence_mapping( annotation: Union[Type[Any], None], ) -> bool: origin = get_origin(annotation) + if origin is Annotated: + return field_annotation_is_scalar_sequence_mapping(get_args(annotation)[0]) if origin is Union or origin is UnionType: at_least_one_scalar_mapping = False for arg in get_args(annotation): diff --git a/fastapi/_compat/v1.py b/fastapi/_compat/v1.py index c080f9104..fbbec552c 100644 --- a/fastapi/_compat/v1.py +++ b/fastapi/_compat/v1.py @@ -1,4 +1,4 @@ -from copy import copy +from copy import copy, deepcopy from dataclasses import dataclass, is_dataclass from enum import Enum from typing import ( @@ -350,6 +350,54 @@ def get_model_fields(model: Type[BaseModel]) -> List[ModelField]: return list(model.__fields__.values()) # type: ignore[attr-defined] -def omit_by_default(field_info: FieldInfo) -> FieldInfo: +def ignore_invalid(cls, v, values, field, **kwargs) -> Any: + from .may_v1 import _regenerate_error_with_loc + + field_copy = deepcopy(field) + field_copy.pre_validators = [ + validator + for validator in field_copy.pre_validators + if getattr(validator, "__name__", "") != "ignore_invalid" + ] + v, errors = field_copy.validate(v, values, loc=field.name) + if not errors: + return v + + # pop the keys or elements that caused the validation errors and revalidate + for error in _regenerate_error_with_loc(errors=errors, loc_prefix=()): + loc = error["loc"][1:] + if len(loc) == 0: + continue + if isinstance(loc[0], int) and isinstance(v, list): + index = loc[0] + if 0 <= index < len(v): + v[index] = None + + # Handle nested list validation errors (e.g., dict[str, list[str]]) + elif isinstance(loc[0], str) and isinstance(v, dict): + key = loc[0] + if ( + len(loc) > 1 + and isinstance(loc[1], int) + and key in v + and isinstance(v[key], list) + ): + list_index = loc[1] + v[key][list_index] = None + elif key in v: + v.pop(key) + + if isinstance(v, list): + v = [el for el in v if el is not None] + + if isinstance(v, dict): + for key in v.keys(): + if isinstance(v[key], list): + v[key] = [el for el in v[key] if el is not None] + + return v + + +def omit_by_default(field_info: FieldInfo) -> tuple[FieldInfo, dict]: """add a wrap validator to omit invalid values by default.""" - raise NotImplementedError("This function is a placeholder in Pydantic v1.") + return field_info, {"ignore_invalid": Validator(ignore_invalid, pre=True)} diff --git a/fastapi/_compat/v2.py b/fastapi/_compat/v2.py index 2a9a82df1..5f2815ac7 100644 --- a/fastapi/_compat/v2.py +++ b/fastapi/_compat/v2.py @@ -507,13 +507,12 @@ if shared.PYDANTIC_VERSION_MINOR_TUPLE >= (2, 6): else: return OnErrorOmit[annotation] - def omit_by_default(field_info: FieldInfo) -> FieldInfo: - """Set omit by default on a FieldInfo's annotation.""" + def omit_by_default(field_info: FieldInfo) -> tuple[FieldInfo, dict]: new_annotation = _omit_by_default(field_info.annotation) new_field_info = copy_field_info( field_info=field_info, annotation=new_annotation ) - return new_field_info + return new_field_info, {} else: @@ -555,9 +554,9 @@ else: return handler(v) - def omit_by_default(field_info: FieldInfo) -> FieldInfo: + def omit_by_default(field_info: FieldInfo) -> tuple[FieldInfo, dict]: """add a wrap validator to omit invalid values by default.""" field_info.metadata = field_info.metadata or [] + [ WrapValidator(ignore_invalid) ] - return field_info + return field_info, {} diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index 42e6ee9aa..b49998ba7 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -506,11 +506,12 @@ def analyze_param( field_info.alias = alias # Omit by default for scalar mapping and scalar sequence mapping query fields - if isinstance(field_info, (params.Query)) and ( - field_annotation_is_scalar_sequence_mapping(field_info.annotation) - or field_annotation_is_scalar_mapping(field_info.annotation) + class_validators: dict[str, list[Any]] = {} + if isinstance(field_info, (params.Query, temp_pydantic_v1_params.Query)) and ( + field_annotation_is_scalar_sequence_mapping(use_annotation_from_field_info) + or field_annotation_is_scalar_mapping(use_annotation_from_field_info) ): - field_info = omit_by_default(field_info) + field_info, class_validators = omit_by_default(field_info) field = create_model_field( name=param_name, @@ -520,6 +521,7 @@ def analyze_param( required=field_info.default in (RequiredParam, may_v1.RequiredParam, Undefined), field_info=field_info, + class_validators=class_validators, ) if is_path_param: assert is_scalar_field(field=field), ( @@ -529,8 +531,8 @@ def analyze_param( assert ( is_scalar_field(field) or is_scalar_sequence_field(field) - or is_scalar_sequence_mapping_field(field) or is_scalar_mapping_field(field) + or is_scalar_sequence_mapping_field(field) or ( _is_model_class(field.type_) # For Pydantic v1 diff --git a/tests/test_application.py b/tests/test_application.py index cd1713ae1..e1ce9413f 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -1219,14 +1219,7 @@ def test_openapi_schema(): "schema": { "additionalProperties": { "items": { - "anyOf": [ - { - "type": "string", - }, - { - "type": "integer", - }, - ], + "type": "integer", }, "type": "array", }, diff --git a/tests/test_tutorial/test_query_params/test_tutorial007_py310.py b/tests/test_tutorial/test_query_params/test_tutorial007_py310.py index 44cf39ab7..7a39bddc3 100644 --- a/tests/test_tutorial/test_query_params/test_tutorial007_py310.py +++ b/tests/test_tutorial/test_query_params/test_tutorial007_py310.py @@ -1,10 +1,6 @@ import pytest -from fastapi._compat import PYDANTIC_V2 from fastapi.testclient import TestClient -if not PYDANTIC_V2: - pytest.skip("This test is only for Pydantic v2", allow_module_level=True) - @pytest.fixture(name="client") def get_client(): @@ -19,7 +15,7 @@ def test_foo_needy_very(client: TestClient): assert response.status_code == 200 assert response.json() == { "query": 2, - "mapping_query_str_or_int": {"foo": "baz"}, + "mapping_query_str": {"foo": "baz"}, "mapping_query_int": {}, "sequence_mapping_int": {"foo": []}, } @@ -30,7 +26,7 @@ def test_just_string_not_scalar_mapping(client: TestClient): assert response.status_code == 200 assert response.json() == { "query": 2, - "mapping_query_str_or_int": {"bar": "3", "foo": "baz"}, + "mapping_query_str": {"bar": "3", "foo": "baz"}, "mapping_query_int": {"bar": 3}, "sequence_mapping_int": {"bar": [3], "foo": [1, 2]}, }