From f9f79926047d7f5d8e0127dd585e7683ee57e9f7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Fri, 6 Feb 2026 07:18:30 -0800 Subject: [PATCH] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20Re-implement=20`on=5Fevent?= =?UTF-8?q?`=20in=20FastAPI=20for=20compatibility=20with=20the=20next=20St?= =?UTF-8?q?arlette,=20while=20keeping=20backwards=20compatibility=20(#1485?= =?UTF-8?q?1)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastapi/routing.py | 157 +++++++++++++++++++++++++++++++++++- tests/test_router_events.py | 76 +++++++++++++++++ 2 files changed, 229 insertions(+), 4 deletions(-) diff --git a/fastapi/routing.py b/fastapi/routing.py index 9ca2f46732..c95f624bdf 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -1,22 +1,31 @@ +import contextlib import email.message import functools import inspect import json +import types from collections.abc import ( AsyncIterator, Awaitable, Collection, Coroutine, + Generator, Mapping, Sequence, ) -from contextlib import AsyncExitStack, asynccontextmanager +from contextlib import ( + AbstractAsyncContextManager, + AbstractContextManager, + AsyncExitStack, + asynccontextmanager, +) from enum import Enum, IntEnum from typing import ( Annotated, Any, Callable, Optional, + TypeVar, Union, ) @@ -143,6 +152,50 @@ def websocket_session( return app +_T = TypeVar("_T") + + +# Vendored from starlette.routing to avoid importing private symbols +class _AsyncLiftContextManager(AbstractAsyncContextManager[_T]): + """ + Wraps a synchronous context manager to make it async. + + This is vendored from Starlette to avoid importing private symbols. + """ + + def __init__(self, cm: AbstractContextManager[_T]) -> None: + self._cm = cm + + async def __aenter__(self) -> _T: + return self._cm.__enter__() + + async def __aexit__( + self, + exc_type: Optional[type[BaseException]], + exc_value: Optional[BaseException], + traceback: Optional[types.TracebackType], + ) -> Optional[bool]: + return self._cm.__exit__(exc_type, exc_value, traceback) + + +# Vendored from starlette.routing to avoid importing private symbols +def _wrap_gen_lifespan_context( + lifespan_context: Callable[[Any], Generator[Any, Any, Any]], +) -> Callable[[Any], AbstractAsyncContextManager[Any]]: + """ + Wrap a generator-based lifespan context into an async context manager. + + This is vendored from Starlette to avoid importing private symbols. + """ + cmgr = contextlib.contextmanager(lifespan_context) + + @functools.wraps(cmgr) + def wrapper(app: Any) -> _AsyncLiftContextManager[Any]: + return _AsyncLiftContextManager(cmgr(app)) + + return wrapper + + def _merge_lifespan_context( original_context: Lifespan[Any], nested_context: Lifespan[Any] ) -> Lifespan[Any]: @@ -160,6 +213,30 @@ def _merge_lifespan_context( return merged_lifespan # type: ignore[return-value] +class _DefaultLifespan: + """ + Default lifespan context manager that runs on_startup and on_shutdown handlers. + + This is a copy of the Starlette _DefaultLifespan class that was removed + in Starlette. FastAPI keeps it to maintain backward compatibility with + on_startup and on_shutdown event handlers. + + Ref: https://github.com/Kludex/starlette/pull/3117 + """ + + def __init__(self, router: "APIRouter") -> None: + self._router = router + + async def __aenter__(self) -> None: + await self._router._startup() + + async def __aexit__(self, *exc_info: object) -> None: + await self._router._shutdown() + + def __call__(self: _T, app: object) -> _T: + return self + + # Cache for endpoint context to avoid re-extracting on every request _endpoint_context_cache: dict[int, EndpointContext] = {} @@ -903,13 +980,33 @@ class APIRouter(routing.Router): ), ] = Default(generate_unique_id), ) -> None: + # Handle on_startup/on_shutdown locally since Starlette removed support + # Ref: https://github.com/Kludex/starlette/pull/3117 + # TODO: deprecate this once the lifespan (or alternative) interface is improved + self.on_startup: list[Callable[[], Any]] = ( + [] if on_startup is None else list(on_startup) + ) + self.on_shutdown: list[Callable[[], Any]] = ( + [] if on_shutdown is None else list(on_shutdown) + ) + + # Determine the lifespan context to use + if lifespan is None: + # Use the default lifespan that runs on_startup/on_shutdown handlers + lifespan_context: Lifespan[Any] = _DefaultLifespan(self) + elif inspect.isasyncgenfunction(lifespan): + lifespan_context = asynccontextmanager(lifespan) + elif inspect.isgeneratorfunction(lifespan): + lifespan_context = _wrap_gen_lifespan_context(lifespan) + else: + lifespan_context = lifespan + self.lifespan_context = lifespan_context + super().__init__( routes=routes, redirect_slashes=redirect_slashes, default=default, - on_startup=on_startup, - on_shutdown=on_shutdown, - lifespan=lifespan, + lifespan=lifespan_context, ) if prefix: assert prefix.startswith("/"), "A path prefix must start with '/'" @@ -4473,6 +4570,58 @@ class APIRouter(routing.Router): generate_unique_id_function=generate_unique_id_function, ) + # TODO: remove this once the lifespan (or alternative) interface is improved + async def _startup(self) -> None: + """ + Run any `.on_startup` event handlers. + + This method is kept for backward compatibility after Starlette removed + support for on_startup/on_shutdown handlers. + + Ref: https://github.com/Kludex/starlette/pull/3117 + """ + for handler in self.on_startup: + if is_async_callable(handler): + await handler() + else: + handler() + + # TODO: remove this once the lifespan (or alternative) interface is improved + async def _shutdown(self) -> None: + """ + Run any `.on_shutdown` event handlers. + + This method is kept for backward compatibility after Starlette removed + support for on_startup/on_shutdown handlers. + + Ref: https://github.com/Kludex/starlette/pull/3117 + """ + for handler in self.on_shutdown: + if is_async_callable(handler): + await handler() + else: + handler() + + # TODO: remove this once the lifespan (or alternative) interface is improved + def add_event_handler( + self, + event_type: str, + func: Callable[[], Any], + ) -> None: + """ + Add an event handler function for startup or shutdown. + + This method is kept for backward compatibility after Starlette removed + support for on_startup/on_shutdown handlers. + + Ref: https://github.com/Kludex/starlette/pull/3117 + """ + assert event_type in ("startup", "shutdown") + if event_type == "startup": + self.on_startup.append(func) + else: + self.on_shutdown.append(func) + @deprecated( """ on_event is deprecated, use lifespan event handlers instead. diff --git a/tests/test_router_events.py b/tests/test_router_events.py index 9df299cdaa..65f2f521c1 100644 --- a/tests/test_router_events.py +++ b/tests/test_router_events.py @@ -241,3 +241,79 @@ def test_merged_mixed_state_lifespans() -> None: with TestClient(app) as client: assert client.app_state == {"router": True} + + +@pytest.mark.filterwarnings( + r"ignore:\s*on_event is deprecated, use lifespan event handlers instead.*:DeprecationWarning" +) +def test_router_async_shutdown_handler(state: State) -> None: + """Test that async on_shutdown event handlers are called correctly, for coverage.""" + app = FastAPI() + + @app.get("/") + def main() -> dict[str, str]: + return {"message": "Hello World"} + + @app.on_event("shutdown") + async def app_shutdown() -> None: + state.app_shutdown = True + + assert state.app_shutdown is False + with TestClient(app) as client: + assert state.app_shutdown is False + response = client.get("/") + assert response.status_code == 200, response.text + assert state.app_shutdown is True + + +def test_router_sync_generator_lifespan(state: State) -> None: + """Test that a sync generator lifespan works via _wrap_gen_lifespan_context.""" + from collections.abc import Generator + + def lifespan(app: FastAPI) -> Generator[None, None, None]: + state.app_startup = True + yield + state.app_shutdown = True + + app = FastAPI(lifespan=lifespan) # type: ignore[arg-type] + + @app.get("/") + def main() -> dict[str, str]: + return {"message": "Hello World"} + + assert state.app_startup is False + assert state.app_shutdown is False + with TestClient(app) as client: + assert state.app_startup is True + assert state.app_shutdown is False + response = client.get("/") + assert response.status_code == 200, response.text + assert response.json() == {"message": "Hello World"} + assert state.app_startup is True + assert state.app_shutdown is True + + +def test_router_async_generator_lifespan(state: State) -> None: + """Test that an async generator lifespan (not wrapped) works.""" + + async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: + state.app_startup = True + yield + state.app_shutdown = True + + app = FastAPI(lifespan=lifespan) # type: ignore[arg-type] + + @app.get("/") + def main() -> dict[str, str]: + return {"message": "Hello World"} + + assert state.app_startup is False + assert state.app_shutdown is False + with TestClient(app) as client: + assert state.app_startup is True + assert state.app_shutdown is False + response = client.get("/") + assert response.status_code == 200, response.text + assert response.json() == {"message": "Hello World"} + assert state.app_startup is True + assert state.app_shutdown is True