This commit is contained in:
Mix 2026-03-16 10:16:58 +00:00 committed by GitHub
commit d52e84afcb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 60 additions and 27 deletions

View File

@ -4,7 +4,7 @@ from annotated_doc import Doc
from fastapi.openapi.models import APIKey, APIKeyIn from fastapi.openapi.models import APIKey, APIKeyIn
from fastapi.security.base import SecurityBase from fastapi.security.base import SecurityBase
from starlette.exceptions import HTTPException from starlette.exceptions import HTTPException
from starlette.requests import Request from starlette.requests import HTTPConnection
from starlette.status import HTTP_401_UNAUTHORIZED from starlette.status import HTTP_401_UNAUTHORIZED
@ -139,8 +139,8 @@ class APIKeyQuery(APIKeyBase):
auto_error=auto_error, auto_error=auto_error,
) )
async def __call__(self, request: Request) -> str | None: async def __call__(self, conn: HTTPConnection) -> str | None:
api_key = request.query_params.get(self.model.name) api_key = conn.query_params.get(self.model.name)
return self.check_api_key(api_key) return self.check_api_key(api_key)
@ -227,8 +227,8 @@ class APIKeyHeader(APIKeyBase):
auto_error=auto_error, auto_error=auto_error,
) )
async def __call__(self, request: Request) -> str | None: async def __call__(self, conn: HTTPConnection) -> str | None:
api_key = request.headers.get(self.model.name) api_key = conn.headers.get(self.model.name)
return self.check_api_key(api_key) return self.check_api_key(api_key)
@ -315,6 +315,6 @@ class APIKeyCookie(APIKeyBase):
auto_error=auto_error, auto_error=auto_error,
) )
async def __call__(self, request: Request) -> str | None: async def __call__(self, conn: HTTPConnection) -> str | None:
api_key = request.cookies.get(self.model.name) api_key = conn.cookies.get(self.model.name)
return self.check_api_key(api_key) return self.check_api_key(api_key)

View File

@ -9,7 +9,7 @@ from fastapi.openapi.models import HTTPBearer as HTTPBearerModel
from fastapi.security.base import SecurityBase from fastapi.security.base import SecurityBase
from fastapi.security.utils import get_authorization_scheme_param from fastapi.security.utils import get_authorization_scheme_param
from pydantic import BaseModel from pydantic import BaseModel
from starlette.requests import Request from starlette.requests import HTTPConnection
from starlette.status import HTTP_401_UNAUTHORIZED from starlette.status import HTTP_401_UNAUTHORIZED
@ -91,8 +91,10 @@ class HTTPBase(SecurityBase):
headers=self.make_authenticate_headers(), headers=self.make_authenticate_headers(),
) )
async def __call__(self, request: Request) -> HTTPAuthorizationCredentials | None: async def __call__(
authorization = request.headers.get("Authorization") self, conn: HTTPConnection
) -> HTTPAuthorizationCredentials | None:
authorization = conn.headers.get("Authorization")
scheme, credentials = get_authorization_scheme_param(authorization) scheme, credentials = get_authorization_scheme_param(authorization)
if not (authorization and scheme and credentials): if not (authorization and scheme and credentials):
if self.auto_error: if self.auto_error:
@ -200,9 +202,9 @@ class HTTPBasic(HTTPBase):
return {"WWW-Authenticate": "Basic"} return {"WWW-Authenticate": "Basic"}
async def __call__( # type: ignore async def __call__( # type: ignore
self, request: Request self, conn: HTTPConnection
) -> HTTPBasicCredentials | None: ) -> HTTPBasicCredentials | None:
authorization = request.headers.get("Authorization") authorization = conn.headers.get("Authorization")
scheme, param = get_authorization_scheme_param(authorization) scheme, param = get_authorization_scheme_param(authorization)
if not authorization or scheme.lower() != "basic": if not authorization or scheme.lower() != "basic":
if self.auto_error: if self.auto_error:
@ -300,8 +302,10 @@ class HTTPBearer(HTTPBase):
self.scheme_name = scheme_name or self.__class__.__name__ self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error self.auto_error = auto_error
async def __call__(self, request: Request) -> HTTPAuthorizationCredentials | None: async def __call__(
authorization = request.headers.get("Authorization") self, conn: HTTPConnection
) -> HTTPAuthorizationCredentials | None:
authorization = conn.headers.get("Authorization")
scheme, credentials = get_authorization_scheme_param(authorization) scheme, credentials = get_authorization_scheme_param(authorization)
if not (authorization and scheme and credentials): if not (authorization and scheme and credentials):
if self.auto_error: if self.auto_error:
@ -401,8 +405,10 @@ class HTTPDigest(HTTPBase):
self.scheme_name = scheme_name or self.__class__.__name__ self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error self.auto_error = auto_error
async def __call__(self, request: Request) -> HTTPAuthorizationCredentials | None: async def __call__(
authorization = request.headers.get("Authorization") self, conn: HTTPConnection
) -> HTTPAuthorizationCredentials | None:
authorization = conn.headers.get("Authorization")
scheme, credentials = get_authorization_scheme_param(authorization) scheme, credentials = get_authorization_scheme_param(authorization)
if not (authorization and scheme and credentials): if not (authorization and scheme and credentials):
if self.auto_error: if self.auto_error:

