Add support for dependencies with scopes, support `scope="request"` for dependencies with `yield` that exit before the response is sent (#14262)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Sebastián Ramírez 2025-11-03 11:12:49 +01:00 committed by GitHub
parent 425a4c5bb1
commit ac438b9934
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 653 additions and 70 deletions

View File

@ -184,6 +184,51 @@ If you raise any exception in the code from the *path operation function*, it wi
/// ///
## Early exit and `scope` { #early-exit-and-scope }
Normally the exit code of dependencies with `yield` is executed **after the response** is sent to the client.
But if you know that you won't need to use the dependency after returning from the *path operation function*, you can use `Depends(scope="function")` to tell FastAPI that it should close the dependency after the *path operation function* returns, but **before** the **response is sent**.
{* ../../docs_src/dependencies/tutorial008e_an_py39.py hl[12,16] *}
`Depends()` receives a `scope` parameter that can be:
* `"function"`: start the dependency before the *path operation function* that handles the request, end the dependency after the *path operation function* ends, but **before** the response is sent back to the client. So, the dependency function will be executed **around** the *path operation **function***.
* `"request"`: start the dependency before the *path operation function* that handles the request (similar to when using `"function"`), but end **after** the response is sent back to the client. So, the dependency function will be executed **around** the **request** and response cycle.
If not specified and the dependency has `yield`, it will have a `scope` of `"request"` by default.
### `scope` for sub-dependencies { #scope-for-sub-dependencies }
When you declare a dependency with a `scope="request"` (the default), any sub-dependency needs to also have a `scope` of `"request"`.
But a dependency with `scope` of `"function"` can have dependencies with `scope` of `"function"` and `scope` of `"request"`.
This is because any dependency needs to be able to run its exit code before the sub-dependencies, as it might need to still use them during its exit code.
```mermaid
sequenceDiagram
participant client as Client
participant dep_req as Dep scope="request"
participant dep_func as Dep scope="function"
participant operation as Path Operation
client ->> dep_req: Start request
Note over dep_req: Run code up to yield
dep_req ->> dep_func: Pass dependency
Note over dep_func: Run code up to yield
dep_func ->> operation: Run path operation with dependency
operation ->> dep_func: Return from path operation
Note over dep_func: Run code after yield
Note over dep_func: ✅ Dependency closed
dep_func ->> client: Send response to client
Note over client: Response sent
Note over dep_req: Run code after yield
Note over dep_req: ✅ Dependency closed
```
## Dependencies with `yield`, `HTTPException`, `except` and Background Tasks { #dependencies-with-yield-httpexception-except-and-background-tasks } ## Dependencies with `yield`, `HTTPException`, `except` and Background Tasks { #dependencies-with-yield-httpexception-except-and-background-tasks }
Dependencies with `yield` have evolved over time to cover different use cases and fix some issues. Dependencies with `yield` have evolved over time to cover different use cases and fix some issues.

View File

@ -0,0 +1,15 @@
from fastapi import Depends, FastAPI
app = FastAPI()
def get_username():
try:
yield "Rick"
finally:
print("Cleanup up before response is sent")
@app.get("/users/me")
def get_user_me(username: str = Depends(get_username, scope="function")):
return username

View File

@ -0,0 +1,16 @@
from fastapi import Depends, FastAPI
from typing_extensions import Annotated
app = FastAPI()
def get_username():
try:
yield "Rick"
finally:
print("Cleanup up before response is sent")
@app.get("/users/me")
def get_user_me(username: Annotated[str, Depends(get_username, scope="function")]):
return username

View File

@ -0,0 +1,17 @@
from typing import Annotated
from fastapi import Depends, FastAPI
app = FastAPI()
def get_username():
try:
yield "Rick"
finally:
print("Cleanup up before response is sent")
@app.get("/users/me")
def get_user_me(username: Annotated[str, Depends(get_username, scope="function")]):
return username

View File

@ -1,8 +1,18 @@
import inspect
import sys
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, Callable, List, Optional, Sequence, Tuple from functools import cached_property
from typing import Any, Callable, List, Optional, Sequence, Union
from fastapi._compat import ModelField from fastapi._compat import ModelField
from fastapi.security.base import SecurityBase from fastapi.security.base import SecurityBase
from fastapi.types import DependencyCacheKey
from typing_extensions import Literal
if sys.version_info >= (3, 13): # pragma: no cover
from inspect import iscoroutinefunction
else: # pragma: no cover
from asyncio import iscoroutinefunction
@dataclass @dataclass
@ -31,7 +41,43 @@ class Dependant:
security_scopes: Optional[List[str]] = None security_scopes: Optional[List[str]] = None
use_cache: bool = True use_cache: bool = True
path: Optional[str] = None path: Optional[str] = None
cache_key: Tuple[Optional[Callable[..., Any]], Tuple[str, ...]] = field(init=False) scope: Union[Literal["function", "request"], None] = None
def __post_init__(self) -> None: @cached_property
self.cache_key = (self.call, tuple(sorted(set(self.security_scopes or [])))) def cache_key(self) -> DependencyCacheKey:
return (
self.call,
tuple(sorted(set(self.security_scopes or []))),
self.computed_scope or "",
)
@cached_property
def is_gen_callable(self) -> bool:
if inspect.isgeneratorfunction(self.call):
return True
dunder_call = getattr(self.call, "__call__", None) # noqa: B004
return inspect.isgeneratorfunction(dunder_call)
@cached_property
def is_async_gen_callable(self) -> bool:
if inspect.isasyncgenfunction(self.call):
return True
dunder_call = getattr(self.call, "__call__", None) # noqa: B004
return inspect.isasyncgenfunction(dunder_call)
@cached_property
def is_coroutine_callable(self) -> bool:
if inspect.isroutine(self.call):
return iscoroutinefunction(self.call)
if inspect.isclass(self.call):
return False
dunder_call = getattr(self.call, "__call__", None) # noqa: B004
return iscoroutinefunction(dunder_call)
@cached_property
def computed_scope(self) -> Union[str, None]:
if self.scope:
return self.scope
if self.is_gen_callable or self.is_async_gen_callable:
return "request"
return None

View File

@ -1,5 +1,4 @@
import inspect import inspect
import sys
from contextlib import AsyncExitStack, contextmanager from contextlib import AsyncExitStack, contextmanager
from copy import copy, deepcopy from copy import copy, deepcopy
from dataclasses import dataclass from dataclasses import dataclass
@ -55,10 +54,12 @@ from fastapi.concurrency import (
contextmanager_in_threadpool, contextmanager_in_threadpool,
) )
from fastapi.dependencies.models import Dependant, SecurityRequirement from fastapi.dependencies.models import Dependant, SecurityRequirement
from fastapi.exceptions import DependencyScopeError
from fastapi.logger import logger from fastapi.logger import logger
from fastapi.security.base import SecurityBase from fastapi.security.base import SecurityBase
from fastapi.security.oauth2 import OAuth2, SecurityScopes from fastapi.security.oauth2 import OAuth2, SecurityScopes
from fastapi.security.open_id_connect_url import OpenIdConnect from fastapi.security.open_id_connect_url import OpenIdConnect
from fastapi.types import DependencyCacheKey
from fastapi.utils import create_model_field, get_path_param_names from fastapi.utils import create_model_field, get_path_param_names
from pydantic import BaseModel from pydantic import BaseModel
from pydantic.fields import FieldInfo from pydantic.fields import FieldInfo
@ -74,15 +75,10 @@ from starlette.datastructures import (
from starlette.requests import HTTPConnection, Request from starlette.requests import HTTPConnection, Request
from starlette.responses import Response from starlette.responses import Response
from starlette.websockets import WebSocket from starlette.websockets import WebSocket
from typing_extensions import Annotated, get_args, get_origin from typing_extensions import Annotated, Literal, get_args, get_origin
from .. import temp_pydantic_v1_params from .. import temp_pydantic_v1_params
if sys.version_info >= (3, 13): # pragma: no cover
from inspect import iscoroutinefunction
else: # pragma: no cover
from asyncio import iscoroutinefunction
multipart_not_installed_error = ( multipart_not_installed_error = (
'Form data requires "python-multipart" to be installed. \n' 'Form data requires "python-multipart" to be installed. \n'
'You can install "python-multipart" with: \n\n' 'You can install "python-multipart" with: \n\n'
@ -137,14 +133,11 @@ def get_parameterless_sub_dependant(*, depends: params.Depends, path: str) -> De
) )
CacheKey = Tuple[Optional[Callable[..., Any]], Tuple[str, ...]]
def get_flat_dependant( def get_flat_dependant(
dependant: Dependant, dependant: Dependant,
*, *,
skip_repeats: bool = False, skip_repeats: bool = False,
visited: Optional[List[CacheKey]] = None, visited: Optional[List[DependencyCacheKey]] = None,
) -> Dependant: ) -> Dependant:
if visited is None: if visited is None:
visited = [] visited = []
@ -237,6 +230,7 @@ def get_dependant(
name: Optional[str] = None, name: Optional[str] = None,
security_scopes: Optional[List[str]] = None, security_scopes: Optional[List[str]] = None,
use_cache: bool = True, use_cache: bool = True,
scope: Union[Literal["function", "request"], None] = None,
) -> Dependant: ) -> Dependant:
dependant = Dependant( dependant = Dependant(
call=call, call=call,
@ -244,6 +238,7 @@ def get_dependant(
path=path, path=path,
security_scopes=security_scopes, security_scopes=security_scopes,
use_cache=use_cache, use_cache=use_cache,
scope=scope,
) )
path_param_names = get_path_param_names(path) path_param_names = get_path_param_names(path)
endpoint_signature = get_typed_signature(call) endpoint_signature = get_typed_signature(call)
@ -251,7 +246,7 @@ def get_dependant(
if isinstance(call, SecurityBase): if isinstance(call, SecurityBase):
use_scopes: List[str] = [] use_scopes: List[str] = []
if isinstance(call, (OAuth2, OpenIdConnect)): if isinstance(call, (OAuth2, OpenIdConnect)):
use_scopes = security_scopes use_scopes = security_scopes or use_scopes
security_requirement = SecurityRequirement( security_requirement = SecurityRequirement(
security_scheme=call, scopes=use_scopes security_scheme=call, scopes=use_scopes
) )
@ -266,6 +261,16 @@ def get_dependant(
) )
if param_details.depends is not None: if param_details.depends is not None:
assert param_details.depends.dependency assert param_details.depends.dependency
if (
(dependant.is_gen_callable or dependant.is_async_gen_callable)
and dependant.computed_scope == "request"
and param_details.depends.scope == "function"
):
assert dependant.call
raise DependencyScopeError(
f'The dependency "{dependant.call.__name__}" has a scope of '
'"request", it cannot depend on dependencies with scope "function".'
)
use_security_scopes = security_scopes or [] use_security_scopes = security_scopes or []
if isinstance(param_details.depends, params.Security): if isinstance(param_details.depends, params.Security):
if param_details.depends.scopes: if param_details.depends.scopes:
@ -276,6 +281,7 @@ def get_dependant(
name=param_name, name=param_name,
security_scopes=use_security_scopes, security_scopes=use_security_scopes,
use_cache=param_details.depends.use_cache, use_cache=param_details.depends.use_cache,
scope=param_details.depends.scope,
) )
dependant.dependencies.append(sub_dependant) dependant.dependencies.append(sub_dependant)
continue continue
@ -532,36 +538,14 @@ def add_param_to_fields(*, field: ModelField, dependant: Dependant) -> None:
dependant.cookie_params.append(field) dependant.cookie_params.append(field)
def is_coroutine_callable(call: Callable[..., Any]) -> bool: async def _solve_generator(
if inspect.isroutine(call): *, dependant: Dependant, stack: AsyncExitStack, sub_values: Dict[str, Any]
return iscoroutinefunction(call)
if inspect.isclass(call):
return False
dunder_call = getattr(call, "__call__", None) # noqa: B004
return iscoroutinefunction(dunder_call)
def is_async_gen_callable(call: Callable[..., Any]) -> bool:
if inspect.isasyncgenfunction(call):
return True
dunder_call = getattr(call, "__call__", None) # noqa: B004
return inspect.isasyncgenfunction(dunder_call)
def is_gen_callable(call: Callable[..., Any]) -> bool:
if inspect.isgeneratorfunction(call):
return True
dunder_call = getattr(call, "__call__", None) # noqa: B004
return inspect.isgeneratorfunction(dunder_call)
async def solve_generator(
*, call: Callable[..., Any], stack: AsyncExitStack, sub_values: Dict[str, Any]
) -> Any: ) -> Any:
if is_gen_callable(call): assert dependant.call
cm = contextmanager_in_threadpool(contextmanager(call)(**sub_values)) if dependant.is_gen_callable:
elif is_async_gen_callable(call): cm = contextmanager_in_threadpool(contextmanager(dependant.call)(**sub_values))
cm = asynccontextmanager(call)(**sub_values) elif dependant.is_async_gen_callable:
cm = asynccontextmanager(dependant.call)(**sub_values)
return await stack.enter_async_context(cm) return await stack.enter_async_context(cm)
@ -571,7 +555,7 @@ class SolvedDependency:
errors: List[Any] errors: List[Any]
background_tasks: Optional[StarletteBackgroundTasks] background_tasks: Optional[StarletteBackgroundTasks]
response: Response response: Response
dependency_cache: Dict[Tuple[Callable[..., Any], Tuple[str]], Any] dependency_cache: Dict[DependencyCacheKey, Any]
async def solve_dependencies( async def solve_dependencies(
@ -582,10 +566,20 @@ async def solve_dependencies(
background_tasks: Optional[StarletteBackgroundTasks] = None, background_tasks: Optional[StarletteBackgroundTasks] = None,
response: Optional[Response] = None, response: Optional[Response] = None,
dependency_overrides_provider: Optional[Any] = None, dependency_overrides_provider: Optional[Any] = None,
dependency_cache: Optional[Dict[Tuple[Callable[..., Any], Tuple[str]], Any]] = None, dependency_cache: Optional[Dict[DependencyCacheKey, Any]] = None,
# TODO: remove this parameter later, no longer used, not removing it yet as some
# people might be monkey patching this function (although that's not supported)
async_exit_stack: AsyncExitStack, async_exit_stack: AsyncExitStack,
embed_body_fields: bool, embed_body_fields: bool,
) -> SolvedDependency: ) -> SolvedDependency:
request_astack = request.scope.get("fastapi_inner_astack")
assert isinstance(request_astack, AsyncExitStack), (
"fastapi_inner_astack not found in request scope"
)
function_astack = request.scope.get("fastapi_function_astack")
assert isinstance(function_astack, AsyncExitStack), (
"fastapi_function_astack not found in request scope"
)
values: Dict[str, Any] = {} values: Dict[str, Any] = {}
errors: List[Any] = [] errors: List[Any] = []
if response is None: if response is None:
@ -594,12 +588,8 @@ async def solve_dependencies(
response.status_code = None # type: ignore response.status_code = None # type: ignore
if dependency_cache is None: if dependency_cache is None:
dependency_cache = {} dependency_cache = {}
sub_dependant: Dependant
for sub_dependant in dependant.dependencies: for sub_dependant in dependant.dependencies:
sub_dependant.call = cast(Callable[..., Any], sub_dependant.call) sub_dependant.call = cast(Callable[..., Any], sub_dependant.call)
sub_dependant.cache_key = cast(
Tuple[Callable[..., Any], Tuple[str]], sub_dependant.cache_key
)
call = sub_dependant.call call = sub_dependant.call
use_sub_dependant = sub_dependant use_sub_dependant = sub_dependant
if ( if (
@ -616,6 +606,7 @@ async def solve_dependencies(
call=call, call=call,
name=sub_dependant.name, name=sub_dependant.name,
security_scopes=sub_dependant.security_scopes, security_scopes=sub_dependant.security_scopes,
scope=sub_dependant.scope,
) )
solved_result = await solve_dependencies( solved_result = await solve_dependencies(
@ -635,11 +626,18 @@ async def solve_dependencies(
continue continue
if sub_dependant.use_cache and sub_dependant.cache_key in dependency_cache: if sub_dependant.use_cache and sub_dependant.cache_key in dependency_cache:
solved = dependency_cache[sub_dependant.cache_key] solved = dependency_cache[sub_dependant.cache_key]
elif is_gen_callable(call) or is_async_gen_callable(call): elif (
solved = await solve_generator( use_sub_dependant.is_gen_callable or use_sub_dependant.is_async_gen_callable
call=call, stack=async_exit_stack, sub_values=solved_result.values ):
use_astack = request_astack
if sub_dependant.scope == "function":
use_astack = function_astack
solved = await _solve_generator(
dependant=use_sub_dependant,
stack=use_astack,
sub_values=solved_result.values,
) )
elif is_coroutine_callable(call): elif use_sub_dependant.is_coroutine_callable:
solved = await call(**solved_result.values) solved = await call(**solved_result.values)
else: else:
solved = await run_in_threadpool(call, **solved_result.values) solved = await run_in_threadpool(call, **solved_result.values)

View File

@ -147,6 +147,13 @@ class FastAPIError(RuntimeError):
""" """
class DependencyScopeError(FastAPIError):
"""
A dependency declared that it depends on another dependency with an invalid
(narrower) scope.
"""
class ValidationException(Exception): class ValidationException(Exception):
def __init__(self, errors: Sequence[Any]) -> None: def __init__(self, errors: Sequence[Any]) -> None:
self._errors = errors self._errors = errors

View File

@ -4,7 +4,7 @@ from annotated_doc import Doc
from fastapi import params from fastapi import params
from fastapi._compat import Undefined from fastapi._compat import Undefined
from fastapi.openapi.models import Example from fastapi.openapi.models import Example
from typing_extensions import Annotated, deprecated from typing_extensions import Annotated, Literal, deprecated
_Unset: Any = Undefined _Unset: Any = Undefined
@ -2245,6 +2245,26 @@ def Depends( # noqa: N802
""" """
), ),
] = True, ] = True,
scope: Annotated[
Union[Literal["function", "request"], None],
Doc(
"""
Mainly for dependencies with `yield`, define when the dependency function
should start (the code before `yield`) and when it should end (the code
after `yield`).
* `"function"`: start the dependency before the *path operation function*
that handles the request, end the dependency after the *path operation
function* ends, but **before** the response is sent back to the client.
So, the dependency function will be executed **around** the *path operation
**function***.
* `"request"`: start the dependency before the *path operation function*
that handles the request (similar to when using `"function"`), but end
**after** the response is sent back to the client. So, the dependency
function will be executed **around** the **request** and response cycle.
"""
),
] = None,
) -> Any: ) -> Any:
""" """
Declare a FastAPI dependency. Declare a FastAPI dependency.
@ -2275,7 +2295,7 @@ def Depends( # noqa: N802
return commons return commons
``` ```
""" """
return params.Depends(dependency=dependency, use_cache=use_cache) return params.Depends(dependency=dependency, use_cache=use_cache, scope=scope)
def Security( # noqa: N802 def Security( # noqa: N802

View File

@ -5,7 +5,7 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, Union
from fastapi.openapi.models import Example from fastapi.openapi.models import Example
from pydantic.fields import FieldInfo from pydantic.fields import FieldInfo
from typing_extensions import Annotated, deprecated from typing_extensions import Annotated, Literal, deprecated
from ._compat import ( from ._compat import (
PYDANTIC_V2, PYDANTIC_V2,
@ -766,6 +766,7 @@ class File(Form): # type: ignore[misc]
class Depends: class Depends:
dependency: Optional[Callable[..., Any]] = None dependency: Optional[Callable[..., Any]] = None
use_cache: bool = True use_cache: bool = True
scope: Union[Literal["function", "request"], None] = None
@dataclass @dataclass

View File

@ -104,10 +104,11 @@ def request_response(
async def app(scope: Scope, receive: Receive, send: Send) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None:
# Starts customization # Starts customization
response_awaited = False response_awaited = False
async with AsyncExitStack() as stack: async with AsyncExitStack() as request_stack:
scope["fastapi_inner_astack"] = stack scope["fastapi_inner_astack"] = request_stack
# Same as in Starlette async with AsyncExitStack() as function_stack:
response = await f(request) scope["fastapi_function_astack"] = function_stack
response = await f(request)
await response(scope, receive, send) await response(scope, receive, send)
# Continues customization # Continues customization
response_awaited = True response_awaited = True
@ -140,11 +141,11 @@ def websocket_session(
session = WebSocket(scope, receive=receive, send=send) session = WebSocket(scope, receive=receive, send=send)
async def app(scope: Scope, receive: Receive, send: Send) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None:
# Starts customization async with AsyncExitStack() as request_stack:
async with AsyncExitStack() as stack: scope["fastapi_inner_astack"] = request_stack
scope["fastapi_inner_astack"] = stack async with AsyncExitStack() as function_stack:
# Same as in Starlette scope["fastapi_function_astack"] = function_stack
await func(session) await func(session)
# Same as in Starlette # Same as in Starlette
await wrap_app_handling_exceptions(app, session)(scope, receive, send) await wrap_app_handling_exceptions(app, session)(scope, receive, send)
@ -479,7 +480,9 @@ class APIWebSocketRoute(routing.WebSocketRoute):
self.name = get_name(endpoint) if name is None else name self.name = get_name(endpoint) if name is None else name
self.dependencies = list(dependencies or []) self.dependencies = list(dependencies or [])
self.path_regex, self.path_format, self.param_convertors = compile_path(path) self.path_regex, self.path_format, self.param_convertors = compile_path(path)
self.dependant = get_dependant(path=self.path_format, call=self.endpoint) self.dependant = get_dependant(
path=self.path_format, call=self.endpoint, scope="function"
)
for depends in self.dependencies[::-1]: for depends in self.dependencies[::-1]:
self.dependant.dependencies.insert( self.dependant.dependencies.insert(
0, 0,
@ -630,7 +633,9 @@ class APIRoute(routing.Route):
self.response_fields = {} self.response_fields = {}
assert callable(endpoint), "An endpoint must be a callable" assert callable(endpoint), "An endpoint must be a callable"
self.dependant = get_dependant(path=self.path_format, call=self.endpoint) self.dependant = get_dependant(
path=self.path_format, call=self.endpoint, scope="function"
)
for depends in self.dependencies[::-1]: for depends in self.dependencies[::-1]:
self.dependant.dependencies.insert( self.dependant.dependencies.insert(
0, 0,

View File

@ -1,6 +1,6 @@
import types import types
from enum import Enum from enum import Enum
from typing import Any, Callable, Dict, Set, Type, TypeVar, Union from typing import Any, Callable, Dict, Optional, Set, Tuple, Type, TypeVar, Union
from pydantic import BaseModel from pydantic import BaseModel
@ -8,3 +8,4 @@ DecoratedCallable = TypeVar("DecoratedCallable", bound=Callable[..., Any])
UnionType = getattr(types, "UnionType", Union) UnionType = getattr(types, "UnionType", Union)
ModelNameMap = Dict[Union[Type[BaseModel], Type[Enum]], str] ModelNameMap = Dict[Union[Type[BaseModel], Type[Enum]], str]
IncEx = Union[Set[int], Set[str], Dict[int, Any], Dict[str, Any]] IncEx = Union[Set[int], Set[str], Dict[int, Any], Dict[str, Any]]
DependencyCacheKey = Tuple[Optional[Callable[..., Any]], Tuple[str, ...], str]

View File

@ -0,0 +1,184 @@
import json
from typing import Any, Tuple
import pytest
from fastapi import Depends, FastAPI
from fastapi.exceptions import FastAPIError
from fastapi.responses import StreamingResponse
from fastapi.testclient import TestClient
from typing_extensions import Annotated
class Session:
def __init__(self) -> None:
self.open = True
def dep_session() -> Any:
s = Session()
yield s
s.open = False
SessionFuncDep = Annotated[Session, Depends(dep_session, scope="function")]
SessionRequestDep = Annotated[Session, Depends(dep_session, scope="request")]
SessionDefaultDep = Annotated[Session, Depends(dep_session)]
class NamedSession:
def __init__(self, name: str = "default") -> None:
self.name = name
self.open = True
def get_named_session(session: SessionRequestDep, session_b: SessionDefaultDep) -> Any:
assert session is session_b
named_session = NamedSession(name="named")
yield named_session, session_b
named_session.open = False
NamedSessionsDep = Annotated[Tuple[NamedSession, Session], Depends(get_named_session)]
def get_named_func_session(session: SessionFuncDep) -> Any:
named_session = NamedSession(name="named")
yield named_session, session
named_session.open = False
def get_named_regular_func_session(session: SessionFuncDep) -> Any:
named_session = NamedSession(name="named")
return named_session, session
BrokenSessionsDep = Annotated[
Tuple[NamedSession, Session], Depends(get_named_func_session)
]
NamedSessionsFuncDep = Annotated[
Tuple[NamedSession, Session], Depends(get_named_func_session, scope="function")
]
RegularSessionsDep = Annotated[
Tuple[NamedSession, Session], Depends(get_named_regular_func_session)
]
app = FastAPI()
@app.get("/function-scope")
def function_scope(session: SessionFuncDep) -> Any:
def iter_data():
yield json.dumps({"is_open": session.open})
return StreamingResponse(iter_data())
@app.get("/request-scope")
def request_scope(session: SessionRequestDep) -> Any:
def iter_data():
yield json.dumps({"is_open": session.open})
return StreamingResponse(iter_data())
@app.get("/two-scopes")
def get_stream_session(
function_session: SessionFuncDep, request_session: SessionRequestDep
) -> Any:
def iter_data():
yield json.dumps(
{"func_is_open": function_session.open, "req_is_open": request_session.open}
)
return StreamingResponse(iter_data())
@app.get("/sub")
def get_sub(sessions: NamedSessionsDep) -> Any:
def iter_data():
yield json.dumps(
{"named_session_open": sessions[0].open, "session_open": sessions[1].open}
)
return StreamingResponse(iter_data())
@app.get("/named-function-scope")
def get_named_function_scope(sessions: NamedSessionsFuncDep) -> Any:
def iter_data():
yield json.dumps(
{"named_session_open": sessions[0].open, "session_open": sessions[1].open}
)
return StreamingResponse(iter_data())
@app.get("/regular-function-scope")
def get_regular_function_scope(sessions: RegularSessionsDep) -> Any:
def iter_data():
yield json.dumps(
{"named_session_open": sessions[0].open, "session_open": sessions[1].open}
)
return StreamingResponse(iter_data())
client = TestClient(app)
def test_function_scope() -> None:
response = client.get("/function-scope")
assert response.status_code == 200
data = response.json()
assert data["is_open"] is False
def test_request_scope() -> None:
response = client.get("/request-scope")
assert response.status_code == 200
data = response.json()
assert data["is_open"] is True
def test_two_scopes() -> None:
response = client.get("/two-scopes")
assert response.status_code == 200
data = response.json()
assert data["func_is_open"] is False
assert data["req_is_open"] is True
def test_sub() -> None:
response = client.get("/sub")
assert response.status_code == 200
data = response.json()
assert data["named_session_open"] is True
assert data["session_open"] is True
def test_broken_scope() -> None:
with pytest.raises(
FastAPIError,
match='The dependency "get_named_func_session" has a scope of "request", it cannot depend on dependencies with scope "function"',
):
@app.get("/broken-scope")
def get_broken(sessions: BrokenSessionsDep) -> Any: # pragma: no cover
pass
def test_named_function_scope() -> None:
response = client.get("/named-function-scope")
assert response.status_code == 200
data = response.json()
assert data["named_session_open"] is False
assert data["session_open"] is False
def test_regular_function_scope() -> None:
response = client.get("/regular-function-scope")
assert response.status_code == 200
data = response.json()
assert data["named_session_open"] is True
assert data["session_open"] is False

View File

@ -0,0 +1,201 @@
from contextvars import ContextVar
from typing import Any, Dict, Tuple
import pytest
from fastapi import Depends, FastAPI, WebSocket
from fastapi.exceptions import FastAPIError
from fastapi.testclient import TestClient
from typing_extensions import Annotated
global_context: ContextVar[Dict[str, Any]] = ContextVar("global_context", default={}) # noqa: B039
class Session:
def __init__(self) -> None:
self.open = True
async def dep_session() -> Any:
s = Session()
yield s
s.open = False
global_state = global_context.get()
global_state["session_closed"] = True
SessionFuncDep = Annotated[Session, Depends(dep_session, scope="function")]
SessionRequestDep = Annotated[Session, Depends(dep_session, scope="request")]
SessionDefaultDep = Annotated[Session, Depends(dep_session)]
class NamedSession:
def __init__(self, name: str = "default") -> None:
self.name = name
self.open = True
def get_named_session(session: SessionRequestDep, session_b: SessionDefaultDep) -> Any:
assert session is session_b
named_session = NamedSession(name="named")
yield named_session, session_b
named_session.open = False
global_state = global_context.get()
global_state["named_session_closed"] = True
NamedSessionsDep = Annotated[Tuple[NamedSession, Session], Depends(get_named_session)]
def get_named_func_session(session: SessionFuncDep) -> Any:
named_session = NamedSession(name="named")
yield named_session, session
named_session.open = False
global_state = global_context.get()
global_state["named_func_session_closed"] = True
def get_named_regular_func_session(session: SessionFuncDep) -> Any:
named_session = NamedSession(name="named")
return named_session, session
BrokenSessionsDep = Annotated[
Tuple[NamedSession, Session], Depends(get_named_func_session)
]
NamedSessionsFuncDep = Annotated[
Tuple[NamedSession, Session], Depends(get_named_func_session, scope="function")
]
RegularSessionsDep = Annotated[
Tuple[NamedSession, Session], Depends(get_named_regular_func_session)
]
app = FastAPI()
@app.websocket("/function-scope")
async def function_scope(websocket: WebSocket, session: SessionFuncDep) -> Any:
await websocket.accept()
await websocket.send_json({"is_open": session.open})
@app.websocket("/request-scope")
async def request_scope(websocket: WebSocket, session: SessionRequestDep) -> Any:
await websocket.accept()
await websocket.send_json({"is_open": session.open})
@app.websocket("/two-scopes")
async def get_stream_session(
websocket: WebSocket,
function_session: SessionFuncDep,
request_session: SessionRequestDep,
) -> Any:
await websocket.accept()
await websocket.send_json(
{"func_is_open": function_session.open, "req_is_open": request_session.open}
)
@app.websocket("/sub")
async def get_sub(websocket: WebSocket, sessions: NamedSessionsDep) -> Any:
await websocket.accept()
await websocket.send_json(
{"named_session_open": sessions[0].open, "session_open": sessions[1].open}
)
@app.websocket("/named-function-scope")
async def get_named_function_scope(
websocket: WebSocket, sessions: NamedSessionsFuncDep
) -> Any:
await websocket.accept()
await websocket.send_json(
{"named_session_open": sessions[0].open, "session_open": sessions[1].open}
)
@app.websocket("/regular-function-scope")
async def get_regular_function_scope(
websocket: WebSocket, sessions: RegularSessionsDep
) -> Any:
await websocket.accept()
await websocket.send_json(
{"named_session_open": sessions[0].open, "session_open": sessions[1].open}
)
client = TestClient(app)
def test_function_scope() -> None:
global_context.set({})
global_state = global_context.get()
with client.websocket_connect("/function-scope") as websocket:
data = websocket.receive_json()
assert data["is_open"] is True
assert global_state["session_closed"] is True
def test_request_scope() -> None:
global_context.set({})
global_state = global_context.get()
with client.websocket_connect("/request-scope") as websocket:
data = websocket.receive_json()
assert data["is_open"] is True
assert global_state["session_closed"] is True
def test_two_scopes() -> None:
global_context.set({})
global_state = global_context.get()
with client.websocket_connect("/two-scopes") as websocket:
data = websocket.receive_json()
assert data["func_is_open"] is True
assert data["req_is_open"] is True
assert global_state["session_closed"] is True
def test_sub() -> None:
global_context.set({})
global_state = global_context.get()
with client.websocket_connect("/sub") as websocket:
data = websocket.receive_json()
assert data["named_session_open"] is True
assert data["session_open"] is True
assert global_state["session_closed"] is True
assert global_state["named_session_closed"] is True
def test_broken_scope() -> None:
with pytest.raises(
FastAPIError,
match='The dependency "get_named_func_session" has a scope of "request", it cannot depend on dependencies with scope "function"',
):
@app.websocket("/broken-scope")
async def get_broken(
websocket: WebSocket, sessions: BrokenSessionsDep
) -> Any: # pragma: no cover
pass
def test_named_function_scope() -> None:
global_context.set({})
global_state = global_context.get()
with client.websocket_connect("/named-function-scope") as websocket:
data = websocket.receive_json()
assert data["named_session_open"] is True
assert data["session_open"] is True
assert global_state["session_closed"] is True
assert global_state["named_func_session_closed"] is True
def test_regular_function_scope() -> None:
global_context.set({})
global_state = global_context.get()
with client.websocket_connect("/regular-function-scope") as websocket:
data = websocket.receive_json()
assert data["named_session_open"] is True
assert data["session_open"] is True
assert global_state["session_closed"] is True

View File

@ -0,0 +1,27 @@
import importlib
import pytest
from fastapi.testclient import TestClient
from ...utils import needs_py39
@pytest.fixture(
name="client",
params=[
"tutorial008e",
"tutorial008e_an",
pytest.param("tutorial008e_an_py39", marks=needs_py39),
],
)
def get_client(request: pytest.FixtureRequest):
mod = importlib.import_module(f"docs_src.dependencies.{request.param}")
client = TestClient(mod.app)
return client
def test_get_users_me(client: TestClient):
response = client.get("/users/me")
assert response.status_code == 200, response.text
assert response.json() == "Rick"