♻️ Simplify internals, remove Pydantic v1 only logic, no longer needed (#14857)

This commit is contained in:
Sebastián Ramírez 2026-02-06 11:04:24 -08:00 committed by GitHub
parent ac8362c447
commit cf55bade7e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 19 additions and 112 deletions

View File

@ -25,7 +25,6 @@ from .v2 import copy_field_info as copy_field_info
from .v2 import create_body_model as create_body_model from .v2 import create_body_model as create_body_model
from .v2 import evaluate_forwardref as evaluate_forwardref from .v2 import evaluate_forwardref as evaluate_forwardref
from .v2 import get_cached_model_fields as get_cached_model_fields from .v2 import get_cached_model_fields as get_cached_model_fields
from .v2 import get_compat_model_name_map as get_compat_model_name_map
from .v2 import get_definitions as get_definitions from .v2 import get_definitions as get_definitions
from .v2 import get_missing_field_error as get_missing_field_error from .v2 import get_missing_field_error as get_missing_field_error
from .v2 import get_schema_from_model_field as get_schema_from_model_field from .v2 import get_schema_from_model_field as get_schema_from_model_field

View File

@ -1,7 +1,7 @@
import re import re
import warnings import warnings
from collections.abc import Sequence from collections.abc import Sequence
from copy import copy, deepcopy from copy import copy
from dataclasses import dataclass, is_dataclass from dataclasses import dataclass, is_dataclass
from enum import Enum from enum import Enum
from functools import lru_cache from functools import lru_cache
@ -169,11 +169,11 @@ class ModelField:
values: dict[str, Any] = {}, # noqa: B006 values: dict[str, Any] = {}, # noqa: B006
*, *,
loc: tuple[Union[int, str], ...] = (), loc: tuple[Union[int, str], ...] = (),
) -> tuple[Any, Union[list[dict[str, Any]], None]]: ) -> tuple[Any, list[dict[str, Any]]]:
try: try:
return ( return (
self._type_adapter.validate_python(value, from_attributes=True), self._type_adapter.validate_python(value, from_attributes=True),
None, [],
) )
except ValidationError as exc: except ValidationError as exc:
return None, _regenerate_error_with_loc( return None, _regenerate_error_with_loc(
@ -305,94 +305,12 @@ def get_definitions(
if "description" in item_def: if "description" in item_def:
item_description = cast(str, item_def["description"]).split("\f")[0] item_description = cast(str, item_def["description"]).split("\f")[0]
item_def["description"] = item_description item_def["description"] = item_description
new_mapping, new_definitions = _remap_definitions_and_field_mappings( # definitions: dict[DefsRef, dict[str, Any]]
model_name_map=model_name_map, # but mypy complains about general str in other places that are not declared as
definitions=definitions, # type: ignore[arg-type] # DefsRef, although DefsRef is just str:
field_mapping=field_mapping, # DefsRef = NewType('DefsRef', str)
) # So, a cast to simplify the types here
return new_mapping, new_definitions return field_mapping, cast(dict[str, dict[str, Any]], definitions)
def _replace_refs(
*,
schema: dict[str, Any],
old_name_to_new_name_map: dict[str, str],
) -> dict[str, Any]:
new_schema = deepcopy(schema)
for key, value in new_schema.items():
if key == "$ref":
value = schema["$ref"]
if isinstance(value, str):
ref_name = schema["$ref"].split("/")[-1]
if ref_name in old_name_to_new_name_map:
new_name = old_name_to_new_name_map[ref_name]
new_schema["$ref"] = REF_TEMPLATE.format(model=new_name)
continue
if isinstance(value, dict):
new_schema[key] = _replace_refs(
schema=value,
old_name_to_new_name_map=old_name_to_new_name_map,
)
elif isinstance(value, list):
new_value = []
for item in value:
if isinstance(item, dict):
new_item = _replace_refs(
schema=item,
old_name_to_new_name_map=old_name_to_new_name_map,
)
new_value.append(new_item)
else:
new_value.append(item)
new_schema[key] = new_value
return new_schema
def _remap_definitions_and_field_mappings(
*,
model_name_map: ModelNameMap,
definitions: dict[str, Any],
field_mapping: dict[
tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
],
) -> tuple[
dict[tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue],
dict[str, Any],
]:
old_name_to_new_name_map = {}
for field_key, schema in field_mapping.items():
model = field_key[0].type_
if model not in model_name_map or "$ref" not in schema:
continue
new_name = model_name_map[model]
old_name = schema["$ref"].split("/")[-1]
if old_name in {f"{new_name}-Input", f"{new_name}-Output"}:
continue
old_name_to_new_name_map[old_name] = new_name
new_field_mapping: dict[
tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
] = {}
for field_key, schema in field_mapping.items():
new_schema = _replace_refs(
schema=schema,
old_name_to_new_name_map=old_name_to_new_name_map,
)
new_field_mapping[field_key] = new_schema
new_definitions = {}
for key, value in definitions.items():
if key in old_name_to_new_name_map:
new_key = old_name_to_new_name_map[key]
else:
new_key = key
new_value = _replace_refs(
schema=value,
old_name_to_new_name_map=old_name_to_new_name_map,
)
new_definitions[new_key] = new_value
return new_field_mapping, new_definitions
def is_scalar_field(field: ModelField) -> bool: def is_scalar_field(field: ModelField) -> bool:
@ -441,7 +359,7 @@ def serialize_sequence_value(*, field: ModelField, value: Any) -> Sequence[Any]:
return shared.sequence_annotation_to_type[origin_type](value) # type: ignore[no-any-return,index] return shared.sequence_annotation_to_type[origin_type](value) # type: ignore[no-any-return,index]
def get_missing_field_error(loc: tuple[str, ...]) -> dict[str, Any]: def get_missing_field_error(loc: tuple[Union[int, str], ...]) -> dict[str, Any]:
error = ValidationError.from_exception_data( error = ValidationError.from_exception_data(
"Field required", [{"type": "missing", "loc": loc, "input": {}}] "Field required", [{"type": "missing", "loc": loc, "input": {}}]
).errors(include_url=False)[0] ).errors(include_url=False)[0]
@ -499,11 +417,6 @@ def get_model_name_map(unique_models: TypeModelSet) -> dict[TypeModelOrEnum, str
return {v: k for k, v in name_model_map.items()} return {v: k for k, v in name_model_map.items()}
def get_compat_model_name_map(fields: list[ModelField]) -> ModelNameMap:
flat_models = get_flat_models_from_fields(fields, known_models=set())
return get_model_name_map(flat_models)
def get_flat_models_from_model( def get_flat_models_from_model(
model: type["BaseModel"], known_models: Union[TypeModelSet, None] = None model: type["BaseModel"], known_models: Union[TypeModelSet, None] = None
) -> TypeModelSet: ) -> TypeModelSet:

View File

@ -21,7 +21,6 @@ from fastapi._compat import (
ModelField, ModelField,
RequiredParam, RequiredParam,
Undefined, Undefined,
_regenerate_error_with_loc,
copy_field_info, copy_field_info,
create_body_model, create_body_model,
evaluate_forwardref, evaluate_forwardref,
@ -718,12 +717,7 @@ def _validate_value_with_model_field(
return None, [get_missing_field_error(loc=loc)] return None, [get_missing_field_error(loc=loc)]
else: else:
return deepcopy(field.default), [] return deepcopy(field.default), []
v_, errors_ = field.validate(value, values, loc=loc) return field.validate(value, values, loc=loc)
if isinstance(errors_, list):
new_errors = _regenerate_error_with_loc(errors=errors_, loc_prefix=())
return None, new_errors
else:
return v_, []
def _is_json_field(field: ModelField) -> bool: def _is_json_field(field: ModelField) -> bool:

View File

@ -9,11 +9,14 @@ from fastapi import routing
from fastapi._compat import ( from fastapi._compat import (
ModelField, ModelField,
Undefined, Undefined,
get_compat_model_name_map,
get_definitions, get_definitions,
get_schema_from_model_field, get_schema_from_model_field,
lenient_issubclass, lenient_issubclass,
) )
from fastapi._compat.v2 import (
get_flat_models_from_fields,
get_model_name_map,
)
from fastapi.datastructures import DefaultPlaceholder from fastapi.datastructures import DefaultPlaceholder
from fastapi.dependencies.models import Dependant from fastapi.dependencies.models import Dependant
from fastapi.dependencies.utils import ( from fastapi.dependencies.utils import (
@ -512,7 +515,8 @@ def get_openapi(
webhook_paths: dict[str, dict[str, Any]] = {} webhook_paths: dict[str, dict[str, Any]] = {}
operation_ids: set[str] = set() operation_ids: set[str] = set()
all_fields = get_fields_from_routes(list(routes or []) + list(webhooks or [])) all_fields = get_fields_from_routes(list(routes or []) + list(webhooks or []))
model_name_map = get_compat_model_name_map(all_fields) flat_models = get_flat_models_from_fields(all_fields, known_models=set())
model_name_map = get_model_name_map(flat_models)
field_mapping, definitions = get_definitions( field_mapping, definitions = get_definitions(
fields=all_fields, fields=all_fields,
model_name_map=model_name_map, model_name_map=model_name_map,

View File

@ -277,15 +277,12 @@ async def serialize_response(
endpoint_ctx: Optional[EndpointContext] = None, endpoint_ctx: Optional[EndpointContext] = None,
) -> Any: ) -> Any:
if field: if field:
errors = []
if is_coroutine: if is_coroutine:
value, errors_ = field.validate(response_content, {}, loc=("response",)) value, errors = field.validate(response_content, {}, loc=("response",))
else: else:
value, errors_ = await run_in_threadpool( value, errors = await run_in_threadpool(
field.validate, response_content, {}, loc=("response",) field.validate, response_content, {}, loc=("response",)
) )
if isinstance(errors_, list):
errors.extend(errors_)
if errors: if errors:
ctx = endpoint_ctx or EndpointContext() ctx = endpoint_ctx or EndpointContext()
raise ResponseValidationError( raise ResponseValidationError(