♻️ Refactor and simplify Pydantic v2 (and v1) compatibility internal utils (#14862)

This commit is contained in:
Sebastián Ramírez 2026-02-07 00:34:32 -08:00 committed by GitHub
parent 8eac94bd91
commit 2e7d3754cd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 44 additions and 55 deletions

View File

@ -1,6 +1,14 @@
from .shared import PYDANTIC_VERSION_MINOR_TUPLE as PYDANTIC_VERSION_MINOR_TUPLE
from .shared import annotation_is_pydantic_v1 as annotation_is_pydantic_v1
from .shared import field_annotation_is_scalar as field_annotation_is_scalar
from .shared import (
field_annotation_is_scalar_sequence as field_annotation_is_scalar_sequence,
)
from .shared import field_annotation_is_sequence as field_annotation_is_sequence
from .shared import (
is_bytes_or_nonable_bytes_annotation as is_bytes_or_nonable_bytes_annotation,
)
from .shared import is_bytes_sequence_annotation as is_bytes_sequence_annotation
from .shared import is_pydantic_v1_model_instance as is_pydantic_v1_model_instance
from .shared import (
is_uploadfile_or_nonable_uploadfile_annotation as is_uploadfile_or_nonable_uploadfile_annotation,
@ -25,11 +33,7 @@ from .v2 import get_flat_models_from_fields as get_flat_models_from_fields
from .v2 import get_missing_field_error as get_missing_field_error
from .v2 import get_model_name_map as get_model_name_map
from .v2 import get_schema_from_model_field as get_schema_from_model_field
from .v2 import is_bytes_field as is_bytes_field
from .v2 import is_bytes_sequence_field as is_bytes_sequence_field
from .v2 import is_scalar_field as is_scalar_field
from .v2 import is_scalar_sequence_field as is_scalar_sequence_field
from .v2 import is_sequence_field as is_sequence_field
from .v2 import serialize_sequence_value as serialize_sequence_value
from .v2 import (
with_info_plain_validator_function as with_info_plain_validator_function,

View File

@ -102,18 +102,10 @@ class ModelField:
sa = self.field_info.serialization_alias
return sa or None
@property
def required(self) -> bool:
return self.field_info.is_required()
@property
def default(self) -> Any:
return self.get_default()
@property
def type_(self) -> Any:
return self.field_info.annotation
def __post_init__(self) -> None:
with warnings.catch_warnings():
# Pydantic >= 2.12.0 warns about field specific metadata that is unused
@ -267,9 +259,9 @@ def get_definitions(
for model in flat_serialization_models
]
flat_model_fields = flat_validation_model_fields + flat_serialization_model_fields
input_types = {f.type_ for f in fields}
input_types = {f.field_info.annotation for f in fields}
unique_flat_model_fields = {
f for f in flat_model_fields if f.type_ not in input_types
f for f in flat_model_fields if f.field_info.annotation not in input_types
}
inputs = [
(
@ -304,22 +296,6 @@ def is_scalar_field(field: ModelField) -> bool:
) and not isinstance(field.field_info, params.Body)
def is_sequence_field(field: ModelField) -> bool:
return shared.field_annotation_is_sequence(field.field_info.annotation)
def is_scalar_sequence_field(field: ModelField) -> bool:
return shared.field_annotation_is_scalar_sequence(field.field_info.annotation)
def is_bytes_field(field: ModelField) -> bool:
return shared.is_bytes_or_nonable_bytes_annotation(field.type_)
def is_bytes_sequence_field(field: ModelField) -> bool:
return shared.is_bytes_sequence_annotation(field.type_)
def copy_field_info(*, field_info: FieldInfo, annotation: Any) -> FieldInfo:
cls = type(field_info)
merged_field_info = cls.from_annotation(annotation)
@ -428,7 +404,7 @@ def get_flat_models_from_annotation(
def get_flat_models_from_field(
field: ModelField, known_models: TypeModelSet
) -> TypeModelSet:
field_type = field.type_
field_type = field.field_info.annotation
if lenient_issubclass(field_type, BaseModel):
if field_type in known_models:
return known_models

View File

@ -25,13 +25,13 @@ from fastapi._compat import (
create_body_model,
evaluate_forwardref,
field_annotation_is_scalar,
field_annotation_is_scalar_sequence,
field_annotation_is_sequence,
get_cached_model_fields,
get_missing_field_error,
is_bytes_field,
is_bytes_sequence_field,
is_bytes_or_nonable_bytes_annotation,
is_bytes_sequence_annotation,
is_scalar_field,
is_scalar_sequence_field,
is_sequence_field,
is_uploadfile_or_nonable_uploadfile_annotation,
is_uploadfile_sequence_annotation,
lenient_issubclass,
@ -182,8 +182,10 @@ def _get_flat_fields_from_params(fields: list[ModelField]) -> list[ModelField]:
if not fields:
return fields
first_field = fields[0]
if len(fields) == 1 and lenient_issubclass(first_field.type_, BaseModel):
fields_to_extract = get_cached_model_fields(first_field.type_)
if len(fields) == 1 and lenient_issubclass(
first_field.field_info.annotation, BaseModel
):
fields_to_extract = get_cached_model_fields(first_field.field_info.annotation)
return fields_to_extract
return fields
@ -521,8 +523,8 @@ def analyze_param(
elif isinstance(field_info, params.Query):
assert (
is_scalar_field(field)
or is_scalar_sequence_field(field)
or lenient_issubclass(field.type_, BaseModel)
or field_annotation_is_scalar_sequence(field.field_info.annotation)
or lenient_issubclass(field.field_info.annotation, BaseModel)
), f"Query parameter {param_name!r} must be one of the supported types"
return ParamDetails(type_annotation=type_annotation, depends=depends, field=field)
@ -708,7 +710,7 @@ def _validate_value_with_model_field(
*, field: ModelField, value: Any, values: dict[str, Any], loc: tuple[str, ...]
) -> tuple[Any, list[Any]]:
if value is None:
if field.required:
if field.field_info.is_required():
return None, [get_missing_field_error(loc=loc)]
else:
return deepcopy(field.default), []
@ -725,7 +727,7 @@ def _get_multidict_value(
alias = alias or get_validation_alias(field)
if (
(not _is_json_field(field))
and is_sequence_field(field)
and field_annotation_is_sequence(field.field_info.annotation)
and isinstance(values, (ImmutableMultiDict, Headers))
):
value = values.getlist(alias)
@ -738,9 +740,12 @@ def _get_multidict_value(
and isinstance(value, str) # For type checks
and value == ""
)
or (is_sequence_field(field) and len(value) == 0)
or (
field_annotation_is_sequence(field.field_info.annotation)
and len(value) == 0
)
):
if field.required:
if field.field_info.is_required():
return
else:
return deepcopy(field.default)
@ -761,8 +766,10 @@ def request_params_to_args(
fields_to_extract = fields
single_not_embedded_field = False
default_convert_underscores = True
if len(fields) == 1 and lenient_issubclass(first_field.type_, BaseModel):
fields_to_extract = get_cached_model_fields(first_field.type_)
if len(fields) == 1 and lenient_issubclass(
first_field.field_info.annotation, BaseModel
):
fields_to_extract = get_cached_model_fields(first_field.field_info.annotation)
single_not_embedded_field = True
# If headers are in a Pydantic model, the way to disable convert_underscores
# would be with Header(convert_underscores=False) at the Pydantic model level
@ -866,8 +873,8 @@ def _should_embed_body_fields(fields: list[ModelField]) -> bool:
# otherwise it has to be embedded, so that the key value pair can be extracted
if (
isinstance(first_field.field_info, params.Form)
and not lenient_issubclass(first_field.type_, BaseModel)
and not is_union_of_base_models(first_field.type_)
and not lenient_issubclass(first_field.field_info.annotation, BaseModel)
and not is_union_of_base_models(first_field.field_info.annotation)
):
return True
return False
@ -884,12 +891,12 @@ async def _extract_form_body(
field_info = field.field_info
if (
isinstance(field_info, params.File)
and is_bytes_field(field)
and is_bytes_or_nonable_bytes_annotation(field.field_info.annotation)
and isinstance(value, UploadFile)
):
value = await value.read()
elif (
is_bytes_sequence_field(field)
is_bytes_sequence_annotation(field.field_info.annotation)
and isinstance(field_info, params.File)
and value_is_sequence(value)
):
@ -936,10 +943,10 @@ async def request_body_to_args(
if (
single_not_embedded_field
and lenient_issubclass(first_field.type_, BaseModel)
and lenient_issubclass(first_field.field_info.annotation, BaseModel)
and isinstance(received_body, FormData)
):
fields_to_extract = get_cached_model_fields(first_field.type_)
fields_to_extract = get_cached_model_fields(first_field.field_info.annotation)
if isinstance(received_body, FormData):
body_to_process = await _extract_form_body(fields_to_extract, received_body)
@ -992,7 +999,9 @@ def get_body_field(
BodyModel = create_body_model(
fields=flat_dependant.body_params, model_name=model_name
)
required = any(True for f in flat_dependant.body_params if f.required)
required = any(
True for f in flat_dependant.body_params if f.field_info.is_required()
)
BodyFieldInfo_kwargs: dict[str, Any] = {
"annotation": BodyModel,
"alias": "body",

View File

@ -129,7 +129,7 @@ def _get_openapi_operation_parameters(
default_convert_underscores = True
if len(flat_dependant.header_params) == 1:
first_field = flat_dependant.header_params[0]
if lenient_issubclass(first_field.type_, BaseModel):
if lenient_issubclass(first_field.field_info.annotation, BaseModel):
default_convert_underscores = getattr(
first_field.field_info, "convert_underscores", True
)
@ -161,7 +161,7 @@ def _get_openapi_operation_parameters(
parameter = {
"name": name,
"in": param_type.value,
"required": param.required,
"required": param.field_info.is_required(),
"schema": param_schema,
}
if field_info.description:
@ -198,7 +198,7 @@ def get_openapi_operation_request_body(
)
field_info = cast(Body, body_field.field_info)
request_media_type = field_info.media_type
required = body_field.required
required = body_field.field_info.is_required()
request_body_oai: dict[str, Any] = {}
if required:
request_body_oai["required"] = required