mirror of https://github.com/tiangolo/fastapi.git
🐛 Admit valid types for Pydantic fields as responses models (#1017)
This commit is contained in:
parent
0f9be1d2e7
commit
afad59dfbb
|
|
@ -27,7 +27,12 @@ from fastapi.dependencies.models import Dependant, SecurityRequirement
|
|||
from fastapi.security.base import SecurityBase
|
||||
from fastapi.security.oauth2 import OAuth2, SecurityScopes
|
||||
from fastapi.security.open_id_connect_url import OpenIdConnect
|
||||
from fastapi.utils import PYDANTIC_1, get_field_info, get_path_param_names
|
||||
from fastapi.utils import (
|
||||
PYDANTIC_1,
|
||||
create_response_field,
|
||||
get_field_info,
|
||||
get_path_param_names,
|
||||
)
|
||||
from pydantic import BaseConfig, BaseModel, create_model
|
||||
from pydantic.error_wrappers import ErrorWrapper
|
||||
from pydantic.errors import MissingError
|
||||
|
|
@ -362,30 +367,14 @@ def get_param_field(
|
|||
alias = param.name.replace("_", "-")
|
||||
else:
|
||||
alias = field_info.alias or param.name
|
||||
if PYDANTIC_1:
|
||||
field = ModelField(
|
||||
field = create_response_field(
|
||||
name=param.name,
|
||||
type_=annotation,
|
||||
default=None if required else default_value,
|
||||
alias=alias,
|
||||
required=required,
|
||||
model_config=BaseConfig,
|
||||
class_validators={},
|
||||
field_info=field_info,
|
||||
)
|
||||
# TODO: remove when removing support for Pydantic < 1.2.0
|
||||
field.required = required
|
||||
else: # pragma: nocover
|
||||
field = ModelField( # type: ignore
|
||||
name=param.name,
|
||||
type_=annotation,
|
||||
default=None if required else default_value,
|
||||
alias=alias,
|
||||
required=required,
|
||||
model_config=BaseConfig,
|
||||
class_validators={},
|
||||
schema=field_info,
|
||||
)
|
||||
field.required = required
|
||||
if not had_schema and not is_scalar_field(field=field):
|
||||
if PYDANTIC_1:
|
||||
|
|
@ -694,8 +683,7 @@ def get_schema_compatible_field(*, field: ModelField) -> ModelField:
|
|||
use_type: type = bytes
|
||||
if field.shape in sequence_shapes:
|
||||
use_type = List[bytes]
|
||||
if PYDANTIC_1:
|
||||
out_field = ModelField(
|
||||
out_field = create_response_field(
|
||||
name=field.name,
|
||||
type_=use_type,
|
||||
class_validators=field.class_validators,
|
||||
|
|
@ -703,18 +691,7 @@ def get_schema_compatible_field(*, field: ModelField) -> ModelField:
|
|||
default=field.default,
|
||||
required=field.required,
|
||||
alias=field.alias,
|
||||
field_info=field.field_info,
|
||||
)
|
||||
else: # pragma: nocover
|
||||
out_field = ModelField( # type: ignore
|
||||
name=field.name,
|
||||
type_=use_type,
|
||||
class_validators=field.class_validators,
|
||||
model_config=field.model_config,
|
||||
default=field.default,
|
||||
required=field.required,
|
||||
alias=field.alias,
|
||||
schema=field.schema, # type: ignore
|
||||
field_info=field.field_info if PYDANTIC_1 else field.schema, # type: ignore
|
||||
)
|
||||
|
||||
return out_field
|
||||
|
|
@ -754,26 +731,10 @@ def get_body_field(*, dependant: Dependant, name: str) -> Optional[ModelField]:
|
|||
]
|
||||
if len(set(body_param_media_types)) == 1:
|
||||
BodyFieldInfo_kwargs["media_type"] = body_param_media_types[0]
|
||||
if PYDANTIC_1:
|
||||
field = ModelField(
|
||||
return create_response_field(
|
||||
name="body",
|
||||
type_=BodyModel,
|
||||
default=None,
|
||||
required=required,
|
||||
model_config=BaseConfig,
|
||||
class_validators={},
|
||||
alias="body",
|
||||
field_info=BodyFieldInfo(**BodyFieldInfo_kwargs),
|
||||
)
|
||||
else: # pragma: nocover
|
||||
field = ModelField( # type: ignore
|
||||
name="body",
|
||||
type_=BodyModel,
|
||||
default=None,
|
||||
required=required,
|
||||
model_config=BaseConfig,
|
||||
class_validators={},
|
||||
alias="body",
|
||||
schema=BodyFieldInfo(**BodyFieldInfo_kwargs),
|
||||
)
|
||||
return field
|
||||
|
|
|
|||
|
|
@ -20,6 +20,12 @@ RequestErrorModel = create_model("Request")
|
|||
WebSocketErrorModel = create_model("WebSocket")
|
||||
|
||||
|
||||
class FastAPIError(RuntimeError):
|
||||
"""
|
||||
A generic, FastAPI-specific error.
|
||||
"""
|
||||
|
||||
|
||||
class RequestValidationError(ValidationError):
|
||||
def __init__(self, errors: Sequence[ErrorList], *, body: Any = None) -> None:
|
||||
self.body = body
|
||||
|
|
|
|||
|
|
@ -17,13 +17,13 @@ from fastapi.openapi.constants import STATUS_CODES_WITH_NO_BODY
|
|||
from fastapi.utils import (
|
||||
PYDANTIC_1,
|
||||
create_cloned_field,
|
||||
create_response_field,
|
||||
generate_operation_id_for_path,
|
||||
get_field_info,
|
||||
warning_response_model_skip_defaults_deprecated,
|
||||
)
|
||||
from pydantic import BaseConfig, BaseModel
|
||||
from pydantic import BaseModel
|
||||
from pydantic.error_wrappers import ErrorWrapper, ValidationError
|
||||
from pydantic.utils import lenient_issubclass
|
||||
from starlette import routing
|
||||
from starlette.concurrency import run_in_threadpool
|
||||
from starlette.exceptions import HTTPException
|
||||
|
|
@ -243,25 +243,8 @@ class APIRoute(routing.Route):
|
|||
status_code not in STATUS_CODES_WITH_NO_BODY
|
||||
), f"Status code {status_code} must not have a response body"
|
||||
response_name = "Response_" + self.unique_id
|
||||
if PYDANTIC_1:
|
||||
self.response_field: Optional[ModelField] = ModelField(
|
||||
name=response_name,
|
||||
type_=self.response_model,
|
||||
class_validators={},
|
||||
default=None,
|
||||
required=False,
|
||||
model_config=BaseConfig,
|
||||
field_info=FieldInfo(None),
|
||||
)
|
||||
else:
|
||||
self.response_field: Optional[ModelField] = ModelField( # type: ignore # pragma: nocover
|
||||
name=response_name,
|
||||
type_=self.response_model,
|
||||
class_validators={},
|
||||
default=None,
|
||||
required=False,
|
||||
model_config=BaseConfig,
|
||||
schema=FieldInfo(None),
|
||||
self.response_field = create_response_field(
|
||||
name=response_name, type_=self.response_model
|
||||
)
|
||||
# Create a clone of the field, so that a Pydantic submodel is not returned
|
||||
# as is just because it's an instance of a subclass of a more limited class
|
||||
|
|
@ -274,7 +257,7 @@ class APIRoute(routing.Route):
|
|||
ModelField
|
||||
] = create_cloned_field(self.response_field)
|
||||
else:
|
||||
self.response_field = None
|
||||
self.response_field = None # type: ignore
|
||||
self.secure_cloned_response_field = None
|
||||
self.status_code = status_code
|
||||
self.tags = tags or []
|
||||
|
|
@ -297,30 +280,8 @@ class APIRoute(routing.Route):
|
|||
assert (
|
||||
additional_status_code not in STATUS_CODES_WITH_NO_BODY
|
||||
), f"Status code {additional_status_code} must not have a response body"
|
||||
assert lenient_issubclass(
|
||||
model, BaseModel
|
||||
), "A response model must be a Pydantic model"
|
||||
response_name = f"Response_{additional_status_code}_{self.unique_id}"
|
||||
if PYDANTIC_1:
|
||||
response_field = ModelField(
|
||||
name=response_name,
|
||||
type_=model,
|
||||
class_validators=None,
|
||||
default=None,
|
||||
required=False,
|
||||
model_config=BaseConfig,
|
||||
field_info=FieldInfo(None),
|
||||
)
|
||||
else:
|
||||
response_field = ModelField( # type: ignore # pragma: nocover
|
||||
name=response_name,
|
||||
type_=model,
|
||||
class_validators=None,
|
||||
default=None,
|
||||
required=False,
|
||||
model_config=BaseConfig,
|
||||
schema=FieldInfo(None),
|
||||
)
|
||||
response_field = create_response_field(name=response_name, type_=model)
|
||||
response_fields[additional_status_code] = response_field
|
||||
if response_fields:
|
||||
self.response_fields: Dict[Union[int, str], ModelField] = response_fields
|
||||
|
|
|
|||
|
|
@ -1,17 +1,20 @@
|
|||
import functools
|
||||
import re
|
||||
from dataclasses import is_dataclass
|
||||
from typing import Any, Dict, List, Sequence, Set, Type, cast
|
||||
from typing import Any, Dict, List, Optional, Sequence, Set, Type, Union, cast
|
||||
|
||||
import fastapi
|
||||
from fastapi import routing
|
||||
from fastapi.logger import logger
|
||||
from fastapi.openapi.constants import REF_PREFIX
|
||||
from pydantic import BaseConfig, BaseModel, create_model
|
||||
from pydantic.class_validators import Validator
|
||||
from pydantic.schema import get_flat_models_from_fields, model_process_schema
|
||||
from pydantic.utils import lenient_issubclass
|
||||
from starlette.routing import BaseRoute
|
||||
|
||||
try:
|
||||
from pydantic.fields import FieldInfo, ModelField
|
||||
from pydantic.fields import FieldInfo, ModelField, UndefinedType
|
||||
|
||||
PYDANTIC_1 = True
|
||||
except ImportError: # pragma: nocover
|
||||
|
|
@ -19,6 +22,10 @@ except ImportError: # pragma: nocover
|
|||
from pydantic.fields import Field as ModelField # type: ignore
|
||||
from pydantic import Schema as FieldInfo # type: ignore
|
||||
|
||||
class UndefinedType: # type: ignore
|
||||
def __repr__(self) -> str:
|
||||
return "PydanticUndefined"
|
||||
|
||||
logger.warning(
|
||||
"Pydantic versions < 1.0.0 are deprecated in FastAPI and support will be "
|
||||
"removed soon."
|
||||
|
|
@ -86,6 +93,44 @@ def get_path_param_names(path: str) -> Set[str]:
|
|||
return {item.strip("{}") for item in re.findall("{[^}]*}", path)}
|
||||
|
||||
|
||||
def create_response_field(
|
||||
name: str,
|
||||
type_: Type[Any],
|
||||
class_validators: Optional[Dict[str, Validator]] = None,
|
||||
default: Optional[Any] = None,
|
||||
required: Union[bool, UndefinedType] = False,
|
||||
model_config: Type[BaseConfig] = BaseConfig,
|
||||
field_info: Optional[FieldInfo] = None,
|
||||
alias: Optional[str] = None,
|
||||
) -> ModelField:
|
||||
"""
|
||||
Create a new response field. Raises if type_ is invalid.
|
||||
"""
|
||||
class_validators = class_validators or {}
|
||||
field_info = field_info or FieldInfo(None)
|
||||
|
||||
response_field = functools.partial(
|
||||
ModelField,
|
||||
name=name,
|
||||
type_=type_,
|
||||
class_validators=class_validators,
|
||||
default=default,
|
||||
required=required,
|
||||
model_config=model_config,
|
||||
alias=alias,
|
||||
)
|
||||
|
||||
try:
|
||||
if PYDANTIC_1:
|
||||
return response_field(field_info=field_info)
|
||||
else: # pragma: nocover
|
||||
return response_field(schema=field_info)
|
||||
except RuntimeError:
|
||||
raise fastapi.exceptions.FastAPIError(
|
||||
f"Invalid args for response field! Hint: check that {type_} is a valid pydantic field type"
|
||||
)
|
||||
|
||||
|
||||
def create_cloned_field(field: ModelField) -> ModelField:
|
||||
original_type = field.type_
|
||||
if is_dataclass(original_type) and hasattr(original_type, "__pydantic_model__"):
|
||||
|
|
@ -96,26 +141,8 @@ def create_cloned_field(field: ModelField) -> ModelField:
|
|||
use_type = create_model(original_type.__name__, __base__=original_type)
|
||||
for f in original_type.__fields__.values():
|
||||
use_type.__fields__[f.name] = create_cloned_field(f)
|
||||
if PYDANTIC_1:
|
||||
new_field = ModelField(
|
||||
name=field.name,
|
||||
type_=use_type,
|
||||
class_validators={},
|
||||
default=None,
|
||||
required=False,
|
||||
model_config=BaseConfig,
|
||||
field_info=FieldInfo(None),
|
||||
)
|
||||
else: # pragma: nocover
|
||||
new_field = ModelField( # type: ignore
|
||||
name=field.name,
|
||||
type_=use_type,
|
||||
class_validators={},
|
||||
default=None,
|
||||
required=False,
|
||||
model_config=BaseConfig,
|
||||
schema=FieldInfo(None),
|
||||
)
|
||||
|
||||
new_field = create_response_field(name=field.name, type_=use_type)
|
||||
new_field.has_alias = field.has_alias
|
||||
new_field.alias = field.alias
|
||||
new_field.class_validators = field.class_validators
|
||||
|
|
|
|||
|
|
@ -0,0 +1,45 @@
|
|||
from typing import List
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from fastapi.exceptions import FastAPIError
|
||||
|
||||
|
||||
class NonPydanticModel:
|
||||
pass
|
||||
|
||||
|
||||
def test_invalid_response_model_raises():
|
||||
with pytest.raises(FastAPIError):
|
||||
app = FastAPI()
|
||||
|
||||
@app.get("/", response_model=NonPydanticModel)
|
||||
def read_root():
|
||||
pass # pragma: nocover
|
||||
|
||||
|
||||
def test_invalid_response_model_sub_type_raises():
|
||||
with pytest.raises(FastAPIError):
|
||||
app = FastAPI()
|
||||
|
||||
@app.get("/", response_model=List[NonPydanticModel])
|
||||
def read_root():
|
||||
pass # pragma: nocover
|
||||
|
||||
|
||||
def test_invalid_response_model_in_responses_raises():
|
||||
with pytest.raises(FastAPIError):
|
||||
app = FastAPI()
|
||||
|
||||
@app.get("/", responses={"500": {"model": NonPydanticModel}})
|
||||
def read_root():
|
||||
pass # pragma: nocover
|
||||
|
||||
|
||||
def test_invalid_response_model_sub_type_in_responses_raises():
|
||||
with pytest.raises(FastAPIError):
|
||||
app = FastAPI()
|
||||
|
||||
@app.get("/", responses={"500": {"model": List[NonPydanticModel]}})
|
||||
def read_root():
|
||||
pass # pragma: nocover
|
||||
|
|
@ -0,0 +1,160 @@
|
|||
from typing import List
|
||||
|
||||
from fastapi import FastAPI
|
||||
from pydantic import BaseModel
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
|
||||
class Model(BaseModel):
|
||||
name: str
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
@app.get("/valid1", responses={"500": {"model": int}})
|
||||
def valid1():
|
||||
pass
|
||||
|
||||
|
||||
@app.get("/valid2", responses={"500": {"model": List[int]}})
|
||||
def valid2():
|
||||
pass
|
||||
|
||||
|
||||
@app.get("/valid3", responses={"500": {"model": Model}})
|
||||
def valid3():
|
||||
pass
|
||||
|
||||
|
||||
@app.get("/valid4", responses={"500": {"model": List[Model]}})
|
||||
def valid4():
|
||||
pass
|
||||
|
||||
|
||||
openapi_schema = {
|
||||
"openapi": "3.0.2",
|
||||
"info": {"title": "FastAPI", "version": "0.1.0"},
|
||||
"paths": {
|
||||
"/valid1": {
|
||||
"get": {
|
||||
"summary": "Valid1",
|
||||
"operationId": "valid1_valid1_get",
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": {"application/json": {"schema": {}}},
|
||||
},
|
||||
"500": {
|
||||
"description": "Internal Server Error",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"title": "Response 500 Valid1 Valid1 Get",
|
||||
"type": "integer",
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
"/valid2": {
|
||||
"get": {
|
||||
"summary": "Valid2",
|
||||
"operationId": "valid2_valid2_get",
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": {"application/json": {"schema": {}}},
|
||||
},
|
||||
"500": {
|
||||
"description": "Internal Server Error",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"title": "Response 500 Valid2 Valid2 Get",
|
||||
"type": "array",
|
||||
"items": {"type": "integer"},
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
"/valid3": {
|
||||
"get": {
|
||||
"summary": "Valid3",
|
||||
"operationId": "valid3_valid3_get",
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": {"application/json": {"schema": {}}},
|
||||
},
|
||||
"500": {
|
||||
"description": "Internal Server Error",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {"$ref": "#/components/schemas/Model"}
|
||||
}
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
"/valid4": {
|
||||
"get": {
|
||||
"summary": "Valid4",
|
||||
"operationId": "valid4_valid4_get",
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": {"application/json": {"schema": {}}},
|
||||
},
|
||||
"500": {
|
||||
"description": "Internal Server Error",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"title": "Response 500 Valid4 Valid4 Get",
|
||||
"type": "array",
|
||||
"items": {"$ref": "#/components/schemas/Model"},
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
},
|
||||
"components": {
|
||||
"schemas": {
|
||||
"Model": {
|
||||
"title": "Model",
|
||||
"required": ["name"],
|
||||
"type": "object",
|
||||
"properties": {"name": {"title": "Name", "type": "string"}},
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
def test_openapi_schema():
|
||||
response = client.get("/openapi.json")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == openapi_schema
|
||||
|
||||
|
||||
def test_path_operations():
|
||||
response = client.get("/valid1")
|
||||
assert response.status_code == 200
|
||||
response = client.get("/valid2")
|
||||
assert response.status_code == 200
|
||||
response = client.get("/valid3")
|
||||
assert response.status_code == 200
|
||||
response = client.get("/valid4")
|
||||
assert response.status_code == 200
|
||||
Loading…
Reference in New Issue