mirror of https://github.com/tiangolo/fastapi.git
✨ Include route in scope to allow middleware and other tools to extract its information (#4603)
This commit is contained in:
parent
1ce16c2f40
commit
f5d7df3c6c
|
|
@ -13,6 +13,7 @@ from typing import (
|
|||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
|
@ -44,7 +45,7 @@ from starlette.concurrency import run_in_threadpool
|
|||
from starlette.exceptions import HTTPException
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse, Response
|
||||
from starlette.routing import BaseRoute
|
||||
from starlette.routing import BaseRoute, Match
|
||||
from starlette.routing import Mount as Mount # noqa
|
||||
from starlette.routing import (
|
||||
compile_path,
|
||||
|
|
@ -53,7 +54,7 @@ from starlette.routing import (
|
|||
websocket_session,
|
||||
)
|
||||
from starlette.status import WS_1008_POLICY_VIOLATION
|
||||
from starlette.types import ASGIApp
|
||||
from starlette.types import ASGIApp, Scope
|
||||
from starlette.websockets import WebSocket
|
||||
|
||||
|
||||
|
|
@ -296,6 +297,12 @@ class APIWebSocketRoute(routing.WebSocketRoute):
|
|||
)
|
||||
self.path_regex, self.path_format, self.param_convertors = compile_path(path)
|
||||
|
||||
def matches(self, scope: Scope) -> Tuple[Match, Scope]:
|
||||
match, child_scope = super().matches(scope)
|
||||
if match != Match.NONE:
|
||||
child_scope["route"] = self
|
||||
return match, child_scope
|
||||
|
||||
|
||||
class APIRoute(routing.Route):
|
||||
def __init__(
|
||||
|
|
@ -432,6 +439,12 @@ class APIRoute(routing.Route):
|
|||
dependency_overrides_provider=self.dependency_overrides_provider,
|
||||
)
|
||||
|
||||
def matches(self, scope: Scope) -> Tuple[Match, Scope]:
|
||||
match, child_scope = super().matches(scope)
|
||||
if match != Match.NONE:
|
||||
child_scope["route"] = self
|
||||
return match, child_scope
|
||||
|
||||
|
||||
class APIRouter(routing.Router):
|
||||
def __init__(
|
||||
|
|
|
|||
|
|
@ -0,0 +1,50 @@
|
|||
import pytest
|
||||
from fastapi import FastAPI, Request, WebSocket, WebSocketDisconnect
|
||||
from fastapi.routing import APIRoute, APIWebSocketRoute
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
@app.get("/users/{user_id}")
|
||||
async def get_user(user_id: str, request: Request):
|
||||
route: APIRoute = request.scope["route"]
|
||||
return {"user_id": user_id, "path": route.path}
|
||||
|
||||
|
||||
@app.websocket("/items/{item_id}")
|
||||
async def websocket_item(item_id: str, websocket: WebSocket):
|
||||
route: APIWebSocketRoute = websocket.scope["route"]
|
||||
await websocket.accept()
|
||||
await websocket.send_json({"item_id": item_id, "path": route.path})
|
||||
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
def test_get():
|
||||
response = client.get("/users/rick")
|
||||
assert response.status_code == 200, response.text
|
||||
assert response.json() == {"user_id": "rick", "path": "/users/{user_id}"}
|
||||
|
||||
|
||||
def test_invalid_method_doesnt_match():
|
||||
response = client.post("/users/rick")
|
||||
assert response.status_code == 405, response.text
|
||||
|
||||
|
||||
def test_invalid_path_doesnt_match():
|
||||
response = client.post("/usersx/rick")
|
||||
assert response.status_code == 404, response.text
|
||||
|
||||
|
||||
def test_websocket():
|
||||
with client.websocket_connect("/items/portal-gun") as websocket:
|
||||
data = websocket.receive_json()
|
||||
assert data == {"item_id": "portal-gun", "path": "/items/{item_id}"}
|
||||
|
||||
|
||||
def test_websocket_invalid_path_doesnt_match():
|
||||
with pytest.raises(WebSocketDisconnect):
|
||||
with client.websocket_connect("/itemsx/portal-gun") as websocket:
|
||||
websocket.receive_json()
|
||||
Loading…
Reference in New Issue