diff --git a/docs_src/route_middleware/tutorial001.py b/docs_src/route_middleware/tutorial001.py new file mode 100644 index 0000000000..1dca98dae4 --- /dev/null +++ b/docs_src/route_middleware/tutorial001.py @@ -0,0 +1,15 @@ +from fastapi import FastAPI, Request +from fastapi.route_middleware import log_route, route_middleware, verify_jwt + +app = FastAPI() + + +@app.post("/secure") +@route_middleware(verify_jwt, log_route) +async def secure_route(req: Request, is_true: bool): + return {"status": "ok", "is_true": is_true, "user": req.user} + + +@app.post("/open") +async def open_route(is_true: bool): + return {"status": "open", "is_true": is_true} diff --git a/fastapi/route_middleware.py b/fastapi/route_middleware.py new file mode 100644 index 0000000000..cba0f1260a --- /dev/null +++ b/fastapi/route_middleware.py @@ -0,0 +1,38 @@ +from functools import wraps +from typing import Callable + +from fastapi import Request + + +def route_middleware(*middlewares: Callable): + def decorator(route_func: Callable): + @wraps(route_func) + async def wrapper(*args, **kwargs): + req = kwargs.get("req") + if req is None: + raise ValueError("Route must have 'request: Request' parameter") + + for middleware in middlewares: + result = middleware(req) + if callable(getattr(result, "__await__", None)): + await result + + return await route_func(*args, **kwargs) + + return wrapper + + return decorator + + +# Example middlewares +async def verify_jwt(req: Request): + # just a mock + if not (req.query_params.get("is_true") == "true"): + req.user = {"name": "xyz", "admin": True} + from fastapi import HTTPException, status + + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Invalid JWT") + + +def log_route(req: Request): + print(f"[LOG] Path: {req.url.path}") diff --git a/tests/test_route_middleware.py b/tests/test_route_middleware.py new file mode 100644 index 0000000000..3df3d10d1a --- /dev/null +++ b/tests/test_route_middleware.py @@ -0,0 +1,34 @@ +from fastapi import FastAPI, Request +from fastapi.route_middleware import log_route, route_middleware, verify_jwt +from fastapi.testclient import TestClient + +app = FastAPI() + + +@app.post("/secure") +@route_middleware(verify_jwt, log_route) +async def secure_route(req: Request, is_true: bool): + return {"status": "ok", "is_true": is_true, "user": req.user} + + +@app.post("/open") +async def open_route(is_true: bool): + return {"status": "open", "is_true": is_true} + + +client = TestClient(app) + + +def test_secure_route_pass(): + response = client.post("/secure?is_true=true") + assert response.status_code == 200 + + +def test_secure_route_fail(): + response = client.post("/secure?is_true=false") + assert response.status_code == 403 + + +def test_open_route(): + response = client.post("/open?is_true=false") + assert response.status_code == 200