This commit is contained in:
Evgeny Bokshitsky 2026-02-04 17:36:50 +00:00 committed by GitHub
commit e33e400b37
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 16 additions and 12 deletions

View File

@ -49,7 +49,7 @@ from fastapi.dependencies.models import Dependant
from fastapi.exceptions import DependencyScopeError
from fastapi.logger import logger
from fastapi.security.oauth2 import SecurityScopes
from fastapi.types import DependencyCacheKey
from fastapi.types import DependencyCacheKey, DependencyOverridesProvider
from fastapi.utils import create_model_field, get_path_param_names
from pydantic import BaseModel
from pydantic.fields import FieldInfo
@ -572,7 +572,7 @@ async def solve_dependencies(
body: Optional[Union[dict[str, Any], FormData]] = None,
background_tasks: Optional[StarletteBackgroundTasks] = None,
response: Optional[Response] = None,
dependency_overrides_provider: Optional[Any] = None,
dependency_overrides_provider: Optional[DependencyOverridesProvider] = None,
dependency_cache: Optional[dict[DependencyCacheKey, Any]] = None,
# TODO: remove this parameter later, no longer used, not removing it yet as some
# people might be monkey patching this function (although that's not supported)
@ -604,9 +604,9 @@ async def solve_dependencies(
and dependency_overrides_provider.dependency_overrides
):
original_call = sub_dependant.call
call = getattr(
dependency_overrides_provider, "dependency_overrides", {}
).get(original_call, original_call)
call = dependency_overrides_provider.dependency_overrides.get(
original_call, original_call
)
use_path: str = sub_dependant.path # type: ignore
use_sub_dependant = get_dependant(
path=use_path,

View File

@ -48,7 +48,7 @@ from fastapi.exceptions import (
ResponseValidationError,
WebSocketRequestValidationError,
)
from fastapi.types import DecoratedCallable, IncEx
from fastapi.types import DecoratedCallable, DependencyOverridesProvider, IncEx
from fastapi.utils import (
create_cloned_field,
create_model_field,
@ -257,7 +257,7 @@ def get_request_handler(
response_model_exclude_unset: bool = False,
response_model_exclude_defaults: bool = False,
response_model_exclude_none: bool = False,
dependency_overrides_provider: Optional[Any] = None,
dependency_overrides_provider: Optional[DependencyOverridesProvider] = None,
embed_body_fields: bool = False,
) -> Callable[[Request], Coroutine[Any, Any, Response]]:
assert dependant.call is not None, "dependant.call must be a function"
@ -405,7 +405,7 @@ def get_request_handler(
def get_websocket_app(
dependant: Dependant,
dependency_overrides_provider: Optional[Any] = None,
dependency_overrides_provider: Optional[DependencyOverridesProvider] = None,
embed_body_fields: bool = False,
) -> Callable[[WebSocket], Coroutine[Any, Any, Any]]:
async def app(websocket: WebSocket) -> None:
@ -448,7 +448,7 @@ class APIWebSocketRoute(routing.WebSocketRoute):
*,
name: Optional[str] = None,
dependencies: Optional[Sequence[params.Depends]] = None,
dependency_overrides_provider: Optional[Any] = None,
dependency_overrides_provider: Optional[DependencyOverridesProvider] = None,
) -> None:
self.path = path
self.endpoint = endpoint
@ -510,7 +510,7 @@ class APIRoute(routing.Route):
response_class: Union[type[Response], DefaultPlaceholder] = Default(
JSONResponse
),
dependency_overrides_provider: Optional[Any] = None,
dependency_overrides_provider: Optional[DependencyOverridesProvider] = None,
callbacks: Optional[list[BaseRoute]] = None,
openapi_extra: Optional[dict[str, Any]] = None,
generate_unique_id_function: Union[
@ -800,7 +800,7 @@ class APIRouter(routing.Router):
),
] = None,
dependency_overrides_provider: Annotated[
Optional[Any],
Optional[DependencyOverridesProvider],
Doc(
"""
Only used internally by FastAPI to handle dependency overrides.

View File

@ -1,6 +1,6 @@
import types
from enum import Enum
from typing import Any, Callable, Optional, TypeVar, Union
from typing import Any, Callable, Optional, Protocol, TypeVar, Union
from pydantic import BaseModel
from pydantic.main import IncEx as IncEx
@ -9,3 +9,7 @@ DecoratedCallable = TypeVar("DecoratedCallable", bound=Callable[..., Any])
UnionType = getattr(types, "UnionType", Union)
ModelNameMap = dict[Union[type[BaseModel], type[Enum]], str]
DependencyCacheKey = tuple[Optional[Callable[..., Any]], tuple[str, ...], str]
class DependencyOverridesProvider(Protocol):
dependency_overrides: dict[Callable[..., Any], Any]