Add WebSocket handling support for HTTP security dependencies

This commit is contained in:
Mix 2023-08-26 14:58:47 +08:00 committed by HexMix
parent 7132a69046
commit 70a48e59a2
5 changed files with 58 additions and 20 deletions

View File

@ -3,8 +3,9 @@ from typing import Optional
from annotated_doc import Doc from annotated_doc import Doc
from fastapi.openapi.models import APIKey, APIKeyIn from fastapi.openapi.models import APIKey, APIKeyIn
from fastapi.security.base import SecurityBase from fastapi.security.base import SecurityBase
from fastapi.security.utils import handle_exc_for_ws
from starlette.exceptions import HTTPException from starlette.exceptions import HTTPException
from starlette.requests import Request from starlette.requests import HTTPConnection
from starlette.status import HTTP_403_FORBIDDEN from starlette.status import HTTP_403_FORBIDDEN
from typing_extensions import Annotated from typing_extensions import Annotated
@ -108,7 +109,8 @@ class APIKeyQuery(APIKeyBase):
self.scheme_name = scheme_name or self.__class__.__name__ self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error self.auto_error = auto_error
async def __call__(self, request: Request) -> Optional[str]: @handle_exc_for_ws
async def __call__(self, request: HTTPConnection) -> Optional[str]:
api_key = request.query_params.get(self.model.name) api_key = request.query_params.get(self.model.name)
return self.check_api_key(api_key, self.auto_error) return self.check_api_key(api_key, self.auto_error)
@ -196,7 +198,8 @@ class APIKeyHeader(APIKeyBase):
self.scheme_name = scheme_name or self.__class__.__name__ self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error self.auto_error = auto_error
async def __call__(self, request: Request) -> Optional[str]: @handle_exc_for_ws
async def __call__(self, request: HTTPConnection) -> Optional[str]:
api_key = request.headers.get(self.model.name) api_key = request.headers.get(self.model.name)
return self.check_api_key(api_key, self.auto_error) return self.check_api_key(api_key, self.auto_error)
@ -284,6 +287,7 @@ class APIKeyCookie(APIKeyBase):
self.scheme_name = scheme_name or self.__class__.__name__ self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error self.auto_error = auto_error
async def __call__(self, request: Request) -> Optional[str]: @handle_exc_for_ws
async def __call__(self, request: HTTPConnection) -> Optional[str]:
api_key = request.cookies.get(self.model.name) api_key = request.cookies.get(self.model.name)
return self.check_api_key(api_key, self.auto_error) return self.check_api_key(api_key, self.auto_error)

View File

@ -7,9 +7,9 @@ from fastapi.exceptions import HTTPException
from fastapi.openapi.models import HTTPBase as HTTPBaseModel from fastapi.openapi.models import HTTPBase as HTTPBaseModel
from fastapi.openapi.models import HTTPBearer as HTTPBearerModel from fastapi.openapi.models import HTTPBearer as HTTPBearerModel
from fastapi.security.base import SecurityBase from fastapi.security.base import SecurityBase
from fastapi.security.utils import get_authorization_scheme_param from fastapi.security.utils import get_authorization_scheme_param, handle_exc_for_ws
from pydantic import BaseModel from pydantic import BaseModel
from starlette.requests import Request from starlette.requests import HTTPConnection
from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN
from typing_extensions import Annotated from typing_extensions import Annotated
@ -80,8 +80,9 @@ class HTTPBase(SecurityBase):
self.scheme_name = scheme_name or self.__class__.__name__ self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error self.auto_error = auto_error
@handle_exc_for_ws
async def __call__( async def __call__(
self, request: Request self, request: HTTPConnection
) -> Optional[HTTPAuthorizationCredentials]: ) -> Optional[HTTPAuthorizationCredentials]:
authorization = request.headers.get("Authorization") authorization = request.headers.get("Authorization")
scheme, credentials = get_authorization_scheme_param(authorization) scheme, credentials = get_authorization_scheme_param(authorization)
@ -185,9 +186,8 @@ class HTTPBasic(HTTPBase):
self.realm = realm self.realm = realm
self.auto_error = auto_error self.auto_error = auto_error
async def __call__( # type: ignore @handle_exc_for_ws
self, request: Request async def __call__(self, request: HTTPConnection) -> Optional[HTTPBasicCredentials]:
) -> Optional[HTTPBasicCredentials]:
authorization = request.headers.get("Authorization") authorization = request.headers.get("Authorization")
scheme, param = get_authorization_scheme_param(authorization) scheme, param = get_authorization_scheme_param(authorization)
if self.realm: if self.realm:
@ -299,8 +299,9 @@ class HTTPBearer(HTTPBase):
self.scheme_name = scheme_name or self.__class__.__name__ self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error self.auto_error = auto_error
@handle_exc_for_ws
async def __call__( async def __call__(
self, request: Request self, request: HTTPConnection
) -> Optional[HTTPAuthorizationCredentials]: ) -> Optional[HTTPAuthorizationCredentials]:
authorization = request.headers.get("Authorization") authorization = request.headers.get("Authorization")
scheme, credentials = get_authorization_scheme_param(authorization) scheme, credentials = get_authorization_scheme_param(authorization)
@ -401,8 +402,9 @@ class HTTPDigest(HTTPBase):
self.scheme_name = scheme_name or self.__class__.__name__ self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error self.auto_error = auto_error
@handle_exc_for_ws
async def __call__( async def __call__(
self, request: Request self, request: HTTPConnection
) -> Optional[HTTPAuthorizationCredentials]: ) -> Optional[HTTPAuthorizationCredentials]:
authorization = request.headers.get("Authorization") authorization = request.headers.get("Authorization")
scheme, credentials = get_authorization_scheme_param(authorization) scheme, credentials = get_authorization_scheme_param(authorization)

