mirror of https://github.com/tiangolo/fastapi.git
♻️ Re-implement `on_event` in FastAPI for compatibility with the next Starlette, while keeping backwards compatibility (#14851)
This commit is contained in:
parent
8e50c55fd9
commit
f9f7992604
|
|
@ -1,22 +1,31 @@
|
|||
import contextlib
|
||||
import email.message
|
||||
import functools
|
||||
import inspect
|
||||
import json
|
||||
import types
|
||||
from collections.abc import (
|
||||
AsyncIterator,
|
||||
Awaitable,
|
||||
Collection,
|
||||
Coroutine,
|
||||
Generator,
|
||||
Mapping,
|
||||
Sequence,
|
||||
)
|
||||
from contextlib import AsyncExitStack, asynccontextmanager
|
||||
from contextlib import (
|
||||
AbstractAsyncContextManager,
|
||||
AbstractContextManager,
|
||||
AsyncExitStack,
|
||||
asynccontextmanager,
|
||||
)
|
||||
from enum import Enum, IntEnum
|
||||
from typing import (
|
||||
Annotated,
|
||||
Any,
|
||||
Callable,
|
||||
Optional,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
|
|
@ -143,6 +152,50 @@ def websocket_session(
|
|||
return app
|
||||
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
|
||||
# Vendored from starlette.routing to avoid importing private symbols
|
||||
class _AsyncLiftContextManager(AbstractAsyncContextManager[_T]):
|
||||
"""
|
||||
Wraps a synchronous context manager to make it async.
|
||||
|
||||
This is vendored from Starlette to avoid importing private symbols.
|
||||
"""
|
||||
|
||||
def __init__(self, cm: AbstractContextManager[_T]) -> None:
|
||||
self._cm = cm
|
||||
|
||||
async def __aenter__(self) -> _T:
|
||||
return self._cm.__enter__()
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: Optional[type[BaseException]],
|
||||
exc_value: Optional[BaseException],
|
||||
traceback: Optional[types.TracebackType],
|
||||
) -> Optional[bool]:
|
||||
return self._cm.__exit__(exc_type, exc_value, traceback)
|
||||
|
||||
|
||||
# Vendored from starlette.routing to avoid importing private symbols
|
||||
def _wrap_gen_lifespan_context(
|
||||
lifespan_context: Callable[[Any], Generator[Any, Any, Any]],
|
||||
) -> Callable[[Any], AbstractAsyncContextManager[Any]]:
|
||||
"""
|
||||
Wrap a generator-based lifespan context into an async context manager.
|
||||
|
||||
This is vendored from Starlette to avoid importing private symbols.
|
||||
"""
|
||||
cmgr = contextlib.contextmanager(lifespan_context)
|
||||
|
||||
@functools.wraps(cmgr)
|
||||
def wrapper(app: Any) -> _AsyncLiftContextManager[Any]:
|
||||
return _AsyncLiftContextManager(cmgr(app))
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def _merge_lifespan_context(
|
||||
original_context: Lifespan[Any], nested_context: Lifespan[Any]
|
||||
) -> Lifespan[Any]:
|
||||
|
|
@ -160,6 +213,30 @@ def _merge_lifespan_context(
|
|||
return merged_lifespan # type: ignore[return-value]
|
||||
|
||||
|
||||
class _DefaultLifespan:
|
||||
"""
|
||||
Default lifespan context manager that runs on_startup and on_shutdown handlers.
|
||||
|
||||
This is a copy of the Starlette _DefaultLifespan class that was removed
|
||||
in Starlette. FastAPI keeps it to maintain backward compatibility with
|
||||
on_startup and on_shutdown event handlers.
|
||||
|
||||
Ref: https://github.com/Kludex/starlette/pull/3117
|
||||
"""
|
||||
|
||||
def __init__(self, router: "APIRouter") -> None:
|
||||
self._router = router
|
||||
|
||||
async def __aenter__(self) -> None:
|
||||
await self._router._startup()
|
||||
|
||||
async def __aexit__(self, *exc_info: object) -> None:
|
||||
await self._router._shutdown()
|
||||
|
||||
def __call__(self: _T, app: object) -> _T:
|
||||
return self
|
||||
|
||||
|
||||
# Cache for endpoint context to avoid re-extracting on every request
|
||||
_endpoint_context_cache: dict[int, EndpointContext] = {}
|
||||
|
||||
|
|
@ -903,13 +980,33 @@ class APIRouter(routing.Router):
|
|||
),
|
||||
] = Default(generate_unique_id),
|
||||
) -> None:
|
||||
# Handle on_startup/on_shutdown locally since Starlette removed support
|
||||
# Ref: https://github.com/Kludex/starlette/pull/3117
|
||||
# TODO: deprecate this once the lifespan (or alternative) interface is improved
|
||||
self.on_startup: list[Callable[[], Any]] = (
|
||||
[] if on_startup is None else list(on_startup)
|
||||
)
|
||||
self.on_shutdown: list[Callable[[], Any]] = (
|
||||
[] if on_shutdown is None else list(on_shutdown)
|
||||
)
|
||||
|
||||
# Determine the lifespan context to use
|
||||
if lifespan is None:
|
||||
# Use the default lifespan that runs on_startup/on_shutdown handlers
|
||||
lifespan_context: Lifespan[Any] = _DefaultLifespan(self)
|
||||
elif inspect.isasyncgenfunction(lifespan):
|
||||
lifespan_context = asynccontextmanager(lifespan)
|
||||
elif inspect.isgeneratorfunction(lifespan):
|
||||
lifespan_context = _wrap_gen_lifespan_context(lifespan)
|
||||
else:
|
||||
lifespan_context = lifespan
|
||||
self.lifespan_context = lifespan_context
|
||||
|
||||
super().__init__(
|
||||
routes=routes,
|
||||
redirect_slashes=redirect_slashes,
|
||||
default=default,
|
||||
on_startup=on_startup,
|
||||
on_shutdown=on_shutdown,
|
||||
lifespan=lifespan,
|
||||
lifespan=lifespan_context,
|
||||
)
|
||||
if prefix:
|
||||
assert prefix.startswith("/"), "A path prefix must start with '/'"
|
||||
|
|
@ -4473,6 +4570,58 @@ class APIRouter(routing.Router):
|
|||
generate_unique_id_function=generate_unique_id_function,
|
||||
)
|
||||
|
||||
# TODO: remove this once the lifespan (or alternative) interface is improved
|
||||
async def _startup(self) -> None:
|
||||
"""
|
||||
Run any `.on_startup` event handlers.
|
||||
|
||||
This method is kept for backward compatibility after Starlette removed
|
||||
support for on_startup/on_shutdown handlers.
|
||||
|
||||
Ref: https://github.com/Kludex/starlette/pull/3117
|
||||
"""
|
||||
for handler in self.on_startup:
|
||||
if is_async_callable(handler):
|
||||
await handler()
|
||||
else:
|
||||
handler()
|
||||
|
||||
# TODO: remove this once the lifespan (or alternative) interface is improved
|
||||
async def _shutdown(self) -> None:
|
||||
"""
|
||||
Run any `.on_shutdown` event handlers.
|
||||
|
||||
This method is kept for backward compatibility after Starlette removed
|
||||
support for on_startup/on_shutdown handlers.
|
||||
|
||||
Ref: https://github.com/Kludex/starlette/pull/3117
|
||||
"""
|
||||
for handler in self.on_shutdown:
|
||||
if is_async_callable(handler):
|
||||
await handler()
|
||||
else:
|
||||
handler()
|
||||
|
||||
# TODO: remove this once the lifespan (or alternative) interface is improved
|
||||
def add_event_handler(
|
||||
self,
|
||||
event_type: str,
|
||||
func: Callable[[], Any],
|
||||
) -> None:
|
||||
"""
|
||||
Add an event handler function for startup or shutdown.
|
||||
|
||||
This method is kept for backward compatibility after Starlette removed
|
||||
support for on_startup/on_shutdown handlers.
|
||||
|
||||
Ref: https://github.com/Kludex/starlette/pull/3117
|
||||
"""
|
||||
assert event_type in ("startup", "shutdown")
|
||||
if event_type == "startup":
|
||||
self.on_startup.append(func)
|
||||
else:
|
||||
self.on_shutdown.append(func)
|
||||
|
||||
@deprecated(
|
||||
"""
|
||||
on_event is deprecated, use lifespan event handlers instead.
|
||||
|
|
|
|||
|
|
@ -241,3 +241,79 @@ def test_merged_mixed_state_lifespans() -> None:
|
|||
|
||||
with TestClient(app) as client:
|
||||
assert client.app_state == {"router": True}
|
||||
|
||||
|
||||
@pytest.mark.filterwarnings(
|
||||
r"ignore:\s*on_event is deprecated, use lifespan event handlers instead.*:DeprecationWarning"
|
||||
)
|
||||
def test_router_async_shutdown_handler(state: State) -> None:
|
||||
"""Test that async on_shutdown event handlers are called correctly, for coverage."""
|
||||
app = FastAPI()
|
||||
|
||||
@app.get("/")
|
||||
def main() -> dict[str, str]:
|
||||
return {"message": "Hello World"}
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def app_shutdown() -> None:
|
||||
state.app_shutdown = True
|
||||
|
||||
assert state.app_shutdown is False
|
||||
with TestClient(app) as client:
|
||||
assert state.app_shutdown is False
|
||||
response = client.get("/")
|
||||
assert response.status_code == 200, response.text
|
||||
assert state.app_shutdown is True
|
||||
|
||||
|
||||
def test_router_sync_generator_lifespan(state: State) -> None:
|
||||
"""Test that a sync generator lifespan works via _wrap_gen_lifespan_context."""
|
||||
from collections.abc import Generator
|
||||
|
||||
def lifespan(app: FastAPI) -> Generator[None, None, None]:
|
||||
state.app_startup = True
|
||||
yield
|
||||
state.app_shutdown = True
|
||||
|
||||
app = FastAPI(lifespan=lifespan) # type: ignore[arg-type]
|
||||
|
||||
@app.get("/")
|
||||
def main() -> dict[str, str]:
|
||||
return {"message": "Hello World"}
|
||||
|
||||
assert state.app_startup is False
|
||||
assert state.app_shutdown is False
|
||||
with TestClient(app) as client:
|
||||
assert state.app_startup is True
|
||||
assert state.app_shutdown is False
|
||||
response = client.get("/")
|
||||
assert response.status_code == 200, response.text
|
||||
assert response.json() == {"message": "Hello World"}
|
||||
assert state.app_startup is True
|
||||
assert state.app_shutdown is True
|
||||
|
||||
|
||||
def test_router_async_generator_lifespan(state: State) -> None:
|
||||
"""Test that an async generator lifespan (not wrapped) works."""
|
||||
|
||||
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
state.app_startup = True
|
||||
yield
|
||||
state.app_shutdown = True
|
||||
|
||||
app = FastAPI(lifespan=lifespan) # type: ignore[arg-type]
|
||||
|
||||
@app.get("/")
|
||||
def main() -> dict[str, str]:
|
||||
return {"message": "Hello World"}
|
||||
|
||||
assert state.app_startup is False
|
||||
assert state.app_shutdown is False
|
||||
with TestClient(app) as client:
|
||||
assert state.app_startup is True
|
||||
assert state.app_shutdown is False
|
||||
response = client.get("/")
|
||||
assert response.status_code == 200, response.text
|
||||
assert response.json() == {"message": "Hello World"}
|
||||
assert state.app_startup is True
|
||||
assert state.app_shutdown is True
|
||||
|
|
|
|||
Loading…
Reference in New Issue