From 30f4123ee7ef755c891e0e68d4c55c8321a7a673 Mon Sep 17 00:00:00 2001
From: philipmunch
Date: Wed, 17 Dec 2025 20:19:06 +0100
Subject: [PATCH 1/2] Security dependencies now working with websockets
---
fastapi/security/api_key.py | 8 +--
fastapi/security/http.py | 10 ++--
fastapi/security/oauth2.py | 8 +--
fastapi/security/open_id_connect_url.py | 4 +-
.../test_security_api_key_cookie_websocket.py | 42 +++++++++++++++
.../test_security_api_key_header_websocket.py | 45 ++++++++++++++++
.../test_security_api_key_query_websocket.py | 43 +++++++++++++++
tests/test_security_http_base_websocket.py | 38 +++++++++++++
tests/test_security_http_basic_websocket.py | 50 +++++++++++++++++
tests/test_security_http_bearer_websocket.py | 46 ++++++++++++++++
tests/test_security_http_digest_websocket.py | 46 ++++++++++++++++
...th2_authorization_code_bearer_websocket.py | 45 ++++++++++++++++
...curity_oauth2_password_bearer_websocket.py | 41 ++++++++++++++
.../test_security_openid_connect_websocket.py | 53 +++++++++++++++++++
14 files changed, 464 insertions(+), 15 deletions(-)
create mode 100644 tests/test_security_api_key_cookie_websocket.py
create mode 100644 tests/test_security_api_key_header_websocket.py
create mode 100644 tests/test_security_api_key_query_websocket.py
create mode 100644 tests/test_security_http_base_websocket.py
create mode 100644 tests/test_security_http_basic_websocket.py
create mode 100644 tests/test_security_http_bearer_websocket.py
create mode 100644 tests/test_security_http_digest_websocket.py
create mode 100644 tests/test_security_oauth2_authorization_code_bearer_websocket.py
create mode 100644 tests/test_security_oauth2_password_bearer_websocket.py
create mode 100644 tests/test_security_openid_connect_websocket.py
diff --git a/fastapi/security/api_key.py b/fastapi/security/api_key.py
index 81c7be10d6..f64370e3e5 100644
--- a/fastapi/security/api_key.py
+++ b/fastapi/security/api_key.py
@@ -4,7 +4,7 @@ from annotated_doc import Doc
from fastapi.openapi.models import APIKey, APIKeyIn
from fastapi.security.base import SecurityBase
from starlette.exceptions import HTTPException
-from starlette.requests import Request
+from starlette.requests import HTTPConnection
from starlette.status import HTTP_401_UNAUTHORIZED
from typing_extensions import Annotated
@@ -138,7 +138,7 @@ class APIKeyQuery(APIKeyBase):
auto_error=auto_error,
)
- async def __call__(self, request: Request) -> Optional[str]:
+ async def __call__(self, request: HTTPConnection) -> Optional[str]:
api_key = request.query_params.get(self.model.name)
return self.check_api_key(api_key)
@@ -226,7 +226,7 @@ class APIKeyHeader(APIKeyBase):
auto_error=auto_error,
)
- async def __call__(self, request: Request) -> Optional[str]:
+ async def __call__(self, request: HTTPConnection) -> Optional[str]:
api_key = request.headers.get(self.model.name)
return self.check_api_key(api_key)
@@ -314,6 +314,6 @@ class APIKeyCookie(APIKeyBase):
auto_error=auto_error,
)
- async def __call__(self, request: Request) -> Optional[str]:
+ async def __call__(self, request: HTTPConnection) -> Optional[str]:
api_key = request.cookies.get(self.model.name)
return self.check_api_key(api_key)
diff --git a/fastapi/security/http.py b/fastapi/security/http.py
index 0d1bbba3a0..a6de1ba2ba 100644
--- a/fastapi/security/http.py
+++ b/fastapi/security/http.py
@@ -9,7 +9,7 @@ from fastapi.openapi.models import HTTPBearer as HTTPBearerModel
from fastapi.security.base import SecurityBase
from fastapi.security.utils import get_authorization_scheme_param
from pydantic import BaseModel
-from starlette.requests import Request
+from starlette.requests import HTTPConnection
from starlette.status import HTTP_401_UNAUTHORIZED
from typing_extensions import Annotated
@@ -93,7 +93,7 @@ class HTTPBase(SecurityBase):
)
async def __call__(
- self, request: Request
+ self, request: HTTPConnection
) -> Optional[HTTPAuthorizationCredentials]:
authorization = request.headers.get("Authorization")
scheme, credentials = get_authorization_scheme_param(authorization)
@@ -203,7 +203,7 @@ class HTTPBasic(HTTPBase):
return {"WWW-Authenticate": "Basic"}
async def __call__( # type: ignore
- self, request: Request
+ self, request: HTTPConnection
) -> Optional[HTTPBasicCredentials]:
authorization = request.headers.get("Authorization")
scheme, param = get_authorization_scheme_param(authorization)
@@ -304,7 +304,7 @@ class HTTPBearer(HTTPBase):
self.auto_error = auto_error
async def __call__(
- self, request: Request
+ self, request: HTTPConnection
) -> Optional[HTTPAuthorizationCredentials]:
authorization = request.headers.get("Authorization")
scheme, credentials = get_authorization_scheme_param(authorization)
@@ -407,7 +407,7 @@ class HTTPDigest(HTTPBase):
self.auto_error = auto_error
async def __call__(
- self, request: Request
+ self, request: HTTPConnection
) -> Optional[HTTPAuthorizationCredentials]:
authorization = request.headers.get("Authorization")
scheme, credentials = get_authorization_scheme_param(authorization)
diff --git a/fastapi/security/oauth2.py b/fastapi/security/oauth2.py
index b41b0f8778..7d666d320e 100644
--- a/fastapi/security/oauth2.py
+++ b/fastapi/security/oauth2.py
@@ -7,7 +7,7 @@ from fastapi.openapi.models import OAuthFlows as OAuthFlowsModel
from fastapi.param_functions import Form
from fastapi.security.base import SecurityBase
from fastapi.security.utils import get_authorization_scheme_param
-from starlette.requests import Request
+from starlette.requests import HTTPConnection
from starlette.status import HTTP_401_UNAUTHORIZED
# TODO: import from typing when deprecating Python 3.9
@@ -399,7 +399,7 @@ class OAuth2(SecurityBase):
headers={"WWW-Authenticate": "Bearer"},
)
- async def __call__(self, request: Request) -> Optional[str]:
+ async def __call__(self, request: HTTPConnection) -> Optional[str]:
authorization = request.headers.get("Authorization")
if not authorization:
if self.auto_error:
@@ -506,7 +506,7 @@ class OAuth2PasswordBearer(OAuth2):
auto_error=auto_error,
)
- async def __call__(self, request: Request) -> Optional[str]:
+ async def __call__(self, request: HTTPConnection) -> Optional[str]:
authorization = request.headers.get("Authorization")
scheme, param = get_authorization_scheme_param(authorization)
if not authorization or scheme.lower() != "bearer":
@@ -612,7 +612,7 @@ class OAuth2AuthorizationCodeBearer(OAuth2):
auto_error=auto_error,
)
- async def __call__(self, request: Request) -> Optional[str]:
+ async def __call__(self, request: HTTPConnection) -> Optional[str]:
authorization = request.headers.get("Authorization")
scheme, param = get_authorization_scheme_param(authorization)
if not authorization or scheme.lower() != "bearer":
diff --git a/fastapi/security/open_id_connect_url.py b/fastapi/security/open_id_connect_url.py
index e574a56a82..4fdd7a89c7 100644
--- a/fastapi/security/open_id_connect_url.py
+++ b/fastapi/security/open_id_connect_url.py
@@ -4,7 +4,7 @@ from annotated_doc import Doc
from fastapi.openapi.models import OpenIdConnect as OpenIdConnectModel
from fastapi.security.base import SecurityBase
from starlette.exceptions import HTTPException
-from starlette.requests import Request
+from starlette.requests import HTTPConnection
from starlette.status import HTTP_401_UNAUTHORIZED
from typing_extensions import Annotated
@@ -85,7 +85,7 @@ class OpenIdConnect(SecurityBase):
headers={"WWW-Authenticate": "Bearer"},
)
- async def __call__(self, request: Request) -> Optional[str]:
+ async def __call__(self, request: HTTPConnection) -> Optional[str]:
authorization = request.headers.get("Authorization")
if not authorization:
if self.auto_error:
diff --git a/tests/test_security_api_key_cookie_websocket.py b/tests/test_security_api_key_cookie_websocket.py
new file mode 100644
index 0000000000..d4e0a4f57e
--- /dev/null
+++ b/tests/test_security_api_key_cookie_websocket.py
@@ -0,0 +1,42 @@
+import pytest
+from fastapi import Depends, FastAPI, Security
+from fastapi.security import APIKeyCookie
+from fastapi.testclient import TestClient
+from pydantic import BaseModel
+from starlette.testclient import WebSocketDenialResponse
+from starlette.websockets import WebSocket
+
+app = FastAPI()
+
+api_key = APIKeyCookie(name="key")
+
+
+class User(BaseModel):
+ username: str
+
+
+def get_current_user(oauth_header: str = Security(api_key)):
+ user = User(username=oauth_header)
+ return user
+
+
+@app.websocket("/ws/users/me")
+async def read_current_user(
+ websocket: WebSocket, current_user: User = Depends(get_current_user)
+):
+ await websocket.accept()
+ await websocket.send_text(current_user.username)
+
+
+def test_security_api_key_ws():
+ client = TestClient(app, cookies={"key": "secret"})
+ with client.websocket_connect("/ws/users/me") as websocket:
+ data = websocket.receive_text()
+ assert data == "secret"
+
+
+def test_security_api_key_no_key_ws():
+ client = TestClient(app)
+ with pytest.raises(WebSocketDenialResponse):
+ with client.websocket_connect("/ws/users/me"):
+ pass
diff --git a/tests/test_security_api_key_header_websocket.py b/tests/test_security_api_key_header_websocket.py
new file mode 100644
index 0000000000..636a03a89e
--- /dev/null
+++ b/tests/test_security_api_key_header_websocket.py
@@ -0,0 +1,45 @@
+import pytest
+from fastapi import Depends, FastAPI, Security
+from fastapi.security import APIKeyHeader
+from fastapi.testclient import TestClient
+from pydantic import BaseModel
+from starlette.testclient import WebSocketDenialResponse
+from starlette.websockets import WebSocket
+
+app = FastAPI()
+
+api_key = APIKeyHeader(name="key")
+
+
+class User(BaseModel):
+ username: str
+
+
+def get_current_user(oauth_header: str = Security(api_key)):
+ user = User(username=oauth_header)
+ return user
+
+
+@app.websocket("/ws/users/me")
+async def read_current_user(
+ websocket: WebSocket, current_user: User = Depends(get_current_user)
+):
+ await websocket.accept()
+ await websocket.send_text(current_user.username)
+
+
+client = TestClient(app)
+
+
+def test_security_api_key_ws():
+ with client.websocket_connect(
+ "/ws/users/me", headers={"key": "secret"}
+ ) as websocket:
+ data = websocket.receive_text()
+ assert data == "secret"
+
+
+def test_security_api_key_no_key_ws():
+ with pytest.raises(WebSocketDenialResponse):
+ with client.websocket_connect("/ws/users/me"):
+ pass
diff --git a/tests/test_security_api_key_query_websocket.py b/tests/test_security_api_key_query_websocket.py
new file mode 100644
index 0000000000..5264538a69
--- /dev/null
+++ b/tests/test_security_api_key_query_websocket.py
@@ -0,0 +1,43 @@
+import pytest
+from fastapi import Depends, FastAPI, Security
+from fastapi.security import APIKeyQuery
+from fastapi.testclient import TestClient
+from pydantic import BaseModel
+from starlette.testclient import WebSocketDenialResponse
+from starlette.websockets import WebSocket
+
+app = FastAPI()
+
+api_key = APIKeyQuery(name="key")
+
+
+class User(BaseModel):
+ username: str
+
+
+def get_current_user(oauth_header: str = Security(api_key)):
+ user = User(username=oauth_header)
+ return user
+
+
+@app.websocket("/ws/users/me")
+async def read_current_user(
+ websocket: WebSocket, current_user: User = Depends(get_current_user)
+):
+ await websocket.accept()
+ await websocket.send_text(current_user.username)
+
+
+client = TestClient(app)
+
+
+def test_security_api_key_query_ws():
+ with client.websocket_connect("/ws/users/me?key=secret") as websocket:
+ data = websocket.receive_text()
+ assert data == "secret"
+
+
+def test_security_api_key_query_no_key_ws():
+ with pytest.raises(WebSocketDenialResponse):
+ with client.websocket_connect("/ws/users/me"):
+ pass
diff --git a/tests/test_security_http_base_websocket.py b/tests/test_security_http_base_websocket.py
new file mode 100644
index 0000000000..7041471ea5
--- /dev/null
+++ b/tests/test_security_http_base_websocket.py
@@ -0,0 +1,38 @@
+import pytest
+from fastapi import FastAPI, Security
+from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBase
+from fastapi.testclient import TestClient
+from starlette.testclient import WebSocketDenialResponse
+from starlette.websockets import WebSocket
+
+app = FastAPI()
+
+security = HTTPBase(scheme="Other")
+
+
+@app.websocket("/ws/users/me")
+async def read_current_user(
+ websocket: WebSocket,
+ credentials: HTTPAuthorizationCredentials = Security(security),
+):
+ await websocket.accept()
+ await websocket.send_json(
+ {"scheme": credentials.scheme, "credentials": credentials.credentials}
+ )
+
+
+client = TestClient(app)
+
+
+def test_security_http_base_ws():
+ with client.websocket_connect(
+ "/ws/users/me", headers={"Authorization": "Other foobar"}
+ ) as websocket:
+ data = websocket.receive_json()
+ assert data == {"scheme": "Other", "credentials": "foobar"}
+
+
+def test_security_http_base_no_credentials_ws():
+ with pytest.raises(WebSocketDenialResponse):
+ with client.websocket_connect("/ws/users/me"):
+ pass
diff --git a/tests/test_security_http_basic_websocket.py b/tests/test_security_http_basic_websocket.py
new file mode 100644
index 0000000000..809487bb05
--- /dev/null
+++ b/tests/test_security_http_basic_websocket.py
@@ -0,0 +1,50 @@
+from base64 import b64encode
+
+import pytest
+from fastapi import FastAPI, Security
+from fastapi.security import HTTPBasic, HTTPBasicCredentials
+from fastapi.testclient import TestClient
+from starlette.testclient import WebSocketDenialResponse
+from starlette.websockets import WebSocket
+
+app = FastAPI()
+
+security = HTTPBasic(realm="simple")
+
+
+@app.websocket("/ws/users/me")
+async def read_current_user(
+ websocket: WebSocket, credentials: HTTPBasicCredentials = Security(security)
+):
+ await websocket.accept()
+ await websocket.send_json(
+ {"username": credentials.username, "password": credentials.password}
+ )
+
+
+client = TestClient(app)
+
+
+def test_security_http_basic_ws():
+ # Build Basic header
+ payload = b64encode(b"john:secret").decode("ascii")
+ auth_header = f"Basic {payload}"
+ with client.websocket_connect(
+ "/ws/users/me", headers={"Authorization": auth_header}
+ ) as websocket:
+ data = websocket.receive_json()
+ assert data == {"username": "john", "password": "secret"}
+
+
+def test_security_http_basic_no_credentials_ws():
+ with pytest.raises(WebSocketDenialResponse):
+ with client.websocket_connect("/ws/users/me"):
+ pass
+
+
+def test_security_http_basic_invalid_credentials_ws():
+ with pytest.raises(WebSocketDenialResponse):
+ with client.websocket_connect(
+ "/ws/users/me", headers={"Authorization": "Basic notbase64"}
+ ):
+ pass
diff --git a/tests/test_security_http_bearer_websocket.py b/tests/test_security_http_bearer_websocket.py
new file mode 100644
index 0000000000..ff41652a6b
--- /dev/null
+++ b/tests/test_security_http_bearer_websocket.py
@@ -0,0 +1,46 @@
+import pytest
+from fastapi import FastAPI, Security
+from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
+from fastapi.testclient import TestClient
+from starlette.testclient import WebSocketDenialResponse
+from starlette.websockets import WebSocket
+
+app = FastAPI()
+
+security = HTTPBearer()
+
+
+@app.websocket("/ws/users/me")
+async def read_current_user(
+ websocket: WebSocket,
+ credentials: HTTPAuthorizationCredentials = Security(security),
+):
+ await websocket.accept()
+ await websocket.send_json(
+ {"scheme": credentials.scheme, "credentials": credentials.credentials}
+ )
+
+
+client = TestClient(app)
+
+
+def test_security_http_bearer_ws():
+ with client.websocket_connect(
+ "/ws/users/me", headers={"Authorization": "Bearer foobar"}
+ ) as websocket:
+ data = websocket.receive_json()
+ assert data == {"scheme": "Bearer", "credentials": "foobar"}
+
+
+def test_security_http_bearer_no_credentials_ws():
+ with pytest.raises(WebSocketDenialResponse):
+ with client.websocket_connect("/ws/users/me"):
+ pass
+
+
+def test_security_http_bearer_incorrect_scheme_ws():
+ with pytest.raises(WebSocketDenialResponse):
+ with client.websocket_connect(
+ "/ws/users/me", headers={"Authorization": "Basic notreally"}
+ ):
+ pass
diff --git a/tests/test_security_http_digest_websocket.py b/tests/test_security_http_digest_websocket.py
new file mode 100644
index 0000000000..065507a851
--- /dev/null
+++ b/tests/test_security_http_digest_websocket.py
@@ -0,0 +1,46 @@
+import pytest
+from fastapi import FastAPI, Security
+from fastapi.security import HTTPAuthorizationCredentials, HTTPDigest
+from fastapi.testclient import TestClient
+from starlette.testclient import WebSocketDenialResponse
+from starlette.websockets import WebSocket
+
+app = FastAPI()
+
+security = HTTPDigest()
+
+
+@app.websocket("/ws/users/me")
+async def read_current_user(
+ websocket: WebSocket,
+ credentials: HTTPAuthorizationCredentials = Security(security),
+):
+ await websocket.accept()
+ await websocket.send_json(
+ {"scheme": credentials.scheme, "credentials": credentials.credentials}
+ )
+
+
+client = TestClient(app)
+
+
+def test_security_http_digest_ws():
+ with client.websocket_connect(
+ "/ws/users/me", headers={"Authorization": "Digest foobar"}
+ ) as websocket:
+ data = websocket.receive_json()
+ assert data == {"scheme": "Digest", "credentials": "foobar"}
+
+
+def test_security_http_digest_no_credentials_ws():
+ with pytest.raises(WebSocketDenialResponse):
+ with client.websocket_connect("/ws/users/me"):
+ pass
+
+
+def test_security_http_digest_incorrect_scheme_ws():
+ with pytest.raises(WebSocketDenialResponse):
+ with client.websocket_connect(
+ "/ws/users/me", headers={"Authorization": "Basic notreally"}
+ ):
+ pass
diff --git a/tests/test_security_oauth2_authorization_code_bearer_websocket.py b/tests/test_security_oauth2_authorization_code_bearer_websocket.py
new file mode 100644
index 0000000000..db983607f5
--- /dev/null
+++ b/tests/test_security_oauth2_authorization_code_bearer_websocket.py
@@ -0,0 +1,45 @@
+import pytest
+from fastapi import FastAPI, Security
+from fastapi.security import OAuth2AuthorizationCodeBearer
+from fastapi.testclient import TestClient
+from starlette.testclient import WebSocketDenialResponse
+from starlette.websockets import WebSocket
+
+app = FastAPI()
+
+oauth2_scheme = OAuth2AuthorizationCodeBearer(
+ authorizationUrl="/api/oauth/authorize",
+ tokenUrl="/api/oauth/token",
+ scopes={"read": "Read access", "write": "Write access"},
+)
+
+
+@app.websocket("/ws/admin")
+async def read_admin(websocket: WebSocket, token: str = Security(oauth2_scheme)):
+ await websocket.accept()
+ await websocket.send_text(token)
+
+
+client = TestClient(app)
+
+
+def test_security_oauth2_authorization_code_bearer_ws():
+ with client.websocket_connect(
+ "/ws/admin", headers={"Authorization": "Bearer faketoken"}
+ ) as websocket:
+ data = websocket.receive_text()
+ assert data == "faketoken"
+
+
+def test_security_oauth2_authorization_code_bearer_no_header_ws():
+ with pytest.raises(WebSocketDenialResponse):
+ with client.websocket_connect("/ws/admin"):
+ pass
+
+
+def test_security_oauth2_authorization_code_bearer_wrong_scheme_ws():
+ with pytest.raises(WebSocketDenialResponse):
+ with client.websocket_connect(
+ "/ws/admin", headers={"Authorization": "Basic nope"}
+ ):
+ pass
diff --git a/tests/test_security_oauth2_password_bearer_websocket.py b/tests/test_security_oauth2_password_bearer_websocket.py
new file mode 100644
index 0000000000..8c71adace0
--- /dev/null
+++ b/tests/test_security_oauth2_password_bearer_websocket.py
@@ -0,0 +1,41 @@
+import pytest
+from fastapi import FastAPI, Security
+from fastapi.security import OAuth2PasswordBearer
+from fastapi.testclient import TestClient
+from starlette.testclient import WebSocketDenialResponse
+from starlette.websockets import WebSocket
+
+app = FastAPI()
+
+oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
+
+
+@app.websocket("/ws/token")
+async def read_token(websocket: WebSocket, token: str = Security(oauth2_scheme)):
+ await websocket.accept()
+ await websocket.send_text(token)
+
+
+client = TestClient(app)
+
+
+def test_security_oauth2_password_bearer_ws():
+ with client.websocket_connect(
+ "/ws/token", headers={"Authorization": "Bearer faketoken"}
+ ) as websocket:
+ data = websocket.receive_text()
+ assert data == "faketoken"
+
+
+def test_security_oauth2_password_bearer_no_header_ws():
+ with pytest.raises(WebSocketDenialResponse):
+ with client.websocket_connect("/ws/token"):
+ pass
+
+
+def test_security_oauth2_password_bearer_wrong_scheme_ws():
+ with pytest.raises(WebSocketDenialResponse):
+ with client.websocket_connect(
+ "/ws/token", headers={"Authorization": "Basic nope"}
+ ):
+ pass
diff --git a/tests/test_security_openid_connect_websocket.py b/tests/test_security_openid_connect_websocket.py
new file mode 100644
index 0000000000..d9a0981409
--- /dev/null
+++ b/tests/test_security_openid_connect_websocket.py
@@ -0,0 +1,53 @@
+import pytest
+from fastapi import Depends, FastAPI, Security
+from fastapi.security.open_id_connect_url import OpenIdConnect
+from fastapi.testclient import TestClient
+from pydantic import BaseModel
+from starlette.testclient import WebSocketDenialResponse
+from starlette.websockets import WebSocket
+
+app = FastAPI()
+
+oid = OpenIdConnect(openIdConnectUrl="/openid")
+
+
+class User(BaseModel):
+ username: str
+
+
+def get_current_user(oauth_header: str = Security(oid)):
+ user = User(username=oauth_header)
+ return user
+
+
+@app.websocket("/ws/users/me")
+async def read_current_user(
+ websocket: WebSocket, current_user: User = Depends(get_current_user)
+):
+ await websocket.accept()
+ await websocket.send_json(current_user.model_dump())
+
+
+client = TestClient(app)
+
+
+def test_security_openid_connect_ws():
+ with client.websocket_connect(
+ "/ws/users/me", headers={"Authorization": "Bearer footokenbar"}
+ ) as websocket:
+ data = websocket.receive_json()
+ assert data == {"username": "Bearer footokenbar"}
+
+
+def test_security_openid_connect_other_header_ws():
+ with client.websocket_connect(
+ "/ws/users/me", headers={"Authorization": "Other footokenbar"}
+ ) as websocket:
+ data = websocket.receive_json()
+ assert data == {"username": "Other footokenbar"}
+
+
+def test_security_openid_connect_no_header_ws():
+ with pytest.raises(WebSocketDenialResponse):
+ with client.websocket_connect("/ws/users/me"):
+ pass
From bf7e2b7cabff4870976d270a6e7a3b884ed17078 Mon Sep 17 00:00:00 2001
From: philipmunch
Date: Wed, 17 Dec 2025 20:19:06 +0100
Subject: [PATCH 2/2] Security dependencies now working with websockets
---
fastapi/security/api_key.py | 8 +--
fastapi/security/http.py | 10 ++--
fastapi/security/oauth2.py | 8 +--
fastapi/security/open_id_connect_url.py | 4 +-
.../test_security_api_key_cookie_websocket.py | 42 +++++++++++++++
.../test_security_api_key_header_websocket.py | 45 ++++++++++++++++
.../test_security_api_key_query_websocket.py | 43 +++++++++++++++
tests/test_security_http_base_websocket.py | 38 +++++++++++++
tests/test_security_http_basic_websocket.py | 50 +++++++++++++++++
tests/test_security_http_bearer_websocket.py | 46 ++++++++++++++++
tests/test_security_http_digest_websocket.py | 46 ++++++++++++++++
...th2_authorization_code_bearer_websocket.py | 45 ++++++++++++++++
...curity_oauth2_password_bearer_websocket.py | 41 ++++++++++++++
.../test_security_openid_connect_websocket.py | 53 +++++++++++++++++++
14 files changed, 464 insertions(+), 15 deletions(-)
create mode 100644 tests/test_security_api_key_cookie_websocket.py
create mode 100644 tests/test_security_api_key_header_websocket.py
create mode 100644 tests/test_security_api_key_query_websocket.py
create mode 100644 tests/test_security_http_base_websocket.py
create mode 100644 tests/test_security_http_basic_websocket.py
create mode 100644 tests/test_security_http_bearer_websocket.py
create mode 100644 tests/test_security_http_digest_websocket.py
create mode 100644 tests/test_security_oauth2_authorization_code_bearer_websocket.py
create mode 100644 tests/test_security_oauth2_password_bearer_websocket.py
create mode 100644 tests/test_security_openid_connect_websocket.py
diff --git a/fastapi/security/api_key.py b/fastapi/security/api_key.py
index 81c7be10d6..f64370e3e5 100644
--- a/fastapi/security/api_key.py
+++ b/fastapi/security/api_key.py
@@ -4,7 +4,7 @@ from annotated_doc import Doc
from fastapi.openapi.models import APIKey, APIKeyIn
from fastapi.security.base import SecurityBase
from starlette.exceptions import HTTPException
-from starlette.requests import Request
+from starlette.requests import HTTPConnection
from starlette.status import HTTP_401_UNAUTHORIZED
from typing_extensions import Annotated
@@ -138,7 +138,7 @@ class APIKeyQuery(APIKeyBase):
auto_error=auto_error,
)
- async def __call__(self, request: Request) -> Optional[str]:
+ async def __call__(self, request: HTTPConnection) -> Optional[str]:
api_key = request.query_params.get(self.model.name)
return self.check_api_key(api_key)
@@ -226,7 +226,7 @@ class APIKeyHeader(APIKeyBase):
auto_error=auto_error,
)
- async def __call__(self, request: Request) -> Optional[str]:
+ async def __call__(self, request: HTTPConnection) -> Optional[str]:
api_key = request.headers.get(self.model.name)
return self.check_api_key(api_key)
@@ -314,6 +314,6 @@ class APIKeyCookie(APIKeyBase):
auto_error=auto_error,
)
- async def __call__(self, request: Request) -> Optional[str]:
+ async def __call__(self, request: HTTPConnection) -> Optional[str]:
api_key = request.cookies.get(self.model.name)
return self.check_api_key(api_key)
diff --git a/fastapi/security/http.py b/fastapi/security/http.py
index 0d1bbba3a0..a6de1ba2ba 100644
--- a/fastapi/security/http.py
+++ b/fastapi/security/http.py
@@ -9,7 +9,7 @@ from fastapi.openapi.models import HTTPBearer as HTTPBearerModel
from fastapi.security.base import SecurityBase
from fastapi.security.utils import get_authorization_scheme_param
from pydantic import BaseModel
-from starlette.requests import Request
+from starlette.requests import HTTPConnection
from starlette.status import HTTP_401_UNAUTHORIZED
from typing_extensions import Annotated
@@ -93,7 +93,7 @@ class HTTPBase(SecurityBase):
)
async def __call__(
- self, request: Request
+ self, request: HTTPConnection
) -> Optional[HTTPAuthorizationCredentials]:
authorization = request.headers.get("Authorization")
scheme, credentials = get_authorization_scheme_param(authorization)
@@ -203,7 +203,7 @@ class HTTPBasic(HTTPBase):
return {"WWW-Authenticate": "Basic"}
async def __call__( # type: ignore
- self, request: Request
+ self, request: HTTPConnection
) -> Optional[HTTPBasicCredentials]:
authorization = request.headers.get("Authorization")
scheme, param = get_authorization_scheme_param(authorization)
@@ -304,7 +304,7 @@ class HTTPBearer(HTTPBase):
self.auto_error = auto_error
async def __call__(
- self, request: Request
+ self, request: HTTPConnection
) -> Optional[HTTPAuthorizationCredentials]:
authorization = request.headers.get("Authorization")
scheme, credentials = get_authorization_scheme_param(authorization)
@@ -407,7 +407,7 @@ class HTTPDigest(HTTPBase):
self.auto_error = auto_error
async def __call__(
- self, request: Request
+ self, request: HTTPConnection
) -> Optional[HTTPAuthorizationCredentials]:
authorization = request.headers.get("Authorization")
scheme, credentials = get_authorization_scheme_param(authorization)
diff --git a/fastapi/security/oauth2.py b/fastapi/security/oauth2.py
index b41b0f8778..7d666d320e 100644
--- a/fastapi/security/oauth2.py
+++ b/fastapi/security/oauth2.py
@@ -7,7 +7,7 @@ from fastapi.openapi.models import OAuthFlows as OAuthFlowsModel
from fastapi.param_functions import Form
from fastapi.security.base import SecurityBase
from fastapi.security.utils import get_authorization_scheme_param
-from starlette.requests import Request
+from starlette.requests import HTTPConnection
from starlette.status import HTTP_401_UNAUTHORIZED
# TODO: import from typing when deprecating Python 3.9
@@ -399,7 +399,7 @@ class OAuth2(SecurityBase):
headers={"WWW-Authenticate": "Bearer"},
)
- async def __call__(self, request: Request) -> Optional[str]:
+ async def __call__(self, request: HTTPConnection) -> Optional[str]:
authorization = request.headers.get("Authorization")
if not authorization:
if self.auto_error:
@@ -506,7 +506,7 @@ class OAuth2PasswordBearer(OAuth2):
auto_error=auto_error,
)
- async def __call__(self, request: Request) -> Optional[str]:
+ async def __call__(self, request: HTTPConnection) -> Optional[str]:
authorization = request.headers.get("Authorization")
scheme, param = get_authorization_scheme_param(authorization)
if not authorization or scheme.lower() != "bearer":
@@ -612,7 +612,7 @@ class OAuth2AuthorizationCodeBearer(OAuth2):
auto_error=auto_error,
)
- async def __call__(self, request: Request) -> Optional[str]:
+ async def __call__(self, request: HTTPConnection) -> Optional[str]:
authorization = request.headers.get("Authorization")
scheme, param = get_authorization_scheme_param(authorization)
if not authorization or scheme.lower() != "bearer":
diff --git a/fastapi/security/open_id_connect_url.py b/fastapi/security/open_id_connect_url.py
index e574a56a82..4fdd7a89c7 100644
--- a/fastapi/security/open_id_connect_url.py
+++ b/fastapi/security/open_id_connect_url.py
@@ -4,7 +4,7 @@ from annotated_doc import Doc
from fastapi.openapi.models import OpenIdConnect as OpenIdConnectModel
from fastapi.security.base import SecurityBase
from starlette.exceptions import HTTPException
-from starlette.requests import Request
+from starlette.requests import HTTPConnection
from starlette.status import HTTP_401_UNAUTHORIZED
from typing_extensions import Annotated
@@ -85,7 +85,7 @@ class OpenIdConnect(SecurityBase):
headers={"WWW-Authenticate": "Bearer"},
)
- async def __call__(self, request: Request) -> Optional[str]:
+ async def __call__(self, request: HTTPConnection) -> Optional[str]:
authorization = request.headers.get("Authorization")
if not authorization:
if self.auto_error:
diff --git a/tests/test_security_api_key_cookie_websocket.py b/tests/test_security_api_key_cookie_websocket.py
new file mode 100644
index 0000000000..d4e0a4f57e
--- /dev/null
+++ b/tests/test_security_api_key_cookie_websocket.py
@@ -0,0 +1,42 @@
+import pytest
+from fastapi import Depends, FastAPI, Security
+from fastapi.security import APIKeyCookie
+from fastapi.testclient import TestClient
+from pydantic import BaseModel
+from starlette.testclient import WebSocketDenialResponse
+from starlette.websockets import WebSocket
+
+app = FastAPI()
+
+api_key = APIKeyCookie(name="key")
+
+
+class User(BaseModel):
+ username: str
+
+
+def get_current_user(oauth_header: str = Security(api_key)):
+ user = User(username=oauth_header)
+ return user
+
+
+@app.websocket("/ws/users/me")
+async def read_current_user(
+ websocket: WebSocket, current_user: User = Depends(get_current_user)
+):
+ await websocket.accept()
+ await websocket.send_text(current_user.username)
+
+
+def test_security_api_key_ws():
+ client = TestClient(app, cookies={"key": "secret"})
+ with client.websocket_connect("/ws/users/me") as websocket:
+ data = websocket.receive_text()
+ assert data == "secret"
+
+
+def test_security_api_key_no_key_ws():
+ client = TestClient(app)
+ with pytest.raises(WebSocketDenialResponse):
+ with client.websocket_connect("/ws/users/me"):
+ pass
diff --git a/tests/test_security_api_key_header_websocket.py b/tests/test_security_api_key_header_websocket.py
new file mode 100644
index 0000000000..636a03a89e
--- /dev/null
+++ b/tests/test_security_api_key_header_websocket.py
@@ -0,0 +1,45 @@
+import pytest
+from fastapi import Depends, FastAPI, Security
+from fastapi.security import APIKeyHeader
+from fastapi.testclient import TestClient
+from pydantic import BaseModel
+from starlette.testclient import WebSocketDenialResponse
+from starlette.websockets import WebSocket
+
+app = FastAPI()
+
+api_key = APIKeyHeader(name="key")
+
+
+class User(BaseModel):
+ username: str
+
+
+def get_current_user(oauth_header: str = Security(api_key)):
+ user = User(username=oauth_header)
+ return user
+
+
+@app.websocket("/ws/users/me")
+async def read_current_user(
+ websocket: WebSocket, current_user: User = Depends(get_current_user)
+):
+ await websocket.accept()
+ await websocket.send_text(current_user.username)
+
+
+client = TestClient(app)
+
+
+def test_security_api_key_ws():
+ with client.websocket_connect(
+ "/ws/users/me", headers={"key": "secret"}
+ ) as websocket:
+ data = websocket.receive_text()
+ assert data == "secret"
+
+
+def test_security_api_key_no_key_ws():
+ with pytest.raises(WebSocketDenialResponse):
+ with client.websocket_connect("/ws/users/me"):
+ pass
diff --git a/tests/test_security_api_key_query_websocket.py b/tests/test_security_api_key_query_websocket.py
new file mode 100644
index 0000000000..5264538a69
--- /dev/null
+++ b/tests/test_security_api_key_query_websocket.py
@@ -0,0 +1,43 @@
+import pytest
+from fastapi import Depends, FastAPI, Security
+from fastapi.security import APIKeyQuery
+from fastapi.testclient import TestClient
+from pydantic import BaseModel
+from starlette.testclient import WebSocketDenialResponse
+from starlette.websockets import WebSocket
+
+app = FastAPI()
+
+api_key = APIKeyQuery(name="key")
+
+
+class User(BaseModel):
+ username: str
+
+
+def get_current_user(oauth_header: str = Security(api_key)):
+ user = User(username=oauth_header)
+ return user
+
+
+@app.websocket("/ws/users/me")
+async def read_current_user(
+ websocket: WebSocket, current_user: User = Depends(get_current_user)
+):
+ await websocket.accept()
+ await websocket.send_text(current_user.username)
+
+
+client = TestClient(app)
+
+
+def test_security_api_key_query_ws():
+ with client.websocket_connect("/ws/users/me?key=secret") as websocket:
+ data = websocket.receive_text()
+ assert data == "secret"
+
+
+def test_security_api_key_query_no_key_ws():
+ with pytest.raises(WebSocketDenialResponse):
+ with client.websocket_connect("/ws/users/me"):
+ pass
diff --git a/tests/test_security_http_base_websocket.py b/tests/test_security_http_base_websocket.py
new file mode 100644
index 0000000000..7041471ea5
--- /dev/null
+++ b/tests/test_security_http_base_websocket.py
@@ -0,0 +1,38 @@
+import pytest
+from fastapi import FastAPI, Security
+from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBase
+from fastapi.testclient import TestClient
+from starlette.testclient import WebSocketDenialResponse
+from starlette.websockets import WebSocket
+
+app = FastAPI()
+
+security = HTTPBase(scheme="Other")
+
+
+@app.websocket("/ws/users/me")
+async def read_current_user(
+ websocket: WebSocket,
+ credentials: HTTPAuthorizationCredentials = Security(security),
+):
+ await websocket.accept()
+ await websocket.send_json(
+ {"scheme": credentials.scheme, "credentials": credentials.credentials}
+ )
+
+
+client = TestClient(app)
+
+
+def test_security_http_base_ws():
+ with client.websocket_connect(
+ "/ws/users/me", headers={"Authorization": "Other foobar"}
+ ) as websocket:
+ data = websocket.receive_json()
+ assert data == {"scheme": "Other", "credentials": "foobar"}
+
+
+def test_security_http_base_no_credentials_ws():
+ with pytest.raises(WebSocketDenialResponse):
+ with client.websocket_connect("/ws/users/me"):
+ pass
diff --git a/tests/test_security_http_basic_websocket.py b/tests/test_security_http_basic_websocket.py
new file mode 100644
index 0000000000..809487bb05
--- /dev/null
+++ b/tests/test_security_http_basic_websocket.py
@@ -0,0 +1,50 @@
+from base64 import b64encode
+
+import pytest
+from fastapi import FastAPI, Security
+from fastapi.security import HTTPBasic, HTTPBasicCredentials
+from fastapi.testclient import TestClient
+from starlette.testclient import WebSocketDenialResponse
+from starlette.websockets import WebSocket
+
+app = FastAPI()
+
+security = HTTPBasic(realm="simple")
+
+
+@app.websocket("/ws/users/me")
+async def read_current_user(
+ websocket: WebSocket, credentials: HTTPBasicCredentials = Security(security)
+):
+ await websocket.accept()
+ await websocket.send_json(
+ {"username": credentials.username, "password": credentials.password}
+ )
+
+
+client = TestClient(app)
+
+
+def test_security_http_basic_ws():
+ # Build Basic header
+ payload = b64encode(b"john:secret").decode("ascii")
+ auth_header = f"Basic {payload}"
+ with client.websocket_connect(
+ "/ws/users/me", headers={"Authorization": auth_header}
+ ) as websocket:
+ data = websocket.receive_json()
+ assert data == {"username": "john", "password": "secret"}
+
+
+def test_security_http_basic_no_credentials_ws():
+ with pytest.raises(WebSocketDenialResponse):
+ with client.websocket_connect("/ws/users/me"):
+ pass
+
+
+def test_security_http_basic_invalid_credentials_ws():
+ with pytest.raises(WebSocketDenialResponse):
+ with client.websocket_connect(
+ "/ws/users/me", headers={"Authorization": "Basic notbase64"}
+ ):
+ pass
diff --git a/tests/test_security_http_bearer_websocket.py b/tests/test_security_http_bearer_websocket.py
new file mode 100644
index 0000000000..ff41652a6b
--- /dev/null
+++ b/tests/test_security_http_bearer_websocket.py
@@ -0,0 +1,46 @@
+import pytest
+from fastapi import FastAPI, Security
+from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
+from fastapi.testclient import TestClient
+from starlette.testclient import WebSocketDenialResponse
+from starlette.websockets import WebSocket
+
+app = FastAPI()
+
+security = HTTPBearer()
+
+
+@app.websocket("/ws/users/me")
+async def read_current_user(
+ websocket: WebSocket,
+ credentials: HTTPAuthorizationCredentials = Security(security),
+):
+ await websocket.accept()
+ await websocket.send_json(
+ {"scheme": credentials.scheme, "credentials": credentials.credentials}
+ )
+
+
+client = TestClient(app)
+
+
+def test_security_http_bearer_ws():
+ with client.websocket_connect(
+ "/ws/users/me", headers={"Authorization": "Bearer foobar"}
+ ) as websocket:
+ data = websocket.receive_json()
+ assert data == {"scheme": "Bearer", "credentials": "foobar"}
+
+
+def test_security_http_bearer_no_credentials_ws():
+ with pytest.raises(WebSocketDenialResponse):
+ with client.websocket_connect("/ws/users/me"):
+ pass
+
+
+def test_security_http_bearer_incorrect_scheme_ws():
+ with pytest.raises(WebSocketDenialResponse):
+ with client.websocket_connect(
+ "/ws/users/me", headers={"Authorization": "Basic notreally"}
+ ):
+ pass
diff --git a/tests/test_security_http_digest_websocket.py b/tests/test_security_http_digest_websocket.py
new file mode 100644
index 0000000000..065507a851
--- /dev/null
+++ b/tests/test_security_http_digest_websocket.py
@@ -0,0 +1,46 @@
+import pytest
+from fastapi import FastAPI, Security
+from fastapi.security import HTTPAuthorizationCredentials, HTTPDigest
+from fastapi.testclient import TestClient
+from starlette.testclient import WebSocketDenialResponse
+from starlette.websockets import WebSocket
+
+app = FastAPI()
+
+security = HTTPDigest()
+
+
+@app.websocket("/ws/users/me")
+async def read_current_user(
+ websocket: WebSocket,
+ credentials: HTTPAuthorizationCredentials = Security(security),
+):
+ await websocket.accept()
+ await websocket.send_json(
+ {"scheme": credentials.scheme, "credentials": credentials.credentials}
+ )
+
+
+client = TestClient(app)
+
+
+def test_security_http_digest_ws():
+ with client.websocket_connect(
+ "/ws/users/me", headers={"Authorization": "Digest foobar"}
+ ) as websocket:
+ data = websocket.receive_json()
+ assert data == {"scheme": "Digest", "credentials": "foobar"}
+
+
+def test_security_http_digest_no_credentials_ws():
+ with pytest.raises(WebSocketDenialResponse):
+ with client.websocket_connect("/ws/users/me"):
+ pass
+
+
+def test_security_http_digest_incorrect_scheme_ws():
+ with pytest.raises(WebSocketDenialResponse):
+ with client.websocket_connect(
+ "/ws/users/me", headers={"Authorization": "Basic notreally"}
+ ):
+ pass
diff --git a/tests/test_security_oauth2_authorization_code_bearer_websocket.py b/tests/test_security_oauth2_authorization_code_bearer_websocket.py
new file mode 100644
index 0000000000..db983607f5
--- /dev/null
+++ b/tests/test_security_oauth2_authorization_code_bearer_websocket.py
@@ -0,0 +1,45 @@
+import pytest
+from fastapi import FastAPI, Security
+from fastapi.security import OAuth2AuthorizationCodeBearer
+from fastapi.testclient import TestClient
+from starlette.testclient import WebSocketDenialResponse
+from starlette.websockets import WebSocket
+
+app = FastAPI()
+
+oauth2_scheme = OAuth2AuthorizationCodeBearer(
+ authorizationUrl="/api/oauth/authorize",
+ tokenUrl="/api/oauth/token",
+ scopes={"read": "Read access", "write": "Write access"},
+)
+
+
+@app.websocket("/ws/admin")
+async def read_admin(websocket: WebSocket, token: str = Security(oauth2_scheme)):
+ await websocket.accept()
+ await websocket.send_text(token)
+
+
+client = TestClient(app)
+
+
+def test_security_oauth2_authorization_code_bearer_ws():
+ with client.websocket_connect(
+ "/ws/admin", headers={"Authorization": "Bearer faketoken"}
+ ) as websocket:
+ data = websocket.receive_text()
+ assert data == "faketoken"
+
+
+def test_security_oauth2_authorization_code_bearer_no_header_ws():
+ with pytest.raises(WebSocketDenialResponse):
+ with client.websocket_connect("/ws/admin"):
+ pass
+
+
+def test_security_oauth2_authorization_code_bearer_wrong_scheme_ws():
+ with pytest.raises(WebSocketDenialResponse):
+ with client.websocket_connect(
+ "/ws/admin", headers={"Authorization": "Basic nope"}
+ ):
+ pass
diff --git a/tests/test_security_oauth2_password_bearer_websocket.py b/tests/test_security_oauth2_password_bearer_websocket.py
new file mode 100644
index 0000000000..8c71adace0
--- /dev/null
+++ b/tests/test_security_oauth2_password_bearer_websocket.py
@@ -0,0 +1,41 @@
+import pytest
+from fastapi import FastAPI, Security
+from fastapi.security import OAuth2PasswordBearer
+from fastapi.testclient import TestClient
+from starlette.testclient import WebSocketDenialResponse
+from starlette.websockets import WebSocket
+
+app = FastAPI()
+
+oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
+
+
+@app.websocket("/ws/token")
+async def read_token(websocket: WebSocket, token: str = Security(oauth2_scheme)):
+ await websocket.accept()
+ await websocket.send_text(token)
+
+
+client = TestClient(app)
+
+
+def test_security_oauth2_password_bearer_ws():
+ with client.websocket_connect(
+ "/ws/token", headers={"Authorization": "Bearer faketoken"}
+ ) as websocket:
+ data = websocket.receive_text()
+ assert data == "faketoken"
+
+
+def test_security_oauth2_password_bearer_no_header_ws():
+ with pytest.raises(WebSocketDenialResponse):
+ with client.websocket_connect("/ws/token"):
+ pass
+
+
+def test_security_oauth2_password_bearer_wrong_scheme_ws():
+ with pytest.raises(WebSocketDenialResponse):
+ with client.websocket_connect(
+ "/ws/token", headers={"Authorization": "Basic nope"}
+ ):
+ pass
diff --git a/tests/test_security_openid_connect_websocket.py b/tests/test_security_openid_connect_websocket.py
new file mode 100644
index 0000000000..f18c030ecd
--- /dev/null
+++ b/tests/test_security_openid_connect_websocket.py
@@ -0,0 +1,53 @@
+import pytest
+from fastapi import Depends, FastAPI, Security
+from fastapi.security.open_id_connect_url import OpenIdConnect
+from fastapi.testclient import TestClient
+from pydantic import BaseModel
+from starlette.testclient import WebSocketDenialResponse
+from starlette.websockets import WebSocket
+
+app = FastAPI()
+
+oid = OpenIdConnect(openIdConnectUrl="/openid")
+
+
+class User(BaseModel):
+ username: str
+
+
+def get_current_user(oauth_header: str = Security(oid)):
+ user = User(username=oauth_header)
+ return user
+
+
+@app.websocket("/ws/users/me")
+async def read_current_user(
+ websocket: WebSocket, current_user: User = Depends(get_current_user)
+):
+ await websocket.accept()
+ await websocket.send_json({"username": current_user.username})
+
+
+client = TestClient(app)
+
+
+def test_security_openid_connect_ws():
+ with client.websocket_connect(
+ "/ws/users/me", headers={"Authorization": "Bearer footokenbar"}
+ ) as websocket:
+ data = websocket.receive_json()
+ assert data == {"username": "Bearer footokenbar"}
+
+
+def test_security_openid_connect_other_header_ws():
+ with client.websocket_connect(
+ "/ws/users/me", headers={"Authorization": "Other footokenbar"}
+ ) as websocket:
+ data = websocket.receive_json()
+ assert data == {"username": "Other footokenbar"}
+
+
+def test_security_openid_connect_no_header_ws():
+ with pytest.raises(WebSocketDenialResponse):
+ with client.websocket_connect("/ws/users/me"):
+ pass