🐛 Fix dependency overrides in WebSockets (#1122)

* add tests to test_ws_router to test dependencies and dependency overrides.

* supply dependency_overrides_provider to APIWebSocketRoute upon creation
This commit is contained in:
amitlissack 2020-03-30 14:45:05 -04:00 committed by GitHub
parent 210af1fd3d
commit 02441ff031
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 33 additions and 2 deletions

View File

@ -498,7 +498,12 @@ class APIRouter(routing.Router):
def add_api_websocket_route(
self, path: str, endpoint: Callable, name: str = None
) -> None:
route = APIWebSocketRoute(path, endpoint=endpoint, name=name)
route = APIWebSocketRoute(
path,
endpoint=endpoint,
name=name,
dependency_overrides_provider=self.dependency_overrides_provider,
)
self.routes.append(route)
def websocket(self, path: str, name: str = None) -> Callable:

View File

@ -1,4 +1,4 @@
from fastapi import APIRouter, FastAPI, WebSocket
from fastapi import APIRouter, Depends, FastAPI, WebSocket
from fastapi.testclient import TestClient
router = APIRouter()
@ -34,6 +34,19 @@ async def routerindex(websocket: WebSocket):
await websocket.close()
async def ws_dependency():
return "Socket Dependency"
@router.websocket("/router-ws-depends/")
async def router_ws_decorator_depends(
websocket: WebSocket, data=Depends(ws_dependency)
):
await websocket.accept()
await websocket.send_text(data)
await websocket.close()
app.include_router(router)
app.include_router(prefix_router, prefix="/prefix")
@ -64,3 +77,16 @@ def test_router2():
with client.websocket_connect("/router2") as websocket:
data = websocket.receive_text()
assert data == "Hello, router!"
def test_router_ws_depends():
client = TestClient(app)
with client.websocket_connect("/router-ws-depends/") as websocket:
assert websocket.receive_text() == "Socket Dependency"
def test_router_ws_depends_with_override():
client = TestClient(app)
app.dependency_overrides[ws_dependency] = lambda: "Override"
with client.websocket_connect("/router-ws-depends/") as websocket:
assert websocket.receive_text() == "Override"