mirror of https://github.com/tiangolo/fastapi.git
🐛 Ensure that `app.include_router` merges nested lifespans (#9630)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Marcelo Trylesinski <marcelotryle@gmail.com> Co-authored-by: Sebastián Ramírez <tiangolo@gmail.com>
This commit is contained in:
parent
22bf988dfb
commit
3a4ac24675
|
|
@ -3,14 +3,16 @@ import dataclasses
|
|||
import email.message
|
||||
import inspect
|
||||
import json
|
||||
from contextlib import AsyncExitStack
|
||||
from contextlib import AsyncExitStack, asynccontextmanager
|
||||
from enum import Enum, IntEnum
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Coroutine,
|
||||
Dict,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
|
|
@ -67,7 +69,7 @@ from starlette.routing import (
|
|||
websocket_session,
|
||||
)
|
||||
from starlette.routing import Mount as Mount # noqa
|
||||
from starlette.types import ASGIApp, Lifespan, Scope
|
||||
from starlette.types import AppType, ASGIApp, Lifespan, Scope
|
||||
from starlette.websockets import WebSocket
|
||||
from typing_extensions import Annotated, Doc, deprecated
|
||||
|
||||
|
|
@ -119,6 +121,23 @@ def _prepare_response_content(
|
|||
return res
|
||||
|
||||
|
||||
def _merge_lifespan_context(
|
||||
original_context: Lifespan[Any], nested_context: Lifespan[Any]
|
||||
) -> Lifespan[Any]:
|
||||
@asynccontextmanager
|
||||
async def merged_lifespan(
|
||||
app: AppType,
|
||||
) -> AsyncIterator[Optional[Mapping[str, Any]]]:
|
||||
async with original_context(app) as maybe_original_state:
|
||||
async with nested_context(app) as maybe_nested_state:
|
||||
if maybe_nested_state is None and maybe_original_state is None:
|
||||
yield None # old ASGI compatibility
|
||||
else:
|
||||
yield {**(maybe_nested_state or {}), **(maybe_original_state or {})}
|
||||
|
||||
return merged_lifespan # type: ignore[return-value]
|
||||
|
||||
|
||||
async def serialize_response(
|
||||
*,
|
||||
field: Optional[ModelField] = None,
|
||||
|
|
@ -1308,6 +1327,10 @@ class APIRouter(routing.Router):
|
|||
self.add_event_handler("startup", handler)
|
||||
for handler in router.on_shutdown:
|
||||
self.add_event_handler("shutdown", handler)
|
||||
self.lifespan_context = _merge_lifespan_context(
|
||||
self.lifespan_context,
|
||||
router.lifespan_context,
|
||||
)
|
||||
|
||||
def get(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
from contextlib import asynccontextmanager
|
||||
from typing import AsyncGenerator, Dict
|
||||
from typing import AsyncGenerator, Dict, Union
|
||||
|
||||
import pytest
|
||||
from fastapi import APIRouter, FastAPI
|
||||
from fastapi import APIRouter, FastAPI, Request
|
||||
from fastapi.testclient import TestClient
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
|
@ -109,3 +109,134 @@ def test_app_lifespan_state(state: State) -> None:
|
|||
assert response.json() == {"message": "Hello World"}
|
||||
assert state.app_startup is True
|
||||
assert state.app_shutdown is True
|
||||
|
||||
|
||||
def test_router_nested_lifespan_state(state: State) -> None:
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI) -> AsyncGenerator[Dict[str, bool], None]:
|
||||
state.app_startup = True
|
||||
yield {"app": True}
|
||||
state.app_shutdown = True
|
||||
|
||||
@asynccontextmanager
|
||||
async def router_lifespan(app: FastAPI) -> AsyncGenerator[Dict[str, bool], None]:
|
||||
state.router_startup = True
|
||||
yield {"router": True}
|
||||
state.router_shutdown = True
|
||||
|
||||
@asynccontextmanager
|
||||
async def subrouter_lifespan(app: FastAPI) -> AsyncGenerator[Dict[str, bool], None]:
|
||||
state.sub_router_startup = True
|
||||
yield {"sub_router": True}
|
||||
state.sub_router_shutdown = True
|
||||
|
||||
sub_router = APIRouter(lifespan=subrouter_lifespan)
|
||||
|
||||
router = APIRouter(lifespan=router_lifespan)
|
||||
router.include_router(sub_router)
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
app.include_router(router)
|
||||
|
||||
@app.get("/")
|
||||
def main(request: Request) -> Dict[str, str]:
|
||||
assert request.state.app
|
||||
assert request.state.router
|
||||
assert request.state.sub_router
|
||||
return {"message": "Hello World"}
|
||||
|
||||
assert state.app_startup is False
|
||||
assert state.router_startup is False
|
||||
assert state.sub_router_startup is False
|
||||
assert state.app_shutdown is False
|
||||
assert state.router_shutdown is False
|
||||
assert state.sub_router_shutdown is False
|
||||
|
||||
with TestClient(app) as client:
|
||||
assert state.app_startup is True
|
||||
assert state.router_startup is True
|
||||
assert state.sub_router_startup is True
|
||||
assert state.app_shutdown is False
|
||||
assert state.router_shutdown is False
|
||||
assert state.sub_router_shutdown is False
|
||||
response = client.get("/")
|
||||
assert response.status_code == 200, response.text
|
||||
assert response.json() == {"message": "Hello World"}
|
||||
|
||||
assert state.app_startup is True
|
||||
assert state.router_startup is True
|
||||
assert state.sub_router_startup is True
|
||||
assert state.app_shutdown is True
|
||||
assert state.router_shutdown is True
|
||||
assert state.sub_router_shutdown is True
|
||||
|
||||
|
||||
def test_router_nested_lifespan_state_overriding_by_parent() -> None:
|
||||
@asynccontextmanager
|
||||
async def lifespan(
|
||||
app: FastAPI,
|
||||
) -> AsyncGenerator[Dict[str, Union[str, bool]], None]:
|
||||
yield {
|
||||
"app_specific": True,
|
||||
"overridden": "app",
|
||||
}
|
||||
|
||||
@asynccontextmanager
|
||||
async def router_lifespan(
|
||||
app: FastAPI,
|
||||
) -> AsyncGenerator[Dict[str, Union[str, bool]], None]:
|
||||
yield {
|
||||
"router_specific": True,
|
||||
"overridden": "router", # should override parent
|
||||
}
|
||||
|
||||
router = APIRouter(lifespan=router_lifespan)
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
app.include_router(router)
|
||||
|
||||
with TestClient(app) as client:
|
||||
assert client.app_state == {
|
||||
"app_specific": True,
|
||||
"router_specific": True,
|
||||
"overridden": "app",
|
||||
}
|
||||
|
||||
|
||||
def test_merged_no_return_lifespans_return_none() -> None:
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
yield
|
||||
|
||||
@asynccontextmanager
|
||||
async def router_lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
yield
|
||||
|
||||
router = APIRouter(lifespan=router_lifespan)
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
app.include_router(router)
|
||||
|
||||
with TestClient(app) as client:
|
||||
assert not client.app_state
|
||||
|
||||
|
||||
def test_merged_mixed_state_lifespans() -> None:
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
yield
|
||||
|
||||
@asynccontextmanager
|
||||
async def router_lifespan(app: FastAPI) -> AsyncGenerator[Dict[str, bool], None]:
|
||||
yield {"router": True}
|
||||
|
||||
@asynccontextmanager
|
||||
async def sub_router_lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
yield
|
||||
|
||||
sub_router = APIRouter(lifespan=sub_router_lifespan)
|
||||
router = APIRouter(lifespan=router_lifespan)
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
router.include_router(sub_router)
|
||||
app.include_router(router)
|
||||
|
||||
with TestClient(app) as client:
|
||||
assert client.app_state == {"router": True}
|
||||
|
|
|
|||
Loading…
Reference in New Issue