diff --git a/fastapi/dependencies/models.py b/fastapi/dependencies/models.py index 58392326d6..f18f2b39ee 100644 --- a/fastapi/dependencies/models.py +++ b/fastapi/dependencies/models.py @@ -38,12 +38,12 @@ class Dependant: dependencies: list["Dependant"] = field(default_factory=list) name: Optional[str] = None call: Optional[Callable[..., Any]] = None - request_param_name: Optional[str] = None - websocket_param_name: Optional[str] = None - http_connection_param_name: Optional[str] = None - response_param_name: Optional[str] = None - background_tasks_param_name: Optional[str] = None - security_scopes_param_name: Optional[str] = None + request_param_names: set[str] = field(default_factory=set) + websocket_param_names: set[str] = field(default_factory=set) + http_connection_param_names: set[str] = field(default_factory=set) + response_param_names: set[str] = field(default_factory=set) + background_tasks_param_names: set[str] = field(default_factory=set) + security_scopes_param_names: set[str] = field(default_factory=set) own_oauth_scopes: Optional[list[str]] = None parent_oauth_scopes: Optional[list[str]] = None use_cache: bool = True @@ -74,7 +74,7 @@ class Dependant: def _uses_scopes(self) -> bool: if self.own_oauth_scopes: return True - if self.security_scopes_param_name is not None: + if self.security_scopes_param_names: return True if self._is_security_scheme: return True diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index dd42371ecc..42e6179531 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -146,12 +146,12 @@ def get_flat_dependant( body_params=dependant.body_params.copy(), name=dependant.name, call=dependant.call, - request_param_name=dependant.request_param_name, - websocket_param_name=dependant.websocket_param_name, - http_connection_param_name=dependant.http_connection_param_name, - response_param_name=dependant.response_param_name, - background_tasks_param_name=dependant.background_tasks_param_name, - security_scopes_param_name=dependant.security_scopes_param_name, + request_param_names=dependant.request_param_names.copy(), + websocket_param_names=dependant.websocket_param_names.copy(), + http_connection_param_names=dependant.http_connection_param_names.copy(), + response_param_names=dependant.response_param_names.copy(), + background_tasks_param_names=dependant.background_tasks_param_names.copy(), + security_scopes_param_names=dependant.security_scopes_param_names.copy(), own_oauth_scopes=dependant.own_oauth_scopes, parent_oauth_scopes=use_parent_oauth_scopes, use_cache=dependant.use_cache, @@ -332,22 +332,22 @@ def add_non_field_param_to_dependency( *, param_name: str, type_annotation: Any, dependant: Dependant ) -> Optional[bool]: if lenient_issubclass(type_annotation, Request): - dependant.request_param_name = param_name + dependant.request_param_names.add(param_name) return True elif lenient_issubclass(type_annotation, WebSocket): - dependant.websocket_param_name = param_name + dependant.websocket_param_names.add(param_name) return True elif lenient_issubclass(type_annotation, HTTPConnection): - dependant.http_connection_param_name = param_name + dependant.http_connection_param_names.add(param_name) return True elif lenient_issubclass(type_annotation, Response): - dependant.response_param_name = param_name + dependant.response_param_names.add(param_name) return True elif lenient_issubclass(type_annotation, StarletteBackgroundTasks): - dependant.background_tasks_param_name = param_name + dependant.background_tasks_param_names.add(param_name) return True elif lenient_issubclass(type_annotation, SecurityScopes): - dependant.security_scopes_param_name = param_name + dependant.security_scopes_param_names.add(param_name) return True return None @@ -684,22 +684,25 @@ async def solve_dependencies( ) values.update(body_values) errors.extend(body_errors) - if dependant.http_connection_param_name: - values[dependant.http_connection_param_name] = request - if dependant.request_param_name and isinstance(request, Request): - values[dependant.request_param_name] = request - elif dependant.websocket_param_name and isinstance(request, WebSocket): - values[dependant.websocket_param_name] = request - if dependant.background_tasks_param_name: + for name in dependant.http_connection_param_names: + values[name] = request + if isinstance(request, Request): + for name in dependant.request_param_names: + values[name] = request + elif isinstance(request, WebSocket): + for name in dependant.websocket_param_names: + values[name] = request + if dependant.background_tasks_param_names: if background_tasks is None: background_tasks = BackgroundTasks() - values[dependant.background_tasks_param_name] = background_tasks - if dependant.response_param_name: - values[dependant.response_param_name] = response - if dependant.security_scopes_param_name: - values[dependant.security_scopes_param_name] = SecurityScopes( - scopes=dependant.oauth_scopes - ) + for name in dependant.background_tasks_param_names: + values[name] = background_tasks + for name in dependant.response_param_names: + values[name] = response + if dependant.security_scopes_param_names: + security_scope = SecurityScopes(scopes=dependant.oauth_scopes) + for name in dependant.security_scopes_param_names: + values[name] = security_scope return SolvedDependency( values=values, errors=errors, diff --git a/tests/test_duplicate_special_dependencies.py b/tests/test_duplicate_special_dependencies.py new file mode 100644 index 0000000000..8dfb57cb24 --- /dev/null +++ b/tests/test_duplicate_special_dependencies.py @@ -0,0 +1,61 @@ +import pytest +from fastapi import BackgroundTasks, FastAPI, Request, Response, WebSocket +from fastapi.security import SecurityScopes +from fastapi.testclient import TestClient + +app = FastAPI() + + +@app.get("/request") +def request(r1: Request, r2: Request) -> str: + assert r1 is not None + assert r1 is r2 + return "success" + + +@app.get("/response") +def response(r1: Response, r2: Response) -> str: + assert r1 is not None + assert r1 is r2 + return "success" + + +@app.get("/background-tasks") +def background_tasks(t1: BackgroundTasks, t2: BackgroundTasks) -> str: + assert t1 is not None + assert t1 is t2 + return "success" + + +@app.get("/security-scopes") +def security_scopes(sc1: SecurityScopes, sc2: SecurityScopes) -> str: + assert sc1 is not None + assert sc1 is sc2 + return "success" + + +@app.websocket("/websocket") +async def websocket(ws1: WebSocket, ws2: WebSocket) -> str: + assert ws1 is ws2 + await ws1.accept() + await ws1.send_text("success") + await ws1.close() + + +@pytest.mark.parametrize( + "url", + ( + "/request", + "/response", + "/background-tasks", + "/security-scopes", + ), +) +def test_duplicate_special_dependency(url: str) -> None: + assert TestClient(app).get(url).text == '"success"' + + +def test_duplicate_websocket_dependency() -> None: + with TestClient(app).websocket_connect("/websocket") as ws: + text = ws.receive_text() + assert text == "success"