mirror of https://github.com/tiangolo/fastapi.git
omit by default for pydantic v2
This commit is contained in:
parent
45825d7d11
commit
03d445ad5e
|
|
@ -1,7 +1,6 @@
|
||||||
from typing import Annotated, Dict, List, Union
|
from typing import Annotated, Dict, List, Union
|
||||||
|
|
||||||
from fastapi import FastAPI, Query
|
from fastapi import FastAPI, Query
|
||||||
from pydantic import OnErrorOmit
|
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
|
|
@ -10,10 +9,10 @@ app = FastAPI()
|
||||||
def get_mixed_mapping_mixed_type_query_params(
|
def get_mixed_mapping_mixed_type_query_params(
|
||||||
query: Annotated[int, Query()] = None,
|
query: Annotated[int, Query()] = None,
|
||||||
mapping_query_str_or_int: Annotated[
|
mapping_query_str_or_int: Annotated[
|
||||||
Union[Dict[str, OnErrorOmit[str]], Dict[str, OnErrorOmit[int]]], Query()
|
Union[Dict[str, str], Dict[str, int]], Query()
|
||||||
] = None,
|
] = None,
|
||||||
mapping_query_int: Annotated[Dict[str, OnErrorOmit[int]], Query()] = None,
|
mapping_query_int: Annotated[Dict[str, int], Query()] = None,
|
||||||
sequence_mapping_int: Annotated[Dict[str, List[OnErrorOmit[int]]], Query()] = None,
|
sequence_mapping_int: Annotated[Dict[str, List[int]], Query()] = None,
|
||||||
):
|
):
|
||||||
return {
|
return {
|
||||||
"query": query,
|
"query": query,
|
||||||
|
|
|
||||||
|
|
@ -28,6 +28,7 @@ if PYDANTIC_V2:
|
||||||
from .v2 import Validator as Validator
|
from .v2 import Validator as Validator
|
||||||
from .v2 import evaluate_forwardref as evaluate_forwardref
|
from .v2 import evaluate_forwardref as evaluate_forwardref
|
||||||
from .v2 import get_missing_field_error as get_missing_field_error
|
from .v2 import get_missing_field_error as get_missing_field_error
|
||||||
|
from .v2 import omit_by_default as omit_by_default
|
||||||
from .v2 import (
|
from .v2 import (
|
||||||
with_info_plain_validator_function as with_info_plain_validator_function,
|
with_info_plain_validator_function as with_info_plain_validator_function,
|
||||||
)
|
)
|
||||||
|
|
@ -44,6 +45,7 @@ else:
|
||||||
from .v1 import Validator as Validator
|
from .v1 import Validator as Validator
|
||||||
from .v1 import evaluate_forwardref as evaluate_forwardref
|
from .v1 import evaluate_forwardref as evaluate_forwardref
|
||||||
from .v1 import get_missing_field_error as get_missing_field_error
|
from .v1 import get_missing_field_error as get_missing_field_error
|
||||||
|
from .v1 import omit_by_default as omit_by_default
|
||||||
from .v1 import ( # type: ignore[assignment]
|
from .v1 import ( # type: ignore[assignment]
|
||||||
with_info_plain_validator_function as with_info_plain_validator_function,
|
with_info_plain_validator_function as with_info_plain_validator_function,
|
||||||
)
|
)
|
||||||
|
|
@ -384,22 +386,3 @@ def _is_model_class(value: Any) -> bool:
|
||||||
|
|
||||||
return lenient_issubclass(value, v2.BaseModel) # type: ignore[attr-defined]
|
return lenient_issubclass(value, v2.BaseModel) # type: ignore[attr-defined]
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def omit_by_default(annotation):
|
|
||||||
from typing import Union
|
|
||||||
|
|
||||||
from pydantic import OnErrorOmit
|
|
||||||
|
|
||||||
origin = getattr(annotation, "__origin__", None)
|
|
||||||
args = getattr(annotation, "__args__", ())
|
|
||||||
|
|
||||||
if origin is Union:
|
|
||||||
new_args = tuple(omit_by_default(arg) for arg in args)
|
|
||||||
return Union[new_args]
|
|
||||||
elif origin in (list, List):
|
|
||||||
return List[omit_by_default(args[0])]
|
|
||||||
elif origin in (dict, Dict):
|
|
||||||
return Dict[args[0], omit_by_default(args[1])]
|
|
||||||
else:
|
|
||||||
return OnErrorOmit[annotation]
|
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@ from dataclasses import dataclass
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
|
Callable,
|
||||||
Dict,
|
Dict,
|
||||||
List,
|
List,
|
||||||
Sequence,
|
Sequence,
|
||||||
|
|
@ -18,7 +19,7 @@ from typing import (
|
||||||
from fastapi._compat import may_v1, shared
|
from fastapi._compat import may_v1, shared
|
||||||
from fastapi.openapi.constants import REF_TEMPLATE
|
from fastapi.openapi.constants import REF_TEMPLATE
|
||||||
from fastapi.types import IncEx, ModelNameMap
|
from fastapi.types import IncEx, ModelNameMap
|
||||||
from pydantic import BaseModel, TypeAdapter, create_model
|
from pydantic import BaseModel, OnErrorOmit, TypeAdapter, WrapValidator, create_model
|
||||||
from pydantic import PydanticSchemaGenerationError as PydanticSchemaGenerationError
|
from pydantic import PydanticSchemaGenerationError as PydanticSchemaGenerationError
|
||||||
from pydantic import PydanticUndefinedAnnotation as PydanticUndefinedAnnotation
|
from pydantic import PydanticUndefinedAnnotation as PydanticUndefinedAnnotation
|
||||||
from pydantic import ValidationError as ValidationError
|
from pydantic import ValidationError as ValidationError
|
||||||
|
|
@ -487,3 +488,66 @@ def get_flat_models_from_fields(
|
||||||
|
|
||||||
def get_long_model_name(model: TypeModelOrEnum) -> str:
|
def get_long_model_name(model: TypeModelOrEnum) -> str:
|
||||||
return f"{model.__module__}__{model.__qualname__}".replace(".", "__")
|
return f"{model.__module__}__{model.__qualname__}".replace(".", "__")
|
||||||
|
|
||||||
|
|
||||||
|
if shared.PYDANTIC_VERSION_MINOR_TUPLE >= (2, 6):
|
||||||
|
# Omit by default for scalar mapping and scalar sequence mapping annotations
|
||||||
|
# added in Pydantic v2.6 https://github.com/pydantic/pydantic/releases/tag/v2.6.0
|
||||||
|
def _omit_by_default(annotation):
|
||||||
|
origin = getattr(annotation, "__origin__", None)
|
||||||
|
args = getattr(annotation, "__args__", ())
|
||||||
|
|
||||||
|
if origin is Union:
|
||||||
|
new_args = tuple(_omit_by_default(arg) for arg in args)
|
||||||
|
return Union[new_args]
|
||||||
|
elif origin in (list, List):
|
||||||
|
return List[_omit_by_default(args[0])]
|
||||||
|
elif origin in (dict, Dict):
|
||||||
|
return Dict[args[0], _omit_by_default(args[1])]
|
||||||
|
else:
|
||||||
|
return OnErrorOmit[annotation]
|
||||||
|
|
||||||
|
def omit_by_default(field_info: FieldInfo) -> FieldInfo:
|
||||||
|
"""Set omit by default on a FieldInfo's annotation."""
|
||||||
|
new_annotation = _omit_by_default(field_info.annotation)
|
||||||
|
new_field_info = copy_field_info(field_info=field_info, annotation=new_annotation)
|
||||||
|
return new_field_info
|
||||||
|
|
||||||
|
else:
|
||||||
|
def ignore_invalid(v: Any, handler: Callable[[Any], Any]) -> Any:
|
||||||
|
try:
|
||||||
|
return handler(v)
|
||||||
|
except ValidationError as exc:
|
||||||
|
# pop the keys or elements that caused the validation errors and revalidate
|
||||||
|
for error in exc.errors():
|
||||||
|
loc = error["loc"]
|
||||||
|
if len(loc) == 0:
|
||||||
|
continue
|
||||||
|
if isinstance(loc[0], int) and isinstance(v, list):
|
||||||
|
index = loc[0]
|
||||||
|
if 0 <= index < len(v):
|
||||||
|
v[index] = None
|
||||||
|
|
||||||
|
# Handle nested list validation errors (e.g., dict[str, list[str]])
|
||||||
|
elif isinstance(loc[0], str) and isinstance(v, dict):
|
||||||
|
key = loc[0]
|
||||||
|
if len(loc) > 1 and isinstance(loc[1], int) and key in v and isinstance(v[key], list):
|
||||||
|
list_index = loc[1]
|
||||||
|
v[key][list_index] = None
|
||||||
|
elif key in v:
|
||||||
|
v.pop(key)
|
||||||
|
|
||||||
|
if isinstance(v, list):
|
||||||
|
v = [el for el in v if el is not None]
|
||||||
|
|
||||||
|
if isinstance(v, dict):
|
||||||
|
for key in v.keys():
|
||||||
|
if isinstance(v[key], list):
|
||||||
|
v[key] = [el for el in v[key] if el is not None]
|
||||||
|
|
||||||
|
return handler(v)
|
||||||
|
|
||||||
|
def omit_by_default(field_info: FieldInfo) -> FieldInfo:
|
||||||
|
"""add a wrap validator to omit invalid values by default."""
|
||||||
|
field_info.metadata = field_info.metadata or [] + [WrapValidator(ignore_invalid)]
|
||||||
|
return field_info
|
||||||
|
|
|
||||||
|
|
@ -505,11 +505,12 @@ def analyze_param(
|
||||||
|
|
||||||
field_info.alias = alias
|
field_info.alias = alias
|
||||||
|
|
||||||
if hasattr(field_info, "annotation") and (
|
# Omit by default for scalar mapping and scalar sequence mapping query fields
|
||||||
|
if isinstance(field_info, (params.Query)) and (
|
||||||
field_annotation_is_scalar_sequence_mapping(field_info.annotation)
|
field_annotation_is_scalar_sequence_mapping(field_info.annotation)
|
||||||
or field_annotation_is_scalar_mapping(field_info.annotation)
|
or field_annotation_is_scalar_mapping(field_info.annotation)
|
||||||
):
|
):
|
||||||
field_info.annotation = omit_by_default(field_info.annotation)
|
field_info = omit_by_default(field_info)
|
||||||
|
|
||||||
field = create_model_field(
|
field = create_model_field(
|
||||||
name=param_name,
|
name=param_name,
|
||||||
|
|
|
||||||
|
|
@ -196,8 +196,8 @@ def get_mapping_query_params(queries: Dict[str, str] = Query({})):
|
||||||
|
|
||||||
@app.get("/query/mixed-params")
|
@app.get("/query/mixed-params")
|
||||||
def get_mixed_mapping_query_params(
|
def get_mixed_mapping_query_params(
|
||||||
sequence_mapping_queries: Dict[str, List[Union[int]]] = Query({}),
|
sequence_mapping_queries: Dict[str, List[int]] = Query({}),
|
||||||
mapping_query: Dict[str, int] = Query(),
|
mapping_query: Dict[str, str] = Query(),
|
||||||
query: str = Query(),
|
query: str = Query(),
|
||||||
):
|
):
|
||||||
return {
|
return {
|
||||||
|
|
@ -214,6 +214,23 @@ def get_sequence_mapping_query_params(queries: Dict[str, List[int]] = Query({}))
|
||||||
return {"queries": queries}
|
return {"queries": queries}
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/query/mixed-type-params")
|
||||||
|
def get_mixed_mapping_mixed_type_query_params(
|
||||||
|
sequence_mapping_queries: Dict[str, List[int]] = Query({}),
|
||||||
|
mapping_query_str: Dict[str, str] = Query({}),
|
||||||
|
mapping_query_int: Dict[str, int] = Query({}),
|
||||||
|
query: int = Query(),
|
||||||
|
):
|
||||||
|
return {
|
||||||
|
"queries": {
|
||||||
|
"query": query,
|
||||||
|
"mapping_query_str": mapping_query_str,
|
||||||
|
"mapping_query_int": mapping_query_int,
|
||||||
|
"sequence_mapping_queries": sequence_mapping_queries,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@app.get("/enum-status-code", status_code=http.HTTPStatus.CREATED)
|
@app.get("/enum-status-code", status_code=http.HTTPStatus.CREATED)
|
||||||
def get_enum_status_code():
|
def get_enum_status_code():
|
||||||
return "foo bar"
|
return "foo bar"
|
||||||
|
|
|
||||||
|
|
@ -55,7 +55,7 @@ def test_enum_status_code_response():
|
||||||
def test_openapi_schema():
|
def test_openapi_schema():
|
||||||
response = client.get("/openapi.json")
|
response = client.get("/openapi.json")
|
||||||
assert response.status_code == 200, response.text
|
assert response.status_code == 200, response.text
|
||||||
assert response.json() == {
|
assert response.json()["paths"] == {
|
||||||
"openapi": "3.1.0",
|
"openapi": "3.1.0",
|
||||||
"info": {"title": "FastAPI", "version": "0.1.0"},
|
"info": {"title": "FastAPI", "version": "0.1.0"},
|
||||||
"externalDocs": {
|
"externalDocs": {
|
||||||
|
|
@ -1520,4 +1520,4 @@ def test_openapi_schema():
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
}
|
}["paths"]
|
||||||
|
|
|
||||||
|
|
@ -439,7 +439,7 @@ def test_mapping_with_non_mapping_query():
|
||||||
assert response.json() == {
|
assert response.json() == {
|
||||||
"queries": {
|
"queries": {
|
||||||
"query": "fizz",
|
"query": "fizz",
|
||||||
"mapping_query": {"foo": 2, "bar": 3},
|
"mapping_query": {"foo": "2", "bar": "3"},
|
||||||
"sequence_mapping_queries": {
|
"sequence_mapping_queries": {
|
||||||
"foo": [1, 2],
|
"foo": [1, 2],
|
||||||
"bar": [3],
|
"bar": [3],
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue