Calculate `override_mode` for each field in `get_definitions`

This commit is contained in:
Yurii Motov 2025-12-04 10:43:51 +01:00
parent c57ac7bdf3
commit fcc906593b
1 changed files with 20 additions and 12 deletions

View File

@ -7,6 +7,7 @@ from typing import (
Any, Any,
Dict, Dict,
List, List,
Mapping,
Sequence, Sequence,
Set, Set,
Tuple, Tuple,
@ -208,15 +209,7 @@ def get_definitions(
Dict[Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue], Dict[Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue],
Dict[str, Dict[str, Any]], Dict[str, Dict[str, Any]],
]: ]:
has_computed_fields: bool = any(
field._type_adapter.core_schema.get("schema", {}).get("computed_fields", [])
for field in fields
)
schema_generator = GenerateJsonSchema(ref_template=REF_TEMPLATE) schema_generator = GenerateJsonSchema(ref_template=REF_TEMPLATE)
override_mode: Union[Literal["validation"], None] = (
None if (separate_input_output_schemas or has_computed_fields) else "validation"
)
validation_fields = [field for field in fields if field.mode == "validation"] validation_fields = [field for field in fields if field.mode == "validation"]
serialization_fields = [field for field in fields if field.mode == "serialization"] serialization_fields = [field for field in fields if field.mode == "serialization"]
flat_validation_models = get_flat_models_from_fields( flat_validation_models = get_flat_models_from_fields(
@ -246,11 +239,26 @@ def get_definitions(
unique_flat_model_fields = { unique_flat_model_fields = {
f for f in flat_model_fields if f.type_ not in input_types f for f in flat_model_fields if f.type_ not in input_types
} }
inputs: List[
Tuple[
ModelField,
Literal["validation", "serialization"],
Mapping[str, Any],
]
] = []
for field in list(fields) + list(unique_flat_model_fields):
has_computed_fields: bool = field._type_adapter.core_schema.get(
"schema", {}
).get("computed_fields", [])
override_mode: Union[Literal["validation"], None] = (
None
if (separate_input_output_schemas or has_computed_fields)
else "validation"
)
inputs.append(
(field, override_mode or field.mode, field._type_adapter.core_schema)
)
inputs = [
(field, override_mode or field.mode, field._type_adapter.core_schema)
for field in list(fields) + list(unique_flat_model_fields)
]
field_mapping, definitions = schema_generator.generate_definitions(inputs=inputs) field_mapping, definitions = schema_generator.generate_definitions(inputs=inputs)
for item_def in cast(Dict[str, Dict[str, Any]], definitions).values(): for item_def in cast(Dict[str, Dict[str, Any]], definitions).values():
if "description" in item_def: if "description" in item_def: