mirror of https://github.com/tiangolo/fastapi.git
match patterns from openapi
This commit is contained in:
parent
5040c2986c
commit
4f88800ace
|
|
@ -17,9 +17,9 @@ from fastapi.exception_handlers import (
|
|||
from fastapi.exceptions import RequestValidationError, WebSocketRequestValidationError
|
||||
from fastapi.logger import logger
|
||||
from fastapi.middleware.asyncexitstack import AsyncExitStackMiddleware
|
||||
from fastapi.openapi.asyncapi_utils import get_asyncapi
|
||||
from fastapi.asyncapi.docs import get_asyncapi_html
|
||||
from fastapi.asyncapi.utils import get_asyncapi
|
||||
from fastapi.openapi.docs import (
|
||||
get_asyncapi_html,
|
||||
get_redoc_html,
|
||||
get_swagger_ui_html,
|
||||
get_swagger_ui_oauth2_redirect_html,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,2 @@
|
|||
ASYNCAPI_VERSION = "2.6.0"
|
||||
REF_PREFIX = "#/components/schemas/"
|
||||
|
|
@ -0,0 +1,127 @@
|
|||
from typing import Annotated
|
||||
|
||||
from annotated_doc import Doc
|
||||
from starlette.responses import HTMLResponse
|
||||
|
||||
|
||||
def get_asyncapi_html(
|
||||
*,
|
||||
asyncapi_url: Annotated[
|
||||
str,
|
||||
Doc(
|
||||
"""
|
||||
The AsyncAPI URL that AsyncAPI Studio should load and use.
|
||||
|
||||
This is normally done automatically by FastAPI using the default URL
|
||||
`/asyncapi.json`.
|
||||
|
||||
Read more about it in the
|
||||
[FastAPI docs for AsyncAPI](https://fastapi.tiangolo.com/advanced/asyncapi/).
|
||||
"""
|
||||
),
|
||||
],
|
||||
title: Annotated[
|
||||
str,
|
||||
Doc(
|
||||
"""
|
||||
The HTML `<title>` content, normally shown in the browser tab.
|
||||
"""
|
||||
),
|
||||
],
|
||||
asyncapi_js_url: Annotated[
|
||||
str,
|
||||
Doc(
|
||||
"""
|
||||
The URL to use to load the AsyncAPI Studio JavaScript.
|
||||
|
||||
It is normally set to a CDN URL.
|
||||
"""
|
||||
),
|
||||
] = "https://unpkg.com/@asyncapi/react-component@latest/browser/standalone/index.js",
|
||||
asyncapi_favicon_url: Annotated[
|
||||
str,
|
||||
Doc(
|
||||
"""
|
||||
The URL of the favicon to use. It is normally shown in the browser tab.
|
||||
"""
|
||||
),
|
||||
] = "https://fastapi.tiangolo.com/img/favicon.png",
|
||||
docs_url: Annotated[
|
||||
str | None,
|
||||
Doc(
|
||||
"""
|
||||
The URL to the OpenAPI docs (Swagger UI) for navigation link.
|
||||
"""
|
||||
),
|
||||
] = None,
|
||||
) -> HTMLResponse:
|
||||
"""
|
||||
Generate and return the HTML that loads AsyncAPI Studio for the interactive
|
||||
WebSocket API docs (normally served at `/asyncapi-docs`).
|
||||
|
||||
You would only call this function yourself if you needed to override some parts,
|
||||
for example the URLs to use to load AsyncAPI Studio's JavaScript.
|
||||
"""
|
||||
navigation_html = ""
|
||||
if docs_url:
|
||||
navigation_html = f"""
|
||||
<div style="padding: 10px; background-color: #f5f5f5; border-bottom: 1px solid #ddd;">
|
||||
<a href="{docs_url}" style="color: #007bff; text-decoration: none; margin-right: 20px;">
|
||||
📄 OpenAPI Docs (REST API)
|
||||
</a>
|
||||
<span style="color: #666;">WebSocket API Documentation</span>
|
||||
</div>
|
||||
"""
|
||||
|
||||
html = f"""
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<link rel="shortcut icon" href="{asyncapi_favicon_url}">
|
||||
<title>{title}</title>
|
||||
<style>
|
||||
body {{
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
}}
|
||||
#asyncapi {{
|
||||
height: 100vh;
|
||||
width: 100%;
|
||||
}}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
{navigation_html}
|
||||
<div id="asyncapi"></div>
|
||||
<script src="{asyncapi_js_url}"></script>
|
||||
<script>
|
||||
(async function() {{
|
||||
const asyncapiSpec = await fetch('{asyncapi_url}').then(res => res.json());
|
||||
const AsyncApiStandalone = window.AsyncApiStandalone || window.AsyncAPIStandalone;
|
||||
if (AsyncApiStandalone) {{
|
||||
AsyncApiStandalone.render({{
|
||||
schema: asyncapiSpec,
|
||||
config: {{
|
||||
show: {{
|
||||
sidebar: true,
|
||||
info: true,
|
||||
servers: true,
|
||||
operations: true,
|
||||
messages: true,
|
||||
}},
|
||||
}},
|
||||
}}, document.getElementById('asyncapi'));
|
||||
}} else {{
|
||||
document.getElementById('asyncapi').innerHTML =
|
||||
'<div style="padding: 20px; text-align: center;">' +
|
||||
'<h2>Failed to load AsyncAPI Studio</h2>' +
|
||||
'<p>Please check your internet connection and try again.</p>' +
|
||||
'</div>';
|
||||
}}
|
||||
}})();
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
return HTMLResponse(html)
|
||||
|
|
@ -0,0 +1,224 @@
|
|||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from fastapi import routing
|
||||
from fastapi.asyncapi.constants import ASYNCAPI_VERSION, REF_PREFIX
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from starlette.routing import BaseRoute
|
||||
|
||||
|
||||
def get_asyncapi_channel(
|
||||
*,
|
||||
route: routing.APIWebSocketRoute,
|
||||
subscribe_payload_schema: dict[str, Any] | None = None,
|
||||
publish_payload_schema: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Generate AsyncAPI channel definition for a WebSocket route."""
|
||||
channel: dict[str, Any] = {}
|
||||
|
||||
# WebSocket channels typically have subscribe operation
|
||||
# (client subscribes to receive messages from server)
|
||||
operation: dict[str, Any] = {
|
||||
"operationId": route.name or f"websocket_{route.path_format}",
|
||||
}
|
||||
|
||||
# Message schema: contentType and optional payload (schema for message body)
|
||||
subscribe_message: dict[str, Any] = {
|
||||
"contentType": "application/json",
|
||||
}
|
||||
if subscribe_payload_schema:
|
||||
subscribe_message["payload"] = subscribe_payload_schema
|
||||
|
||||
operation["message"] = subscribe_message
|
||||
channel["subscribe"] = operation
|
||||
|
||||
# WebSockets are bidirectional, so we also include publish
|
||||
# (client can publish messages to server)
|
||||
publish_operation: dict[str, Any] = {
|
||||
"operationId": f"{route.name or f'websocket_{route.path_format}'}_publish",
|
||||
"message": {
|
||||
"contentType": "application/json",
|
||||
**({"payload": publish_payload_schema} if publish_payload_schema else {}),
|
||||
},
|
||||
}
|
||||
channel["publish"] = publish_operation
|
||||
|
||||
return channel
|
||||
|
||||
|
||||
def _get_fields_from_websocket_routes(
|
||||
routes: Sequence[BaseRoute],
|
||||
) -> list[Any]:
|
||||
"""Collect body (ModelField) params from WebSocket routes for schema generation."""
|
||||
from fastapi.dependencies.utils import get_flat_dependant
|
||||
|
||||
from fastapi._compat import ModelField
|
||||
from pydantic.fields import FieldInfo
|
||||
|
||||
fields: list[Any] = []
|
||||
seen_models: set[type[BaseModel]] = set()
|
||||
for route in routes or []:
|
||||
if not isinstance(route, routing.APIWebSocketRoute):
|
||||
continue
|
||||
flat_dependant = get_flat_dependant(route.dependant, skip_repeats=True)
|
||||
fields.extend(flat_dependant.body_params)
|
||||
# Add explicit subscribe_schema / publish_schema as ModelFields so they get definitions
|
||||
for model in (
|
||||
getattr(route, "subscribe_schema", None),
|
||||
getattr(route, "publish_schema", None),
|
||||
):
|
||||
if model is not None and isinstance(model, type) and issubclass(model, BaseModel) and model not in seen_models:
|
||||
seen_models.add(model)
|
||||
fields.append(
|
||||
ModelField(
|
||||
field_info=FieldInfo(annotation=model),
|
||||
name=model.__name__,
|
||||
mode="validation",
|
||||
)
|
||||
)
|
||||
return fields
|
||||
|
||||
|
||||
def get_asyncapi(
|
||||
*,
|
||||
title: str,
|
||||
version: str,
|
||||
asyncapi_version: str = ASYNCAPI_VERSION,
|
||||
summary: str | None = None,
|
||||
description: str | None = None,
|
||||
routes: Sequence[BaseRoute],
|
||||
servers: list[dict[str, str | Any]] | None = None,
|
||||
terms_of_service: str | None = None,
|
||||
contact: dict[str, str | Any] | None = None,
|
||||
license_info: dict[str, str | Any] | None = None,
|
||||
external_docs: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Generate AsyncAPI schema from FastAPI application routes.
|
||||
|
||||
Filters for WebSocket routes and generates AsyncAPI 2.6.0 compliant schema.
|
||||
Includes components/schemas for message payloads when WebSocket routes use
|
||||
Pydantic models (e.g. via Body() in dependencies).
|
||||
"""
|
||||
from fastapi._compat import (
|
||||
ModelField,
|
||||
get_definitions,
|
||||
get_flat_models_from_fields,
|
||||
get_model_name_map,
|
||||
get_schema_from_model_field,
|
||||
)
|
||||
|
||||
info: dict[str, Any] = {"title": title, "version": version}
|
||||
if summary:
|
||||
info["summary"] = summary
|
||||
if description:
|
||||
info["description"] = description
|
||||
if terms_of_service:
|
||||
info["termsOfService"] = terms_of_service
|
||||
if contact:
|
||||
info["contact"] = contact
|
||||
if license_info:
|
||||
info["license"] = license_info
|
||||
|
||||
output: dict[str, Any] = {"asyncapi": asyncapi_version, "info": info}
|
||||
|
||||
# Add default WebSocket server if no servers provided and we have WebSocket routes
|
||||
websocket_routes = [
|
||||
route for route in routes or [] if isinstance(route, routing.APIWebSocketRoute)
|
||||
]
|
||||
if websocket_routes and not servers:
|
||||
# Default WebSocket server - can be overridden by providing servers parameter
|
||||
output["servers"] = [
|
||||
{
|
||||
"url": "ws://localhost:8000",
|
||||
"protocol": "ws",
|
||||
"description": "WebSocket server",
|
||||
}
|
||||
]
|
||||
elif servers:
|
||||
output["servers"] = servers
|
||||
|
||||
# Build components/schemas from WebSocket body params and explicit subscribe/publish_schema
|
||||
ws_fields = _get_fields_from_websocket_routes(routes or [])
|
||||
components: dict[str, Any] = {}
|
||||
route_subscribe_schemas: dict[str, dict[str, Any] | None] = {}
|
||||
route_publish_schemas: dict[str, dict[str, Any] | None] = {}
|
||||
if ws_fields:
|
||||
flat_models = get_flat_models_from_fields(ws_fields, known_models=set())
|
||||
model_name_map = get_model_name_map(flat_models)
|
||||
field_mapping, definitions = get_definitions(
|
||||
fields=ws_fields,
|
||||
model_name_map=model_name_map,
|
||||
separate_input_output_schemas=True,
|
||||
)
|
||||
if definitions:
|
||||
components["schemas"] = {k: definitions[k] for k in sorted(definitions)}
|
||||
# For each WebSocket route, resolve subscribe and publish payload schemas
|
||||
for route in routes or []:
|
||||
if not isinstance(route, routing.APIWebSocketRoute):
|
||||
continue
|
||||
sub_schema: dict[str, Any] | None = None
|
||||
pub_schema: dict[str, Any] | None = None
|
||||
# Explicit subscribe_schema / publish_schema (e.g. when route has no Body() in Depends)
|
||||
subscribe_model = getattr(route, "subscribe_schema", None)
|
||||
publish_model = getattr(route, "publish_schema", None)
|
||||
if subscribe_model is not None and isinstance(subscribe_model, type) and issubclass(subscribe_model, BaseModel):
|
||||
sub_schema = {"$ref": f"{REF_PREFIX}{subscribe_model.__name__}"}
|
||||
if publish_model is not None and isinstance(publish_model, type) and issubclass(publish_model, BaseModel):
|
||||
pub_schema = {"$ref": f"{REF_PREFIX}{publish_model.__name__}"}
|
||||
# Fall back to first body param (Depends with Body()) for both if not set
|
||||
if sub_schema is None or pub_schema is None:
|
||||
flat_dependant = route._flat_dependant
|
||||
if flat_dependant.body_params:
|
||||
first_body = flat_dependant.body_params[0]
|
||||
if isinstance(first_body, ModelField):
|
||||
body_schema = get_schema_from_model_field(
|
||||
field=first_body,
|
||||
model_name_map=model_name_map,
|
||||
field_mapping=field_mapping,
|
||||
separate_input_output_schemas=True,
|
||||
)
|
||||
# Use only $ref for channel payload when schema is in components
|
||||
if "$ref" in body_schema and body_schema["$ref"].startswith(
|
||||
REF_PREFIX
|
||||
):
|
||||
body_schema = {"$ref": body_schema["$ref"]}
|
||||
if sub_schema is None:
|
||||
sub_schema = body_schema
|
||||
if pub_schema is None:
|
||||
pub_schema = body_schema
|
||||
route_subscribe_schemas[route.path_format] = sub_schema
|
||||
route_publish_schemas[route.path_format] = pub_schema
|
||||
else:
|
||||
for route in routes or []:
|
||||
if not isinstance(route, routing.APIWebSocketRoute):
|
||||
continue
|
||||
route_subscribe_schemas[route.path_format] = None
|
||||
route_publish_schemas[route.path_format] = None
|
||||
|
||||
channels: dict[str, dict[str, Any]] = {}
|
||||
|
||||
# Filter routes to only include WebSocket routes
|
||||
for route in routes or []:
|
||||
if isinstance(route, routing.APIWebSocketRoute):
|
||||
sub_schema = route_subscribe_schemas.get(route.path_format)
|
||||
pub_schema = route_publish_schemas.get(route.path_format)
|
||||
channel = get_asyncapi_channel(
|
||||
route=route,
|
||||
subscribe_payload_schema=sub_schema,
|
||||
publish_payload_schema=pub_schema,
|
||||
)
|
||||
if channel:
|
||||
channels[route.path_format] = channel
|
||||
|
||||
output["channels"] = channels
|
||||
|
||||
if components:
|
||||
output["components"] = components
|
||||
|
||||
if external_docs:
|
||||
output["externalDocs"] = external_docs
|
||||
|
||||
return jsonable_encoder(output, by_alias=True, exclude_none=True) # type: ignore
|
||||
|
|
@ -1,105 +0,0 @@
|
|||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from fastapi import routing
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from starlette.routing import BaseRoute
|
||||
|
||||
|
||||
def get_asyncapi_channel(
|
||||
*,
|
||||
route: routing.APIWebSocketRoute,
|
||||
) -> dict[str, Any]:
|
||||
"""Generate AsyncAPI channel definition for a WebSocket route."""
|
||||
channel: dict[str, Any] = {}
|
||||
|
||||
# WebSocket channels typically have subscribe operation
|
||||
# (client subscribes to receive messages from server)
|
||||
operation: dict[str, Any] = {
|
||||
"operationId": route.name or f"websocket_{route.path_format}",
|
||||
}
|
||||
|
||||
# Basic message schema - can be enhanced later with actual message types
|
||||
# For WebSockets, messages can be sent in both directions
|
||||
message: dict[str, Any] = {
|
||||
"contentType": "application/json",
|
||||
}
|
||||
|
||||
operation["message"] = message
|
||||
channel["subscribe"] = operation
|
||||
|
||||
# WebSockets are bidirectional, so we also include publish
|
||||
# (client can publish messages to server)
|
||||
publish_operation: dict[str, Any] = {
|
||||
"operationId": f"{route.name or f'websocket_{route.path_format}'}_publish",
|
||||
"message": message,
|
||||
}
|
||||
channel["publish"] = publish_operation
|
||||
|
||||
return channel
|
||||
|
||||
|
||||
def get_asyncapi(
|
||||
*,
|
||||
title: str,
|
||||
version: str,
|
||||
asyncapi_version: str = "2.6.0",
|
||||
summary: str | None = None,
|
||||
description: str | None = None,
|
||||
routes: Sequence[BaseRoute],
|
||||
servers: list[dict[str, str | Any]] | None = None,
|
||||
terms_of_service: str | None = None,
|
||||
contact: dict[str, str | Any] | None = None,
|
||||
license_info: dict[str, str | Any] | None = None,
|
||||
external_docs: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Generate AsyncAPI schema from FastAPI application routes.
|
||||
|
||||
Filters for WebSocket routes and generates AsyncAPI 2.6.0 compliant schema.
|
||||
"""
|
||||
info: dict[str, Any] = {"title": title, "version": version}
|
||||
if summary:
|
||||
info["summary"] = summary
|
||||
if description:
|
||||
info["description"] = description
|
||||
if terms_of_service:
|
||||
info["termsOfService"] = terms_of_service
|
||||
if contact:
|
||||
info["contact"] = contact
|
||||
if license_info:
|
||||
info["license"] = license_info
|
||||
|
||||
output: dict[str, Any] = {"asyncapi": asyncapi_version, "info": info}
|
||||
|
||||
# Add default WebSocket server if no servers provided and we have WebSocket routes
|
||||
websocket_routes = [
|
||||
route for route in routes or [] if isinstance(route, routing.APIWebSocketRoute)
|
||||
]
|
||||
if websocket_routes and not servers:
|
||||
# Default WebSocket server - can be overridden by providing servers parameter
|
||||
output["servers"] = [
|
||||
{
|
||||
"url": "ws://localhost:8000",
|
||||
"protocol": "ws",
|
||||
"description": "WebSocket server",
|
||||
}
|
||||
]
|
||||
elif servers:
|
||||
output["servers"] = servers
|
||||
|
||||
channels: dict[str, dict[str, Any]] = {}
|
||||
|
||||
# Filter routes to only include WebSocket routes
|
||||
for route in routes or []:
|
||||
if isinstance(route, routing.APIWebSocketRoute):
|
||||
channel = get_asyncapi_channel(route=route)
|
||||
if channel:
|
||||
channels[route.path_format] = channel
|
||||
|
||||
output["channels"] = channels
|
||||
|
||||
if external_docs:
|
||||
output["externalDocs"] = external_docs
|
||||
|
||||
return jsonable_encoder(output, by_alias=True, exclude_none=True) # type: ignore
|
||||
|
|
@ -214,129 +214,6 @@ def get_swagger_ui_html(
|
|||
return HTMLResponse(html)
|
||||
|
||||
|
||||
def get_asyncapi_html(
|
||||
*,
|
||||
asyncapi_url: Annotated[
|
||||
str,
|
||||
Doc(
|
||||
"""
|
||||
The AsyncAPI URL that AsyncAPI Studio should load and use.
|
||||
|
||||
This is normally done automatically by FastAPI using the default URL
|
||||
`/asyncapi.json`.
|
||||
|
||||
Read more about it in the
|
||||
[FastAPI docs for AsyncAPI](https://fastapi.tiangolo.com/advanced/asyncapi/).
|
||||
"""
|
||||
),
|
||||
],
|
||||
title: Annotated[
|
||||
str,
|
||||
Doc(
|
||||
"""
|
||||
The HTML `<title>` content, normally shown in the browser tab.
|
||||
"""
|
||||
),
|
||||
],
|
||||
asyncapi_js_url: Annotated[
|
||||
str,
|
||||
Doc(
|
||||
"""
|
||||
The URL to use to load the AsyncAPI Studio JavaScript.
|
||||
|
||||
It is normally set to a CDN URL.
|
||||
"""
|
||||
),
|
||||
] = "https://unpkg.com/@asyncapi/react-component@latest/browser/standalone/index.js",
|
||||
asyncapi_favicon_url: Annotated[
|
||||
str,
|
||||
Doc(
|
||||
"""
|
||||
The URL of the favicon to use. It is normally shown in the browser tab.
|
||||
"""
|
||||
),
|
||||
] = "https://fastapi.tiangolo.com/img/favicon.png",
|
||||
docs_url: Annotated[
|
||||
str | None,
|
||||
Doc(
|
||||
"""
|
||||
The URL to the OpenAPI docs (Swagger UI) for navigation link.
|
||||
"""
|
||||
),
|
||||
] = None,
|
||||
) -> HTMLResponse:
|
||||
"""
|
||||
Generate and return the HTML that loads AsyncAPI Studio for the interactive
|
||||
WebSocket API docs (normally served at `/asyncapi-docs`).
|
||||
|
||||
You would only call this function yourself if you needed to override some parts,
|
||||
for example the URLs to use to load AsyncAPI Studio's JavaScript.
|
||||
"""
|
||||
navigation_html = ""
|
||||
if docs_url:
|
||||
navigation_html = f"""
|
||||
<div style="padding: 10px; background-color: #f5f5f5; border-bottom: 1px solid #ddd;">
|
||||
<a href="{docs_url}" style="color: #007bff; text-decoration: none; margin-right: 20px;">
|
||||
📄 OpenAPI Docs (REST API)
|
||||
</a>
|
||||
<span style="color: #666;">WebSocket API Documentation</span>
|
||||
</div>
|
||||
"""
|
||||
|
||||
html = f"""
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<link rel="shortcut icon" href="{asyncapi_favicon_url}">
|
||||
<title>{title}</title>
|
||||
<style>
|
||||
body {{
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
}}
|
||||
#asyncapi {{
|
||||
height: 100vh;
|
||||
width: 100%;
|
||||
}}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
{navigation_html}
|
||||
<div id="asyncapi"></div>
|
||||
<script src="{asyncapi_js_url}"></script>
|
||||
<script>
|
||||
(async function() {{
|
||||
const asyncapiSpec = await fetch('{asyncapi_url}').then(res => res.json());
|
||||
const AsyncApiStandalone = window.AsyncApiStandalone || window.AsyncAPIStandalone;
|
||||
if (AsyncApiStandalone) {{
|
||||
AsyncApiStandalone.render({{
|
||||
schema: asyncapiSpec,
|
||||
config: {{
|
||||
show: {{
|
||||
sidebar: true,
|
||||
info: true,
|
||||
servers: true,
|
||||
operations: true,
|
||||
messages: true,
|
||||
}},
|
||||
}},
|
||||
}}, document.getElementById('asyncapi'));
|
||||
}} else {{
|
||||
document.getElementById('asyncapi').innerHTML =
|
||||
'<div style="padding: 20px; text-align: center;">' +
|
||||
'<h2>Failed to load AsyncAPI Studio</h2>' +
|
||||
'<p>Please check your internet connection and try again.</p>' +
|
||||
'</div>';
|
||||
}}
|
||||
}})();
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
return HTMLResponse(html)
|
||||
|
||||
|
||||
def get_redoc_html(
|
||||
*,
|
||||
openapi_url: Annotated[
|
||||
|
|
|
|||
|
|
@ -541,11 +541,15 @@ class APIWebSocketRoute(routing.WebSocketRoute):
|
|||
name: str | None = None,
|
||||
dependencies: Sequence[params.Depends] | None = None,
|
||||
dependency_overrides_provider: Any | None = None,
|
||||
subscribe_schema: type[Any] | None = None,
|
||||
publish_schema: type[Any] | None = None,
|
||||
) -> None:
|
||||
self.path = path
|
||||
self.endpoint = endpoint
|
||||
self.name = get_name(endpoint) if name is None else name
|
||||
self.dependencies = list(dependencies or [])
|
||||
self.subscribe_schema = subscribe_schema
|
||||
self.publish_schema = publish_schema
|
||||
self.path_regex, self.path_format, self.param_convertors = compile_path(path)
|
||||
self.dependant = get_dependant(
|
||||
path=self.path_format, call=self.endpoint, scope="function"
|
||||
|
|
@ -1214,6 +1218,8 @@ class APIRouter(routing.Router):
|
|||
name: str | None = None,
|
||||
*,
|
||||
dependencies: Sequence[params.Depends] | None = None,
|
||||
subscribe_schema: type[Any] | None = None,
|
||||
publish_schema: type[Any] | None = None,
|
||||
) -> None:
|
||||
current_dependencies = self.dependencies.copy()
|
||||
if dependencies:
|
||||
|
|
@ -1225,6 +1231,8 @@ class APIRouter(routing.Router):
|
|||
name=name,
|
||||
dependencies=current_dependencies,
|
||||
dependency_overrides_provider=self.dependency_overrides_provider,
|
||||
subscribe_schema=subscribe_schema,
|
||||
publish_schema=publish_schema,
|
||||
)
|
||||
self.routes.append(route)
|
||||
|
||||
|
|
@ -1259,6 +1267,25 @@ class APIRouter(routing.Router):
|
|||
"""
|
||||
),
|
||||
] = None,
|
||||
subscribe_schema: Annotated[
|
||||
type[Any] | None,
|
||||
Doc(
|
||||
"""
|
||||
Pydantic model for messages the client sends (subscribe operation).
|
||||
Used to generate AsyncAPI message payload schema when the route
|
||||
does not use Body() in dependencies.
|
||||
"""
|
||||
),
|
||||
] = None,
|
||||
publish_schema: Annotated[
|
||||
type[Any] | None,
|
||||
Doc(
|
||||
"""
|
||||
Pydantic model for messages the server sends (publish operation).
|
||||
Used to generate AsyncAPI message payload schema.
|
||||
"""
|
||||
),
|
||||
] = None,
|
||||
) -> Callable[[DecoratedCallable], DecoratedCallable]:
|
||||
"""
|
||||
Decorate a WebSocket function.
|
||||
|
|
@ -1289,7 +1316,12 @@ class APIRouter(routing.Router):
|
|||
|
||||
def decorator(func: DecoratedCallable) -> DecoratedCallable:
|
||||
self.add_api_websocket_route(
|
||||
path, func, name=name, dependencies=dependencies
|
||||
path,
|
||||
func,
|
||||
name=name,
|
||||
dependencies=dependencies,
|
||||
subscribe_schema=subscribe_schema,
|
||||
publish_schema=publish_schema,
|
||||
)
|
||||
return func
|
||||
|
||||
|
|
@ -1543,6 +1575,8 @@ class APIRouter(routing.Router):
|
|||
route.endpoint,
|
||||
dependencies=current_dependencies,
|
||||
name=route.name,
|
||||
subscribe_schema=route.subscribe_schema,
|
||||
publish_schema=route.publish_schema,
|
||||
)
|
||||
elif isinstance(route, routing.WebSocketRoute):
|
||||
self.add_websocket_route(
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
from fastapi import FastAPI, WebSocket
|
||||
from fastapi.openapi.asyncapi_utils import get_asyncapi, get_asyncapi_channel
|
||||
from fastapi import APIRouter, Body, Depends, FastAPI, WebSocket
|
||||
from fastapi.asyncapi.utils import get_asyncapi, get_asyncapi_channel
|
||||
from fastapi.testclient import TestClient
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
def test_asyncapi_schema():
|
||||
|
|
@ -480,6 +481,115 @@ def test_asyncapi_url_none_no_link_in_swagger():
|
|||
assert response.status_code == 404
|
||||
|
||||
|
||||
def test_asyncapi_components_and_message_payload():
|
||||
"""Test AsyncAPI schema includes components/schemas and message payload when models are used."""
|
||||
app = FastAPI(title="Test API", version="1.0.0")
|
||||
|
||||
class QueryMessage(BaseModel):
|
||||
"""Message sent on /query channel."""
|
||||
|
||||
text: str
|
||||
limit: int = 10
|
||||
|
||||
def get_query_message(
|
||||
msg: QueryMessage = Body(default=QueryMessage(text="", limit=10))
|
||||
) -> QueryMessage:
|
||||
return msg
|
||||
|
||||
@app.websocket("/query")
|
||||
async def query_ws(websocket: WebSocket, msg: QueryMessage = Depends(get_query_message)):
|
||||
await websocket.accept()
|
||||
await websocket.close()
|
||||
|
||||
# Connect to websocket so handler and dependency are covered (body default used)
|
||||
client = TestClient(app)
|
||||
with client.websocket_connect("/query"):
|
||||
pass
|
||||
|
||||
# Generate schema and assert components
|
||||
response = client.get("/asyncapi.json")
|
||||
assert response.status_code == 200, response.text
|
||||
schema = response.json()
|
||||
|
||||
# Should have components with schemas (reusable model definitions)
|
||||
assert "components" in schema
|
||||
assert "schemas" in schema["components"]
|
||||
assert "QueryMessage" in schema["components"]["schemas"]
|
||||
query_schema = schema["components"]["schemas"]["QueryMessage"]
|
||||
assert query_schema.get("title") == "QueryMessage"
|
||||
assert "text" in query_schema.get("properties", {})
|
||||
assert "limit" in query_schema.get("properties", {})
|
||||
|
||||
# Channel messages should reference the payload schema
|
||||
channel = schema["channels"]["/query"]
|
||||
for operation_key in ("subscribe", "publish"):
|
||||
msg_spec = channel[operation_key]["message"]
|
||||
assert msg_spec["contentType"] == "application/json"
|
||||
assert "payload" in msg_spec
|
||||
assert msg_spec["payload"] == {"$ref": "#/components/schemas/QueryMessage"}
|
||||
|
||||
|
||||
def test_asyncapi_explicit_subscribe_publish_schema():
|
||||
"""Test AsyncAPI schema when websocket uses subscribe_schema and publish_schema (no Body in deps).
|
||||
|
||||
Covers: components/schemas built from explicit subscribe_schema/publish_schema ModelFields,
|
||||
and channel message payloads set from explicit subscribe_model/publish_model $refs.
|
||||
"""
|
||||
app = FastAPI(title="Test API", version="1.0.0")
|
||||
router = APIRouter()
|
||||
|
||||
class ClientMessage(BaseModel):
|
||||
"""Message the client sends."""
|
||||
|
||||
action: str
|
||||
payload: str = ""
|
||||
|
||||
class ServerMessage(BaseModel):
|
||||
"""Message the server sends."""
|
||||
|
||||
event: str
|
||||
data: dict = {}
|
||||
|
||||
@router.websocket(
|
||||
"/chat",
|
||||
subscribe_schema=ClientMessage,
|
||||
publish_schema=ServerMessage,
|
||||
)
|
||||
async def chat_ws(websocket: WebSocket):
|
||||
await websocket.accept()
|
||||
await websocket.close()
|
||||
|
||||
app.include_router(router)
|
||||
client = TestClient(app)
|
||||
with client.websocket_connect("/chat"):
|
||||
pass
|
||||
|
||||
response = client.get("/asyncapi.json")
|
||||
assert response.status_code == 200, response.text
|
||||
schema = response.json()
|
||||
|
||||
# Components should include both models (from explicit subscribe_schema/publish_schema ModelFields)
|
||||
assert "components" in schema
|
||||
assert "schemas" in schema["components"]
|
||||
assert "ClientMessage" in schema["components"]["schemas"]
|
||||
assert "ServerMessage" in schema["components"]["schemas"]
|
||||
client_schema = schema["components"]["schemas"]["ClientMessage"]
|
||||
server_schema = schema["components"]["schemas"]["ServerMessage"]
|
||||
assert client_schema.get("title") == "ClientMessage"
|
||||
assert "action" in client_schema.get("properties", {})
|
||||
assert server_schema.get("title") == "ServerMessage"
|
||||
assert "event" in server_schema.get("properties", {})
|
||||
|
||||
# Channel subscribe/publish should use explicit $refs (subscribe_model / publish_model path)
|
||||
channel = schema["channels"]["/chat"]
|
||||
sub_msg = channel["subscribe"]["message"]
|
||||
pub_msg = channel["publish"]["message"]
|
||||
assert sub_msg["contentType"] == "application/json"
|
||||
assert sub_msg["payload"] == {"$ref": "#/components/schemas/ClientMessage"}
|
||||
assert pub_msg["contentType"] == "application/json"
|
||||
assert pub_msg["payload"] == {"$ref": "#/components/schemas/ServerMessage"}
|
||||
|
||||
|
||||
def test_asyncapi_with_root_path_in_servers():
|
||||
"""Test AsyncAPI schema includes root_path in servers when root_path_in_servers is True."""
|
||||
app = FastAPI(
|
||||
|
|
|
|||
Loading…
Reference in New Issue