This commit is contained in:
Jan Vollmer 2026-02-17 10:01:25 +00:00 committed by GitHub
commit 8d59133f9e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 188 additions and 22 deletions

View File

@ -974,6 +974,7 @@ class FastAPI(Starlette):
include_in_schema=include_in_schema,
responses=responses,
generate_unique_id_function=generate_unique_id_function,
defer_init=False,
)
self.exception_handlers: dict[
Any, Callable[[Request, Any], Response | Awaitable[Response]]

View File

@ -21,6 +21,7 @@ from contextlib import (
asynccontextmanager,
)
from enum import Enum, IntEnum
from functools import cached_property
from typing import (
Annotated,
Any,
@ -582,6 +583,7 @@ class APIRoute(routing.Route):
openapi_extra: dict[str, Any] | None = None,
generate_unique_id_function: Callable[["APIRoute"], str]
| DefaultPlaceholder = Default(generate_unique_id),
defer_init: bool = True,
) -> None:
self.path = path
self.endpoint = endpoint
@ -630,20 +632,32 @@ class APIRoute(routing.Route):
assert is_body_allowed_for_status_code(status_code), (
f"Status code {status_code} must not have a response body"
)
response_name = "Response_" + self.unique_id
self.response_field = create_model_field(
name=response_name,
type_=self.response_model,
mode="serialization",
)
else:
self.response_field = None # type: ignore
self.dependencies = list(dependencies or [])
self.description = description or inspect.cleandoc(self.endpoint.__doc__ or "")
# if a "form feed" character (page break) is found in the description text,
# truncate description text to the content preceding the first "form feed"
self.description = self.description.split("\f")[0].strip()
response_fields = {}
assert callable(endpoint), "An endpoint must be a callable"
if not defer_init:
self.init_attributes()
@cached_property
def response_field(self) -> ModelField | None:
if not self.response_model:
return None
response_name = "Response_" + self.unique_id
return create_model_field(
name=response_name,
type_=self.response_model,
mode="serialization",
)
@cached_property
def response_fields(self) -> dict[int | str, ModelField]:
response_fields: dict[int | str, ModelField] = {}
for additional_status_code, response in self.responses.items():
assert isinstance(response, dict), "An additional response must be a dict"
model = response.get("model")
@ -656,30 +670,49 @@ class APIRoute(routing.Route):
name=response_name, type_=model, mode="serialization"
)
response_fields[additional_status_code] = response_field
if response_fields:
self.response_fields: dict[int | str, ModelField] = response_fields
else:
self.response_fields = {}
assert callable(endpoint), "An endpoint must be a callable"
self.dependant = get_dependant(
return response_fields
@cached_property
def dependant(self) -> Dependant:
dependant = get_dependant(
path=self.path_format, call=self.endpoint, scope="function"
)
for depends in self.dependencies[::-1]:
self.dependant.dependencies.insert(
dependant.dependencies.insert(
0,
get_parameterless_sub_dependant(depends=depends, path=self.path_format),
)
self._flat_dependant = get_flat_dependant(self.dependant)
self._embed_body_fields = _should_embed_body_fields(
self._flat_dependant.body_params
)
self.body_field = get_body_field(
return dependant
@cached_property
def _flat_dependant(self) -> Dependant:
return get_flat_dependant(self.dependant)
@cached_property
def _embed_body_fields(self) -> bool:
return _should_embed_body_fields(self._flat_dependant.body_params)
@cached_property
def body_field(self) -> ModelField | None:
return get_body_field(
flat_dependant=self._flat_dependant,
name=self.unique_id,
embed_body_fields=self._embed_body_fields,
)
self.app = request_response(self.get_route_handler())
@cached_property
def app(self) -> ASGIApp: # type: ignore
return request_response(self.get_route_handler())
def init_attributes(self) -> None:
_ = self.app
_ = self.dependant
_ = self.response_field
_ = self.response_fields
_ = self.body_field
_ = self._flat_dependant
_ = self._embed_body_fields
def get_route_handler(self) -> Callable[[Request], Coroutine[Any, Any, Response]]:
return get_request_handler(
@ -946,6 +979,16 @@ class APIRouter(routing.Router):
"""
),
] = Default(generate_unique_id),
defer_init: Annotated[
bool,
Doc(
"""
By default, every route will defer its initialization until the first call.
This flag can be used to deactivate this behavior for the routes defined in this router,
causing the routes to initialize immediately when they are defined.
"""
),
] = True,
) -> None:
# Determine the lifespan context to use
if lifespan is None:
@ -992,6 +1035,9 @@ class APIRouter(routing.Router):
self.route_class = route_class
self.default_response_class = default_response_class
self.generate_unique_id_function = generate_unique_id_function
self.defer_init = defer_init
if not self.defer_init:
self.init_routes()
def route(
self,
@ -1088,6 +1134,7 @@ class APIRouter(routing.Router):
callbacks=current_callbacks,
openapi_extra=openapi_extra,
generate_unique_id_function=current_generate_unique_id,
defer_init=self.defer_init,
)
self.routes.append(route)
@ -1250,6 +1297,11 @@ class APIRouter(routing.Router):
return decorator
def init_routes(self) -> None:
for route in self.routes:
if isinstance(route, APIRoute):
route.init_attributes()
def include_router(
self,
router: Annotated["APIRouter", Doc("The `APIRouter` to include.")],

View File

@ -0,0 +1,113 @@
from itertools import chain
from fastapi import APIRouter, Depends, FastAPI
from fastapi.routing import APIRoute
from fastapi.testclient import TestClient
from pydantic import BaseModel
from starlette.routing import BaseRoute
deferred_keys = [
"app",
"response_fields",
"body_field",
"response_field",
"dependant",
"_flat_dependant",
"_embed_body_fields",
]
def check_if_initialized(route: APIRoute, should_not: bool = False):
for key in deferred_keys:
if should_not:
assert key not in route.__dict__
else:
assert key in route.__dict__
def create_test_router(routes: list[BaseRoute] | None = None, defer_init: bool = True):
router = APIRouter(routes=routes or [], defer_init=defer_init)
class UserIdBody(BaseModel):
user_id: int
@router.get("/user_id", dependencies=[Depends(lambda: True)])
async def get_user_id(user_id: int = Depends(lambda: 1)) -> UserIdBody:
return {"user_id": user_id}
return router
def test_route_defers():
app = FastAPI()
router = create_test_router(routes=app.router.routes)
for route in router.routes:
if not isinstance(route, APIRoute):
continue
check_if_initialized(route, should_not=True)
app.router = router
client = TestClient(app)
response = client.get("/user_id")
assert response.status_code == 200
response = client.get("/openapi.json")
assert response.status_code == 200
for route in router.routes:
if not isinstance(route, APIRoute):
continue
check_if_initialized(route)
def test_route_manual_init():
router = create_test_router()
for route in router.routes:
check_if_initialized(route, should_not=True)
route.init_attributes()
check_if_initialized(route)
router = create_test_router()
router.init_routes()
for route in router.routes:
check_if_initialized(route)
def test_router_defer_init_flag():
route = APIRoute("/test", lambda: {"test": True}, defer_init=False)
check_if_initialized(route)
deferring_router = create_test_router()
router = create_test_router(routes=deferring_router.routes, defer_init=False)
for route in router.routes:
check_if_initialized(route)
def test_root_router_always_initialized():
app = FastAPI()
@app.get("/test")
async def test_get():
return {"test": 1}
router = create_test_router()
app.include_router(router)
for route in app.router.routes:
if not isinstance(route, APIRoute):
continue
check_if_initialized(route)
client = TestClient(app)
response = client.get("/test")
assert response.status_code == 200
def test_include_router_no_init():
router1 = create_test_router()
router2 = create_test_router()
router2.include_router(router1)
for route in chain(router1.routes, router2.routes):
check_if_initialized(route, should_not=True)