🎨 [pre-commit.ci] Auto format from pre-commit.com hooks

This commit is contained in:
pre-commit-ci[bot] 2024-05-01 10:22:21 +00:00
parent 4c5c0f60d2
commit f644f72306
1 changed files with 16 additions and 8 deletions

View File

@ -1,11 +1,13 @@
from typing import Any, Callable
from functools import update_wrapper
from typing import Any, Callable
from fastapi import Depends, FastAPI
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from fastapi.routing import APIRoute
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from fastapi.testclient import TestClient
from starlette.exceptions import HTTPException
class EndpointWrapper(Callable[..., Any]):
def __init__(self, endpoint: Callable[..., Any]):
self.endpoint = endpoint
@ -14,39 +16,45 @@ class EndpointWrapper(Callable[..., Any]):
async def __call__(self, *args, **kwargs):
return await self.endpoint(*args, **kwargs)
def dummy_secruity_check(token: HTTPAuthorizationCredentials = Depends(HTTPBearer())):
if token.credentials != "fake-token":
raise HTTPException(status_code=401, detail="Unauthorized")
def protect(endpoint: Callable[..., Any]):
if not isinstance(endpoint, EndpointWrapper):
endpoint = EndpointWrapper(endpoint)
endpoint.protected = True
return endpoint
class CustomAPIRoute(APIRoute):
def __init__(self, path: str, endpoint: Callable[..., Any], dependencies=None, **kwargs) -> None:
def __init__(
self, path: str, endpoint: Callable[..., Any], dependencies=None, **kwargs
) -> None:
if dependencies is None:
dependencies = []
if (
isinstance(endpoint, EndpointWrapper)
and endpoint.protected
):
if isinstance(endpoint, EndpointWrapper) and endpoint.protected:
dependencies.append(Depends(dummy_secruity_check))
super().__init__(path, endpoint, dependencies=dependencies, **kwargs)
app = FastAPI()
app.router.route_class = CustomAPIRoute
@app.get("/protected")
@protect
async def protected_route():
return {"message": "This is a protected route"}
client = TestClient(app)
def test_protected_route():
response = client.get("/protected")
assert response.status_code == 403