From 1a57459edaeddcf8e96830a02355696f5ab2fad1 Mon Sep 17 00:00:00 2001 From: JONEMI21 Date: Thu, 6 Nov 2025 22:04:36 +0000 Subject: [PATCH] update and remove otherwise captured query params --- fastapi/_compat/__init__.py | 2 + fastapi/_compat/main.py | 24 ++++++++++++ fastapi/_compat/shared.py | 39 +++++++++++++++++++ fastapi/_compat/v1.py | 39 +++++++++++++++++++ fastapi/_compat/v2.py | 10 +++++ fastapi/dependencies/utils.py | 37 +++++++++++++++++- tests/main.py | 24 ++++++++---- tests/test_query.py | 29 ++++++++++---- .../test_tutorial007_py310.py | 6 +-- 9 files changed, 190 insertions(+), 20 deletions(-) diff --git a/fastapi/_compat/__init__.py b/fastapi/_compat/__init__.py index 0aadd68de..3561176fe 100644 --- a/fastapi/_compat/__init__.py +++ b/fastapi/_compat/__init__.py @@ -24,7 +24,9 @@ from .main import get_schema_from_model_field as get_schema_from_model_field from .main import is_bytes_field as is_bytes_field from .main import is_bytes_sequence_field as is_bytes_sequence_field from .main import is_scalar_field as is_scalar_field +from .main import is_scalar_mapping_field as is_scalar_mapping_field from .main import is_scalar_sequence_field as is_scalar_sequence_field +from .main import is_scalar_sequence_mapping_field as is_scalar_sequence_mapping_field from .main import is_sequence_field as is_sequence_field from .main import serialize_sequence_value as serialize_sequence_value from .main import ( diff --git a/fastapi/_compat/main.py b/fastapi/_compat/main.py index e5275950e..5a36d887a 100644 --- a/fastapi/_compat/main.py +++ b/fastapi/_compat/main.py @@ -209,6 +209,30 @@ def is_sequence_field(field: ModelField) -> bool: return v2.is_sequence_field(field) # type: ignore[arg-type] +def is_scalar_mapping_field(field: ModelField) -> bool: + if isinstance(field, may_v1.ModelField): + from fastapi._compat import v1 + + return v1.is_scalar_mapping_field(field) + else: + assert PYDANTIC_V2 + from . import v2 + + return v2.is_scalar_mapping_field(field) # type: ignore[arg-type] + + +def is_scalar_sequence_mapping_field(field: ModelField) -> bool: + if isinstance(field, may_v1.ModelField): + from fastapi._compat import v1 + + return v1.is_scalar_sequence_mapping_field(field) + else: + assert PYDANTIC_V2 + from . import v2 + + return v2.is_scalar_sequence_mapping_field(field) # type: ignore[arg-type] + + def serialize_sequence_value(*, field: ModelField, value: Any) -> Sequence[Any]: if isinstance(field, may_v1.ModelField): from fastapi._compat import v1 diff --git a/fastapi/_compat/shared.py b/fastapi/_compat/shared.py index cabf48228..95730b839 100644 --- a/fastapi/_compat/shared.py +++ b/fastapi/_compat/shared.py @@ -144,6 +144,45 @@ def field_annotation_is_scalar_sequence(annotation: Union[Type[Any], None]) -> b ) +def field_annotation_is_scalar_mapping( + annotation: Union[Type[Any], None], +) -> bool: + origin = get_origin(annotation) + if origin is Union or origin is UnionType: + at_least_one_scalar_mapping = False + for arg in get_args(annotation): + if field_annotation_is_scalar_mapping(arg): + at_least_one_scalar_mapping = True + continue + elif not field_annotation_is_scalar(arg): + return False + return at_least_one_scalar_mapping + return lenient_issubclass(origin, Mapping) and all( + field_annotation_is_scalar(sub_annotation) + for sub_annotation in get_args(annotation) + ) + + +def field_annotation_is_scalar_sequence_mapping( + annotation: Union[Type[Any], None], +) -> bool: + origin = get_origin(annotation) + if origin is Union or origin is UnionType: + at_least_one_scalar_mapping = False + for arg in get_args(annotation): + if field_annotation_is_scalar_sequence_mapping(arg): + at_least_one_scalar_mapping = True + continue + elif not field_annotation_is_scalar(arg): + return False + return at_least_one_scalar_mapping + return lenient_issubclass(origin, Mapping) and all( + field_annotation_is_scalar_sequence(sub_annotation) + or field_annotation_is_scalar(sub_annotation) + for sub_annotation in get_args(annotation) + ) + + def is_bytes_or_nonable_bytes_annotation(annotation: Any) -> bool: if lenient_issubclass(annotation, bytes): return True diff --git a/fastapi/_compat/v1.py b/fastapi/_compat/v1.py index e17ce8bea..31e3fe259 100644 --- a/fastapi/_compat/v1.py +++ b/fastapi/_compat/v1.py @@ -6,6 +6,7 @@ from typing import ( Callable, Dict, List, + Mapping, Sequence, Set, Tuple, @@ -84,6 +85,7 @@ else: from pydantic.v1.fields import ( SHAPE_FROZENSET, SHAPE_LIST, + SHAPE_MAPPING, SHAPE_SEQUENCE, SHAPE_SET, SHAPE_SINGLETON, @@ -144,6 +146,11 @@ sequence_shape_to_type = { SHAPE_TUPLE_ELLIPSIS: list, } +mapping_shapes = { + SHAPE_MAPPING, +} +mapping_shapes_to_type = {SHAPE_MAPPING: Mapping} + @dataclass class GenerateJsonSchema: @@ -219,6 +226,30 @@ def is_pv1_scalar_sequence_field(field: ModelField) -> bool: return False +def is_pv1_scalar_sequence_mapping_field(field: ModelField) -> bool: + if (field.shape in mapping_shapes) and not lenient_issubclass( # type: ignore[attr-defined] + field.type_, BaseModel + ): + if field.sub_fields is not None: # type: ignore[attr-defined] + for sub_field in field.sub_fields: # type: ignore[attr-defined] + if not is_scalar_sequence_field(sub_field): + return False + return True + return False + + +def is_pv1_scalar_mapping_field(field: ModelField) -> bool: + if (field.shape in mapping_shapes) and not lenient_issubclass( # type: ignore[attr-defined] + field.type_, BaseModel + ): + if field.sub_fields is not None: # type: ignore[attr-defined] + for sub_field in field.sub_fields: # type: ignore[attr-defined] + if not is_scalar_field(sub_field): + return False + return True + return False + + def _model_rebuild(model: Type[BaseModel]) -> None: model.update_forward_refs() @@ -277,6 +308,14 @@ def is_scalar_sequence_field(field: ModelField) -> bool: return is_pv1_scalar_sequence_field(field) +def is_scalar_mapping_field(field: ModelField) -> bool: + return is_pv1_scalar_mapping_field(field) + + +def is_scalar_sequence_mapping_field(field: ModelField) -> bool: + return is_pv1_scalar_sequence_mapping_field(field) + + def is_bytes_field(field: ModelField) -> bool: return lenient_issubclass(field.type_, bytes) # type: ignore[no-any-return] diff --git a/fastapi/_compat/v2.py b/fastapi/_compat/v2.py index 6a87b9ae9..de033083d 100644 --- a/fastapi/_compat/v2.py +++ b/fastapi/_compat/v2.py @@ -352,6 +352,16 @@ def is_scalar_sequence_field(field: ModelField) -> bool: return shared.field_annotation_is_scalar_sequence(field.field_info.annotation) +def is_scalar_mapping_field(field: ModelField) -> bool: + return shared.field_annotation_is_scalar_mapping(field.field_info.annotation) + + +def is_scalar_sequence_mapping_field(field: ModelField) -> bool: + return shared.field_annotation_is_scalar_sequence_mapping( + field.field_info.annotation + ) + + def is_bytes_field(field: ModelField) -> bool: return shared.is_bytes_or_nonable_bytes_annotation(field.type_) diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index 4a8b5cf60..a6ade4f59 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -37,7 +37,9 @@ from fastapi._compat import ( is_bytes_field, is_bytes_sequence_field, is_scalar_field, + is_scalar_mapping_field, is_scalar_sequence_field, + is_scalar_sequence_mapping_field, is_sequence_field, is_uploadfile_or_nonable_uploadfile_annotation, is_uploadfile_sequence_annotation, @@ -512,6 +514,7 @@ def analyze_param( assert ( is_scalar_field(field) or is_scalar_sequence_field(field) + or is_scalar_sequence_mapping_field(field) or ( _is_model_class(field.type_) # For Pydantic v1 @@ -707,6 +710,20 @@ def _validate_value_with_model_field( else: return deepcopy(field.default), [] v_, errors_ = field.validate(value, values, loc=loc) + if ( + errors_ + and isinstance(field.field_info, params.Query) + and isinstance(value, Mapping) + and (is_scalar_sequence_mapping_field(field) or is_scalar_mapping_field(field)) + ): + # Remove failing keys from the dict and try to re-validate + invalid_keys = {err["loc"][2] for err in errors_ if len(err["loc"]) >= 3} + v_, errors_ = field.validate( + {k: v for k, v in value.items() if k not in invalid_keys}, + values, + loc=loc, + ) + if _is_error_wrapper(errors_): # type: ignore[arg-type] return None, [errors_] elif isinstance(errors_, list): @@ -720,10 +737,19 @@ def _get_multidict_value( field: ModelField, values: Mapping[str, Any], alias: Union[str, None] = None ) -> Any: alias = alias or field.alias + value: Any = None if is_sequence_field(field) and isinstance(values, (ImmutableMultiDict, Headers)): value = values.getlist(alias) - else: - value = values.get(alias, None) + elif alias in values: + value = values[alias] + elif values and is_scalar_mapping_field(field) and isinstance(values, QueryParams): + value = dict(values) + elif ( + values + and is_scalar_sequence_mapping_field(field) + and isinstance(values, QueryParams) + ): + value = {key: values.getlist(key) for key in values.keys()} if ( value is None or ( @@ -820,6 +846,13 @@ def request_params_to_args( errors.extend(errors_) else: values[field.name] = v_ + # remove keys which were captured by a mapping query field but were otherwise specified + for field in fields: + if isinstance(values.get(field.name), dict) and ( + is_scalar_mapping_field(field) or is_scalar_sequence_mapping_field(field) + ): + for f_ in fields: + values[field.name].pop(f_.alias, None) return values, errors diff --git a/tests/main.py b/tests/main.py index 03a4de252..f0daad9ac 100644 --- a/tests/main.py +++ b/tests/main.py @@ -191,12 +191,12 @@ def get_query_param_required_type(query: int = Query()): @app.get("/query/mapping-params") def get_mapping_query_params(queries: Dict[str, str] = Query({})): - return f"foo bar {queries['foo']} {queries['bar']}" + return {"queries": queries} @app.get("/query/mapping-sequence-params") def get_sequence_mapping_query_params(queries: Dict[str, List[int]] = Query({})): - return f"foo bar {dict(queries)}" + return {"queries": queries} @app.get("/query/mixed-params") @@ -205,10 +205,13 @@ def get_mixed_mapping_query_params( mapping_query: Dict[str, str] = Query(), query: str = Query(), ): - return ( - f"foo bar {sequence_mapping_queries['foo'][0]} {sequence_mapping_queries['foo'][1]} " - f"{mapping_query['foo']} {mapping_query['bar']} {query}" - ) + return { + "queries": { + "query": query, + "mapping_query": mapping_query, + "sequence_mapping_queries": sequence_mapping_queries, + } + } @app.get("/query/mixed-type-params") @@ -218,7 +221,14 @@ def get_mixed_mapping_mixed_type_query_params( mapping_query_int: Dict[str, int] = Query({}), query: int = Query(), ): - return f"foo bar {query} {mapping_query_str} {mapping_query_int} {dict(sequence_mapping_queries)}" + return { + "queries": { + "query": query, + "mapping_query_str": mapping_query_str, + "mapping_query_int": mapping_query_int, + "sequence_mapping_queries": sequence_mapping_queries, + } + } @app.get("/enum-status-code", status_code=http.HTTPStatus.CREATED) diff --git a/tests/test_query.py b/tests/test_query.py index 40f64b0d0..5b483f17f 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -424,31 +424,44 @@ def test_query_frozenset_query_1_query_1_query_2(): def test_mapping_query(): response = client.get("/query/mapping-params/?foo=fuzz&bar=buzz") assert response.status_code == 200 - assert response.json() == "foo bar fuzz buzz" + assert response.json() == {"queries": {"bar": "buzz", "foo": "fuzz"}} def test_mapping_with_non_mapping_query(): response = client.get("/query/mixed-params/?foo=fuzz&foo=baz&bar=buzz&query=fizz") assert response.status_code == 200 - assert response.json() == "foo bar fuzz baz baz buzz fizz" + assert response.json() == { + "queries": { + "query": "fizz", + "mapping_query": {"foo": "baz", "bar": "buzz"}, + "sequence_mapping_queries": { + "foo": ["fuzz", "baz"], + "bar": ["buzz"], + }, + } + } def test_mapping_with_non_mapping_query_mixed_types(): response = client.get("/query/mixed-type-params/?foo=fuzz&foo=baz&bar=buzz&query=1") assert response.status_code == 200 - assert ( - response.json() - == "foo bar 1 {'foo': 'baz', 'bar': 'buzz', 'query': '1'} {'query': 1} {'foo': [], 'bar': [], 'query': [1]}" - ) + assert response.json() == { + "queries": { + "query": 1, + "mapping_query_str": {"foo": "baz", "bar": "buzz"}, + "mapping_query_int": {}, + "sequence_mapping_queries": {}, + } + } def test_sequence_mapping_query(): response = client.get("/query/mapping-sequence-params/?foo=1&foo=2") assert response.status_code == 200 - assert response.json() == "foo bar {'foo': [1, 2]}" + assert response.json() == {"queries": {"foo": [1, 2]}} def test_sequence_mapping_query_drops_invalid(): response = client.get("/query/mapping-sequence-params/?foo=fuzz&foo=buzz") assert response.status_code == 200 - assert response.json() == "foo bar {'foo': []}" + assert response.json() == {"queries": {}} diff --git a/tests/test_tutorial/test_query_params/test_tutorial007_py310.py b/tests/test_tutorial/test_query_params/test_tutorial007_py310.py index a6dfc140f..b275dacd8 100644 --- a/tests/test_tutorial/test_query_params/test_tutorial007_py310.py +++ b/tests/test_tutorial/test_query_params/test_tutorial007_py310.py @@ -15,7 +15,7 @@ def test_foo_needy_very(client: TestClient): assert response.status_code == 200 assert response.json() == { "query": 2, - "string_mapping": {"query": "2", "foo": "baz"}, - "mapping_query_int": {"query": 2}, - "sequence_mapping_queries": {"query": [1, 2], "foo": []}, + "string_mapping": {"foo": "baz"}, + "mapping_query_int": {}, + "sequence_mapping_queries": {}, }