omit by default for pydantic v2

This commit is contained in:
JONEMI21 2025-11-10 10:29:17 +00:00
parent 45825d7d11
commit 03d445ad5e
7 changed files with 95 additions and 31 deletions

View File

@ -1,7 +1,6 @@
from typing import Annotated, Dict, List, Union from typing import Annotated, Dict, List, Union
from fastapi import FastAPI, Query from fastapi import FastAPI, Query
from pydantic import OnErrorOmit
app = FastAPI() app = FastAPI()
@ -10,10 +9,10 @@ app = FastAPI()
def get_mixed_mapping_mixed_type_query_params( def get_mixed_mapping_mixed_type_query_params(
query: Annotated[int, Query()] = None, query: Annotated[int, Query()] = None,
mapping_query_str_or_int: Annotated[ mapping_query_str_or_int: Annotated[
Union[Dict[str, OnErrorOmit[str]], Dict[str, OnErrorOmit[int]]], Query() Union[Dict[str, str], Dict[str, int]], Query()
] = None, ] = None,
mapping_query_int: Annotated[Dict[str, OnErrorOmit[int]], Query()] = None, mapping_query_int: Annotated[Dict[str, int], Query()] = None,
sequence_mapping_int: Annotated[Dict[str, List[OnErrorOmit[int]]], Query()] = None, sequence_mapping_int: Annotated[Dict[str, List[int]], Query()] = None,
): ):
return { return {
"query": query, "query": query,

View File

@ -28,6 +28,7 @@ if PYDANTIC_V2:
from .v2 import Validator as Validator from .v2 import Validator as Validator
from .v2 import evaluate_forwardref as evaluate_forwardref from .v2 import evaluate_forwardref as evaluate_forwardref
from .v2 import get_missing_field_error as get_missing_field_error from .v2 import get_missing_field_error as get_missing_field_error
from .v2 import omit_by_default as omit_by_default
from .v2 import ( from .v2 import (
with_info_plain_validator_function as with_info_plain_validator_function, with_info_plain_validator_function as with_info_plain_validator_function,
) )
@ -44,6 +45,7 @@ else:
from .v1 import Validator as Validator from .v1 import Validator as Validator
from .v1 import evaluate_forwardref as evaluate_forwardref from .v1 import evaluate_forwardref as evaluate_forwardref
from .v1 import get_missing_field_error as get_missing_field_error from .v1 import get_missing_field_error as get_missing_field_error
from .v1 import omit_by_default as omit_by_default
from .v1 import ( # type: ignore[assignment] from .v1 import ( # type: ignore[assignment]
with_info_plain_validator_function as with_info_plain_validator_function, with_info_plain_validator_function as with_info_plain_validator_function,
) )
@ -384,22 +386,3 @@ def _is_model_class(value: Any) -> bool:
return lenient_issubclass(value, v2.BaseModel) # type: ignore[attr-defined] return lenient_issubclass(value, v2.BaseModel) # type: ignore[attr-defined]
return False return False
def omit_by_default(annotation):
from typing import Union
from pydantic import OnErrorOmit
origin = getattr(annotation, "__origin__", None)
args = getattr(annotation, "__args__", ())
if origin is Union:
new_args = tuple(omit_by_default(arg) for arg in args)
return Union[new_args]
elif origin in (list, List):
return List[omit_by_default(args[0])]
elif origin in (dict, Dict):
return Dict[args[0], omit_by_default(args[1])]
else:
return OnErrorOmit[annotation]

View File

@ -5,6 +5,7 @@ from dataclasses import dataclass
from enum import Enum from enum import Enum
from typing import ( from typing import (
Any, Any,
Callable,
Dict, Dict,
List, List,
Sequence, Sequence,
@ -18,7 +19,7 @@ from typing import (
from fastapi._compat import may_v1, shared from fastapi._compat import may_v1, shared
from fastapi.openapi.constants import REF_TEMPLATE from fastapi.openapi.constants import REF_TEMPLATE
from fastapi.types import IncEx, ModelNameMap from fastapi.types import IncEx, ModelNameMap
from pydantic import BaseModel, TypeAdapter, create_model from pydantic import BaseModel, OnErrorOmit, TypeAdapter, WrapValidator, create_model
from pydantic import PydanticSchemaGenerationError as PydanticSchemaGenerationError from pydantic import PydanticSchemaGenerationError as PydanticSchemaGenerationError
from pydantic import PydanticUndefinedAnnotation as PydanticUndefinedAnnotation from pydantic import PydanticUndefinedAnnotation as PydanticUndefinedAnnotation
from pydantic import ValidationError as ValidationError from pydantic import ValidationError as ValidationError
@ -487,3 +488,66 @@ def get_flat_models_from_fields(
def get_long_model_name(model: TypeModelOrEnum) -> str: def get_long_model_name(model: TypeModelOrEnum) -> str:
return f"{model.__module__}__{model.__qualname__}".replace(".", "__") return f"{model.__module__}__{model.__qualname__}".replace(".", "__")
if shared.PYDANTIC_VERSION_MINOR_TUPLE >= (2, 6):
# Omit by default for scalar mapping and scalar sequence mapping annotations
# added in Pydantic v2.6 https://github.com/pydantic/pydantic/releases/tag/v2.6.0
def _omit_by_default(annotation):
origin = getattr(annotation, "__origin__", None)
args = getattr(annotation, "__args__", ())
if origin is Union:
new_args = tuple(_omit_by_default(arg) for arg in args)
return Union[new_args]
elif origin in (list, List):
return List[_omit_by_default(args[0])]
elif origin in (dict, Dict):
return Dict[args[0], _omit_by_default(args[1])]
else:
return OnErrorOmit[annotation]
def omit_by_default(field_info: FieldInfo) -> FieldInfo:
"""Set omit by default on a FieldInfo's annotation."""
new_annotation = _omit_by_default(field_info.annotation)
new_field_info = copy_field_info(field_info=field_info, annotation=new_annotation)
return new_field_info
else:
def ignore_invalid(v: Any, handler: Callable[[Any], Any]) -> Any:
try:
return handler(v)
except ValidationError as exc:
# pop the keys or elements that caused the validation errors and revalidate
for error in exc.errors():
loc = error["loc"]
if len(loc) == 0:
continue
if isinstance(loc[0], int) and isinstance(v, list):
index = loc[0]
if 0 <= index < len(v):
v[index] = None
# Handle nested list validation errors (e.g., dict[str, list[str]])
elif isinstance(loc[0], str) and isinstance(v, dict):
key = loc[0]
if len(loc) > 1 and isinstance(loc[1], int) and key in v and isinstance(v[key], list):
list_index = loc[1]
v[key][list_index] = None
elif key in v:
v.pop(key)
if isinstance(v, list):
v = [el for el in v if el is not None]
if isinstance(v, dict):
for key in v.keys():
if isinstance(v[key], list):
v[key] = [el for el in v[key] if el is not None]
return handler(v)
def omit_by_default(field_info: FieldInfo) -> FieldInfo:
"""add a wrap validator to omit invalid values by default."""
field_info.metadata = field_info.metadata or [] + [WrapValidator(ignore_invalid)]
return field_info

View File

@ -505,11 +505,12 @@ def analyze_param(
field_info.alias = alias field_info.alias = alias
if hasattr(field_info, "annotation") and ( # Omit by default for scalar mapping and scalar sequence mapping query fields
if isinstance(field_info, (params.Query)) and (
field_annotation_is_scalar_sequence_mapping(field_info.annotation) field_annotation_is_scalar_sequence_mapping(field_info.annotation)
or field_annotation_is_scalar_mapping(field_info.annotation) or field_annotation_is_scalar_mapping(field_info.annotation)
): ):
field_info.annotation = omit_by_default(field_info.annotation) field_info = omit_by_default(field_info)
field = create_model_field( field = create_model_field(
name=param_name, name=param_name,

View File

@ -196,8 +196,8 @@ def get_mapping_query_params(queries: Dict[str, str] = Query({})):
@app.get("/query/mixed-params") @app.get("/query/mixed-params")
def get_mixed_mapping_query_params( def get_mixed_mapping_query_params(
sequence_mapping_queries: Dict[str, List[Union[int]]] = Query({}), sequence_mapping_queries: Dict[str, List[int]] = Query({}),
mapping_query: Dict[str, int] = Query(), mapping_query: Dict[str, str] = Query(),
query: str = Query(), query: str = Query(),
): ):
return { return {
@ -214,6 +214,23 @@ def get_sequence_mapping_query_params(queries: Dict[str, List[int]] = Query({}))
return {"queries": queries} return {"queries": queries}
@app.get("/query/mixed-type-params")
def get_mixed_mapping_mixed_type_query_params(
sequence_mapping_queries: Dict[str, List[int]] = Query({}),
mapping_query_str: Dict[str, str] = Query({}),
mapping_query_int: Dict[str, int] = Query({}),
query: int = Query(),
):
return {
"queries": {
"query": query,
"mapping_query_str": mapping_query_str,
"mapping_query_int": mapping_query_int,
"sequence_mapping_queries": sequence_mapping_queries,
}
}
@app.get("/enum-status-code", status_code=http.HTTPStatus.CREATED) @app.get("/enum-status-code", status_code=http.HTTPStatus.CREATED)
def get_enum_status_code(): def get_enum_status_code():
return "foo bar" return "foo bar"

View File

@ -55,7 +55,7 @@ def test_enum_status_code_response():
def test_openapi_schema(): def test_openapi_schema():
response = client.get("/openapi.json") response = client.get("/openapi.json")
assert response.status_code == 200, response.text assert response.status_code == 200, response.text
assert response.json() == { assert response.json()["paths"] == {
"openapi": "3.1.0", "openapi": "3.1.0",
"info": {"title": "FastAPI", "version": "0.1.0"}, "info": {"title": "FastAPI", "version": "0.1.0"},
"externalDocs": { "externalDocs": {
@ -1520,4 +1520,4 @@ def test_openapi_schema():
}, },
} }
}, },
} }["paths"]

View File

@ -439,7 +439,7 @@ def test_mapping_with_non_mapping_query():
assert response.json() == { assert response.json() == {
"queries": { "queries": {
"query": "fizz", "query": "fizz",
"mapping_query": {"foo": 2, "bar": 3}, "mapping_query": {"foo": "2", "bar": "3"},
"sequence_mapping_queries": { "sequence_mapping_queries": {
"foo": [1, 2], "foo": [1, 2],
"bar": [3], "bar": [3],