mirror of https://github.com/tiangolo/fastapi.git
Merge b2855f0326 into 272204c0c7
This commit is contained in:
commit
f99f5ce28d
|
|
@ -1264,12 +1264,14 @@ class FastAPI(Starlette):
|
||||||
name: Optional[str] = None,
|
name: Optional[str] = None,
|
||||||
*,
|
*,
|
||||||
dependencies: Optional[Sequence[Depends]] = None,
|
dependencies: Optional[Sequence[Depends]] = None,
|
||||||
|
tags: Optional[List[Union[str, Enum]]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.router.add_api_websocket_route(
|
self.router.add_api_websocket_route(
|
||||||
path,
|
path,
|
||||||
endpoint,
|
endpoint,
|
||||||
name=name,
|
name=name,
|
||||||
dependencies=dependencies,
|
dependencies=dependencies,
|
||||||
|
tags=tags,
|
||||||
)
|
)
|
||||||
|
|
||||||
def websocket(
|
def websocket(
|
||||||
|
|
@ -1303,6 +1305,14 @@ class FastAPI(Starlette):
|
||||||
"""
|
"""
|
||||||
),
|
),
|
||||||
] = None,
|
] = None,
|
||||||
|
tags: Annotated[
|
||||||
|
Optional[List[Union[str, Enum]]],
|
||||||
|
Doc(
|
||||||
|
"""
|
||||||
|
A list of tags to be applied to this WebSocket.
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
] = None,
|
||||||
) -> Callable[[DecoratedCallable], DecoratedCallable]:
|
) -> Callable[[DecoratedCallable], DecoratedCallable]:
|
||||||
"""
|
"""
|
||||||
Decorate a WebSocket function.
|
Decorate a WebSocket function.
|
||||||
|
|
@ -1332,6 +1342,7 @@ class FastAPI(Starlette):
|
||||||
func,
|
func,
|
||||||
name=name,
|
name=name,
|
||||||
dependencies=dependencies,
|
dependencies=dependencies,
|
||||||
|
tags=tags,
|
||||||
)
|
)
|
||||||
return func
|
return func
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -13,6 +13,7 @@ from typing import (
|
||||||
Collection,
|
Collection,
|
||||||
Coroutine,
|
Coroutine,
|
||||||
Dict,
|
Dict,
|
||||||
|
Iterable,
|
||||||
List,
|
List,
|
||||||
Mapping,
|
Mapping,
|
||||||
Optional,
|
Optional,
|
||||||
|
|
@ -522,12 +523,14 @@ class APIWebSocketRoute(routing.WebSocketRoute):
|
||||||
endpoint: Callable[..., Any],
|
endpoint: Callable[..., Any],
|
||||||
*,
|
*,
|
||||||
name: Optional[str] = None,
|
name: Optional[str] = None,
|
||||||
|
tags: Optional[List[Union[str, Enum]]] = None,
|
||||||
dependencies: Optional[Sequence[params.Depends]] = None,
|
dependencies: Optional[Sequence[params.Depends]] = None,
|
||||||
dependency_overrides_provider: Optional[Any] = None,
|
dependency_overrides_provider: Optional[Any] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.path = path
|
self.path = path
|
||||||
self.endpoint = endpoint
|
self.endpoint = endpoint
|
||||||
self.name = get_name(endpoint) if name is None else name
|
self.name = get_name(endpoint) if name is None else name
|
||||||
|
self.tags: List[Union[str, Enum]] = tags or []
|
||||||
self.dependencies = list(dependencies or [])
|
self.dependencies = list(dependencies or [])
|
||||||
self.path_regex, self.path_format, self.param_convertors = compile_path(path)
|
self.path_regex, self.path_format, self.param_convertors = compile_path(path)
|
||||||
self.dependant = get_dependant(
|
self.dependant = get_dependant(
|
||||||
|
|
@ -1052,9 +1055,7 @@ class APIRouter(routing.Router):
|
||||||
current_response_class = get_value_or_default(
|
current_response_class = get_value_or_default(
|
||||||
response_class, self.default_response_class
|
response_class, self.default_response_class
|
||||||
)
|
)
|
||||||
current_tags = self.tags.copy()
|
current_tags = self.combine_tags(tags or [])
|
||||||
if tags:
|
|
||||||
current_tags.extend(tags)
|
|
||||||
current_dependencies = self.dependencies.copy()
|
current_dependencies = self.dependencies.copy()
|
||||||
if dependencies:
|
if dependencies:
|
||||||
current_dependencies.extend(dependencies)
|
current_dependencies.extend(dependencies)
|
||||||
|
|
@ -1069,7 +1070,7 @@ class APIRouter(routing.Router):
|
||||||
endpoint=endpoint,
|
endpoint=endpoint,
|
||||||
response_model=response_model,
|
response_model=response_model,
|
||||||
status_code=status_code,
|
status_code=status_code,
|
||||||
tags=current_tags,
|
tags=list(current_tags),
|
||||||
dependencies=current_dependencies,
|
dependencies=current_dependencies,
|
||||||
summary=summary,
|
summary=summary,
|
||||||
description=description,
|
description=description,
|
||||||
|
|
@ -1162,16 +1163,20 @@ class APIRouter(routing.Router):
|
||||||
endpoint: Callable[..., Any],
|
endpoint: Callable[..., Any],
|
||||||
name: Optional[str] = None,
|
name: Optional[str] = None,
|
||||||
*,
|
*,
|
||||||
|
tags: Optional[List[Union[str, Enum]]] = None,
|
||||||
dependencies: Optional[Sequence[params.Depends]] = None,
|
dependencies: Optional[Sequence[params.Depends]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
current_dependencies = self.dependencies.copy()
|
current_dependencies = self.dependencies.copy()
|
||||||
if dependencies:
|
if dependencies:
|
||||||
current_dependencies.extend(dependencies)
|
current_dependencies.extend(dependencies)
|
||||||
|
|
||||||
|
current_tags = self.combine_tags(tags)
|
||||||
|
|
||||||
route = APIWebSocketRoute(
|
route = APIWebSocketRoute(
|
||||||
self.prefix + path,
|
self.prefix + path,
|
||||||
endpoint=endpoint,
|
endpoint=endpoint,
|
||||||
name=name,
|
name=name,
|
||||||
|
tags=current_tags,
|
||||||
dependencies=current_dependencies,
|
dependencies=current_dependencies,
|
||||||
dependency_overrides_provider=self.dependency_overrides_provider,
|
dependency_overrides_provider=self.dependency_overrides_provider,
|
||||||
)
|
)
|
||||||
|
|
@ -1196,6 +1201,14 @@ class APIRouter(routing.Router):
|
||||||
),
|
),
|
||||||
] = None,
|
] = None,
|
||||||
*,
|
*,
|
||||||
|
tags: Annotated[
|
||||||
|
Optional[List[Union[str, Enum]]],
|
||||||
|
Doc(
|
||||||
|
"""
|
||||||
|
A list of tags to be applied to this WebSocket.
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
] = None,
|
||||||
dependencies: Annotated[
|
dependencies: Annotated[
|
||||||
Optional[Sequence[params.Depends]],
|
Optional[Sequence[params.Depends]],
|
||||||
Doc(
|
Doc(
|
||||||
|
|
@ -1238,7 +1251,7 @@ class APIRouter(routing.Router):
|
||||||
|
|
||||||
def decorator(func: DecoratedCallable) -> DecoratedCallable:
|
def decorator(func: DecoratedCallable) -> DecoratedCallable:
|
||||||
self.add_api_websocket_route(
|
self.add_api_websocket_route(
|
||||||
path, func, name=name, dependencies=dependencies
|
path, func, name=name, tags=tags, dependencies=dependencies
|
||||||
)
|
)
|
||||||
return func
|
return func
|
||||||
|
|
||||||
|
|
@ -1412,11 +1425,7 @@ class APIRouter(routing.Router):
|
||||||
default_response_class,
|
default_response_class,
|
||||||
self.default_response_class,
|
self.default_response_class,
|
||||||
)
|
)
|
||||||
current_tags = []
|
current_tags = self.combine_tags(tags, route)
|
||||||
if tags:
|
|
||||||
current_tags.extend(tags)
|
|
||||||
if route.tags:
|
|
||||||
current_tags.extend(route.tags)
|
|
||||||
current_dependencies: List[params.Depends] = []
|
current_dependencies: List[params.Depends] = []
|
||||||
if dependencies:
|
if dependencies:
|
||||||
current_dependencies.extend(dependencies)
|
current_dependencies.extend(dependencies)
|
||||||
|
|
@ -1478,11 +1487,14 @@ class APIRouter(routing.Router):
|
||||||
current_dependencies.extend(dependencies)
|
current_dependencies.extend(dependencies)
|
||||||
if route.dependencies:
|
if route.dependencies:
|
||||||
current_dependencies.extend(route.dependencies)
|
current_dependencies.extend(route.dependencies)
|
||||||
|
|
||||||
|
current_tags = self.combine_tags(tags, route)
|
||||||
self.add_api_websocket_route(
|
self.add_api_websocket_route(
|
||||||
prefix + route.path,
|
prefix + route.path,
|
||||||
route.endpoint,
|
route.endpoint,
|
||||||
dependencies=current_dependencies,
|
dependencies=current_dependencies,
|
||||||
name=route.name,
|
name=route.name,
|
||||||
|
tags=current_tags,
|
||||||
)
|
)
|
||||||
elif isinstance(route, routing.WebSocketRoute):
|
elif isinstance(route, routing.WebSocketRoute):
|
||||||
self.add_websocket_route(
|
self.add_websocket_route(
|
||||||
|
|
@ -4571,3 +4583,27 @@ class APIRouter(routing.Router):
|
||||||
return func
|
return func
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
def combine_tags(
|
||||||
|
self,
|
||||||
|
*entities: Annotated[
|
||||||
|
Union[None, str, routing.Route, Sequence],
|
||||||
|
Doc(
|
||||||
|
"""
|
||||||
|
Combine the router's current tags with those of the provided entities.
|
||||||
|
Supports None, strings, iterables, and Route objects with a `tags` attribute.
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
],
|
||||||
|
) -> List[str]:
|
||||||
|
tags = set(self.tags or [])
|
||||||
|
for entity in entities:
|
||||||
|
if entity is None:
|
||||||
|
continue
|
||||||
|
if isinstance(entity, str):
|
||||||
|
tags.add(entity)
|
||||||
|
elif isinstance(entity, Iterable):
|
||||||
|
tags.update(entity)
|
||||||
|
elif hasattr(entity, "tags"):
|
||||||
|
tags = tags.union(entity.tags)
|
||||||
|
return sorted(tags)
|
||||||
|
|
|
||||||
|
|
@ -580,7 +580,7 @@ def test_openapi_schema(client: TestClient):
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"put": {
|
"put": {
|
||||||
"tags": ["items", "custom"],
|
"tags": ["custom", "items"],
|
||||||
"summary": "Update Item",
|
"summary": "Update Item",
|
||||||
"operationId": "update_item_items__item_id__put",
|
"operationId": "update_item_items__item_id__put",
|
||||||
"parameters": [
|
"parameters": [
|
||||||
|
|
|
||||||
|
|
@ -13,9 +13,9 @@ from fastapi import (
|
||||||
from fastapi.middleware import Middleware
|
from fastapi.middleware import Middleware
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter(tags=["base"])
|
||||||
prefix_router = APIRouter()
|
prefix_router = APIRouter(tags=["prefix"])
|
||||||
native_prefix_route = APIRouter(prefix="/native")
|
native_prefix_router = APIRouter(prefix="/native", tags=["native"])
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -68,7 +68,7 @@ async def router_ws_decorator_depends(
|
||||||
await websocket.close()
|
await websocket.close()
|
||||||
|
|
||||||
|
|
||||||
@native_prefix_route.websocket("/")
|
@native_prefix_router.websocket("/")
|
||||||
async def router_native_prefix_ws(websocket: WebSocket):
|
async def router_native_prefix_ws(websocket: WebSocket):
|
||||||
await websocket.accept()
|
await websocket.accept()
|
||||||
await websocket.send_text("Hello, router with native prefix!")
|
await websocket.send_text("Hello, router with native prefix!")
|
||||||
|
|
@ -104,11 +104,33 @@ async def router_ws_custom_error(websocket: WebSocket):
|
||||||
raise CustomError()
|
raise CustomError()
|
||||||
|
|
||||||
|
|
||||||
|
@app.websocket("/test_tags", name="test-app-tags", tags=["test-app-tags"])
|
||||||
|
@router.websocket("/test_tags/", name="test-router-tags", tags=["test-router-tags"])
|
||||||
|
async def router_ws_test_tags(websocket: WebSocket):
|
||||||
|
pass # pragma: no cover
|
||||||
|
|
||||||
|
|
||||||
|
@prefix_router.websocket(
|
||||||
|
"/test_tags/", name="test-prefix-router-tags", tags=["test-prefix-router-tags"]
|
||||||
|
)
|
||||||
|
async def prefix_router_ws_test_tags(websocket: WebSocket):
|
||||||
|
pass # pragma: no cover
|
||||||
|
|
||||||
|
|
||||||
|
@native_prefix_router.websocket(
|
||||||
|
"/test_tags/",
|
||||||
|
name="test-native-prefix-router-tags",
|
||||||
|
tags=["test-native-prefix-router-tags"],
|
||||||
|
)
|
||||||
|
async def native_prefix_router_ws_test_tags(websocket: WebSocket):
|
||||||
|
pass # pragma: no cover
|
||||||
|
|
||||||
|
|
||||||
def make_app(app=None, **kwargs):
|
def make_app(app=None, **kwargs):
|
||||||
app = app or FastAPI(**kwargs)
|
app = app or FastAPI(**kwargs)
|
||||||
app.include_router(router)
|
app.include_router(router)
|
||||||
app.include_router(prefix_router, prefix="/prefix")
|
app.include_router(prefix_router, prefix="/prefix")
|
||||||
app.include_router(native_prefix_route)
|
app.include_router(native_prefix_router)
|
||||||
return app
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -269,3 +291,23 @@ def test_depend_err_handler():
|
||||||
pass # pragma: no cover
|
pass # pragma: no cover
|
||||||
assert e.value.code == 1002
|
assert e.value.code == 1002
|
||||||
assert "foo" in e.value.reason
|
assert "foo" in e.value.reason
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"route_name,route_tags",
|
||||||
|
[
|
||||||
|
("test-app-tags", ["test-app-tags"]),
|
||||||
|
("test-router-tags", ["base", "test-router-tags"]),
|
||||||
|
("test-prefix-router-tags", ["prefix", "test-prefix-router-tags"]),
|
||||||
|
(
|
||||||
|
"test-native-prefix-router-tags",
|
||||||
|
["native", "test-native-prefix-router-tags"],
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_websocket_tags(route_name, route_tags):
|
||||||
|
"""
|
||||||
|
Verify that it is possible to add tags to websocket routes
|
||||||
|
"""
|
||||||
|
route = next(route for route in app.routes if route.name == route_name)
|
||||||
|
assert sorted(route.tags) == sorted(route_tags)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue