Refactor dependency scope type

DRY up references to dependency scope by adding DependencyScope type.
This commit is contained in:
Alan Potter 2026-02-05 06:56:52 -05:00
parent 79406a4b04
commit 95bceaa2fb
5 changed files with 53 additions and 136 deletions

View File

@ -6,8 +6,7 @@ from typing import Any, Callable, Optional, Union
from fastapi._compat import ModelField
from fastapi.security.base import SecurityBase
from fastapi.types import DependencyCacheKey
from typing_extensions import Literal
from fastapi.types import DependencyCacheKey, DependencyScope
if sys.version_info >= (3, 13): # pragma: no cover
from inspect import iscoroutinefunction
@ -48,7 +47,7 @@ class Dependant:
parent_oauth_scopes: Optional[list[str]] = None
use_cache: bool = True
path: Optional[str] = None
scope: Union[Literal["function", "request"], None] = None
scope: DependencyScope = None
@cached_property
def oauth_scopes(self) -> list[str]:

View File

@ -49,7 +49,7 @@ from fastapi.dependencies.models import Dependant
from fastapi.exceptions import DependencyScopeError
from fastapi.logger import logger
from fastapi.security.oauth2 import SecurityScopes
from fastapi.types import DependencyCacheKey
from fastapi.types import DependencyCacheKey, DependencyScope
from fastapi.utils import create_model_field, get_path_param_names
from pydantic import BaseModel, Json
from pydantic.fields import FieldInfo
@ -65,7 +65,7 @@ from starlette.datastructures import (
from starlette.requests import HTTPConnection, Request
from starlette.responses import Response
from starlette.websockets import WebSocket
from typing_extensions import Literal, get_args, get_origin
from typing_extensions import get_args, get_origin
from typing_inspection.typing_objects import is_typealiastype
multipart_not_installed_error = (
@ -111,9 +111,7 @@ def ensure_multipart_is_installed() -> None:
def get_parameterless_sub_dependant(*, depends: params.Depends, path: str) -> Dependant:
assert callable(depends.dependency), (
"A parameter-less dependency must have a callable dependency"
)
assert callable(depends.dependency), "A parameter-less dependency must have a callable dependency"
own_oauth_scopes: list[str] = []
if isinstance(depends, params.Security) and depends.scopes:
own_oauth_scopes.extend(depends.scopes)
@ -135,9 +133,7 @@ def get_flat_dependant(
if visited is None:
visited = []
visited.append(dependant.cache_key)
use_parent_oauth_scopes = (parent_oauth_scopes or []) + (
dependant.oauth_scopes or []
)
use_parent_oauth_scopes = (parent_oauth_scopes or []) + (dependant.oauth_scopes or [])
flat_dependant = Dependant(
path_params=dependant.path_params.copy(),
@ -262,7 +258,7 @@ def get_dependant(
own_oauth_scopes: Optional[list[str]] = None,
parent_oauth_scopes: Optional[list[str]] = None,
use_cache: bool = True,
scope: Union[Literal["function", "request"], None] = None,
scope: DependencyScope = None,
) -> Dependant:
dependant = Dependant(
call=call,
@ -317,9 +313,7 @@ def get_dependant(
type_annotation=param_details.type_annotation,
dependant=dependant,
):
assert param_details.field is None, (
f"Cannot specify multiple FastAPI annotations for {param_name!r}"
)
assert param_details.field is None, f"Cannot specify multiple FastAPI annotations for {param_name!r}"
continue
assert param_details.field is not None
if isinstance(param_details.field.field_info, params.Body):
@ -329,9 +323,7 @@ def get_dependant(
return dependant
def add_non_field_param_to_dependency(
*, param_name: str, type_annotation: Any, dependant: Dependant
) -> Optional[bool]:
def add_non_field_param_to_dependency(*, param_name: str, type_annotation: Any, dependant: Dependant) -> Optional[bool]:
if lenient_issubclass(type_annotation, Request):
dependant.request_param_name = param_name
return True
@ -381,11 +373,7 @@ def analyze_param(
if get_origin(use_annotation) is Annotated:
annotated_args = get_args(annotation)
type_annotation = annotated_args[0]
fastapi_annotations = [
arg
for arg in annotated_args[1:]
if isinstance(arg, (FieldInfo, params.Depends))
]
fastapi_annotations = [arg for arg in annotated_args[1:] if isinstance(arg, (FieldInfo, params.Depends))]
fastapi_specific_annotations = [
arg
for arg in fastapi_annotations
@ -399,9 +387,7 @@ def analyze_param(
)
]
if fastapi_specific_annotations:
fastapi_annotation: Union[FieldInfo, params.Depends, None] = (
fastapi_specific_annotations[-1]
)
fastapi_annotation: Union[FieldInfo, params.Depends, None] = fastapi_specific_annotations[-1]
else:
fastapi_annotation = None
# Set default for Annotated FieldInfo
@ -411,9 +397,7 @@ def analyze_param(
field_info=fastapi_annotation,
annotation=use_annotation,
)
assert (
field_info.default == Undefined or field_info.default == RequiredParam
), (
assert field_info.default == Undefined or field_info.default == RequiredParam, (
f"`{field_info.__class__.__name__}` default value cannot be set in"
f" `Annotated` for {param_name!r}. Set the default value with `=` instead."
)
@ -427,10 +411,7 @@ def analyze_param(
depends = fastapi_annotation
# Get Depends from default value
if isinstance(value, params.Depends):
assert depends is None, (
"Cannot specify `Depends` in `Annotated` and default value"
f" together for {param_name!r}"
)
assert depends is None, f"Cannot specify `Depends` in `Annotated` and default value together for {param_name!r}"
assert field_info is None, (
"Cannot specify a FastAPI annotation in `Annotated` and `Depends` as a"
f" default value together for {param_name!r}"
@ -439,8 +420,7 @@ def analyze_param(
# Get FieldInfo from default value
elif isinstance(value, FieldInfo):
assert field_info is None, (
"Cannot specify FastAPI annotations in `Annotated` and default value"
f" together for {param_name!r}"
f"Cannot specify FastAPI annotations in `Annotated` and default value together for {param_name!r}"
)
field_info = value
if isinstance(field_info, FieldInfo):
@ -466,9 +446,7 @@ def analyze_param(
SecurityScopes,
),
):
assert field_info is None, (
f"Cannot specify FastAPI annotation for type {type_annotation!r}"
)
assert field_info is None, f"Cannot specify FastAPI annotation for type {type_annotation!r}"
# Handle default assignations, neither field_info nor depends was not found in Annotated nor default value
elif field_info is None and depends is None:
default_value = value if value is not inspect.Signature.empty else RequiredParam
@ -477,9 +455,9 @@ def analyze_param(
# parameter might sometimes be a path parameter and sometimes not. See
# `tests/test_infer_param_optionality.py` for an example.
field_info = params.Path(annotation=use_annotation)
elif is_uploadfile_or_nonable_uploadfile_annotation(
elif is_uploadfile_or_nonable_uploadfile_annotation(type_annotation) or is_uploadfile_sequence_annotation(
type_annotation
) or is_uploadfile_sequence_annotation(type_annotation):
):
field_info = params.File(annotation=use_annotation, default=default_value)
elif not field_annotation_is_scalar(annotation=type_annotation):
field_info = params.Body(annotation=use_annotation, default=default_value)
@ -492,13 +470,9 @@ def analyze_param(
# Handle field_info.in_
if is_path_param:
assert isinstance(field_info, params.Path), (
f"Cannot use `{field_info.__class__.__name__}` for path param"
f" {param_name!r}"
f"Cannot use `{field_info.__class__.__name__}` for path param {param_name!r}"
)
elif (
isinstance(field_info, params.Param)
and getattr(field_info, "in_", None) is None
):
elif isinstance(field_info, params.Param) and getattr(field_info, "in_", None) is None:
field_info.in_ = params.ParamTypes.query
use_annotation_from_field_info = use_annotation
if isinstance(field_info, params.Form):
@ -517,9 +491,7 @@ def analyze_param(
field_info=field_info,
)
if is_path_param:
assert is_scalar_field(field=field), (
"Path params must be of one of the supported types"
)
assert is_scalar_field(field=field), "Path params must be of one of the supported types"
elif isinstance(field_info, params.Query):
assert (
is_scalar_field(field)
@ -550,9 +522,7 @@ def add_param_to_fields(*, field: ModelField, dependant: Dependant) -> None:
dependant.cookie_params.append(field)
async def _solve_generator(
*, dependant: Dependant, stack: AsyncExitStack, sub_values: dict[str, Any]
) -> Any:
async def _solve_generator(*, dependant: Dependant, stack: AsyncExitStack, sub_values: dict[str, Any]) -> Any:
assert dependant.call
if dependant.is_async_gen_callable:
cm = asynccontextmanager(dependant.call)(**sub_values)
@ -585,13 +555,9 @@ async def solve_dependencies(
embed_body_fields: bool,
) -> SolvedDependency:
request_astack = request.scope.get("fastapi_inner_astack")
assert isinstance(request_astack, AsyncExitStack), (
"fastapi_inner_astack not found in request scope"
)
assert isinstance(request_astack, AsyncExitStack), "fastapi_inner_astack not found in request scope"
function_astack = request.scope.get("fastapi_function_astack")
assert isinstance(function_astack, AsyncExitStack), (
"fastapi_function_astack not found in request scope"
)
assert isinstance(function_astack, AsyncExitStack), "fastapi_function_astack not found in request scope"
values: dict[str, Any] = {}
errors: list[Any] = []
if response is None:
@ -604,14 +570,9 @@ async def solve_dependencies(
sub_dependant.call = cast(Callable[..., Any], sub_dependant.call)
call = sub_dependant.call
use_sub_dependant = sub_dependant
if (
dependency_overrides_provider
and dependency_overrides_provider.dependency_overrides
):
if dependency_overrides_provider and dependency_overrides_provider.dependency_overrides:
original_call = sub_dependant.call
call = getattr(
dependency_overrides_provider, "dependency_overrides", {}
).get(original_call, original_call)
call = getattr(dependency_overrides_provider, "dependency_overrides", {}).get(original_call, original_call)
use_path: str = sub_dependant.path # type: ignore
use_sub_dependant = get_dependant(
path=use_path,
@ -638,9 +599,7 @@ async def solve_dependencies(
continue
if sub_dependant.use_cache and sub_dependant.cache_key in dependency_cache:
solved = dependency_cache[sub_dependant.cache_key]
elif (
use_sub_dependant.is_gen_callable or use_sub_dependant.is_async_gen_callable
):
elif use_sub_dependant.is_gen_callable or use_sub_dependant.is_async_gen_callable:
use_astack = request_astack
if sub_dependant.scope == "function":
use_astack = function_astack
@ -657,18 +616,10 @@ async def solve_dependencies(
values[sub_dependant.name] = solved
if sub_dependant.cache_key not in dependency_cache:
dependency_cache[sub_dependant.cache_key] = solved
path_values, path_errors = request_params_to_args(
dependant.path_params, request.path_params
)
query_values, query_errors = request_params_to_args(
dependant.query_params, request.query_params
)
header_values, header_errors = request_params_to_args(
dependant.header_params, request.headers
)
cookie_values, cookie_errors = request_params_to_args(
dependant.cookie_params, request.cookies
)
path_values, path_errors = request_params_to_args(dependant.path_params, request.path_params)
query_values, query_errors = request_params_to_args(dependant.query_params, request.query_params)
header_values, header_errors = request_params_to_args(dependant.header_params, request.headers)
cookie_values, cookie_errors = request_params_to_args(dependant.cookie_params, request.cookies)
values.update(path_values)
values.update(query_values)
values.update(header_values)
@ -698,9 +649,7 @@ async def solve_dependencies(
if dependant.response_param_name:
values[dependant.response_param_name] = response
if dependant.security_scopes_param_name:
values[dependant.security_scopes_param_name] = SecurityScopes(
scopes=dependant.oauth_scopes
)
values[dependant.security_scopes_param_name] = SecurityScopes(scopes=dependant.oauth_scopes)
return SolvedDependency(
values=values,
errors=errors,
@ -730,15 +679,9 @@ def _is_json_field(field: ModelField) -> bool:
return any(type(item) is Json for item in field.field_info.metadata)
def _get_multidict_value(
field: ModelField, values: Mapping[str, Any], alias: Union[str, None] = None
) -> Any:
def _get_multidict_value(field: ModelField, values: Mapping[str, Any], alias: Union[str, None] = None) -> Any:
alias = alias or get_validation_alias(field)
if (
(not _is_json_field(field))
and is_sequence_field(field)
and isinstance(values, (ImmutableMultiDict, Headers))
):
if (not _is_json_field(field)) and is_sequence_field(field) and isinstance(values, (ImmutableMultiDict, Headers)):
value = values.getlist(alias)
else:
value = values.get(alias, None)
@ -777,9 +720,7 @@ def request_params_to_args(
single_not_embedded_field = True
# If headers are in a Pydantic model, the way to disable convert_underscores
# would be with Header(convert_underscores=False) at the Pydantic model level
default_convert_underscores = getattr(
first_field.field_info, "convert_underscores", True
)
default_convert_underscores = getattr(first_field.field_info, "convert_underscores", True)
params_to_process: dict[str, Any] = {}
@ -790,9 +731,7 @@ def request_params_to_args(
if isinstance(received_params, Headers):
# Handle fields extracted from a Pydantic Model for a header, each field
# doesn't have a FieldInfo of type Header with the default convert_underscores=True
convert_underscores = getattr(
field.field_info, "convert_underscores", default_convert_underscores
)
convert_underscores = getattr(field.field_info, "convert_underscores", default_convert_underscores)
if convert_underscores:
alias = get_validation_alias(field)
if alias == field.name:
@ -815,9 +754,7 @@ def request_params_to_args(
if single_not_embedded_field:
field_info = first_field.field_info
assert isinstance(field_info, params.Param), (
"Params must be subclasses of Param"
)
assert isinstance(field_info, params.Param), "Params must be subclasses of Param"
loc: tuple[str, ...] = (field_info.in_.value,)
v_, errors_ = _validate_value_with_model_field(
field=first_field, value=params_to_process, values=values, loc=loc
@ -827,13 +764,9 @@ def request_params_to_args(
for field in fields:
value = _get_multidict_value(field, received_params)
field_info = field.field_info
assert isinstance(field_info, params.Param), (
"Params must be subclasses of Param"
)
assert isinstance(field_info, params.Param), "Params must be subclasses of Param"
loc = (field_info.in_.value, get_validation_alias(field))
v_, errors_ = _validate_value_with_model_field(
field=field, value=value, values=values, loc=loc
)
v_, errors_ = _validate_value_with_model_field(field=field, value=value, values=values, loc=loc)
if errors_:
errors.extend(errors_)
else:
@ -893,17 +826,9 @@ async def _extract_form_body(
for field in body_fields:
value = _get_multidict_value(field, received_body)
field_info = field.field_info
if (
isinstance(field_info, params.File)
and is_bytes_field(field)
and isinstance(value, UploadFile)
):
if isinstance(field_info, params.File) and is_bytes_field(field) and isinstance(value, UploadFile):
value = await value.read()
elif (
is_bytes_sequence_field(field)
and isinstance(field_info, params.File)
and value_is_sequence(value)
):
elif is_bytes_sequence_field(field) and isinstance(field_info, params.File) and value_is_sequence(value):
# For types
assert isinstance(value, sequence_types)
results: list[Union[bytes, str]] = []
@ -957,9 +882,7 @@ async def request_body_to_args(
if single_not_embedded_field:
loc: tuple[str, ...] = ("body",)
v_, errors_ = _validate_value_with_model_field(
field=first_field, value=body_to_process, values=values, loc=loc
)
v_, errors_ = _validate_value_with_model_field(field=first_field, value=body_to_process, values=values, loc=loc)
return {first_field.name: v_}, errors_
for field in body_fields:
loc = ("body", get_validation_alias(field))
@ -971,9 +894,7 @@ async def request_body_to_args(
except AttributeError:
errors.append(get_missing_field_error(loc))
continue
v_, errors_ = _validate_value_with_model_field(
field=field, value=value, values=values, loc=loc
)
v_, errors_ = _validate_value_with_model_field(field=field, value=value, values=values, loc=loc)
if errors_:
errors.extend(errors_)
else:
@ -981,9 +902,7 @@ async def request_body_to_args(
return values, errors
def get_body_field(
*, flat_dependant: Dependant, name: str, embed_body_fields: bool
) -> Optional[ModelField]:
def get_body_field(*, flat_dependant: Dependant, name: str, embed_body_fields: bool) -> Optional[ModelField]:
"""
Get a ModelField representing the request body for a path operation, combining
all body parameters into a single field if necessary.
@ -1000,9 +919,7 @@ def get_body_field(
if not embed_body_fields:
return first_param
model_name = "Body_" + name
BodyModel = create_body_model(
fields=flat_dependant.body_params, model_name=model_name
)
BodyModel = create_body_model(fields=flat_dependant.body_params, model_name=model_name)
required = any(True for f in flat_dependant.body_params if f.required)
BodyFieldInfo_kwargs: dict[str, Any] = {
"annotation": BodyModel,
@ -1018,9 +935,7 @@ def get_body_field(
BodyFieldInfo = params.Body
body_param_media_types = [
f.field_info.media_type
for f in flat_dependant.body_params
if isinstance(f.field_info, params.Body)
f.field_info.media_type for f in flat_dependant.body_params if isinstance(f.field_info, params.Body)
]
if len(set(body_param_media_types)) == 1:
BodyFieldInfo_kwargs["media_type"] = body_param_media_types[0]

View File

@ -5,8 +5,9 @@ from annotated_doc import Doc
from fastapi import params
from fastapi._compat import Undefined
from fastapi.openapi.models import Example
from fastapi.types import DependencyScope
from pydantic import AliasChoices, AliasPath
from typing_extensions import Literal, deprecated
from typing_extensions import deprecated
_Unset: Any = Undefined
@ -2315,7 +2316,7 @@ def Depends( # noqa: N802
),
] = True,
scope: Annotated[
Union[Literal["function", "request"], None],
DependencyScope,
Doc(
"""
Mainly for dependencies with `yield`, define when the dependency function

View File

@ -6,9 +6,10 @@ from typing import Annotated, Any, Callable, Optional, Union
from fastapi.exceptions import FastAPIDeprecationWarning
from fastapi.openapi.models import Example
from fastapi.types import DependencyScope
from pydantic import AliasChoices, AliasPath
from pydantic.fields import FieldInfo
from typing_extensions import Literal, deprecated
from typing_extensions import deprecated
from ._compat import (
Undefined,
@ -747,7 +748,7 @@ class File(Form): # type: ignore[misc]
class Depends:
dependency: Optional[Callable[..., Any]] = None
use_cache: bool = True
scope: Union[Literal["function", "request"], None] = None
scope: DependencyScope = None
@dataclass(frozen=True)

View File

@ -1,6 +1,6 @@
import types
from enum import Enum
from typing import Any, Callable, Optional, TypeVar, Union
from typing import Any, Callable, Literal, Optional, TypeVar, Union
from pydantic import BaseModel
from pydantic.main import IncEx as IncEx
@ -9,3 +9,4 @@ DecoratedCallable = TypeVar("DecoratedCallable", bound=Callable[..., Any])
UnionType = getattr(types, "UnionType", Union)
ModelNameMap = dict[Union[type[BaseModel], type[Enum]], str]
DependencyCacheKey = tuple[Optional[Callable[..., Any]], tuple[str, ...], str]
DependencyScope = Union[Literal["function", "request"], None]