From 0e835695f1d89f72f8e04ec896590e4898b9a006 Mon Sep 17 00:00:00 2001 From: JONEMI19 Date: Fri, 7 Jul 2023 20:07:03 +0000 Subject: [PATCH] add mapping types --- fastapi/_compat.py | 54 +++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 51 insertions(+), 3 deletions(-) diff --git a/fastapi/_compat.py b/fastapi/_compat.py index 8fadd6ed3..2fbef2d15 100644 --- a/fastapi/_compat.py +++ b/fastapi/_compat.py @@ -43,6 +43,14 @@ sequence_annotation_to_type = { sequence_types = tuple(sequence_annotation_to_type.keys()) +mapping_annotation_to_type = { + Mapping: list, + List: list, +} + +mapping_types = tuple(mapping_annotation_to_type.keys()) + + if PYDANTIC_V2: from pydantic import PydanticSchemaGenerationError as PydanticSchemaGenerationError from pydantic import TypeAdapter @@ -229,10 +237,10 @@ if PYDANTIC_V2: return field_annotation_is_scalar_sequence(field.field_info.annotation) def is_scalar_sequence_mapping_field(field: ModelField) -> bool: - return field_annotation_is_scalar_sequence(field.field_info.annotation) + return field_annotation_is_scalar_sequence_mapping(field.field_info.annotation) def is_scalar_mapping_field(field: ModelField) -> bool: - return field_annotation_is_scalar_sequence(field.field_info.annotation) + return field_annotation_is_scalar_mapping(field.field_info.annotation) def is_bytes_field(field: ModelField) -> bool: return is_bytes_or_nonable_bytes_annotation(field.type_) @@ -559,6 +567,15 @@ def field_annotation_is_sequence(annotation: Union[Type[Any], None]) -> bool: get_origin(annotation) ) +def _annotation_is_mapping(annotation: Union[Type[Any], None]) -> bool: + if lenient_issubclass(annotation, (str, bytes)): + return False + return lenient_issubclass(annotation, mapping_types) + +def field_annotation_is_mapping(annotation: Union[Type[Any], None]) -> bool: + return _annotation_is_mapping(annotation) or _annotation_is_sequence( + get_origin(annotation) + ) def value_is_sequence(value: Any) -> bool: return isinstance(value, sequence_types) and not isinstance(value, (str, bytes)) # type: ignore[arg-type] @@ -566,7 +583,7 @@ def value_is_sequence(value: Any) -> bool: def _annotation_is_complex(annotation: Union[Type[Any], None]) -> bool: return ( - lenient_issubclass(annotation, (BaseModel, Mapping, UploadFile)) + lenient_issubclass(annotation, (BaseModel, UploadFile)) or _annotation_is_sequence(annotation) or is_dataclass(annotation) ) @@ -606,6 +623,37 @@ def field_annotation_is_scalar_sequence(annotation: Union[Type[Any], None]) -> b for sub_annotation in get_args(annotation) ) +def field_annotation_is_scalar_mapping(annotation: Union[Type[Any], None]) -> bool: + origin = get_origin(annotation) + if origin is Union or origin is UnionType: + at_least_one_scalar_mapping = False + for arg in get_args(annotation): + if field_annotation_is_scalar_mapping(arg): + at_least_one_scalar_mapping = True + continue + elif not field_annotation_is_scalar(arg): + return False + return at_least_one_scalar_mapping + return field_annotation_is_mapping(annotation) and all( + field_annotation_is_scalar(sub_annotation) + for sub_annotation in get_args(annotation) + ) + +def field_annotation_is_scalar_sequence_mapping(annotation: Union[Type[Any], None]) -> bool: + origin = get_origin(annotation) + if origin is Union or origin is UnionType: + at_least_one_scalar_mapping = False + for arg in get_args(annotation): + if field_annotation_is_scalar_mapping(arg): + at_least_one_scalar_mapping = True + continue + elif not field_annotation_is_scalar(arg): + return False + return at_least_one_scalar_mapping + return field_annotation_is_mapping(annotation) and all( + (field_annotation_is_scalar_sequence(sub_annotation) or field_annotation_is_scalar(sub_annotation)) + for sub_annotation in get_args(annotation) + ) def is_bytes_or_nonable_bytes_annotation(annotation: Any) -> bool: if lenient_issubclass(annotation, bytes):