mirror of https://github.com/tiangolo/fastapi.git
♻️ Use new Pydantic v2 JSON Schema generator (#9813)
Co-authored-by: David Montague <35119617+dmontagu@users.noreply.github.com>
This commit is contained in:
parent
a65281fe09
commit
d4e3dcfa3a
|
|
@ -79,6 +79,7 @@ if PYDANTIC_V2:
|
||||||
class ModelField:
|
class ModelField:
|
||||||
field_info: FieldInfo
|
field_info: FieldInfo
|
||||||
name: str
|
name: str
|
||||||
|
mode: Literal["validation", "serialization"] = "validation"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def alias(self) -> str:
|
def alias(self) -> str:
|
||||||
|
|
@ -178,9 +179,12 @@ if PYDANTIC_V2:
|
||||||
field: ModelField,
|
field: ModelField,
|
||||||
schema_generator: GenerateJsonSchema,
|
schema_generator: GenerateJsonSchema,
|
||||||
model_name_map: ModelNameMap,
|
model_name_map: ModelNameMap,
|
||||||
|
field_mapping: Dict[
|
||||||
|
Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
|
||||||
|
],
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
# This expects that GenerateJsonSchema was already used to generate the definitions
|
# This expects that GenerateJsonSchema was already used to generate the definitions
|
||||||
json_schema = schema_generator.generate_inner(field._type_adapter.core_schema)
|
json_schema = field_mapping[(field, field.mode)]
|
||||||
if "$ref" not in json_schema:
|
if "$ref" not in json_schema:
|
||||||
# TODO remove when deprecating Pydantic v1
|
# TODO remove when deprecating Pydantic v1
|
||||||
# Ref: https://github.com/pydantic/pydantic/blob/d61792cc42c80b13b23e3ffa74bc37ec7c77f7d1/pydantic/schema.py#L207
|
# Ref: https://github.com/pydantic/pydantic/blob/d61792cc42c80b13b23e3ffa74bc37ec7c77f7d1/pydantic/schema.py#L207
|
||||||
|
|
@ -197,12 +201,12 @@ if PYDANTIC_V2:
|
||||||
fields: List[ModelField],
|
fields: List[ModelField],
|
||||||
schema_generator: GenerateJsonSchema,
|
schema_generator: GenerateJsonSchema,
|
||||||
model_name_map: ModelNameMap,
|
model_name_map: ModelNameMap,
|
||||||
) -> Dict[str, Dict[str, Any]]:
|
) -> Tuple[Dict[str, Any], Dict[str, Dict[str, Any]]]:
|
||||||
inputs = [
|
inputs = [
|
||||||
(field, "validation", field._type_adapter.core_schema) for field in fields
|
(field, field.mode, field._type_adapter.core_schema) for field in fields
|
||||||
]
|
]
|
||||||
_, definitions = schema_generator.generate_definitions(inputs=inputs) # type: ignore[arg-type]
|
field_mapping, definitions = schema_generator.generate_definitions(inputs=inputs) # type: ignore[arg-type]
|
||||||
return definitions # type: ignore[return-value]
|
return field_mapping, definitions # type: ignore[return-value]
|
||||||
|
|
||||||
def is_scalar_field(field: ModelField) -> bool:
|
def is_scalar_field(field: ModelField) -> bool:
|
||||||
from fastapi import params
|
from fastapi import params
|
||||||
|
|
@ -419,6 +423,9 @@ else:
|
||||||
field: ModelField,
|
field: ModelField,
|
||||||
schema_generator: GenerateJsonSchema,
|
schema_generator: GenerateJsonSchema,
|
||||||
model_name_map: ModelNameMap,
|
model_name_map: ModelNameMap,
|
||||||
|
field_mapping: Dict[
|
||||||
|
Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
|
||||||
|
],
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
# This expects that GenerateJsonSchema was already used to generate the definitions
|
# This expects that GenerateJsonSchema was already used to generate the definitions
|
||||||
return field_schema( # type: ignore[no-any-return]
|
return field_schema( # type: ignore[no-any-return]
|
||||||
|
|
@ -434,9 +441,11 @@ else:
|
||||||
fields: List[ModelField],
|
fields: List[ModelField],
|
||||||
schema_generator: GenerateJsonSchema,
|
schema_generator: GenerateJsonSchema,
|
||||||
model_name_map: ModelNameMap,
|
model_name_map: ModelNameMap,
|
||||||
) -> Dict[str, Dict[str, Any]]:
|
) -> Tuple[Dict[str, Any], Dict[str, Dict[str, Any]]]:
|
||||||
models = get_flat_models_from_fields(fields, known_models=set())
|
models = get_flat_models_from_fields(fields, known_models=set())
|
||||||
return get_model_definitions(flat_models=models, model_name_map=model_name_map)
|
return {}, get_model_definitions(
|
||||||
|
flat_models=models, model_name_map=model_name_map
|
||||||
|
)
|
||||||
|
|
||||||
def is_scalar_field(field: ModelField) -> bool:
|
def is_scalar_field(field: ModelField) -> bool:
|
||||||
return is_pv1_scalar_field(field)
|
return is_pv1_scalar_field(field)
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Type, Union,
|
||||||
from fastapi import routing
|
from fastapi import routing
|
||||||
from fastapi._compat import (
|
from fastapi._compat import (
|
||||||
GenerateJsonSchema,
|
GenerateJsonSchema,
|
||||||
|
JsonSchemaValue,
|
||||||
ModelField,
|
ModelField,
|
||||||
Undefined,
|
Undefined,
|
||||||
get_compat_model_name_map,
|
get_compat_model_name_map,
|
||||||
|
|
@ -30,6 +31,7 @@ from fastapi.utils import (
|
||||||
from starlette.responses import JSONResponse
|
from starlette.responses import JSONResponse
|
||||||
from starlette.routing import BaseRoute
|
from starlette.routing import BaseRoute
|
||||||
from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY
|
from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY
|
||||||
|
from typing_extensions import Literal
|
||||||
|
|
||||||
validation_error_definition = {
|
validation_error_definition = {
|
||||||
"title": "ValidationError",
|
"title": "ValidationError",
|
||||||
|
|
@ -90,6 +92,9 @@ def get_openapi_operation_parameters(
|
||||||
all_route_params: Sequence[ModelField],
|
all_route_params: Sequence[ModelField],
|
||||||
schema_generator: GenerateJsonSchema,
|
schema_generator: GenerateJsonSchema,
|
||||||
model_name_map: ModelNameMap,
|
model_name_map: ModelNameMap,
|
||||||
|
field_mapping: Dict[
|
||||||
|
Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
|
||||||
|
],
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
parameters = []
|
parameters = []
|
||||||
for param in all_route_params:
|
for param in all_route_params:
|
||||||
|
|
@ -101,6 +106,7 @@ def get_openapi_operation_parameters(
|
||||||
field=param,
|
field=param,
|
||||||
schema_generator=schema_generator,
|
schema_generator=schema_generator,
|
||||||
model_name_map=model_name_map,
|
model_name_map=model_name_map,
|
||||||
|
field_mapping=field_mapping,
|
||||||
)
|
)
|
||||||
parameter = {
|
parameter = {
|
||||||
"name": param.alias,
|
"name": param.alias,
|
||||||
|
|
@ -123,6 +129,9 @@ def get_openapi_operation_request_body(
|
||||||
body_field: Optional[ModelField],
|
body_field: Optional[ModelField],
|
||||||
schema_generator: GenerateJsonSchema,
|
schema_generator: GenerateJsonSchema,
|
||||||
model_name_map: ModelNameMap,
|
model_name_map: ModelNameMap,
|
||||||
|
field_mapping: Dict[
|
||||||
|
Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
|
||||||
|
],
|
||||||
) -> Optional[Dict[str, Any]]:
|
) -> Optional[Dict[str, Any]]:
|
||||||
if not body_field:
|
if not body_field:
|
||||||
return None
|
return None
|
||||||
|
|
@ -131,6 +140,7 @@ def get_openapi_operation_request_body(
|
||||||
field=body_field,
|
field=body_field,
|
||||||
schema_generator=schema_generator,
|
schema_generator=schema_generator,
|
||||||
model_name_map=model_name_map,
|
model_name_map=model_name_map,
|
||||||
|
field_mapping=field_mapping,
|
||||||
)
|
)
|
||||||
field_info = cast(Body, body_field.field_info)
|
field_info = cast(Body, body_field.field_info)
|
||||||
request_media_type = field_info.media_type
|
request_media_type = field_info.media_type
|
||||||
|
|
@ -198,6 +208,9 @@ def get_openapi_path(
|
||||||
operation_ids: Set[str],
|
operation_ids: Set[str],
|
||||||
schema_generator: GenerateJsonSchema,
|
schema_generator: GenerateJsonSchema,
|
||||||
model_name_map: ModelNameMap,
|
model_name_map: ModelNameMap,
|
||||||
|
field_mapping: Dict[
|
||||||
|
Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
|
||||||
|
],
|
||||||
) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
|
) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
|
||||||
path = {}
|
path = {}
|
||||||
security_schemes: Dict[str, Any] = {}
|
security_schemes: Dict[str, Any] = {}
|
||||||
|
|
@ -228,6 +241,7 @@ def get_openapi_path(
|
||||||
all_route_params=all_route_params,
|
all_route_params=all_route_params,
|
||||||
schema_generator=schema_generator,
|
schema_generator=schema_generator,
|
||||||
model_name_map=model_name_map,
|
model_name_map=model_name_map,
|
||||||
|
field_mapping=field_mapping,
|
||||||
)
|
)
|
||||||
parameters.extend(operation_parameters)
|
parameters.extend(operation_parameters)
|
||||||
if parameters:
|
if parameters:
|
||||||
|
|
@ -248,6 +262,7 @@ def get_openapi_path(
|
||||||
body_field=route.body_field,
|
body_field=route.body_field,
|
||||||
schema_generator=schema_generator,
|
schema_generator=schema_generator,
|
||||||
model_name_map=model_name_map,
|
model_name_map=model_name_map,
|
||||||
|
field_mapping=field_mapping,
|
||||||
)
|
)
|
||||||
if request_body_oai:
|
if request_body_oai:
|
||||||
operation["requestBody"] = request_body_oai
|
operation["requestBody"] = request_body_oai
|
||||||
|
|
@ -264,6 +279,7 @@ def get_openapi_path(
|
||||||
operation_ids=operation_ids,
|
operation_ids=operation_ids,
|
||||||
schema_generator=schema_generator,
|
schema_generator=schema_generator,
|
||||||
model_name_map=model_name_map,
|
model_name_map=model_name_map,
|
||||||
|
field_mapping=field_mapping,
|
||||||
)
|
)
|
||||||
callbacks[callback.name] = {callback.path: cb_path}
|
callbacks[callback.name] = {callback.path: cb_path}
|
||||||
operation["callbacks"] = callbacks
|
operation["callbacks"] = callbacks
|
||||||
|
|
@ -293,6 +309,7 @@ def get_openapi_path(
|
||||||
field=route.response_field,
|
field=route.response_field,
|
||||||
schema_generator=schema_generator,
|
schema_generator=schema_generator,
|
||||||
model_name_map=model_name_map,
|
model_name_map=model_name_map,
|
||||||
|
field_mapping=field_mapping,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
response_schema = {}
|
response_schema = {}
|
||||||
|
|
@ -325,6 +342,7 @@ def get_openapi_path(
|
||||||
field=field,
|
field=field,
|
||||||
schema_generator=schema_generator,
|
schema_generator=schema_generator,
|
||||||
model_name_map=model_name_map,
|
model_name_map=model_name_map,
|
||||||
|
field_mapping=field_mapping,
|
||||||
)
|
)
|
||||||
media_type = route_response_media_type or "application/json"
|
media_type = route_response_media_type or "application/json"
|
||||||
additional_schema = (
|
additional_schema = (
|
||||||
|
|
@ -437,7 +455,7 @@ def get_openapi(
|
||||||
all_fields = get_fields_from_routes(list(routes or []) + list(webhooks or []))
|
all_fields = get_fields_from_routes(list(routes or []) + list(webhooks or []))
|
||||||
model_name_map = get_compat_model_name_map(all_fields)
|
model_name_map = get_compat_model_name_map(all_fields)
|
||||||
schema_generator = GenerateJsonSchema(ref_template=REF_TEMPLATE)
|
schema_generator = GenerateJsonSchema(ref_template=REF_TEMPLATE)
|
||||||
definitions = get_definitions(
|
field_mapping, definitions = get_definitions(
|
||||||
fields=all_fields,
|
fields=all_fields,
|
||||||
schema_generator=schema_generator,
|
schema_generator=schema_generator,
|
||||||
model_name_map=model_name_map,
|
model_name_map=model_name_map,
|
||||||
|
|
@ -449,6 +467,7 @@ def get_openapi(
|
||||||
operation_ids=operation_ids,
|
operation_ids=operation_ids,
|
||||||
schema_generator=schema_generator,
|
schema_generator=schema_generator,
|
||||||
model_name_map=model_name_map,
|
model_name_map=model_name_map,
|
||||||
|
field_mapping=field_mapping,
|
||||||
)
|
)
|
||||||
if result:
|
if result:
|
||||||
path, security_schemes, path_definitions = result
|
path, security_schemes, path_definitions = result
|
||||||
|
|
@ -467,6 +486,7 @@ def get_openapi(
|
||||||
operation_ids=operation_ids,
|
operation_ids=operation_ids,
|
||||||
schema_generator=schema_generator,
|
schema_generator=schema_generator,
|
||||||
model_name_map=model_name_map,
|
model_name_map=model_name_map,
|
||||||
|
field_mapping=field_mapping,
|
||||||
)
|
)
|
||||||
if result:
|
if result:
|
||||||
path, security_schemes, path_definitions = result
|
path, security_schemes, path_definitions = result
|
||||||
|
|
|
||||||
|
|
@ -446,7 +446,11 @@ class APIRoute(routing.Route):
|
||||||
), f"Status code {status_code} must not have a response body"
|
), f"Status code {status_code} must not have a response body"
|
||||||
response_name = "Response_" + self.unique_id
|
response_name = "Response_" + self.unique_id
|
||||||
self.response_field = create_response_field(
|
self.response_field = create_response_field(
|
||||||
name=response_name, type_=self.response_model
|
name=response_name,
|
||||||
|
type_=self.response_model,
|
||||||
|
# TODO: This should actually set mode='serialization', just, that changes the schemas
|
||||||
|
# mode="serialization",
|
||||||
|
mode="validation",
|
||||||
)
|
)
|
||||||
# Create a clone of the field, so that a Pydantic submodel is not returned
|
# Create a clone of the field, so that a Pydantic submodel is not returned
|
||||||
# as is just because it's an instance of a subclass of a more limited class
|
# as is just because it's an instance of a subclass of a more limited class
|
||||||
|
|
|
||||||
|
|
@ -28,6 +28,7 @@ from fastapi._compat import (
|
||||||
from fastapi.datastructures import DefaultPlaceholder, DefaultType
|
from fastapi.datastructures import DefaultPlaceholder, DefaultType
|
||||||
from pydantic import BaseModel, create_model
|
from pydantic import BaseModel, create_model
|
||||||
from pydantic.fields import FieldInfo
|
from pydantic.fields import FieldInfo
|
||||||
|
from typing_extensions import Literal
|
||||||
|
|
||||||
if TYPE_CHECKING: # pragma: nocover
|
if TYPE_CHECKING: # pragma: nocover
|
||||||
from .routing import APIRoute
|
from .routing import APIRoute
|
||||||
|
|
@ -68,6 +69,7 @@ def create_response_field(
|
||||||
model_config: Type[BaseConfig] = BaseConfig,
|
model_config: Type[BaseConfig] = BaseConfig,
|
||||||
field_info: Optional[FieldInfo] = None,
|
field_info: Optional[FieldInfo] = None,
|
||||||
alias: Optional[str] = None,
|
alias: Optional[str] = None,
|
||||||
|
mode: Literal["validation", "serialization"] = "validation",
|
||||||
) -> ModelField:
|
) -> ModelField:
|
||||||
"""
|
"""
|
||||||
Create a new response field. Raises if type_ is invalid.
|
Create a new response field. Raises if type_ is invalid.
|
||||||
|
|
@ -80,7 +82,9 @@ def create_response_field(
|
||||||
else:
|
else:
|
||||||
field_info = field_info or FieldInfo()
|
field_info = field_info or FieldInfo()
|
||||||
kwargs = {"name": name, "field_info": field_info}
|
kwargs = {"name": name, "field_info": field_info}
|
||||||
if not PYDANTIC_V2:
|
if PYDANTIC_V2:
|
||||||
|
kwargs.update({"mode": mode})
|
||||||
|
else:
|
||||||
kwargs.update(
|
kwargs.update(
|
||||||
{
|
{
|
||||||
"type_": type_,
|
"type_": type_,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue