mirror of https://github.com/tiangolo/fastapi.git
Added tags to websockets.
This commit is contained in:
parent
7c75b55580
commit
3da3827227
|
|
@ -392,12 +392,14 @@ class APIWebSocketRoute(routing.WebSocketRoute):
|
||||||
endpoint: Callable[..., Any],
|
endpoint: Callable[..., Any],
|
||||||
*,
|
*,
|
||||||
name: Optional[str] = None,
|
name: Optional[str] = None,
|
||||||
|
tags: Optional[List[Union[str, Enum]]] = None,
|
||||||
dependencies: Optional[Sequence[params.Depends]] = None,
|
dependencies: Optional[Sequence[params.Depends]] = None,
|
||||||
dependency_overrides_provider: Optional[Any] = None,
|
dependency_overrides_provider: Optional[Any] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.path = path
|
self.path = path
|
||||||
self.endpoint = endpoint
|
self.endpoint = endpoint
|
||||||
self.name = get_name(endpoint) if name is None else name
|
self.name = get_name(endpoint) if name is None else name
|
||||||
|
self.tags: List[Union[str, Enum]] = tags or []
|
||||||
self.dependencies = list(dependencies or [])
|
self.dependencies = list(dependencies or [])
|
||||||
self.path_regex, self.path_format, self.param_convertors = compile_path(path)
|
self.path_regex, self.path_format, self.param_convertors = compile_path(path)
|
||||||
self.dependant = get_dependant(path=self.path_format, call=self.endpoint)
|
self.dependant = get_dependant(path=self.path_format, call=self.endpoint)
|
||||||
|
|
@ -1028,6 +1030,7 @@ class APIRouter(routing.Router):
|
||||||
endpoint: Callable[..., Any],
|
endpoint: Callable[..., Any],
|
||||||
name: Optional[str] = None,
|
name: Optional[str] = None,
|
||||||
*,
|
*,
|
||||||
|
tags: Optional[List[Union[str, Enum]]] = None,
|
||||||
dependencies: Optional[Sequence[params.Depends]] = None,
|
dependencies: Optional[Sequence[params.Depends]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
current_dependencies = self.dependencies.copy()
|
current_dependencies = self.dependencies.copy()
|
||||||
|
|
@ -1038,6 +1041,7 @@ class APIRouter(routing.Router):
|
||||||
self.prefix + path,
|
self.prefix + path,
|
||||||
endpoint=endpoint,
|
endpoint=endpoint,
|
||||||
name=name,
|
name=name,
|
||||||
|
tags=tags,
|
||||||
dependencies=current_dependencies,
|
dependencies=current_dependencies,
|
||||||
dependency_overrides_provider=self.dependency_overrides_provider,
|
dependency_overrides_provider=self.dependency_overrides_provider,
|
||||||
)
|
)
|
||||||
|
|
@ -1062,6 +1066,14 @@ class APIRouter(routing.Router):
|
||||||
),
|
),
|
||||||
] = None,
|
] = None,
|
||||||
*,
|
*,
|
||||||
|
tags: Annotated[
|
||||||
|
Optional[List[Union[str, Enum]]],
|
||||||
|
Doc(
|
||||||
|
"""
|
||||||
|
A list of tags to be applied to this WebSocket.
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
] = None,
|
||||||
dependencies: Annotated[
|
dependencies: Annotated[
|
||||||
Optional[Sequence[params.Depends]],
|
Optional[Sequence[params.Depends]],
|
||||||
Doc(
|
Doc(
|
||||||
|
|
@ -1104,7 +1116,7 @@ class APIRouter(routing.Router):
|
||||||
|
|
||||||
def decorator(func: DecoratedCallable) -> DecoratedCallable:
|
def decorator(func: DecoratedCallable) -> DecoratedCallable:
|
||||||
self.add_api_websocket_route(
|
self.add_api_websocket_route(
|
||||||
path, func, name=name, dependencies=dependencies
|
path, func, name=name, tags=tags, dependencies=dependencies
|
||||||
)
|
)
|
||||||
return func
|
return func
|
||||||
|
|
||||||
|
|
@ -1344,11 +1356,18 @@ class APIRouter(routing.Router):
|
||||||
current_dependencies.extend(dependencies)
|
current_dependencies.extend(dependencies)
|
||||||
if route.dependencies:
|
if route.dependencies:
|
||||||
current_dependencies.extend(route.dependencies)
|
current_dependencies.extend(route.dependencies)
|
||||||
|
|
||||||
|
current_tags = []
|
||||||
|
if tags:
|
||||||
|
current_tags.extend(tags)
|
||||||
|
if route.tags:
|
||||||
|
current_tags.extend(route.tags)
|
||||||
self.add_api_websocket_route(
|
self.add_api_websocket_route(
|
||||||
prefix + route.path,
|
prefix + route.path,
|
||||||
route.endpoint,
|
route.endpoint,
|
||||||
dependencies=current_dependencies,
|
dependencies=current_dependencies,
|
||||||
name=route.name,
|
name=route.name,
|
||||||
|
tags=current_tags,
|
||||||
)
|
)
|
||||||
elif isinstance(route, routing.WebSocketRoute):
|
elif isinstance(route, routing.WebSocketRoute):
|
||||||
self.add_websocket_route(
|
self.add_websocket_route(
|
||||||
|
|
|
||||||
|
|
@ -103,6 +103,12 @@ class CustomError(Exception):
|
||||||
async def router_ws_custom_error(websocket: WebSocket):
|
async def router_ws_custom_error(websocket: WebSocket):
|
||||||
raise CustomError()
|
raise CustomError()
|
||||||
|
|
||||||
|
@router.websocket("/test_tags/", name='test-tags', tags=["test"])
|
||||||
|
async def router_ws_test_tags(websocket: WebSocket):
|
||||||
|
await websocket.accept()
|
||||||
|
await websocket.send_text("Hello, router with tags!")
|
||||||
|
await websocket.close()
|
||||||
|
|
||||||
|
|
||||||
def make_app(app=None, **kwargs):
|
def make_app(app=None, **kwargs):
|
||||||
app = app or FastAPI(**kwargs)
|
app = app or FastAPI(**kwargs)
|
||||||
|
|
@ -269,3 +275,11 @@ def test_depend_err_handler():
|
||||||
pass # pragma: no cover
|
pass # pragma: no cover
|
||||||
assert e.value.code == 1002
|
assert e.value.code == 1002
|
||||||
assert "foo" in e.value.reason
|
assert "foo" in e.value.reason
|
||||||
|
|
||||||
|
|
||||||
|
def test_websocket_tags():
|
||||||
|
"""
|
||||||
|
Verify that it is possible to add tags to websocket routes
|
||||||
|
"""
|
||||||
|
route = next(route for route in app.routes if route.name == 'test-tags')
|
||||||
|
assert route.tags == ["test"]
|
||||||
Loading…
Reference in New Issue