fix for pydantic v1

This commit is contained in:
JONEMI21 2025-11-07 09:28:03 +00:00
parent 78cffc8ad7
commit 61b683e562
2 changed files with 14 additions and 9 deletions

View File

@ -6,7 +6,6 @@ from typing import (
Callable,
Dict,
List,
Mapping,
Sequence,
Set,
Tuple,
@ -36,9 +35,9 @@ if not PYDANTIC_V2:
from pydantic.error_wrappers import ErrorWrapper as ErrorWrapper
from pydantic.errors import MissingError
from pydantic.fields import ( # type: ignore[attr-defined]
MAPPING_LIKE_SHAPES,
SHAPE_FROZENSET,
SHAPE_LIST,
SHAPE_MAPPING,
SHAPE_SEQUENCE,
SHAPE_SET,
SHAPE_SINGLETON,
@ -84,9 +83,9 @@ else:
from pydantic.v1.error_wrappers import ErrorWrapper as ErrorWrapper
from pydantic.v1.errors import MissingError
from pydantic.v1.fields import (
MAPPING_LIKE_SHAPES,
SHAPE_FROZENSET,
SHAPE_LIST,
SHAPE_MAPPING,
SHAPE_SEQUENCE,
SHAPE_SET,
SHAPE_SINGLETON,
@ -147,10 +146,7 @@ sequence_shape_to_type = {
SHAPE_TUPLE_ELLIPSIS: list,
}
mapping_shapes = {
SHAPE_MAPPING,
}
mapping_shapes_to_type = {SHAPE_MAPPING: Mapping}
mapping_shapes = MAPPING_LIKE_SHAPES
@dataclass
@ -233,7 +229,9 @@ def is_pv1_scalar_sequence_mapping_field(field: ModelField) -> bool:
):
if field.sub_fields is not None:
for sub_field in field.sub_fields:
if not is_scalar_sequence_field(sub_field):
if not (
is_scalar_sequence_field(sub_field) or is_scalar_field(sub_field)
):
return False
return True
return False

View File

@ -701,6 +701,13 @@ async def solve_dependencies(
)
def _extract_error_locs(errors_):
if isinstance(errors_, list):
errors_ = may_v1._regenerate_error_with_loc(errors=errors_, loc_prefix=())
return {err["loc"][2] for err in errors_ if len(err["loc"]) >= 3}
def _validate_value_with_model_field(
*, field: ModelField, value: Any, values: Dict[str, Any], loc: Tuple[str, ...]
) -> Tuple[Any, List[Any]]:
@ -717,7 +724,7 @@ def _validate_value_with_model_field(
and (is_scalar_sequence_mapping_field(field) or is_scalar_mapping_field(field))
):
# Remove failing keys from the dict and try to re-validate
invalid_keys = {err["loc"][2] for err in errors_ if len(err["loc"]) >= 3}
invalid_keys = _extract_error_locs(errors_)
v_, errors_ = field.validate(
{k: v for k, v in value.items() if k not in invalid_keys},
values,