mirror of https://github.com/tiangolo/fastapi.git
Merge 01152f60dd into 272204c0c7
This commit is contained in:
commit
3c308e0cba
|
|
@ -979,6 +979,7 @@ class FastAPI(Starlette):
|
|||
include_in_schema=include_in_schema,
|
||||
responses=responses,
|
||||
generate_unique_id_function=generate_unique_id_function,
|
||||
defer_init=False,
|
||||
)
|
||||
self.exception_handlers: Dict[
|
||||
Any, Callable[[Request, Any], Union[Response, Awaitable[Response]]]
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ import inspect
|
|||
import json
|
||||
from contextlib import AsyncExitStack, asynccontextmanager
|
||||
from enum import Enum, IntEnum
|
||||
from functools import cached_property
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
|
|
@ -591,6 +592,7 @@ class APIRoute(routing.Route):
|
|||
generate_unique_id_function: Union[
|
||||
Callable[["APIRoute"], str], DefaultPlaceholder
|
||||
] = Default(generate_unique_id),
|
||||
defer_init: bool = True,
|
||||
) -> None:
|
||||
self.path = path
|
||||
self.endpoint = endpoint
|
||||
|
|
@ -639,31 +641,42 @@ class APIRoute(routing.Route):
|
|||
assert is_body_allowed_for_status_code(status_code), (
|
||||
f"Status code {status_code} must not have a response body"
|
||||
)
|
||||
response_name = "Response_" + self.unique_id
|
||||
self.response_field = create_model_field(
|
||||
name=response_name,
|
||||
type_=self.response_model,
|
||||
mode="serialization",
|
||||
)
|
||||
# Create a clone of the field, so that a Pydantic submodel is not returned
|
||||
# as is just because it's an instance of a subclass of a more limited class
|
||||
# e.g. UserInDB (containing hashed_password) could be a subclass of User
|
||||
# that doesn't have the hashed_password. But because it's a subclass, it
|
||||
# would pass the validation and be returned as is.
|
||||
# By being a new field, no inheritance will be passed as is. A new model
|
||||
# will always be created.
|
||||
# TODO: remove when deprecating Pydantic v1
|
||||
self.secure_cloned_response_field: Optional[ModelField] = (
|
||||
create_cloned_field(self.response_field)
|
||||
)
|
||||
else:
|
||||
self.response_field = None # type: ignore
|
||||
self.secure_cloned_response_field = None
|
||||
self.dependencies = list(dependencies or [])
|
||||
self.description = description or inspect.cleandoc(self.endpoint.__doc__ or "")
|
||||
# if a "form feed" character (page break) is found in the description text,
|
||||
# truncate description text to the content preceding the first "form feed"
|
||||
self.description = self.description.split("\f")[0].strip()
|
||||
|
||||
assert callable(endpoint), "An endpoint must be a callable"
|
||||
|
||||
if not defer_init:
|
||||
self.init_attributes()
|
||||
|
||||
@cached_property
|
||||
def response_field(self) -> Optional[ModelField]:
|
||||
if not self.response_model:
|
||||
return None
|
||||
response_name = "Response_" + self.unique_id
|
||||
return create_model_field(
|
||||
name=response_name,
|
||||
type_=self.response_model,
|
||||
mode="serialization",
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def secure_cloned_response_field(self) -> Optional[ModelField]:
|
||||
# Create a clone of the field, so that a Pydantic submodel is not returned
|
||||
# as is just because it's an instance of a subclass of a more limited class
|
||||
# e.g. UserInDB (containing hashed_password) could be a subclass of User
|
||||
# that doesn't have the hashed_password. But because it's a subclass, it
|
||||
# would pass the validation and be returned as is.
|
||||
# By being a new field, no inheritance will be passed as is. A new model
|
||||
# will always be created.
|
||||
# TODO: remove when deprecating Pydantic v1
|
||||
return create_cloned_field(self.response_field) if self.response_field else None
|
||||
|
||||
@cached_property
|
||||
def response_fields(self) -> Dict[Union[int, str], ModelField]:
|
||||
response_fields = {}
|
||||
for additional_status_code, response in self.responses.items():
|
||||
assert isinstance(response, dict), "An additional response must be a dict"
|
||||
|
|
@ -677,30 +690,49 @@ class APIRoute(routing.Route):
|
|||
name=response_name, type_=model, mode="serialization"
|
||||
)
|
||||
response_fields[additional_status_code] = response_field
|
||||
if response_fields:
|
||||
self.response_fields: Dict[Union[int, str], ModelField] = response_fields
|
||||
else:
|
||||
self.response_fields = {}
|
||||
return response_fields
|
||||
|
||||
assert callable(endpoint), "An endpoint must be a callable"
|
||||
self.dependant = get_dependant(
|
||||
@cached_property
|
||||
def dependant(self) -> Dependant:
|
||||
dependant = get_dependant(
|
||||
path=self.path_format, call=self.endpoint, scope="function"
|
||||
)
|
||||
for depends in self.dependencies[::-1]:
|
||||
self.dependant.dependencies.insert(
|
||||
dependant.dependencies.insert(
|
||||
0,
|
||||
get_parameterless_sub_dependant(depends=depends, path=self.path_format),
|
||||
)
|
||||
self._flat_dependant = get_flat_dependant(self.dependant)
|
||||
self._embed_body_fields = _should_embed_body_fields(
|
||||
self._flat_dependant.body_params
|
||||
)
|
||||
self.body_field = get_body_field(
|
||||
return dependant
|
||||
|
||||
@cached_property
|
||||
def _flat_dependant(self) -> Dependant:
|
||||
return get_flat_dependant(self.dependant)
|
||||
|
||||
@cached_property
|
||||
def _embed_body_fields(self) -> bool:
|
||||
return _should_embed_body_fields(self._flat_dependant.body_params)
|
||||
|
||||
@cached_property
|
||||
def body_field(self) -> Optional[ModelField]:
|
||||
return get_body_field(
|
||||
flat_dependant=self._flat_dependant,
|
||||
name=self.unique_id,
|
||||
embed_body_fields=self._embed_body_fields,
|
||||
)
|
||||
self.app = request_response(self.get_route_handler())
|
||||
|
||||
@cached_property
|
||||
def app(self) -> ASGIApp: # type: ignore
|
||||
return request_response(self.get_route_handler())
|
||||
|
||||
def init_attributes(self) -> None:
|
||||
_ = self.app
|
||||
_ = self.dependant
|
||||
_ = self.response_field
|
||||
_ = self.response_fields
|
||||
_ = self.secure_cloned_response_field
|
||||
_ = self.body_field
|
||||
_ = self._flat_dependant
|
||||
_ = self._embed_body_fields
|
||||
|
||||
def get_route_handler(self) -> Callable[[Request], Coroutine[Any, Any, Response]]:
|
||||
return get_request_handler(
|
||||
|
|
@ -967,6 +999,16 @@ class APIRouter(routing.Router):
|
|||
"""
|
||||
),
|
||||
] = Default(generate_unique_id),
|
||||
defer_init: Annotated[
|
||||
bool,
|
||||
Doc(
|
||||
"""
|
||||
By default, every route will defer its initialization until the first call.
|
||||
This flag can be used to deactivate this behavior for the routes defined in this router,
|
||||
causing the routes to initialize immediately when they are defined.
|
||||
"""
|
||||
),
|
||||
] = True,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
routes=routes,
|
||||
|
|
@ -992,6 +1034,9 @@ class APIRouter(routing.Router):
|
|||
self.route_class = route_class
|
||||
self.default_response_class = default_response_class
|
||||
self.generate_unique_id_function = generate_unique_id_function
|
||||
self.defer_init = defer_init
|
||||
if not self.defer_init:
|
||||
self.init_routes()
|
||||
|
||||
def route(
|
||||
self,
|
||||
|
|
@ -1091,6 +1136,7 @@ class APIRouter(routing.Router):
|
|||
callbacks=current_callbacks,
|
||||
openapi_extra=openapi_extra,
|
||||
generate_unique_id_function=current_generate_unique_id,
|
||||
defer_init=self.defer_init,
|
||||
)
|
||||
self.routes.append(route)
|
||||
|
||||
|
|
@ -1253,6 +1299,11 @@ class APIRouter(routing.Router):
|
|||
|
||||
return decorator
|
||||
|
||||
def init_routes(self) -> None:
|
||||
for route in self.routes:
|
||||
if isinstance(route, APIRoute):
|
||||
route.init_attributes()
|
||||
|
||||
def include_router(
|
||||
self,
|
||||
router: Annotated["APIRouter", Doc("The `APIRouter` to include.")],
|
||||
|
|
|
|||
|
|
@ -0,0 +1,117 @@
|
|||
from itertools import chain
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, FastAPI
|
||||
from fastapi.routing import APIRoute
|
||||
from fastapi.testclient import TestClient
|
||||
from pydantic import BaseModel
|
||||
from starlette.routing import BaseRoute
|
||||
|
||||
deferred_keys = [
|
||||
"app",
|
||||
"response_fields",
|
||||
"body_field",
|
||||
"response_field",
|
||||
"secure_cloned_response_field",
|
||||
"dependant",
|
||||
"_flat_dependant",
|
||||
"_embed_body_fields",
|
||||
]
|
||||
|
||||
|
||||
def check_if_initialized(route: APIRoute, should_not: bool = False):
|
||||
for key in deferred_keys:
|
||||
if should_not:
|
||||
assert key not in route.__dict__
|
||||
else:
|
||||
assert key in route.__dict__
|
||||
|
||||
|
||||
def create_test_router(
|
||||
routes: Optional[List[BaseRoute]] = None, defer_init: bool = True
|
||||
):
|
||||
router = APIRouter(routes=routes or [], defer_init=defer_init)
|
||||
|
||||
class UserIdBody(BaseModel):
|
||||
user_id: int
|
||||
|
||||
@router.get("/user_id", dependencies=[Depends(lambda: True)])
|
||||
async def get_user_id(user_id: int = Depends(lambda: 1)) -> UserIdBody:
|
||||
return {"user_id": user_id}
|
||||
|
||||
return router
|
||||
|
||||
|
||||
def test_route_defers():
|
||||
app = FastAPI()
|
||||
router = create_test_router(routes=app.router.routes)
|
||||
|
||||
for route in router.routes:
|
||||
if not isinstance(route, APIRoute):
|
||||
continue
|
||||
check_if_initialized(route, should_not=True)
|
||||
|
||||
app.router = router
|
||||
client = TestClient(app)
|
||||
response = client.get("/user_id")
|
||||
assert response.status_code == 200
|
||||
response = client.get("/openapi.json")
|
||||
assert response.status_code == 200
|
||||
|
||||
for route in router.routes:
|
||||
if not isinstance(route, APIRoute):
|
||||
continue
|
||||
check_if_initialized(route)
|
||||
|
||||
|
||||
def test_route_manual_init():
|
||||
router = create_test_router()
|
||||
for route in router.routes:
|
||||
check_if_initialized(route, should_not=True)
|
||||
route.init_attributes()
|
||||
check_if_initialized(route)
|
||||
|
||||
router = create_test_router()
|
||||
router.init_routes()
|
||||
for route in router.routes:
|
||||
check_if_initialized(route)
|
||||
|
||||
|
||||
def test_router_defer_init_flag():
|
||||
route = APIRoute("/test", lambda: {"test": True}, defer_init=False)
|
||||
check_if_initialized(route)
|
||||
|
||||
deferring_router = create_test_router()
|
||||
router = create_test_router(routes=deferring_router.routes, defer_init=False)
|
||||
|
||||
for route in router.routes:
|
||||
check_if_initialized(route)
|
||||
|
||||
|
||||
def test_root_router_always_initialized():
|
||||
app = FastAPI()
|
||||
|
||||
@app.get("/test")
|
||||
async def test_get():
|
||||
return {"test": 1}
|
||||
|
||||
router = create_test_router()
|
||||
app.include_router(router)
|
||||
for route in app.router.routes:
|
||||
if not isinstance(route, APIRoute):
|
||||
continue
|
||||
check_if_initialized(route)
|
||||
|
||||
client = TestClient(app)
|
||||
response = client.get("/test")
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
def test_include_router_no_init():
|
||||
router1 = create_test_router()
|
||||
|
||||
router2 = create_test_router()
|
||||
router2.include_router(router1)
|
||||
|
||||
for route in chain(router1.routes, router2.routes):
|
||||
check_if_initialized(route, should_not=True)
|
||||
Loading…
Reference in New Issue