mirror of https://github.com/tiangolo/fastapi.git
Add tests for websocket with authorization
This commit is contained in:
parent
7d071df243
commit
f40bbafd8d
|
|
@ -29,7 +29,8 @@ def handle_exc_for_ws(func: _SecurityDepFunc) -> _SecurityDepFunc:
|
|||
except HTTPException as e:
|
||||
if not isinstance(request, WebSocket):
|
||||
raise e
|
||||
await request.accept()
|
||||
# close before accepted with result a HTTP 403 so the exception argument is ignored
|
||||
# ref: https://asgi.readthedocs.io/en/latest/specs/www.html#close-send-event
|
||||
raise WebSocketException(
|
||||
code=WS_1008_POLICY_VIOLATION, reason=e.detail
|
||||
) from None
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
from fastapi import FastAPI, Security
|
||||
import pytest
|
||||
from fastapi import FastAPI, Security, WebSocket
|
||||
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBase
|
||||
from fastapi.testclient import TestClient
|
||||
from starlette.websockets import WebSocketDisconnect
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
|
@ -12,6 +14,16 @@ def read_current_user(credentials: HTTPAuthorizationCredentials = Security(secur
|
|||
return {"scheme": credentials.scheme, "credentials": credentials.credentials}
|
||||
|
||||
|
||||
@app.websocket("/users/timeline")
|
||||
async def read_user_timeline(
|
||||
websocket: WebSocket, credentials: HTTPAuthorizationCredentials = Security(security)
|
||||
):
|
||||
await websocket.accept()
|
||||
await websocket.send_json(
|
||||
{"scheme": credentials.scheme, "credentials": credentials.credentials}
|
||||
)
|
||||
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
|
|
@ -27,6 +39,21 @@ def test_security_http_base_no_credentials():
|
|||
assert response.json() == {"detail": "Not authenticated"}
|
||||
|
||||
|
||||
def test_security_http_base_with_ws():
|
||||
with client.websocket_connect(
|
||||
"/users/timeline", headers={"Authorization": "Other foobar"}
|
||||
) as websocket:
|
||||
data = websocket.receive_json()
|
||||
assert data == {"scheme": "Other", "credentials": "foobar"}
|
||||
|
||||
|
||||
def test_security_http_base_with_ws_no_credentials():
|
||||
with pytest.raises(WebSocketDisconnect) as e:
|
||||
with client.websocket_connect("/users/timeline"):
|
||||
pass
|
||||
assert e.value.reason == "Not authenticated"
|
||||
|
||||
|
||||
def test_openapi_schema():
|
||||
response = client.get("/openapi.json")
|
||||
assert response.status_code == 200, response.text
|
||||
|
|
|
|||
Loading…
Reference in New Issue