Simplify code

This commit is contained in:
Yurii Motov 2025-12-04 14:03:19 +01:00
parent 8d24dfb938
commit fd59ae8b5d
1 changed files with 19 additions and 24 deletions

View File

@ -7,7 +7,6 @@ from typing import (
Any,
Dict,
List,
Mapping,
Sequence,
Set,
Tuple,
@ -172,6 +171,13 @@ def _get_model_config(model: BaseModel) -> Any:
return model.model_config
def _has_computed_fields(field: ModelField) -> bool:
computed_fields = field._type_adapter.core_schema.get("schema", {}).get(
"computed_fields", []
)
return len(computed_fields) > 0
def get_schema_from_model_field(
*,
field: ModelField,
@ -181,12 +187,9 @@ def get_schema_from_model_field(
],
separate_input_output_schemas: bool = True,
) -> Dict[str, Any]:
computed_fields = field._type_adapter.core_schema.get("schema", {}).get(
"computed_fields", []
)
override_mode: Union[Literal["validation"], None] = (
None
if (separate_input_output_schemas or len(computed_fields) > 0)
if (separate_input_output_schemas or _has_computed_fields(field))
else "validation"
)
# This expects that GenerateJsonSchema was already used to generate the definitions
@ -239,26 +242,18 @@ def get_definitions(
unique_flat_model_fields = {
f for f in flat_model_fields if f.type_ not in input_types
}
inputs: List[
Tuple[
ModelField,
Literal["validation", "serialization"],
Mapping[str, Any],
]
] = []
for field in list(fields) + list(unique_flat_model_fields):
has_computed_fields: bool = field._type_adapter.core_schema.get(
"schema", {}
).get("computed_fields", [])
override_mode: Union[Literal["validation"], None] = (
None
if (separate_input_output_schemas or has_computed_fields)
else "validation"
inputs = [
(
field,
(
field.mode
if (separate_input_output_schemas or _has_computed_fields(field))
else "validation"
),
field._type_adapter.core_schema,
)
inputs.append(
(field, override_mode or field.mode, field._type_adapter.core_schema)
)
for field in list(fields) + list(unique_flat_model_fields)
]
field_mapping, definitions = schema_generator.generate_definitions(inputs=inputs)
for item_def in cast(Dict[str, Dict[str, Any]], definitions).values():
if "description" in item_def: