From baa5cd2ca6c67c760387c1bf47ef56037c94343c Mon Sep 17 00:00:00 2001 From: Daniyar Yeralin Date: Wed, 12 Aug 2020 14:37:07 -0400 Subject: [PATCH] Introduce mapping shapes --- fastapi/dependencies/utils.py | 30 ++++++++++++++++++++++++++++-- 1 file changed, 28 insertions(+), 2 deletions(-) diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index 7c9f7e847..d09ccde8f 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -41,7 +41,7 @@ from pydantic.fields import ( SHAPE_TUPLE_ELLIPSIS, FieldInfo, ModelField, - Required, + Required, SHAPE_MAPPING, ) from pydantic.schema import get_annotation_from_field_info from pydantic.typing import ForwardRef, evaluate_forwardref @@ -69,6 +69,13 @@ sequence_shape_to_type = { SHAPE_TUPLE_ELLIPSIS: list, } +mapping_shapes = { + SHAPE_MAPPING +} +mapping_types = (dict) +mapping_shapes_to_type = { + SHAPE_MAPPING: dict +} multipart_not_installed_error = ( 'Form data requires "python-multipart" to be installed. \n' @@ -240,6 +247,20 @@ def is_scalar_sequence_field(field: ModelField) -> bool: return False +def is_scalar_mapping_field(field: ModelField) -> bool: + if (field.shape in mapping_shapes) and not lenient_issubclass( + field.type_, BaseModel + ): + if field.sub_fields is not None: + for sub_field in field.sub_fields: + if not is_scalar_field(sub_field): + return False + return True + if lenient_issubclass(field.type_, mapping_types): + return True + return False + + def get_typed_signature(call: Callable) -> inspect.Signature: signature = inspect.signature(call) globalns = getattr(call, "__globals__", {}) @@ -324,7 +345,8 @@ def get_dependant( add_param_to_fields(field=param_field, dependant=dependant) elif isinstance( param.default, (params.Query, params.Header) - ) and is_scalar_sequence_field(param_field): + ) and (is_scalar_sequence_field(param_field) + or is_scalar_mapping_field(param_field)): add_param_to_fields(field=param_field, dependant=dependant) else: field_info = param_field.field_info @@ -610,6 +632,10 @@ def request_params_to_args( received_params, (QueryParams, Headers) ): value = received_params.getlist(field.alias) or field.default + elif is_scalar_mapping_field(field) and isinstance( + received_params, (QueryParams,) + ): + value = received_params._dict else: value = received_params.get(field.alias) field_info = field.field_info