mirror of https://github.com/tiangolo/fastapi.git
✨ Add support for injecting HTTPConnection (#1827)
This commit is contained in:
parent
5ed48ccdc8
commit
b9a0179a03
|
|
@ -34,6 +34,7 @@ class Dependant:
|
||||||
call: Optional[Callable] = None,
|
call: Optional[Callable] = None,
|
||||||
request_param_name: Optional[str] = None,
|
request_param_name: Optional[str] = None,
|
||||||
websocket_param_name: Optional[str] = None,
|
websocket_param_name: Optional[str] = None,
|
||||||
|
http_connection_param_name: Optional[str] = None,
|
||||||
response_param_name: Optional[str] = None,
|
response_param_name: Optional[str] = None,
|
||||||
background_tasks_param_name: Optional[str] = None,
|
background_tasks_param_name: Optional[str] = None,
|
||||||
security_scopes_param_name: Optional[str] = None,
|
security_scopes_param_name: Optional[str] = None,
|
||||||
|
|
@ -50,6 +51,7 @@ class Dependant:
|
||||||
self.security_requirements = security_schemes or []
|
self.security_requirements = security_schemes or []
|
||||||
self.request_param_name = request_param_name
|
self.request_param_name = request_param_name
|
||||||
self.websocket_param_name = websocket_param_name
|
self.websocket_param_name = websocket_param_name
|
||||||
|
self.http_connection_param_name = http_connection_param_name
|
||||||
self.response_param_name = response_param_name
|
self.response_param_name = response_param_name
|
||||||
self.background_tasks_param_name = background_tasks_param_name
|
self.background_tasks_param_name = background_tasks_param_name
|
||||||
self.security_scopes = security_scopes
|
self.security_scopes = security_scopes
|
||||||
|
|
|
||||||
|
|
@ -41,7 +41,7 @@ from pydantic.utils import lenient_issubclass
|
||||||
from starlette.background import BackgroundTasks
|
from starlette.background import BackgroundTasks
|
||||||
from starlette.concurrency import run_in_threadpool
|
from starlette.concurrency import run_in_threadpool
|
||||||
from starlette.datastructures import FormData, Headers, QueryParams, UploadFile
|
from starlette.datastructures import FormData, Headers, QueryParams, UploadFile
|
||||||
from starlette.requests import Request
|
from starlette.requests import HTTPConnection, Request
|
||||||
from starlette.responses import Response
|
from starlette.responses import Response
|
||||||
from starlette.websockets import WebSocket
|
from starlette.websockets import WebSocket
|
||||||
|
|
||||||
|
|
@ -371,6 +371,9 @@ def add_non_field_param_to_dependency(
|
||||||
elif lenient_issubclass(param.annotation, WebSocket):
|
elif lenient_issubclass(param.annotation, WebSocket):
|
||||||
dependant.websocket_param_name = param.name
|
dependant.websocket_param_name = param.name
|
||||||
return True
|
return True
|
||||||
|
elif lenient_issubclass(param.annotation, HTTPConnection):
|
||||||
|
dependant.http_connection_param_name = param.name
|
||||||
|
return True
|
||||||
elif lenient_issubclass(param.annotation, Response):
|
elif lenient_issubclass(param.annotation, Response):
|
||||||
dependant.response_param_name = param.name
|
dependant.response_param_name = param.name
|
||||||
return True
|
return True
|
||||||
|
|
@ -607,6 +610,8 @@ async def solve_dependencies(
|
||||||
)
|
)
|
||||||
values.update(body_values)
|
values.update(body_values)
|
||||||
errors.extend(body_errors)
|
errors.extend(body_errors)
|
||||||
|
if dependant.http_connection_param_name:
|
||||||
|
values[dependant.http_connection_param_name] = request
|
||||||
if dependant.request_param_name and isinstance(request, Request):
|
if dependant.request_param_name and isinstance(request, Request):
|
||||||
values[dependant.request_param_name] = request
|
values[dependant.request_param_name] = request
|
||||||
elif dependant.websocket_param_name and isinstance(request, WebSocket):
|
elif dependant.websocket_param_name and isinstance(request, WebSocket):
|
||||||
|
|
|
||||||
|
|
@ -1 +1 @@
|
||||||
from starlette.requests import Request # noqa
|
from starlette.requests import HTTPConnection, Request # noqa
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,39 @@
|
||||||
|
from fastapi import Depends, FastAPI
|
||||||
|
from fastapi.requests import HTTPConnection
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
from starlette.websockets import WebSocket
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
app.state.value = 42
|
||||||
|
|
||||||
|
|
||||||
|
async def extract_value_from_http_connection(conn: HTTPConnection):
|
||||||
|
return conn.app.state.value
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/http")
|
||||||
|
async def get_value_by_http(value: int = Depends(extract_value_from_http_connection)):
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
@app.websocket("/ws")
|
||||||
|
async def get_value_by_ws(
|
||||||
|
websocket: WebSocket, value: int = Depends(extract_value_from_http_connection)
|
||||||
|
):
|
||||||
|
await websocket.accept()
|
||||||
|
await websocket.send_json(value)
|
||||||
|
await websocket.close()
|
||||||
|
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
|
||||||
|
|
||||||
|
def test_value_extracting_by_http():
|
||||||
|
response = client.get("/http")
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == 42
|
||||||
|
|
||||||
|
|
||||||
|
def test_value_extracting_by_ws():
|
||||||
|
with client.websocket_connect("/ws") as websocket:
|
||||||
|
assert websocket.receive_json() == 42
|
||||||
Loading…
Reference in New Issue