diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index fc5dfed85a..83429f0993 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -49,7 +49,7 @@ from fastapi.dependencies.models import Dependant from fastapi.exceptions import DependencyScopeError from fastapi.logger import logger from fastapi.security.oauth2 import SecurityScopes -from fastapi.types import DependencyCacheKey +from fastapi.types import DependencyCacheKey, DependencyOverridesProvider from fastapi.utils import create_model_field, get_path_param_names from pydantic import BaseModel from pydantic.fields import FieldInfo @@ -572,7 +572,7 @@ async def solve_dependencies( body: Optional[Union[dict[str, Any], FormData]] = None, background_tasks: Optional[StarletteBackgroundTasks] = None, response: Optional[Response] = None, - dependency_overrides_provider: Optional[Any] = None, + dependency_overrides_provider: Optional[DependencyOverridesProvider] = None, dependency_cache: Optional[dict[DependencyCacheKey, Any]] = None, # TODO: remove this parameter later, no longer used, not removing it yet as some # people might be monkey patching this function (although that's not supported) @@ -604,9 +604,9 @@ async def solve_dependencies( and dependency_overrides_provider.dependency_overrides ): original_call = sub_dependant.call - call = getattr( - dependency_overrides_provider, "dependency_overrides", {} - ).get(original_call, original_call) + call = dependency_overrides_provider.dependency_overrides.get( + original_call, original_call + ) use_path: str = sub_dependant.path # type: ignore use_sub_dependant = get_dependant( path=use_path, diff --git a/fastapi/routing.py b/fastapi/routing.py index 9ca2f46732..c39f95154d 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -48,7 +48,7 @@ from fastapi.exceptions import ( ResponseValidationError, WebSocketRequestValidationError, ) -from fastapi.types import DecoratedCallable, IncEx +from fastapi.types import DecoratedCallable, DependencyOverridesProvider, IncEx from fastapi.utils import ( create_cloned_field, create_model_field, @@ -257,7 +257,7 @@ def get_request_handler( response_model_exclude_unset: bool = False, response_model_exclude_defaults: bool = False, response_model_exclude_none: bool = False, - dependency_overrides_provider: Optional[Any] = None, + dependency_overrides_provider: Optional[DependencyOverridesProvider] = None, embed_body_fields: bool = False, ) -> Callable[[Request], Coroutine[Any, Any, Response]]: assert dependant.call is not None, "dependant.call must be a function" @@ -405,7 +405,7 @@ def get_request_handler( def get_websocket_app( dependant: Dependant, - dependency_overrides_provider: Optional[Any] = None, + dependency_overrides_provider: Optional[DependencyOverridesProvider] = None, embed_body_fields: bool = False, ) -> Callable[[WebSocket], Coroutine[Any, Any, Any]]: async def app(websocket: WebSocket) -> None: @@ -448,7 +448,7 @@ class APIWebSocketRoute(routing.WebSocketRoute): *, name: Optional[str] = None, dependencies: Optional[Sequence[params.Depends]] = None, - dependency_overrides_provider: Optional[Any] = None, + dependency_overrides_provider: Optional[DependencyOverridesProvider] = None, ) -> None: self.path = path self.endpoint = endpoint @@ -510,7 +510,7 @@ class APIRoute(routing.Route): response_class: Union[type[Response], DefaultPlaceholder] = Default( JSONResponse ), - dependency_overrides_provider: Optional[Any] = None, + dependency_overrides_provider: Optional[DependencyOverridesProvider] = None, callbacks: Optional[list[BaseRoute]] = None, openapi_extra: Optional[dict[str, Any]] = None, generate_unique_id_function: Union[ @@ -800,7 +800,7 @@ class APIRouter(routing.Router): ), ] = None, dependency_overrides_provider: Annotated[ - Optional[Any], + Optional[DependencyOverridesProvider], Doc( """ Only used internally by FastAPI to handle dependency overrides. diff --git a/fastapi/types.py b/fastapi/types.py index 1c3a6de749..9eaaaa435a 100644 --- a/fastapi/types.py +++ b/fastapi/types.py @@ -1,6 +1,6 @@ import types from enum import Enum -from typing import Any, Callable, Optional, TypeVar, Union +from typing import Any, Callable, Optional, Protocol, TypeVar, Union from pydantic import BaseModel from pydantic.main import IncEx as IncEx @@ -9,3 +9,7 @@ DecoratedCallable = TypeVar("DecoratedCallable", bound=Callable[..., Any]) UnionType = getattr(types, "UnionType", Union) ModelNameMap = dict[Union[type[BaseModel], type[Enum]], str] DependencyCacheKey = tuple[Optional[Callable[..., Any]], tuple[str, ...], str] + + +class DependencyOverridesProvider(Protocol): + dependency_overrides: dict[Callable[..., Any], Any]