diff --git a/fastapi/security/utils.py b/fastapi/security/utils.py index c40d526a7..3ddf56f1f 100644 --- a/fastapi/security/utils.py +++ b/fastapi/security/utils.py @@ -29,7 +29,8 @@ def handle_exc_for_ws(func: _SecurityDepFunc) -> _SecurityDepFunc: except HTTPException as e: if not isinstance(request, WebSocket): raise e - await request.accept() + # close before accepted with result a HTTP 403 so the exception argument is ignored + # ref: https://asgi.readthedocs.io/en/latest/specs/www.html#close-send-event raise WebSocketException( code=WS_1008_POLICY_VIOLATION, reason=e.detail ) from None diff --git a/tests/test_security_http_base.py b/tests/test_security_http_base.py index 51928bafd..5c097939b 100644 --- a/tests/test_security_http_base.py +++ b/tests/test_security_http_base.py @@ -1,6 +1,8 @@ -from fastapi import FastAPI, Security +import pytest +from fastapi import FastAPI, Security, WebSocket from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBase from fastapi.testclient import TestClient +from starlette.websockets import WebSocketDisconnect app = FastAPI() @@ -12,6 +14,16 @@ def read_current_user(credentials: HTTPAuthorizationCredentials = Security(secur return {"scheme": credentials.scheme, "credentials": credentials.credentials} +@app.websocket("/users/timeline") +async def read_user_timeline( + websocket: WebSocket, credentials: HTTPAuthorizationCredentials = Security(security) +): + await websocket.accept() + await websocket.send_json( + {"scheme": credentials.scheme, "credentials": credentials.credentials} + ) + + client = TestClient(app) @@ -27,6 +39,21 @@ def test_security_http_base_no_credentials(): assert response.json() == {"detail": "Not authenticated"} +def test_security_http_base_with_ws(): + with client.websocket_connect( + "/users/timeline", headers={"Authorization": "Other foobar"} + ) as websocket: + data = websocket.receive_json() + assert data == {"scheme": "Other", "credentials": "foobar"} + + +def test_security_http_base_with_ws_no_credentials(): + with pytest.raises(WebSocketDisconnect) as e: + with client.websocket_connect("/users/timeline"): + pass + assert e.value.reason == "Not authenticated" + + def test_openapi_schema(): response = client.get("/openapi.json") assert response.status_code == 200, response.text