♻️ Re-implement `on_event` in FastAPI for compatibility with the next Starlette, while keeping backwards compatibility (#14851)

This commit is contained in:
Sebastián Ramírez 2026-02-06 07:18:30 -08:00 committed by GitHub
parent 8e50c55fd9
commit f9f7992604
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 229 additions and 4 deletions

View File

@ -1,22 +1,31 @@
import contextlib
import email.message
import functools
import inspect
import json
import types
from collections.abc import (
AsyncIterator,
Awaitable,
Collection,
Coroutine,
Generator,
Mapping,
Sequence,
)
from contextlib import AsyncExitStack, asynccontextmanager
from contextlib import (
AbstractAsyncContextManager,
AbstractContextManager,
AsyncExitStack,
asynccontextmanager,
)
from enum import Enum, IntEnum
from typing import (
Annotated,
Any,
Callable,
Optional,
TypeVar,
Union,
)
@ -143,6 +152,50 @@ def websocket_session(
return app
_T = TypeVar("_T")
# Vendored from starlette.routing to avoid importing private symbols
class _AsyncLiftContextManager(AbstractAsyncContextManager[_T]):
"""
Wraps a synchronous context manager to make it async.
This is vendored from Starlette to avoid importing private symbols.
"""
def __init__(self, cm: AbstractContextManager[_T]) -> None:
self._cm = cm
async def __aenter__(self) -> _T:
return self._cm.__enter__()
async def __aexit__(
self,
exc_type: Optional[type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[types.TracebackType],
) -> Optional[bool]:
return self._cm.__exit__(exc_type, exc_value, traceback)
# Vendored from starlette.routing to avoid importing private symbols
def _wrap_gen_lifespan_context(
lifespan_context: Callable[[Any], Generator[Any, Any, Any]],
) -> Callable[[Any], AbstractAsyncContextManager[Any]]:
"""
Wrap a generator-based lifespan context into an async context manager.
This is vendored from Starlette to avoid importing private symbols.
"""
cmgr = contextlib.contextmanager(lifespan_context)
@functools.wraps(cmgr)
def wrapper(app: Any) -> _AsyncLiftContextManager[Any]:
return _AsyncLiftContextManager(cmgr(app))
return wrapper
def _merge_lifespan_context(
original_context: Lifespan[Any], nested_context: Lifespan[Any]
) -> Lifespan[Any]:
@ -160,6 +213,30 @@ def _merge_lifespan_context(
return merged_lifespan # type: ignore[return-value]
class _DefaultLifespan:
"""
Default lifespan context manager that runs on_startup and on_shutdown handlers.
This is a copy of the Starlette _DefaultLifespan class that was removed
in Starlette. FastAPI keeps it to maintain backward compatibility with
on_startup and on_shutdown event handlers.
Ref: https://github.com/Kludex/starlette/pull/3117
"""
def __init__(self, router: "APIRouter") -> None:
self._router = router
async def __aenter__(self) -> None:
await self._router._startup()
async def __aexit__(self, *exc_info: object) -> None:
await self._router._shutdown()
def __call__(self: _T, app: object) -> _T:
return self
# Cache for endpoint context to avoid re-extracting on every request
_endpoint_context_cache: dict[int, EndpointContext] = {}
@ -903,13 +980,33 @@ class APIRouter(routing.Router):
),
] = Default(generate_unique_id),
) -> None:
# Handle on_startup/on_shutdown locally since Starlette removed support
# Ref: https://github.com/Kludex/starlette/pull/3117
# TODO: deprecate this once the lifespan (or alternative) interface is improved
self.on_startup: list[Callable[[], Any]] = (
[] if on_startup is None else list(on_startup)
)
self.on_shutdown: list[Callable[[], Any]] = (
[] if on_shutdown is None else list(on_shutdown)
)
# Determine the lifespan context to use
if lifespan is None:
# Use the default lifespan that runs on_startup/on_shutdown handlers
lifespan_context: Lifespan[Any] = _DefaultLifespan(self)
elif inspect.isasyncgenfunction(lifespan):
lifespan_context = asynccontextmanager(lifespan)
elif inspect.isgeneratorfunction(lifespan):
lifespan_context = _wrap_gen_lifespan_context(lifespan)
else:
lifespan_context = lifespan
self.lifespan_context = lifespan_context
super().__init__(
routes=routes,
redirect_slashes=redirect_slashes,
default=default,
on_startup=on_startup,
on_shutdown=on_shutdown,
lifespan=lifespan,
lifespan=lifespan_context,
)
if prefix:
assert prefix.startswith("/"), "A path prefix must start with '/'"
@ -4473,6 +4570,58 @@ class APIRouter(routing.Router):
generate_unique_id_function=generate_unique_id_function,
)
# TODO: remove this once the lifespan (or alternative) interface is improved
async def _startup(self) -> None:
"""
Run any `.on_startup` event handlers.
This method is kept for backward compatibility after Starlette removed
support for on_startup/on_shutdown handlers.
Ref: https://github.com/Kludex/starlette/pull/3117
"""
for handler in self.on_startup:
if is_async_callable(handler):
await handler()
else:
handler()
# TODO: remove this once the lifespan (or alternative) interface is improved
async def _shutdown(self) -> None:
"""
Run any `.on_shutdown` event handlers.
This method is kept for backward compatibility after Starlette removed
support for on_startup/on_shutdown handlers.
Ref: https://github.com/Kludex/starlette/pull/3117
"""
for handler in self.on_shutdown:
if is_async_callable(handler):
await handler()
else:
handler()
# TODO: remove this once the lifespan (or alternative) interface is improved
def add_event_handler(
self,
event_type: str,
func: Callable[[], Any],
) -> None:
"""
Add an event handler function for startup or shutdown.
This method is kept for backward compatibility after Starlette removed
support for on_startup/on_shutdown handlers.
Ref: https://github.com/Kludex/starlette/pull/3117
"""
assert event_type in ("startup", "shutdown")
if event_type == "startup":
self.on_startup.append(func)
else:
self.on_shutdown.append(func)
@deprecated(
"""
on_event is deprecated, use lifespan event handlers instead.

View File

@ -241,3 +241,79 @@ def test_merged_mixed_state_lifespans() -> None:
with TestClient(app) as client:
assert client.app_state == {"router": True}
@pytest.mark.filterwarnings(
r"ignore:\s*on_event is deprecated, use lifespan event handlers instead.*:DeprecationWarning"
)
def test_router_async_shutdown_handler(state: State) -> None:
"""Test that async on_shutdown event handlers are called correctly, for coverage."""
app = FastAPI()
@app.get("/")
def main() -> dict[str, str]:
return {"message": "Hello World"}
@app.on_event("shutdown")
async def app_shutdown() -> None:
state.app_shutdown = True
assert state.app_shutdown is False
with TestClient(app) as client:
assert state.app_shutdown is False
response = client.get("/")
assert response.status_code == 200, response.text
assert state.app_shutdown is True
def test_router_sync_generator_lifespan(state: State) -> None:
"""Test that a sync generator lifespan works via _wrap_gen_lifespan_context."""
from collections.abc import Generator
def lifespan(app: FastAPI) -> Generator[None, None, None]:
state.app_startup = True
yield
state.app_shutdown = True
app = FastAPI(lifespan=lifespan) # type: ignore[arg-type]
@app.get("/")
def main() -> dict[str, str]:
return {"message": "Hello World"}
assert state.app_startup is False
assert state.app_shutdown is False
with TestClient(app) as client:
assert state.app_startup is True
assert state.app_shutdown is False
response = client.get("/")
assert response.status_code == 200, response.text
assert response.json() == {"message": "Hello World"}
assert state.app_startup is True
assert state.app_shutdown is True
def test_router_async_generator_lifespan(state: State) -> None:
"""Test that an async generator lifespan (not wrapped) works."""
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
state.app_startup = True
yield
state.app_shutdown = True
app = FastAPI(lifespan=lifespan) # type: ignore[arg-type]
@app.get("/")
def main() -> dict[str, str]:
return {"message": "Hello World"}
assert state.app_startup is False
assert state.app_shutdown is False
with TestClient(app) as client:
assert state.app_startup is True
assert state.app_shutdown is False
response = client.get("/")
assert response.status_code == 200, response.text
assert response.json() == {"message": "Hello World"}
assert state.app_startup is True
assert state.app_shutdown is True