From 30f4123ee7ef755c891e0e68d4c55c8321a7a673 Mon Sep 17 00:00:00 2001 From: philipmunch Date: Wed, 17 Dec 2025 20:19:06 +0100 Subject: [PATCH 1/2] Security dependencies now working with websockets --- fastapi/security/api_key.py | 8 +-- fastapi/security/http.py | 10 ++-- fastapi/security/oauth2.py | 8 +-- fastapi/security/open_id_connect_url.py | 4 +- .../test_security_api_key_cookie_websocket.py | 42 +++++++++++++++ .../test_security_api_key_header_websocket.py | 45 ++++++++++++++++ .../test_security_api_key_query_websocket.py | 43 +++++++++++++++ tests/test_security_http_base_websocket.py | 38 +++++++++++++ tests/test_security_http_basic_websocket.py | 50 +++++++++++++++++ tests/test_security_http_bearer_websocket.py | 46 ++++++++++++++++ tests/test_security_http_digest_websocket.py | 46 ++++++++++++++++ ...th2_authorization_code_bearer_websocket.py | 45 ++++++++++++++++ ...curity_oauth2_password_bearer_websocket.py | 41 ++++++++++++++ .../test_security_openid_connect_websocket.py | 53 +++++++++++++++++++ 14 files changed, 464 insertions(+), 15 deletions(-) create mode 100644 tests/test_security_api_key_cookie_websocket.py create mode 100644 tests/test_security_api_key_header_websocket.py create mode 100644 tests/test_security_api_key_query_websocket.py create mode 100644 tests/test_security_http_base_websocket.py create mode 100644 tests/test_security_http_basic_websocket.py create mode 100644 tests/test_security_http_bearer_websocket.py create mode 100644 tests/test_security_http_digest_websocket.py create mode 100644 tests/test_security_oauth2_authorization_code_bearer_websocket.py create mode 100644 tests/test_security_oauth2_password_bearer_websocket.py create mode 100644 tests/test_security_openid_connect_websocket.py diff --git a/fastapi/security/api_key.py b/fastapi/security/api_key.py index 81c7be10d6..f64370e3e5 100644 --- a/fastapi/security/api_key.py +++ b/fastapi/security/api_key.py @@ -4,7 +4,7 @@ from annotated_doc import Doc from fastapi.openapi.models import APIKey, APIKeyIn from fastapi.security.base import SecurityBase from starlette.exceptions import HTTPException -from starlette.requests import Request +from starlette.requests import HTTPConnection from starlette.status import HTTP_401_UNAUTHORIZED from typing_extensions import Annotated @@ -138,7 +138,7 @@ class APIKeyQuery(APIKeyBase): auto_error=auto_error, ) - async def __call__(self, request: Request) -> Optional[str]: + async def __call__(self, request: HTTPConnection) -> Optional[str]: api_key = request.query_params.get(self.model.name) return self.check_api_key(api_key) @@ -226,7 +226,7 @@ class APIKeyHeader(APIKeyBase): auto_error=auto_error, ) - async def __call__(self, request: Request) -> Optional[str]: + async def __call__(self, request: HTTPConnection) -> Optional[str]: api_key = request.headers.get(self.model.name) return self.check_api_key(api_key) @@ -314,6 +314,6 @@ class APIKeyCookie(APIKeyBase): auto_error=auto_error, ) - async def __call__(self, request: Request) -> Optional[str]: + async def __call__(self, request: HTTPConnection) -> Optional[str]: api_key = request.cookies.get(self.model.name) return self.check_api_key(api_key) diff --git a/fastapi/security/http.py b/fastapi/security/http.py index 0d1bbba3a0..a6de1ba2ba 100644 --- a/fastapi/security/http.py +++ b/fastapi/security/http.py @@ -9,7 +9,7 @@ from fastapi.openapi.models import HTTPBearer as HTTPBearerModel from fastapi.security.base import SecurityBase from fastapi.security.utils import get_authorization_scheme_param from pydantic import BaseModel -from starlette.requests import Request +from starlette.requests import HTTPConnection from starlette.status import HTTP_401_UNAUTHORIZED from typing_extensions import Annotated @@ -93,7 +93,7 @@ class HTTPBase(SecurityBase): ) async def __call__( - self, request: Request + self, request: HTTPConnection ) -> Optional[HTTPAuthorizationCredentials]: authorization = request.headers.get("Authorization") scheme, credentials = get_authorization_scheme_param(authorization) @@ -203,7 +203,7 @@ class HTTPBasic(HTTPBase): return {"WWW-Authenticate": "Basic"} async def __call__( # type: ignore - self, request: Request + self, request: HTTPConnection ) -> Optional[HTTPBasicCredentials]: authorization = request.headers.get("Authorization") scheme, param = get_authorization_scheme_param(authorization) @@ -304,7 +304,7 @@ class HTTPBearer(HTTPBase): self.auto_error = auto_error async def __call__( - self, request: Request + self, request: HTTPConnection ) -> Optional[HTTPAuthorizationCredentials]: authorization = request.headers.get("Authorization") scheme, credentials = get_authorization_scheme_param(authorization) @@ -407,7 +407,7 @@ class HTTPDigest(HTTPBase): self.auto_error = auto_error async def __call__( - self, request: Request + self, request: HTTPConnection ) -> Optional[HTTPAuthorizationCredentials]: authorization = request.headers.get("Authorization") scheme, credentials = get_authorization_scheme_param(authorization) diff --git a/fastapi/security/oauth2.py b/fastapi/security/oauth2.py index b41b0f8778..7d666d320e 100644 --- a/fastapi/security/oauth2.py +++ b/fastapi/security/oauth2.py @@ -7,7 +7,7 @@ from fastapi.openapi.models import OAuthFlows as OAuthFlowsModel from fastapi.param_functions import Form from fastapi.security.base import SecurityBase from fastapi.security.utils import get_authorization_scheme_param -from starlette.requests import Request +from starlette.requests import HTTPConnection from starlette.status import HTTP_401_UNAUTHORIZED # TODO: import from typing when deprecating Python 3.9 @@ -399,7 +399,7 @@ class OAuth2(SecurityBase): headers={"WWW-Authenticate": "Bearer"}, ) - async def __call__(self, request: Request) -> Optional[str]: + async def __call__(self, request: HTTPConnection) -> Optional[str]: authorization = request.headers.get("Authorization") if not authorization: if self.auto_error: @@ -506,7 +506,7 @@ class OAuth2PasswordBearer(OAuth2): auto_error=auto_error, ) - async def __call__(self, request: Request) -> Optional[str]: + async def __call__(self, request: HTTPConnection) -> Optional[str]: authorization = request.headers.get("Authorization") scheme, param = get_authorization_scheme_param(authorization) if not authorization or scheme.lower() != "bearer": @@ -612,7 +612,7 @@ class OAuth2AuthorizationCodeBearer(OAuth2): auto_error=auto_error, ) - async def __call__(self, request: Request) -> Optional[str]: + async def __call__(self, request: HTTPConnection) -> Optional[str]: authorization = request.headers.get("Authorization") scheme, param = get_authorization_scheme_param(authorization) if not authorization or scheme.lower() != "bearer": diff --git a/fastapi/security/open_id_connect_url.py b/fastapi/security/open_id_connect_url.py index e574a56a82..4fdd7a89c7 100644 --- a/fastapi/security/open_id_connect_url.py +++ b/fastapi/security/open_id_connect_url.py @@ -4,7 +4,7 @@ from annotated_doc import Doc from fastapi.openapi.models import OpenIdConnect as OpenIdConnectModel from fastapi.security.base import SecurityBase from starlette.exceptions import HTTPException -from starlette.requests import Request +from starlette.requests import HTTPConnection from starlette.status import HTTP_401_UNAUTHORIZED from typing_extensions import Annotated @@ -85,7 +85,7 @@ class OpenIdConnect(SecurityBase): headers={"WWW-Authenticate": "Bearer"}, ) - async def __call__(self, request: Request) -> Optional[str]: + async def __call__(self, request: HTTPConnection) -> Optional[str]: authorization = request.headers.get("Authorization") if not authorization: if self.auto_error: diff --git a/tests/test_security_api_key_cookie_websocket.py b/tests/test_security_api_key_cookie_websocket.py new file mode 100644 index 0000000000..d4e0a4f57e --- /dev/null +++ b/tests/test_security_api_key_cookie_websocket.py @@ -0,0 +1,42 @@ +import pytest +from fastapi import Depends, FastAPI, Security +from fastapi.security import APIKeyCookie +from fastapi.testclient import TestClient +from pydantic import BaseModel +from starlette.testclient import WebSocketDenialResponse +from starlette.websockets import WebSocket + +app = FastAPI() + +api_key = APIKeyCookie(name="key") + + +class User(BaseModel): + username: str + + +def get_current_user(oauth_header: str = Security(api_key)): + user = User(username=oauth_header) + return user + + +@app.websocket("/ws/users/me") +async def read_current_user( + websocket: WebSocket, current_user: User = Depends(get_current_user) +): + await websocket.accept() + await websocket.send_text(current_user.username) + + +def test_security_api_key_ws(): + client = TestClient(app, cookies={"key": "secret"}) + with client.websocket_connect("/ws/users/me") as websocket: + data = websocket.receive_text() + assert data == "secret" + + +def test_security_api_key_no_key_ws(): + client = TestClient(app) + with pytest.raises(WebSocketDenialResponse): + with client.websocket_connect("/ws/users/me"): + pass diff --git a/tests/test_security_api_key_header_websocket.py b/tests/test_security_api_key_header_websocket.py new file mode 100644 index 0000000000..636a03a89e --- /dev/null +++ b/tests/test_security_api_key_header_websocket.py @@ -0,0 +1,45 @@ +import pytest +from fastapi import Depends, FastAPI, Security +from fastapi.security import APIKeyHeader +from fastapi.testclient import TestClient +from pydantic import BaseModel +from starlette.testclient import WebSocketDenialResponse +from starlette.websockets import WebSocket + +app = FastAPI() + +api_key = APIKeyHeader(name="key") + + +class User(BaseModel): + username: str + + +def get_current_user(oauth_header: str = Security(api_key)): + user = User(username=oauth_header) + return user + + +@app.websocket("/ws/users/me") +async def read_current_user( + websocket: WebSocket, current_user: User = Depends(get_current_user) +): + await websocket.accept() + await websocket.send_text(current_user.username) + + +client = TestClient(app) + + +def test_security_api_key_ws(): + with client.websocket_connect( + "/ws/users/me", headers={"key": "secret"} + ) as websocket: + data = websocket.receive_text() + assert data == "secret" + + +def test_security_api_key_no_key_ws(): + with pytest.raises(WebSocketDenialResponse): + with client.websocket_connect("/ws/users/me"): + pass diff --git a/tests/test_security_api_key_query_websocket.py b/tests/test_security_api_key_query_websocket.py new file mode 100644 index 0000000000..5264538a69 --- /dev/null +++ b/tests/test_security_api_key_query_websocket.py @@ -0,0 +1,43 @@ +import pytest +from fastapi import Depends, FastAPI, Security +from fastapi.security import APIKeyQuery +from fastapi.testclient import TestClient +from pydantic import BaseModel +from starlette.testclient import WebSocketDenialResponse +from starlette.websockets import WebSocket + +app = FastAPI() + +api_key = APIKeyQuery(name="key") + + +class User(BaseModel): + username: str + + +def get_current_user(oauth_header: str = Security(api_key)): + user = User(username=oauth_header) + return user + + +@app.websocket("/ws/users/me") +async def read_current_user( + websocket: WebSocket, current_user: User = Depends(get_current_user) +): + await websocket.accept() + await websocket.send_text(current_user.username) + + +client = TestClient(app) + + +def test_security_api_key_query_ws(): + with client.websocket_connect("/ws/users/me?key=secret") as websocket: + data = websocket.receive_text() + assert data == "secret" + + +def test_security_api_key_query_no_key_ws(): + with pytest.raises(WebSocketDenialResponse): + with client.websocket_connect("/ws/users/me"): + pass diff --git a/tests/test_security_http_base_websocket.py b/tests/test_security_http_base_websocket.py new file mode 100644 index 0000000000..7041471ea5 --- /dev/null +++ b/tests/test_security_http_base_websocket.py @@ -0,0 +1,38 @@ +import pytest +from fastapi import FastAPI, Security +from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBase +from fastapi.testclient import TestClient +from starlette.testclient import WebSocketDenialResponse +from starlette.websockets import WebSocket + +app = FastAPI() + +security = HTTPBase(scheme="Other") + + +@app.websocket("/ws/users/me") +async def read_current_user( + websocket: WebSocket, + credentials: HTTPAuthorizationCredentials = Security(security), +): + await websocket.accept() + await websocket.send_json( + {"scheme": credentials.scheme, "credentials": credentials.credentials} + ) + + +client = TestClient(app) + + +def test_security_http_base_ws(): + with client.websocket_connect( + "/ws/users/me", headers={"Authorization": "Other foobar"} + ) as websocket: + data = websocket.receive_json() + assert data == {"scheme": "Other", "credentials": "foobar"} + + +def test_security_http_base_no_credentials_ws(): + with pytest.raises(WebSocketDenialResponse): + with client.websocket_connect("/ws/users/me"): + pass diff --git a/tests/test_security_http_basic_websocket.py b/tests/test_security_http_basic_websocket.py new file mode 100644 index 0000000000..809487bb05 --- /dev/null +++ b/tests/test_security_http_basic_websocket.py @@ -0,0 +1,50 @@ +from base64 import b64encode + +import pytest +from fastapi import FastAPI, Security +from fastapi.security import HTTPBasic, HTTPBasicCredentials +from fastapi.testclient import TestClient +from starlette.testclient import WebSocketDenialResponse +from starlette.websockets import WebSocket + +app = FastAPI() + +security = HTTPBasic(realm="simple") + + +@app.websocket("/ws/users/me") +async def read_current_user( + websocket: WebSocket, credentials: HTTPBasicCredentials = Security(security) +): + await websocket.accept() + await websocket.send_json( + {"username": credentials.username, "password": credentials.password} + ) + + +client = TestClient(app) + + +def test_security_http_basic_ws(): + # Build Basic header + payload = b64encode(b"john:secret").decode("ascii") + auth_header = f"Basic {payload}" + with client.websocket_connect( + "/ws/users/me", headers={"Authorization": auth_header} + ) as websocket: + data = websocket.receive_json() + assert data == {"username": "john", "password": "secret"} + + +def test_security_http_basic_no_credentials_ws(): + with pytest.raises(WebSocketDenialResponse): + with client.websocket_connect("/ws/users/me"): + pass + + +def test_security_http_basic_invalid_credentials_ws(): + with pytest.raises(WebSocketDenialResponse): + with client.websocket_connect( + "/ws/users/me", headers={"Authorization": "Basic notbase64"} + ): + pass diff --git a/tests/test_security_http_bearer_websocket.py b/tests/test_security_http_bearer_websocket.py new file mode 100644 index 0000000000..ff41652a6b --- /dev/null +++ b/tests/test_security_http_bearer_websocket.py @@ -0,0 +1,46 @@ +import pytest +from fastapi import FastAPI, Security +from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer +from fastapi.testclient import TestClient +from starlette.testclient import WebSocketDenialResponse +from starlette.websockets import WebSocket + +app = FastAPI() + +security = HTTPBearer() + + +@app.websocket("/ws/users/me") +async def read_current_user( + websocket: WebSocket, + credentials: HTTPAuthorizationCredentials = Security(security), +): + await websocket.accept() + await websocket.send_json( + {"scheme": credentials.scheme, "credentials": credentials.credentials} + ) + + +client = TestClient(app) + + +def test_security_http_bearer_ws(): + with client.websocket_connect( + "/ws/users/me", headers={"Authorization": "Bearer foobar"} + ) as websocket: + data = websocket.receive_json() + assert data == {"scheme": "Bearer", "credentials": "foobar"} + + +def test_security_http_bearer_no_credentials_ws(): + with pytest.raises(WebSocketDenialResponse): + with client.websocket_connect("/ws/users/me"): + pass + + +def test_security_http_bearer_incorrect_scheme_ws(): + with pytest.raises(WebSocketDenialResponse): + with client.websocket_connect( + "/ws/users/me", headers={"Authorization": "Basic notreally"} + ): + pass diff --git a/tests/test_security_http_digest_websocket.py b/tests/test_security_http_digest_websocket.py new file mode 100644 index 0000000000..065507a851 --- /dev/null +++ b/tests/test_security_http_digest_websocket.py @@ -0,0 +1,46 @@ +import pytest +from fastapi import FastAPI, Security +from fastapi.security import HTTPAuthorizationCredentials, HTTPDigest +from fastapi.testclient import TestClient +from starlette.testclient import WebSocketDenialResponse +from starlette.websockets import WebSocket + +app = FastAPI() + +security = HTTPDigest() + + +@app.websocket("/ws/users/me") +async def read_current_user( + websocket: WebSocket, + credentials: HTTPAuthorizationCredentials = Security(security), +): + await websocket.accept() + await websocket.send_json( + {"scheme": credentials.scheme, "credentials": credentials.credentials} + ) + + +client = TestClient(app) + + +def test_security_http_digest_ws(): + with client.websocket_connect( + "/ws/users/me", headers={"Authorization": "Digest foobar"} + ) as websocket: + data = websocket.receive_json() + assert data == {"scheme": "Digest", "credentials": "foobar"} + + +def test_security_http_digest_no_credentials_ws(): + with pytest.raises(WebSocketDenialResponse): + with client.websocket_connect("/ws/users/me"): + pass + + +def test_security_http_digest_incorrect_scheme_ws(): + with pytest.raises(WebSocketDenialResponse): + with client.websocket_connect( + "/ws/users/me", headers={"Authorization": "Basic notreally"} + ): + pass diff --git a/tests/test_security_oauth2_authorization_code_bearer_websocket.py b/tests/test_security_oauth2_authorization_code_bearer_websocket.py new file mode 100644 index 0000000000..db983607f5 --- /dev/null +++ b/tests/test_security_oauth2_authorization_code_bearer_websocket.py @@ -0,0 +1,45 @@ +import pytest +from fastapi import FastAPI, Security +from fastapi.security import OAuth2AuthorizationCodeBearer +from fastapi.testclient import TestClient +from starlette.testclient import WebSocketDenialResponse +from starlette.websockets import WebSocket + +app = FastAPI() + +oauth2_scheme = OAuth2AuthorizationCodeBearer( + authorizationUrl="/api/oauth/authorize", + tokenUrl="/api/oauth/token", + scopes={"read": "Read access", "write": "Write access"}, +) + + +@app.websocket("/ws/admin") +async def read_admin(websocket: WebSocket, token: str = Security(oauth2_scheme)): + await websocket.accept() + await websocket.send_text(token) + + +client = TestClient(app) + + +def test_security_oauth2_authorization_code_bearer_ws(): + with client.websocket_connect( + "/ws/admin", headers={"Authorization": "Bearer faketoken"} + ) as websocket: + data = websocket.receive_text() + assert data == "faketoken" + + +def test_security_oauth2_authorization_code_bearer_no_header_ws(): + with pytest.raises(WebSocketDenialResponse): + with client.websocket_connect("/ws/admin"): + pass + + +def test_security_oauth2_authorization_code_bearer_wrong_scheme_ws(): + with pytest.raises(WebSocketDenialResponse): + with client.websocket_connect( + "/ws/admin", headers={"Authorization": "Basic nope"} + ): + pass diff --git a/tests/test_security_oauth2_password_bearer_websocket.py b/tests/test_security_oauth2_password_bearer_websocket.py new file mode 100644 index 0000000000..8c71adace0 --- /dev/null +++ b/tests/test_security_oauth2_password_bearer_websocket.py @@ -0,0 +1,41 @@ +import pytest +from fastapi import FastAPI, Security +from fastapi.security import OAuth2PasswordBearer +from fastapi.testclient import TestClient +from starlette.testclient import WebSocketDenialResponse +from starlette.websockets import WebSocket + +app = FastAPI() + +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") + + +@app.websocket("/ws/token") +async def read_token(websocket: WebSocket, token: str = Security(oauth2_scheme)): + await websocket.accept() + await websocket.send_text(token) + + +client = TestClient(app) + + +def test_security_oauth2_password_bearer_ws(): + with client.websocket_connect( + "/ws/token", headers={"Authorization": "Bearer faketoken"} + ) as websocket: + data = websocket.receive_text() + assert data == "faketoken" + + +def test_security_oauth2_password_bearer_no_header_ws(): + with pytest.raises(WebSocketDenialResponse): + with client.websocket_connect("/ws/token"): + pass + + +def test_security_oauth2_password_bearer_wrong_scheme_ws(): + with pytest.raises(WebSocketDenialResponse): + with client.websocket_connect( + "/ws/token", headers={"Authorization": "Basic nope"} + ): + pass diff --git a/tests/test_security_openid_connect_websocket.py b/tests/test_security_openid_connect_websocket.py new file mode 100644 index 0000000000..d9a0981409 --- /dev/null +++ b/tests/test_security_openid_connect_websocket.py @@ -0,0 +1,53 @@ +import pytest +from fastapi import Depends, FastAPI, Security +from fastapi.security.open_id_connect_url import OpenIdConnect +from fastapi.testclient import TestClient +from pydantic import BaseModel +from starlette.testclient import WebSocketDenialResponse +from starlette.websockets import WebSocket + +app = FastAPI() + +oid = OpenIdConnect(openIdConnectUrl="/openid") + + +class User(BaseModel): + username: str + + +def get_current_user(oauth_header: str = Security(oid)): + user = User(username=oauth_header) + return user + + +@app.websocket("/ws/users/me") +async def read_current_user( + websocket: WebSocket, current_user: User = Depends(get_current_user) +): + await websocket.accept() + await websocket.send_json(current_user.model_dump()) + + +client = TestClient(app) + + +def test_security_openid_connect_ws(): + with client.websocket_connect( + "/ws/users/me", headers={"Authorization": "Bearer footokenbar"} + ) as websocket: + data = websocket.receive_json() + assert data == {"username": "Bearer footokenbar"} + + +def test_security_openid_connect_other_header_ws(): + with client.websocket_connect( + "/ws/users/me", headers={"Authorization": "Other footokenbar"} + ) as websocket: + data = websocket.receive_json() + assert data == {"username": "Other footokenbar"} + + +def test_security_openid_connect_no_header_ws(): + with pytest.raises(WebSocketDenialResponse): + with client.websocket_connect("/ws/users/me"): + pass From bf7e2b7cabff4870976d270a6e7a3b884ed17078 Mon Sep 17 00:00:00 2001 From: philipmunch Date: Wed, 17 Dec 2025 20:19:06 +0100 Subject: [PATCH 2/2] Security dependencies now working with websockets --- fastapi/security/api_key.py | 8 +-- fastapi/security/http.py | 10 ++-- fastapi/security/oauth2.py | 8 +-- fastapi/security/open_id_connect_url.py | 4 +- .../test_security_api_key_cookie_websocket.py | 42 +++++++++++++++ .../test_security_api_key_header_websocket.py | 45 ++++++++++++++++ .../test_security_api_key_query_websocket.py | 43 +++++++++++++++ tests/test_security_http_base_websocket.py | 38 +++++++++++++ tests/test_security_http_basic_websocket.py | 50 +++++++++++++++++ tests/test_security_http_bearer_websocket.py | 46 ++++++++++++++++ tests/test_security_http_digest_websocket.py | 46 ++++++++++++++++ ...th2_authorization_code_bearer_websocket.py | 45 ++++++++++++++++ ...curity_oauth2_password_bearer_websocket.py | 41 ++++++++++++++ .../test_security_openid_connect_websocket.py | 53 +++++++++++++++++++ 14 files changed, 464 insertions(+), 15 deletions(-) create mode 100644 tests/test_security_api_key_cookie_websocket.py create mode 100644 tests/test_security_api_key_header_websocket.py create mode 100644 tests/test_security_api_key_query_websocket.py create mode 100644 tests/test_security_http_base_websocket.py create mode 100644 tests/test_security_http_basic_websocket.py create mode 100644 tests/test_security_http_bearer_websocket.py create mode 100644 tests/test_security_http_digest_websocket.py create mode 100644 tests/test_security_oauth2_authorization_code_bearer_websocket.py create mode 100644 tests/test_security_oauth2_password_bearer_websocket.py create mode 100644 tests/test_security_openid_connect_websocket.py diff --git a/fastapi/security/api_key.py b/fastapi/security/api_key.py index 81c7be10d6..f64370e3e5 100644 --- a/fastapi/security/api_key.py +++ b/fastapi/security/api_key.py @@ -4,7 +4,7 @@ from annotated_doc import Doc from fastapi.openapi.models import APIKey, APIKeyIn from fastapi.security.base import SecurityBase from starlette.exceptions import HTTPException -from starlette.requests import Request +from starlette.requests import HTTPConnection from starlette.status import HTTP_401_UNAUTHORIZED from typing_extensions import Annotated @@ -138,7 +138,7 @@ class APIKeyQuery(APIKeyBase): auto_error=auto_error, ) - async def __call__(self, request: Request) -> Optional[str]: + async def __call__(self, request: HTTPConnection) -> Optional[str]: api_key = request.query_params.get(self.model.name) return self.check_api_key(api_key) @@ -226,7 +226,7 @@ class APIKeyHeader(APIKeyBase): auto_error=auto_error, ) - async def __call__(self, request: Request) -> Optional[str]: + async def __call__(self, request: HTTPConnection) -> Optional[str]: api_key = request.headers.get(self.model.name) return self.check_api_key(api_key) @@ -314,6 +314,6 @@ class APIKeyCookie(APIKeyBase): auto_error=auto_error, ) - async def __call__(self, request: Request) -> Optional[str]: + async def __call__(self, request: HTTPConnection) -> Optional[str]: api_key = request.cookies.get(self.model.name) return self.check_api_key(api_key) diff --git a/fastapi/security/http.py b/fastapi/security/http.py index 0d1bbba3a0..a6de1ba2ba 100644 --- a/fastapi/security/http.py +++ b/fastapi/security/http.py @@ -9,7 +9,7 @@ from fastapi.openapi.models import HTTPBearer as HTTPBearerModel from fastapi.security.base import SecurityBase from fastapi.security.utils import get_authorization_scheme_param from pydantic import BaseModel -from starlette.requests import Request +from starlette.requests import HTTPConnection from starlette.status import HTTP_401_UNAUTHORIZED from typing_extensions import Annotated @@ -93,7 +93,7 @@ class HTTPBase(SecurityBase): ) async def __call__( - self, request: Request + self, request: HTTPConnection ) -> Optional[HTTPAuthorizationCredentials]: authorization = request.headers.get("Authorization") scheme, credentials = get_authorization_scheme_param(authorization) @@ -203,7 +203,7 @@ class HTTPBasic(HTTPBase): return {"WWW-Authenticate": "Basic"} async def __call__( # type: ignore - self, request: Request + self, request: HTTPConnection ) -> Optional[HTTPBasicCredentials]: authorization = request.headers.get("Authorization") scheme, param = get_authorization_scheme_param(authorization) @@ -304,7 +304,7 @@ class HTTPBearer(HTTPBase): self.auto_error = auto_error async def __call__( - self, request: Request + self, request: HTTPConnection ) -> Optional[HTTPAuthorizationCredentials]: authorization = request.headers.get("Authorization") scheme, credentials = get_authorization_scheme_param(authorization) @@ -407,7 +407,7 @@ class HTTPDigest(HTTPBase): self.auto_error = auto_error async def __call__( - self, request: Request + self, request: HTTPConnection ) -> Optional[HTTPAuthorizationCredentials]: authorization = request.headers.get("Authorization") scheme, credentials = get_authorization_scheme_param(authorization) diff --git a/fastapi/security/oauth2.py b/fastapi/security/oauth2.py index b41b0f8778..7d666d320e 100644 --- a/fastapi/security/oauth2.py +++ b/fastapi/security/oauth2.py @@ -7,7 +7,7 @@ from fastapi.openapi.models import OAuthFlows as OAuthFlowsModel from fastapi.param_functions import Form from fastapi.security.base import SecurityBase from fastapi.security.utils import get_authorization_scheme_param -from starlette.requests import Request +from starlette.requests import HTTPConnection from starlette.status import HTTP_401_UNAUTHORIZED # TODO: import from typing when deprecating Python 3.9 @@ -399,7 +399,7 @@ class OAuth2(SecurityBase): headers={"WWW-Authenticate": "Bearer"}, ) - async def __call__(self, request: Request) -> Optional[str]: + async def __call__(self, request: HTTPConnection) -> Optional[str]: authorization = request.headers.get("Authorization") if not authorization: if self.auto_error: @@ -506,7 +506,7 @@ class OAuth2PasswordBearer(OAuth2): auto_error=auto_error, ) - async def __call__(self, request: Request) -> Optional[str]: + async def __call__(self, request: HTTPConnection) -> Optional[str]: authorization = request.headers.get("Authorization") scheme, param = get_authorization_scheme_param(authorization) if not authorization or scheme.lower() != "bearer": @@ -612,7 +612,7 @@ class OAuth2AuthorizationCodeBearer(OAuth2): auto_error=auto_error, ) - async def __call__(self, request: Request) -> Optional[str]: + async def __call__(self, request: HTTPConnection) -> Optional[str]: authorization = request.headers.get("Authorization") scheme, param = get_authorization_scheme_param(authorization) if not authorization or scheme.lower() != "bearer": diff --git a/fastapi/security/open_id_connect_url.py b/fastapi/security/open_id_connect_url.py index e574a56a82..4fdd7a89c7 100644 --- a/fastapi/security/open_id_connect_url.py +++ b/fastapi/security/open_id_connect_url.py @@ -4,7 +4,7 @@ from annotated_doc import Doc from fastapi.openapi.models import OpenIdConnect as OpenIdConnectModel from fastapi.security.base import SecurityBase from starlette.exceptions import HTTPException -from starlette.requests import Request +from starlette.requests import HTTPConnection from starlette.status import HTTP_401_UNAUTHORIZED from typing_extensions import Annotated @@ -85,7 +85,7 @@ class OpenIdConnect(SecurityBase): headers={"WWW-Authenticate": "Bearer"}, ) - async def __call__(self, request: Request) -> Optional[str]: + async def __call__(self, request: HTTPConnection) -> Optional[str]: authorization = request.headers.get("Authorization") if not authorization: if self.auto_error: diff --git a/tests/test_security_api_key_cookie_websocket.py b/tests/test_security_api_key_cookie_websocket.py new file mode 100644 index 0000000000..d4e0a4f57e --- /dev/null +++ b/tests/test_security_api_key_cookie_websocket.py @@ -0,0 +1,42 @@ +import pytest +from fastapi import Depends, FastAPI, Security +from fastapi.security import APIKeyCookie +from fastapi.testclient import TestClient +from pydantic import BaseModel +from starlette.testclient import WebSocketDenialResponse +from starlette.websockets import WebSocket + +app = FastAPI() + +api_key = APIKeyCookie(name="key") + + +class User(BaseModel): + username: str + + +def get_current_user(oauth_header: str = Security(api_key)): + user = User(username=oauth_header) + return user + + +@app.websocket("/ws/users/me") +async def read_current_user( + websocket: WebSocket, current_user: User = Depends(get_current_user) +): + await websocket.accept() + await websocket.send_text(current_user.username) + + +def test_security_api_key_ws(): + client = TestClient(app, cookies={"key": "secret"}) + with client.websocket_connect("/ws/users/me") as websocket: + data = websocket.receive_text() + assert data == "secret" + + +def test_security_api_key_no_key_ws(): + client = TestClient(app) + with pytest.raises(WebSocketDenialResponse): + with client.websocket_connect("/ws/users/me"): + pass diff --git a/tests/test_security_api_key_header_websocket.py b/tests/test_security_api_key_header_websocket.py new file mode 100644 index 0000000000..636a03a89e --- /dev/null +++ b/tests/test_security_api_key_header_websocket.py @@ -0,0 +1,45 @@ +import pytest +from fastapi import Depends, FastAPI, Security +from fastapi.security import APIKeyHeader +from fastapi.testclient import TestClient +from pydantic import BaseModel +from starlette.testclient import WebSocketDenialResponse +from starlette.websockets import WebSocket + +app = FastAPI() + +api_key = APIKeyHeader(name="key") + + +class User(BaseModel): + username: str + + +def get_current_user(oauth_header: str = Security(api_key)): + user = User(username=oauth_header) + return user + + +@app.websocket("/ws/users/me") +async def read_current_user( + websocket: WebSocket, current_user: User = Depends(get_current_user) +): + await websocket.accept() + await websocket.send_text(current_user.username) + + +client = TestClient(app) + + +def test_security_api_key_ws(): + with client.websocket_connect( + "/ws/users/me", headers={"key": "secret"} + ) as websocket: + data = websocket.receive_text() + assert data == "secret" + + +def test_security_api_key_no_key_ws(): + with pytest.raises(WebSocketDenialResponse): + with client.websocket_connect("/ws/users/me"): + pass diff --git a/tests/test_security_api_key_query_websocket.py b/tests/test_security_api_key_query_websocket.py new file mode 100644 index 0000000000..5264538a69 --- /dev/null +++ b/tests/test_security_api_key_query_websocket.py @@ -0,0 +1,43 @@ +import pytest +from fastapi import Depends, FastAPI, Security +from fastapi.security import APIKeyQuery +from fastapi.testclient import TestClient +from pydantic import BaseModel +from starlette.testclient import WebSocketDenialResponse +from starlette.websockets import WebSocket + +app = FastAPI() + +api_key = APIKeyQuery(name="key") + + +class User(BaseModel): + username: str + + +def get_current_user(oauth_header: str = Security(api_key)): + user = User(username=oauth_header) + return user + + +@app.websocket("/ws/users/me") +async def read_current_user( + websocket: WebSocket, current_user: User = Depends(get_current_user) +): + await websocket.accept() + await websocket.send_text(current_user.username) + + +client = TestClient(app) + + +def test_security_api_key_query_ws(): + with client.websocket_connect("/ws/users/me?key=secret") as websocket: + data = websocket.receive_text() + assert data == "secret" + + +def test_security_api_key_query_no_key_ws(): + with pytest.raises(WebSocketDenialResponse): + with client.websocket_connect("/ws/users/me"): + pass diff --git a/tests/test_security_http_base_websocket.py b/tests/test_security_http_base_websocket.py new file mode 100644 index 0000000000..7041471ea5 --- /dev/null +++ b/tests/test_security_http_base_websocket.py @@ -0,0 +1,38 @@ +import pytest +from fastapi import FastAPI, Security +from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBase +from fastapi.testclient import TestClient +from starlette.testclient import WebSocketDenialResponse +from starlette.websockets import WebSocket + +app = FastAPI() + +security = HTTPBase(scheme="Other") + + +@app.websocket("/ws/users/me") +async def read_current_user( + websocket: WebSocket, + credentials: HTTPAuthorizationCredentials = Security(security), +): + await websocket.accept() + await websocket.send_json( + {"scheme": credentials.scheme, "credentials": credentials.credentials} + ) + + +client = TestClient(app) + + +def test_security_http_base_ws(): + with client.websocket_connect( + "/ws/users/me", headers={"Authorization": "Other foobar"} + ) as websocket: + data = websocket.receive_json() + assert data == {"scheme": "Other", "credentials": "foobar"} + + +def test_security_http_base_no_credentials_ws(): + with pytest.raises(WebSocketDenialResponse): + with client.websocket_connect("/ws/users/me"): + pass diff --git a/tests/test_security_http_basic_websocket.py b/tests/test_security_http_basic_websocket.py new file mode 100644 index 0000000000..809487bb05 --- /dev/null +++ b/tests/test_security_http_basic_websocket.py @@ -0,0 +1,50 @@ +from base64 import b64encode + +import pytest +from fastapi import FastAPI, Security +from fastapi.security import HTTPBasic, HTTPBasicCredentials +from fastapi.testclient import TestClient +from starlette.testclient import WebSocketDenialResponse +from starlette.websockets import WebSocket + +app = FastAPI() + +security = HTTPBasic(realm="simple") + + +@app.websocket("/ws/users/me") +async def read_current_user( + websocket: WebSocket, credentials: HTTPBasicCredentials = Security(security) +): + await websocket.accept() + await websocket.send_json( + {"username": credentials.username, "password": credentials.password} + ) + + +client = TestClient(app) + + +def test_security_http_basic_ws(): + # Build Basic header + payload = b64encode(b"john:secret").decode("ascii") + auth_header = f"Basic {payload}" + with client.websocket_connect( + "/ws/users/me", headers={"Authorization": auth_header} + ) as websocket: + data = websocket.receive_json() + assert data == {"username": "john", "password": "secret"} + + +def test_security_http_basic_no_credentials_ws(): + with pytest.raises(WebSocketDenialResponse): + with client.websocket_connect("/ws/users/me"): + pass + + +def test_security_http_basic_invalid_credentials_ws(): + with pytest.raises(WebSocketDenialResponse): + with client.websocket_connect( + "/ws/users/me", headers={"Authorization": "Basic notbase64"} + ): + pass diff --git a/tests/test_security_http_bearer_websocket.py b/tests/test_security_http_bearer_websocket.py new file mode 100644 index 0000000000..ff41652a6b --- /dev/null +++ b/tests/test_security_http_bearer_websocket.py @@ -0,0 +1,46 @@ +import pytest +from fastapi import FastAPI, Security +from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer +from fastapi.testclient import TestClient +from starlette.testclient import WebSocketDenialResponse +from starlette.websockets import WebSocket + +app = FastAPI() + +security = HTTPBearer() + + +@app.websocket("/ws/users/me") +async def read_current_user( + websocket: WebSocket, + credentials: HTTPAuthorizationCredentials = Security(security), +): + await websocket.accept() + await websocket.send_json( + {"scheme": credentials.scheme, "credentials": credentials.credentials} + ) + + +client = TestClient(app) + + +def test_security_http_bearer_ws(): + with client.websocket_connect( + "/ws/users/me", headers={"Authorization": "Bearer foobar"} + ) as websocket: + data = websocket.receive_json() + assert data == {"scheme": "Bearer", "credentials": "foobar"} + + +def test_security_http_bearer_no_credentials_ws(): + with pytest.raises(WebSocketDenialResponse): + with client.websocket_connect("/ws/users/me"): + pass + + +def test_security_http_bearer_incorrect_scheme_ws(): + with pytest.raises(WebSocketDenialResponse): + with client.websocket_connect( + "/ws/users/me", headers={"Authorization": "Basic notreally"} + ): + pass diff --git a/tests/test_security_http_digest_websocket.py b/tests/test_security_http_digest_websocket.py new file mode 100644 index 0000000000..065507a851 --- /dev/null +++ b/tests/test_security_http_digest_websocket.py @@ -0,0 +1,46 @@ +import pytest +from fastapi import FastAPI, Security +from fastapi.security import HTTPAuthorizationCredentials, HTTPDigest +from fastapi.testclient import TestClient +from starlette.testclient import WebSocketDenialResponse +from starlette.websockets import WebSocket + +app = FastAPI() + +security = HTTPDigest() + + +@app.websocket("/ws/users/me") +async def read_current_user( + websocket: WebSocket, + credentials: HTTPAuthorizationCredentials = Security(security), +): + await websocket.accept() + await websocket.send_json( + {"scheme": credentials.scheme, "credentials": credentials.credentials} + ) + + +client = TestClient(app) + + +def test_security_http_digest_ws(): + with client.websocket_connect( + "/ws/users/me", headers={"Authorization": "Digest foobar"} + ) as websocket: + data = websocket.receive_json() + assert data == {"scheme": "Digest", "credentials": "foobar"} + + +def test_security_http_digest_no_credentials_ws(): + with pytest.raises(WebSocketDenialResponse): + with client.websocket_connect("/ws/users/me"): + pass + + +def test_security_http_digest_incorrect_scheme_ws(): + with pytest.raises(WebSocketDenialResponse): + with client.websocket_connect( + "/ws/users/me", headers={"Authorization": "Basic notreally"} + ): + pass diff --git a/tests/test_security_oauth2_authorization_code_bearer_websocket.py b/tests/test_security_oauth2_authorization_code_bearer_websocket.py new file mode 100644 index 0000000000..db983607f5 --- /dev/null +++ b/tests/test_security_oauth2_authorization_code_bearer_websocket.py @@ -0,0 +1,45 @@ +import pytest +from fastapi import FastAPI, Security +from fastapi.security import OAuth2AuthorizationCodeBearer +from fastapi.testclient import TestClient +from starlette.testclient import WebSocketDenialResponse +from starlette.websockets import WebSocket + +app = FastAPI() + +oauth2_scheme = OAuth2AuthorizationCodeBearer( + authorizationUrl="/api/oauth/authorize", + tokenUrl="/api/oauth/token", + scopes={"read": "Read access", "write": "Write access"}, +) + + +@app.websocket("/ws/admin") +async def read_admin(websocket: WebSocket, token: str = Security(oauth2_scheme)): + await websocket.accept() + await websocket.send_text(token) + + +client = TestClient(app) + + +def test_security_oauth2_authorization_code_bearer_ws(): + with client.websocket_connect( + "/ws/admin", headers={"Authorization": "Bearer faketoken"} + ) as websocket: + data = websocket.receive_text() + assert data == "faketoken" + + +def test_security_oauth2_authorization_code_bearer_no_header_ws(): + with pytest.raises(WebSocketDenialResponse): + with client.websocket_connect("/ws/admin"): + pass + + +def test_security_oauth2_authorization_code_bearer_wrong_scheme_ws(): + with pytest.raises(WebSocketDenialResponse): + with client.websocket_connect( + "/ws/admin", headers={"Authorization": "Basic nope"} + ): + pass diff --git a/tests/test_security_oauth2_password_bearer_websocket.py b/tests/test_security_oauth2_password_bearer_websocket.py new file mode 100644 index 0000000000..8c71adace0 --- /dev/null +++ b/tests/test_security_oauth2_password_bearer_websocket.py @@ -0,0 +1,41 @@ +import pytest +from fastapi import FastAPI, Security +from fastapi.security import OAuth2PasswordBearer +from fastapi.testclient import TestClient +from starlette.testclient import WebSocketDenialResponse +from starlette.websockets import WebSocket + +app = FastAPI() + +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") + + +@app.websocket("/ws/token") +async def read_token(websocket: WebSocket, token: str = Security(oauth2_scheme)): + await websocket.accept() + await websocket.send_text(token) + + +client = TestClient(app) + + +def test_security_oauth2_password_bearer_ws(): + with client.websocket_connect( + "/ws/token", headers={"Authorization": "Bearer faketoken"} + ) as websocket: + data = websocket.receive_text() + assert data == "faketoken" + + +def test_security_oauth2_password_bearer_no_header_ws(): + with pytest.raises(WebSocketDenialResponse): + with client.websocket_connect("/ws/token"): + pass + + +def test_security_oauth2_password_bearer_wrong_scheme_ws(): + with pytest.raises(WebSocketDenialResponse): + with client.websocket_connect( + "/ws/token", headers={"Authorization": "Basic nope"} + ): + pass diff --git a/tests/test_security_openid_connect_websocket.py b/tests/test_security_openid_connect_websocket.py new file mode 100644 index 0000000000..f18c030ecd --- /dev/null +++ b/tests/test_security_openid_connect_websocket.py @@ -0,0 +1,53 @@ +import pytest +from fastapi import Depends, FastAPI, Security +from fastapi.security.open_id_connect_url import OpenIdConnect +from fastapi.testclient import TestClient +from pydantic import BaseModel +from starlette.testclient import WebSocketDenialResponse +from starlette.websockets import WebSocket + +app = FastAPI() + +oid = OpenIdConnect(openIdConnectUrl="/openid") + + +class User(BaseModel): + username: str + + +def get_current_user(oauth_header: str = Security(oid)): + user = User(username=oauth_header) + return user + + +@app.websocket("/ws/users/me") +async def read_current_user( + websocket: WebSocket, current_user: User = Depends(get_current_user) +): + await websocket.accept() + await websocket.send_json({"username": current_user.username}) + + +client = TestClient(app) + + +def test_security_openid_connect_ws(): + with client.websocket_connect( + "/ws/users/me", headers={"Authorization": "Bearer footokenbar"} + ) as websocket: + data = websocket.receive_json() + assert data == {"username": "Bearer footokenbar"} + + +def test_security_openid_connect_other_header_ws(): + with client.websocket_connect( + "/ws/users/me", headers={"Authorization": "Other footokenbar"} + ) as websocket: + data = websocket.receive_json() + assert data == {"username": "Other footokenbar"} + + +def test_security_openid_connect_no_header_ws(): + with pytest.raises(WebSocketDenialResponse): + with client.websocket_connect("/ws/users/me"): + pass