mirror of https://github.com/tiangolo/fastapi.git
Merge dde7ef0ee5 into 272204c0c7
This commit is contained in:
commit
d1161dec96
|
|
@ -2,7 +2,7 @@ import inspect
|
|||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
from functools import cached_property, partial
|
||||
from typing import Any, Callable, List, Optional, Union
|
||||
from typing import Any, Callable, List, Optional, Set, Union
|
||||
|
||||
from fastapi._compat import ModelField
|
||||
from fastapi.security.base import SecurityBase
|
||||
|
|
@ -38,12 +38,12 @@ class Dependant:
|
|||
dependencies: List["Dependant"] = field(default_factory=list)
|
||||
name: Optional[str] = None
|
||||
call: Optional[Callable[..., Any]] = None
|
||||
request_param_name: Optional[str] = None
|
||||
websocket_param_name: Optional[str] = None
|
||||
http_connection_param_name: Optional[str] = None
|
||||
response_param_name: Optional[str] = None
|
||||
background_tasks_param_name: Optional[str] = None
|
||||
security_scopes_param_name: Optional[str] = None
|
||||
request_param_names: Set[str] = field(default_factory=set)
|
||||
websocket_param_names: Set[str] = field(default_factory=set)
|
||||
http_connection_param_names: Set[str] = field(default_factory=set)
|
||||
response_param_names: Set[str] = field(default_factory=set)
|
||||
background_tasks_param_names: Set[str] = field(default_factory=set)
|
||||
security_scopes_param_names: Set[str] = field(default_factory=set)
|
||||
own_oauth_scopes: Optional[List[str]] = None
|
||||
parent_oauth_scopes: Optional[List[str]] = None
|
||||
use_cache: bool = True
|
||||
|
|
@ -74,7 +74,7 @@ class Dependant:
|
|||
def _uses_scopes(self) -> bool:
|
||||
if self.own_oauth_scopes:
|
||||
return True
|
||||
if self.security_scopes_param_name is not None:
|
||||
if self.security_scopes_param_names:
|
||||
return True
|
||||
if self._is_security_scheme:
|
||||
return True
|
||||
|
|
|
|||
|
|
@ -158,12 +158,12 @@ def get_flat_dependant(
|
|||
body_params=dependant.body_params.copy(),
|
||||
name=dependant.name,
|
||||
call=dependant.call,
|
||||
request_param_name=dependant.request_param_name,
|
||||
websocket_param_name=dependant.websocket_param_name,
|
||||
http_connection_param_name=dependant.http_connection_param_name,
|
||||
response_param_name=dependant.response_param_name,
|
||||
background_tasks_param_name=dependant.background_tasks_param_name,
|
||||
security_scopes_param_name=dependant.security_scopes_param_name,
|
||||
request_param_names=dependant.request_param_names.copy(),
|
||||
websocket_param_names=dependant.websocket_param_names.copy(),
|
||||
http_connection_param_names=dependant.http_connection_param_names.copy(),
|
||||
response_param_names=dependant.response_param_names.copy(),
|
||||
background_tasks_param_names=dependant.background_tasks_param_names.copy(),
|
||||
security_scopes_param_names=dependant.security_scopes_param_names.copy(),
|
||||
own_oauth_scopes=dependant.own_oauth_scopes,
|
||||
parent_oauth_scopes=use_parent_oauth_scopes,
|
||||
use_cache=dependant.use_cache,
|
||||
|
|
@ -341,22 +341,22 @@ def add_non_field_param_to_dependency(
|
|||
*, param_name: str, type_annotation: Any, dependant: Dependant
|
||||
) -> Optional[bool]:
|
||||
if lenient_issubclass(type_annotation, Request):
|
||||
dependant.request_param_name = param_name
|
||||
dependant.request_param_names.add(param_name)
|
||||
return True
|
||||
elif lenient_issubclass(type_annotation, WebSocket):
|
||||
dependant.websocket_param_name = param_name
|
||||
dependant.websocket_param_names.add(param_name)
|
||||
return True
|
||||
elif lenient_issubclass(type_annotation, HTTPConnection):
|
||||
dependant.http_connection_param_name = param_name
|
||||
dependant.http_connection_param_names.add(param_name)
|
||||
return True
|
||||
elif lenient_issubclass(type_annotation, Response):
|
||||
dependant.response_param_name = param_name
|
||||
dependant.response_param_names.add(param_name)
|
||||
return True
|
||||
elif lenient_issubclass(type_annotation, StarletteBackgroundTasks):
|
||||
dependant.background_tasks_param_name = param_name
|
||||
dependant.background_tasks_param_names.add(param_name)
|
||||
return True
|
||||
elif lenient_issubclass(type_annotation, SecurityScopes):
|
||||
dependant.security_scopes_param_name = param_name
|
||||
dependant.security_scopes_param_names.add(param_name)
|
||||
return True
|
||||
return None
|
||||
|
||||
|
|
@ -706,22 +706,25 @@ async def solve_dependencies(
|
|||
)
|
||||
values.update(body_values)
|
||||
errors.extend(body_errors)
|
||||
if dependant.http_connection_param_name:
|
||||
values[dependant.http_connection_param_name] = request
|
||||
if dependant.request_param_name and isinstance(request, Request):
|
||||
values[dependant.request_param_name] = request
|
||||
elif dependant.websocket_param_name and isinstance(request, WebSocket):
|
||||
values[dependant.websocket_param_name] = request
|
||||
if dependant.background_tasks_param_name:
|
||||
for name in dependant.http_connection_param_names:
|
||||
values[name] = request
|
||||
if isinstance(request, Request):
|
||||
for name in dependant.request_param_names:
|
||||
values[name] = request
|
||||
elif isinstance(request, WebSocket):
|
||||
for name in dependant.websocket_param_names:
|
||||
values[name] = request
|
||||
if dependant.background_tasks_param_names:
|
||||
if background_tasks is None:
|
||||
background_tasks = BackgroundTasks()
|
||||
values[dependant.background_tasks_param_name] = background_tasks
|
||||
if dependant.response_param_name:
|
||||
values[dependant.response_param_name] = response
|
||||
if dependant.security_scopes_param_name:
|
||||
values[dependant.security_scopes_param_name] = SecurityScopes(
|
||||
scopes=dependant.oauth_scopes
|
||||
)
|
||||
for name in dependant.background_tasks_param_names:
|
||||
values[name] = background_tasks
|
||||
for name in dependant.response_param_names:
|
||||
values[name] = response
|
||||
if dependant.security_scopes_param_names:
|
||||
security_scope = SecurityScopes(scopes=dependant.oauth_scopes)
|
||||
for name in dependant.security_scopes_param_names:
|
||||
values[name] = security_scope
|
||||
return SolvedDependency(
|
||||
values=values,
|
||||
errors=errors,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,61 @@
|
|||
import pytest
|
||||
from fastapi import BackgroundTasks, FastAPI, Request, Response, WebSocket
|
||||
from fastapi.security import SecurityScopes
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
@app.get("/request")
|
||||
def request(r1: Request, r2: Request) -> str:
|
||||
assert r1 is not None
|
||||
assert r1 is r2
|
||||
return "success"
|
||||
|
||||
|
||||
@app.get("/response")
|
||||
def response(r1: Response, r2: Response) -> str:
|
||||
assert r1 is not None
|
||||
assert r1 is r2
|
||||
return "success"
|
||||
|
||||
|
||||
@app.get("/background-tasks")
|
||||
def background_tasks(t1: BackgroundTasks, t2: BackgroundTasks) -> str:
|
||||
assert t1 is not None
|
||||
assert t1 is t2
|
||||
return "success"
|
||||
|
||||
|
||||
@app.get("/security-scopes")
|
||||
def security_scopes(sc1: SecurityScopes, sc2: SecurityScopes) -> str:
|
||||
assert sc1 is not None
|
||||
assert sc1 is sc2
|
||||
return "success"
|
||||
|
||||
|
||||
@app.websocket("/websocket")
|
||||
async def websocket(ws1: WebSocket, ws2: WebSocket) -> str:
|
||||
assert ws1 is ws2
|
||||
await ws1.accept()
|
||||
await ws1.send_text("success")
|
||||
await ws1.close()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"url",
|
||||
(
|
||||
"/request",
|
||||
"/response",
|
||||
"/background-tasks",
|
||||
"/security-scopes",
|
||||
),
|
||||
)
|
||||
def test_duplicate_special_dependency(url: str) -> None:
|
||||
assert TestClient(app).get(url).text == '"success"'
|
||||
|
||||
|
||||
def test_duplicate_websocket_dependency() -> None:
|
||||
with TestClient(app).websocket_connect("/websocket") as ws:
|
||||
text = ws.receive_text()
|
||||
assert text == "success"
|
||||
Loading…
Reference in New Issue