diff --git a/fastapi/security/api_key.py b/fastapi/security/api_key.py index 496c815a7..416049970 100644 --- a/fastapi/security/api_key.py +++ b/fastapi/security/api_key.py @@ -3,8 +3,9 @@ from typing import Optional from annotated_doc import Doc from fastapi.openapi.models import APIKey, APIKeyIn from fastapi.security.base import SecurityBase +from fastapi.security.utils import handle_exc_for_ws from starlette.exceptions import HTTPException -from starlette.requests import Request +from starlette.requests import HTTPConnection from starlette.status import HTTP_403_FORBIDDEN from typing_extensions import Annotated @@ -108,7 +109,8 @@ class APIKeyQuery(APIKeyBase): self.scheme_name = scheme_name or self.__class__.__name__ self.auto_error = auto_error - async def __call__(self, request: Request) -> Optional[str]: + @handle_exc_for_ws + async def __call__(self, request: HTTPConnection) -> Optional[str]: api_key = request.query_params.get(self.model.name) return self.check_api_key(api_key, self.auto_error) @@ -196,7 +198,8 @@ class APIKeyHeader(APIKeyBase): self.scheme_name = scheme_name or self.__class__.__name__ self.auto_error = auto_error - async def __call__(self, request: Request) -> Optional[str]: + @handle_exc_for_ws + async def __call__(self, request: HTTPConnection) -> Optional[str]: api_key = request.headers.get(self.model.name) return self.check_api_key(api_key, self.auto_error) @@ -284,6 +287,7 @@ class APIKeyCookie(APIKeyBase): self.scheme_name = scheme_name or self.__class__.__name__ self.auto_error = auto_error - async def __call__(self, request: Request) -> Optional[str]: + @handle_exc_for_ws + async def __call__(self, request: HTTPConnection) -> Optional[str]: api_key = request.cookies.get(self.model.name) return self.check_api_key(api_key, self.auto_error) diff --git a/fastapi/security/http.py b/fastapi/security/http.py index 3a5985650..df2cb6b43 100644 --- a/fastapi/security/http.py +++ b/fastapi/security/http.py @@ -7,9 +7,9 @@ from fastapi.exceptions import HTTPException from fastapi.openapi.models import HTTPBase as HTTPBaseModel from fastapi.openapi.models import HTTPBearer as HTTPBearerModel from fastapi.security.base import SecurityBase -from fastapi.security.utils import get_authorization_scheme_param +from fastapi.security.utils import get_authorization_scheme_param, handle_exc_for_ws from pydantic import BaseModel -from starlette.requests import Request +from starlette.requests import HTTPConnection from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN from typing_extensions import Annotated @@ -80,8 +80,9 @@ class HTTPBase(SecurityBase): self.scheme_name = scheme_name or self.__class__.__name__ self.auto_error = auto_error + @handle_exc_for_ws async def __call__( - self, request: Request + self, request: HTTPConnection ) -> Optional[HTTPAuthorizationCredentials]: authorization = request.headers.get("Authorization") scheme, credentials = get_authorization_scheme_param(authorization) @@ -185,9 +186,8 @@ class HTTPBasic(HTTPBase): self.realm = realm self.auto_error = auto_error - async def __call__( # type: ignore - self, request: Request - ) -> Optional[HTTPBasicCredentials]: + @handle_exc_for_ws + async def __call__(self, request: HTTPConnection) -> Optional[HTTPBasicCredentials]: authorization = request.headers.get("Authorization") scheme, param = get_authorization_scheme_param(authorization) if self.realm: @@ -299,8 +299,9 @@ class HTTPBearer(HTTPBase): self.scheme_name = scheme_name or self.__class__.__name__ self.auto_error = auto_error + @handle_exc_for_ws async def __call__( - self, request: Request + self, request: HTTPConnection ) -> Optional[HTTPAuthorizationCredentials]: authorization = request.headers.get("Authorization") scheme, credentials = get_authorization_scheme_param(authorization) @@ -401,8 +402,9 @@ class HTTPDigest(HTTPBase): self.scheme_name = scheme_name or self.__class__.__name__ self.auto_error = auto_error + @handle_exc_for_ws async def __call__( - self, request: Request + self, request: HTTPConnection ) -> Optional[HTTPAuthorizationCredentials]: authorization = request.headers.get("Authorization") scheme, credentials = get_authorization_scheme_param(authorization) diff --git a/fastapi/security/oauth2.py b/fastapi/security/oauth2.py index f8d97d762..8c15e905f 100644 --- a/fastapi/security/oauth2.py +++ b/fastapi/security/oauth2.py @@ -6,8 +6,8 @@ from fastapi.openapi.models import OAuth2 as OAuth2Model from fastapi.openapi.models import OAuthFlows as OAuthFlowsModel from fastapi.param_functions import Form from fastapi.security.base import SecurityBase -from fastapi.security.utils import get_authorization_scheme_param -from starlette.requests import Request +from fastapi.security.utils import get_authorization_scheme_param, handle_exc_for_ws +from starlette.requests import HTTPConnection from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN # TODO: import from typing when deprecating Python 3.9 @@ -377,7 +377,8 @@ class OAuth2(SecurityBase): self.scheme_name = scheme_name or self.__class__.__name__ self.auto_error = auto_error - async def __call__(self, request: Request) -> Optional[str]: + @handle_exc_for_ws + async def __call__(self, request: HTTPConnection) -> Optional[str]: authorization = request.headers.get("Authorization") if not authorization: if self.auto_error: @@ -486,7 +487,8 @@ class OAuth2PasswordBearer(OAuth2): auto_error=auto_error, ) - async def __call__(self, request: Request) -> Optional[str]: + @handle_exc_for_ws + async def __call__(self, request: HTTPConnection) -> Optional[str]: authorization = request.headers.get("Authorization") scheme, param = get_authorization_scheme_param(authorization) if not authorization or scheme.lower() != "bearer": @@ -596,7 +598,8 @@ class OAuth2AuthorizationCodeBearer(OAuth2): auto_error=auto_error, ) - async def __call__(self, request: Request) -> Optional[str]: + @handle_exc_for_ws + async def __call__(self, request: HTTPConnection) -> Optional[str]: authorization = request.headers.get("Authorization") scheme, param = get_authorization_scheme_param(authorization) if not authorization or scheme.lower() != "bearer": diff --git a/fastapi/security/open_id_connect_url.py b/fastapi/security/open_id_connect_url.py index 5e99798e6..291160193 100644 --- a/fastapi/security/open_id_connect_url.py +++ b/fastapi/security/open_id_connect_url.py @@ -3,8 +3,9 @@ from typing import Optional from annotated_doc import Doc from fastapi.openapi.models import OpenIdConnect as OpenIdConnectModel from fastapi.security.base import SecurityBase +from fastapi.security.utils import handle_exc_for_ws from starlette.exceptions import HTTPException -from starlette.requests import Request +from starlette.requests import HTTPConnection from starlette.status import HTTP_403_FORBIDDEN from typing_extensions import Annotated @@ -73,7 +74,8 @@ class OpenIdConnect(SecurityBase): self.scheme_name = scheme_name or self.__class__.__name__ self.auto_error = auto_error - async def __call__(self, request: Request) -> Optional[str]: + @handle_exc_for_ws + async def __call__(self, request: HTTPConnection) -> Optional[str]: authorization = request.headers.get("Authorization") if not authorization: if self.auto_error: diff --git a/fastapi/security/utils.py b/fastapi/security/utils.py index fa7a450b7..2a0849303 100644 --- a/fastapi/security/utils.py +++ b/fastapi/security/utils.py @@ -1,4 +1,10 @@ -from typing import Optional, Tuple +from functools import wraps +from typing import Any, Awaitable, Callable, Optional, Tuple, TypeVar + +from fastapi.exceptions import HTTPException, WebSocketException +from starlette.requests import HTTPConnection +from starlette.status import WS_1008_POLICY_VIOLATION +from starlette.websockets import WebSocket def get_authorization_scheme_param( @@ -8,3 +14,24 @@ def get_authorization_scheme_param( return "", "" scheme, _, param = authorization_header_value.partition(" ") return scheme, param + + +_SecurityDepFunc = TypeVar( + "_SecurityDepFunc", bound=Callable[[Any, HTTPConnection], Awaitable] +) + + +def handle_exc_for_ws(func: _SecurityDepFunc) -> _SecurityDepFunc: + @wraps(func) + async def wrapper(self, request: HTTPConnection, *args, **kwargs): + try: + return await func(self, request, *args, **kwargs) + except HTTPException as e: + if not isinstance(request, WebSocket): + raise e + await request.accept() + raise WebSocketException( + code=WS_1008_POLICY_VIOLATION, reason=e.detail + ) from None + + return wrapper # type: ignore