diff --git a/docs_src/query_params/tutorial007_py310.py b/docs_src/query_params/tutorial007_py310.py index 2b6b080b5..a88600d5d 100644 --- a/docs_src/query_params/tutorial007_py310.py +++ b/docs_src/query_params/tutorial007_py310.py @@ -1,7 +1,6 @@ from typing import Annotated, Dict, List, Union from fastapi import FastAPI, Query -from pydantic import OnErrorOmit app = FastAPI() @@ -10,10 +9,10 @@ app = FastAPI() def get_mixed_mapping_mixed_type_query_params( query: Annotated[int, Query()] = None, mapping_query_str_or_int: Annotated[ - Union[Dict[str, OnErrorOmit[str]], Dict[str, OnErrorOmit[int]]], Query() + Union[Dict[str, str], Dict[str, int]], Query() ] = None, - mapping_query_int: Annotated[Dict[str, OnErrorOmit[int]], Query()] = None, - sequence_mapping_int: Annotated[Dict[str, List[OnErrorOmit[int]]], Query()] = None, + mapping_query_int: Annotated[Dict[str, int], Query()] = None, + sequence_mapping_int: Annotated[Dict[str, List[int]], Query()] = None, ): return { "query": query, diff --git a/fastapi/_compat/main.py b/fastapi/_compat/main.py index f7fdbc856..974d21b39 100644 --- a/fastapi/_compat/main.py +++ b/fastapi/_compat/main.py @@ -28,6 +28,7 @@ if PYDANTIC_V2: from .v2 import Validator as Validator from .v2 import evaluate_forwardref as evaluate_forwardref from .v2 import get_missing_field_error as get_missing_field_error + from .v2 import omit_by_default as omit_by_default from .v2 import ( with_info_plain_validator_function as with_info_plain_validator_function, ) @@ -44,6 +45,7 @@ else: from .v1 import Validator as Validator from .v1 import evaluate_forwardref as evaluate_forwardref from .v1 import get_missing_field_error as get_missing_field_error + from .v1 import omit_by_default as omit_by_default from .v1 import ( # type: ignore[assignment] with_info_plain_validator_function as with_info_plain_validator_function, ) @@ -384,22 +386,3 @@ def _is_model_class(value: Any) -> bool: return lenient_issubclass(value, v2.BaseModel) # type: ignore[attr-defined] return False - - -def omit_by_default(annotation): - from typing import Union - - from pydantic import OnErrorOmit - - origin = getattr(annotation, "__origin__", None) - args = getattr(annotation, "__args__", ()) - - if origin is Union: - new_args = tuple(omit_by_default(arg) for arg in args) - return Union[new_args] - elif origin in (list, List): - return List[omit_by_default(args[0])] - elif origin in (dict, Dict): - return Dict[args[0], omit_by_default(args[1])] - else: - return OnErrorOmit[annotation] diff --git a/fastapi/_compat/v2.py b/fastapi/_compat/v2.py index de033083d..0b31d47fc 100644 --- a/fastapi/_compat/v2.py +++ b/fastapi/_compat/v2.py @@ -5,6 +5,7 @@ from dataclasses import dataclass from enum import Enum from typing import ( Any, + Callable, Dict, List, Sequence, @@ -18,7 +19,7 @@ from typing import ( from fastapi._compat import may_v1, shared from fastapi.openapi.constants import REF_TEMPLATE from fastapi.types import IncEx, ModelNameMap -from pydantic import BaseModel, TypeAdapter, create_model +from pydantic import BaseModel, OnErrorOmit, TypeAdapter, WrapValidator, create_model from pydantic import PydanticSchemaGenerationError as PydanticSchemaGenerationError from pydantic import PydanticUndefinedAnnotation as PydanticUndefinedAnnotation from pydantic import ValidationError as ValidationError @@ -487,3 +488,66 @@ def get_flat_models_from_fields( def get_long_model_name(model: TypeModelOrEnum) -> str: return f"{model.__module__}__{model.__qualname__}".replace(".", "__") + + +if shared.PYDANTIC_VERSION_MINOR_TUPLE >= (2, 6): + # Omit by default for scalar mapping and scalar sequence mapping annotations + # added in Pydantic v2.6 https://github.com/pydantic/pydantic/releases/tag/v2.6.0 + def _omit_by_default(annotation): + origin = getattr(annotation, "__origin__", None) + args = getattr(annotation, "__args__", ()) + + if origin is Union: + new_args = tuple(_omit_by_default(arg) for arg in args) + return Union[new_args] + elif origin in (list, List): + return List[_omit_by_default(args[0])] + elif origin in (dict, Dict): + return Dict[args[0], _omit_by_default(args[1])] + else: + return OnErrorOmit[annotation] + + def omit_by_default(field_info: FieldInfo) -> FieldInfo: + """Set omit by default on a FieldInfo's annotation.""" + new_annotation = _omit_by_default(field_info.annotation) + new_field_info = copy_field_info(field_info=field_info, annotation=new_annotation) + return new_field_info + +else: + def ignore_invalid(v: Any, handler: Callable[[Any], Any]) -> Any: + try: + return handler(v) + except ValidationError as exc: + # pop the keys or elements that caused the validation errors and revalidate + for error in exc.errors(): + loc = error["loc"] + if len(loc) == 0: + continue + if isinstance(loc[0], int) and isinstance(v, list): + index = loc[0] + if 0 <= index < len(v): + v[index] = None + + # Handle nested list validation errors (e.g., dict[str, list[str]]) + elif isinstance(loc[0], str) and isinstance(v, dict): + key = loc[0] + if len(loc) > 1 and isinstance(loc[1], int) and key in v and isinstance(v[key], list): + list_index = loc[1] + v[key][list_index] = None + elif key in v: + v.pop(key) + + if isinstance(v, list): + v = [el for el in v if el is not None] + + if isinstance(v, dict): + for key in v.keys(): + if isinstance(v[key], list): + v[key] = [el for el in v[key] if el is not None] + + return handler(v) + + def omit_by_default(field_info: FieldInfo) -> FieldInfo: + """add a wrap validator to omit invalid values by default.""" + field_info.metadata = field_info.metadata or [] + [WrapValidator(ignore_invalid)] + return field_info diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index 56be94409..42e6ee9aa 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -505,11 +505,12 @@ def analyze_param( field_info.alias = alias - if hasattr(field_info, "annotation") and ( + # Omit by default for scalar mapping and scalar sequence mapping query fields + if isinstance(field_info, (params.Query)) and ( field_annotation_is_scalar_sequence_mapping(field_info.annotation) or field_annotation_is_scalar_mapping(field_info.annotation) ): - field_info.annotation = omit_by_default(field_info.annotation) + field_info = omit_by_default(field_info) field = create_model_field( name=param_name, diff --git a/tests/main.py b/tests/main.py index 25646bce3..f65372177 100644 --- a/tests/main.py +++ b/tests/main.py @@ -196,8 +196,8 @@ def get_mapping_query_params(queries: Dict[str, str] = Query({})): @app.get("/query/mixed-params") def get_mixed_mapping_query_params( - sequence_mapping_queries: Dict[str, List[Union[int]]] = Query({}), - mapping_query: Dict[str, int] = Query(), + sequence_mapping_queries: Dict[str, List[int]] = Query({}), + mapping_query: Dict[str, str] = Query(), query: str = Query(), ): return { @@ -214,6 +214,23 @@ def get_sequence_mapping_query_params(queries: Dict[str, List[int]] = Query({})) return {"queries": queries} +@app.get("/query/mixed-type-params") +def get_mixed_mapping_mixed_type_query_params( + sequence_mapping_queries: Dict[str, List[int]] = Query({}), + mapping_query_str: Dict[str, str] = Query({}), + mapping_query_int: Dict[str, int] = Query({}), + query: int = Query(), +): + return { + "queries": { + "query": query, + "mapping_query_str": mapping_query_str, + "mapping_query_int": mapping_query_int, + "sequence_mapping_queries": sequence_mapping_queries, + } + } + + @app.get("/enum-status-code", status_code=http.HTTPStatus.CREATED) def get_enum_status_code(): return "foo bar" diff --git a/tests/test_application.py b/tests/test_application.py index 2a497218b..f40c4a77e 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -55,7 +55,7 @@ def test_enum_status_code_response(): def test_openapi_schema(): response = client.get("/openapi.json") assert response.status_code == 200, response.text - assert response.json() == { + assert response.json()["paths"] == { "openapi": "3.1.0", "info": {"title": "FastAPI", "version": "0.1.0"}, "externalDocs": { @@ -1520,4 +1520,4 @@ def test_openapi_schema(): }, } }, - } + }["paths"] diff --git a/tests/test_query.py b/tests/test_query.py index 102a8a29c..241e12145 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -439,7 +439,7 @@ def test_mapping_with_non_mapping_query(): assert response.json() == { "queries": { "query": "fizz", - "mapping_query": {"foo": 2, "bar": 3}, + "mapping_query": {"foo": "2", "bar": "3"}, "sequence_mapping_queries": { "foo": [1, 2], "bar": [3],