mirror of https://github.com/tiangolo/fastapi.git
Merge 01152f60dd into 272204c0c7
This commit is contained in:
commit
3c308e0cba
|
|
@ -979,6 +979,7 @@ class FastAPI(Starlette):
|
||||||
include_in_schema=include_in_schema,
|
include_in_schema=include_in_schema,
|
||||||
responses=responses,
|
responses=responses,
|
||||||
generate_unique_id_function=generate_unique_id_function,
|
generate_unique_id_function=generate_unique_id_function,
|
||||||
|
defer_init=False,
|
||||||
)
|
)
|
||||||
self.exception_handlers: Dict[
|
self.exception_handlers: Dict[
|
||||||
Any, Callable[[Request, Any], Union[Response, Awaitable[Response]]]
|
Any, Callable[[Request, Any], Union[Response, Awaitable[Response]]]
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@ import inspect
|
||||||
import json
|
import json
|
||||||
from contextlib import AsyncExitStack, asynccontextmanager
|
from contextlib import AsyncExitStack, asynccontextmanager
|
||||||
from enum import Enum, IntEnum
|
from enum import Enum, IntEnum
|
||||||
|
from functools import cached_property
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
AsyncIterator,
|
AsyncIterator,
|
||||||
|
|
@ -591,6 +592,7 @@ class APIRoute(routing.Route):
|
||||||
generate_unique_id_function: Union[
|
generate_unique_id_function: Union[
|
||||||
Callable[["APIRoute"], str], DefaultPlaceholder
|
Callable[["APIRoute"], str], DefaultPlaceholder
|
||||||
] = Default(generate_unique_id),
|
] = Default(generate_unique_id),
|
||||||
|
defer_init: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.path = path
|
self.path = path
|
||||||
self.endpoint = endpoint
|
self.endpoint = endpoint
|
||||||
|
|
@ -639,31 +641,42 @@ class APIRoute(routing.Route):
|
||||||
assert is_body_allowed_for_status_code(status_code), (
|
assert is_body_allowed_for_status_code(status_code), (
|
||||||
f"Status code {status_code} must not have a response body"
|
f"Status code {status_code} must not have a response body"
|
||||||
)
|
)
|
||||||
response_name = "Response_" + self.unique_id
|
|
||||||
self.response_field = create_model_field(
|
|
||||||
name=response_name,
|
|
||||||
type_=self.response_model,
|
|
||||||
mode="serialization",
|
|
||||||
)
|
|
||||||
# Create a clone of the field, so that a Pydantic submodel is not returned
|
|
||||||
# as is just because it's an instance of a subclass of a more limited class
|
|
||||||
# e.g. UserInDB (containing hashed_password) could be a subclass of User
|
|
||||||
# that doesn't have the hashed_password. But because it's a subclass, it
|
|
||||||
# would pass the validation and be returned as is.
|
|
||||||
# By being a new field, no inheritance will be passed as is. A new model
|
|
||||||
# will always be created.
|
|
||||||
# TODO: remove when deprecating Pydantic v1
|
|
||||||
self.secure_cloned_response_field: Optional[ModelField] = (
|
|
||||||
create_cloned_field(self.response_field)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.response_field = None # type: ignore
|
|
||||||
self.secure_cloned_response_field = None
|
|
||||||
self.dependencies = list(dependencies or [])
|
self.dependencies = list(dependencies or [])
|
||||||
self.description = description or inspect.cleandoc(self.endpoint.__doc__ or "")
|
self.description = description or inspect.cleandoc(self.endpoint.__doc__ or "")
|
||||||
# if a "form feed" character (page break) is found in the description text,
|
# if a "form feed" character (page break) is found in the description text,
|
||||||
# truncate description text to the content preceding the first "form feed"
|
# truncate description text to the content preceding the first "form feed"
|
||||||
self.description = self.description.split("\f")[0].strip()
|
self.description = self.description.split("\f")[0].strip()
|
||||||
|
|
||||||
|
assert callable(endpoint), "An endpoint must be a callable"
|
||||||
|
|
||||||
|
if not defer_init:
|
||||||
|
self.init_attributes()
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def response_field(self) -> Optional[ModelField]:
|
||||||
|
if not self.response_model:
|
||||||
|
return None
|
||||||
|
response_name = "Response_" + self.unique_id
|
||||||
|
return create_model_field(
|
||||||
|
name=response_name,
|
||||||
|
type_=self.response_model,
|
||||||
|
mode="serialization",
|
||||||
|
)
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def secure_cloned_response_field(self) -> Optional[ModelField]:
|
||||||
|
# Create a clone of the field, so that a Pydantic submodel is not returned
|
||||||
|
# as is just because it's an instance of a subclass of a more limited class
|
||||||
|
# e.g. UserInDB (containing hashed_password) could be a subclass of User
|
||||||
|
# that doesn't have the hashed_password. But because it's a subclass, it
|
||||||
|
# would pass the validation and be returned as is.
|
||||||
|
# By being a new field, no inheritance will be passed as is. A new model
|
||||||
|
# will always be created.
|
||||||
|
# TODO: remove when deprecating Pydantic v1
|
||||||
|
return create_cloned_field(self.response_field) if self.response_field else None
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def response_fields(self) -> Dict[Union[int, str], ModelField]:
|
||||||
response_fields = {}
|
response_fields = {}
|
||||||
for additional_status_code, response in self.responses.items():
|
for additional_status_code, response in self.responses.items():
|
||||||
assert isinstance(response, dict), "An additional response must be a dict"
|
assert isinstance(response, dict), "An additional response must be a dict"
|
||||||
|
|
@ -677,30 +690,49 @@ class APIRoute(routing.Route):
|
||||||
name=response_name, type_=model, mode="serialization"
|
name=response_name, type_=model, mode="serialization"
|
||||||
)
|
)
|
||||||
response_fields[additional_status_code] = response_field
|
response_fields[additional_status_code] = response_field
|
||||||
if response_fields:
|
return response_fields
|
||||||
self.response_fields: Dict[Union[int, str], ModelField] = response_fields
|
|
||||||
else:
|
|
||||||
self.response_fields = {}
|
|
||||||
|
|
||||||
assert callable(endpoint), "An endpoint must be a callable"
|
@cached_property
|
||||||
self.dependant = get_dependant(
|
def dependant(self) -> Dependant:
|
||||||
|
dependant = get_dependant(
|
||||||
path=self.path_format, call=self.endpoint, scope="function"
|
path=self.path_format, call=self.endpoint, scope="function"
|
||||||
)
|
)
|
||||||
for depends in self.dependencies[::-1]:
|
for depends in self.dependencies[::-1]:
|
||||||
self.dependant.dependencies.insert(
|
dependant.dependencies.insert(
|
||||||
0,
|
0,
|
||||||
get_parameterless_sub_dependant(depends=depends, path=self.path_format),
|
get_parameterless_sub_dependant(depends=depends, path=self.path_format),
|
||||||
)
|
)
|
||||||
self._flat_dependant = get_flat_dependant(self.dependant)
|
return dependant
|
||||||
self._embed_body_fields = _should_embed_body_fields(
|
|
||||||
self._flat_dependant.body_params
|
@cached_property
|
||||||
)
|
def _flat_dependant(self) -> Dependant:
|
||||||
self.body_field = get_body_field(
|
return get_flat_dependant(self.dependant)
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def _embed_body_fields(self) -> bool:
|
||||||
|
return _should_embed_body_fields(self._flat_dependant.body_params)
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def body_field(self) -> Optional[ModelField]:
|
||||||
|
return get_body_field(
|
||||||
flat_dependant=self._flat_dependant,
|
flat_dependant=self._flat_dependant,
|
||||||
name=self.unique_id,
|
name=self.unique_id,
|
||||||
embed_body_fields=self._embed_body_fields,
|
embed_body_fields=self._embed_body_fields,
|
||||||
)
|
)
|
||||||
self.app = request_response(self.get_route_handler())
|
|
||||||
|
@cached_property
|
||||||
|
def app(self) -> ASGIApp: # type: ignore
|
||||||
|
return request_response(self.get_route_handler())
|
||||||
|
|
||||||
|
def init_attributes(self) -> None:
|
||||||
|
_ = self.app
|
||||||
|
_ = self.dependant
|
||||||
|
_ = self.response_field
|
||||||
|
_ = self.response_fields
|
||||||
|
_ = self.secure_cloned_response_field
|
||||||
|
_ = self.body_field
|
||||||
|
_ = self._flat_dependant
|
||||||
|
_ = self._embed_body_fields
|
||||||
|
|
||||||
def get_route_handler(self) -> Callable[[Request], Coroutine[Any, Any, Response]]:
|
def get_route_handler(self) -> Callable[[Request], Coroutine[Any, Any, Response]]:
|
||||||
return get_request_handler(
|
return get_request_handler(
|
||||||
|
|
@ -967,6 +999,16 @@ class APIRouter(routing.Router):
|
||||||
"""
|
"""
|
||||||
),
|
),
|
||||||
] = Default(generate_unique_id),
|
] = Default(generate_unique_id),
|
||||||
|
defer_init: Annotated[
|
||||||
|
bool,
|
||||||
|
Doc(
|
||||||
|
"""
|
||||||
|
By default, every route will defer its initialization until the first call.
|
||||||
|
This flag can be used to deactivate this behavior for the routes defined in this router,
|
||||||
|
causing the routes to initialize immediately when they are defined.
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
] = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(
|
super().__init__(
|
||||||
routes=routes,
|
routes=routes,
|
||||||
|
|
@ -992,6 +1034,9 @@ class APIRouter(routing.Router):
|
||||||
self.route_class = route_class
|
self.route_class = route_class
|
||||||
self.default_response_class = default_response_class
|
self.default_response_class = default_response_class
|
||||||
self.generate_unique_id_function = generate_unique_id_function
|
self.generate_unique_id_function = generate_unique_id_function
|
||||||
|
self.defer_init = defer_init
|
||||||
|
if not self.defer_init:
|
||||||
|
self.init_routes()
|
||||||
|
|
||||||
def route(
|
def route(
|
||||||
self,
|
self,
|
||||||
|
|
@ -1091,6 +1136,7 @@ class APIRouter(routing.Router):
|
||||||
callbacks=current_callbacks,
|
callbacks=current_callbacks,
|
||||||
openapi_extra=openapi_extra,
|
openapi_extra=openapi_extra,
|
||||||
generate_unique_id_function=current_generate_unique_id,
|
generate_unique_id_function=current_generate_unique_id,
|
||||||
|
defer_init=self.defer_init,
|
||||||
)
|
)
|
||||||
self.routes.append(route)
|
self.routes.append(route)
|
||||||
|
|
||||||
|
|
@ -1253,6 +1299,11 @@ class APIRouter(routing.Router):
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
def init_routes(self) -> None:
|
||||||
|
for route in self.routes:
|
||||||
|
if isinstance(route, APIRoute):
|
||||||
|
route.init_attributes()
|
||||||
|
|
||||||
def include_router(
|
def include_router(
|
||||||
self,
|
self,
|
||||||
router: Annotated["APIRouter", Doc("The `APIRouter` to include.")],
|
router: Annotated["APIRouter", Doc("The `APIRouter` to include.")],
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,117 @@
|
||||||
|
from itertools import chain
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, FastAPI
|
||||||
|
from fastapi.routing import APIRoute
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from starlette.routing import BaseRoute
|
||||||
|
|
||||||
|
deferred_keys = [
|
||||||
|
"app",
|
||||||
|
"response_fields",
|
||||||
|
"body_field",
|
||||||
|
"response_field",
|
||||||
|
"secure_cloned_response_field",
|
||||||
|
"dependant",
|
||||||
|
"_flat_dependant",
|
||||||
|
"_embed_body_fields",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def check_if_initialized(route: APIRoute, should_not: bool = False):
|
||||||
|
for key in deferred_keys:
|
||||||
|
if should_not:
|
||||||
|
assert key not in route.__dict__
|
||||||
|
else:
|
||||||
|
assert key in route.__dict__
|
||||||
|
|
||||||
|
|
||||||
|
def create_test_router(
|
||||||
|
routes: Optional[List[BaseRoute]] = None, defer_init: bool = True
|
||||||
|
):
|
||||||
|
router = APIRouter(routes=routes or [], defer_init=defer_init)
|
||||||
|
|
||||||
|
class UserIdBody(BaseModel):
|
||||||
|
user_id: int
|
||||||
|
|
||||||
|
@router.get("/user_id", dependencies=[Depends(lambda: True)])
|
||||||
|
async def get_user_id(user_id: int = Depends(lambda: 1)) -> UserIdBody:
|
||||||
|
return {"user_id": user_id}
|
||||||
|
|
||||||
|
return router
|
||||||
|
|
||||||
|
|
||||||
|
def test_route_defers():
|
||||||
|
app = FastAPI()
|
||||||
|
router = create_test_router(routes=app.router.routes)
|
||||||
|
|
||||||
|
for route in router.routes:
|
||||||
|
if not isinstance(route, APIRoute):
|
||||||
|
continue
|
||||||
|
check_if_initialized(route, should_not=True)
|
||||||
|
|
||||||
|
app.router = router
|
||||||
|
client = TestClient(app)
|
||||||
|
response = client.get("/user_id")
|
||||||
|
assert response.status_code == 200
|
||||||
|
response = client.get("/openapi.json")
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
for route in router.routes:
|
||||||
|
if not isinstance(route, APIRoute):
|
||||||
|
continue
|
||||||
|
check_if_initialized(route)
|
||||||
|
|
||||||
|
|
||||||
|
def test_route_manual_init():
|
||||||
|
router = create_test_router()
|
||||||
|
for route in router.routes:
|
||||||
|
check_if_initialized(route, should_not=True)
|
||||||
|
route.init_attributes()
|
||||||
|
check_if_initialized(route)
|
||||||
|
|
||||||
|
router = create_test_router()
|
||||||
|
router.init_routes()
|
||||||
|
for route in router.routes:
|
||||||
|
check_if_initialized(route)
|
||||||
|
|
||||||
|
|
||||||
|
def test_router_defer_init_flag():
|
||||||
|
route = APIRoute("/test", lambda: {"test": True}, defer_init=False)
|
||||||
|
check_if_initialized(route)
|
||||||
|
|
||||||
|
deferring_router = create_test_router()
|
||||||
|
router = create_test_router(routes=deferring_router.routes, defer_init=False)
|
||||||
|
|
||||||
|
for route in router.routes:
|
||||||
|
check_if_initialized(route)
|
||||||
|
|
||||||
|
|
||||||
|
def test_root_router_always_initialized():
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
@app.get("/test")
|
||||||
|
async def test_get():
|
||||||
|
return {"test": 1}
|
||||||
|
|
||||||
|
router = create_test_router()
|
||||||
|
app.include_router(router)
|
||||||
|
for route in app.router.routes:
|
||||||
|
if not isinstance(route, APIRoute):
|
||||||
|
continue
|
||||||
|
check_if_initialized(route)
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
response = client.get("/test")
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
|
||||||
|
def test_include_router_no_init():
|
||||||
|
router1 = create_test_router()
|
||||||
|
|
||||||
|
router2 = create_test_router()
|
||||||
|
router2.include_router(router1)
|
||||||
|
|
||||||
|
for route in chain(router1.routes, router2.routes):
|
||||||
|
check_if_initialized(route, should_not=True)
|
||||||
Loading…
Reference in New Issue