From 95bceaa2fb47e695fe124f145d6856e94a042b00 Mon Sep 17 00:00:00 2001 From: Alan Potter Date: Thu, 5 Feb 2026 06:56:52 -0500 Subject: [PATCH] Refactor dependency scope type DRY up references to dependency scope by adding DependencyScope type. --- fastapi/dependencies/models.py | 5 +- fastapi/dependencies/utils.py | 171 +++++++++------------------------ fastapi/param_functions.py | 5 +- fastapi/params.py | 5 +- fastapi/types.py | 3 +- 5 files changed, 53 insertions(+), 136 deletions(-) diff --git a/fastapi/dependencies/models.py b/fastapi/dependencies/models.py index 58392326d6..98ea415f39 100644 --- a/fastapi/dependencies/models.py +++ b/fastapi/dependencies/models.py @@ -6,8 +6,7 @@ from typing import Any, Callable, Optional, Union from fastapi._compat import ModelField from fastapi.security.base import SecurityBase -from fastapi.types import DependencyCacheKey -from typing_extensions import Literal +from fastapi.types import DependencyCacheKey, DependencyScope if sys.version_info >= (3, 13): # pragma: no cover from inspect import iscoroutinefunction @@ -48,7 +47,7 @@ class Dependant: parent_oauth_scopes: Optional[list[str]] = None use_cache: bool = True path: Optional[str] = None - scope: Union[Literal["function", "request"], None] = None + scope: DependencyScope = None @cached_property def oauth_scopes(self) -> list[str]: diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index b8f7f948c6..ccd5ae78af 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -49,7 +49,7 @@ from fastapi.dependencies.models import Dependant from fastapi.exceptions import DependencyScopeError from fastapi.logger import logger from fastapi.security.oauth2 import SecurityScopes -from fastapi.types import DependencyCacheKey +from fastapi.types import DependencyCacheKey, DependencyScope from fastapi.utils import create_model_field, get_path_param_names from pydantic import BaseModel, Json from pydantic.fields import FieldInfo @@ -65,7 +65,7 @@ from starlette.datastructures import ( from starlette.requests import HTTPConnection, Request from starlette.responses import Response from starlette.websockets import WebSocket -from typing_extensions import Literal, get_args, get_origin +from typing_extensions import get_args, get_origin from typing_inspection.typing_objects import is_typealiastype multipart_not_installed_error = ( @@ -111,9 +111,7 @@ def ensure_multipart_is_installed() -> None: def get_parameterless_sub_dependant(*, depends: params.Depends, path: str) -> Dependant: - assert callable(depends.dependency), ( - "A parameter-less dependency must have a callable dependency" - ) + assert callable(depends.dependency), "A parameter-less dependency must have a callable dependency" own_oauth_scopes: list[str] = [] if isinstance(depends, params.Security) and depends.scopes: own_oauth_scopes.extend(depends.scopes) @@ -135,9 +133,7 @@ def get_flat_dependant( if visited is None: visited = [] visited.append(dependant.cache_key) - use_parent_oauth_scopes = (parent_oauth_scopes or []) + ( - dependant.oauth_scopes or [] - ) + use_parent_oauth_scopes = (parent_oauth_scopes or []) + (dependant.oauth_scopes or []) flat_dependant = Dependant( path_params=dependant.path_params.copy(), @@ -262,7 +258,7 @@ def get_dependant( own_oauth_scopes: Optional[list[str]] = None, parent_oauth_scopes: Optional[list[str]] = None, use_cache: bool = True, - scope: Union[Literal["function", "request"], None] = None, + scope: DependencyScope = None, ) -> Dependant: dependant = Dependant( call=call, @@ -317,9 +313,7 @@ def get_dependant( type_annotation=param_details.type_annotation, dependant=dependant, ): - assert param_details.field is None, ( - f"Cannot specify multiple FastAPI annotations for {param_name!r}" - ) + assert param_details.field is None, f"Cannot specify multiple FastAPI annotations for {param_name!r}" continue assert param_details.field is not None if isinstance(param_details.field.field_info, params.Body): @@ -329,9 +323,7 @@ def get_dependant( return dependant -def add_non_field_param_to_dependency( - *, param_name: str, type_annotation: Any, dependant: Dependant -) -> Optional[bool]: +def add_non_field_param_to_dependency(*, param_name: str, type_annotation: Any, dependant: Dependant) -> Optional[bool]: if lenient_issubclass(type_annotation, Request): dependant.request_param_name = param_name return True @@ -381,11 +373,7 @@ def analyze_param( if get_origin(use_annotation) is Annotated: annotated_args = get_args(annotation) type_annotation = annotated_args[0] - fastapi_annotations = [ - arg - for arg in annotated_args[1:] - if isinstance(arg, (FieldInfo, params.Depends)) - ] + fastapi_annotations = [arg for arg in annotated_args[1:] if isinstance(arg, (FieldInfo, params.Depends))] fastapi_specific_annotations = [ arg for arg in fastapi_annotations @@ -399,9 +387,7 @@ def analyze_param( ) ] if fastapi_specific_annotations: - fastapi_annotation: Union[FieldInfo, params.Depends, None] = ( - fastapi_specific_annotations[-1] - ) + fastapi_annotation: Union[FieldInfo, params.Depends, None] = fastapi_specific_annotations[-1] else: fastapi_annotation = None # Set default for Annotated FieldInfo @@ -411,9 +397,7 @@ def analyze_param( field_info=fastapi_annotation, annotation=use_annotation, ) - assert ( - field_info.default == Undefined or field_info.default == RequiredParam - ), ( + assert field_info.default == Undefined or field_info.default == RequiredParam, ( f"`{field_info.__class__.__name__}` default value cannot be set in" f" `Annotated` for {param_name!r}. Set the default value with `=` instead." ) @@ -427,10 +411,7 @@ def analyze_param( depends = fastapi_annotation # Get Depends from default value if isinstance(value, params.Depends): - assert depends is None, ( - "Cannot specify `Depends` in `Annotated` and default value" - f" together for {param_name!r}" - ) + assert depends is None, f"Cannot specify `Depends` in `Annotated` and default value together for {param_name!r}" assert field_info is None, ( "Cannot specify a FastAPI annotation in `Annotated` and `Depends` as a" f" default value together for {param_name!r}" @@ -439,8 +420,7 @@ def analyze_param( # Get FieldInfo from default value elif isinstance(value, FieldInfo): assert field_info is None, ( - "Cannot specify FastAPI annotations in `Annotated` and default value" - f" together for {param_name!r}" + f"Cannot specify FastAPI annotations in `Annotated` and default value together for {param_name!r}" ) field_info = value if isinstance(field_info, FieldInfo): @@ -466,9 +446,7 @@ def analyze_param( SecurityScopes, ), ): - assert field_info is None, ( - f"Cannot specify FastAPI annotation for type {type_annotation!r}" - ) + assert field_info is None, f"Cannot specify FastAPI annotation for type {type_annotation!r}" # Handle default assignations, neither field_info nor depends was not found in Annotated nor default value elif field_info is None and depends is None: default_value = value if value is not inspect.Signature.empty else RequiredParam @@ -477,9 +455,9 @@ def analyze_param( # parameter might sometimes be a path parameter and sometimes not. See # `tests/test_infer_param_optionality.py` for an example. field_info = params.Path(annotation=use_annotation) - elif is_uploadfile_or_nonable_uploadfile_annotation( + elif is_uploadfile_or_nonable_uploadfile_annotation(type_annotation) or is_uploadfile_sequence_annotation( type_annotation - ) or is_uploadfile_sequence_annotation(type_annotation): + ): field_info = params.File(annotation=use_annotation, default=default_value) elif not field_annotation_is_scalar(annotation=type_annotation): field_info = params.Body(annotation=use_annotation, default=default_value) @@ -492,13 +470,9 @@ def analyze_param( # Handle field_info.in_ if is_path_param: assert isinstance(field_info, params.Path), ( - f"Cannot use `{field_info.__class__.__name__}` for path param" - f" {param_name!r}" + f"Cannot use `{field_info.__class__.__name__}` for path param {param_name!r}" ) - elif ( - isinstance(field_info, params.Param) - and getattr(field_info, "in_", None) is None - ): + elif isinstance(field_info, params.Param) and getattr(field_info, "in_", None) is None: field_info.in_ = params.ParamTypes.query use_annotation_from_field_info = use_annotation if isinstance(field_info, params.Form): @@ -517,9 +491,7 @@ def analyze_param( field_info=field_info, ) if is_path_param: - assert is_scalar_field(field=field), ( - "Path params must be of one of the supported types" - ) + assert is_scalar_field(field=field), "Path params must be of one of the supported types" elif isinstance(field_info, params.Query): assert ( is_scalar_field(field) @@ -550,9 +522,7 @@ def add_param_to_fields(*, field: ModelField, dependant: Dependant) -> None: dependant.cookie_params.append(field) -async def _solve_generator( - *, dependant: Dependant, stack: AsyncExitStack, sub_values: dict[str, Any] -) -> Any: +async def _solve_generator(*, dependant: Dependant, stack: AsyncExitStack, sub_values: dict[str, Any]) -> Any: assert dependant.call if dependant.is_async_gen_callable: cm = asynccontextmanager(dependant.call)(**sub_values) @@ -585,13 +555,9 @@ async def solve_dependencies( embed_body_fields: bool, ) -> SolvedDependency: request_astack = request.scope.get("fastapi_inner_astack") - assert isinstance(request_astack, AsyncExitStack), ( - "fastapi_inner_astack not found in request scope" - ) + assert isinstance(request_astack, AsyncExitStack), "fastapi_inner_astack not found in request scope" function_astack = request.scope.get("fastapi_function_astack") - assert isinstance(function_astack, AsyncExitStack), ( - "fastapi_function_astack not found in request scope" - ) + assert isinstance(function_astack, AsyncExitStack), "fastapi_function_astack not found in request scope" values: dict[str, Any] = {} errors: list[Any] = [] if response is None: @@ -604,14 +570,9 @@ async def solve_dependencies( sub_dependant.call = cast(Callable[..., Any], sub_dependant.call) call = sub_dependant.call use_sub_dependant = sub_dependant - if ( - dependency_overrides_provider - and dependency_overrides_provider.dependency_overrides - ): + if dependency_overrides_provider and dependency_overrides_provider.dependency_overrides: original_call = sub_dependant.call - call = getattr( - dependency_overrides_provider, "dependency_overrides", {} - ).get(original_call, original_call) + call = getattr(dependency_overrides_provider, "dependency_overrides", {}).get(original_call, original_call) use_path: str = sub_dependant.path # type: ignore use_sub_dependant = get_dependant( path=use_path, @@ -638,9 +599,7 @@ async def solve_dependencies( continue if sub_dependant.use_cache and sub_dependant.cache_key in dependency_cache: solved = dependency_cache[sub_dependant.cache_key] - elif ( - use_sub_dependant.is_gen_callable or use_sub_dependant.is_async_gen_callable - ): + elif use_sub_dependant.is_gen_callable or use_sub_dependant.is_async_gen_callable: use_astack = request_astack if sub_dependant.scope == "function": use_astack = function_astack @@ -657,18 +616,10 @@ async def solve_dependencies( values[sub_dependant.name] = solved if sub_dependant.cache_key not in dependency_cache: dependency_cache[sub_dependant.cache_key] = solved - path_values, path_errors = request_params_to_args( - dependant.path_params, request.path_params - ) - query_values, query_errors = request_params_to_args( - dependant.query_params, request.query_params - ) - header_values, header_errors = request_params_to_args( - dependant.header_params, request.headers - ) - cookie_values, cookie_errors = request_params_to_args( - dependant.cookie_params, request.cookies - ) + path_values, path_errors = request_params_to_args(dependant.path_params, request.path_params) + query_values, query_errors = request_params_to_args(dependant.query_params, request.query_params) + header_values, header_errors = request_params_to_args(dependant.header_params, request.headers) + cookie_values, cookie_errors = request_params_to_args(dependant.cookie_params, request.cookies) values.update(path_values) values.update(query_values) values.update(header_values) @@ -698,9 +649,7 @@ async def solve_dependencies( if dependant.response_param_name: values[dependant.response_param_name] = response if dependant.security_scopes_param_name: - values[dependant.security_scopes_param_name] = SecurityScopes( - scopes=dependant.oauth_scopes - ) + values[dependant.security_scopes_param_name] = SecurityScopes(scopes=dependant.oauth_scopes) return SolvedDependency( values=values, errors=errors, @@ -730,15 +679,9 @@ def _is_json_field(field: ModelField) -> bool: return any(type(item) is Json for item in field.field_info.metadata) -def _get_multidict_value( - field: ModelField, values: Mapping[str, Any], alias: Union[str, None] = None -) -> Any: +def _get_multidict_value(field: ModelField, values: Mapping[str, Any], alias: Union[str, None] = None) -> Any: alias = alias or get_validation_alias(field) - if ( - (not _is_json_field(field)) - and is_sequence_field(field) - and isinstance(values, (ImmutableMultiDict, Headers)) - ): + if (not _is_json_field(field)) and is_sequence_field(field) and isinstance(values, (ImmutableMultiDict, Headers)): value = values.getlist(alias) else: value = values.get(alias, None) @@ -777,9 +720,7 @@ def request_params_to_args( single_not_embedded_field = True # If headers are in a Pydantic model, the way to disable convert_underscores # would be with Header(convert_underscores=False) at the Pydantic model level - default_convert_underscores = getattr( - first_field.field_info, "convert_underscores", True - ) + default_convert_underscores = getattr(first_field.field_info, "convert_underscores", True) params_to_process: dict[str, Any] = {} @@ -790,9 +731,7 @@ def request_params_to_args( if isinstance(received_params, Headers): # Handle fields extracted from a Pydantic Model for a header, each field # doesn't have a FieldInfo of type Header with the default convert_underscores=True - convert_underscores = getattr( - field.field_info, "convert_underscores", default_convert_underscores - ) + convert_underscores = getattr(field.field_info, "convert_underscores", default_convert_underscores) if convert_underscores: alias = get_validation_alias(field) if alias == field.name: @@ -815,9 +754,7 @@ def request_params_to_args( if single_not_embedded_field: field_info = first_field.field_info - assert isinstance(field_info, params.Param), ( - "Params must be subclasses of Param" - ) + assert isinstance(field_info, params.Param), "Params must be subclasses of Param" loc: tuple[str, ...] = (field_info.in_.value,) v_, errors_ = _validate_value_with_model_field( field=first_field, value=params_to_process, values=values, loc=loc @@ -827,13 +764,9 @@ def request_params_to_args( for field in fields: value = _get_multidict_value(field, received_params) field_info = field.field_info - assert isinstance(field_info, params.Param), ( - "Params must be subclasses of Param" - ) + assert isinstance(field_info, params.Param), "Params must be subclasses of Param" loc = (field_info.in_.value, get_validation_alias(field)) - v_, errors_ = _validate_value_with_model_field( - field=field, value=value, values=values, loc=loc - ) + v_, errors_ = _validate_value_with_model_field(field=field, value=value, values=values, loc=loc) if errors_: errors.extend(errors_) else: @@ -893,17 +826,9 @@ async def _extract_form_body( for field in body_fields: value = _get_multidict_value(field, received_body) field_info = field.field_info - if ( - isinstance(field_info, params.File) - and is_bytes_field(field) - and isinstance(value, UploadFile) - ): + if isinstance(field_info, params.File) and is_bytes_field(field) and isinstance(value, UploadFile): value = await value.read() - elif ( - is_bytes_sequence_field(field) - and isinstance(field_info, params.File) - and value_is_sequence(value) - ): + elif is_bytes_sequence_field(field) and isinstance(field_info, params.File) and value_is_sequence(value): # For types assert isinstance(value, sequence_types) results: list[Union[bytes, str]] = [] @@ -957,9 +882,7 @@ async def request_body_to_args( if single_not_embedded_field: loc: tuple[str, ...] = ("body",) - v_, errors_ = _validate_value_with_model_field( - field=first_field, value=body_to_process, values=values, loc=loc - ) + v_, errors_ = _validate_value_with_model_field(field=first_field, value=body_to_process, values=values, loc=loc) return {first_field.name: v_}, errors_ for field in body_fields: loc = ("body", get_validation_alias(field)) @@ -971,9 +894,7 @@ async def request_body_to_args( except AttributeError: errors.append(get_missing_field_error(loc)) continue - v_, errors_ = _validate_value_with_model_field( - field=field, value=value, values=values, loc=loc - ) + v_, errors_ = _validate_value_with_model_field(field=field, value=value, values=values, loc=loc) if errors_: errors.extend(errors_) else: @@ -981,9 +902,7 @@ async def request_body_to_args( return values, errors -def get_body_field( - *, flat_dependant: Dependant, name: str, embed_body_fields: bool -) -> Optional[ModelField]: +def get_body_field(*, flat_dependant: Dependant, name: str, embed_body_fields: bool) -> Optional[ModelField]: """ Get a ModelField representing the request body for a path operation, combining all body parameters into a single field if necessary. @@ -1000,9 +919,7 @@ def get_body_field( if not embed_body_fields: return first_param model_name = "Body_" + name - BodyModel = create_body_model( - fields=flat_dependant.body_params, model_name=model_name - ) + BodyModel = create_body_model(fields=flat_dependant.body_params, model_name=model_name) required = any(True for f in flat_dependant.body_params if f.required) BodyFieldInfo_kwargs: dict[str, Any] = { "annotation": BodyModel, @@ -1018,9 +935,7 @@ def get_body_field( BodyFieldInfo = params.Body body_param_media_types = [ - f.field_info.media_type - for f in flat_dependant.body_params - if isinstance(f.field_info, params.Body) + f.field_info.media_type for f in flat_dependant.body_params if isinstance(f.field_info, params.Body) ] if len(set(body_param_media_types)) == 1: BodyFieldInfo_kwargs["media_type"] = body_param_media_types[0] diff --git a/fastapi/param_functions.py b/fastapi/param_functions.py index 9bd92be4c7..7ddb5246ae 100644 --- a/fastapi/param_functions.py +++ b/fastapi/param_functions.py @@ -5,8 +5,9 @@ from annotated_doc import Doc from fastapi import params from fastapi._compat import Undefined from fastapi.openapi.models import Example +from fastapi.types import DependencyScope from pydantic import AliasChoices, AliasPath -from typing_extensions import Literal, deprecated +from typing_extensions import deprecated _Unset: Any = Undefined @@ -2315,7 +2316,7 @@ def Depends( # noqa: N802 ), ] = True, scope: Annotated[ - Union[Literal["function", "request"], None], + DependencyScope, Doc( """ Mainly for dependencies with `yield`, define when the dependency function diff --git a/fastapi/params.py b/fastapi/params.py index 72e797f833..b602cd63e2 100644 --- a/fastapi/params.py +++ b/fastapi/params.py @@ -6,9 +6,10 @@ from typing import Annotated, Any, Callable, Optional, Union from fastapi.exceptions import FastAPIDeprecationWarning from fastapi.openapi.models import Example +from fastapi.types import DependencyScope from pydantic import AliasChoices, AliasPath from pydantic.fields import FieldInfo -from typing_extensions import Literal, deprecated +from typing_extensions import deprecated from ._compat import ( Undefined, @@ -747,7 +748,7 @@ class File(Form): # type: ignore[misc] class Depends: dependency: Optional[Callable[..., Any]] = None use_cache: bool = True - scope: Union[Literal["function", "request"], None] = None + scope: DependencyScope = None @dataclass(frozen=True) diff --git a/fastapi/types.py b/fastapi/types.py index 1c3a6de749..8c9b53442d 100644 --- a/fastapi/types.py +++ b/fastapi/types.py @@ -1,6 +1,6 @@ import types from enum import Enum -from typing import Any, Callable, Optional, TypeVar, Union +from typing import Any, Callable, Literal, Optional, TypeVar, Union from pydantic import BaseModel from pydantic.main import IncEx as IncEx @@ -9,3 +9,4 @@ DecoratedCallable = TypeVar("DecoratedCallable", bound=Callable[..., Any]) UnionType = getattr(types, "UnionType", Union) ModelNameMap = dict[Union[type[BaseModel], type[Enum]], str] DependencyCacheKey = tuple[Optional[Callable[..., Any]], tuple[str, ...], str] +DependencyScope = Union[Literal["function", "request"], None]