mirror of https://github.com/tiangolo/fastapi.git
✨ Add support for `dependencies` in WebSocket routes (#4534)
Co-authored-by: Sebastián Ramírez <tiangolo@gmail.com>
This commit is contained in:
parent
ee96a099d8
commit
d8b8f211e8
|
|
@ -401,15 +401,34 @@ class FastAPI(Starlette):
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
def add_api_websocket_route(
|
def add_api_websocket_route(
|
||||||
self, path: str, endpoint: Callable[..., Any], name: Optional[str] = None
|
self,
|
||||||
|
path: str,
|
||||||
|
endpoint: Callable[..., Any],
|
||||||
|
name: Optional[str] = None,
|
||||||
|
*,
|
||||||
|
dependencies: Optional[Sequence[Depends]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.router.add_api_websocket_route(path, endpoint, name=name)
|
self.router.add_api_websocket_route(
|
||||||
|
path,
|
||||||
|
endpoint,
|
||||||
|
name=name,
|
||||||
|
dependencies=dependencies,
|
||||||
|
)
|
||||||
|
|
||||||
def websocket(
|
def websocket(
|
||||||
self, path: str, name: Optional[str] = None
|
self,
|
||||||
|
path: str,
|
||||||
|
name: Optional[str] = None,
|
||||||
|
*,
|
||||||
|
dependencies: Optional[Sequence[Depends]] = None,
|
||||||
) -> Callable[[DecoratedCallable], DecoratedCallable]:
|
) -> Callable[[DecoratedCallable], DecoratedCallable]:
|
||||||
def decorator(func: DecoratedCallable) -> DecoratedCallable:
|
def decorator(func: DecoratedCallable) -> DecoratedCallable:
|
||||||
self.add_api_websocket_route(path, func, name=name)
|
self.add_api_websocket_route(
|
||||||
|
path,
|
||||||
|
func,
|
||||||
|
name=name,
|
||||||
|
dependencies=dependencies,
|
||||||
|
)
|
||||||
return func
|
return func
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
|
||||||
|
|
@ -296,13 +296,21 @@ class APIWebSocketRoute(routing.WebSocketRoute):
|
||||||
endpoint: Callable[..., Any],
|
endpoint: Callable[..., Any],
|
||||||
*,
|
*,
|
||||||
name: Optional[str] = None,
|
name: Optional[str] = None,
|
||||||
|
dependencies: Optional[Sequence[params.Depends]] = None,
|
||||||
dependency_overrides_provider: Optional[Any] = None,
|
dependency_overrides_provider: Optional[Any] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.path = path
|
self.path = path
|
||||||
self.endpoint = endpoint
|
self.endpoint = endpoint
|
||||||
self.name = get_name(endpoint) if name is None else name
|
self.name = get_name(endpoint) if name is None else name
|
||||||
|
self.dependencies = list(dependencies or [])
|
||||||
self.path_regex, self.path_format, self.param_convertors = compile_path(path)
|
self.path_regex, self.path_format, self.param_convertors = compile_path(path)
|
||||||
self.dependant = get_dependant(path=self.path_format, call=self.endpoint)
|
self.dependant = get_dependant(path=self.path_format, call=self.endpoint)
|
||||||
|
for depends in self.dependencies[::-1]:
|
||||||
|
self.dependant.dependencies.insert(
|
||||||
|
0,
|
||||||
|
get_parameterless_sub_dependant(depends=depends, path=self.path_format),
|
||||||
|
)
|
||||||
|
|
||||||
self.app = websocket_session(
|
self.app = websocket_session(
|
||||||
get_websocket_app(
|
get_websocket_app(
|
||||||
dependant=self.dependant,
|
dependant=self.dependant,
|
||||||
|
|
@ -416,10 +424,7 @@ class APIRoute(routing.Route):
|
||||||
else:
|
else:
|
||||||
self.response_field = None # type: ignore
|
self.response_field = None # type: ignore
|
||||||
self.secure_cloned_response_field = None
|
self.secure_cloned_response_field = None
|
||||||
if dependencies:
|
self.dependencies = list(dependencies or [])
|
||||||
self.dependencies = list(dependencies)
|
|
||||||
else:
|
|
||||||
self.dependencies = []
|
|
||||||
self.description = description or inspect.cleandoc(self.endpoint.__doc__ or "")
|
self.description = description or inspect.cleandoc(self.endpoint.__doc__ or "")
|
||||||
# if a "form feed" character (page break) is found in the description text,
|
# if a "form feed" character (page break) is found in the description text,
|
||||||
# truncate description text to the content preceding the first "form feed"
|
# truncate description text to the content preceding the first "form feed"
|
||||||
|
|
@ -514,7 +519,7 @@ class APIRouter(routing.Router):
|
||||||
), "A path prefix must not end with '/', as the routes will start with '/'"
|
), "A path prefix must not end with '/', as the routes will start with '/'"
|
||||||
self.prefix = prefix
|
self.prefix = prefix
|
||||||
self.tags: List[Union[str, Enum]] = tags or []
|
self.tags: List[Union[str, Enum]] = tags or []
|
||||||
self.dependencies = list(dependencies or []) or []
|
self.dependencies = list(dependencies or [])
|
||||||
self.deprecated = deprecated
|
self.deprecated = deprecated
|
||||||
self.include_in_schema = include_in_schema
|
self.include_in_schema = include_in_schema
|
||||||
self.responses = responses or {}
|
self.responses = responses or {}
|
||||||
|
|
@ -688,21 +693,37 @@ class APIRouter(routing.Router):
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
def add_api_websocket_route(
|
def add_api_websocket_route(
|
||||||
self, path: str, endpoint: Callable[..., Any], name: Optional[str] = None
|
self,
|
||||||
|
path: str,
|
||||||
|
endpoint: Callable[..., Any],
|
||||||
|
name: Optional[str] = None,
|
||||||
|
*,
|
||||||
|
dependencies: Optional[Sequence[params.Depends]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
current_dependencies = self.dependencies.copy()
|
||||||
|
if dependencies:
|
||||||
|
current_dependencies.extend(dependencies)
|
||||||
|
|
||||||
route = APIWebSocketRoute(
|
route = APIWebSocketRoute(
|
||||||
self.prefix + path,
|
self.prefix + path,
|
||||||
endpoint=endpoint,
|
endpoint=endpoint,
|
||||||
name=name,
|
name=name,
|
||||||
|
dependencies=current_dependencies,
|
||||||
dependency_overrides_provider=self.dependency_overrides_provider,
|
dependency_overrides_provider=self.dependency_overrides_provider,
|
||||||
)
|
)
|
||||||
self.routes.append(route)
|
self.routes.append(route)
|
||||||
|
|
||||||
def websocket(
|
def websocket(
|
||||||
self, path: str, name: Optional[str] = None
|
self,
|
||||||
|
path: str,
|
||||||
|
name: Optional[str] = None,
|
||||||
|
*,
|
||||||
|
dependencies: Optional[Sequence[params.Depends]] = None,
|
||||||
) -> Callable[[DecoratedCallable], DecoratedCallable]:
|
) -> Callable[[DecoratedCallable], DecoratedCallable]:
|
||||||
def decorator(func: DecoratedCallable) -> DecoratedCallable:
|
def decorator(func: DecoratedCallable) -> DecoratedCallable:
|
||||||
self.add_api_websocket_route(path, func, name=name)
|
self.add_api_websocket_route(
|
||||||
|
path, func, name=name, dependencies=dependencies
|
||||||
|
)
|
||||||
return func
|
return func
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
@ -817,8 +838,16 @@ class APIRouter(routing.Router):
|
||||||
name=route.name,
|
name=route.name,
|
||||||
)
|
)
|
||||||
elif isinstance(route, APIWebSocketRoute):
|
elif isinstance(route, APIWebSocketRoute):
|
||||||
|
current_dependencies = []
|
||||||
|
if dependencies:
|
||||||
|
current_dependencies.extend(dependencies)
|
||||||
|
if route.dependencies:
|
||||||
|
current_dependencies.extend(route.dependencies)
|
||||||
self.add_api_websocket_route(
|
self.add_api_websocket_route(
|
||||||
prefix + route.path, route.endpoint, name=route.name
|
prefix + route.path,
|
||||||
|
route.endpoint,
|
||||||
|
dependencies=current_dependencies,
|
||||||
|
name=route.name,
|
||||||
)
|
)
|
||||||
elif isinstance(route, routing.WebSocketRoute):
|
elif isinstance(route, routing.WebSocketRoute):
|
||||||
self.add_websocket_route(
|
self.add_websocket_route(
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,73 @@
|
||||||
|
import json
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, FastAPI, WebSocket
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
|
|
||||||
|
def dependency_list() -> List[str]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
DepList = Annotated[List[str], Depends(dependency_list)]
|
||||||
|
|
||||||
|
|
||||||
|
def create_dependency(name: str):
|
||||||
|
def fun(deps: DepList):
|
||||||
|
deps.append(name)
|
||||||
|
|
||||||
|
return Depends(fun)
|
||||||
|
|
||||||
|
|
||||||
|
router = APIRouter(dependencies=[create_dependency("router")])
|
||||||
|
prefix_router = APIRouter(dependencies=[create_dependency("prefix_router")])
|
||||||
|
app = FastAPI(dependencies=[create_dependency("app")])
|
||||||
|
|
||||||
|
|
||||||
|
@app.websocket("/", dependencies=[create_dependency("index")])
|
||||||
|
async def index(websocket: WebSocket, deps: DepList):
|
||||||
|
await websocket.accept()
|
||||||
|
await websocket.send_text(json.dumps(deps))
|
||||||
|
await websocket.close()
|
||||||
|
|
||||||
|
|
||||||
|
@router.websocket("/router", dependencies=[create_dependency("routerindex")])
|
||||||
|
async def routerindex(websocket: WebSocket, deps: DepList):
|
||||||
|
await websocket.accept()
|
||||||
|
await websocket.send_text(json.dumps(deps))
|
||||||
|
await websocket.close()
|
||||||
|
|
||||||
|
|
||||||
|
@prefix_router.websocket("/", dependencies=[create_dependency("routerprefixindex")])
|
||||||
|
async def routerprefixindex(websocket: WebSocket, deps: DepList):
|
||||||
|
await websocket.accept()
|
||||||
|
await websocket.send_text(json.dumps(deps))
|
||||||
|
await websocket.close()
|
||||||
|
|
||||||
|
|
||||||
|
app.include_router(router, dependencies=[create_dependency("router2")])
|
||||||
|
app.include_router(
|
||||||
|
prefix_router, prefix="/prefix", dependencies=[create_dependency("prefix_router2")]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_index():
|
||||||
|
client = TestClient(app)
|
||||||
|
with client.websocket_connect("/") as websocket:
|
||||||
|
data = json.loads(websocket.receive_text())
|
||||||
|
assert data == ["app", "index"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_routerindex():
|
||||||
|
client = TestClient(app)
|
||||||
|
with client.websocket_connect("/router") as websocket:
|
||||||
|
data = json.loads(websocket.receive_text())
|
||||||
|
assert data == ["app", "router2", "router", "routerindex"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_routerprefixindex():
|
||||||
|
client = TestClient(app)
|
||||||
|
with client.websocket_connect("/prefix/") as websocket:
|
||||||
|
data = json.loads(websocket.receive_text())
|
||||||
|
assert data == ["app", "prefix_router2", "prefix_router", "routerprefixindex"]
|
||||||
Loading…
Reference in New Issue