From 4f418eefe0d21a703289fd3521cf1e3b59252e81 Mon Sep 17 00:00:00 2001 From: John Mahoney Date: Mon, 21 Apr 2025 12:17:41 -0500 Subject: [PATCH] Refactor some stuff --- fastapi/applications.py | 11 ++++++ fastapi/routing.py | 36 +++++++++++-------- .../test_bigger_applications/test_main.py | 2 +- tests/test_ws_router.py | 35 +++++++++++++----- 4 files changed, 59 insertions(+), 25 deletions(-) diff --git a/fastapi/applications.py b/fastapi/applications.py index 6d427cdc2..5c2d31a5d 100644 --- a/fastapi/applications.py +++ b/fastapi/applications.py @@ -1179,12 +1179,14 @@ class FastAPI(Starlette): name: Optional[str] = None, *, dependencies: Optional[Sequence[Depends]] = None, + tags: Optional[List[Union[str, Enum]]] = None, ) -> None: self.router.add_api_websocket_route( path, endpoint, name=name, dependencies=dependencies, + tags=tags, ) def websocket( @@ -1218,6 +1220,14 @@ class FastAPI(Starlette): """ ), ] = None, + tags: Annotated[ + Optional[List[Union[str, Enum]]], + Doc( + """ + A list of tags to be applied to this WebSocket. + """ + ) + ] = None ) -> Callable[[DecoratedCallable], DecoratedCallable]: """ Decorate a WebSocket function. @@ -1247,6 +1257,7 @@ class FastAPI(Starlette): func, name=name, dependencies=dependencies, + tags=tags, ) return func diff --git a/fastapi/routing.py b/fastapi/routing.py index 9ee4f0d97..dfd206099 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -11,6 +11,7 @@ from typing import ( Callable, Coroutine, Dict, + Iterable, List, Mapping, Optional, @@ -920,9 +921,7 @@ class APIRouter(routing.Router): current_response_class = get_value_or_default( response_class, self.default_response_class ) - current_tags = self.tags.copy() - if tags: - current_tags.extend(tags) + current_tags = self.combine_tags(tags or []) current_dependencies = self.dependencies.copy() if dependencies: current_dependencies.extend(dependencies) @@ -937,7 +936,7 @@ class APIRouter(routing.Router): endpoint=endpoint, response_model=response_model, status_code=status_code, - tags=current_tags, + tags=list(current_tags), dependencies=current_dependencies, summary=summary, description=description, @@ -1037,11 +1036,13 @@ class APIRouter(routing.Router): if dependencies: current_dependencies.extend(dependencies) + current_tags = self.combine_tags(tags) + route = APIWebSocketRoute( self.prefix + path, endpoint=endpoint, name=name, - tags=tags, + tags=current_tags, dependencies=current_dependencies, dependency_overrides_provider=self.dependency_overrides_provider, ) @@ -1290,11 +1291,7 @@ class APIRouter(routing.Router): default_response_class, self.default_response_class, ) - current_tags = [] - if tags: - current_tags.extend(tags) - if route.tags: - current_tags.extend(route.tags) + current_tags = self.combine_tags(tags, route) current_dependencies: List[params.Depends] = [] if dependencies: current_dependencies.extend(dependencies) @@ -1357,11 +1354,7 @@ class APIRouter(routing.Router): if route.dependencies: current_dependencies.extend(route.dependencies) - current_tags = [] - if tags: - current_tags.extend(tags) - if route.tags: - current_tags.extend(route.tags) + current_tags = self.combine_tags(tags, route) self.add_api_websocket_route( prefix + route.path, route.endpoint, @@ -4456,3 +4449,16 @@ class APIRouter(routing.Router): return func return decorator + + def combine_tags(self, *entities): + tags = set(self.tags or []) + for entity in entities: + if entity is None: + continue + if isinstance(entity, str): + tags.add(entity) + elif isinstance(entity, Iterable): + tags.update(entity) + elif hasattr(entity, "tags"): + tags = tags.union(entity.tags) + return sorted(tags) diff --git a/tests/test_tutorial/test_bigger_applications/test_main.py b/tests/test_tutorial/test_bigger_applications/test_main.py index fe40fad7d..6deedb57e 100644 --- a/tests/test_tutorial/test_bigger_applications/test_main.py +++ b/tests/test_tutorial/test_bigger_applications/test_main.py @@ -580,7 +580,7 @@ def test_openapi_schema(client: TestClient): }, }, "put": { - "tags": ["items", "custom"], + "tags": ["custom", "items"], "summary": "Update Item", "operationId": "update_item_items__item_id__put", "parameters": [ diff --git a/tests/test_ws_router.py b/tests/test_ws_router.py index bc3988326..740a3ccae 100644 --- a/tests/test_ws_router.py +++ b/tests/test_ws_router.py @@ -13,9 +13,9 @@ from fastapi import ( from fastapi.middleware import Middleware from fastapi.testclient import TestClient -router = APIRouter() -prefix_router = APIRouter() -native_prefix_route = APIRouter(prefix="/native") +router = APIRouter(tags=['base']) +prefix_router = APIRouter(tags=['prefix']) +native_prefix_router = APIRouter(prefix="/native", tags=['native']) app = FastAPI() @@ -68,7 +68,7 @@ async def router_ws_decorator_depends( await websocket.close() -@native_prefix_route.websocket("/") +@native_prefix_router.websocket("/") async def router_native_prefix_ws(websocket: WebSocket): await websocket.accept() await websocket.send_text("Hello, router with native prefix!") @@ -104,16 +104,27 @@ async def router_ws_custom_error(websocket: WebSocket): raise CustomError() -@router.websocket("/test_tags/", name="test-tags", tags=["test"]) +@app.websocket("/test_tags", name="test-app-tags", tags=["test-app-tags"]) + +@router.websocket("/test_tags/", name="test-router-tags", tags=["test-router-tags"]) async def router_ws_test_tags(websocket: WebSocket): pass # pragma: no cover +@prefix_router.websocket("/test_tags/", name="test-prefix-router-tags", tags=["test-prefix-router-tags"]) +async def prefix_router_ws_test_tags(websocket: WebSocket): + pass # pragma: no cover + +@native_prefix_router.websocket("/test_tags/", name="test-native-prefix-router-tags", + tags=["test-native-prefix-router-tags"]) +async def native_prefix_router_ws_test_tags(websocket: WebSocket): + pass # pragma: no cover + def make_app(app=None, **kwargs): app = app or FastAPI(**kwargs) app.include_router(router) app.include_router(prefix_router, prefix="/prefix") - app.include_router(native_prefix_route) + app.include_router(native_prefix_router) return app @@ -276,9 +287,15 @@ def test_depend_err_handler(): assert "foo" in e.value.reason -def test_websocket_tags(): +@pytest.mark.parametrize("route_name,route_tags", [ + ("test-app-tags", ["test-app-tags"]), + ("test-router-tags", ["base", "test-router-tags"]), + ("test-prefix-router-tags", ["prefix", "test-prefix-router-tags"]), + ("test-native-prefix-router-tags", ["native", "test-native-prefix-router-tags"]), +]) +def test_websocket_tags(route_name, route_tags): """ Verify that it is possible to add tags to websocket routes """ - route = next(route for route in app.routes if route.name == "test-tags") - assert route.tags == ["test"] + route = next(route for route in app.routes if route.name == route_name) + assert sorted(route.tags) == sorted(route_tags)