This commit is contained in:
Jan Vollmer 2025-12-16 21:07:31 +00:00 committed by GitHub
commit 3c308e0cba
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 202 additions and 33 deletions

View File

@ -979,6 +979,7 @@ class FastAPI(Starlette):
include_in_schema=include_in_schema, include_in_schema=include_in_schema,
responses=responses, responses=responses,
generate_unique_id_function=generate_unique_id_function, generate_unique_id_function=generate_unique_id_function,
defer_init=False,
) )
self.exception_handlers: Dict[ self.exception_handlers: Dict[
Any, Callable[[Request, Any], Union[Response, Awaitable[Response]]] Any, Callable[[Request, Any], Union[Response, Awaitable[Response]]]

View File

@ -5,6 +5,7 @@ import inspect
import json import json
from contextlib import AsyncExitStack, asynccontextmanager from contextlib import AsyncExitStack, asynccontextmanager
from enum import Enum, IntEnum from enum import Enum, IntEnum
from functools import cached_property
from typing import ( from typing import (
Any, Any,
AsyncIterator, AsyncIterator,
@ -591,6 +592,7 @@ class APIRoute(routing.Route):
generate_unique_id_function: Union[ generate_unique_id_function: Union[
Callable[["APIRoute"], str], DefaultPlaceholder Callable[["APIRoute"], str], DefaultPlaceholder
] = Default(generate_unique_id), ] = Default(generate_unique_id),
defer_init: bool = True,
) -> None: ) -> None:
self.path = path self.path = path
self.endpoint = endpoint self.endpoint = endpoint
@ -639,31 +641,42 @@ class APIRoute(routing.Route):
assert is_body_allowed_for_status_code(status_code), ( assert is_body_allowed_for_status_code(status_code), (
f"Status code {status_code} must not have a response body" f"Status code {status_code} must not have a response body"
) )
response_name = "Response_" + self.unique_id
self.response_field = create_model_field(
name=response_name,
type_=self.response_model,
mode="serialization",
)
# Create a clone of the field, so that a Pydantic submodel is not returned
# as is just because it's an instance of a subclass of a more limited class
# e.g. UserInDB (containing hashed_password) could be a subclass of User
# that doesn't have the hashed_password. But because it's a subclass, it
# would pass the validation and be returned as is.
# By being a new field, no inheritance will be passed as is. A new model
# will always be created.
# TODO: remove when deprecating Pydantic v1
self.secure_cloned_response_field: Optional[ModelField] = (
create_cloned_field(self.response_field)
)
else:
self.response_field = None # type: ignore
self.secure_cloned_response_field = None
self.dependencies = list(dependencies or []) self.dependencies = list(dependencies or [])
self.description = description or inspect.cleandoc(self.endpoint.__doc__ or "") self.description = description or inspect.cleandoc(self.endpoint.__doc__ or "")
# if a "form feed" character (page break) is found in the description text, # if a "form feed" character (page break) is found in the description text,
# truncate description text to the content preceding the first "form feed" # truncate description text to the content preceding the first "form feed"
self.description = self.description.split("\f")[0].strip() self.description = self.description.split("\f")[0].strip()
assert callable(endpoint), "An endpoint must be a callable"
if not defer_init:
self.init_attributes()
@cached_property
def response_field(self) -> Optional[ModelField]:
if not self.response_model:
return None
response_name = "Response_" + self.unique_id
return create_model_field(
name=response_name,
type_=self.response_model,
mode="serialization",
)
@cached_property
def secure_cloned_response_field(self) -> Optional[ModelField]:
# Create a clone of the field, so that a Pydantic submodel is not returned
# as is just because it's an instance of a subclass of a more limited class
# e.g. UserInDB (containing hashed_password) could be a subclass of User
# that doesn't have the hashed_password. But because it's a subclass, it
# would pass the validation and be returned as is.
# By being a new field, no inheritance will be passed as is. A new model
# will always be created.
# TODO: remove when deprecating Pydantic v1
return create_cloned_field(self.response_field) if self.response_field else None
@cached_property
def response_fields(self) -> Dict[Union[int, str], ModelField]:
response_fields = {} response_fields = {}
for additional_status_code, response in self.responses.items(): for additional_status_code, response in self.responses.items():
assert isinstance(response, dict), "An additional response must be a dict" assert isinstance(response, dict), "An additional response must be a dict"
@ -677,30 +690,49 @@ class APIRoute(routing.Route):
name=response_name, type_=model, mode="serialization" name=response_name, type_=model, mode="serialization"
) )
response_fields[additional_status_code] = response_field response_fields[additional_status_code] = response_field
if response_fields: return response_fields
self.response_fields: Dict[Union[int, str], ModelField] = response_fields
else:
self.response_fields = {}
assert callable(endpoint), "An endpoint must be a callable" @cached_property
self.dependant = get_dependant( def dependant(self) -> Dependant:
dependant = get_dependant(
path=self.path_format, call=self.endpoint, scope="function" path=self.path_format, call=self.endpoint, scope="function"
) )
for depends in self.dependencies[::-1]: for depends in self.dependencies[::-1]:
self.dependant.dependencies.insert( dependant.dependencies.insert(
0, 0,
get_parameterless_sub_dependant(depends=depends, path=self.path_format), get_parameterless_sub_dependant(depends=depends, path=self.path_format),
) )
self._flat_dependant = get_flat_dependant(self.dependant) return dependant
self._embed_body_fields = _should_embed_body_fields(
self._flat_dependant.body_params @cached_property
) def _flat_dependant(self) -> Dependant:
self.body_field = get_body_field( return get_flat_dependant(self.dependant)
@cached_property
def _embed_body_fields(self) -> bool:
return _should_embed_body_fields(self._flat_dependant.body_params)
@cached_property
def body_field(self) -> Optional[ModelField]:
return get_body_field(
flat_dependant=self._flat_dependant, flat_dependant=self._flat_dependant,
name=self.unique_id, name=self.unique_id,
embed_body_fields=self._embed_body_fields, embed_body_fields=self._embed_body_fields,
) )
self.app = request_response(self.get_route_handler())
@cached_property
def app(self) -> ASGIApp: # type: ignore
return request_response(self.get_route_handler())
def init_attributes(self) -> None:
_ = self.app
_ = self.dependant
_ = self.response_field
_ = self.response_fields
_ = self.secure_cloned_response_field
_ = self.body_field
_ = self._flat_dependant
_ = self._embed_body_fields
def get_route_handler(self) -> Callable[[Request], Coroutine[Any, Any, Response]]: def get_route_handler(self) -> Callable[[Request], Coroutine[Any, Any, Response]]:
return get_request_handler( return get_request_handler(
@ -967,6 +999,16 @@ class APIRouter(routing.Router):
""" """
), ),
] = Default(generate_unique_id), ] = Default(generate_unique_id),
defer_init: Annotated[
bool,
Doc(
"""
By default, every route will defer its initialization until the first call.
This flag can be used to deactivate this behavior for the routes defined in this router,
causing the routes to initialize immediately when they are defined.
"""
),
] = True,
) -> None: ) -> None:
super().__init__( super().__init__(
routes=routes, routes=routes,
@ -992,6 +1034,9 @@ class APIRouter(routing.Router):
self.route_class = route_class self.route_class = route_class
self.default_response_class = default_response_class self.default_response_class = default_response_class
self.generate_unique_id_function = generate_unique_id_function self.generate_unique_id_function = generate_unique_id_function
self.defer_init = defer_init
if not self.defer_init:
self.init_routes()
def route( def route(
self, self,
@ -1091,6 +1136,7 @@ class APIRouter(routing.Router):
callbacks=current_callbacks, callbacks=current_callbacks,
openapi_extra=openapi_extra, openapi_extra=openapi_extra,
generate_unique_id_function=current_generate_unique_id, generate_unique_id_function=current_generate_unique_id,
defer_init=self.defer_init,
) )
self.routes.append(route) self.routes.append(route)
@ -1253,6 +1299,11 @@ class APIRouter(routing.Router):
return decorator return decorator
def init_routes(self) -> None:
for route in self.routes:
if isinstance(route, APIRoute):
route.init_attributes()
def include_router( def include_router(
self, self,
router: Annotated["APIRouter", Doc("The `APIRouter` to include.")], router: Annotated["APIRouter", Doc("The `APIRouter` to include.")],

View File

@ -0,0 +1,117 @@
from itertools import chain
from typing import List, Optional
from fastapi import APIRouter, Depends, FastAPI
from fastapi.routing import APIRoute
from fastapi.testclient import TestClient
from pydantic import BaseModel
from starlette.routing import BaseRoute
deferred_keys = [
"app",
"response_fields",
"body_field",
"response_field",
"secure_cloned_response_field",
"dependant",
"_flat_dependant",
"_embed_body_fields",
]
def check_if_initialized(route: APIRoute, should_not: bool = False):
for key in deferred_keys:
if should_not:
assert key not in route.__dict__
else:
assert key in route.__dict__
def create_test_router(
routes: Optional[List[BaseRoute]] = None, defer_init: bool = True
):
router = APIRouter(routes=routes or [], defer_init=defer_init)
class UserIdBody(BaseModel):
user_id: int
@router.get("/user_id", dependencies=[Depends(lambda: True)])
async def get_user_id(user_id: int = Depends(lambda: 1)) -> UserIdBody:
return {"user_id": user_id}
return router
def test_route_defers():
app = FastAPI()
router = create_test_router(routes=app.router.routes)
for route in router.routes:
if not isinstance(route, APIRoute):
continue
check_if_initialized(route, should_not=True)
app.router = router
client = TestClient(app)
response = client.get("/user_id")
assert response.status_code == 200
response = client.get("/openapi.json")
assert response.status_code == 200
for route in router.routes:
if not isinstance(route, APIRoute):
continue
check_if_initialized(route)
def test_route_manual_init():
router = create_test_router()
for route in router.routes:
check_if_initialized(route, should_not=True)
route.init_attributes()
check_if_initialized(route)
router = create_test_router()
router.init_routes()
for route in router.routes:
check_if_initialized(route)
def test_router_defer_init_flag():
route = APIRoute("/test", lambda: {"test": True}, defer_init=False)
check_if_initialized(route)
deferring_router = create_test_router()
router = create_test_router(routes=deferring_router.routes, defer_init=False)
for route in router.routes:
check_if_initialized(route)
def test_root_router_always_initialized():
app = FastAPI()
@app.get("/test")
async def test_get():
return {"test": 1}
router = create_test_router()
app.include_router(router)
for route in app.router.routes:
if not isinstance(route, APIRoute):
continue
check_if_initialized(route)
client = TestClient(app)
response = client.get("/test")
assert response.status_code == 200
def test_include_router_no_init():
router1 = create_test_router()
router2 = create_test_router()
router2.include_router(router1)
for route in chain(router1.routes, router2.routes):
check_if_initialized(route, should_not=True)