This commit is contained in:
Mix 2025-12-16 21:07:31 +00:00 committed by GitHub
commit 344698a0eb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 47 additions and 20 deletions

View File

@ -4,7 +4,7 @@ from annotated_doc import Doc
from fastapi.openapi.models import APIKey, APIKeyIn from fastapi.openapi.models import APIKey, APIKeyIn
from fastapi.security.base import SecurityBase from fastapi.security.base import SecurityBase
from starlette.exceptions import HTTPException from starlette.exceptions import HTTPException
from starlette.requests import Request from starlette.requests import HTTPConnection
from starlette.status import HTTP_401_UNAUTHORIZED from starlette.status import HTTP_401_UNAUTHORIZED
from typing_extensions import Annotated from typing_extensions import Annotated
@ -138,7 +138,7 @@ class APIKeyQuery(APIKeyBase):
auto_error=auto_error, auto_error=auto_error,
) )
async def __call__(self, request: Request) -> Optional[str]: async def __call__(self, request: HTTPConnection) -> Optional[str]:
api_key = request.query_params.get(self.model.name) api_key = request.query_params.get(self.model.name)
return self.check_api_key(api_key) return self.check_api_key(api_key)
@ -226,7 +226,7 @@ class APIKeyHeader(APIKeyBase):
auto_error=auto_error, auto_error=auto_error,
) )
async def __call__(self, request: Request) -> Optional[str]: async def __call__(self, request: HTTPConnection) -> Optional[str]:
api_key = request.headers.get(self.model.name) api_key = request.headers.get(self.model.name)
return self.check_api_key(api_key) return self.check_api_key(api_key)
@ -314,6 +314,6 @@ class APIKeyCookie(APIKeyBase):
auto_error=auto_error, auto_error=auto_error,
) )
async def __call__(self, request: Request) -> Optional[str]: async def __call__(self, request: HTTPConnection) -> Optional[str]:
api_key = request.cookies.get(self.model.name) api_key = request.cookies.get(self.model.name)
return self.check_api_key(api_key) return self.check_api_key(api_key)

View File

@ -9,7 +9,7 @@ from fastapi.openapi.models import HTTPBearer as HTTPBearerModel
from fastapi.security.base import SecurityBase from fastapi.security.base import SecurityBase
from fastapi.security.utils import get_authorization_scheme_param from fastapi.security.utils import get_authorization_scheme_param
from pydantic import BaseModel from pydantic import BaseModel
from starlette.requests import Request from starlette.requests import HTTPConnection
from starlette.status import HTTP_401_UNAUTHORIZED from starlette.status import HTTP_401_UNAUTHORIZED
from typing_extensions import Annotated from typing_extensions import Annotated
@ -93,9 +93,9 @@ class HTTPBase(SecurityBase):
) )
async def __call__( async def __call__(
self, request: Request self, conn: HTTPConnection
) -> Optional[HTTPAuthorizationCredentials]: ) -> Optional[HTTPAuthorizationCredentials]:
authorization = request.headers.get("Authorization") authorization = conn.headers.get("Authorization")
scheme, credentials = get_authorization_scheme_param(authorization) scheme, credentials = get_authorization_scheme_param(authorization)
if not (authorization and scheme and credentials): if not (authorization and scheme and credentials):
if self.auto_error: if self.auto_error:
@ -203,9 +203,9 @@ class HTTPBasic(HTTPBase):
return {"WWW-Authenticate": "Basic"} return {"WWW-Authenticate": "Basic"}
async def __call__( # type: ignore async def __call__( # type: ignore
self, request: Request self, conn: HTTPConnection
) -> Optional[HTTPBasicCredentials]: ) -> Optional[HTTPBasicCredentials]:
authorization = request.headers.get("Authorization") authorization = conn.headers.get("Authorization")
scheme, param = get_authorization_scheme_param(authorization) scheme, param = get_authorization_scheme_param(authorization)
if not authorization or scheme.lower() != "basic": if not authorization or scheme.lower() != "basic":
if self.auto_error: if self.auto_error:
@ -304,9 +304,9 @@ class HTTPBearer(HTTPBase):
self.auto_error = auto_error self.auto_error = auto_error
async def __call__( async def __call__(
self, request: Request self, conn: HTTPConnection
) -> Optional[HTTPAuthorizationCredentials]: ) -> Optional[HTTPAuthorizationCredentials]:
authorization = request.headers.get("Authorization") authorization = conn.headers.get("Authorization")
scheme, credentials = get_authorization_scheme_param(authorization) scheme, credentials = get_authorization_scheme_param(authorization)
if not (authorization and scheme and credentials): if not (authorization and scheme and credentials):
if self.auto_error: if self.auto_error:
@ -407,9 +407,9 @@ class HTTPDigest(HTTPBase):
self.auto_error = auto_error self.auto_error = auto_error
async def __call__( async def __call__(
self, request: Request self, conn: HTTPConnection
) -> Optional[HTTPAuthorizationCredentials]: ) -> Optional[HTTPAuthorizationCredentials]:
authorization = request.headers.get("Authorization") authorization = conn.headers.get("Authorization")
scheme, credentials = get_authorization_scheme_param(authorization) scheme, credentials = get_authorization_scheme_param(authorization)
if not (authorization and scheme and credentials): if not (authorization and scheme and credentials):
if self.auto_error: if self.auto_error:

