diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index a982b071a..a881312d7 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -37,6 +37,7 @@ from pydantic.errors import MissingError from pydantic.fields import ( SHAPE_FROZENSET, SHAPE_LIST, + SHAPE_MAPPING, SHAPE_SEQUENCE, SHAPE_SET, SHAPE_SINGLETON, @@ -74,6 +75,9 @@ sequence_shape_to_type = { SHAPE_TUPLE_ELLIPSIS: list, } +mapping_shapes = {SHAPE_MAPPING} +mapping_types = dict +mapping_shapes_to_type = {SHAPE_MAPPING: dict} multipart_not_installed_error = ( 'Form data requires "python-multipart" to be installed. \n' @@ -245,6 +249,20 @@ def is_scalar_sequence_field(field: ModelField) -> bool: return False +def is_scalar_mapping_field(field: ModelField) -> bool: + if (field.shape in mapping_shapes) and not lenient_issubclass( + field.type_, BaseModel + ): + if field.sub_fields is not None: + for sub_field in field.sub_fields: + if not is_scalar_field(sub_field): + return False + return True + if lenient_issubclass(field.type_, mapping_types): + return True + return False + + def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature: signature = inspect.signature(call) globalns = getattr(call, "__globals__", {}) @@ -324,9 +342,10 @@ def get_dependant( add_param_to_fields(field=param_field, dependant=dependant) elif is_scalar_field(field=param_field): add_param_to_fields(field=param_field, dependant=dependant) - elif isinstance( - param.default, (params.Query, params.Header) - ) and is_scalar_sequence_field(param_field): + elif isinstance(param.default, (params.Query, params.Header)) and ( + is_scalar_sequence_field(param_field) + or is_scalar_mapping_field(param_field) + ): add_param_to_fields(field=param_field, dependant=dependant) else: field_info = param_field.field_info @@ -603,6 +622,10 @@ def request_params_to_args( received_params, (QueryParams, Headers) ): value = received_params.getlist(field.alias) or field.default + elif is_scalar_mapping_field(field) and isinstance( + received_params, (QueryParams,) + ): + value = received_params._dict else: value = received_params.get(field.alias) field_info = field.field_info diff --git a/tests/main.py b/tests/main.py index fce665704..3bc4150cc 100644 --- a/tests/main.py +++ b/tests/main.py @@ -189,6 +189,11 @@ def get_query_param_required_type(query: int = Query()): return f"foo bar {query}" +@app.get("/query/params") +def get_query_params(queries: Dict[str, int] = Query({})): + return f"foo bar {queries}" + + @app.get("/enum-status-code", status_code=http.HTTPStatus.CREATED) def get_enum_status_code(): return "foo bar" diff --git a/tests/test_application.py b/tests/test_application.py index b7d72f9ad..c412515db 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -1078,6 +1078,41 @@ openapi_schema = { ], } }, + "/query/params": { + "get": { + "summary": "Get Query Params", + "operationId": "get_query_params_query_params_get", + "parameters": [ + { + "required": False, + "schema": { + "title": "Queries", + "type": "object", + "additionalProperties": {"type": "integer"}, + "default": {}, + }, + "name": "queries", + "in": "query", + } + ], + "responses": { + "200": { + "description": "Successful Response", + "content": {"application/json": {"schema": {}}}, + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + }, + }, + } + }, "/enum-status-code": { "get": { "responses": { diff --git a/tests/test_invalid_sequence_param.py b/tests/test_invalid_sequence_param.py index 475786adb..f03d18fb5 100644 --- a/tests/test_invalid_sequence_param.py +++ b/tests/test_invalid_sequence_param.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Tuple import pytest from fastapi import FastAPI, Query diff --git a/tests/test_query.py b/tests/test_query.py index 0c73eb665..b44e188cd 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -54,6 +54,11 @@ response_not_valid_int = { ("/query/param-required/int?query=50", 200, "foo bar 50"), ("/query/param-required/int?query=foo", 422, response_not_valid_int), ("/query/frozenset/?query=1&query=1&query=2", 200, "1,2"), + ( + "/query/params?first-query=1&second-query=2", + 200, + "foo bar {'first-query': 1, " "'second-query': 2}", + ), ], ) def test_get_path(path, expected_status, expected_response):