mirror of https://github.com/tiangolo/fastapi.git
✨ Add support for injecting HTTPConnection (#1827)
This commit is contained in:
parent
5ed48ccdc8
commit
b9a0179a03
|
|
@ -34,6 +34,7 @@ class Dependant:
|
|||
call: Optional[Callable] = None,
|
||||
request_param_name: Optional[str] = None,
|
||||
websocket_param_name: Optional[str] = None,
|
||||
http_connection_param_name: Optional[str] = None,
|
||||
response_param_name: Optional[str] = None,
|
||||
background_tasks_param_name: Optional[str] = None,
|
||||
security_scopes_param_name: Optional[str] = None,
|
||||
|
|
@ -50,6 +51,7 @@ class Dependant:
|
|||
self.security_requirements = security_schemes or []
|
||||
self.request_param_name = request_param_name
|
||||
self.websocket_param_name = websocket_param_name
|
||||
self.http_connection_param_name = http_connection_param_name
|
||||
self.response_param_name = response_param_name
|
||||
self.background_tasks_param_name = background_tasks_param_name
|
||||
self.security_scopes = security_scopes
|
||||
|
|
|
|||
|
|
@ -41,7 +41,7 @@ from pydantic.utils import lenient_issubclass
|
|||
from starlette.background import BackgroundTasks
|
||||
from starlette.concurrency import run_in_threadpool
|
||||
from starlette.datastructures import FormData, Headers, QueryParams, UploadFile
|
||||
from starlette.requests import Request
|
||||
from starlette.requests import HTTPConnection, Request
|
||||
from starlette.responses import Response
|
||||
from starlette.websockets import WebSocket
|
||||
|
||||
|
|
@ -371,6 +371,9 @@ def add_non_field_param_to_dependency(
|
|||
elif lenient_issubclass(param.annotation, WebSocket):
|
||||
dependant.websocket_param_name = param.name
|
||||
return True
|
||||
elif lenient_issubclass(param.annotation, HTTPConnection):
|
||||
dependant.http_connection_param_name = param.name
|
||||
return True
|
||||
elif lenient_issubclass(param.annotation, Response):
|
||||
dependant.response_param_name = param.name
|
||||
return True
|
||||
|
|
@ -607,6 +610,8 @@ async def solve_dependencies(
|
|||
)
|
||||
values.update(body_values)
|
||||
errors.extend(body_errors)
|
||||
if dependant.http_connection_param_name:
|
||||
values[dependant.http_connection_param_name] = request
|
||||
if dependant.request_param_name and isinstance(request, Request):
|
||||
values[dependant.request_param_name] = request
|
||||
elif dependant.websocket_param_name and isinstance(request, WebSocket):
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
from starlette.requests import Request # noqa
|
||||
from starlette.requests import HTTPConnection, Request # noqa
|
||||
|
|
|
|||
|
|
@ -0,0 +1,39 @@
|
|||
from fastapi import Depends, FastAPI
|
||||
from fastapi.requests import HTTPConnection
|
||||
from fastapi.testclient import TestClient
|
||||
from starlette.websockets import WebSocket
|
||||
|
||||
app = FastAPI()
|
||||
app.state.value = 42
|
||||
|
||||
|
||||
async def extract_value_from_http_connection(conn: HTTPConnection):
|
||||
return conn.app.state.value
|
||||
|
||||
|
||||
@app.get("/http")
|
||||
async def get_value_by_http(value: int = Depends(extract_value_from_http_connection)):
|
||||
return value
|
||||
|
||||
|
||||
@app.websocket("/ws")
|
||||
async def get_value_by_ws(
|
||||
websocket: WebSocket, value: int = Depends(extract_value_from_http_connection)
|
||||
):
|
||||
await websocket.accept()
|
||||
await websocket.send_json(value)
|
||||
await websocket.close()
|
||||
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
def test_value_extracting_by_http():
|
||||
response = client.get("/http")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == 42
|
||||
|
||||
|
||||
def test_value_extracting_by_ws():
|
||||
with client.websocket_connect("/ws") as websocket:
|
||||
assert websocket.receive_json() == 42
|
||||
Loading…
Reference in New Issue