View File

@ -7,7 +7,7 @@ from fastapi.openapi.models import OAuthFlows as OAuthFlowsModel
from fastapi.param_functions import Form from fastapi.param_functions import Form
from fastapi.security.base import SecurityBase from fastapi.security.base import SecurityBase
from fastapi.security.utils import get_authorization_scheme_param from fastapi.security.utils import get_authorization_scheme_param
from starlette.requests import Request from starlette.requests import HTTPConnection
from starlette.status import HTTP_401_UNAUTHORIZED from starlette.status import HTTP_401_UNAUTHORIZED
# TODO: import from typing when deprecating Python 3.9 # TODO: import from typing when deprecating Python 3.9
@ -399,7 +399,7 @@ class OAuth2(SecurityBase):
headers={"WWW-Authenticate": "Bearer"}, headers={"WWW-Authenticate": "Bearer"},
) )
async def __call__(self, request: Request) -> Optional[str]: async def __call__(self, request: HTTPConnection) -> Optional[str]:
authorization = request.headers.get("Authorization") authorization = request.headers.get("Authorization")
if not authorization: if not authorization:
if self.auto_error: if self.auto_error:
@ -506,7 +506,7 @@ class OAuth2PasswordBearer(OAuth2):
auto_error=auto_error, auto_error=auto_error,
) )
async def __call__(self, request: Request) -> Optional[str]: async def __call__(self, request: HTTPConnection) -> Optional[str]:
authorization = request.headers.get("Authorization") authorization = request.headers.get("Authorization")
scheme, param = get_authorization_scheme_param(authorization) scheme, param = get_authorization_scheme_param(authorization)
if not authorization or scheme.lower() != "bearer": if not authorization or scheme.lower() != "bearer":
@ -612,7 +612,7 @@ class OAuth2AuthorizationCodeBearer(OAuth2):
auto_error=auto_error, auto_error=auto_error,
) )
async def __call__(self, request: Request) -> Optional[str]: async def __call__(self, request: HTTPConnection) -> Optional[str]:
authorization = request.headers.get("Authorization") authorization = request.headers.get("Authorization")
scheme, param = get_authorization_scheme_param(authorization) scheme, param = get_authorization_scheme_param(authorization)
if not authorization or scheme.lower() != "bearer": if not authorization or scheme.lower() != "bearer":

View File

@ -4,7 +4,7 @@ from annotated_doc import Doc
from fastapi.openapi.models import OpenIdConnect as OpenIdConnectModel from fastapi.openapi.models import OpenIdConnect as OpenIdConnectModel
from fastapi.security.base import SecurityBase from fastapi.security.base import SecurityBase
from starlette.exceptions import HTTPException from starlette.exceptions import HTTPException
from starlette.requests import Request from starlette.requests import HTTPConnection
from starlette.status import HTTP_401_UNAUTHORIZED from starlette.status import HTTP_401_UNAUTHORIZED
from typing_extensions import Annotated from typing_extensions import Annotated
@ -85,7 +85,7 @@ class OpenIdConnect(SecurityBase):
headers={"WWW-Authenticate": "Bearer"}, headers={"WWW-Authenticate": "Bearer"},
) )
async def __call__(self, request: Request) -> Optional[str]: async def __call__(self, request: HTTPConnection) -> Optional[str]:
authorization = request.headers.get("Authorization") authorization = request.headers.get("Authorization")
if not authorization: if not authorization:
if self.auto_error: if self.auto_error:

View File

@ -1,6 +1,6 @@
from typing import Optional from typing import Optional
from fastapi import FastAPI, Security from fastapi import FastAPI, Security, WebSocket
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBase from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBase
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
@ -18,6 +18,19 @@ def read_current_user(
return {"scheme": credentials.scheme, "credentials": credentials.credentials} return {"scheme": credentials.scheme, "credentials": credentials.credentials}
@app.websocket("/users/timeline")
async def read_user_timeline(
websocket: WebSocket,
credentials: Optional[HTTPAuthorizationCredentials] = Security(security),
):
await websocket.accept()
await websocket.send_json(
{"scheme": credentials.scheme, "credentials": credentials.credentials}
if credentials
else {"msg": "Create an account first"}
)
client = TestClient(app) client = TestClient(app)
@ -33,6 +46,20 @@ def test_security_http_base_no_credentials():
assert response.json() == {"msg": "Create an account first"} assert response.json() == {"msg": "Create an account first"}
def test_security_http_base_with_ws():
with client.websocket_connect(
"/users/timeline", headers={"Authorization": "Other foobar"}
) as websocket:
data = websocket.receive_json()
assert data == {"scheme": "Other", "credentials": "foobar"}
def test_security_http_base_with_ws_no_credentials():
with client.websocket_connect("/users/timeline") as websocket:
data = websocket.receive_json()
assert data == {"msg": "Create an account first"}
def test_openapi_schema(): def test_openapi_schema():
response = client.get("/openapi.json") response = client.get("/openapi.json")
assert response.status_code == 200, response.text assert response.status_code == 200, response.text