add typing for dependency_overrides_provider

This commit is contained in:
Evgeny Bokshitsky 2025-12-28 01:21:44 +04:00
parent 47391ea8fb
commit eb5b08b83d
3 changed files with 13 additions and 12 deletions

View File

@ -49,7 +49,7 @@ from fastapi.dependencies.models import Dependant
from fastapi.exceptions import DependencyScopeError
from fastapi.logger import logger
from fastapi.security.oauth2 import SecurityScopes
from fastapi.types import DependencyCacheKey
from fastapi.types import DependencyCacheKey, DependencyOverridesProvider
from fastapi.utils import create_model_field, get_path_param_names
from pydantic import BaseModel
from pydantic.fields import FieldInfo
@ -567,7 +567,7 @@ async def solve_dependencies(
body: Optional[Union[dict[str, Any], FormData]] = None,
background_tasks: Optional[StarletteBackgroundTasks] = None,
response: Optional[Response] = None,
dependency_overrides_provider: Optional[Any] = None,
dependency_overrides_provider: Optional[DependencyOverridesProvider] = None,
dependency_cache: Optional[dict[DependencyCacheKey, Any]] = None,
# TODO: remove this parameter later, no longer used, not removing it yet as some
# people might be monkey patching this function (although that's not supported)
@ -599,9 +599,7 @@ async def solve_dependencies(
and dependency_overrides_provider.dependency_overrides
):
original_call = sub_dependant.call
call = getattr(
dependency_overrides_provider, "dependency_overrides", {}
).get(original_call, original_call)
call = dependency_overrides_provider.dependency_overrides.get(original_call, original_call)
use_path: str = sub_dependant.path # type: ignore
use_sub_dependant = get_dependant(
path=use_path,

View File

@ -48,7 +48,7 @@ from fastapi.exceptions import (
ResponseValidationError,
WebSocketRequestValidationError,
)
from fastapi.types import DecoratedCallable, IncEx
from fastapi.types import DecoratedCallable, IncEx, DependencyOverridesProvider
from fastapi.utils import (
create_cloned_field,
create_model_field,
@ -257,7 +257,7 @@ def get_request_handler(
response_model_exclude_unset: bool = False,
response_model_exclude_defaults: bool = False,
response_model_exclude_none: bool = False,
dependency_overrides_provider: Optional[Any] = None,
dependency_overrides_provider: Optional[DependencyOverridesProvider] = None,
embed_body_fields: bool = False,
) -> Callable[[Request], Coroutine[Any, Any, Response]]:
assert dependant.call is not None, "dependant.call must be a function"
@ -405,7 +405,7 @@ def get_request_handler(
def get_websocket_app(
dependant: Dependant,
dependency_overrides_provider: Optional[Any] = None,
dependency_overrides_provider: Optional[DependencyOverridesProvider] = None,
embed_body_fields: bool = False,
) -> Callable[[WebSocket], Coroutine[Any, Any, Any]]:
async def app(websocket: WebSocket) -> None:
@ -448,7 +448,7 @@ class APIWebSocketRoute(routing.WebSocketRoute):
*,
name: Optional[str] = None,
dependencies: Optional[Sequence[params.Depends]] = None,
dependency_overrides_provider: Optional[Any] = None,
dependency_overrides_provider: Optional[DependencyOverridesProvider] = None,
) -> None:
self.path = path
self.endpoint = endpoint
@ -510,7 +510,7 @@ class APIRoute(routing.Route):
response_class: Union[type[Response], DefaultPlaceholder] = Default(
JSONResponse
),
dependency_overrides_provider: Optional[Any] = None,
dependency_overrides_provider: Optional[DependencyOverridesProvider] = None,
callbacks: Optional[list[BaseRoute]] = None,
openapi_extra: Optional[dict[str, Any]] = None,
generate_unique_id_function: Union[
@ -800,7 +800,7 @@ class APIRouter(routing.Router):
),
] = None,
dependency_overrides_provider: Annotated[
Optional[Any],
Optional[DependencyOverridesProvider],
Doc(
"""
Only used internally by FastAPI to handle dependency overrides.

View File

@ -1,6 +1,6 @@
import types
from enum import Enum
from typing import Any, Callable, Optional, TypeVar, Union
from typing import Any, Callable, Optional, TypeVar, Union, Protocol
from pydantic import BaseModel
@ -9,3 +9,6 @@ UnionType = getattr(types, "UnionType", Union)
ModelNameMap = dict[Union[type[BaseModel], type[Enum]], str]
IncEx = Union[set[int], set[str], dict[int, Any], dict[str, Any]]
DependencyCacheKey = tuple[Optional[Callable[..., Any]], tuple[str, ...], str]
class DependencyOverridesProvider(Protocol):
dependency_overrides: dict[Callable[..., Any], Any]