mirror of https://github.com/tiangolo/fastapi.git
Merge 6215c768f2 into eb6851dd4b
This commit is contained in:
commit
c8e1e2a8d7
|
|
@ -1,4 +1,6 @@
|
||||||
|
import inspect
|
||||||
from collections.abc import Awaitable, Callable, Coroutine, Sequence
|
from collections.abc import Awaitable, Callable, Coroutine, Sequence
|
||||||
|
from contextlib import AsyncExitStack, asynccontextmanager
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Annotated, Any, TypeVar
|
from typing import Annotated, Any, TypeVar
|
||||||
|
|
||||||
|
|
@ -37,6 +39,50 @@ from typing_extensions import deprecated
|
||||||
|
|
||||||
AppType = TypeVar("AppType", bound="FastAPI")
|
AppType = TypeVar("AppType", bound="FastAPI")
|
||||||
|
|
||||||
|
# Attribute name on the router used to run lifespan-scoped dependencies at startup.
|
||||||
|
FASTAPI_LIFESPAN_DEPENDENCY_CACHE = "fastapi_lifespan_dependency_cache"
|
||||||
|
|
||||||
|
|
||||||
|
def _wrap_lifespan_with_dependency_cache(original: Any) -> Any:
|
||||||
|
"""Wrap the user's lifespan to run and cache lifespan-scoped dependencies."""
|
||||||
|
|
||||||
|
def wrapped(app: Any) -> Any:
|
||||||
|
@asynccontextmanager
|
||||||
|
async def cm() -> Any:
|
||||||
|
fastapi_app = getattr(app, "_fastapi_app", None)
|
||||||
|
if fastapi_app is None and hasattr(app, "router"):
|
||||||
|
router = getattr(app, "router", None)
|
||||||
|
if router is not None and getattr(router, "_fastapi_app", None) is app:
|
||||||
|
fastapi_app = app
|
||||||
|
router_for_deps = getattr(app, "router", app)
|
||||||
|
stack: AsyncExitStack | None = None
|
||||||
|
orig_cm = original(app)
|
||||||
|
try:
|
||||||
|
if fastapi_app is not None:
|
||||||
|
stack = AsyncExitStack()
|
||||||
|
await stack.__aenter__()
|
||||||
|
cache: dict[Any, Any] = {}
|
||||||
|
await routing._run_lifespan_dependencies(
|
||||||
|
router_for_deps, cache, stack
|
||||||
|
)
|
||||||
|
setattr(
|
||||||
|
fastapi_app.state,
|
||||||
|
FASTAPI_LIFESPAN_DEPENDENCY_CACHE,
|
||||||
|
cache,
|
||||||
|
)
|
||||||
|
yield await orig_cm.__aenter__()
|
||||||
|
finally:
|
||||||
|
import sys
|
||||||
|
|
||||||
|
exc_type, exc_val, exc_tb = sys.exc_info()
|
||||||
|
await orig_cm.__aexit__(exc_type, exc_val, exc_tb)
|
||||||
|
if stack is not None:
|
||||||
|
await stack.__aexit__(exc_type, exc_val, exc_tb)
|
||||||
|
|
||||||
|
return cm()
|
||||||
|
|
||||||
|
return wrapped
|
||||||
|
|
||||||
|
|
||||||
class FastAPI(Starlette):
|
class FastAPI(Starlette):
|
||||||
"""
|
"""
|
||||||
|
|
@ -979,13 +1025,27 @@ class FastAPI(Starlette):
|
||||||
"""
|
"""
|
||||||
),
|
),
|
||||||
] = {}
|
] = {}
|
||||||
|
_inner_lifespan: Callable[[Any], Any]
|
||||||
|
if lifespan is None:
|
||||||
|
|
||||||
|
def _default_lifespan(app: Any) -> Any:
|
||||||
|
return routing._DefaultLifespan(app.router)
|
||||||
|
|
||||||
|
_inner_lifespan = _default_lifespan
|
||||||
|
elif inspect.isasyncgenfunction(lifespan):
|
||||||
|
_inner_lifespan = asynccontextmanager(lifespan)
|
||||||
|
elif inspect.isgeneratorfunction(lifespan):
|
||||||
|
_inner_lifespan = routing._wrap_gen_lifespan_context(lifespan)
|
||||||
|
else:
|
||||||
|
_inner_lifespan = lifespan
|
||||||
|
_lifespan = _wrap_lifespan_with_dependency_cache(_inner_lifespan)
|
||||||
self.router: routing.APIRouter = routing.APIRouter(
|
self.router: routing.APIRouter = routing.APIRouter(
|
||||||
routes=routes,
|
routes=routes,
|
||||||
redirect_slashes=redirect_slashes,
|
redirect_slashes=redirect_slashes,
|
||||||
dependency_overrides_provider=self,
|
dependency_overrides_provider=self,
|
||||||
on_startup=on_startup,
|
on_startup=on_startup,
|
||||||
on_shutdown=on_shutdown,
|
on_shutdown=on_shutdown,
|
||||||
lifespan=lifespan,
|
lifespan=_lifespan,
|
||||||
default_response_class=default_response_class,
|
default_response_class=default_response_class,
|
||||||
dependencies=dependencies,
|
dependencies=dependencies,
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
|
|
@ -995,6 +1055,7 @@ class FastAPI(Starlette):
|
||||||
generate_unique_id_function=generate_unique_id_function,
|
generate_unique_id_function=generate_unique_id_function,
|
||||||
strict_content_type=strict_content_type,
|
strict_content_type=strict_content_type,
|
||||||
)
|
)
|
||||||
|
self.router._fastapi_app = self # type: ignore[attr-defined]
|
||||||
self.exception_handlers: dict[
|
self.exception_handlers: dict[
|
||||||
Any, Callable[[Request, Any], Response | Awaitable[Response]]
|
Any, Callable[[Request, Any], Response | Awaitable[Response]]
|
||||||
] = {} if exception_handlers is None else dict(exception_handlers)
|
] = {} if exception_handlers is None else dict(exception_handlers)
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,15 @@
|
||||||
|
from collections.abc import Callable
|
||||||
|
|
||||||
|
cli_main: Callable[[], None] | None = None
|
||||||
try:
|
try:
|
||||||
from fastapi_cli.cli import main as cli_main
|
from fastapi_cli.cli import main as cli_main
|
||||||
|
|
||||||
except ImportError: # pragma: no cover
|
except ImportError: # pragma: no cover
|
||||||
cli_main = None # type: ignore
|
pass
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
def main() -> None:
|
||||||
if not cli_main: # type: ignore[truthy-function] # ty: ignore[unused-ignore-comment]
|
if cli_main is None:
|
||||||
message = 'To use the fastapi command, please install "fastapi[standard]":\n\n\tpip install "fastapi[standard]"\n'
|
message = 'To use the fastapi command, please install "fastapi[standard]":\n\n\tpip install "fastapi[standard]"\n'
|
||||||
print(message)
|
print(message)
|
||||||
raise RuntimeError(message) # noqa: B904
|
raise RuntimeError(message) # noqa: B904
|
||||||
|
|
|
||||||
|
|
@ -48,7 +48,7 @@ class Dependant:
|
||||||
parent_oauth_scopes: list[str] | None = None
|
parent_oauth_scopes: list[str] | None = None
|
||||||
use_cache: bool = True
|
use_cache: bool = True
|
||||||
path: str | None = None
|
path: str | None = None
|
||||||
scope: Literal["function", "request"] | None = None
|
scope: Literal["function", "request", "lifespan"] | None = None
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def oauth_scopes(self) -> list[str]:
|
def oauth_scopes(self) -> list[str]:
|
||||||
|
|
|
||||||
|
|
@ -216,7 +216,7 @@ def _get_signature(call: Callable[..., Any]) -> inspect.Signature:
|
||||||
except NameError:
|
except NameError:
|
||||||
# Handle type annotations with if TYPE_CHECKING, not used by FastAPI
|
# Handle type annotations with if TYPE_CHECKING, not used by FastAPI
|
||||||
# e.g. dependency return types
|
# e.g. dependency return types
|
||||||
if sys.version_info >= (3, 14):
|
if sys.version_info >= (3, 14): # pragma: no cover
|
||||||
from annotationlib import Format
|
from annotationlib import Format
|
||||||
|
|
||||||
signature = inspect.signature(call, annotation_format=Format.FORWARDREF)
|
signature = inspect.signature(call, annotation_format=Format.FORWARDREF)
|
||||||
|
|
@ -291,7 +291,7 @@ def get_dependant(
|
||||||
own_oauth_scopes: list[str] | None = None,
|
own_oauth_scopes: list[str] | None = None,
|
||||||
parent_oauth_scopes: list[str] | None = None,
|
parent_oauth_scopes: list[str] | None = None,
|
||||||
use_cache: bool = True,
|
use_cache: bool = True,
|
||||||
scope: Literal["function", "request"] | None = None,
|
scope: Literal["function", "request", "lifespan"] | None = None,
|
||||||
) -> Dependant:
|
) -> Dependant:
|
||||||
dependant = Dependant(
|
dependant = Dependant(
|
||||||
call=call,
|
call=call,
|
||||||
|
|
@ -327,6 +327,22 @@ def get_dependant(
|
||||||
f'The dependency "{call_name}" has a scope of '
|
f'The dependency "{call_name}" has a scope of '
|
||||||
'"request", it cannot depend on dependencies with scope "function".'
|
'"request", it cannot depend on dependencies with scope "function".'
|
||||||
)
|
)
|
||||||
|
# Lifespan-scoped dependencies can only depend on other lifespan-scoped deps.
|
||||||
|
if (
|
||||||
|
dependant.computed_scope == "lifespan"
|
||||||
|
and param_details.depends.scope
|
||||||
|
not in (
|
||||||
|
None,
|
||||||
|
"lifespan",
|
||||||
|
)
|
||||||
|
):
|
||||||
|
assert dependant.call
|
||||||
|
call_name = getattr(dependant.call, "__name__", "<unnamed_callable>")
|
||||||
|
raise DependencyScopeError(
|
||||||
|
f'The dependency "{call_name}" has a scope of '
|
||||||
|
'"lifespan", it cannot depend on dependencies with scope '
|
||||||
|
f'"{param_details.depends.scope}".'
|
||||||
|
)
|
||||||
sub_own_oauth_scopes: list[str] = []
|
sub_own_oauth_scopes: list[str] = []
|
||||||
if isinstance(param_details.depends, params.Security):
|
if isinstance(param_details.depends, params.Security):
|
||||||
if param_details.depends.scopes:
|
if param_details.depends.scopes:
|
||||||
|
|
@ -608,6 +624,7 @@ async def solve_dependencies(
|
||||||
# people might be monkey patching this function (although that's not supported)
|
# people might be monkey patching this function (although that's not supported)
|
||||||
async_exit_stack: AsyncExitStack,
|
async_exit_stack: AsyncExitStack,
|
||||||
embed_body_fields: bool,
|
embed_body_fields: bool,
|
||||||
|
solving_lifespan_deps: bool = False,
|
||||||
) -> SolvedDependency:
|
) -> SolvedDependency:
|
||||||
request_astack = request.scope.get("fastapi_inner_astack")
|
request_astack = request.scope.get("fastapi_inner_astack")
|
||||||
assert isinstance(request_astack, AsyncExitStack), (
|
assert isinstance(request_astack, AsyncExitStack), (
|
||||||
|
|
@ -656,6 +673,7 @@ async def solve_dependencies(
|
||||||
dependency_cache=dependency_cache,
|
dependency_cache=dependency_cache,
|
||||||
async_exit_stack=async_exit_stack,
|
async_exit_stack=async_exit_stack,
|
||||||
embed_body_fields=embed_body_fields,
|
embed_body_fields=embed_body_fields,
|
||||||
|
solving_lifespan_deps=solving_lifespan_deps,
|
||||||
)
|
)
|
||||||
background_tasks = solved_result.background_tasks
|
background_tasks = solved_result.background_tasks
|
||||||
if solved_result.errors:
|
if solved_result.errors:
|
||||||
|
|
@ -663,6 +681,30 @@ async def solve_dependencies(
|
||||||
continue
|
continue
|
||||||
if sub_dependant.use_cache and sub_dependant.cache_key in dependency_cache:
|
if sub_dependant.use_cache and sub_dependant.cache_key in dependency_cache:
|
||||||
solved = dependency_cache[sub_dependant.cache_key]
|
solved = dependency_cache[sub_dependant.cache_key]
|
||||||
|
elif sub_dependant.computed_scope == "lifespan":
|
||||||
|
# At request time, lifespan deps must come from cache (set at startup).
|
||||||
|
if sub_dependant.cache_key in dependency_cache:
|
||||||
|
solved = dependency_cache[sub_dependant.cache_key] # pragma: no cover
|
||||||
|
elif solving_lifespan_deps:
|
||||||
|
# At startup: run the lifespan dep; request_astack is the lifespan stack.
|
||||||
|
if (
|
||||||
|
use_sub_dependant.is_gen_callable
|
||||||
|
or use_sub_dependant.is_async_gen_callable
|
||||||
|
):
|
||||||
|
solved = await _solve_generator(
|
||||||
|
dependant=use_sub_dependant,
|
||||||
|
stack=request_astack,
|
||||||
|
sub_values=solved_result.values,
|
||||||
|
)
|
||||||
|
elif use_sub_dependant.is_coroutine_callable:
|
||||||
|
solved = await call(**solved_result.values)
|
||||||
|
else:
|
||||||
|
solved = await run_in_threadpool(call, **solved_result.values)
|
||||||
|
else:
|
||||||
|
raise DependencyScopeError(
|
||||||
|
"Lifespan-scoped dependency was not initialized at application startup. "
|
||||||
|
"Ensure the application lifespan runs and populates lifespan dependencies."
|
||||||
|
)
|
||||||
elif (
|
elif (
|
||||||
use_sub_dependant.is_gen_callable or use_sub_dependant.is_async_gen_callable
|
use_sub_dependant.is_gen_callable or use_sub_dependant.is_async_gen_callable
|
||||||
):
|
):
|
||||||
|
|
|
||||||
|
|
@ -2314,7 +2314,7 @@ def Depends( # noqa: N802
|
||||||
),
|
),
|
||||||
] = True,
|
] = True,
|
||||||
scope: Annotated[
|
scope: Annotated[
|
||||||
Literal["function", "request"] | None,
|
Literal["function", "request", "lifespan"] | None,
|
||||||
Doc(
|
Doc(
|
||||||
"""
|
"""
|
||||||
Mainly for dependencies with `yield`, define when the dependency function
|
Mainly for dependencies with `yield`, define when the dependency function
|
||||||
|
|
@ -2330,6 +2330,10 @@ def Depends( # noqa: N802
|
||||||
that handles the request (similar to when using `"function"`), but end
|
that handles the request (similar to when using `"function"`), but end
|
||||||
**after** the response is sent back to the client. So, the dependency
|
**after** the response is sent back to the client. So, the dependency
|
||||||
function will be executed **around** the **request** and response cycle.
|
function will be executed **around** the **request** and response cycle.
|
||||||
|
* `"lifespan"`: the dependency is evaluated **once** when the application
|
||||||
|
starts and the same value is reused for every request. It is cleaned up
|
||||||
|
when the application shuts down. Use this for resources like database
|
||||||
|
connection pools that should live for the application lifetime.
|
||||||
|
|
||||||
Read more about it in the
|
Read more about it in the
|
||||||
[FastAPI docs for FastAPI Dependencies with yield](https://fastapi.tiangolo.com/tutorial/dependencies/dependencies-with-yield/#early-exit-and-scope)
|
[FastAPI docs for FastAPI Dependencies with yield](https://fastapi.tiangolo.com/tutorial/dependencies/dependencies-with-yield/#early-exit-and-scope)
|
||||||
|
|
|
||||||
|
|
@ -746,7 +746,7 @@ class File(Form): # type: ignore[misc] # ty: ignore[unused-ignore-comment]
|
||||||
class Depends:
|
class Depends:
|
||||||
dependency: Callable[..., Any] | None = None
|
dependency: Callable[..., Any] | None = None
|
||||||
use_cache: bool = True
|
use_cache: bool = True
|
||||||
scope: Literal["function", "request"] | None = None
|
scope: Literal["function", "request", "lifespan"] | None = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
|
|
|
||||||
|
|
@ -206,6 +206,61 @@ def _wrap_gen_lifespan_context(
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def _collect_lifespan_dependants(router: "APIRouter") -> list[Dependant]:
|
||||||
|
"""Collect all unique lifespan-scoped dependants from router and nested routers."""
|
||||||
|
seen: dict[tuple[Any, ...], Dependant] = {}
|
||||||
|
for route in router.routes:
|
||||||
|
if isinstance(route, APIRoute):
|
||||||
|
flat = get_flat_dependant(route.dependant)
|
||||||
|
for d in flat.dependencies:
|
||||||
|
if d.computed_scope == "lifespan":
|
||||||
|
key = d.cache_key
|
||||||
|
if key not in seen:
|
||||||
|
seen[key] = d
|
||||||
|
return list(seen.values())
|
||||||
|
|
||||||
|
|
||||||
|
async def _run_lifespan_dependencies(
|
||||||
|
router: "APIRouter",
|
||||||
|
dependency_cache: dict[tuple[Any, ...], Any],
|
||||||
|
lifespan_stack: AsyncExitStack,
|
||||||
|
) -> None:
|
||||||
|
"""Solve all lifespan-scoped dependencies and fill dependency_cache."""
|
||||||
|
from starlette.requests import Request
|
||||||
|
|
||||||
|
lifespan_deps = _collect_lifespan_dependants(router)
|
||||||
|
if not lifespan_deps:
|
||||||
|
return
|
||||||
|
synthetic = Dependant(call=None, path="/", dependencies=lifespan_deps)
|
||||||
|
# Minimal scope so solve_dependencies can run; lifespan_stack used for cleanup.
|
||||||
|
scope: dict[str, Any] = {
|
||||||
|
"type": "http",
|
||||||
|
"path": "/",
|
||||||
|
"path_params": {},
|
||||||
|
"query_string": b"",
|
||||||
|
"headers": [],
|
||||||
|
"fastapi_inner_astack": lifespan_stack,
|
||||||
|
"fastapi_function_astack": lifespan_stack,
|
||||||
|
}
|
||||||
|
|
||||||
|
async def noop_receive() -> Any:
|
||||||
|
return {"type": "http.disconnect"}
|
||||||
|
|
||||||
|
async def noop_send(message: Any) -> None: # pragma: no cover
|
||||||
|
pass # ASGI send not used by lifespan dependency resolution
|
||||||
|
|
||||||
|
request = Request(scope, noop_receive, noop_send)
|
||||||
|
await solve_dependencies(
|
||||||
|
request=request,
|
||||||
|
dependant=synthetic,
|
||||||
|
body=None,
|
||||||
|
dependency_cache=dependency_cache,
|
||||||
|
async_exit_stack=lifespan_stack,
|
||||||
|
embed_body_fields=False,
|
||||||
|
solving_lifespan_deps=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _merge_lifespan_context(
|
def _merge_lifespan_context(
|
||||||
original_context: Lifespan[Any], nested_context: Lifespan[Any]
|
original_context: Lifespan[Any], nested_context: Lifespan[Any]
|
||||||
) -> Lifespan[Any]:
|
) -> Lifespan[Any]:
|
||||||
|
|
@ -454,11 +509,16 @@ def get_request_handler(
|
||||||
assert isinstance(async_exit_stack, AsyncExitStack), (
|
assert isinstance(async_exit_stack, AsyncExitStack), (
|
||||||
"fastapi_inner_astack not found in request scope"
|
"fastapi_inner_astack not found in request scope"
|
||||||
)
|
)
|
||||||
|
lifespan_cache = getattr(
|
||||||
|
request.app.state, "fastapi_lifespan_dependency_cache", None
|
||||||
|
)
|
||||||
|
dependency_cache = dict(lifespan_cache) if lifespan_cache else None
|
||||||
solved_result = await solve_dependencies(
|
solved_result = await solve_dependencies(
|
||||||
request=request,
|
request=request,
|
||||||
dependant=dependant,
|
dependant=dependant,
|
||||||
body=cast(dict[str, Any] | FormData | bytes | None, body),
|
body=cast(dict[str, Any] | FormData | bytes | None, body),
|
||||||
dependency_overrides_provider=dependency_overrides_provider,
|
dependency_overrides_provider=dependency_overrides_provider,
|
||||||
|
dependency_cache=dependency_cache,
|
||||||
async_exit_stack=async_exit_stack,
|
async_exit_stack=async_exit_stack,
|
||||||
embed_body_fields=embed_body_fields,
|
embed_body_fields=embed_body_fields,
|
||||||
)
|
)
|
||||||
|
|
@ -748,10 +808,15 @@ def get_websocket_app(
|
||||||
assert isinstance(async_exit_stack, AsyncExitStack), (
|
assert isinstance(async_exit_stack, AsyncExitStack), (
|
||||||
"fastapi_inner_astack not found in request scope"
|
"fastapi_inner_astack not found in request scope"
|
||||||
)
|
)
|
||||||
|
lifespan_cache = getattr(
|
||||||
|
websocket.app.state, "fastapi_lifespan_dependency_cache", None
|
||||||
|
)
|
||||||
|
dependency_cache = dict(lifespan_cache) if lifespan_cache else None
|
||||||
solved_result = await solve_dependencies(
|
solved_result = await solve_dependencies(
|
||||||
request=websocket,
|
request=websocket,
|
||||||
dependant=dependant,
|
dependant=dependant,
|
||||||
dependency_overrides_provider=dependency_overrides_provider,
|
dependency_overrides_provider=dependency_overrides_provider,
|
||||||
|
dependency_cache=dependency_cache,
|
||||||
async_exit_stack=async_exit_stack,
|
async_exit_stack=async_exit_stack,
|
||||||
embed_body_fields=embed_body_fields,
|
embed_body_fields=embed_body_fields,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,282 @@
|
||||||
|
"""Tests for lifespan-scoped dependencies (Depends(..., scope="lifespan"))."""
|
||||||
|
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi import Depends, FastAPI
|
||||||
|
from fastapi.exceptions import DependencyScopeError
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
from starlette.requests import Request
|
||||||
|
|
||||||
|
|
||||||
|
def test_lifespan_dependency_single_request() -> None:
|
||||||
|
"""Lifespan-scoped dependency is created once and reused across requests."""
|
||||||
|
started: list[str] = []
|
||||||
|
stopped: list[str] = []
|
||||||
|
|
||||||
|
def get_db() -> str:
|
||||||
|
started.append("db")
|
||||||
|
yield "db_conn"
|
||||||
|
stopped.append("db")
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
@app.get("/")
|
||||||
|
def root(db: Annotated[str, Depends(get_db, scope="lifespan")]) -> dict[str, str]:
|
||||||
|
return {"db": db}
|
||||||
|
|
||||||
|
assert len(started) == 0
|
||||||
|
assert len(stopped) == 0
|
||||||
|
|
||||||
|
with TestClient(app) as client:
|
||||||
|
assert len(started) == 1, "lifespan dep should start once at app startup"
|
||||||
|
r1 = client.get("/")
|
||||||
|
assert r1.status_code == 200
|
||||||
|
assert r1.json() == {"db": "db_conn"}
|
||||||
|
r2 = client.get("/")
|
||||||
|
assert r2.status_code == 200
|
||||||
|
assert r2.json() == {"db": "db_conn"}
|
||||||
|
assert len(started) == 1, "lifespan dep should not restart per request"
|
||||||
|
|
||||||
|
assert len(stopped) == 1, "lifespan dep should stop once at app shutdown"
|
||||||
|
|
||||||
|
|
||||||
|
def test_lifespan_dependency_with_custom_lifespan() -> None:
|
||||||
|
"""Lifespan-scoped dependency runs inside app lifespan and is cleaned up on shutdown."""
|
||||||
|
started: list[str] = []
|
||||||
|
stopped: list[str] = []
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
started.append("lifespan")
|
||||||
|
yield
|
||||||
|
stopped.append("lifespan")
|
||||||
|
|
||||||
|
def get_pool() -> str:
|
||||||
|
started.append("pool")
|
||||||
|
yield "pool_conn"
|
||||||
|
stopped.append("pool")
|
||||||
|
|
||||||
|
app = FastAPI(lifespan=lifespan)
|
||||||
|
|
||||||
|
@app.get("/")
|
||||||
|
def root(
|
||||||
|
pool: Annotated[str, Depends(get_pool, scope="lifespan")],
|
||||||
|
) -> dict[str, str]:
|
||||||
|
return {"pool": pool}
|
||||||
|
|
||||||
|
with TestClient(app) as client:
|
||||||
|
assert "lifespan" in started
|
||||||
|
assert "pool" in started
|
||||||
|
r = client.get("/")
|
||||||
|
assert r.status_code == 200
|
||||||
|
assert r.json() == {"pool": "pool_conn"}
|
||||||
|
|
||||||
|
assert "pool" in stopped
|
||||||
|
assert "lifespan" in stopped
|
||||||
|
|
||||||
|
|
||||||
|
def test_lifespan_dependency_same_instance_across_requests() -> None:
|
||||||
|
"""The same instance is injected for every request when scope is lifespan."""
|
||||||
|
instances: list[object] = []
|
||||||
|
|
||||||
|
def get_singleton() -> object:
|
||||||
|
inst = object()
|
||||||
|
instances.append(inst)
|
||||||
|
yield inst
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
@app.get("/")
|
||||||
|
def root(
|
||||||
|
s: Annotated[object, Depends(get_singleton, scope="lifespan")],
|
||||||
|
) -> dict[str, bool]:
|
||||||
|
return {"is_singleton": len(instances) == 1 and s is instances[0]}
|
||||||
|
|
||||||
|
with TestClient(app) as client:
|
||||||
|
r1 = client.get("/")
|
||||||
|
r2 = client.get("/")
|
||||||
|
assert r1.status_code == 200 and r2.status_code == 200
|
||||||
|
assert r1.json()["is_singleton"] is True
|
||||||
|
assert r2.json()["is_singleton"] is True
|
||||||
|
assert len(instances) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_lifespan_dependency_decorator_level_dependencies_runs_at_startup() -> None:
|
||||||
|
"""Decorator-level dependencies=[Depends(..., scope='lifespan')] run at startup once."""
|
||||||
|
started: list[str] = []
|
||||||
|
stopped: list[str] = []
|
||||||
|
|
||||||
|
def lifespan_dep() -> str:
|
||||||
|
started.append("lifespan_dep")
|
||||||
|
yield "ok"
|
||||||
|
stopped.append("lifespan_dep")
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
@app.get("/", dependencies=[Depends(lifespan_dep, scope="lifespan")])
|
||||||
|
def root() -> dict[str, str]:
|
||||||
|
return {"ok": "yes"}
|
||||||
|
|
||||||
|
with TestClient(app) as client:
|
||||||
|
assert started == ["lifespan_dep"]
|
||||||
|
r1 = client.get("/")
|
||||||
|
r2 = client.get("/")
|
||||||
|
assert r1.status_code == 200 and r2.status_code == 200
|
||||||
|
assert r1.json() == {"ok": "yes"}
|
||||||
|
assert r2.json() == {"ok": "yes"}
|
||||||
|
assert started == ["lifespan_dep"]
|
||||||
|
|
||||||
|
assert stopped == ["lifespan_dep"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_lifespan_dependency_synthetic_request_receive_send() -> None:
|
||||||
|
"""Lifespan dep that uses Request.receive covers noop_receive during startup."""
|
||||||
|
|
||||||
|
async def lifespan_dep(request: Request) -> str:
|
||||||
|
await request.receive()
|
||||||
|
return "ok"
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
@app.get("/")
|
||||||
|
def root(
|
||||||
|
v: Annotated[str, Depends(lifespan_dep, scope="lifespan")],
|
||||||
|
) -> dict[str, str]:
|
||||||
|
return {"v": v}
|
||||||
|
|
||||||
|
with TestClient(app) as client:
|
||||||
|
r = client.get("/")
|
||||||
|
assert r.status_code == 200
|
||||||
|
assert r.json() == {"v": "ok"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_lifespan_dependency_sync_callable() -> None:
|
||||||
|
"""Sync (non-gen, non-coroutine) lifespan dep runs via run_in_threadpool (utils 702)."""
|
||||||
|
|
||||||
|
def sync_lifespan_dep() -> str:
|
||||||
|
return "sync_val"
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
@app.get("/")
|
||||||
|
def root(
|
||||||
|
v: Annotated[str, Depends(sync_lifespan_dep, scope="lifespan")],
|
||||||
|
) -> dict[str, str]:
|
||||||
|
return {"v": v}
|
||||||
|
|
||||||
|
with TestClient(app) as client:
|
||||||
|
r = client.get("/")
|
||||||
|
assert r.status_code == 200
|
||||||
|
assert r.json() == {"v": "sync_val"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_lifespan_dependency_nested() -> None:
|
||||||
|
"""Lifespan dep B depending on A covers dependency_cache hit path (utils.py line 685)."""
|
||||||
|
order: list[str] = []
|
||||||
|
|
||||||
|
def lifespan_a() -> str:
|
||||||
|
order.append("a")
|
||||||
|
yield "a"
|
||||||
|
|
||||||
|
def lifespan_b(
|
||||||
|
a: Annotated[str, Depends(lifespan_a, scope="lifespan")],
|
||||||
|
) -> str:
|
||||||
|
order.append("b")
|
||||||
|
yield a + "-b"
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
@app.get("/")
|
||||||
|
def root(
|
||||||
|
b: Annotated[str, Depends(lifespan_b, scope="lifespan")],
|
||||||
|
) -> dict[str, str]:
|
||||||
|
return {"b": b}
|
||||||
|
|
||||||
|
with TestClient(app) as client:
|
||||||
|
r = client.get("/")
|
||||||
|
assert r.status_code == 200
|
||||||
|
assert r.json() == {"b": "a-b"}
|
||||||
|
assert order == ["a", "b"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_lifespan_dependency_shared_cache_hit() -> None:
|
||||||
|
"""Two lifespan deps B and C both depend on A; second resolution hits cache (utils 687)."""
|
||||||
|
order: list[str] = []
|
||||||
|
|
||||||
|
def lifespan_a() -> str:
|
||||||
|
order.append("a")
|
||||||
|
yield "a"
|
||||||
|
|
||||||
|
def lifespan_b(
|
||||||
|
a: Annotated[str, Depends(lifespan_a, scope="lifespan")],
|
||||||
|
) -> str:
|
||||||
|
order.append("b")
|
||||||
|
yield a + "-b"
|
||||||
|
|
||||||
|
def lifespan_c(
|
||||||
|
a: Annotated[str, Depends(lifespan_a, scope="lifespan")],
|
||||||
|
) -> str:
|
||||||
|
order.append("c")
|
||||||
|
yield a + "-c"
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
@app.get("/")
|
||||||
|
def root(
|
||||||
|
b: Annotated[str, Depends(lifespan_b, scope="lifespan")],
|
||||||
|
c: Annotated[str, Depends(lifespan_c, scope="lifespan")],
|
||||||
|
) -> dict[str, str]:
|
||||||
|
return {"b": b, "c": c}
|
||||||
|
|
||||||
|
with TestClient(app) as client:
|
||||||
|
r = client.get("/")
|
||||||
|
assert r.status_code == 200
|
||||||
|
assert r.json() == {"b": "a-b", "c": "a-c"}
|
||||||
|
assert order == ["a", "b", "c"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_lifespan_dependency_cannot_depend_on_request_scope() -> None:
|
||||||
|
"""Lifespan-scoped dependency that depends on request-scoped dep raises."""
|
||||||
|
|
||||||
|
def request_scoped() -> int:
|
||||||
|
return 1 # pragma: no cover - never run; raises at app.get("/")(root)
|
||||||
|
|
||||||
|
def lifespan_dep(
|
||||||
|
x: Annotated[int, Depends(request_scoped, scope="request")],
|
||||||
|
) -> int:
|
||||||
|
return x # pragma: no cover - never run; raises at app.get("/")(root)
|
||||||
|
|
||||||
|
def root(
|
||||||
|
y: Annotated[int, Depends(lifespan_dep, scope="lifespan")],
|
||||||
|
) -> dict[str, int]:
|
||||||
|
return {"y": y} # pragma: no cover - never run; raises at app.get("/")(root)
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
with pytest.raises(DependencyScopeError) as exc_info:
|
||||||
|
app.get("/")(root)
|
||||||
|
assert "lifespan" in str(exc_info.value) and "cannot depend" in str(exc_info.value)
|
||||||
|
|
||||||
|
|
||||||
|
def test_lifespan_dependency_not_initialized_raises() -> None:
|
||||||
|
"""Request that needs a lifespan dep which was not run (e.g. mounted sub-app) raises."""
|
||||||
|
|
||||||
|
def lifespan_dep() -> str:
|
||||||
|
yield "conn" # pragma: no cover - never run; request raises before dep runs
|
||||||
|
|
||||||
|
sub_app = FastAPI()
|
||||||
|
|
||||||
|
@sub_app.get("/sub")
|
||||||
|
def sub_root(
|
||||||
|
x: Annotated[str, Depends(lifespan_dep, scope="lifespan")],
|
||||||
|
) -> dict[str, str]:
|
||||||
|
return {"x": x} # pragma: no cover - never run; request raises before handler
|
||||||
|
|
||||||
|
main_app = FastAPI()
|
||||||
|
main_app.mount("/mounted", sub_app)
|
||||||
|
|
||||||
|
with TestClient(main_app) as client:
|
||||||
|
with pytest.raises(DependencyScopeError) as exc_info:
|
||||||
|
client.get("/mounted/sub")
|
||||||
|
assert "lifespan" in str(exc_info.value).lower()
|
||||||
|
|
@ -318,6 +318,53 @@ def test_router_async_generator_lifespan(state: State) -> None:
|
||||||
assert state.app_shutdown is True
|
assert state.app_shutdown is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_router_apirouter_raw_async_gen_lifespan(state: State) -> None:
|
||||||
|
"""APIRouter(lifespan=raw_async_gen) normalizes via asynccontextmanager (routing 1344)."""
|
||||||
|
|
||||||
|
async def router_lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||||
|
state.router_startup = True
|
||||||
|
yield
|
||||||
|
state.router_shutdown = True
|
||||||
|
|
||||||
|
router = APIRouter(lifespan=router_lifespan)
|
||||||
|
|
||||||
|
@router.get("/")
|
||||||
|
def main() -> dict[str, str]:
|
||||||
|
return {"message": "ok"}
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
app.include_router(router)
|
||||||
|
|
||||||
|
with TestClient(app) as client:
|
||||||
|
assert state.router_startup is True
|
||||||
|
assert client.get("/").json() == {"message": "ok"}
|
||||||
|
assert state.router_shutdown is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_router_apirouter_raw_sync_gen_lifespan(state: State) -> None:
|
||||||
|
"""APIRouter(lifespan=raw_sync_gen) normalizes via _wrap_gen_lifespan_context (routing 1346)."""
|
||||||
|
from collections.abc import Generator
|
||||||
|
|
||||||
|
def router_lifespan(app: FastAPI) -> Generator[None, None, None]:
|
||||||
|
state.router_startup = True
|
||||||
|
yield
|
||||||
|
state.router_shutdown = True
|
||||||
|
|
||||||
|
router = APIRouter(lifespan=router_lifespan)
|
||||||
|
|
||||||
|
@router.get("/")
|
||||||
|
def main() -> dict[str, str]:
|
||||||
|
return {"message": "ok"}
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
app.include_router(router)
|
||||||
|
|
||||||
|
with TestClient(app) as client:
|
||||||
|
assert state.router_startup is True
|
||||||
|
assert client.get("/").json() == {"message": "ok"}
|
||||||
|
assert state.router_shutdown is True
|
||||||
|
|
||||||
|
|
||||||
def test_startup_shutdown_handlers_as_parameters(state: State) -> None:
|
def test_startup_shutdown_handlers_as_parameters(state: State) -> None:
|
||||||
"""Test that startup/shutdown handlers passed as parameters to FastAPI are called correctly."""
|
"""Test that startup/shutdown handlers passed as parameters to FastAPI are called correctly."""
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue