diff --git a/fastapi/applications.py b/fastapi/applications.py index 0e31c7a2b8..78d3064ab6 100644 --- a/fastapi/applications.py +++ b/fastapi/applications.py @@ -8,6 +8,8 @@ from typing import ( from annotated_doc import Doc from fastapi import routing +from fastapi.asyncapi.docs import get_asyncapi_html +from fastapi.asyncapi.utils import get_asyncapi from fastapi.datastructures import Default, DefaultPlaceholder from fastapi.exception_handlers import ( http_exception_handler, @@ -17,8 +19,6 @@ from fastapi.exception_handlers import ( from fastapi.exceptions import RequestValidationError, WebSocketRequestValidationError from fastapi.logger import logger from fastapi.middleware.asyncexitstack import AsyncExitStackMiddleware -from fastapi.asyncapi.docs import get_asyncapi_html -from fastapi.asyncapi.utils import get_asyncapi from fastapi.openapi.docs import ( get_redoc_html, get_swagger_ui_html, diff --git a/fastapi/asyncapi/utils.py b/fastapi/asyncapi/utils.py index 065ad72d53..7b8acefc51 100644 --- a/fastapi/asyncapi/utils.py +++ b/fastapi/asyncapi/utils.py @@ -1,11 +1,10 @@ from collections.abc import Sequence from typing import Any -from pydantic import BaseModel - from fastapi import routing from fastapi.asyncapi.constants import ASYNCAPI_VERSION, REF_PREFIX from fastapi.encoders import jsonable_encoder +from pydantic import BaseModel from starlette.routing import BaseRoute @@ -52,9 +51,8 @@ def _get_fields_from_websocket_routes( routes: Sequence[BaseRoute], ) -> list[Any]: """Collect body (ModelField) params from WebSocket routes for schema generation.""" - from fastapi.dependencies.utils import get_flat_dependant - from fastapi._compat import ModelField + from fastapi.dependencies.utils import get_flat_dependant from pydantic.fields import FieldInfo fields: list[Any] = [] @@ -69,7 +67,12 @@ def _get_fields_from_websocket_routes( getattr(route, "subscribe_schema", None), getattr(route, "publish_schema", None), ): - if model is not None and isinstance(model, type) and issubclass(model, BaseModel) and model not in seen_models: + if ( + model is not None + and isinstance(model, type) + and issubclass(model, BaseModel) + and model not in seen_models + ): seen_models.add(model) fields.append( ModelField( @@ -164,9 +167,17 @@ def get_asyncapi( # Explicit subscribe_schema / publish_schema (e.g. when route has no Body() in Depends) subscribe_model = getattr(route, "subscribe_schema", None) publish_model = getattr(route, "publish_schema", None) - if subscribe_model is not None and isinstance(subscribe_model, type) and issubclass(subscribe_model, BaseModel): + if ( + subscribe_model is not None + and isinstance(subscribe_model, type) + and issubclass(subscribe_model, BaseModel) + ): sub_schema = {"$ref": f"{REF_PREFIX}{subscribe_model.__name__}"} - if publish_model is not None and isinstance(publish_model, type) and issubclass(publish_model, BaseModel): + if ( + publish_model is not None + and isinstance(publish_model, type) + and issubclass(publish_model, BaseModel) + ): pub_schema = {"$ref": f"{REF_PREFIX}{publish_model.__name__}"} # Fall back to first body param (Depends with Body()) for both if not set if sub_schema is None or pub_schema is None: diff --git a/tests/test_asyncapi.py b/tests/test_asyncapi.py index 28cddb0374..f1496c8b1f 100644 --- a/tests/test_asyncapi.py +++ b/tests/test_asyncapi.py @@ -492,12 +492,14 @@ def test_asyncapi_components_and_message_payload(): limit: int = 10 def get_query_message( - msg: QueryMessage = Body(default=QueryMessage(text="", limit=10)) + msg: QueryMessage = Body(default=QueryMessage(text="", limit=10)), ) -> QueryMessage: return msg @app.websocket("/query") - async def query_ws(websocket: WebSocket, msg: QueryMessage = Depends(get_query_message)): + async def query_ws( + websocket: WebSocket, msg: QueryMessage = Depends(get_query_message) + ): await websocket.accept() await websocket.close()