mirror of https://github.com/tiangolo/fastapi.git
🐛 Fix Enum handling with their own schema definitions (#1463)
* 🐛 Fix extra support for enum with its own schema * ✅ Fix/update test for enum with its own schema * 🐛 Fix type declarations * 🔧 Update format and lint scripts to support locally installed Pydantic and Starlette * 🐛 Add temporary type ignores while enum schemas are merged
This commit is contained in:
parent
98bb9f13da
commit
5984233223
|
|
@ -188,6 +188,16 @@ def get_flat_dependant(
|
|||
return flat_dependant
|
||||
|
||||
|
||||
def get_flat_params(dependant: Dependant) -> List[ModelField]:
|
||||
flat_dependant = get_flat_dependant(dependant, skip_repeats=True)
|
||||
return (
|
||||
flat_dependant.path_params
|
||||
+ flat_dependant.query_params
|
||||
+ flat_dependant.header_params
|
||||
+ flat_dependant.cookie_params
|
||||
)
|
||||
|
||||
|
||||
def is_scalar_field(field: ModelField) -> bool:
|
||||
field_info = get_field_info(field)
|
||||
if not (
|
||||
|
|
|
|||
|
|
@ -1,9 +1,10 @@
|
|||
import http.client
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, cast
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Type, Union, cast
|
||||
|
||||
from fastapi import routing
|
||||
from fastapi.dependencies.models import Dependant
|
||||
from fastapi.dependencies.utils import get_flat_dependant
|
||||
from fastapi.dependencies.utils import get_flat_dependant, get_flat_params
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from fastapi.openapi.constants import (
|
||||
METHODS_WITH_BODY,
|
||||
|
|
@ -15,11 +16,14 @@ from fastapi.params import Body, Param
|
|||
from fastapi.utils import (
|
||||
generate_operation_id_for_path,
|
||||
get_field_info,
|
||||
get_flat_models_from_routes,
|
||||
get_model_definitions,
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
from pydantic.schema import field_schema, get_model_name_map
|
||||
from pydantic.schema import (
|
||||
field_schema,
|
||||
get_flat_models_from_fields,
|
||||
get_model_name_map,
|
||||
)
|
||||
from pydantic.utils import lenient_issubclass
|
||||
from starlette.responses import JSONResponse
|
||||
from starlette.routing import BaseRoute
|
||||
|
|
@ -64,16 +68,6 @@ status_code_ranges: Dict[str, str] = {
|
|||
}
|
||||
|
||||
|
||||
def get_openapi_params(dependant: Dependant) -> List[ModelField]:
|
||||
flat_dependant = get_flat_dependant(dependant, skip_repeats=True)
|
||||
return (
|
||||
flat_dependant.path_params
|
||||
+ flat_dependant.query_params
|
||||
+ flat_dependant.header_params
|
||||
+ flat_dependant.cookie_params
|
||||
)
|
||||
|
||||
|
||||
def get_openapi_security_definitions(flat_dependant: Dependant) -> Tuple[Dict, List]:
|
||||
security_definitions = {}
|
||||
operation_security = []
|
||||
|
|
@ -90,17 +84,22 @@ def get_openapi_security_definitions(flat_dependant: Dependant) -> Tuple[Dict, L
|
|||
|
||||
|
||||
def get_openapi_operation_parameters(
|
||||
*,
|
||||
all_route_params: Sequence[ModelField],
|
||||
model_name_map: Dict[Union[Type[BaseModel], Type[Enum]], str]
|
||||
) -> List[Dict[str, Any]]:
|
||||
parameters = []
|
||||
for param in all_route_params:
|
||||
field_info = get_field_info(param)
|
||||
field_info = cast(Param, field_info)
|
||||
# ignore mypy error until enum schemas are released
|
||||
parameter = {
|
||||
"name": param.alias,
|
||||
"in": field_info.in_.value,
|
||||
"required": param.required,
|
||||
"schema": field_schema(param, model_name_map={})[0],
|
||||
"schema": field_schema(
|
||||
param, model_name_map=model_name_map, ref_prefix=REF_PREFIX # type: ignore
|
||||
)[0],
|
||||
}
|
||||
if field_info.description:
|
||||
parameter["description"] = field_info.description
|
||||
|
|
@ -111,13 +110,16 @@ def get_openapi_operation_parameters(
|
|||
|
||||
|
||||
def get_openapi_operation_request_body(
|
||||
*, body_field: Optional[ModelField], model_name_map: Dict[Type[BaseModel], str]
|
||||
*,
|
||||
body_field: Optional[ModelField],
|
||||
model_name_map: Dict[Union[Type[BaseModel], Type[Enum]], str]
|
||||
) -> Optional[Dict]:
|
||||
if not body_field:
|
||||
return None
|
||||
assert isinstance(body_field, ModelField)
|
||||
# ignore mypy error until enum schemas are released
|
||||
body_schema, _, _ = field_schema(
|
||||
body_field, model_name_map=model_name_map, ref_prefix=REF_PREFIX
|
||||
body_field, model_name_map=model_name_map, ref_prefix=REF_PREFIX # type: ignore
|
||||
)
|
||||
field_info = cast(Body, get_field_info(body_field))
|
||||
request_media_type = field_info.media_type
|
||||
|
|
@ -176,8 +178,10 @@ def get_openapi_path(
|
|||
operation.setdefault("security", []).extend(operation_security)
|
||||
if security_definitions:
|
||||
security_schemes.update(security_definitions)
|
||||
all_route_params = get_openapi_params(route.dependant)
|
||||
operation_parameters = get_openapi_operation_parameters(all_route_params)
|
||||
all_route_params = get_flat_params(route.dependant)
|
||||
operation_parameters = get_openapi_operation_parameters(
|
||||
all_route_params=all_route_params, model_name_map=model_name_map
|
||||
)
|
||||
parameters.extend(operation_parameters)
|
||||
if parameters:
|
||||
operation["parameters"] = list(
|
||||
|
|
@ -270,6 +274,38 @@ def get_openapi_path(
|
|||
return path, security_schemes, definitions
|
||||
|
||||
|
||||
def get_flat_models_from_routes(
|
||||
routes: Sequence[BaseRoute],
|
||||
) -> Set[Union[Type[BaseModel], Type[Enum]]]:
|
||||
body_fields_from_routes: List[ModelField] = []
|
||||
responses_from_routes: List[ModelField] = []
|
||||
request_fields_from_routes: List[ModelField] = []
|
||||
callback_flat_models: Set[Union[Type[BaseModel], Type[Enum]]] = set()
|
||||
for route in routes:
|
||||
if getattr(route, "include_in_schema", None) and isinstance(
|
||||
route, routing.APIRoute
|
||||
):
|
||||
if route.body_field:
|
||||
assert isinstance(
|
||||
route.body_field, ModelField
|
||||
), "A request body must be a Pydantic Field"
|
||||
body_fields_from_routes.append(route.body_field)
|
||||
if route.response_field:
|
||||
responses_from_routes.append(route.response_field)
|
||||
if route.response_fields:
|
||||
responses_from_routes.extend(route.response_fields.values())
|
||||
if route.callbacks:
|
||||
callback_flat_models |= get_flat_models_from_routes(route.callbacks)
|
||||
params = get_flat_params(route.dependant)
|
||||
request_fields_from_routes.extend(params)
|
||||
|
||||
flat_models = callback_flat_models | get_flat_models_from_fields(
|
||||
body_fields_from_routes + responses_from_routes + request_fields_from_routes,
|
||||
known_models=set(),
|
||||
)
|
||||
return flat_models
|
||||
|
||||
|
||||
def get_openapi(
|
||||
*,
|
||||
title: str,
|
||||
|
|
@ -286,9 +322,11 @@ def get_openapi(
|
|||
components: Dict[str, Dict] = {}
|
||||
paths: Dict[str, Dict] = {}
|
||||
flat_models = get_flat_models_from_routes(routes)
|
||||
model_name_map = get_model_name_map(flat_models)
|
||||
# ignore mypy error until enum schemas are released
|
||||
model_name_map = get_model_name_map(flat_models) # type: ignore
|
||||
# ignore mypy error until enum schemas are released
|
||||
definitions = get_model_definitions(
|
||||
flat_models=flat_models, model_name_map=model_name_map
|
||||
flat_models=flat_models, model_name_map=model_name_map # type: ignore
|
||||
)
|
||||
for route in routes:
|
||||
if isinstance(route, routing.APIRoute):
|
||||
|
|
|
|||
|
|
@ -1,17 +1,16 @@
|
|||
import functools
|
||||
import re
|
||||
from dataclasses import is_dataclass
|
||||
from typing import Any, Dict, List, Optional, Sequence, Set, Type, Union, cast
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Optional, Set, Type, Union, cast
|
||||
|
||||
import fastapi
|
||||
from fastapi import routing
|
||||
from fastapi.logger import logger
|
||||
from fastapi.openapi.constants import REF_PREFIX
|
||||
from pydantic import BaseConfig, BaseModel, create_model
|
||||
from pydantic.class_validators import Validator
|
||||
from pydantic.schema import get_flat_models_from_fields, model_process_schema
|
||||
from pydantic.schema import model_process_schema
|
||||
from pydantic.utils import lenient_issubclass
|
||||
from starlette.routing import BaseRoute
|
||||
|
||||
try:
|
||||
from pydantic.fields import FieldInfo, ModelField, UndefinedType
|
||||
|
|
@ -50,38 +49,16 @@ def warning_response_model_skip_defaults_deprecated() -> None:
|
|||
)
|
||||
|
||||
|
||||
def get_flat_models_from_routes(routes: Sequence[BaseRoute]) -> Set[Type[BaseModel]]:
|
||||
body_fields_from_routes: List[ModelField] = []
|
||||
responses_from_routes: List[ModelField] = []
|
||||
callback_flat_models: Set[Type[BaseModel]] = set()
|
||||
for route in routes:
|
||||
if getattr(route, "include_in_schema", None) and isinstance(
|
||||
route, routing.APIRoute
|
||||
):
|
||||
if route.body_field:
|
||||
assert isinstance(
|
||||
route.body_field, ModelField
|
||||
), "A request body must be a Pydantic Field"
|
||||
body_fields_from_routes.append(route.body_field)
|
||||
if route.response_field:
|
||||
responses_from_routes.append(route.response_field)
|
||||
if route.response_fields:
|
||||
responses_from_routes.extend(route.response_fields.values())
|
||||
if route.callbacks:
|
||||
callback_flat_models |= get_flat_models_from_routes(route.callbacks)
|
||||
flat_models = callback_flat_models | get_flat_models_from_fields(
|
||||
body_fields_from_routes + responses_from_routes, known_models=set()
|
||||
)
|
||||
return flat_models
|
||||
|
||||
|
||||
def get_model_definitions(
|
||||
*, flat_models: Set[Type[BaseModel]], model_name_map: Dict[Type[BaseModel], str]
|
||||
*,
|
||||
flat_models: Set[Union[Type[BaseModel], Type[Enum]]],
|
||||
model_name_map: Dict[Union[Type[BaseModel], Type[Enum]], str],
|
||||
) -> Dict[str, Any]:
|
||||
definitions: Dict[str, Dict] = {}
|
||||
for model in flat_models:
|
||||
# ignore mypy error until enum schemas are released
|
||||
m_schema, m_definitions, m_nested_models = model_process_schema(
|
||||
model, model_name_map=model_name_map, ref_prefix=REF_PREFIX
|
||||
model, model_name_map=model_name_map, ref_prefix=REF_PREFIX # type: ignore
|
||||
)
|
||||
definitions.update(m_definitions)
|
||||
model_name = model_name_map[model]
|
||||
|
|
|
|||
|
|
@ -3,4 +3,4 @@ set -x
|
|||
|
||||
autoflake --remove-all-unused-imports --recursive --remove-unused-variables --in-place docs_src fastapi tests scripts --exclude=__init__.py
|
||||
black fastapi tests docs_src scripts
|
||||
isort --multi-line=3 --trailing-comma --force-grid-wrap=0 --combine-as --line-width 88 --recursive --thirdparty fastapi --apply fastapi tests docs_src scripts
|
||||
isort --multi-line=3 --trailing-comma --force-grid-wrap=0 --combine-as --line-width 88 --recursive --thirdparty fastapi --thirdparty pydantic --thirdparty starlette --apply fastapi tests docs_src scripts
|
||||
|
|
|
|||
|
|
@ -5,4 +5,4 @@ set -x
|
|||
|
||||
mypy fastapi
|
||||
black fastapi tests --check
|
||||
isort --multi-line=3 --trailing-comma --force-grid-wrap=0 --combine-as --line-width 88 --recursive --check-only --thirdparty fastapi fastapi tests
|
||||
isort --multi-line=3 --trailing-comma --force-grid-wrap=0 --combine-as --line-width 88 --recursive --check-only --thirdparty fastapi --thirdparty fastapi --thirdparty pydantic --thirdparty starlette fastapi tests
|
||||
|
|
|
|||
|
|
@ -87,7 +87,7 @@ openapi_schema2 = {
|
|||
"parameters": [
|
||||
{
|
||||
"required": True,
|
||||
"schema": {"$ref": "#/definitions/ModelName"},
|
||||
"schema": {"$ref": "#/components/schemas/ModelName"},
|
||||
"name": "model_name",
|
||||
"in": "path",
|
||||
}
|
||||
|
|
@ -124,6 +124,12 @@ openapi_schema2 = {
|
|||
}
|
||||
},
|
||||
},
|
||||
"ModelName": {
|
||||
"title": "ModelName",
|
||||
"enum": ["alexnet", "resnet", "lenet"],
|
||||
"type": "string",
|
||||
"description": "An enumeration.",
|
||||
},
|
||||
"ValidationError": {
|
||||
"title": "ValidationError",
|
||||
"required": ["loc", "msg", "type"],
|
||||
|
|
|
|||
Loading…
Reference in New Issue