add mapping types

This commit is contained in:
JONEMI19 2023-07-07 20:07:03 +00:00
parent 168d839114
commit 0e835695f1
1 changed files with 51 additions and 3 deletions

View File

@ -43,6 +43,14 @@ sequence_annotation_to_type = {
sequence_types = tuple(sequence_annotation_to_type.keys())
mapping_annotation_to_type = {
Mapping: list,
List: list,
}
mapping_types = tuple(mapping_annotation_to_type.keys())
if PYDANTIC_V2:
from pydantic import PydanticSchemaGenerationError as PydanticSchemaGenerationError
from pydantic import TypeAdapter
@ -229,10 +237,10 @@ if PYDANTIC_V2:
return field_annotation_is_scalar_sequence(field.field_info.annotation)
def is_scalar_sequence_mapping_field(field: ModelField) -> bool:
return field_annotation_is_scalar_sequence(field.field_info.annotation)
return field_annotation_is_scalar_sequence_mapping(field.field_info.annotation)
def is_scalar_mapping_field(field: ModelField) -> bool:
return field_annotation_is_scalar_sequence(field.field_info.annotation)
return field_annotation_is_scalar_mapping(field.field_info.annotation)
def is_bytes_field(field: ModelField) -> bool:
return is_bytes_or_nonable_bytes_annotation(field.type_)
@ -559,6 +567,15 @@ def field_annotation_is_sequence(annotation: Union[Type[Any], None]) -> bool:
get_origin(annotation)
)
def _annotation_is_mapping(annotation: Union[Type[Any], None]) -> bool:
if lenient_issubclass(annotation, (str, bytes)):
return False
return lenient_issubclass(annotation, mapping_types)
def field_annotation_is_mapping(annotation: Union[Type[Any], None]) -> bool:
return _annotation_is_mapping(annotation) or _annotation_is_sequence(
get_origin(annotation)
)
def value_is_sequence(value: Any) -> bool:
return isinstance(value, sequence_types) and not isinstance(value, (str, bytes)) # type: ignore[arg-type]
@ -566,7 +583,7 @@ def value_is_sequence(value: Any) -> bool:
def _annotation_is_complex(annotation: Union[Type[Any], None]) -> bool:
return (
lenient_issubclass(annotation, (BaseModel, Mapping, UploadFile))
lenient_issubclass(annotation, (BaseModel, UploadFile))
or _annotation_is_sequence(annotation)
or is_dataclass(annotation)
)
@ -606,6 +623,37 @@ def field_annotation_is_scalar_sequence(annotation: Union[Type[Any], None]) -> b
for sub_annotation in get_args(annotation)
)
def field_annotation_is_scalar_mapping(annotation: Union[Type[Any], None]) -> bool:
origin = get_origin(annotation)
if origin is Union or origin is UnionType:
at_least_one_scalar_mapping = False
for arg in get_args(annotation):
if field_annotation_is_scalar_mapping(arg):
at_least_one_scalar_mapping = True
continue
elif not field_annotation_is_scalar(arg):
return False
return at_least_one_scalar_mapping
return field_annotation_is_mapping(annotation) and all(
field_annotation_is_scalar(sub_annotation)
for sub_annotation in get_args(annotation)
)
def field_annotation_is_scalar_sequence_mapping(annotation: Union[Type[Any], None]) -> bool:
origin = get_origin(annotation)
if origin is Union or origin is UnionType:
at_least_one_scalar_mapping = False
for arg in get_args(annotation):
if field_annotation_is_scalar_mapping(arg):
at_least_one_scalar_mapping = True
continue
elif not field_annotation_is_scalar(arg):
return False
return at_least_one_scalar_mapping
return field_annotation_is_mapping(annotation) and all(
(field_annotation_is_scalar_sequence(sub_annotation) or field_annotation_is_scalar(sub_annotation))
for sub_annotation in get_args(annotation)
)
def is_bytes_or_nonable_bytes_annotation(annotation: Any) -> bool:
if lenient_issubclass(annotation, bytes):