mirror of https://github.com/tiangolo/fastapi.git
✨ Add support for `lifespan` async context managers (superseding `startup` and `shutdown` events) (#2944)
Co-authored-by: Mike Shantz <mshantz@coldstorage.com> Co-authored-by: Jonathan Plasse <13716151+JonathanPlasse@users.noreply.github.com> Co-authored-by: Sebastián Ramírez <tiangolo@gmail.com>
This commit is contained in:
parent
66e03c816b
commit
cc9a73c3f8
|
|
@ -1,13 +1,108 @@
|
||||||
# Events: startup - shutdown
|
# Lifespan Events
|
||||||
|
|
||||||
|
You can define logic (code) that should be executed before the application **starts up**. This means that this code will be executed **once**, **before** the application **starts receiving requests**.
|
||||||
|
|
||||||
|
The same way, you can define logic (code) that should be executed when the application is **shutting down**. In this case, this code will be executed **once**, **after** having handled possibly **many requests**.
|
||||||
|
|
||||||
|
Because this code is executed before the application **starts** taking requests, and right after it **finishes** handling requests, it covers the whole application **lifespan** (the word "lifespan" will be important in a second 😉).
|
||||||
|
|
||||||
|
This can be very useful for setting up **resources** that you need to use for the whole app, and that are **shared** among requests, and/or that you need to **clean up** afterwards. For example, a database connection pool, or loading a shared machine learning model.
|
||||||
|
|
||||||
|
## Use Case
|
||||||
|
|
||||||
|
Let's start with an example **use case** and then see how to solve it with this.
|
||||||
|
|
||||||
|
Let's imagine that you have some **machine learning models** that you want to use to handle requests. 🤖
|
||||||
|
|
||||||
|
The same models are shared among requests, so, it's not one model per request, or one per user or something similar.
|
||||||
|
|
||||||
|
Let's imagine that loading the model can **take quite some time**, because it has to read a lot of **data from disk**. So you don't want to do it for every request.
|
||||||
|
|
||||||
|
You could load it at the top level of the module/file, but that would also mean that it would **load the model** even if you are just running a simple automated test, then that test would be **slow** because it would have to wait for the model to load before being able to run an independent part of the code.
|
||||||
|
|
||||||
|
That's what we'll solve, let's load the model before the requests are handled, but only right before the application starts receiving requests, not while the code is being loaded.
|
||||||
|
|
||||||
|
## Lifespan
|
||||||
|
|
||||||
|
You can define this *startup* and *shutdown* logic using the `lifespan` parameter of the `FastAPI` app, and a "context manager" (I'll show you what that is in a second).
|
||||||
|
|
||||||
|
Let's start with an example and then see it in detail.
|
||||||
|
|
||||||
|
We create an async function `lifespan()` with `yield` like this:
|
||||||
|
|
||||||
|
```Python hl_lines="16 19"
|
||||||
|
{!../../../docs_src/events/tutorial003.py!}
|
||||||
|
```
|
||||||
|
|
||||||
|
Here we are simulating the expensive *startup* operation of loading the model by putting the (fake) model function in the dictionary with machine learning models before the `yield`. This code will be executed **before** the application **starts taking requests**, during the *startup*.
|
||||||
|
|
||||||
|
And then, right after the `yield`, we unload the model. This code will be executed **after** the application **finishes handling requests**, right before the *shutdown*. This could, for example, release resources like memory or a GPU.
|
||||||
|
|
||||||
|
!!! tip
|
||||||
|
The `shutdown` would happen when you are **stopping** the application.
|
||||||
|
|
||||||
|
Maybe you need to start a new version, or you just got tired of running it. 🤷
|
||||||
|
|
||||||
|
### Lifespan function
|
||||||
|
|
||||||
|
The first thing to notice, is that we are defining an async function with `yield`. This is very similar to Dependencies with `yield`.
|
||||||
|
|
||||||
|
```Python hl_lines="14-19"
|
||||||
|
{!../../../docs_src/events/tutorial003.py!}
|
||||||
|
```
|
||||||
|
|
||||||
|
The first part of the function, before the `yield`, will be executed **before** the application starts.
|
||||||
|
|
||||||
|
And the part after the `yield` will be executed **after** the application has finished.
|
||||||
|
|
||||||
|
### Async Context Manager
|
||||||
|
|
||||||
|
If you check, the function is decorated with an `@asynccontextmanager`.
|
||||||
|
|
||||||
|
That converts the function into something called an "**async context manager**".
|
||||||
|
|
||||||
|
```Python hl_lines="1 13"
|
||||||
|
{!../../../docs_src/events/tutorial003.py!}
|
||||||
|
```
|
||||||
|
|
||||||
|
A **context manager** in Python is something that you can use in a `with` statement, for example, `open()` can be used as a context manager:
|
||||||
|
|
||||||
|
```Python
|
||||||
|
with open("file.txt") as file:
|
||||||
|
file.read()
|
||||||
|
```
|
||||||
|
|
||||||
|
In recent versions of Python, there's also an **async context manager**. You would use it with `async with`:
|
||||||
|
|
||||||
|
```Python
|
||||||
|
async with lifespan(app):
|
||||||
|
await do_stuff()
|
||||||
|
```
|
||||||
|
|
||||||
|
When you create a context manager or an async context manager like above, what it does is that, before entering the `with` block, it will execute the code before the `yield`, and after exiting the `with` block, it will execute the code after the `yield`.
|
||||||
|
|
||||||
|
In our code example above, we don't use it directly, but we pass it to FastAPI for it to use it.
|
||||||
|
|
||||||
|
The `lifespan` parameter of the `FastAPI` app takes an **async context manager**, so we can pass our new `lifespan` async context manager to it.
|
||||||
|
|
||||||
|
```Python hl_lines="22"
|
||||||
|
{!../../../docs_src/events/tutorial003.py!}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Alternative Events (deprecated)
|
||||||
|
|
||||||
|
!!! warning
|
||||||
|
The recommended way to handle the *startup* and *shutdown* is using the `lifespan` parameter of the `FastAPI` app as described above.
|
||||||
|
|
||||||
|
You can probably skip this part.
|
||||||
|
|
||||||
|
There's an alternative way to define this logic to be executed during *startup* and during *shutdown*.
|
||||||
|
|
||||||
You can define event handlers (functions) that need to be executed before the application starts up, or when the application is shutting down.
|
You can define event handlers (functions) that need to be executed before the application starts up, or when the application is shutting down.
|
||||||
|
|
||||||
These functions can be declared with `async def` or normal `def`.
|
These functions can be declared with `async def` or normal `def`.
|
||||||
|
|
||||||
!!! warning
|
### `startup` event
|
||||||
Only event handlers for the main application will be executed, not for [Sub Applications - Mounts](./sub-applications.md){.internal-link target=_blank}.
|
|
||||||
|
|
||||||
## `startup` event
|
|
||||||
|
|
||||||
To add a function that should be run before the application starts, declare it with the event `"startup"`:
|
To add a function that should be run before the application starts, declare it with the event `"startup"`:
|
||||||
|
|
||||||
|
|
@ -21,7 +116,7 @@ You can add more than one event handler function.
|
||||||
|
|
||||||
And your application won't start receiving requests until all the `startup` event handlers have completed.
|
And your application won't start receiving requests until all the `startup` event handlers have completed.
|
||||||
|
|
||||||
## `shutdown` event
|
### `shutdown` event
|
||||||
|
|
||||||
To add a function that should be run when the application is shutting down, declare it with the event `"shutdown"`:
|
To add a function that should be run when the application is shutting down, declare it with the event `"shutdown"`:
|
||||||
|
|
||||||
|
|
@ -45,3 +140,21 @@ Here, the `shutdown` event handler function will write a text line `"Application
|
||||||
|
|
||||||
!!! info
|
!!! info
|
||||||
You can read more about these event handlers in <a href="https://www.starlette.io/events/" class="external-link" target="_blank">Starlette's Events' docs</a>.
|
You can read more about these event handlers in <a href="https://www.starlette.io/events/" class="external-link" target="_blank">Starlette's Events' docs</a>.
|
||||||
|
|
||||||
|
### `startup` and `shutdown` together
|
||||||
|
|
||||||
|
There's a high chance that the logic for your *startup* and *shutdown* is connected, you might want to start something and then finish it, acquire a resource and then release it, etc.
|
||||||
|
|
||||||
|
Doing that in separated functions that don't share logic or variables together is more difficult as you would need to store values in global variables or similar tricks.
|
||||||
|
|
||||||
|
Because of that, it's now recommended to instead use the `lifespan` as explained above.
|
||||||
|
|
||||||
|
## Technical Details
|
||||||
|
|
||||||
|
Just a technical detail for the curious nerds. 🤓
|
||||||
|
|
||||||
|
Underneath, in the ASGI technical specification, this is part of the <a href="https://asgi.readthedocs.io/en/latest/specs/lifespan.html" class="external-link" target="_blank">Lifespan Protocol</a>, and it defines events called `startup` and `shutdown`.
|
||||||
|
|
||||||
|
## Sub Applications
|
||||||
|
|
||||||
|
🚨 Have in mind that these lifespan events (startup and shutdown) will only be executed for the main application, not for [Sub Applications - Mounts](./sub-applications.md){.internal-link target=_blank}.
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,28 @@
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
|
from fastapi import FastAPI
|
||||||
|
|
||||||
|
|
||||||
|
def fake_answer_to_everything_ml_model(x: float):
|
||||||
|
return x * 42
|
||||||
|
|
||||||
|
|
||||||
|
ml_models = {}
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
# Load the ML model
|
||||||
|
ml_models["answer_to_everything"] = fake_answer_to_everything_ml_model
|
||||||
|
yield
|
||||||
|
# Clean up the ML models and release the resources
|
||||||
|
ml_models.clear()
|
||||||
|
|
||||||
|
|
||||||
|
app = FastAPI(lifespan=lifespan)
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/predict")
|
||||||
|
async def predict(x: float):
|
||||||
|
result = ml_models["answer_to_everything"](x)
|
||||||
|
return {"result": result}
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
|
AsyncContextManager,
|
||||||
Awaitable,
|
Awaitable,
|
||||||
Callable,
|
Callable,
|
||||||
Coroutine,
|
Coroutine,
|
||||||
|
|
@ -71,6 +72,7 @@ class FastAPI(Starlette):
|
||||||
] = None,
|
] = None,
|
||||||
on_startup: Optional[Sequence[Callable[[], Any]]] = None,
|
on_startup: Optional[Sequence[Callable[[], Any]]] = None,
|
||||||
on_shutdown: Optional[Sequence[Callable[[], Any]]] = None,
|
on_shutdown: Optional[Sequence[Callable[[], Any]]] = None,
|
||||||
|
lifespan: Optional[Callable[["FastAPI"], AsyncContextManager[Any]]] = None,
|
||||||
terms_of_service: Optional[str] = None,
|
terms_of_service: Optional[str] = None,
|
||||||
contact: Optional[Dict[str, Union[str, Any]]] = None,
|
contact: Optional[Dict[str, Union[str, Any]]] = None,
|
||||||
license_info: Optional[Dict[str, Union[str, Any]]] = None,
|
license_info: Optional[Dict[str, Union[str, Any]]] = None,
|
||||||
|
|
@ -125,6 +127,7 @@ class FastAPI(Starlette):
|
||||||
dependency_overrides_provider=self,
|
dependency_overrides_provider=self,
|
||||||
on_startup=on_startup,
|
on_startup=on_startup,
|
||||||
on_shutdown=on_shutdown,
|
on_shutdown=on_shutdown,
|
||||||
|
lifespan=lifespan,
|
||||||
default_response_class=default_response_class,
|
default_response_class=default_response_class,
|
||||||
dependencies=dependencies,
|
dependencies=dependencies,
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@ from contextlib import AsyncExitStack
|
||||||
from enum import Enum, IntEnum
|
from enum import Enum, IntEnum
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
|
AsyncContextManager,
|
||||||
Callable,
|
Callable,
|
||||||
Coroutine,
|
Coroutine,
|
||||||
Dict,
|
Dict,
|
||||||
|
|
@ -492,6 +493,7 @@ class APIRouter(routing.Router):
|
||||||
route_class: Type[APIRoute] = APIRoute,
|
route_class: Type[APIRoute] = APIRoute,
|
||||||
on_startup: Optional[Sequence[Callable[[], Any]]] = None,
|
on_startup: Optional[Sequence[Callable[[], Any]]] = None,
|
||||||
on_shutdown: Optional[Sequence[Callable[[], Any]]] = None,
|
on_shutdown: Optional[Sequence[Callable[[], Any]]] = None,
|
||||||
|
lifespan: Optional[Callable[[Any], AsyncContextManager[Any]]] = None,
|
||||||
deprecated: Optional[bool] = None,
|
deprecated: Optional[bool] = None,
|
||||||
include_in_schema: bool = True,
|
include_in_schema: bool = True,
|
||||||
generate_unique_id_function: Callable[[APIRoute], str] = Default(
|
generate_unique_id_function: Callable[[APIRoute], str] = Default(
|
||||||
|
|
@ -504,6 +506,7 @@ class APIRouter(routing.Router):
|
||||||
default=default,
|
default=default,
|
||||||
on_startup=on_startup,
|
on_startup=on_startup,
|
||||||
on_shutdown=on_shutdown,
|
on_shutdown=on_shutdown,
|
||||||
|
lifespan=lifespan,
|
||||||
)
|
)
|
||||||
if prefix:
|
if prefix:
|
||||||
assert prefix.startswith("/"), "A path prefix must start with '/'"
|
assert prefix.startswith("/"), "A path prefix must start with '/'"
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,7 @@
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from typing import AsyncGenerator, Dict
|
||||||
|
|
||||||
|
import pytest
|
||||||
from fastapi import APIRouter, FastAPI
|
from fastapi import APIRouter, FastAPI
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
@ -12,57 +16,49 @@ class State(BaseModel):
|
||||||
sub_router_shutdown: bool = False
|
sub_router_shutdown: bool = False
|
||||||
|
|
||||||
|
|
||||||
state = State()
|
@pytest.fixture
|
||||||
|
def state() -> State:
|
||||||
|
return State()
|
||||||
|
|
||||||
|
|
||||||
|
def test_router_events(state: State) -> None:
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
|
@app.get("/")
|
||||||
|
def main() -> Dict[str, str]:
|
||||||
|
return {"message": "Hello World"}
|
||||||
|
|
||||||
@app.on_event("startup")
|
@app.on_event("startup")
|
||||||
def app_startup():
|
def app_startup() -> None:
|
||||||
state.app_startup = True
|
state.app_startup = True
|
||||||
|
|
||||||
|
|
||||||
@app.on_event("shutdown")
|
@app.on_event("shutdown")
|
||||||
def app_shutdown():
|
def app_shutdown() -> None:
|
||||||
state.app_shutdown = True
|
state.app_shutdown = True
|
||||||
|
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
@router.on_event("startup")
|
@router.on_event("startup")
|
||||||
def router_startup():
|
def router_startup() -> None:
|
||||||
state.router_startup = True
|
state.router_startup = True
|
||||||
|
|
||||||
|
|
||||||
@router.on_event("shutdown")
|
@router.on_event("shutdown")
|
||||||
def router_shutdown():
|
def router_shutdown() -> None:
|
||||||
state.router_shutdown = True
|
state.router_shutdown = True
|
||||||
|
|
||||||
|
|
||||||
sub_router = APIRouter()
|
sub_router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
@sub_router.on_event("startup")
|
@sub_router.on_event("startup")
|
||||||
def sub_router_startup():
|
def sub_router_startup() -> None:
|
||||||
state.sub_router_startup = True
|
state.sub_router_startup = True
|
||||||
|
|
||||||
|
|
||||||
@sub_router.on_event("shutdown")
|
@sub_router.on_event("shutdown")
|
||||||
def sub_router_shutdown():
|
def sub_router_shutdown() -> None:
|
||||||
state.sub_router_shutdown = True
|
state.sub_router_shutdown = True
|
||||||
|
|
||||||
|
|
||||||
@sub_router.get("/")
|
|
||||||
def main():
|
|
||||||
return {"message": "Hello World"}
|
|
||||||
|
|
||||||
|
|
||||||
router.include_router(sub_router)
|
router.include_router(sub_router)
|
||||||
app.include_router(router)
|
app.include_router(router)
|
||||||
|
|
||||||
|
|
||||||
def test_router_events():
|
|
||||||
assert state.app_startup is False
|
assert state.app_startup is False
|
||||||
assert state.router_startup is False
|
assert state.router_startup is False
|
||||||
assert state.sub_router_startup is False
|
assert state.sub_router_startup is False
|
||||||
|
|
@ -85,3 +81,28 @@ def test_router_events():
|
||||||
assert state.app_shutdown is True
|
assert state.app_shutdown is True
|
||||||
assert state.router_shutdown is True
|
assert state.router_shutdown is True
|
||||||
assert state.sub_router_shutdown is True
|
assert state.sub_router_shutdown is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_app_lifespan_state(state: State) -> None:
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||||
|
state.app_startup = True
|
||||||
|
yield
|
||||||
|
state.app_shutdown = True
|
||||||
|
|
||||||
|
app = FastAPI(lifespan=lifespan)
|
||||||
|
|
||||||
|
@app.get("/")
|
||||||
|
def main() -> Dict[str, str]:
|
||||||
|
return {"message": "Hello World"}
|
||||||
|
|
||||||
|
assert state.app_startup is False
|
||||||
|
assert state.app_shutdown is False
|
||||||
|
with TestClient(app) as client:
|
||||||
|
assert state.app_startup is True
|
||||||
|
assert state.app_shutdown is False
|
||||||
|
response = client.get("/")
|
||||||
|
assert response.status_code == 200, response.text
|
||||||
|
assert response.json() == {"message": "Hello World"}
|
||||||
|
assert state.app_startup is True
|
||||||
|
assert state.app_shutdown is True
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,86 @@
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
from docs_src.events.tutorial003 import (
|
||||||
|
app,
|
||||||
|
fake_answer_to_everything_ml_model,
|
||||||
|
ml_models,
|
||||||
|
)
|
||||||
|
|
||||||
|
openapi_schema = {
|
||||||
|
"openapi": "3.0.2",
|
||||||
|
"info": {"title": "FastAPI", "version": "0.1.0"},
|
||||||
|
"paths": {
|
||||||
|
"/predict": {
|
||||||
|
"get": {
|
||||||
|
"summary": "Predict",
|
||||||
|
"operationId": "predict_predict_get",
|
||||||
|
"parameters": [
|
||||||
|
{
|
||||||
|
"required": True,
|
||||||
|
"schema": {"title": "X", "type": "number"},
|
||||||
|
"name": "x",
|
||||||
|
"in": "query",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"responses": {
|
||||||
|
"200": {
|
||||||
|
"description": "Successful Response",
|
||||||
|
"content": {"application/json": {"schema": {}}},
|
||||||
|
},
|
||||||
|
"422": {
|
||||||
|
"description": "Validation Error",
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/components/schemas/HTTPValidationError"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"components": {
|
||||||
|
"schemas": {
|
||||||
|
"HTTPValidationError": {
|
||||||
|
"title": "HTTPValidationError",
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"detail": {
|
||||||
|
"title": "Detail",
|
||||||
|
"type": "array",
|
||||||
|
"items": {"$ref": "#/components/schemas/ValidationError"},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"ValidationError": {
|
||||||
|
"title": "ValidationError",
|
||||||
|
"required": ["loc", "msg", "type"],
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"loc": {
|
||||||
|
"title": "Location",
|
||||||
|
"type": "array",
|
||||||
|
"items": {"anyOf": [{"type": "string"}, {"type": "integer"}]},
|
||||||
|
},
|
||||||
|
"msg": {"title": "Message", "type": "string"},
|
||||||
|
"type": {"title": "Error Type", "type": "string"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_events():
|
||||||
|
assert not ml_models, "ml_models should be empty"
|
||||||
|
with TestClient(app) as client:
|
||||||
|
assert ml_models["answer_to_everything"] == fake_answer_to_everything_ml_model
|
||||||
|
response = client.get("/openapi.json")
|
||||||
|
assert response.status_code == 200, response.text
|
||||||
|
assert response.json() == openapi_schema
|
||||||
|
response = client.get("/predict", params={"x": 2})
|
||||||
|
assert response.status_code == 200, response.text
|
||||||
|
assert response.json() == {"result": 84.0}
|
||||||
|
assert not ml_models, "ml_models should be empty"
|
||||||
Loading…
Reference in New Issue