mirror of https://github.com/tiangolo/fastapi.git
✨ Add support for `lifespan` async context managers (superseding `startup` and `shutdown` events) (#2944)
Co-authored-by: Mike Shantz <mshantz@coldstorage.com> Co-authored-by: Jonathan Plasse <13716151+JonathanPlasse@users.noreply.github.com> Co-authored-by: Sebastián Ramírez <tiangolo@gmail.com>
This commit is contained in:
parent
66e03c816b
commit
cc9a73c3f8
|
|
@ -1,13 +1,108 @@
|
|||
# Events: startup - shutdown
|
||||
# Lifespan Events
|
||||
|
||||
You can define logic (code) that should be executed before the application **starts up**. This means that this code will be executed **once**, **before** the application **starts receiving requests**.
|
||||
|
||||
The same way, you can define logic (code) that should be executed when the application is **shutting down**. In this case, this code will be executed **once**, **after** having handled possibly **many requests**.
|
||||
|
||||
Because this code is executed before the application **starts** taking requests, and right after it **finishes** handling requests, it covers the whole application **lifespan** (the word "lifespan" will be important in a second 😉).
|
||||
|
||||
This can be very useful for setting up **resources** that you need to use for the whole app, and that are **shared** among requests, and/or that you need to **clean up** afterwards. For example, a database connection pool, or loading a shared machine learning model.
|
||||
|
||||
## Use Case
|
||||
|
||||
Let's start with an example **use case** and then see how to solve it with this.
|
||||
|
||||
Let's imagine that you have some **machine learning models** that you want to use to handle requests. 🤖
|
||||
|
||||
The same models are shared among requests, so, it's not one model per request, or one per user or something similar.
|
||||
|
||||
Let's imagine that loading the model can **take quite some time**, because it has to read a lot of **data from disk**. So you don't want to do it for every request.
|
||||
|
||||
You could load it at the top level of the module/file, but that would also mean that it would **load the model** even if you are just running a simple automated test, then that test would be **slow** because it would have to wait for the model to load before being able to run an independent part of the code.
|
||||
|
||||
That's what we'll solve, let's load the model before the requests are handled, but only right before the application starts receiving requests, not while the code is being loaded.
|
||||
|
||||
## Lifespan
|
||||
|
||||
You can define this *startup* and *shutdown* logic using the `lifespan` parameter of the `FastAPI` app, and a "context manager" (I'll show you what that is in a second).
|
||||
|
||||
Let's start with an example and then see it in detail.
|
||||
|
||||
We create an async function `lifespan()` with `yield` like this:
|
||||
|
||||
```Python hl_lines="16 19"
|
||||
{!../../../docs_src/events/tutorial003.py!}
|
||||
```
|
||||
|
||||
Here we are simulating the expensive *startup* operation of loading the model by putting the (fake) model function in the dictionary with machine learning models before the `yield`. This code will be executed **before** the application **starts taking requests**, during the *startup*.
|
||||
|
||||
And then, right after the `yield`, we unload the model. This code will be executed **after** the application **finishes handling requests**, right before the *shutdown*. This could, for example, release resources like memory or a GPU.
|
||||
|
||||
!!! tip
|
||||
The `shutdown` would happen when you are **stopping** the application.
|
||||
|
||||
Maybe you need to start a new version, or you just got tired of running it. 🤷
|
||||
|
||||
### Lifespan function
|
||||
|
||||
The first thing to notice, is that we are defining an async function with `yield`. This is very similar to Dependencies with `yield`.
|
||||
|
||||
```Python hl_lines="14-19"
|
||||
{!../../../docs_src/events/tutorial003.py!}
|
||||
```
|
||||
|
||||
The first part of the function, before the `yield`, will be executed **before** the application starts.
|
||||
|
||||
And the part after the `yield` will be executed **after** the application has finished.
|
||||
|
||||
### Async Context Manager
|
||||
|
||||
If you check, the function is decorated with an `@asynccontextmanager`.
|
||||
|
||||
That converts the function into something called an "**async context manager**".
|
||||
|
||||
```Python hl_lines="1 13"
|
||||
{!../../../docs_src/events/tutorial003.py!}
|
||||
```
|
||||
|
||||
A **context manager** in Python is something that you can use in a `with` statement, for example, `open()` can be used as a context manager:
|
||||
|
||||
```Python
|
||||
with open("file.txt") as file:
|
||||
file.read()
|
||||
```
|
||||
|
||||
In recent versions of Python, there's also an **async context manager**. You would use it with `async with`:
|
||||
|
||||
```Python
|
||||
async with lifespan(app):
|
||||
await do_stuff()
|
||||
```
|
||||
|
||||
When you create a context manager or an async context manager like above, what it does is that, before entering the `with` block, it will execute the code before the `yield`, and after exiting the `with` block, it will execute the code after the `yield`.
|
||||
|
||||
In our code example above, we don't use it directly, but we pass it to FastAPI for it to use it.
|
||||
|
||||
The `lifespan` parameter of the `FastAPI` app takes an **async context manager**, so we can pass our new `lifespan` async context manager to it.
|
||||
|
||||
```Python hl_lines="22"
|
||||
{!../../../docs_src/events/tutorial003.py!}
|
||||
```
|
||||
|
||||
## Alternative Events (deprecated)
|
||||
|
||||
!!! warning
|
||||
The recommended way to handle the *startup* and *shutdown* is using the `lifespan` parameter of the `FastAPI` app as described above.
|
||||
|
||||
You can probably skip this part.
|
||||
|
||||
There's an alternative way to define this logic to be executed during *startup* and during *shutdown*.
|
||||
|
||||
You can define event handlers (functions) that need to be executed before the application starts up, or when the application is shutting down.
|
||||
|
||||
These functions can be declared with `async def` or normal `def`.
|
||||
|
||||
!!! warning
|
||||
Only event handlers for the main application will be executed, not for [Sub Applications - Mounts](./sub-applications.md){.internal-link target=_blank}.
|
||||
|
||||
## `startup` event
|
||||
### `startup` event
|
||||
|
||||
To add a function that should be run before the application starts, declare it with the event `"startup"`:
|
||||
|
||||
|
|
@ -21,7 +116,7 @@ You can add more than one event handler function.
|
|||
|
||||
And your application won't start receiving requests until all the `startup` event handlers have completed.
|
||||
|
||||
## `shutdown` event
|
||||
### `shutdown` event
|
||||
|
||||
To add a function that should be run when the application is shutting down, declare it with the event `"shutdown"`:
|
||||
|
||||
|
|
@ -45,3 +140,21 @@ Here, the `shutdown` event handler function will write a text line `"Application
|
|||
|
||||
!!! info
|
||||
You can read more about these event handlers in <a href="https://www.starlette.io/events/" class="external-link" target="_blank">Starlette's Events' docs</a>.
|
||||
|
||||
### `startup` and `shutdown` together
|
||||
|
||||
There's a high chance that the logic for your *startup* and *shutdown* is connected, you might want to start something and then finish it, acquire a resource and then release it, etc.
|
||||
|
||||
Doing that in separated functions that don't share logic or variables together is more difficult as you would need to store values in global variables or similar tricks.
|
||||
|
||||
Because of that, it's now recommended to instead use the `lifespan` as explained above.
|
||||
|
||||
## Technical Details
|
||||
|
||||
Just a technical detail for the curious nerds. 🤓
|
||||
|
||||
Underneath, in the ASGI technical specification, this is part of the <a href="https://asgi.readthedocs.io/en/latest/specs/lifespan.html" class="external-link" target="_blank">Lifespan Protocol</a>, and it defines events called `startup` and `shutdown`.
|
||||
|
||||
## Sub Applications
|
||||
|
||||
🚨 Have in mind that these lifespan events (startup and shutdown) will only be executed for the main application, not for [Sub Applications - Mounts](./sub-applications.md){.internal-link target=_blank}.
|
||||
|
|
|
|||
|
|
@ -0,0 +1,28 @@
|
|||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI
|
||||
|
||||
|
||||
def fake_answer_to_everything_ml_model(x: float):
|
||||
return x * 42
|
||||
|
||||
|
||||
ml_models = {}
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
# Load the ML model
|
||||
ml_models["answer_to_everything"] = fake_answer_to_everything_ml_model
|
||||
yield
|
||||
# Clean up the ML models and release the resources
|
||||
ml_models.clear()
|
||||
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
|
||||
@app.get("/predict")
|
||||
async def predict(x: float):
|
||||
result = ml_models["answer_to_everything"](x)
|
||||
return {"result": result}
|
||||
|
|
@ -1,6 +1,7 @@
|
|||
from enum import Enum
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncContextManager,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Coroutine,
|
||||
|
|
@ -71,6 +72,7 @@ class FastAPI(Starlette):
|
|||
] = None,
|
||||
on_startup: Optional[Sequence[Callable[[], Any]]] = None,
|
||||
on_shutdown: Optional[Sequence[Callable[[], Any]]] = None,
|
||||
lifespan: Optional[Callable[["FastAPI"], AsyncContextManager[Any]]] = None,
|
||||
terms_of_service: Optional[str] = None,
|
||||
contact: Optional[Dict[str, Union[str, Any]]] = None,
|
||||
license_info: Optional[Dict[str, Union[str, Any]]] = None,
|
||||
|
|
@ -125,6 +127,7 @@ class FastAPI(Starlette):
|
|||
dependency_overrides_provider=self,
|
||||
on_startup=on_startup,
|
||||
on_shutdown=on_shutdown,
|
||||
lifespan=lifespan,
|
||||
default_response_class=default_response_class,
|
||||
dependencies=dependencies,
|
||||
callbacks=callbacks,
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from contextlib import AsyncExitStack
|
|||
from enum import Enum, IntEnum
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncContextManager,
|
||||
Callable,
|
||||
Coroutine,
|
||||
Dict,
|
||||
|
|
@ -492,6 +493,7 @@ class APIRouter(routing.Router):
|
|||
route_class: Type[APIRoute] = APIRoute,
|
||||
on_startup: Optional[Sequence[Callable[[], Any]]] = None,
|
||||
on_shutdown: Optional[Sequence[Callable[[], Any]]] = None,
|
||||
lifespan: Optional[Callable[[Any], AsyncContextManager[Any]]] = None,
|
||||
deprecated: Optional[bool] = None,
|
||||
include_in_schema: bool = True,
|
||||
generate_unique_id_function: Callable[[APIRoute], str] = Default(
|
||||
|
|
@ -504,6 +506,7 @@ class APIRouter(routing.Router):
|
|||
default=default,
|
||||
on_startup=on_startup,
|
||||
on_shutdown=on_shutdown,
|
||||
lifespan=lifespan,
|
||||
)
|
||||
if prefix:
|
||||
assert prefix.startswith("/"), "A path prefix must start with '/'"
|
||||
|
|
|
|||
|
|
@ -1,3 +1,7 @@
|
|||
from contextlib import asynccontextmanager
|
||||
from typing import AsyncGenerator, Dict
|
||||
|
||||
import pytest
|
||||
from fastapi import APIRouter, FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
from pydantic import BaseModel
|
||||
|
|
@ -12,57 +16,49 @@ class State(BaseModel):
|
|||
sub_router_shutdown: bool = False
|
||||
|
||||
|
||||
state = State()
|
||||
|
||||
app = FastAPI()
|
||||
@pytest.fixture
|
||||
def state() -> State:
|
||||
return State()
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
def app_startup():
|
||||
state.app_startup = True
|
||||
def test_router_events(state: State) -> None:
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
@app.on_event("shutdown")
|
||||
def app_shutdown():
|
||||
state.app_shutdown = True
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.on_event("startup")
|
||||
def router_startup():
|
||||
state.router_startup = True
|
||||
|
||||
|
||||
@router.on_event("shutdown")
|
||||
def router_shutdown():
|
||||
state.router_shutdown = True
|
||||
|
||||
|
||||
sub_router = APIRouter()
|
||||
|
||||
|
||||
@sub_router.on_event("startup")
|
||||
def sub_router_startup():
|
||||
state.sub_router_startup = True
|
||||
|
||||
|
||||
@sub_router.on_event("shutdown")
|
||||
def sub_router_shutdown():
|
||||
state.sub_router_shutdown = True
|
||||
|
||||
|
||||
@sub_router.get("/")
|
||||
def main():
|
||||
@app.get("/")
|
||||
def main() -> Dict[str, str]:
|
||||
return {"message": "Hello World"}
|
||||
|
||||
@app.on_event("startup")
|
||||
def app_startup() -> None:
|
||||
state.app_startup = True
|
||||
|
||||
router.include_router(sub_router)
|
||||
app.include_router(router)
|
||||
@app.on_event("shutdown")
|
||||
def app_shutdown() -> None:
|
||||
state.app_shutdown = True
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.on_event("startup")
|
||||
def router_startup() -> None:
|
||||
state.router_startup = True
|
||||
|
||||
@router.on_event("shutdown")
|
||||
def router_shutdown() -> None:
|
||||
state.router_shutdown = True
|
||||
|
||||
sub_router = APIRouter()
|
||||
|
||||
@sub_router.on_event("startup")
|
||||
def sub_router_startup() -> None:
|
||||
state.sub_router_startup = True
|
||||
|
||||
@sub_router.on_event("shutdown")
|
||||
def sub_router_shutdown() -> None:
|
||||
state.sub_router_shutdown = True
|
||||
|
||||
router.include_router(sub_router)
|
||||
app.include_router(router)
|
||||
|
||||
def test_router_events():
|
||||
assert state.app_startup is False
|
||||
assert state.router_startup is False
|
||||
assert state.sub_router_startup is False
|
||||
|
|
@ -85,3 +81,28 @@ def test_router_events():
|
|||
assert state.app_shutdown is True
|
||||
assert state.router_shutdown is True
|
||||
assert state.sub_router_shutdown is True
|
||||
|
||||
|
||||
def test_app_lifespan_state(state: State) -> None:
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
state.app_startup = True
|
||||
yield
|
||||
state.app_shutdown = True
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
@app.get("/")
|
||||
def main() -> Dict[str, str]:
|
||||
return {"message": "Hello World"}
|
||||
|
||||
assert state.app_startup is False
|
||||
assert state.app_shutdown is False
|
||||
with TestClient(app) as client:
|
||||
assert state.app_startup is True
|
||||
assert state.app_shutdown is False
|
||||
response = client.get("/")
|
||||
assert response.status_code == 200, response.text
|
||||
assert response.json() == {"message": "Hello World"}
|
||||
assert state.app_startup is True
|
||||
assert state.app_shutdown is True
|
||||
|
|
|
|||
|
|
@ -0,0 +1,86 @@
|
|||
from fastapi.testclient import TestClient
|
||||
|
||||
from docs_src.events.tutorial003 import (
|
||||
app,
|
||||
fake_answer_to_everything_ml_model,
|
||||
ml_models,
|
||||
)
|
||||
|
||||
openapi_schema = {
|
||||
"openapi": "3.0.2",
|
||||
"info": {"title": "FastAPI", "version": "0.1.0"},
|
||||
"paths": {
|
||||
"/predict": {
|
||||
"get": {
|
||||
"summary": "Predict",
|
||||
"operationId": "predict_predict_get",
|
||||
"parameters": [
|
||||
{
|
||||
"required": True,
|
||||
"schema": {"title": "X", "type": "number"},
|
||||
"name": "x",
|
||||
"in": "query",
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": {"application/json": {"schema": {}}},
|
||||
},
|
||||
"422": {
|
||||
"description": "Validation Error",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/HTTPValidationError"
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
},
|
||||
"components": {
|
||||
"schemas": {
|
||||
"HTTPValidationError": {
|
||||
"title": "HTTPValidationError",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"detail": {
|
||||
"title": "Detail",
|
||||
"type": "array",
|
||||
"items": {"$ref": "#/components/schemas/ValidationError"},
|
||||
}
|
||||
},
|
||||
},
|
||||
"ValidationError": {
|
||||
"title": "ValidationError",
|
||||
"required": ["loc", "msg", "type"],
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"loc": {
|
||||
"title": "Location",
|
||||
"type": "array",
|
||||
"items": {"anyOf": [{"type": "string"}, {"type": "integer"}]},
|
||||
},
|
||||
"msg": {"title": "Message", "type": "string"},
|
||||
"type": {"title": "Error Type", "type": "string"},
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def test_events():
|
||||
assert not ml_models, "ml_models should be empty"
|
||||
with TestClient(app) as client:
|
||||
assert ml_models["answer_to_everything"] == fake_answer_to_everything_ml_model
|
||||
response = client.get("/openapi.json")
|
||||
assert response.status_code == 200, response.text
|
||||
assert response.json() == openapi_schema
|
||||
response = client.get("/predict", params={"x": 2})
|
||||
assert response.status_code == 200, response.text
|
||||
assert response.json() == {"result": 84.0}
|
||||
assert not ml_models, "ml_models should be empty"
|
||||
Loading…
Reference in New Issue