mirror of https://github.com/tiangolo/fastapi.git
✨ Add support for dependencies with scopes, support `scope="request"` for dependencies with `yield` that exit before the response is sent (#14262)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
425a4c5bb1
commit
ac438b9934
|
|
@ -184,6 +184,51 @@ If you raise any exception in the code from the *path operation function*, it wi
|
||||||
|
|
||||||
///
|
///
|
||||||
|
|
||||||
|
## Early exit and `scope` { #early-exit-and-scope }
|
||||||
|
|
||||||
|
Normally the exit code of dependencies with `yield` is executed **after the response** is sent to the client.
|
||||||
|
|
||||||
|
But if you know that you won't need to use the dependency after returning from the *path operation function*, you can use `Depends(scope="function")` to tell FastAPI that it should close the dependency after the *path operation function* returns, but **before** the **response is sent**.
|
||||||
|
|
||||||
|
{* ../../docs_src/dependencies/tutorial008e_an_py39.py hl[12,16] *}
|
||||||
|
|
||||||
|
`Depends()` receives a `scope` parameter that can be:
|
||||||
|
|
||||||
|
* `"function"`: start the dependency before the *path operation function* that handles the request, end the dependency after the *path operation function* ends, but **before** the response is sent back to the client. So, the dependency function will be executed **around** the *path operation **function***.
|
||||||
|
* `"request"`: start the dependency before the *path operation function* that handles the request (similar to when using `"function"`), but end **after** the response is sent back to the client. So, the dependency function will be executed **around** the **request** and response cycle.
|
||||||
|
|
||||||
|
If not specified and the dependency has `yield`, it will have a `scope` of `"request"` by default.
|
||||||
|
|
||||||
|
### `scope` for sub-dependencies { #scope-for-sub-dependencies }
|
||||||
|
|
||||||
|
When you declare a dependency with a `scope="request"` (the default), any sub-dependency needs to also have a `scope` of `"request"`.
|
||||||
|
|
||||||
|
But a dependency with `scope` of `"function"` can have dependencies with `scope` of `"function"` and `scope` of `"request"`.
|
||||||
|
|
||||||
|
This is because any dependency needs to be able to run its exit code before the sub-dependencies, as it might need to still use them during its exit code.
|
||||||
|
|
||||||
|
```mermaid
|
||||||
|
sequenceDiagram
|
||||||
|
|
||||||
|
participant client as Client
|
||||||
|
participant dep_req as Dep scope="request"
|
||||||
|
participant dep_func as Dep scope="function"
|
||||||
|
participant operation as Path Operation
|
||||||
|
|
||||||
|
client ->> dep_req: Start request
|
||||||
|
Note over dep_req: Run code up to yield
|
||||||
|
dep_req ->> dep_func: Pass dependency
|
||||||
|
Note over dep_func: Run code up to yield
|
||||||
|
dep_func ->> operation: Run path operation with dependency
|
||||||
|
operation ->> dep_func: Return from path operation
|
||||||
|
Note over dep_func: Run code after yield
|
||||||
|
Note over dep_func: ✅ Dependency closed
|
||||||
|
dep_func ->> client: Send response to client
|
||||||
|
Note over client: Response sent
|
||||||
|
Note over dep_req: Run code after yield
|
||||||
|
Note over dep_req: ✅ Dependency closed
|
||||||
|
```
|
||||||
|
|
||||||
## Dependencies with `yield`, `HTTPException`, `except` and Background Tasks { #dependencies-with-yield-httpexception-except-and-background-tasks }
|
## Dependencies with `yield`, `HTTPException`, `except` and Background Tasks { #dependencies-with-yield-httpexception-except-and-background-tasks }
|
||||||
|
|
||||||
Dependencies with `yield` have evolved over time to cover different use cases and fix some issues.
|
Dependencies with `yield` have evolved over time to cover different use cases and fix some issues.
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,15 @@
|
||||||
|
from fastapi import Depends, FastAPI
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
|
||||||
|
def get_username():
|
||||||
|
try:
|
||||||
|
yield "Rick"
|
||||||
|
finally:
|
||||||
|
print("Cleanup up before response is sent")
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/users/me")
|
||||||
|
def get_user_me(username: str = Depends(get_username, scope="function")):
|
||||||
|
return username
|
||||||
|
|
@ -0,0 +1,16 @@
|
||||||
|
from fastapi import Depends, FastAPI
|
||||||
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
|
||||||
|
def get_username():
|
||||||
|
try:
|
||||||
|
yield "Rick"
|
||||||
|
finally:
|
||||||
|
print("Cleanup up before response is sent")
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/users/me")
|
||||||
|
def get_user_me(username: Annotated[str, Depends(get_username, scope="function")]):
|
||||||
|
return username
|
||||||
|
|
@ -0,0 +1,17 @@
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from fastapi import Depends, FastAPI
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
|
||||||
|
def get_username():
|
||||||
|
try:
|
||||||
|
yield "Rick"
|
||||||
|
finally:
|
||||||
|
print("Cleanup up before response is sent")
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/users/me")
|
||||||
|
def get_user_me(username: Annotated[str, Depends(get_username, scope="function")]):
|
||||||
|
return username
|
||||||
|
|
@ -1,8 +1,18 @@
|
||||||
|
import inspect
|
||||||
|
import sys
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any, Callable, List, Optional, Sequence, Tuple
|
from functools import cached_property
|
||||||
|
from typing import Any, Callable, List, Optional, Sequence, Union
|
||||||
|
|
||||||
from fastapi._compat import ModelField
|
from fastapi._compat import ModelField
|
||||||
from fastapi.security.base import SecurityBase
|
from fastapi.security.base import SecurityBase
|
||||||
|
from fastapi.types import DependencyCacheKey
|
||||||
|
from typing_extensions import Literal
|
||||||
|
|
||||||
|
if sys.version_info >= (3, 13): # pragma: no cover
|
||||||
|
from inspect import iscoroutinefunction
|
||||||
|
else: # pragma: no cover
|
||||||
|
from asyncio import iscoroutinefunction
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
@ -31,7 +41,43 @@ class Dependant:
|
||||||
security_scopes: Optional[List[str]] = None
|
security_scopes: Optional[List[str]] = None
|
||||||
use_cache: bool = True
|
use_cache: bool = True
|
||||||
path: Optional[str] = None
|
path: Optional[str] = None
|
||||||
cache_key: Tuple[Optional[Callable[..., Any]], Tuple[str, ...]] = field(init=False)
|
scope: Union[Literal["function", "request"], None] = None
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
@cached_property
|
||||||
self.cache_key = (self.call, tuple(sorted(set(self.security_scopes or []))))
|
def cache_key(self) -> DependencyCacheKey:
|
||||||
|
return (
|
||||||
|
self.call,
|
||||||
|
tuple(sorted(set(self.security_scopes or []))),
|
||||||
|
self.computed_scope or "",
|
||||||
|
)
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def is_gen_callable(self) -> bool:
|
||||||
|
if inspect.isgeneratorfunction(self.call):
|
||||||
|
return True
|
||||||
|
dunder_call = getattr(self.call, "__call__", None) # noqa: B004
|
||||||
|
return inspect.isgeneratorfunction(dunder_call)
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def is_async_gen_callable(self) -> bool:
|
||||||
|
if inspect.isasyncgenfunction(self.call):
|
||||||
|
return True
|
||||||
|
dunder_call = getattr(self.call, "__call__", None) # noqa: B004
|
||||||
|
return inspect.isasyncgenfunction(dunder_call)
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def is_coroutine_callable(self) -> bool:
|
||||||
|
if inspect.isroutine(self.call):
|
||||||
|
return iscoroutinefunction(self.call)
|
||||||
|
if inspect.isclass(self.call):
|
||||||
|
return False
|
||||||
|
dunder_call = getattr(self.call, "__call__", None) # noqa: B004
|
||||||
|
return iscoroutinefunction(dunder_call)
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def computed_scope(self) -> Union[str, None]:
|
||||||
|
if self.scope:
|
||||||
|
return self.scope
|
||||||
|
if self.is_gen_callable or self.is_async_gen_callable:
|
||||||
|
return "request"
|
||||||
|
return None
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,4 @@
|
||||||
import inspect
|
import inspect
|
||||||
import sys
|
|
||||||
from contextlib import AsyncExitStack, contextmanager
|
from contextlib import AsyncExitStack, contextmanager
|
||||||
from copy import copy, deepcopy
|
from copy import copy, deepcopy
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
@ -55,10 +54,12 @@ from fastapi.concurrency import (
|
||||||
contextmanager_in_threadpool,
|
contextmanager_in_threadpool,
|
||||||
)
|
)
|
||||||
from fastapi.dependencies.models import Dependant, SecurityRequirement
|
from fastapi.dependencies.models import Dependant, SecurityRequirement
|
||||||
|
from fastapi.exceptions import DependencyScopeError
|
||||||
from fastapi.logger import logger
|
from fastapi.logger import logger
|
||||||
from fastapi.security.base import SecurityBase
|
from fastapi.security.base import SecurityBase
|
||||||
from fastapi.security.oauth2 import OAuth2, SecurityScopes
|
from fastapi.security.oauth2 import OAuth2, SecurityScopes
|
||||||
from fastapi.security.open_id_connect_url import OpenIdConnect
|
from fastapi.security.open_id_connect_url import OpenIdConnect
|
||||||
|
from fastapi.types import DependencyCacheKey
|
||||||
from fastapi.utils import create_model_field, get_path_param_names
|
from fastapi.utils import create_model_field, get_path_param_names
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from pydantic.fields import FieldInfo
|
from pydantic.fields import FieldInfo
|
||||||
|
|
@ -74,15 +75,10 @@ from starlette.datastructures import (
|
||||||
from starlette.requests import HTTPConnection, Request
|
from starlette.requests import HTTPConnection, Request
|
||||||
from starlette.responses import Response
|
from starlette.responses import Response
|
||||||
from starlette.websockets import WebSocket
|
from starlette.websockets import WebSocket
|
||||||
from typing_extensions import Annotated, get_args, get_origin
|
from typing_extensions import Annotated, Literal, get_args, get_origin
|
||||||
|
|
||||||
from .. import temp_pydantic_v1_params
|
from .. import temp_pydantic_v1_params
|
||||||
|
|
||||||
if sys.version_info >= (3, 13): # pragma: no cover
|
|
||||||
from inspect import iscoroutinefunction
|
|
||||||
else: # pragma: no cover
|
|
||||||
from asyncio import iscoroutinefunction
|
|
||||||
|
|
||||||
multipart_not_installed_error = (
|
multipart_not_installed_error = (
|
||||||
'Form data requires "python-multipart" to be installed. \n'
|
'Form data requires "python-multipart" to be installed. \n'
|
||||||
'You can install "python-multipart" with: \n\n'
|
'You can install "python-multipart" with: \n\n'
|
||||||
|
|
@ -137,14 +133,11 @@ def get_parameterless_sub_dependant(*, depends: params.Depends, path: str) -> De
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
CacheKey = Tuple[Optional[Callable[..., Any]], Tuple[str, ...]]
|
|
||||||
|
|
||||||
|
|
||||||
def get_flat_dependant(
|
def get_flat_dependant(
|
||||||
dependant: Dependant,
|
dependant: Dependant,
|
||||||
*,
|
*,
|
||||||
skip_repeats: bool = False,
|
skip_repeats: bool = False,
|
||||||
visited: Optional[List[CacheKey]] = None,
|
visited: Optional[List[DependencyCacheKey]] = None,
|
||||||
) -> Dependant:
|
) -> Dependant:
|
||||||
if visited is None:
|
if visited is None:
|
||||||
visited = []
|
visited = []
|
||||||
|
|
@ -237,6 +230,7 @@ def get_dependant(
|
||||||
name: Optional[str] = None,
|
name: Optional[str] = None,
|
||||||
security_scopes: Optional[List[str]] = None,
|
security_scopes: Optional[List[str]] = None,
|
||||||
use_cache: bool = True,
|
use_cache: bool = True,
|
||||||
|
scope: Union[Literal["function", "request"], None] = None,
|
||||||
) -> Dependant:
|
) -> Dependant:
|
||||||
dependant = Dependant(
|
dependant = Dependant(
|
||||||
call=call,
|
call=call,
|
||||||
|
|
@ -244,6 +238,7 @@ def get_dependant(
|
||||||
path=path,
|
path=path,
|
||||||
security_scopes=security_scopes,
|
security_scopes=security_scopes,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
|
scope=scope,
|
||||||
)
|
)
|
||||||
path_param_names = get_path_param_names(path)
|
path_param_names = get_path_param_names(path)
|
||||||
endpoint_signature = get_typed_signature(call)
|
endpoint_signature = get_typed_signature(call)
|
||||||
|
|
@ -251,7 +246,7 @@ def get_dependant(
|
||||||
if isinstance(call, SecurityBase):
|
if isinstance(call, SecurityBase):
|
||||||
use_scopes: List[str] = []
|
use_scopes: List[str] = []
|
||||||
if isinstance(call, (OAuth2, OpenIdConnect)):
|
if isinstance(call, (OAuth2, OpenIdConnect)):
|
||||||
use_scopes = security_scopes
|
use_scopes = security_scopes or use_scopes
|
||||||
security_requirement = SecurityRequirement(
|
security_requirement = SecurityRequirement(
|
||||||
security_scheme=call, scopes=use_scopes
|
security_scheme=call, scopes=use_scopes
|
||||||
)
|
)
|
||||||
|
|
@ -266,6 +261,16 @@ def get_dependant(
|
||||||
)
|
)
|
||||||
if param_details.depends is not None:
|
if param_details.depends is not None:
|
||||||
assert param_details.depends.dependency
|
assert param_details.depends.dependency
|
||||||
|
if (
|
||||||
|
(dependant.is_gen_callable or dependant.is_async_gen_callable)
|
||||||
|
and dependant.computed_scope == "request"
|
||||||
|
and param_details.depends.scope == "function"
|
||||||
|
):
|
||||||
|
assert dependant.call
|
||||||
|
raise DependencyScopeError(
|
||||||
|
f'The dependency "{dependant.call.__name__}" has a scope of '
|
||||||
|
'"request", it cannot depend on dependencies with scope "function".'
|
||||||
|
)
|
||||||
use_security_scopes = security_scopes or []
|
use_security_scopes = security_scopes or []
|
||||||
if isinstance(param_details.depends, params.Security):
|
if isinstance(param_details.depends, params.Security):
|
||||||
if param_details.depends.scopes:
|
if param_details.depends.scopes:
|
||||||
|
|
@ -276,6 +281,7 @@ def get_dependant(
|
||||||
name=param_name,
|
name=param_name,
|
||||||
security_scopes=use_security_scopes,
|
security_scopes=use_security_scopes,
|
||||||
use_cache=param_details.depends.use_cache,
|
use_cache=param_details.depends.use_cache,
|
||||||
|
scope=param_details.depends.scope,
|
||||||
)
|
)
|
||||||
dependant.dependencies.append(sub_dependant)
|
dependant.dependencies.append(sub_dependant)
|
||||||
continue
|
continue
|
||||||
|
|
@ -532,36 +538,14 @@ def add_param_to_fields(*, field: ModelField, dependant: Dependant) -> None:
|
||||||
dependant.cookie_params.append(field)
|
dependant.cookie_params.append(field)
|
||||||
|
|
||||||
|
|
||||||
def is_coroutine_callable(call: Callable[..., Any]) -> bool:
|
async def _solve_generator(
|
||||||
if inspect.isroutine(call):
|
*, dependant: Dependant, stack: AsyncExitStack, sub_values: Dict[str, Any]
|
||||||
return iscoroutinefunction(call)
|
|
||||||
if inspect.isclass(call):
|
|
||||||
return False
|
|
||||||
dunder_call = getattr(call, "__call__", None) # noqa: B004
|
|
||||||
return iscoroutinefunction(dunder_call)
|
|
||||||
|
|
||||||
|
|
||||||
def is_async_gen_callable(call: Callable[..., Any]) -> bool:
|
|
||||||
if inspect.isasyncgenfunction(call):
|
|
||||||
return True
|
|
||||||
dunder_call = getattr(call, "__call__", None) # noqa: B004
|
|
||||||
return inspect.isasyncgenfunction(dunder_call)
|
|
||||||
|
|
||||||
|
|
||||||
def is_gen_callable(call: Callable[..., Any]) -> bool:
|
|
||||||
if inspect.isgeneratorfunction(call):
|
|
||||||
return True
|
|
||||||
dunder_call = getattr(call, "__call__", None) # noqa: B004
|
|
||||||
return inspect.isgeneratorfunction(dunder_call)
|
|
||||||
|
|
||||||
|
|
||||||
async def solve_generator(
|
|
||||||
*, call: Callable[..., Any], stack: AsyncExitStack, sub_values: Dict[str, Any]
|
|
||||||
) -> Any:
|
) -> Any:
|
||||||
if is_gen_callable(call):
|
assert dependant.call
|
||||||
cm = contextmanager_in_threadpool(contextmanager(call)(**sub_values))
|
if dependant.is_gen_callable:
|
||||||
elif is_async_gen_callable(call):
|
cm = contextmanager_in_threadpool(contextmanager(dependant.call)(**sub_values))
|
||||||
cm = asynccontextmanager(call)(**sub_values)
|
elif dependant.is_async_gen_callable:
|
||||||
|
cm = asynccontextmanager(dependant.call)(**sub_values)
|
||||||
return await stack.enter_async_context(cm)
|
return await stack.enter_async_context(cm)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -571,7 +555,7 @@ class SolvedDependency:
|
||||||
errors: List[Any]
|
errors: List[Any]
|
||||||
background_tasks: Optional[StarletteBackgroundTasks]
|
background_tasks: Optional[StarletteBackgroundTasks]
|
||||||
response: Response
|
response: Response
|
||||||
dependency_cache: Dict[Tuple[Callable[..., Any], Tuple[str]], Any]
|
dependency_cache: Dict[DependencyCacheKey, Any]
|
||||||
|
|
||||||
|
|
||||||
async def solve_dependencies(
|
async def solve_dependencies(
|
||||||
|
|
@ -582,10 +566,20 @@ async def solve_dependencies(
|
||||||
background_tasks: Optional[StarletteBackgroundTasks] = None,
|
background_tasks: Optional[StarletteBackgroundTasks] = None,
|
||||||
response: Optional[Response] = None,
|
response: Optional[Response] = None,
|
||||||
dependency_overrides_provider: Optional[Any] = None,
|
dependency_overrides_provider: Optional[Any] = None,
|
||||||
dependency_cache: Optional[Dict[Tuple[Callable[..., Any], Tuple[str]], Any]] = None,
|
dependency_cache: Optional[Dict[DependencyCacheKey, Any]] = None,
|
||||||
|
# TODO: remove this parameter later, no longer used, not removing it yet as some
|
||||||
|
# people might be monkey patching this function (although that's not supported)
|
||||||
async_exit_stack: AsyncExitStack,
|
async_exit_stack: AsyncExitStack,
|
||||||
embed_body_fields: bool,
|
embed_body_fields: bool,
|
||||||
) -> SolvedDependency:
|
) -> SolvedDependency:
|
||||||
|
request_astack = request.scope.get("fastapi_inner_astack")
|
||||||
|
assert isinstance(request_astack, AsyncExitStack), (
|
||||||
|
"fastapi_inner_astack not found in request scope"
|
||||||
|
)
|
||||||
|
function_astack = request.scope.get("fastapi_function_astack")
|
||||||
|
assert isinstance(function_astack, AsyncExitStack), (
|
||||||
|
"fastapi_function_astack not found in request scope"
|
||||||
|
)
|
||||||
values: Dict[str, Any] = {}
|
values: Dict[str, Any] = {}
|
||||||
errors: List[Any] = []
|
errors: List[Any] = []
|
||||||
if response is None:
|
if response is None:
|
||||||
|
|
@ -594,12 +588,8 @@ async def solve_dependencies(
|
||||||
response.status_code = None # type: ignore
|
response.status_code = None # type: ignore
|
||||||
if dependency_cache is None:
|
if dependency_cache is None:
|
||||||
dependency_cache = {}
|
dependency_cache = {}
|
||||||
sub_dependant: Dependant
|
|
||||||
for sub_dependant in dependant.dependencies:
|
for sub_dependant in dependant.dependencies:
|
||||||
sub_dependant.call = cast(Callable[..., Any], sub_dependant.call)
|
sub_dependant.call = cast(Callable[..., Any], sub_dependant.call)
|
||||||
sub_dependant.cache_key = cast(
|
|
||||||
Tuple[Callable[..., Any], Tuple[str]], sub_dependant.cache_key
|
|
||||||
)
|
|
||||||
call = sub_dependant.call
|
call = sub_dependant.call
|
||||||
use_sub_dependant = sub_dependant
|
use_sub_dependant = sub_dependant
|
||||||
if (
|
if (
|
||||||
|
|
@ -616,6 +606,7 @@ async def solve_dependencies(
|
||||||
call=call,
|
call=call,
|
||||||
name=sub_dependant.name,
|
name=sub_dependant.name,
|
||||||
security_scopes=sub_dependant.security_scopes,
|
security_scopes=sub_dependant.security_scopes,
|
||||||
|
scope=sub_dependant.scope,
|
||||||
)
|
)
|
||||||
|
|
||||||
solved_result = await solve_dependencies(
|
solved_result = await solve_dependencies(
|
||||||
|
|
@ -635,11 +626,18 @@ async def solve_dependencies(
|
||||||
continue
|
continue
|
||||||
if sub_dependant.use_cache and sub_dependant.cache_key in dependency_cache:
|
if sub_dependant.use_cache and sub_dependant.cache_key in dependency_cache:
|
||||||
solved = dependency_cache[sub_dependant.cache_key]
|
solved = dependency_cache[sub_dependant.cache_key]
|
||||||
elif is_gen_callable(call) or is_async_gen_callable(call):
|
elif (
|
||||||
solved = await solve_generator(
|
use_sub_dependant.is_gen_callable or use_sub_dependant.is_async_gen_callable
|
||||||
call=call, stack=async_exit_stack, sub_values=solved_result.values
|
):
|
||||||
|
use_astack = request_astack
|
||||||
|
if sub_dependant.scope == "function":
|
||||||
|
use_astack = function_astack
|
||||||
|
solved = await _solve_generator(
|
||||||
|
dependant=use_sub_dependant,
|
||||||
|
stack=use_astack,
|
||||||
|
sub_values=solved_result.values,
|
||||||
)
|
)
|
||||||
elif is_coroutine_callable(call):
|
elif use_sub_dependant.is_coroutine_callable:
|
||||||
solved = await call(**solved_result.values)
|
solved = await call(**solved_result.values)
|
||||||
else:
|
else:
|
||||||
solved = await run_in_threadpool(call, **solved_result.values)
|
solved = await run_in_threadpool(call, **solved_result.values)
|
||||||
|
|
|
||||||
|
|
@ -147,6 +147,13 @@ class FastAPIError(RuntimeError):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class DependencyScopeError(FastAPIError):
|
||||||
|
"""
|
||||||
|
A dependency declared that it depends on another dependency with an invalid
|
||||||
|
(narrower) scope.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
class ValidationException(Exception):
|
class ValidationException(Exception):
|
||||||
def __init__(self, errors: Sequence[Any]) -> None:
|
def __init__(self, errors: Sequence[Any]) -> None:
|
||||||
self._errors = errors
|
self._errors = errors
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ from annotated_doc import Doc
|
||||||
from fastapi import params
|
from fastapi import params
|
||||||
from fastapi._compat import Undefined
|
from fastapi._compat import Undefined
|
||||||
from fastapi.openapi.models import Example
|
from fastapi.openapi.models import Example
|
||||||
from typing_extensions import Annotated, deprecated
|
from typing_extensions import Annotated, Literal, deprecated
|
||||||
|
|
||||||
_Unset: Any = Undefined
|
_Unset: Any = Undefined
|
||||||
|
|
||||||
|
|
@ -2245,6 +2245,26 @@ def Depends( # noqa: N802
|
||||||
"""
|
"""
|
||||||
),
|
),
|
||||||
] = True,
|
] = True,
|
||||||
|
scope: Annotated[
|
||||||
|
Union[Literal["function", "request"], None],
|
||||||
|
Doc(
|
||||||
|
"""
|
||||||
|
Mainly for dependencies with `yield`, define when the dependency function
|
||||||
|
should start (the code before `yield`) and when it should end (the code
|
||||||
|
after `yield`).
|
||||||
|
|
||||||
|
* `"function"`: start the dependency before the *path operation function*
|
||||||
|
that handles the request, end the dependency after the *path operation
|
||||||
|
function* ends, but **before** the response is sent back to the client.
|
||||||
|
So, the dependency function will be executed **around** the *path operation
|
||||||
|
**function***.
|
||||||
|
* `"request"`: start the dependency before the *path operation function*
|
||||||
|
that handles the request (similar to when using `"function"`), but end
|
||||||
|
**after** the response is sent back to the client. So, the dependency
|
||||||
|
function will be executed **around** the **request** and response cycle.
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
] = None,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Declare a FastAPI dependency.
|
Declare a FastAPI dependency.
|
||||||
|
|
@ -2275,7 +2295,7 @@ def Depends( # noqa: N802
|
||||||
return commons
|
return commons
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
return params.Depends(dependency=dependency, use_cache=use_cache)
|
return params.Depends(dependency=dependency, use_cache=use_cache, scope=scope)
|
||||||
|
|
||||||
|
|
||||||
def Security( # noqa: N802
|
def Security( # noqa: N802
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, Union
|
||||||
|
|
||||||
from fastapi.openapi.models import Example
|
from fastapi.openapi.models import Example
|
||||||
from pydantic.fields import FieldInfo
|
from pydantic.fields import FieldInfo
|
||||||
from typing_extensions import Annotated, deprecated
|
from typing_extensions import Annotated, Literal, deprecated
|
||||||
|
|
||||||
from ._compat import (
|
from ._compat import (
|
||||||
PYDANTIC_V2,
|
PYDANTIC_V2,
|
||||||
|
|
@ -766,6 +766,7 @@ class File(Form): # type: ignore[misc]
|
||||||
class Depends:
|
class Depends:
|
||||||
dependency: Optional[Callable[..., Any]] = None
|
dependency: Optional[Callable[..., Any]] = None
|
||||||
use_cache: bool = True
|
use_cache: bool = True
|
||||||
|
scope: Union[Literal["function", "request"], None] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
|
||||||
|
|
@ -104,10 +104,11 @@ def request_response(
|
||||||
async def app(scope: Scope, receive: Receive, send: Send) -> None:
|
async def app(scope: Scope, receive: Receive, send: Send) -> None:
|
||||||
# Starts customization
|
# Starts customization
|
||||||
response_awaited = False
|
response_awaited = False
|
||||||
async with AsyncExitStack() as stack:
|
async with AsyncExitStack() as request_stack:
|
||||||
scope["fastapi_inner_astack"] = stack
|
scope["fastapi_inner_astack"] = request_stack
|
||||||
# Same as in Starlette
|
async with AsyncExitStack() as function_stack:
|
||||||
response = await f(request)
|
scope["fastapi_function_astack"] = function_stack
|
||||||
|
response = await f(request)
|
||||||
await response(scope, receive, send)
|
await response(scope, receive, send)
|
||||||
# Continues customization
|
# Continues customization
|
||||||
response_awaited = True
|
response_awaited = True
|
||||||
|
|
@ -140,11 +141,11 @@ def websocket_session(
|
||||||
session = WebSocket(scope, receive=receive, send=send)
|
session = WebSocket(scope, receive=receive, send=send)
|
||||||
|
|
||||||
async def app(scope: Scope, receive: Receive, send: Send) -> None:
|
async def app(scope: Scope, receive: Receive, send: Send) -> None:
|
||||||
# Starts customization
|
async with AsyncExitStack() as request_stack:
|
||||||
async with AsyncExitStack() as stack:
|
scope["fastapi_inner_astack"] = request_stack
|
||||||
scope["fastapi_inner_astack"] = stack
|
async with AsyncExitStack() as function_stack:
|
||||||
# Same as in Starlette
|
scope["fastapi_function_astack"] = function_stack
|
||||||
await func(session)
|
await func(session)
|
||||||
|
|
||||||
# Same as in Starlette
|
# Same as in Starlette
|
||||||
await wrap_app_handling_exceptions(app, session)(scope, receive, send)
|
await wrap_app_handling_exceptions(app, session)(scope, receive, send)
|
||||||
|
|
@ -479,7 +480,9 @@ class APIWebSocketRoute(routing.WebSocketRoute):
|
||||||
self.name = get_name(endpoint) if name is None else name
|
self.name = get_name(endpoint) if name is None else name
|
||||||
self.dependencies = list(dependencies or [])
|
self.dependencies = list(dependencies or [])
|
||||||
self.path_regex, self.path_format, self.param_convertors = compile_path(path)
|
self.path_regex, self.path_format, self.param_convertors = compile_path(path)
|
||||||
self.dependant = get_dependant(path=self.path_format, call=self.endpoint)
|
self.dependant = get_dependant(
|
||||||
|
path=self.path_format, call=self.endpoint, scope="function"
|
||||||
|
)
|
||||||
for depends in self.dependencies[::-1]:
|
for depends in self.dependencies[::-1]:
|
||||||
self.dependant.dependencies.insert(
|
self.dependant.dependencies.insert(
|
||||||
0,
|
0,
|
||||||
|
|
@ -630,7 +633,9 @@ class APIRoute(routing.Route):
|
||||||
self.response_fields = {}
|
self.response_fields = {}
|
||||||
|
|
||||||
assert callable(endpoint), "An endpoint must be a callable"
|
assert callable(endpoint), "An endpoint must be a callable"
|
||||||
self.dependant = get_dependant(path=self.path_format, call=self.endpoint)
|
self.dependant = get_dependant(
|
||||||
|
path=self.path_format, call=self.endpoint, scope="function"
|
||||||
|
)
|
||||||
for depends in self.dependencies[::-1]:
|
for depends in self.dependencies[::-1]:
|
||||||
self.dependant.dependencies.insert(
|
self.dependant.dependencies.insert(
|
||||||
0,
|
0,
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
import types
|
import types
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Callable, Dict, Set, Type, TypeVar, Union
|
from typing import Any, Callable, Dict, Optional, Set, Tuple, Type, TypeVar, Union
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
@ -8,3 +8,4 @@ DecoratedCallable = TypeVar("DecoratedCallable", bound=Callable[..., Any])
|
||||||
UnionType = getattr(types, "UnionType", Union)
|
UnionType = getattr(types, "UnionType", Union)
|
||||||
ModelNameMap = Dict[Union[Type[BaseModel], Type[Enum]], str]
|
ModelNameMap = Dict[Union[Type[BaseModel], Type[Enum]], str]
|
||||||
IncEx = Union[Set[int], Set[str], Dict[int, Any], Dict[str, Any]]
|
IncEx = Union[Set[int], Set[str], Dict[int, Any], Dict[str, Any]]
|
||||||
|
DependencyCacheKey = Tuple[Optional[Callable[..., Any]], Tuple[str, ...], str]
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,184 @@
|
||||||
|
import json
|
||||||
|
from typing import Any, Tuple
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi import Depends, FastAPI
|
||||||
|
from fastapi.exceptions import FastAPIError
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
|
|
||||||
|
class Session:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.open = True
|
||||||
|
|
||||||
|
|
||||||
|
def dep_session() -> Any:
|
||||||
|
s = Session()
|
||||||
|
yield s
|
||||||
|
s.open = False
|
||||||
|
|
||||||
|
|
||||||
|
SessionFuncDep = Annotated[Session, Depends(dep_session, scope="function")]
|
||||||
|
SessionRequestDep = Annotated[Session, Depends(dep_session, scope="request")]
|
||||||
|
SessionDefaultDep = Annotated[Session, Depends(dep_session)]
|
||||||
|
|
||||||
|
|
||||||
|
class NamedSession:
|
||||||
|
def __init__(self, name: str = "default") -> None:
|
||||||
|
self.name = name
|
||||||
|
self.open = True
|
||||||
|
|
||||||
|
|
||||||
|
def get_named_session(session: SessionRequestDep, session_b: SessionDefaultDep) -> Any:
|
||||||
|
assert session is session_b
|
||||||
|
named_session = NamedSession(name="named")
|
||||||
|
yield named_session, session_b
|
||||||
|
named_session.open = False
|
||||||
|
|
||||||
|
|
||||||
|
NamedSessionsDep = Annotated[Tuple[NamedSession, Session], Depends(get_named_session)]
|
||||||
|
|
||||||
|
|
||||||
|
def get_named_func_session(session: SessionFuncDep) -> Any:
|
||||||
|
named_session = NamedSession(name="named")
|
||||||
|
yield named_session, session
|
||||||
|
named_session.open = False
|
||||||
|
|
||||||
|
|
||||||
|
def get_named_regular_func_session(session: SessionFuncDep) -> Any:
|
||||||
|
named_session = NamedSession(name="named")
|
||||||
|
return named_session, session
|
||||||
|
|
||||||
|
|
||||||
|
BrokenSessionsDep = Annotated[
|
||||||
|
Tuple[NamedSession, Session], Depends(get_named_func_session)
|
||||||
|
]
|
||||||
|
NamedSessionsFuncDep = Annotated[
|
||||||
|
Tuple[NamedSession, Session], Depends(get_named_func_session, scope="function")
|
||||||
|
]
|
||||||
|
|
||||||
|
RegularSessionsDep = Annotated[
|
||||||
|
Tuple[NamedSession, Session], Depends(get_named_regular_func_session)
|
||||||
|
]
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/function-scope")
|
||||||
|
def function_scope(session: SessionFuncDep) -> Any:
|
||||||
|
def iter_data():
|
||||||
|
yield json.dumps({"is_open": session.open})
|
||||||
|
|
||||||
|
return StreamingResponse(iter_data())
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/request-scope")
|
||||||
|
def request_scope(session: SessionRequestDep) -> Any:
|
||||||
|
def iter_data():
|
||||||
|
yield json.dumps({"is_open": session.open})
|
||||||
|
|
||||||
|
return StreamingResponse(iter_data())
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/two-scopes")
|
||||||
|
def get_stream_session(
|
||||||
|
function_session: SessionFuncDep, request_session: SessionRequestDep
|
||||||
|
) -> Any:
|
||||||
|
def iter_data():
|
||||||
|
yield json.dumps(
|
||||||
|
{"func_is_open": function_session.open, "req_is_open": request_session.open}
|
||||||
|
)
|
||||||
|
|
||||||
|
return StreamingResponse(iter_data())
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/sub")
|
||||||
|
def get_sub(sessions: NamedSessionsDep) -> Any:
|
||||||
|
def iter_data():
|
||||||
|
yield json.dumps(
|
||||||
|
{"named_session_open": sessions[0].open, "session_open": sessions[1].open}
|
||||||
|
)
|
||||||
|
|
||||||
|
return StreamingResponse(iter_data())
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/named-function-scope")
|
||||||
|
def get_named_function_scope(sessions: NamedSessionsFuncDep) -> Any:
|
||||||
|
def iter_data():
|
||||||
|
yield json.dumps(
|
||||||
|
{"named_session_open": sessions[0].open, "session_open": sessions[1].open}
|
||||||
|
)
|
||||||
|
|
||||||
|
return StreamingResponse(iter_data())
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/regular-function-scope")
|
||||||
|
def get_regular_function_scope(sessions: RegularSessionsDep) -> Any:
|
||||||
|
def iter_data():
|
||||||
|
yield json.dumps(
|
||||||
|
{"named_session_open": sessions[0].open, "session_open": sessions[1].open}
|
||||||
|
)
|
||||||
|
|
||||||
|
return StreamingResponse(iter_data())
|
||||||
|
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
|
||||||
|
|
||||||
|
def test_function_scope() -> None:
|
||||||
|
response = client.get("/function-scope")
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["is_open"] is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_request_scope() -> None:
|
||||||
|
response = client.get("/request-scope")
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["is_open"] is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_two_scopes() -> None:
|
||||||
|
response = client.get("/two-scopes")
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["func_is_open"] is False
|
||||||
|
assert data["req_is_open"] is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_sub() -> None:
|
||||||
|
response = client.get("/sub")
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["named_session_open"] is True
|
||||||
|
assert data["session_open"] is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_broken_scope() -> None:
|
||||||
|
with pytest.raises(
|
||||||
|
FastAPIError,
|
||||||
|
match='The dependency "get_named_func_session" has a scope of "request", it cannot depend on dependencies with scope "function"',
|
||||||
|
):
|
||||||
|
|
||||||
|
@app.get("/broken-scope")
|
||||||
|
def get_broken(sessions: BrokenSessionsDep) -> Any: # pragma: no cover
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def test_named_function_scope() -> None:
|
||||||
|
response = client.get("/named-function-scope")
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["named_session_open"] is False
|
||||||
|
assert data["session_open"] is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_regular_function_scope() -> None:
|
||||||
|
response = client.get("/regular-function-scope")
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["named_session_open"] is True
|
||||||
|
assert data["session_open"] is False
|
||||||
|
|
@ -0,0 +1,201 @@
|
||||||
|
from contextvars import ContextVar
|
||||||
|
from typing import Any, Dict, Tuple
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi import Depends, FastAPI, WebSocket
|
||||||
|
from fastapi.exceptions import FastAPIError
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
|
global_context: ContextVar[Dict[str, Any]] = ContextVar("global_context", default={}) # noqa: B039
|
||||||
|
|
||||||
|
|
||||||
|
class Session:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.open = True
|
||||||
|
|
||||||
|
|
||||||
|
async def dep_session() -> Any:
|
||||||
|
s = Session()
|
||||||
|
yield s
|
||||||
|
s.open = False
|
||||||
|
global_state = global_context.get()
|
||||||
|
global_state["session_closed"] = True
|
||||||
|
|
||||||
|
|
||||||
|
SessionFuncDep = Annotated[Session, Depends(dep_session, scope="function")]
|
||||||
|
SessionRequestDep = Annotated[Session, Depends(dep_session, scope="request")]
|
||||||
|
SessionDefaultDep = Annotated[Session, Depends(dep_session)]
|
||||||
|
|
||||||
|
|
||||||
|
class NamedSession:
|
||||||
|
def __init__(self, name: str = "default") -> None:
|
||||||
|
self.name = name
|
||||||
|
self.open = True
|
||||||
|
|
||||||
|
|
||||||
|
def get_named_session(session: SessionRequestDep, session_b: SessionDefaultDep) -> Any:
|
||||||
|
assert session is session_b
|
||||||
|
named_session = NamedSession(name="named")
|
||||||
|
yield named_session, session_b
|
||||||
|
named_session.open = False
|
||||||
|
global_state = global_context.get()
|
||||||
|
global_state["named_session_closed"] = True
|
||||||
|
|
||||||
|
|
||||||
|
NamedSessionsDep = Annotated[Tuple[NamedSession, Session], Depends(get_named_session)]
|
||||||
|
|
||||||
|
|
||||||
|
def get_named_func_session(session: SessionFuncDep) -> Any:
|
||||||
|
named_session = NamedSession(name="named")
|
||||||
|
yield named_session, session
|
||||||
|
named_session.open = False
|
||||||
|
global_state = global_context.get()
|
||||||
|
global_state["named_func_session_closed"] = True
|
||||||
|
|
||||||
|
|
||||||
|
def get_named_regular_func_session(session: SessionFuncDep) -> Any:
|
||||||
|
named_session = NamedSession(name="named")
|
||||||
|
return named_session, session
|
||||||
|
|
||||||
|
|
||||||
|
BrokenSessionsDep = Annotated[
|
||||||
|
Tuple[NamedSession, Session], Depends(get_named_func_session)
|
||||||
|
]
|
||||||
|
NamedSessionsFuncDep = Annotated[
|
||||||
|
Tuple[NamedSession, Session], Depends(get_named_func_session, scope="function")
|
||||||
|
]
|
||||||
|
|
||||||
|
RegularSessionsDep = Annotated[
|
||||||
|
Tuple[NamedSession, Session], Depends(get_named_regular_func_session)
|
||||||
|
]
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
|
||||||
|
@app.websocket("/function-scope")
|
||||||
|
async def function_scope(websocket: WebSocket, session: SessionFuncDep) -> Any:
|
||||||
|
await websocket.accept()
|
||||||
|
await websocket.send_json({"is_open": session.open})
|
||||||
|
|
||||||
|
|
||||||
|
@app.websocket("/request-scope")
|
||||||
|
async def request_scope(websocket: WebSocket, session: SessionRequestDep) -> Any:
|
||||||
|
await websocket.accept()
|
||||||
|
await websocket.send_json({"is_open": session.open})
|
||||||
|
|
||||||
|
|
||||||
|
@app.websocket("/two-scopes")
|
||||||
|
async def get_stream_session(
|
||||||
|
websocket: WebSocket,
|
||||||
|
function_session: SessionFuncDep,
|
||||||
|
request_session: SessionRequestDep,
|
||||||
|
) -> Any:
|
||||||
|
await websocket.accept()
|
||||||
|
await websocket.send_json(
|
||||||
|
{"func_is_open": function_session.open, "req_is_open": request_session.open}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.websocket("/sub")
|
||||||
|
async def get_sub(websocket: WebSocket, sessions: NamedSessionsDep) -> Any:
|
||||||
|
await websocket.accept()
|
||||||
|
await websocket.send_json(
|
||||||
|
{"named_session_open": sessions[0].open, "session_open": sessions[1].open}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.websocket("/named-function-scope")
|
||||||
|
async def get_named_function_scope(
|
||||||
|
websocket: WebSocket, sessions: NamedSessionsFuncDep
|
||||||
|
) -> Any:
|
||||||
|
await websocket.accept()
|
||||||
|
await websocket.send_json(
|
||||||
|
{"named_session_open": sessions[0].open, "session_open": sessions[1].open}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.websocket("/regular-function-scope")
|
||||||
|
async def get_regular_function_scope(
|
||||||
|
websocket: WebSocket, sessions: RegularSessionsDep
|
||||||
|
) -> Any:
|
||||||
|
await websocket.accept()
|
||||||
|
await websocket.send_json(
|
||||||
|
{"named_session_open": sessions[0].open, "session_open": sessions[1].open}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
|
||||||
|
|
||||||
|
def test_function_scope() -> None:
|
||||||
|
global_context.set({})
|
||||||
|
global_state = global_context.get()
|
||||||
|
with client.websocket_connect("/function-scope") as websocket:
|
||||||
|
data = websocket.receive_json()
|
||||||
|
assert data["is_open"] is True
|
||||||
|
assert global_state["session_closed"] is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_request_scope() -> None:
|
||||||
|
global_context.set({})
|
||||||
|
global_state = global_context.get()
|
||||||
|
with client.websocket_connect("/request-scope") as websocket:
|
||||||
|
data = websocket.receive_json()
|
||||||
|
assert data["is_open"] is True
|
||||||
|
assert global_state["session_closed"] is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_two_scopes() -> None:
|
||||||
|
global_context.set({})
|
||||||
|
global_state = global_context.get()
|
||||||
|
with client.websocket_connect("/two-scopes") as websocket:
|
||||||
|
data = websocket.receive_json()
|
||||||
|
assert data["func_is_open"] is True
|
||||||
|
assert data["req_is_open"] is True
|
||||||
|
assert global_state["session_closed"] is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_sub() -> None:
|
||||||
|
global_context.set({})
|
||||||
|
global_state = global_context.get()
|
||||||
|
with client.websocket_connect("/sub") as websocket:
|
||||||
|
data = websocket.receive_json()
|
||||||
|
assert data["named_session_open"] is True
|
||||||
|
assert data["session_open"] is True
|
||||||
|
assert global_state["session_closed"] is True
|
||||||
|
assert global_state["named_session_closed"] is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_broken_scope() -> None:
|
||||||
|
with pytest.raises(
|
||||||
|
FastAPIError,
|
||||||
|
match='The dependency "get_named_func_session" has a scope of "request", it cannot depend on dependencies with scope "function"',
|
||||||
|
):
|
||||||
|
|
||||||
|
@app.websocket("/broken-scope")
|
||||||
|
async def get_broken(
|
||||||
|
websocket: WebSocket, sessions: BrokenSessionsDep
|
||||||
|
) -> Any: # pragma: no cover
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def test_named_function_scope() -> None:
|
||||||
|
global_context.set({})
|
||||||
|
global_state = global_context.get()
|
||||||
|
with client.websocket_connect("/named-function-scope") as websocket:
|
||||||
|
data = websocket.receive_json()
|
||||||
|
assert data["named_session_open"] is True
|
||||||
|
assert data["session_open"] is True
|
||||||
|
assert global_state["session_closed"] is True
|
||||||
|
assert global_state["named_func_session_closed"] is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_regular_function_scope() -> None:
|
||||||
|
global_context.set({})
|
||||||
|
global_state = global_context.get()
|
||||||
|
with client.websocket_connect("/regular-function-scope") as websocket:
|
||||||
|
data = websocket.receive_json()
|
||||||
|
assert data["named_session_open"] is True
|
||||||
|
assert data["session_open"] is True
|
||||||
|
assert global_state["session_closed"] is True
|
||||||
|
|
@ -0,0 +1,27 @@
|
||||||
|
import importlib
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
from ...utils import needs_py39
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(
|
||||||
|
name="client",
|
||||||
|
params=[
|
||||||
|
"tutorial008e",
|
||||||
|
"tutorial008e_an",
|
||||||
|
pytest.param("tutorial008e_an_py39", marks=needs_py39),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def get_client(request: pytest.FixtureRequest):
|
||||||
|
mod = importlib.import_module(f"docs_src.dependencies.{request.param}")
|
||||||
|
|
||||||
|
client = TestClient(mod.app)
|
||||||
|
return client
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_users_me(client: TestClient):
|
||||||
|
response = client.get("/users/me")
|
||||||
|
assert response.status_code == 200, response.text
|
||||||
|
assert response.json() == "Rick"
|
||||||
Loading…
Reference in New Issue