diff --git a/fastapi/applications.py b/fastapi/applications.py index 41d86143e..84f01d7a7 100644 --- a/fastapi/applications.py +++ b/fastapi/applications.py @@ -974,6 +974,7 @@ class FastAPI(Starlette): include_in_schema=include_in_schema, responses=responses, generate_unique_id_function=generate_unique_id_function, + defer_init=False, ) self.exception_handlers: dict[ Any, Callable[[Request, Any], Response | Awaitable[Response]] diff --git a/fastapi/routing.py b/fastapi/routing.py index ea82ab14a..3a525d9d2 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -21,6 +21,7 @@ from contextlib import ( asynccontextmanager, ) from enum import Enum, IntEnum +from functools import cached_property from typing import ( Annotated, Any, @@ -582,6 +583,7 @@ class APIRoute(routing.Route): openapi_extra: dict[str, Any] | None = None, generate_unique_id_function: Callable[["APIRoute"], str] | DefaultPlaceholder = Default(generate_unique_id), + defer_init: bool = True, ) -> None: self.path = path self.endpoint = endpoint @@ -630,20 +632,32 @@ class APIRoute(routing.Route): assert is_body_allowed_for_status_code(status_code), ( f"Status code {status_code} must not have a response body" ) - response_name = "Response_" + self.unique_id - self.response_field = create_model_field( - name=response_name, - type_=self.response_model, - mode="serialization", - ) - else: - self.response_field = None # type: ignore + self.dependencies = list(dependencies or []) self.description = description or inspect.cleandoc(self.endpoint.__doc__ or "") # if a "form feed" character (page break) is found in the description text, # truncate description text to the content preceding the first "form feed" self.description = self.description.split("\f")[0].strip() - response_fields = {} + + assert callable(endpoint), "An endpoint must be a callable" + + if not defer_init: + self.init_attributes() + + @cached_property + def response_field(self) -> ModelField | None: + if not self.response_model: + return None + response_name = "Response_" + self.unique_id + return create_model_field( + name=response_name, + type_=self.response_model, + mode="serialization", + ) + + @cached_property + def response_fields(self) -> dict[int | str, ModelField]: + response_fields: dict[int | str, ModelField] = {} for additional_status_code, response in self.responses.items(): assert isinstance(response, dict), "An additional response must be a dict" model = response.get("model") @@ -656,30 +670,49 @@ class APIRoute(routing.Route): name=response_name, type_=model, mode="serialization" ) response_fields[additional_status_code] = response_field - if response_fields: - self.response_fields: dict[int | str, ModelField] = response_fields - else: - self.response_fields = {} - assert callable(endpoint), "An endpoint must be a callable" - self.dependant = get_dependant( + return response_fields + + @cached_property + def dependant(self) -> Dependant: + dependant = get_dependant( path=self.path_format, call=self.endpoint, scope="function" ) for depends in self.dependencies[::-1]: - self.dependant.dependencies.insert( + dependant.dependencies.insert( 0, get_parameterless_sub_dependant(depends=depends, path=self.path_format), ) - self._flat_dependant = get_flat_dependant(self.dependant) - self._embed_body_fields = _should_embed_body_fields( - self._flat_dependant.body_params - ) - self.body_field = get_body_field( + return dependant + + @cached_property + def _flat_dependant(self) -> Dependant: + return get_flat_dependant(self.dependant) + + @cached_property + def _embed_body_fields(self) -> bool: + return _should_embed_body_fields(self._flat_dependant.body_params) + + @cached_property + def body_field(self) -> ModelField | None: + return get_body_field( flat_dependant=self._flat_dependant, name=self.unique_id, embed_body_fields=self._embed_body_fields, ) - self.app = request_response(self.get_route_handler()) + + @cached_property + def app(self) -> ASGIApp: # type: ignore + return request_response(self.get_route_handler()) + + def init_attributes(self) -> None: + _ = self.app + _ = self.dependant + _ = self.response_field + _ = self.response_fields + _ = self.body_field + _ = self._flat_dependant + _ = self._embed_body_fields def get_route_handler(self) -> Callable[[Request], Coroutine[Any, Any, Response]]: return get_request_handler( @@ -946,6 +979,16 @@ class APIRouter(routing.Router): """ ), ] = Default(generate_unique_id), + defer_init: Annotated[ + bool, + Doc( + """ + By default, every route will defer its initialization until the first call. + This flag can be used to deactivate this behavior for the routes defined in this router, + causing the routes to initialize immediately when they are defined. + """ + ), + ] = True, ) -> None: # Determine the lifespan context to use if lifespan is None: @@ -992,6 +1035,9 @@ class APIRouter(routing.Router): self.route_class = route_class self.default_response_class = default_response_class self.generate_unique_id_function = generate_unique_id_function + self.defer_init = defer_init + if not self.defer_init: + self.init_routes() def route( self, @@ -1088,6 +1134,7 @@ class APIRouter(routing.Router): callbacks=current_callbacks, openapi_extra=openapi_extra, generate_unique_id_function=current_generate_unique_id, + defer_init=self.defer_init, ) self.routes.append(route) @@ -1250,6 +1297,11 @@ class APIRouter(routing.Router): return decorator + def init_routes(self) -> None: + for route in self.routes: + if isinstance(route, APIRoute): + route.init_attributes() + def include_router( self, router: Annotated["APIRouter", Doc("The `APIRouter` to include.")], diff --git a/tests/test_route_deferred_init.py b/tests/test_route_deferred_init.py new file mode 100644 index 000000000..17e927425 --- /dev/null +++ b/tests/test_route_deferred_init.py @@ -0,0 +1,113 @@ +from itertools import chain + +from fastapi import APIRouter, Depends, FastAPI +from fastapi.routing import APIRoute +from fastapi.testclient import TestClient +from pydantic import BaseModel +from starlette.routing import BaseRoute + +deferred_keys = [ + "app", + "response_fields", + "body_field", + "response_field", + "dependant", + "_flat_dependant", + "_embed_body_fields", +] + + +def check_if_initialized(route: APIRoute, should_not: bool = False): + for key in deferred_keys: + if should_not: + assert key not in route.__dict__ + else: + assert key in route.__dict__ + + +def create_test_router(routes: list[BaseRoute] | None = None, defer_init: bool = True): + router = APIRouter(routes=routes or [], defer_init=defer_init) + + class UserIdBody(BaseModel): + user_id: int + + @router.get("/user_id", dependencies=[Depends(lambda: True)]) + async def get_user_id(user_id: int = Depends(lambda: 1)) -> UserIdBody: + return {"user_id": user_id} + + return router + + +def test_route_defers(): + app = FastAPI() + router = create_test_router(routes=app.router.routes) + + for route in router.routes: + if not isinstance(route, APIRoute): + continue + check_if_initialized(route, should_not=True) + + app.router = router + client = TestClient(app) + response = client.get("/user_id") + assert response.status_code == 200 + response = client.get("/openapi.json") + assert response.status_code == 200 + + for route in router.routes: + if not isinstance(route, APIRoute): + continue + check_if_initialized(route) + + +def test_route_manual_init(): + router = create_test_router() + for route in router.routes: + check_if_initialized(route, should_not=True) + route.init_attributes() + check_if_initialized(route) + + router = create_test_router() + router.init_routes() + for route in router.routes: + check_if_initialized(route) + + +def test_router_defer_init_flag(): + route = APIRoute("/test", lambda: {"test": True}, defer_init=False) + check_if_initialized(route) + + deferring_router = create_test_router() + router = create_test_router(routes=deferring_router.routes, defer_init=False) + + for route in router.routes: + check_if_initialized(route) + + +def test_root_router_always_initialized(): + app = FastAPI() + + @app.get("/test") + async def test_get(): + return {"test": 1} + + router = create_test_router() + app.include_router(router) + for route in app.router.routes: + if not isinstance(route, APIRoute): + continue + check_if_initialized(route) + + client = TestClient(app) + response = client.get("/test") + assert response.status_code == 200 + + +def test_include_router_no_init(): + router1 = create_test_router() + + router2 = create_test_router() + router2.include_router(router1) + + for route in chain(router1.routes, router2.routes): + check_if_initialized(route, should_not=True)