mirror of https://github.com/tiangolo/fastapi.git
test: add test for fix/allow-callable_get_request_handler
This commit is contained in:
parent
f8074c72d9
commit
4c5c0f60d2
|
|
@ -0,0 +1,59 @@
|
|||
from typing import Any, Callable
|
||||
from functools import update_wrapper
|
||||
from fastapi import Depends, FastAPI
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
from fastapi.routing import APIRoute
|
||||
from fastapi.testclient import TestClient
|
||||
from starlette.exceptions import HTTPException
|
||||
|
||||
class EndpointWrapper(Callable[..., Any]):
|
||||
def __init__(self, endpoint: Callable[..., Any]):
|
||||
self.endpoint = endpoint
|
||||
self.protected = False
|
||||
update_wrapper(self, endpoint)
|
||||
|
||||
async def __call__(self, *args, **kwargs):
|
||||
return await self.endpoint(*args, **kwargs)
|
||||
|
||||
def dummy_secruity_check(token: HTTPAuthorizationCredentials = Depends(HTTPBearer())):
|
||||
if token.credentials != "fake-token":
|
||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||
|
||||
def protect(endpoint: Callable[..., Any]):
|
||||
if not isinstance(endpoint, EndpointWrapper):
|
||||
endpoint = EndpointWrapper(endpoint)
|
||||
endpoint.protected = True
|
||||
return endpoint
|
||||
|
||||
class CustomAPIRoute(APIRoute):
|
||||
def __init__(self, path: str, endpoint: Callable[..., Any], dependencies=None, **kwargs) -> None:
|
||||
if dependencies is None:
|
||||
dependencies = []
|
||||
if (
|
||||
isinstance(endpoint, EndpointWrapper)
|
||||
and endpoint.protected
|
||||
):
|
||||
dependencies.append(Depends(dummy_secruity_check))
|
||||
super().__init__(path, endpoint, dependencies=dependencies, **kwargs)
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
app.router.route_class = CustomAPIRoute
|
||||
|
||||
@app.get("/protected")
|
||||
@protect
|
||||
async def protected_route():
|
||||
return {"message": "This is a protected route"}
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
def test_protected_route():
|
||||
response = client.get("/protected")
|
||||
assert response.status_code == 403
|
||||
|
||||
response = client.get("/protected", headers={"Authorization": "Bearer some-token"})
|
||||
assert response.status_code == 401
|
||||
|
||||
response = client.get("/protected", headers={"Authorization": "Bearer fake-token"})
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"message": "This is a protected route"}
|
||||
Loading…
Reference in New Issue