diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index ab18ec2db..1725da95c 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -48,7 +48,7 @@ from fastapi.dependencies.models import Dependant from fastapi.exceptions import DependencyScopeError from fastapi.logger import logger from fastapi.security.oauth2 import SecurityScopes -from fastapi.types import DependencyCacheKey +from fastapi.types import DependencyCacheKey, DependencyOverridesProvider from fastapi.utils import create_model_field, get_path_param_names from pydantic import BaseModel, Json from pydantic.fields import FieldInfo @@ -569,7 +569,7 @@ async def solve_dependencies( body: dict[str, Any] | FormData | None = None, background_tasks: StarletteBackgroundTasks | None = None, response: Response | None = None, - dependency_overrides_provider: Any | None = None, + dependency_overrides_provider: DependencyOverridesProvider | None = None, dependency_cache: dict[DependencyCacheKey, Any] | None = None, # TODO: remove this parameter later, no longer used, not removing it yet as some # people might be monkey patching this function (although that's not supported) @@ -601,9 +601,9 @@ async def solve_dependencies( and dependency_overrides_provider.dependency_overrides ): original_call = sub_dependant.call - call = getattr( - dependency_overrides_provider, "dependency_overrides", {} - ).get(original_call, original_call) + call = dependency_overrides_provider.dependency_overrides.get( + original_call, original_call + ) use_path: str = sub_dependant.path # type: ignore use_sub_dependant = get_dependant( path=use_path, diff --git a/fastapi/routing.py b/fastapi/routing.py index ea82ab14a..4f0d38e45 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -53,7 +53,7 @@ from fastapi.exceptions import ( ResponseValidationError, WebSocketRequestValidationError, ) -from fastapi.types import DecoratedCallable, IncEx +from fastapi.types import DecoratedCallable, DependencyOverridesProvider, IncEx from fastapi.utils import ( create_model_field, generate_unique_id, @@ -326,7 +326,7 @@ def get_request_handler( response_model_exclude_unset: bool = False, response_model_exclude_defaults: bool = False, response_model_exclude_none: bool = False, - dependency_overrides_provider: Any | None = None, + dependency_overrides_provider: DependencyOverridesProvider | None = None, embed_body_fields: bool = False, ) -> Callable[[Request], Coroutine[Any, Any, Response]]: assert dependant.call is not None, "dependant.call must be a function" @@ -474,7 +474,7 @@ def get_request_handler( def get_websocket_app( dependant: Dependant, - dependency_overrides_provider: Any | None = None, + dependency_overrides_provider: DependencyOverridesProvider | None = None, embed_body_fields: bool = False, ) -> Callable[[WebSocket], Coroutine[Any, Any, Any]]: async def app(websocket: WebSocket) -> None: @@ -517,7 +517,7 @@ class APIWebSocketRoute(routing.WebSocketRoute): *, name: str | None = None, dependencies: Sequence[params.Depends] | None = None, - dependency_overrides_provider: Any | None = None, + dependency_overrides_provider: DependencyOverridesProvider | None = None, ) -> None: self.path = path self.endpoint = endpoint @@ -577,7 +577,7 @@ class APIRoute(routing.Route): response_model_exclude_none: bool = False, include_in_schema: bool = True, response_class: type[Response] | DefaultPlaceholder = Default(JSONResponse), - dependency_overrides_provider: Any | None = None, + dependency_overrides_provider: DependencyOverridesProvider | None = None, callbacks: list[BaseRoute] | None = None, openapi_extra: dict[str, Any] | None = None, generate_unique_id_function: Callable[["APIRoute"], str] @@ -844,7 +844,7 @@ class APIRouter(routing.Router): ), ] = None, dependency_overrides_provider: Annotated[ - Any | None, + DependencyOverridesProvider | None, Doc( """ Only used internally by FastAPI to handle dependency overrides. diff --git a/fastapi/types.py b/fastapi/types.py index 1fb86e13b..7beef6e51 100644 --- a/fastapi/types.py +++ b/fastapi/types.py @@ -1,7 +1,7 @@ import types from collections.abc import Callable from enum import Enum -from typing import Any, TypeVar, Union +from typing import Any, Protocol, TypeVar, Union from pydantic import BaseModel from pydantic.main import IncEx as IncEx @@ -10,3 +10,7 @@ DecoratedCallable = TypeVar("DecoratedCallable", bound=Callable[..., Any]) UnionType = getattr(types, "UnionType", Union) ModelNameMap = dict[type[BaseModel] | type[Enum], str] DependencyCacheKey = tuple[Callable[..., Any] | None, tuple[str, ...], str] + + +class DependencyOverridesProvider(Protocol): + dependency_overrides: dict[Callable[..., Any], Any]