View File

@ -6,8 +6,8 @@ from fastapi.openapi.models import OAuth2 as OAuth2Model
from fastapi.openapi.models import OAuthFlows as OAuthFlowsModel from fastapi.openapi.models import OAuthFlows as OAuthFlowsModel
from fastapi.param_functions import Form from fastapi.param_functions import Form
from fastapi.security.base import SecurityBase from fastapi.security.base import SecurityBase
from fastapi.security.utils import get_authorization_scheme_param from fastapi.security.utils import get_authorization_scheme_param, handle_exc_for_ws
from starlette.requests import Request from starlette.requests import HTTPConnection
from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN
# TODO: import from typing when deprecating Python 3.9 # TODO: import from typing when deprecating Python 3.9
@ -377,7 +377,8 @@ class OAuth2(SecurityBase):
self.scheme_name = scheme_name or self.__class__.__name__ self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error self.auto_error = auto_error
async def __call__(self, request: Request) -> Optional[str]: @handle_exc_for_ws
async def __call__(self, request: HTTPConnection) -> Optional[str]:
authorization = request.headers.get("Authorization") authorization = request.headers.get("Authorization")
if not authorization: if not authorization:
if self.auto_error: if self.auto_error:
@ -486,7 +487,8 @@ class OAuth2PasswordBearer(OAuth2):
auto_error=auto_error, auto_error=auto_error,
) )
async def __call__(self, request: Request) -> Optional[str]: @handle_exc_for_ws
async def __call__(self, request: HTTPConnection) -> Optional[str]:
authorization = request.headers.get("Authorization") authorization = request.headers.get("Authorization")
scheme, param = get_authorization_scheme_param(authorization) scheme, param = get_authorization_scheme_param(authorization)
if not authorization or scheme.lower() != "bearer": if not authorization or scheme.lower() != "bearer":
@ -596,7 +598,8 @@ class OAuth2AuthorizationCodeBearer(OAuth2):
auto_error=auto_error, auto_error=auto_error,
) )
async def __call__(self, request: Request) -> Optional[str]: @handle_exc_for_ws
async def __call__(self, request: HTTPConnection) -> Optional[str]:
authorization = request.headers.get("Authorization") authorization = request.headers.get("Authorization")
scheme, param = get_authorization_scheme_param(authorization) scheme, param = get_authorization_scheme_param(authorization)
if not authorization or scheme.lower() != "bearer": if not authorization or scheme.lower() != "bearer":

View File

@ -3,8 +3,9 @@ from typing import Optional
from annotated_doc import Doc from annotated_doc import Doc
from fastapi.openapi.models import OpenIdConnect as OpenIdConnectModel from fastapi.openapi.models import OpenIdConnect as OpenIdConnectModel
from fastapi.security.base import SecurityBase from fastapi.security.base import SecurityBase
from fastapi.security.utils import handle_exc_for_ws
from starlette.exceptions import HTTPException from starlette.exceptions import HTTPException
from starlette.requests import Request from starlette.requests import HTTPConnection
from starlette.status import HTTP_403_FORBIDDEN from starlette.status import HTTP_403_FORBIDDEN
from typing_extensions import Annotated from typing_extensions import Annotated
@ -73,7 +74,8 @@ class OpenIdConnect(SecurityBase):
self.scheme_name = scheme_name or self.__class__.__name__ self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error self.auto_error = auto_error
async def __call__(self, request: Request) -> Optional[str]: @handle_exc_for_ws
async def __call__(self, request: HTTPConnection) -> Optional[str]:
authorization = request.headers.get("Authorization") authorization = request.headers.get("Authorization")
if not authorization: if not authorization:
if self.auto_error: if self.auto_error:

View File

@ -1,4 +1,10 @@
from typing import Optional, Tuple from functools import wraps
from typing import Any, Awaitable, Callable, Optional, Tuple, TypeVar
from fastapi.exceptions import HTTPException, WebSocketException
from starlette.requests import HTTPConnection
from starlette.status import WS_1008_POLICY_VIOLATION
from starlette.websockets import WebSocket
def get_authorization_scheme_param( def get_authorization_scheme_param(
@ -8,3 +14,24 @@ def get_authorization_scheme_param(
return "", "" return "", ""
scheme, _, param = authorization_header_value.partition(" ") scheme, _, param = authorization_header_value.partition(" ")
return scheme, param return scheme, param
_SecurityDepFunc = TypeVar(
"_SecurityDepFunc", bound=Callable[[Any, HTTPConnection], Awaitable]
)
def handle_exc_for_ws(func: _SecurityDepFunc) -> _SecurityDepFunc:
@wraps(func)
async def wrapper(self, request: HTTPConnection, *args, **kwargs):
try:
return await func(self, request, *args, **kwargs)
except HTTPException as e:
if not isinstance(request, WebSocket):
raise e
await request.accept()
raise WebSocketException(
code=WS_1008_POLICY_VIOLATION, reason=e.detail
) from None
return wrapper # type: ignore