diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index b3d013241d..aa9ed56bed 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -40,7 +40,6 @@ from pydantic.fields import ( SHAPE_LIST, SHAPE_MAPPING, SHAPE_SEQUENCE, - SHAPE_MAPPING, SHAPE_SET, SHAPE_SINGLETON, SHAPE_TUPLE, @@ -251,6 +250,7 @@ def is_scalar_sequence_field(field: ModelField) -> bool: return True return False + def is_scalar_mapping_field(field: ModelField) -> bool: if (field.shape in mapping_shapes) and not lenient_issubclass( field.type_, BaseModel @@ -276,6 +276,7 @@ def is_scalar_sequence_mapping_field(field: ModelField) -> bool: return True return False + def is_scalar_mapping_field(field: ModelField) -> bool: if ( (field.shape in mapping_shapes) @@ -544,9 +545,10 @@ def is_body_param(*, param_field: ModelField, is_path_param: bool) -> bool: param_field.field_info, (params.Query, params.Header) ) and is_scalar_sequence_field(param_field): return False - elif isinstance( - param_field.field_info, (params.Query, params.Header) - ) and (is_scalar_sequence_mapping_field(param_field) or is_scalar_mapping_field(param_field)): + elif isinstance(param_field.field_info, (params.Query, params.Header)) and ( + is_scalar_sequence_mapping_field(param_field) + or is_scalar_mapping_field(param_field) + ): return False else: assert isinstance( @@ -751,7 +753,7 @@ def request_params_to_args( ): if not len(received_params.multi_items()): value = field.default - else: + else: value = defaultdict(list) for k, v in received_params.multi_items(): value[k].append(v) diff --git a/tests/test_invalid_mapping_param.py b/tests/test_invalid_mapping_param.py index ef49a3218d..249a759d50 100644 --- a/tests/test_invalid_mapping_param.py +++ b/tests/test_invalid_mapping_param.py @@ -1,15 +1,13 @@ -from typing import Mapping, List +from typing import List, Mapping import pytest from fastapi import FastAPI, Query -from pydantic import BaseModel def test_invalid_sequence(): with pytest.raises(AssertionError): app = FastAPI() - @app.get("/items/") def read_items(q: Mapping[str, List[List[str]]] = Query(default=None)): pass # pragma: no cover