mirror of https://github.com/tiangolo/fastapi.git
✨ Implement dependency value cache per request (#292)
* ✨ Add dependency cache, with support for disabling it * ✅ Add tests for dependency cache * 📝 Add docs about dependency value caching
This commit is contained in:
parent
09cd7c47a1
commit
bff5dbbf5d
|
|
@ -17,14 +17,12 @@ This is very useful when you need to:
|
|||
|
||||
All these, while minimizing code repetition.
|
||||
|
||||
|
||||
## First Steps
|
||||
|
||||
Let's see a very simple example. It will be so simple that it is not very useful, for now.
|
||||
|
||||
But this way we can focus on how the **Dependency Injection** system works.
|
||||
|
||||
|
||||
### Create a dependency, or "dependable"
|
||||
|
||||
Let's first focus on the dependency.
|
||||
|
|
@ -151,7 +149,6 @@ The simplicity of the dependency injection system makes **FastAPI** compatible w
|
|||
* response data injection systems
|
||||
* etc.
|
||||
|
||||
|
||||
## Simple and Powerful
|
||||
|
||||
Although the hierarchical dependency injection system is very simple to define and use, it's still very powerful.
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ You could create a first dependency ("dependable") like:
|
|||
```Python hl_lines="6 7"
|
||||
{!./src/dependencies/tutorial005.py!}
|
||||
```
|
||||
|
||||
It declares an optional query parameter `q` as a `str`, and then it just returns it.
|
||||
|
||||
This is quite simple (not very useful), but will help us focus on how the sub-dependencies work.
|
||||
|
|
@ -43,6 +44,18 @@ Then we can use the dependency with:
|
|||
|
||||
But **FastAPI** will know that it has to solve `query_extractor` first, to pass the results of that to `query_or_cookie_extractor` while calling it.
|
||||
|
||||
## Using the same dependency multiple times
|
||||
|
||||
If one of your dependencies is declared multiple times for the same *path operation*, for example, multiple dependencies have a common sub-dependency, **FastAPI** will know to call that sub-dependency only once per request.
|
||||
|
||||
And it will save the returned value in a <abbr title="A utility/system to store computed/generated values, to re-use them instead of computing them again.">"cache"</abbr> and pass it to all the "dependants" that need it in that specific request, instead of calling the dependency multiple times for the same request.
|
||||
|
||||
In an advanced scenario where you know you need the dependency to be called at every step (possibly multiple times) in the same request instead of using the "cached" value, you can set the parameter `use_cache=False` when using `Depends`:
|
||||
|
||||
```Python hl_lines="1"
|
||||
async def needy_dependency(fresh_value: str = Depends(get_value, use_cache=False)):
|
||||
return {"fresh_value": fresh_value}
|
||||
```
|
||||
|
||||
## Recap
|
||||
|
||||
|
|
@ -54,7 +67,7 @@ But still, it is very powerful, and allows you to declare arbitrarily deeply nes
|
|||
|
||||
!!! tip
|
||||
All this might not seem as useful with these simple examples.
|
||||
|
||||
|
||||
But you will see how useful it is in the chapters about **security**.
|
||||
|
||||
And you will also see the amounts of code it will save you.
|
||||
|
|
|
|||
|
|
@ -30,6 +30,7 @@ class Dependant:
|
|||
background_tasks_param_name: str = None,
|
||||
security_scopes_param_name: str = None,
|
||||
security_scopes: List[str] = None,
|
||||
use_cache: bool = True,
|
||||
path: str = None,
|
||||
) -> None:
|
||||
self.path_params = path_params or []
|
||||
|
|
@ -46,5 +47,8 @@ class Dependant:
|
|||
self.security_scopes_param_name = security_scopes_param_name
|
||||
self.name = name
|
||||
self.call = call
|
||||
self.use_cache = use_cache
|
||||
# Store the path to be able to re-generate a dependable from it in overrides
|
||||
self.path = path
|
||||
# Save the cache key at creation to optimize performance
|
||||
self.cache_key = (self.call, tuple(sorted(set(self.security_scopes or []))))
|
||||
|
|
|
|||
|
|
@ -95,7 +95,11 @@ def get_sub_dependant(
|
|||
security_scheme=dependency, scopes=use_scopes
|
||||
)
|
||||
sub_dependant = get_dependant(
|
||||
path=path, call=dependency, name=name, security_scopes=security_scopes
|
||||
path=path,
|
||||
call=dependency,
|
||||
name=name,
|
||||
security_scopes=security_scopes,
|
||||
use_cache=depends.use_cache,
|
||||
)
|
||||
if security_requirement:
|
||||
sub_dependant.security_requirements.append(security_requirement)
|
||||
|
|
@ -111,6 +115,7 @@ def get_flat_dependant(dependant: Dependant) -> Dependant:
|
|||
cookie_params=dependant.cookie_params.copy(),
|
||||
body_params=dependant.body_params.copy(),
|
||||
security_schemes=dependant.security_requirements.copy(),
|
||||
use_cache=dependant.use_cache,
|
||||
path=dependant.path,
|
||||
)
|
||||
for sub_dependant in dependant.dependencies:
|
||||
|
|
@ -148,12 +153,17 @@ def is_scalar_sequence_field(field: Field) -> bool:
|
|||
|
||||
|
||||
def get_dependant(
|
||||
*, path: str, call: Callable, name: str = None, security_scopes: List[str] = None
|
||||
*,
|
||||
path: str,
|
||||
call: Callable,
|
||||
name: str = None,
|
||||
security_scopes: List[str] = None,
|
||||
use_cache: bool = True,
|
||||
) -> Dependant:
|
||||
path_param_names = get_path_param_names(path)
|
||||
endpoint_signature = inspect.signature(call)
|
||||
signature_params = endpoint_signature.parameters
|
||||
dependant = Dependant(call=call, name=name, path=path)
|
||||
dependant = Dependant(call=call, name=name, path=path, use_cache=use_cache)
|
||||
for param_name, param in signature_params.items():
|
||||
if isinstance(param.default, params.Depends):
|
||||
sub_dependant = get_param_sub_dependant(
|
||||
|
|
@ -286,18 +296,29 @@ async def solve_dependencies(
|
|||
body: Dict[str, Any] = None,
|
||||
background_tasks: BackgroundTasks = None,
|
||||
dependency_overrides_provider: Any = None,
|
||||
) -> Tuple[Dict[str, Any], List[ErrorWrapper], Optional[BackgroundTasks]]:
|
||||
dependency_cache: Dict[Tuple[Callable, Tuple[str]], Any] = None,
|
||||
) -> Tuple[
|
||||
Dict[str, Any],
|
||||
List[ErrorWrapper],
|
||||
Optional[BackgroundTasks],
|
||||
Dict[Tuple[Callable, Tuple[str]], Any],
|
||||
]:
|
||||
values: Dict[str, Any] = {}
|
||||
errors: List[ErrorWrapper] = []
|
||||
dependency_cache = dependency_cache or {}
|
||||
sub_dependant: Dependant
|
||||
for sub_dependant in dependant.dependencies:
|
||||
call: Callable = sub_dependant.call # type: ignore
|
||||
sub_dependant.call = cast(Callable, sub_dependant.call)
|
||||
sub_dependant.cache_key = cast(
|
||||
Tuple[Callable, Tuple[str]], sub_dependant.cache_key
|
||||
)
|
||||
call = sub_dependant.call
|
||||
use_sub_dependant = sub_dependant
|
||||
if (
|
||||
dependency_overrides_provider
|
||||
and dependency_overrides_provider.dependency_overrides
|
||||
):
|
||||
original_call: Callable = sub_dependant.call # type: ignore
|
||||
original_call = sub_dependant.call
|
||||
call = getattr(
|
||||
dependency_overrides_provider, "dependency_overrides", {}
|
||||
).get(original_call, original_call)
|
||||
|
|
@ -309,22 +330,28 @@ async def solve_dependencies(
|
|||
security_scopes=sub_dependant.security_scopes,
|
||||
)
|
||||
|
||||
sub_values, sub_errors, background_tasks = await solve_dependencies(
|
||||
sub_values, sub_errors, background_tasks, sub_dependency_cache = await solve_dependencies(
|
||||
request=request,
|
||||
dependant=use_sub_dependant,
|
||||
body=body,
|
||||
background_tasks=background_tasks,
|
||||
dependency_overrides_provider=dependency_overrides_provider,
|
||||
dependency_cache=dependency_cache,
|
||||
)
|
||||
dependency_cache.update(sub_dependency_cache)
|
||||
if sub_errors:
|
||||
errors.extend(sub_errors)
|
||||
continue
|
||||
if is_coroutine_callable(call):
|
||||
if sub_dependant.use_cache and sub_dependant.cache_key in dependency_cache:
|
||||
solved = dependency_cache[sub_dependant.cache_key]
|
||||
elif is_coroutine_callable(call):
|
||||
solved = await call(**sub_values)
|
||||
else:
|
||||
solved = await run_in_threadpool(call, **sub_values)
|
||||
if use_sub_dependant.name is not None:
|
||||
values[use_sub_dependant.name] = solved
|
||||
if sub_dependant.name is not None:
|
||||
values[sub_dependant.name] = solved
|
||||
if sub_dependant.cache_key not in dependency_cache:
|
||||
dependency_cache[sub_dependant.cache_key] = solved
|
||||
path_values, path_errors = request_params_to_args(
|
||||
dependant.path_params, request.path_params
|
||||
)
|
||||
|
|
@ -360,7 +387,7 @@ async def solve_dependencies(
|
|||
values[dependant.security_scopes_param_name] = SecurityScopes(
|
||||
scopes=dependant.security_scopes
|
||||
)
|
||||
return values, errors, background_tasks
|
||||
return values, errors, background_tasks, dependency_cache
|
||||
|
||||
|
||||
def request_params_to_args(
|
||||
|
|
|
|||
|
|
@ -238,11 +238,13 @@ def File( # noqa: N802
|
|||
)
|
||||
|
||||
|
||||
def Depends(dependency: Callable = None) -> Any: # noqa: N802
|
||||
return params.Depends(dependency=dependency)
|
||||
def Depends( # noqa: N802
|
||||
dependency: Callable = None, *, use_cache: bool = True
|
||||
) -> Any:
|
||||
return params.Depends(dependency=dependency, use_cache=use_cache)
|
||||
|
||||
|
||||
def Security( # noqa: N802
|
||||
dependency: Callable = None, scopes: Sequence[str] = None
|
||||
dependency: Callable = None, *, scopes: Sequence[str] = None, use_cache: bool = True
|
||||
) -> Any:
|
||||
return params.Security(dependency=dependency, scopes=scopes)
|
||||
return params.Security(dependency=dependency, scopes=scopes, use_cache=use_cache)
|
||||
|
|
|
|||
|
|
@ -308,11 +308,18 @@ class File(Form):
|
|||
|
||||
|
||||
class Depends:
|
||||
def __init__(self, dependency: Callable = None):
|
||||
def __init__(self, dependency: Callable = None, *, use_cache: bool = True):
|
||||
self.dependency = dependency
|
||||
self.use_cache = use_cache
|
||||
|
||||
|
||||
class Security(Depends):
|
||||
def __init__(self, dependency: Callable = None, scopes: Sequence[str] = None):
|
||||
def __init__(
|
||||
self,
|
||||
dependency: Callable = None,
|
||||
*,
|
||||
scopes: Sequence[str] = None,
|
||||
use_cache: bool = True,
|
||||
):
|
||||
super().__init__(dependency=dependency, use_cache=use_cache)
|
||||
self.scopes = scopes or []
|
||||
super().__init__(dependency=dependency)
|
||||
|
|
|
|||
|
|
@ -102,7 +102,7 @@ def get_app(
|
|||
raise HTTPException(
|
||||
status_code=400, detail="There was an error parsing the body"
|
||||
) from e
|
||||
values, errors, background_tasks = await solve_dependencies(
|
||||
values, errors, background_tasks, _ = await solve_dependencies(
|
||||
request=request,
|
||||
dependant=dependant,
|
||||
body=body,
|
||||
|
|
@ -141,7 +141,7 @@ def get_websocket_app(
|
|||
dependant: Dependant, dependency_overrides_provider: Any = None
|
||||
) -> Callable:
|
||||
async def app(websocket: WebSocket) -> None:
|
||||
values, errors, _ = await solve_dependencies(
|
||||
values, errors, _, _2 = await solve_dependencies(
|
||||
request=websocket,
|
||||
dependant=dependant,
|
||||
dependency_overrides_provider=dependency_overrides_provider,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,68 @@
|
|||
from fastapi import Depends, FastAPI
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
counter_holder = {"counter": 0}
|
||||
|
||||
|
||||
async def dep_counter():
|
||||
counter_holder["counter"] += 1
|
||||
return counter_holder["counter"]
|
||||
|
||||
|
||||
async def super_dep(count: int = Depends(dep_counter)):
|
||||
return count
|
||||
|
||||
|
||||
@app.get("/counter/")
|
||||
async def get_counter(count: int = Depends(dep_counter)):
|
||||
return {"counter": count}
|
||||
|
||||
|
||||
@app.get("/sub-counter/")
|
||||
async def get_sub_counter(
|
||||
subcount: int = Depends(super_dep), count: int = Depends(dep_counter)
|
||||
):
|
||||
return {"counter": count, "subcounter": subcount}
|
||||
|
||||
|
||||
@app.get("/sub-counter-no-cache/")
|
||||
async def get_sub_counter_no_cache(
|
||||
subcount: int = Depends(super_dep),
|
||||
count: int = Depends(dep_counter, use_cache=False),
|
||||
):
|
||||
return {"counter": count, "subcounter": subcount}
|
||||
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
def test_normal_counter():
|
||||
counter_holder["counter"] = 0
|
||||
response = client.get("/counter/")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"counter": 1}
|
||||
response = client.get("/counter/")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"counter": 2}
|
||||
|
||||
|
||||
def test_sub_counter():
|
||||
counter_holder["counter"] = 0
|
||||
response = client.get("/sub-counter/")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"counter": 1, "subcounter": 1}
|
||||
response = client.get("/sub-counter/")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"counter": 2, "subcounter": 2}
|
||||
|
||||
|
||||
def test_sub_counter_no_cache():
|
||||
counter_holder["counter"] = 0
|
||||
response = client.get("/sub-counter-no-cache/")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"counter": 2, "subcounter": 1}
|
||||
response = client.get("/sub-counter-no-cache/")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"counter": 4, "subcounter": 3}
|
||||
Loading…
Reference in New Issue