View File

@ -7,7 +7,7 @@ from fastapi.openapi.models import OAuthFlows as OAuthFlowsModel
from fastapi.param_functions import Form from fastapi.param_functions import Form
from fastapi.security.base import SecurityBase from fastapi.security.base import SecurityBase
from fastapi.security.utils import get_authorization_scheme_param from fastapi.security.utils import get_authorization_scheme_param
from starlette.requests import Request from starlette.requests import HTTPConnection
from starlette.status import HTTP_401_UNAUTHORIZED from starlette.status import HTTP_401_UNAUTHORIZED
@ -420,8 +420,8 @@ class OAuth2(SecurityBase):
headers={"WWW-Authenticate": "Bearer"}, headers={"WWW-Authenticate": "Bearer"},
) )
async def __call__(self, request: Request) -> str | None: async def __call__(self, conn: HTTPConnection) -> str | None:
authorization = request.headers.get("Authorization") authorization = conn.headers.get("Authorization")
if not authorization: if not authorization:
if self.auto_error: if self.auto_error:
raise self.make_not_authenticated_error() raise self.make_not_authenticated_error()
@ -533,8 +533,8 @@ class OAuth2PasswordBearer(OAuth2):
auto_error=auto_error, auto_error=auto_error,
) )
async def __call__(self, request: Request) -> str | None: async def __call__(self, conn: HTTPConnection) -> str | None:
authorization = request.headers.get("Authorization") authorization = conn.headers.get("Authorization")
scheme, param = get_authorization_scheme_param(authorization) scheme, param = get_authorization_scheme_param(authorization)
if not authorization or scheme.lower() != "bearer": if not authorization or scheme.lower() != "bearer":
if self.auto_error: if self.auto_error:
@ -639,8 +639,8 @@ class OAuth2AuthorizationCodeBearer(OAuth2):
auto_error=auto_error, auto_error=auto_error,
) )
async def __call__(self, request: Request) -> str | None: async def __call__(self, conn: HTTPConnection) -> str | None:
authorization = request.headers.get("Authorization") authorization = conn.headers.get("Authorization")
scheme, param = get_authorization_scheme_param(authorization) scheme, param = get_authorization_scheme_param(authorization)
if not authorization or scheme.lower() != "bearer": if not authorization or scheme.lower() != "bearer":
if self.auto_error: if self.auto_error:

View File

@ -4,7 +4,7 @@ from annotated_doc import Doc
from fastapi.openapi.models import OpenIdConnect as OpenIdConnectModel from fastapi.openapi.models import OpenIdConnect as OpenIdConnectModel
from fastapi.security.base import SecurityBase from fastapi.security.base import SecurityBase
from starlette.exceptions import HTTPException from starlette.exceptions import HTTPException
from starlette.requests import Request from starlette.requests import HTTPConnection
from starlette.status import HTTP_401_UNAUTHORIZED from starlette.status import HTTP_401_UNAUTHORIZED
@ -84,8 +84,8 @@ class OpenIdConnect(SecurityBase):
headers={"WWW-Authenticate": "Bearer"}, headers={"WWW-Authenticate": "Bearer"},
) )
async def __call__(self, request: Request) -> str | None: async def __call__(self, conn: HTTPConnection) -> str | None:
authorization = request.headers.get("Authorization") authorization = conn.headers.get("Authorization")
if not authorization: if not authorization:
if self.auto_error: if self.auto_error:
raise self.make_not_authenticated_error() raise self.make_not_authenticated_error()

View File

@ -1,4 +1,4 @@
from fastapi import FastAPI, Security from fastapi import FastAPI, Security, WebSocket
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBase from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBase
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from inline_snapshot import snapshot from inline_snapshot import snapshot
@ -17,6 +17,19 @@ def read_current_user(
return {"scheme": credentials.scheme, "credentials": credentials.credentials} return {"scheme": credentials.scheme, "credentials": credentials.credentials}
@app.websocket("/users/timeline")
async def read_user_timeline(
websocket: WebSocket,
credentials: HTTPAuthorizationCredentials | None = Security(security),
):
await websocket.accept()
await websocket.send_json(
{"scheme": credentials.scheme, "credentials": credentials.credentials}
if credentials
else {"msg": "Create an account first"}
)
client = TestClient(app) client = TestClient(app)
@ -32,6 +45,20 @@ def test_security_http_base_no_credentials():
assert response.json() == {"msg": "Create an account first"} assert response.json() == {"msg": "Create an account first"}
def test_security_http_base_with_ws():
with client.websocket_connect(
"/users/timeline", headers={"Authorization": "Other foobar"}
) as websocket:
data = websocket.receive_json()
assert data == {"scheme": "Other", "credentials": "foobar"}
def test_security_http_base_with_ws_no_credentials():
with client.websocket_connect("/users/timeline") as websocket:
data = websocket.receive_json()
assert data == {"msg": "Create an account first"}
def test_openapi_schema(): def test_openapi_schema():
response = client.get("/openapi.json") response = client.get("/openapi.json")
assert response.status_code == 200, response.text assert response.status_code == 200, response.text