This commit is contained in:
Peter Volf 2026-02-06 19:06:38 +00:00 committed by GitHub
commit 4280a9f6b8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 97 additions and 33 deletions

View File

@ -38,12 +38,12 @@ class Dependant:
dependencies: list["Dependant"] = field(default_factory=list)
name: Optional[str] = None
call: Optional[Callable[..., Any]] = None
request_param_name: Optional[str] = None
websocket_param_name: Optional[str] = None
http_connection_param_name: Optional[str] = None
response_param_name: Optional[str] = None
background_tasks_param_name: Optional[str] = None
security_scopes_param_name: Optional[str] = None
request_param_names: set[str] = field(default_factory=set)
websocket_param_names: set[str] = field(default_factory=set)
http_connection_param_names: set[str] = field(default_factory=set)
response_param_names: set[str] = field(default_factory=set)
background_tasks_param_names: set[str] = field(default_factory=set)
security_scopes_param_names: set[str] = field(default_factory=set)
own_oauth_scopes: Optional[list[str]] = None
parent_oauth_scopes: Optional[list[str]] = None
use_cache: bool = True
@ -74,7 +74,7 @@ class Dependant:
def _uses_scopes(self) -> bool:
if self.own_oauth_scopes:
return True
if self.security_scopes_param_name is not None:
if self.security_scopes_param_names:
return True
if self._is_security_scheme:
return True

View File

@ -146,12 +146,12 @@ def get_flat_dependant(
body_params=dependant.body_params.copy(),
name=dependant.name,
call=dependant.call,
request_param_name=dependant.request_param_name,
websocket_param_name=dependant.websocket_param_name,
http_connection_param_name=dependant.http_connection_param_name,
response_param_name=dependant.response_param_name,
background_tasks_param_name=dependant.background_tasks_param_name,
security_scopes_param_name=dependant.security_scopes_param_name,
request_param_names=dependant.request_param_names.copy(),
websocket_param_names=dependant.websocket_param_names.copy(),
http_connection_param_names=dependant.http_connection_param_names.copy(),
response_param_names=dependant.response_param_names.copy(),
background_tasks_param_names=dependant.background_tasks_param_names.copy(),
security_scopes_param_names=dependant.security_scopes_param_names.copy(),
own_oauth_scopes=dependant.own_oauth_scopes,
parent_oauth_scopes=use_parent_oauth_scopes,
use_cache=dependant.use_cache,
@ -332,22 +332,22 @@ def add_non_field_param_to_dependency(
*, param_name: str, type_annotation: Any, dependant: Dependant
) -> Optional[bool]:
if lenient_issubclass(type_annotation, Request):
dependant.request_param_name = param_name
dependant.request_param_names.add(param_name)
return True
elif lenient_issubclass(type_annotation, WebSocket):
dependant.websocket_param_name = param_name
dependant.websocket_param_names.add(param_name)
return True
elif lenient_issubclass(type_annotation, HTTPConnection):
dependant.http_connection_param_name = param_name
dependant.http_connection_param_names.add(param_name)
return True
elif lenient_issubclass(type_annotation, Response):
dependant.response_param_name = param_name
dependant.response_param_names.add(param_name)
return True
elif lenient_issubclass(type_annotation, StarletteBackgroundTasks):
dependant.background_tasks_param_name = param_name
dependant.background_tasks_param_names.add(param_name)
return True
elif lenient_issubclass(type_annotation, SecurityScopes):
dependant.security_scopes_param_name = param_name
dependant.security_scopes_param_names.add(param_name)
return True
return None
@ -684,22 +684,25 @@ async def solve_dependencies(
)
values.update(body_values)
errors.extend(body_errors)
if dependant.http_connection_param_name:
values[dependant.http_connection_param_name] = request
if dependant.request_param_name and isinstance(request, Request):
values[dependant.request_param_name] = request
elif dependant.websocket_param_name and isinstance(request, WebSocket):
values[dependant.websocket_param_name] = request
if dependant.background_tasks_param_name:
for name in dependant.http_connection_param_names:
values[name] = request
if isinstance(request, Request):
for name in dependant.request_param_names:
values[name] = request
elif isinstance(request, WebSocket):
for name in dependant.websocket_param_names:
values[name] = request
if dependant.background_tasks_param_names:
if background_tasks is None:
background_tasks = BackgroundTasks()
values[dependant.background_tasks_param_name] = background_tasks
if dependant.response_param_name:
values[dependant.response_param_name] = response
if dependant.security_scopes_param_name:
values[dependant.security_scopes_param_name] = SecurityScopes(
scopes=dependant.oauth_scopes
)
for name in dependant.background_tasks_param_names:
values[name] = background_tasks
for name in dependant.response_param_names:
values[name] = response
if dependant.security_scopes_param_names:
security_scope = SecurityScopes(scopes=dependant.oauth_scopes)
for name in dependant.security_scopes_param_names:
values[name] = security_scope
return SolvedDependency(
values=values,
errors=errors,

View File

@ -0,0 +1,61 @@
import pytest
from fastapi import BackgroundTasks, FastAPI, Request, Response, WebSocket
from fastapi.security import SecurityScopes
from fastapi.testclient import TestClient
app = FastAPI()
@app.get("/request")
def request(r1: Request, r2: Request) -> str:
assert r1 is not None
assert r1 is r2
return "success"
@app.get("/response")
def response(r1: Response, r2: Response) -> str:
assert r1 is not None
assert r1 is r2
return "success"
@app.get("/background-tasks")
def background_tasks(t1: BackgroundTasks, t2: BackgroundTasks) -> str:
assert t1 is not None
assert t1 is t2
return "success"
@app.get("/security-scopes")
def security_scopes(sc1: SecurityScopes, sc2: SecurityScopes) -> str:
assert sc1 is not None
assert sc1 is sc2
return "success"
@app.websocket("/websocket")
async def websocket(ws1: WebSocket, ws2: WebSocket) -> str:
assert ws1 is ws2
await ws1.accept()
await ws1.send_text("success")
await ws1.close()
@pytest.mark.parametrize(
"url",
(
"/request",
"/response",
"/background-tasks",
"/security-scopes",
),
)
def test_duplicate_special_dependency(url: str) -> None:
assert TestClient(app).get(url).text == '"success"'
def test_duplicate_websocket_dependency() -> None:
with TestClient(app).websocket_connect("/websocket") as ws:
text = ws.receive_text()
assert text == "success"