From 378d30d18b36c20ac39b6e920caeb078e538b9c1 Mon Sep 17 00:00:00 2001 From: JONEMI19 Date: Tue, 14 Mar 2023 07:13:26 +0000 Subject: [PATCH] handle Map[scalar, List[scalar]] and Map[scalar, scalar] sepaeratly --- fastapi/dependencies/utils.py | 25 ++++++++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index 01fe85c2b..5002e1a05 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -257,9 +257,22 @@ def is_scalar_mapping_field(field: ModelField) -> bool: if field.sub_fields is None: return True for sub_field in field.sub_fields: - if (not is_scalar_field(sub_field)) and ( - not is_scalar_sequence_field(sub_field) - ): + if not is_scalar_field(sub_field): + return False + return True + if lenient_issubclass(field.type_, mapping_types): + return True + return False + + +def is_scalar_sequence_mapping_field(field: ModelField) -> bool: + if (field.shape in mapping_shapes) and not lenient_issubclass( + field.type_, BaseModel + ): + if field.sub_fields is None: + return True + for sub_field in field.sub_fields: + if not is_scalar_sequence_field(sub_field): return False return True if lenient_issubclass(field.type_, mapping_types): @@ -349,6 +362,7 @@ def get_dependant( elif isinstance(param.default, (params.Query, params.Header)) and ( is_scalar_sequence_field(param_field) or is_scalar_mapping_field(param_field) + or is_scalar_sequence_mapping_field(param_field) ): add_param_to_fields(field=param_field, dependant=dependant) else: @@ -628,10 +642,15 @@ def request_params_to_args( value = received_params.getlist(field.alias) or field.default elif is_scalar_mapping_field(field) and isinstance( received_params, (QueryParams,) + ): + value = dict(received_params.multi_items()) + elif is_scalar_sequence_mapping_field(field) and isinstance( + received_params, (QueryParams,) ): value = defaultdict(list) for key, field_value in received_params.multi_items(): value[key].append(field_value) + value = dict(value) else: value = received_params.get(field.alias) field_info = field.field_info