mirror of https://github.com/tiangolo/fastapi.git
✨ Add support for dependencies with scopes, support `scope="request"` for dependencies with `yield` that exit before the response is sent (#14262)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
425a4c5bb1
commit
ac438b9934
|
|
@ -184,6 +184,51 @@ If you raise any exception in the code from the *path operation function*, it wi
|
|||
|
||||
///
|
||||
|
||||
## Early exit and `scope` { #early-exit-and-scope }
|
||||
|
||||
Normally the exit code of dependencies with `yield` is executed **after the response** is sent to the client.
|
||||
|
||||
But if you know that you won't need to use the dependency after returning from the *path operation function*, you can use `Depends(scope="function")` to tell FastAPI that it should close the dependency after the *path operation function* returns, but **before** the **response is sent**.
|
||||
|
||||
{* ../../docs_src/dependencies/tutorial008e_an_py39.py hl[12,16] *}
|
||||
|
||||
`Depends()` receives a `scope` parameter that can be:
|
||||
|
||||
* `"function"`: start the dependency before the *path operation function* that handles the request, end the dependency after the *path operation function* ends, but **before** the response is sent back to the client. So, the dependency function will be executed **around** the *path operation **function***.
|
||||
* `"request"`: start the dependency before the *path operation function* that handles the request (similar to when using `"function"`), but end **after** the response is sent back to the client. So, the dependency function will be executed **around** the **request** and response cycle.
|
||||
|
||||
If not specified and the dependency has `yield`, it will have a `scope` of `"request"` by default.
|
||||
|
||||
### `scope` for sub-dependencies { #scope-for-sub-dependencies }
|
||||
|
||||
When you declare a dependency with a `scope="request"` (the default), any sub-dependency needs to also have a `scope` of `"request"`.
|
||||
|
||||
But a dependency with `scope` of `"function"` can have dependencies with `scope` of `"function"` and `scope` of `"request"`.
|
||||
|
||||
This is because any dependency needs to be able to run its exit code before the sub-dependencies, as it might need to still use them during its exit code.
|
||||
|
||||
```mermaid
|
||||
sequenceDiagram
|
||||
|
||||
participant client as Client
|
||||
participant dep_req as Dep scope="request"
|
||||
participant dep_func as Dep scope="function"
|
||||
participant operation as Path Operation
|
||||
|
||||
client ->> dep_req: Start request
|
||||
Note over dep_req: Run code up to yield
|
||||
dep_req ->> dep_func: Pass dependency
|
||||
Note over dep_func: Run code up to yield
|
||||
dep_func ->> operation: Run path operation with dependency
|
||||
operation ->> dep_func: Return from path operation
|
||||
Note over dep_func: Run code after yield
|
||||
Note over dep_func: ✅ Dependency closed
|
||||
dep_func ->> client: Send response to client
|
||||
Note over client: Response sent
|
||||
Note over dep_req: Run code after yield
|
||||
Note over dep_req: ✅ Dependency closed
|
||||
```
|
||||
|
||||
## Dependencies with `yield`, `HTTPException`, `except` and Background Tasks { #dependencies-with-yield-httpexception-except-and-background-tasks }
|
||||
|
||||
Dependencies with `yield` have evolved over time to cover different use cases and fix some issues.
|
||||
|
|
|
|||
|
|
@ -0,0 +1,15 @@
|
|||
from fastapi import Depends, FastAPI
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
def get_username():
|
||||
try:
|
||||
yield "Rick"
|
||||
finally:
|
||||
print("Cleanup up before response is sent")
|
||||
|
||||
|
||||
@app.get("/users/me")
|
||||
def get_user_me(username: str = Depends(get_username, scope="function")):
|
||||
return username
|
||||
|
|
@ -0,0 +1,16 @@
|
|||
from fastapi import Depends, FastAPI
|
||||
from typing_extensions import Annotated
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
def get_username():
|
||||
try:
|
||||
yield "Rick"
|
||||
finally:
|
||||
print("Cleanup up before response is sent")
|
||||
|
||||
|
||||
@app.get("/users/me")
|
||||
def get_user_me(username: Annotated[str, Depends(get_username, scope="function")]):
|
||||
return username
|
||||
|
|
@ -0,0 +1,17 @@
|
|||
from typing import Annotated
|
||||
|
||||
from fastapi import Depends, FastAPI
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
def get_username():
|
||||
try:
|
||||
yield "Rick"
|
||||
finally:
|
||||
print("Cleanup up before response is sent")
|
||||
|
||||
|
||||
@app.get("/users/me")
|
||||
def get_user_me(username: Annotated[str, Depends(get_username, scope="function")]):
|
||||
return username
|
||||
|
|
@ -1,8 +1,18 @@
|
|||
import inspect
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable, List, Optional, Sequence, Tuple
|
||||
from functools import cached_property
|
||||
from typing import Any, Callable, List, Optional, Sequence, Union
|
||||
|
||||
from fastapi._compat import ModelField
|
||||
from fastapi.security.base import SecurityBase
|
||||
from fastapi.types import DependencyCacheKey
|
||||
from typing_extensions import Literal
|
||||
|
||||
if sys.version_info >= (3, 13): # pragma: no cover
|
||||
from inspect import iscoroutinefunction
|
||||
else: # pragma: no cover
|
||||
from asyncio import iscoroutinefunction
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -31,7 +41,43 @@ class Dependant:
|
|||
security_scopes: Optional[List[str]] = None
|
||||
use_cache: bool = True
|
||||
path: Optional[str] = None
|
||||
cache_key: Tuple[Optional[Callable[..., Any]], Tuple[str, ...]] = field(init=False)
|
||||
scope: Union[Literal["function", "request"], None] = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.cache_key = (self.call, tuple(sorted(set(self.security_scopes or []))))
|
||||
@cached_property
|
||||
def cache_key(self) -> DependencyCacheKey:
|
||||
return (
|
||||
self.call,
|
||||
tuple(sorted(set(self.security_scopes or []))),
|
||||
self.computed_scope or "",
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def is_gen_callable(self) -> bool:
|
||||
if inspect.isgeneratorfunction(self.call):
|
||||
return True
|
||||
dunder_call = getattr(self.call, "__call__", None) # noqa: B004
|
||||
return inspect.isgeneratorfunction(dunder_call)
|
||||
|
||||
@cached_property
|
||||
def is_async_gen_callable(self) -> bool:
|
||||
if inspect.isasyncgenfunction(self.call):
|
||||
return True
|
||||
dunder_call = getattr(self.call, "__call__", None) # noqa: B004
|
||||
return inspect.isasyncgenfunction(dunder_call)
|
||||
|
||||
@cached_property
|
||||
def is_coroutine_callable(self) -> bool:
|
||||
if inspect.isroutine(self.call):
|
||||
return iscoroutinefunction(self.call)
|
||||
if inspect.isclass(self.call):
|
||||
return False
|
||||
dunder_call = getattr(self.call, "__call__", None) # noqa: B004
|
||||
return iscoroutinefunction(dunder_call)
|
||||
|
||||
@cached_property
|
||||
def computed_scope(self) -> Union[str, None]:
|
||||
if self.scope:
|
||||
return self.scope
|
||||
if self.is_gen_callable or self.is_async_gen_callable:
|
||||
return "request"
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
import inspect
|
||||
import sys
|
||||
from contextlib import AsyncExitStack, contextmanager
|
||||
from copy import copy, deepcopy
|
||||
from dataclasses import dataclass
|
||||
|
|
@ -55,10 +54,12 @@ from fastapi.concurrency import (
|
|||
contextmanager_in_threadpool,
|
||||
)
|
||||
from fastapi.dependencies.models import Dependant, SecurityRequirement
|
||||
from fastapi.exceptions import DependencyScopeError
|
||||
from fastapi.logger import logger
|
||||
from fastapi.security.base import SecurityBase
|
||||
from fastapi.security.oauth2 import OAuth2, SecurityScopes
|
||||
from fastapi.security.open_id_connect_url import OpenIdConnect
|
||||
from fastapi.types import DependencyCacheKey
|
||||
from fastapi.utils import create_model_field, get_path_param_names
|
||||
from pydantic import BaseModel
|
||||
from pydantic.fields import FieldInfo
|
||||
|
|
@ -74,15 +75,10 @@ from starlette.datastructures import (
|
|||
from starlette.requests import HTTPConnection, Request
|
||||
from starlette.responses import Response
|
||||
from starlette.websockets import WebSocket
|
||||
from typing_extensions import Annotated, get_args, get_origin
|
||||
from typing_extensions import Annotated, Literal, get_args, get_origin
|
||||
|
||||
from .. import temp_pydantic_v1_params
|
||||
|
||||
if sys.version_info >= (3, 13): # pragma: no cover
|
||||
from inspect import iscoroutinefunction
|
||||
else: # pragma: no cover
|
||||
from asyncio import iscoroutinefunction
|
||||
|
||||
multipart_not_installed_error = (
|
||||
'Form data requires "python-multipart" to be installed. \n'
|
||||
'You can install "python-multipart" with: \n\n'
|
||||
|
|
@ -137,14 +133,11 @@ def get_parameterless_sub_dependant(*, depends: params.Depends, path: str) -> De
|
|||
)
|
||||
|
||||
|
||||
CacheKey = Tuple[Optional[Callable[..., Any]], Tuple[str, ...]]
|
||||
|
||||
|
||||
def get_flat_dependant(
|
||||
dependant: Dependant,
|
||||
*,
|
||||
skip_repeats: bool = False,
|
||||
visited: Optional[List[CacheKey]] = None,
|
||||
visited: Optional[List[DependencyCacheKey]] = None,
|
||||
) -> Dependant:
|
||||
if visited is None:
|
||||
visited = []
|
||||
|
|
@ -237,6 +230,7 @@ def get_dependant(
|
|||
name: Optional[str] = None,
|
||||
security_scopes: Optional[List[str]] = None,
|
||||
use_cache: bool = True,
|
||||
scope: Union[Literal["function", "request"], None] = None,
|
||||
) -> Dependant:
|
||||
dependant = Dependant(
|
||||
call=call,
|
||||
|
|
@ -244,6 +238,7 @@ def get_dependant(
|
|||
path=path,
|
||||
security_scopes=security_scopes,
|
||||
use_cache=use_cache,
|
||||
scope=scope,
|
||||
)
|
||||
path_param_names = get_path_param_names(path)
|
||||
endpoint_signature = get_typed_signature(call)
|
||||
|
|
@ -251,7 +246,7 @@ def get_dependant(
|
|||
if isinstance(call, SecurityBase):
|
||||
use_scopes: List[str] = []
|
||||
if isinstance(call, (OAuth2, OpenIdConnect)):
|
||||
use_scopes = security_scopes
|
||||
use_scopes = security_scopes or use_scopes
|
||||
security_requirement = SecurityRequirement(
|
||||
security_scheme=call, scopes=use_scopes
|
||||
)
|
||||
|
|
@ -266,6 +261,16 @@ def get_dependant(
|
|||
)
|
||||
if param_details.depends is not None:
|
||||
assert param_details.depends.dependency
|
||||
if (
|
||||
(dependant.is_gen_callable or dependant.is_async_gen_callable)
|
||||
and dependant.computed_scope == "request"
|
||||
and param_details.depends.scope == "function"
|
||||
):
|
||||
assert dependant.call
|
||||
raise DependencyScopeError(
|
||||
f'The dependency "{dependant.call.__name__}" has a scope of '
|
||||
'"request", it cannot depend on dependencies with scope "function".'
|
||||
)
|
||||
use_security_scopes = security_scopes or []
|
||||
if isinstance(param_details.depends, params.Security):
|
||||
if param_details.depends.scopes:
|
||||
|
|
@ -276,6 +281,7 @@ def get_dependant(
|
|||
name=param_name,
|
||||
security_scopes=use_security_scopes,
|
||||
use_cache=param_details.depends.use_cache,
|
||||
scope=param_details.depends.scope,
|
||||
)
|
||||
dependant.dependencies.append(sub_dependant)
|
||||
continue
|
||||
|
|
@ -532,36 +538,14 @@ def add_param_to_fields(*, field: ModelField, dependant: Dependant) -> None:
|
|||
dependant.cookie_params.append(field)
|
||||
|
||||
|
||||
def is_coroutine_callable(call: Callable[..., Any]) -> bool:
|
||||
if inspect.isroutine(call):
|
||||
return iscoroutinefunction(call)
|
||||
if inspect.isclass(call):
|
||||
return False
|
||||
dunder_call = getattr(call, "__call__", None) # noqa: B004
|
||||
return iscoroutinefunction(dunder_call)
|
||||
|
||||
|
||||
def is_async_gen_callable(call: Callable[..., Any]) -> bool:
|
||||
if inspect.isasyncgenfunction(call):
|
||||
return True
|
||||
dunder_call = getattr(call, "__call__", None) # noqa: B004
|
||||
return inspect.isasyncgenfunction(dunder_call)
|
||||
|
||||
|
||||
def is_gen_callable(call: Callable[..., Any]) -> bool:
|
||||
if inspect.isgeneratorfunction(call):
|
||||
return True
|
||||
dunder_call = getattr(call, "__call__", None) # noqa: B004
|
||||
return inspect.isgeneratorfunction(dunder_call)
|
||||
|
||||
|
||||
async def solve_generator(
|
||||
*, call: Callable[..., Any], stack: AsyncExitStack, sub_values: Dict[str, Any]
|
||||
async def _solve_generator(
|
||||
*, dependant: Dependant, stack: AsyncExitStack, sub_values: Dict[str, Any]
|
||||
) -> Any:
|
||||
if is_gen_callable(call):
|
||||
cm = contextmanager_in_threadpool(contextmanager(call)(**sub_values))
|
||||
elif is_async_gen_callable(call):
|
||||
cm = asynccontextmanager(call)(**sub_values)
|
||||
assert dependant.call
|
||||
if dependant.is_gen_callable:
|
||||
cm = contextmanager_in_threadpool(contextmanager(dependant.call)(**sub_values))
|
||||
elif dependant.is_async_gen_callable:
|
||||
cm = asynccontextmanager(dependant.call)(**sub_values)
|
||||
return await stack.enter_async_context(cm)
|
||||
|
||||
|
||||
|
|
@ -571,7 +555,7 @@ class SolvedDependency:
|
|||
errors: List[Any]
|
||||
background_tasks: Optional[StarletteBackgroundTasks]
|
||||
response: Response
|
||||
dependency_cache: Dict[Tuple[Callable[..., Any], Tuple[str]], Any]
|
||||
dependency_cache: Dict[DependencyCacheKey, Any]
|
||||
|
||||
|
||||
async def solve_dependencies(
|
||||
|
|
@ -582,10 +566,20 @@ async def solve_dependencies(
|
|||
background_tasks: Optional[StarletteBackgroundTasks] = None,
|
||||
response: Optional[Response] = None,
|
||||
dependency_overrides_provider: Optional[Any] = None,
|
||||
dependency_cache: Optional[Dict[Tuple[Callable[..., Any], Tuple[str]], Any]] = None,
|
||||
dependency_cache: Optional[Dict[DependencyCacheKey, Any]] = None,
|
||||
# TODO: remove this parameter later, no longer used, not removing it yet as some
|
||||
# people might be monkey patching this function (although that's not supported)
|
||||
async_exit_stack: AsyncExitStack,
|
||||
embed_body_fields: bool,
|
||||
) -> SolvedDependency:
|
||||
request_astack = request.scope.get("fastapi_inner_astack")
|
||||
assert isinstance(request_astack, AsyncExitStack), (
|
||||
"fastapi_inner_astack not found in request scope"
|
||||
)
|
||||
function_astack = request.scope.get("fastapi_function_astack")
|
||||
assert isinstance(function_astack, AsyncExitStack), (
|
||||
"fastapi_function_astack not found in request scope"
|
||||
)
|
||||
values: Dict[str, Any] = {}
|
||||
errors: List[Any] = []
|
||||
if response is None:
|
||||
|
|
@ -594,12 +588,8 @@ async def solve_dependencies(
|
|||
response.status_code = None # type: ignore
|
||||
if dependency_cache is None:
|
||||
dependency_cache = {}
|
||||
sub_dependant: Dependant
|
||||
for sub_dependant in dependant.dependencies:
|
||||
sub_dependant.call = cast(Callable[..., Any], sub_dependant.call)
|
||||
sub_dependant.cache_key = cast(
|
||||
Tuple[Callable[..., Any], Tuple[str]], sub_dependant.cache_key
|
||||
)
|
||||
call = sub_dependant.call
|
||||
use_sub_dependant = sub_dependant
|
||||
if (
|
||||
|
|
@ -616,6 +606,7 @@ async def solve_dependencies(
|
|||
call=call,
|
||||
name=sub_dependant.name,
|
||||
security_scopes=sub_dependant.security_scopes,
|
||||
scope=sub_dependant.scope,
|
||||
)
|
||||
|
||||
solved_result = await solve_dependencies(
|
||||
|
|
@ -635,11 +626,18 @@ async def solve_dependencies(
|
|||
continue
|
||||
if sub_dependant.use_cache and sub_dependant.cache_key in dependency_cache:
|
||||
solved = dependency_cache[sub_dependant.cache_key]
|
||||
elif is_gen_callable(call) or is_async_gen_callable(call):
|
||||
solved = await solve_generator(
|
||||
call=call, stack=async_exit_stack, sub_values=solved_result.values
|
||||
elif (
|
||||
use_sub_dependant.is_gen_callable or use_sub_dependant.is_async_gen_callable
|
||||
):
|
||||
use_astack = request_astack
|
||||
if sub_dependant.scope == "function":
|
||||
use_astack = function_astack
|
||||
solved = await _solve_generator(
|
||||
dependant=use_sub_dependant,
|
||||
stack=use_astack,
|
||||
sub_values=solved_result.values,
|
||||
)
|
||||
elif is_coroutine_callable(call):
|
||||
elif use_sub_dependant.is_coroutine_callable:
|
||||
solved = await call(**solved_result.values)
|
||||
else:
|
||||
solved = await run_in_threadpool(call, **solved_result.values)
|
||||
|
|
|
|||
|
|
@ -147,6 +147,13 @@ class FastAPIError(RuntimeError):
|
|||
"""
|
||||
|
||||
|
||||
class DependencyScopeError(FastAPIError):
|
||||
"""
|
||||
A dependency declared that it depends on another dependency with an invalid
|
||||
(narrower) scope.
|
||||
"""
|
||||
|
||||
|
||||
class ValidationException(Exception):
|
||||
def __init__(self, errors: Sequence[Any]) -> None:
|
||||
self._errors = errors
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from annotated_doc import Doc
|
|||
from fastapi import params
|
||||
from fastapi._compat import Undefined
|
||||
from fastapi.openapi.models import Example
|
||||
from typing_extensions import Annotated, deprecated
|
||||
from typing_extensions import Annotated, Literal, deprecated
|
||||
|
||||
_Unset: Any = Undefined
|
||||
|
||||
|
|
@ -2245,6 +2245,26 @@ def Depends( # noqa: N802
|
|||
"""
|
||||
),
|
||||
] = True,
|
||||
scope: Annotated[
|
||||
Union[Literal["function", "request"], None],
|
||||
Doc(
|
||||
"""
|
||||
Mainly for dependencies with `yield`, define when the dependency function
|
||||
should start (the code before `yield`) and when it should end (the code
|
||||
after `yield`).
|
||||
|
||||
* `"function"`: start the dependency before the *path operation function*
|
||||
that handles the request, end the dependency after the *path operation
|
||||
function* ends, but **before** the response is sent back to the client.
|
||||
So, the dependency function will be executed **around** the *path operation
|
||||
**function***.
|
||||
* `"request"`: start the dependency before the *path operation function*
|
||||
that handles the request (similar to when using `"function"`), but end
|
||||
**after** the response is sent back to the client. So, the dependency
|
||||
function will be executed **around** the **request** and response cycle.
|
||||
"""
|
||||
),
|
||||
] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Declare a FastAPI dependency.
|
||||
|
|
@ -2275,7 +2295,7 @@ def Depends( # noqa: N802
|
|||
return commons
|
||||
```
|
||||
"""
|
||||
return params.Depends(dependency=dependency, use_cache=use_cache)
|
||||
return params.Depends(dependency=dependency, use_cache=use_cache, scope=scope)
|
||||
|
||||
|
||||
def Security( # noqa: N802
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, Union
|
|||
|
||||
from fastapi.openapi.models import Example
|
||||
from pydantic.fields import FieldInfo
|
||||
from typing_extensions import Annotated, deprecated
|
||||
from typing_extensions import Annotated, Literal, deprecated
|
||||
|
||||
from ._compat import (
|
||||
PYDANTIC_V2,
|
||||
|
|
@ -766,6 +766,7 @@ class File(Form): # type: ignore[misc]
|
|||
class Depends:
|
||||
dependency: Optional[Callable[..., Any]] = None
|
||||
use_cache: bool = True
|
||||
scope: Union[Literal["function", "request"], None] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
|
|||
|
|
@ -104,9 +104,10 @@ def request_response(
|
|||
async def app(scope: Scope, receive: Receive, send: Send) -> None:
|
||||
# Starts customization
|
||||
response_awaited = False
|
||||
async with AsyncExitStack() as stack:
|
||||
scope["fastapi_inner_astack"] = stack
|
||||
# Same as in Starlette
|
||||
async with AsyncExitStack() as request_stack:
|
||||
scope["fastapi_inner_astack"] = request_stack
|
||||
async with AsyncExitStack() as function_stack:
|
||||
scope["fastapi_function_astack"] = function_stack
|
||||
response = await f(request)
|
||||
await response(scope, receive, send)
|
||||
# Continues customization
|
||||
|
|
@ -140,10 +141,10 @@ def websocket_session(
|
|||
session = WebSocket(scope, receive=receive, send=send)
|
||||
|
||||
async def app(scope: Scope, receive: Receive, send: Send) -> None:
|
||||
# Starts customization
|
||||
async with AsyncExitStack() as stack:
|
||||
scope["fastapi_inner_astack"] = stack
|
||||
# Same as in Starlette
|
||||
async with AsyncExitStack() as request_stack:
|
||||
scope["fastapi_inner_astack"] = request_stack
|
||||
async with AsyncExitStack() as function_stack:
|
||||
scope["fastapi_function_astack"] = function_stack
|
||||
await func(session)
|
||||
|
||||
# Same as in Starlette
|
||||
|
|
@ -479,7 +480,9 @@ class APIWebSocketRoute(routing.WebSocketRoute):
|
|||
self.name = get_name(endpoint) if name is None else name
|
||||
self.dependencies = list(dependencies or [])
|
||||
self.path_regex, self.path_format, self.param_convertors = compile_path(path)
|
||||
self.dependant = get_dependant(path=self.path_format, call=self.endpoint)
|
||||
self.dependant = get_dependant(
|
||||
path=self.path_format, call=self.endpoint, scope="function"
|
||||
)
|
||||
for depends in self.dependencies[::-1]:
|
||||
self.dependant.dependencies.insert(
|
||||
0,
|
||||
|
|
@ -630,7 +633,9 @@ class APIRoute(routing.Route):
|
|||
self.response_fields = {}
|
||||
|
||||
assert callable(endpoint), "An endpoint must be a callable"
|
||||
self.dependant = get_dependant(path=self.path_format, call=self.endpoint)
|
||||
self.dependant = get_dependant(
|
||||
path=self.path_format, call=self.endpoint, scope="function"
|
||||
)
|
||||
for depends in self.dependencies[::-1]:
|
||||
self.dependant.dependencies.insert(
|
||||
0,
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import types
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Dict, Set, Type, TypeVar, Union
|
||||
from typing import Any, Callable, Dict, Optional, Set, Tuple, Type, TypeVar, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
|
@ -8,3 +8,4 @@ DecoratedCallable = TypeVar("DecoratedCallable", bound=Callable[..., Any])
|
|||
UnionType = getattr(types, "UnionType", Union)
|
||||
ModelNameMap = Dict[Union[Type[BaseModel], Type[Enum]], str]
|
||||
IncEx = Union[Set[int], Set[str], Dict[int, Any], Dict[str, Any]]
|
||||
DependencyCacheKey = Tuple[Optional[Callable[..., Any]], Tuple[str, ...], str]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,184 @@
|
|||
import json
|
||||
from typing import Any, Tuple
|
||||
|
||||
import pytest
|
||||
from fastapi import Depends, FastAPI
|
||||
from fastapi.exceptions import FastAPIError
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi.testclient import TestClient
|
||||
from typing_extensions import Annotated
|
||||
|
||||
|
||||
class Session:
|
||||
def __init__(self) -> None:
|
||||
self.open = True
|
||||
|
||||
|
||||
def dep_session() -> Any:
|
||||
s = Session()
|
||||
yield s
|
||||
s.open = False
|
||||
|
||||
|
||||
SessionFuncDep = Annotated[Session, Depends(dep_session, scope="function")]
|
||||
SessionRequestDep = Annotated[Session, Depends(dep_session, scope="request")]
|
||||
SessionDefaultDep = Annotated[Session, Depends(dep_session)]
|
||||
|
||||
|
||||
class NamedSession:
|
||||
def __init__(self, name: str = "default") -> None:
|
||||
self.name = name
|
||||
self.open = True
|
||||
|
||||
|
||||
def get_named_session(session: SessionRequestDep, session_b: SessionDefaultDep) -> Any:
|
||||
assert session is session_b
|
||||
named_session = NamedSession(name="named")
|
||||
yield named_session, session_b
|
||||
named_session.open = False
|
||||
|
||||
|
||||
NamedSessionsDep = Annotated[Tuple[NamedSession, Session], Depends(get_named_session)]
|
||||
|
||||
|
||||
def get_named_func_session(session: SessionFuncDep) -> Any:
|
||||
named_session = NamedSession(name="named")
|
||||
yield named_session, session
|
||||
named_session.open = False
|
||||
|
||||
|
||||
def get_named_regular_func_session(session: SessionFuncDep) -> Any:
|
||||
named_session = NamedSession(name="named")
|
||||
return named_session, session
|
||||
|
||||
|
||||
BrokenSessionsDep = Annotated[
|
||||
Tuple[NamedSession, Session], Depends(get_named_func_session)
|
||||
]
|
||||
NamedSessionsFuncDep = Annotated[
|
||||
Tuple[NamedSession, Session], Depends(get_named_func_session, scope="function")
|
||||
]
|
||||
|
||||
RegularSessionsDep = Annotated[
|
||||
Tuple[NamedSession, Session], Depends(get_named_regular_func_session)
|
||||
]
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
@app.get("/function-scope")
|
||||
def function_scope(session: SessionFuncDep) -> Any:
|
||||
def iter_data():
|
||||
yield json.dumps({"is_open": session.open})
|
||||
|
||||
return StreamingResponse(iter_data())
|
||||
|
||||
|
||||
@app.get("/request-scope")
|
||||
def request_scope(session: SessionRequestDep) -> Any:
|
||||
def iter_data():
|
||||
yield json.dumps({"is_open": session.open})
|
||||
|
||||
return StreamingResponse(iter_data())
|
||||
|
||||
|
||||
@app.get("/two-scopes")
|
||||
def get_stream_session(
|
||||
function_session: SessionFuncDep, request_session: SessionRequestDep
|
||||
) -> Any:
|
||||
def iter_data():
|
||||
yield json.dumps(
|
||||
{"func_is_open": function_session.open, "req_is_open": request_session.open}
|
||||
)
|
||||
|
||||
return StreamingResponse(iter_data())
|
||||
|
||||
|
||||
@app.get("/sub")
|
||||
def get_sub(sessions: NamedSessionsDep) -> Any:
|
||||
def iter_data():
|
||||
yield json.dumps(
|
||||
{"named_session_open": sessions[0].open, "session_open": sessions[1].open}
|
||||
)
|
||||
|
||||
return StreamingResponse(iter_data())
|
||||
|
||||
|
||||
@app.get("/named-function-scope")
|
||||
def get_named_function_scope(sessions: NamedSessionsFuncDep) -> Any:
|
||||
def iter_data():
|
||||
yield json.dumps(
|
||||
{"named_session_open": sessions[0].open, "session_open": sessions[1].open}
|
||||
)
|
||||
|
||||
return StreamingResponse(iter_data())
|
||||
|
||||
|
||||
@app.get("/regular-function-scope")
|
||||
def get_regular_function_scope(sessions: RegularSessionsDep) -> Any:
|
||||
def iter_data():
|
||||
yield json.dumps(
|
||||
{"named_session_open": sessions[0].open, "session_open": sessions[1].open}
|
||||
)
|
||||
|
||||
return StreamingResponse(iter_data())
|
||||
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
def test_function_scope() -> None:
|
||||
response = client.get("/function-scope")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["is_open"] is False
|
||||
|
||||
|
||||
def test_request_scope() -> None:
|
||||
response = client.get("/request-scope")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["is_open"] is True
|
||||
|
||||
|
||||
def test_two_scopes() -> None:
|
||||
response = client.get("/two-scopes")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["func_is_open"] is False
|
||||
assert data["req_is_open"] is True
|
||||
|
||||
|
||||
def test_sub() -> None:
|
||||
response = client.get("/sub")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["named_session_open"] is True
|
||||
assert data["session_open"] is True
|
||||
|
||||
|
||||
def test_broken_scope() -> None:
|
||||
with pytest.raises(
|
||||
FastAPIError,
|
||||
match='The dependency "get_named_func_session" has a scope of "request", it cannot depend on dependencies with scope "function"',
|
||||
):
|
||||
|
||||
@app.get("/broken-scope")
|
||||
def get_broken(sessions: BrokenSessionsDep) -> Any: # pragma: no cover
|
||||
pass
|
||||
|
||||
|
||||
def test_named_function_scope() -> None:
|
||||
response = client.get("/named-function-scope")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["named_session_open"] is False
|
||||
assert data["session_open"] is False
|
||||
|
||||
|
||||
def test_regular_function_scope() -> None:
|
||||
response = client.get("/regular-function-scope")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["named_session_open"] is True
|
||||
assert data["session_open"] is False
|
||||
|
|
@ -0,0 +1,201 @@
|
|||
from contextvars import ContextVar
|
||||
from typing import Any, Dict, Tuple
|
||||
|
||||
import pytest
|
||||
from fastapi import Depends, FastAPI, WebSocket
|
||||
from fastapi.exceptions import FastAPIError
|
||||
from fastapi.testclient import TestClient
|
||||
from typing_extensions import Annotated
|
||||
|
||||
global_context: ContextVar[Dict[str, Any]] = ContextVar("global_context", default={}) # noqa: B039
|
||||
|
||||
|
||||
class Session:
|
||||
def __init__(self) -> None:
|
||||
self.open = True
|
||||
|
||||
|
||||
async def dep_session() -> Any:
|
||||
s = Session()
|
||||
yield s
|
||||
s.open = False
|
||||
global_state = global_context.get()
|
||||
global_state["session_closed"] = True
|
||||
|
||||
|
||||
SessionFuncDep = Annotated[Session, Depends(dep_session, scope="function")]
|
||||
SessionRequestDep = Annotated[Session, Depends(dep_session, scope="request")]
|
||||
SessionDefaultDep = Annotated[Session, Depends(dep_session)]
|
||||
|
||||
|
||||
class NamedSession:
|
||||
def __init__(self, name: str = "default") -> None:
|
||||
self.name = name
|
||||
self.open = True
|
||||
|
||||
|
||||
def get_named_session(session: SessionRequestDep, session_b: SessionDefaultDep) -> Any:
|
||||
assert session is session_b
|
||||
named_session = NamedSession(name="named")
|
||||
yield named_session, session_b
|
||||
named_session.open = False
|
||||
global_state = global_context.get()
|
||||
global_state["named_session_closed"] = True
|
||||
|
||||
|
||||
NamedSessionsDep = Annotated[Tuple[NamedSession, Session], Depends(get_named_session)]
|
||||
|
||||
|
||||
def get_named_func_session(session: SessionFuncDep) -> Any:
|
||||
named_session = NamedSession(name="named")
|
||||
yield named_session, session
|
||||
named_session.open = False
|
||||
global_state = global_context.get()
|
||||
global_state["named_func_session_closed"] = True
|
||||
|
||||
|
||||
def get_named_regular_func_session(session: SessionFuncDep) -> Any:
|
||||
named_session = NamedSession(name="named")
|
||||
return named_session, session
|
||||
|
||||
|
||||
BrokenSessionsDep = Annotated[
|
||||
Tuple[NamedSession, Session], Depends(get_named_func_session)
|
||||
]
|
||||
NamedSessionsFuncDep = Annotated[
|
||||
Tuple[NamedSession, Session], Depends(get_named_func_session, scope="function")
|
||||
]
|
||||
|
||||
RegularSessionsDep = Annotated[
|
||||
Tuple[NamedSession, Session], Depends(get_named_regular_func_session)
|
||||
]
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
@app.websocket("/function-scope")
|
||||
async def function_scope(websocket: WebSocket, session: SessionFuncDep) -> Any:
|
||||
await websocket.accept()
|
||||
await websocket.send_json({"is_open": session.open})
|
||||
|
||||
|
||||
@app.websocket("/request-scope")
|
||||
async def request_scope(websocket: WebSocket, session: SessionRequestDep) -> Any:
|
||||
await websocket.accept()
|
||||
await websocket.send_json({"is_open": session.open})
|
||||
|
||||
|
||||
@app.websocket("/two-scopes")
|
||||
async def get_stream_session(
|
||||
websocket: WebSocket,
|
||||
function_session: SessionFuncDep,
|
||||
request_session: SessionRequestDep,
|
||||
) -> Any:
|
||||
await websocket.accept()
|
||||
await websocket.send_json(
|
||||
{"func_is_open": function_session.open, "req_is_open": request_session.open}
|
||||
)
|
||||
|
||||
|
||||
@app.websocket("/sub")
|
||||
async def get_sub(websocket: WebSocket, sessions: NamedSessionsDep) -> Any:
|
||||
await websocket.accept()
|
||||
await websocket.send_json(
|
||||
{"named_session_open": sessions[0].open, "session_open": sessions[1].open}
|
||||
)
|
||||
|
||||
|
||||
@app.websocket("/named-function-scope")
|
||||
async def get_named_function_scope(
|
||||
websocket: WebSocket, sessions: NamedSessionsFuncDep
|
||||
) -> Any:
|
||||
await websocket.accept()
|
||||
await websocket.send_json(
|
||||
{"named_session_open": sessions[0].open, "session_open": sessions[1].open}
|
||||
)
|
||||
|
||||
|
||||
@app.websocket("/regular-function-scope")
|
||||
async def get_regular_function_scope(
|
||||
websocket: WebSocket, sessions: RegularSessionsDep
|
||||
) -> Any:
|
||||
await websocket.accept()
|
||||
await websocket.send_json(
|
||||
{"named_session_open": sessions[0].open, "session_open": sessions[1].open}
|
||||
)
|
||||
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
def test_function_scope() -> None:
|
||||
global_context.set({})
|
||||
global_state = global_context.get()
|
||||
with client.websocket_connect("/function-scope") as websocket:
|
||||
data = websocket.receive_json()
|
||||
assert data["is_open"] is True
|
||||
assert global_state["session_closed"] is True
|
||||
|
||||
|
||||
def test_request_scope() -> None:
|
||||
global_context.set({})
|
||||
global_state = global_context.get()
|
||||
with client.websocket_connect("/request-scope") as websocket:
|
||||
data = websocket.receive_json()
|
||||
assert data["is_open"] is True
|
||||
assert global_state["session_closed"] is True
|
||||
|
||||
|
||||
def test_two_scopes() -> None:
|
||||
global_context.set({})
|
||||
global_state = global_context.get()
|
||||
with client.websocket_connect("/two-scopes") as websocket:
|
||||
data = websocket.receive_json()
|
||||
assert data["func_is_open"] is True
|
||||
assert data["req_is_open"] is True
|
||||
assert global_state["session_closed"] is True
|
||||
|
||||
|
||||
def test_sub() -> None:
|
||||
global_context.set({})
|
||||
global_state = global_context.get()
|
||||
with client.websocket_connect("/sub") as websocket:
|
||||
data = websocket.receive_json()
|
||||
assert data["named_session_open"] is True
|
||||
assert data["session_open"] is True
|
||||
assert global_state["session_closed"] is True
|
||||
assert global_state["named_session_closed"] is True
|
||||
|
||||
|
||||
def test_broken_scope() -> None:
|
||||
with pytest.raises(
|
||||
FastAPIError,
|
||||
match='The dependency "get_named_func_session" has a scope of "request", it cannot depend on dependencies with scope "function"',
|
||||
):
|
||||
|
||||
@app.websocket("/broken-scope")
|
||||
async def get_broken(
|
||||
websocket: WebSocket, sessions: BrokenSessionsDep
|
||||
) -> Any: # pragma: no cover
|
||||
pass
|
||||
|
||||
|
||||
def test_named_function_scope() -> None:
|
||||
global_context.set({})
|
||||
global_state = global_context.get()
|
||||
with client.websocket_connect("/named-function-scope") as websocket:
|
||||
data = websocket.receive_json()
|
||||
assert data["named_session_open"] is True
|
||||
assert data["session_open"] is True
|
||||
assert global_state["session_closed"] is True
|
||||
assert global_state["named_func_session_closed"] is True
|
||||
|
||||
|
||||
def test_regular_function_scope() -> None:
|
||||
global_context.set({})
|
||||
global_state = global_context.get()
|
||||
with client.websocket_connect("/regular-function-scope") as websocket:
|
||||
data = websocket.receive_json()
|
||||
assert data["named_session_open"] is True
|
||||
assert data["session_open"] is True
|
||||
assert global_state["session_closed"] is True
|
||||
|
|
@ -0,0 +1,27 @@
|
|||
import importlib
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from ...utils import needs_py39
|
||||
|
||||
|
||||
@pytest.fixture(
|
||||
name="client",
|
||||
params=[
|
||||
"tutorial008e",
|
||||
"tutorial008e_an",
|
||||
pytest.param("tutorial008e_an_py39", marks=needs_py39),
|
||||
],
|
||||
)
|
||||
def get_client(request: pytest.FixtureRequest):
|
||||
mod = importlib.import_module(f"docs_src.dependencies.{request.param}")
|
||||
|
||||
client = TestClient(mod.app)
|
||||
return client
|
||||
|
||||
|
||||
def test_get_users_me(client: TestClient):
|
||||
response = client.get("/users/me")
|
||||
assert response.status_code == 200, response.text
|
||||
assert response.json() == "Rick"
|
||||
Loading…
Reference in New Issue