From 45825d7d118a845bb991af55492d610ce9c57c04 Mon Sep 17 00:00:00 2001 From: JONEMI21 Date: Mon, 10 Nov 2025 08:47:21 +0000 Subject: [PATCH] omit by default query params --- fastapi/_compat/__init__.py | 1 + fastapi/_compat/main.py | 19 ++++++ fastapi/dependencies/utils.py | 16 ++++- tests/main.py | 28 +-------- tests/test_omit_by_default.py | 111 ---------------------------------- tests/test_query.py | 27 ++------- 6 files changed, 41 insertions(+), 161 deletions(-) delete mode 100644 tests/test_omit_by_default.py diff --git a/fastapi/_compat/__init__.py b/fastapi/_compat/__init__.py index 3df9175c4..b3438dc15 100644 --- a/fastapi/_compat/__init__.py +++ b/fastapi/_compat/__init__.py @@ -28,6 +28,7 @@ from .main import is_scalar_mapping_field as is_scalar_mapping_field from .main import is_scalar_sequence_field as is_scalar_sequence_field from .main import is_scalar_sequence_mapping_field as is_scalar_sequence_mapping_field from .main import is_sequence_field as is_sequence_field +from .main import omit_by_default as omit_by_default from .main import serialize_sequence_value as serialize_sequence_value from .main import ( with_info_plain_validator_function as with_info_plain_validator_function, diff --git a/fastapi/_compat/main.py b/fastapi/_compat/main.py index 5a36d887a..f7fdbc856 100644 --- a/fastapi/_compat/main.py +++ b/fastapi/_compat/main.py @@ -384,3 +384,22 @@ def _is_model_class(value: Any) -> bool: return lenient_issubclass(value, v2.BaseModel) # type: ignore[attr-defined] return False + + +def omit_by_default(annotation): + from typing import Union + + from pydantic import OnErrorOmit + + origin = getattr(annotation, "__origin__", None) + args = getattr(annotation, "__args__", ()) + + if origin is Union: + new_args = tuple(omit_by_default(arg) for arg in args) + return Union[new_args] + elif origin in (list, List): + return List[omit_by_default(args[0])] + elif origin in (dict, Dict): + return Dict[args[0], omit_by_default(args[1])] + else: + return OnErrorOmit[annotation] diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index b596b819f..56be94409 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -45,11 +45,16 @@ from fastapi._compat import ( is_uploadfile_sequence_annotation, lenient_issubclass, may_v1, + omit_by_default, sequence_types, serialize_sequence_value, value_is_sequence, ) -from fastapi._compat.shared import annotation_is_pydantic_v1 +from fastapi._compat.shared import ( + annotation_is_pydantic_v1, + field_annotation_is_scalar_mapping, + field_annotation_is_scalar_sequence_mapping, +) from fastapi.background import BackgroundTasks from fastapi.concurrency import ( asynccontextmanager, @@ -500,6 +505,12 @@ def analyze_param( field_info.alias = alias + if hasattr(field_info, "annotation") and ( + field_annotation_is_scalar_sequence_mapping(field_info.annotation) + or field_annotation_is_scalar_mapping(field_info.annotation) + ): + field_info.annotation = omit_by_default(field_info.annotation) + field = create_model_field( name=param_name, type_=use_annotation_from_field_info, @@ -833,7 +844,8 @@ def request_params_to_args( errors.extend(errors_) else: values[field.name] = v_ - # remove keys which were captured by a mapping query field but were otherwise specified + # remove keys which were captured by a mapping query field but were + # specified as individual fields for field in fields: if isinstance(values.get(field.name), dict) and ( is_scalar_mapping_field(field) or is_scalar_sequence_mapping_field(field) diff --git a/tests/main.py b/tests/main.py index fe042e870..25646bce3 100644 --- a/tests/main.py +++ b/tests/main.py @@ -194,13 +194,10 @@ def get_mapping_query_params(queries: Dict[str, str] = Query({})): return {"queries": queries} -from pydantic import OnErrorOmit - - @app.get("/query/mixed-params") def get_mixed_mapping_query_params( - sequence_mapping_queries: Dict[str, List[Union[str, OnErrorOmit[int]]]] = Query({}), - mapping_query: Dict[str, str] = Query(), + sequence_mapping_queries: Dict[str, List[Union[int]]] = Query({}), + mapping_query: Dict[str, int] = Query(), query: str = Query(), ): return { @@ -213,29 +210,10 @@ def get_mixed_mapping_query_params( @app.get("/query/mapping-sequence-params") -def get_sequence_mapping_query_params( - queries: Dict[str, List[OnErrorOmit[int]]] = Query({}), -): +def get_sequence_mapping_query_params(queries: Dict[str, List[int]] = Query({})): return {"queries": queries} -@app.get("/query/mixed-type-params") -def get_mixed_mapping_mixed_type_query_params( - sequence_mapping_queries: Dict[str, List[OnErrorOmit[int]]] = Query({}), - mapping_query_str: Dict[str, OnErrorOmit[str]] = Query({}), - mapping_query_int: Dict[str, OnErrorOmit[int]] = Query({}), - query: int = Query(), -): - return { - "queries": { - "query": query, - "mapping_query_str": mapping_query_str, - "mapping_query_int": mapping_query_int, - "sequence_mapping_queries": sequence_mapping_queries, - } - } - - @app.get("/enum-status-code", status_code=http.HTTPStatus.CREATED) def get_enum_status_code(): return "foo bar" diff --git a/tests/test_omit_by_default.py b/tests/test_omit_by_default.py deleted file mode 100644 index 97be98b5e..000000000 --- a/tests/test_omit_by_default.py +++ /dev/null @@ -1,111 +0,0 @@ -import sys -from typing import Dict, List, Optional, Union - -import pytest -from pydantic import OnErrorOmit -from typing_extensions import Annotated - - -def omit_by_default(annotation): - """A simplified version of the omit_by_default function for testing purposes.""" - origin = getattr(annotation, "__origin__", None) - args = getattr(annotation, "__args__", ()) - - if origin is Annotated: - new_args = (omit_by_default(args[0]),) + args[1:] - return Annotated[new_args[0], *new_args[1:]] - elif origin is Union: - new_args = tuple(omit_by_default(arg) for arg in args) - return Union[new_args] - elif origin in (list, List): - return List[omit_by_default(args[0])] - elif origin in (dict, Dict): - return Dict[args[0], omit_by_default(args[1])] - else: - return OnErrorOmit[annotation] - - -def test_omit_by_default_simple_type(): - result = omit_by_default(int) - assert result == OnErrorOmit[int] - - -def test_omit_by_default_union(): - result = omit_by_default(Union[int, str]) - assert result == Union[OnErrorOmit[int], OnErrorOmit[str]] - - -def test_omit_by_default_optional(): - result = omit_by_default(Optional[int]) - assert result == Union[OnErrorOmit[int], OnErrorOmit[type(None)]] - - -def test_omit_by_default_annotated(): - result = omit_by_default(Annotated[int, "metadata"]) - origin = result.__origin__ if hasattr(result, "__origin__") else None - assert origin is Annotated - args = result.__args__ if hasattr(result, "__args__") else () - assert len(args) == 2 - assert args[0] == OnErrorOmit[int] - assert args[1] == "metadata" - - -def test_omit_by_default_annotated_union(): - result = omit_by_default(Annotated[Union[int, str], "metadata"]) - origin = result.__origin__ if hasattr(result, "__origin__") else None - assert origin is Annotated - args = result.__args__ if hasattr(result, "__args__") else () - assert len(args) == 2 - assert args[0] == Union[OnErrorOmit[int], OnErrorOmit[str]] - assert args[1] == "metadata" - - -def test_omit_by_default_list(): - result = omit_by_default(List[int]) - assert result == List[OnErrorOmit[int]] - - -def test_omit_by_default_dict(): - result = omit_by_default(Dict[str, int]) - assert result == Dict[str, OnErrorOmit[int]] - - -def test_omit_by_default_nested_union(): - result = omit_by_default(Union[int, Union[str, float]]) - assert result == Union[OnErrorOmit[int], OnErrorOmit[Union[str, float]]] - - -def test_omit_by_default_annotated_with_multiple_metadata(): - result = omit_by_default(Annotated[str, "meta1", "meta2"]) - origin = result.__origin__ if hasattr(result, "__origin__") else None - assert origin is Annotated - args = result.__args__ if hasattr(result, "__args__") else () - assert len(args) == 3 - assert args[0] == OnErrorOmit[str] - assert args[1] == "meta1" - assert args[2] == "meta2" - - -@pytest.mark.skipif( - sys.version_info < (3, 10), reason="Union type syntax requires Python 3.10+" -) -def test_omit_by_default_pipe_union(): - annotation = eval("int | str") - result = omit_by_default(annotation) - assert result == Union[OnErrorOmit[int], OnErrorOmit[str]] - - -def test_omit_by_default_complex_nested(): - result = omit_by_default(Annotated[Union[int, Optional[str]], "metadata"]) - origin = result.__origin__ if hasattr(result, "__origin__") else None - assert origin is Annotated - args = result.__args__ if hasattr(result, "__args__") else () - assert len(args) == 2 - expected_union = Union[OnErrorOmit[int], OnErrorOmit[Union[str, type(None)]]] - assert args[0] == expected_union - assert args[1] == "metadata" - - -def test_omit_by_default_dict_with_union_value(): - result = omit_by_default(Dict[str, Union[int, str]]) - assert result == Dict[str, Union[OnErrorOmit[int], OnErrorOmit[str]]] diff --git a/tests/test_query.py b/tests/test_query.py index 1a1fb6146..102a8a29c 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -434,34 +434,15 @@ def test_sequence_mapping_query(): def test_mapping_with_non_mapping_query(): - response = client.get("/query/mixed-params/?foo=fuzz&foo=baz&bar=buzz&query=fizz") + response = client.get("/query/mixed-params/?foo=1&foo=2&bar=3&query=fizz") assert response.status_code == 200 assert response.json() == { "queries": { "query": "fizz", - "mapping_query": {"foo": "baz", "bar": "buzz"}, + "mapping_query": {"foo": 2, "bar": 3}, "sequence_mapping_queries": { - "foo": ["fuzz", "baz"], - "bar": ["buzz"], + "foo": [1, 2], + "bar": [3], }, } } - - -def test_mapping_with_non_mapping_query_mixed_types(): - response = client.get("/query/mixed-type-params/?foo=fuzz&foo=baz&bar=buzz&query=1") - assert response.status_code == 200 - assert response.json() == { - "queries": { - "query": 1, - "mapping_query_int": {}, - "mapping_query_str": {"bar": "buzz", "foo": "baz"}, - "sequence_mapping_queries": {"bar": [], "foo": []}, - } - } - - -def test_sequence_mapping_query_drops_invalid(): - response = client.get("/query/mapping-sequence-params/?foo=fuzz&foo=buzz") - assert response.status_code == 200 - assert response.json() == {"queries": {"foo": []}}