match patterns from openapi

This commit is contained in:
rechain 2026-02-27 10:31:53 -05:00
parent 5040c2986c
commit 4f88800ace
9 changed files with 502 additions and 233 deletions

View File

@ -17,9 +17,9 @@ from fastapi.exception_handlers import (
from fastapi.exceptions import RequestValidationError, WebSocketRequestValidationError
from fastapi.logger import logger
from fastapi.middleware.asyncexitstack import AsyncExitStackMiddleware
from fastapi.openapi.asyncapi_utils import get_asyncapi
from fastapi.asyncapi.docs import get_asyncapi_html
from fastapi.asyncapi.utils import get_asyncapi
from fastapi.openapi.docs import (
get_asyncapi_html,
get_redoc_html,
get_swagger_ui_html,
get_swagger_ui_oauth2_redirect_html,

View File

View File

@ -0,0 +1,2 @@
ASYNCAPI_VERSION = "2.6.0"
REF_PREFIX = "#/components/schemas/"

127
fastapi/asyncapi/docs.py Normal file
View File

@ -0,0 +1,127 @@
from typing import Annotated
from annotated_doc import Doc
from starlette.responses import HTMLResponse
def get_asyncapi_html(
*,
asyncapi_url: Annotated[
str,
Doc(
"""
The AsyncAPI URL that AsyncAPI Studio should load and use.
This is normally done automatically by FastAPI using the default URL
`/asyncapi.json`.
Read more about it in the
[FastAPI docs for AsyncAPI](https://fastapi.tiangolo.com/advanced/asyncapi/).
"""
),
],
title: Annotated[
str,
Doc(
"""
The HTML `<title>` content, normally shown in the browser tab.
"""
),
],
asyncapi_js_url: Annotated[
str,
Doc(
"""
The URL to use to load the AsyncAPI Studio JavaScript.
It is normally set to a CDN URL.
"""
),
] = "https://unpkg.com/@asyncapi/react-component@latest/browser/standalone/index.js",
asyncapi_favicon_url: Annotated[
str,
Doc(
"""
The URL of the favicon to use. It is normally shown in the browser tab.
"""
),
] = "https://fastapi.tiangolo.com/img/favicon.png",
docs_url: Annotated[
str | None,
Doc(
"""
The URL to the OpenAPI docs (Swagger UI) for navigation link.
"""
),
] = None,
) -> HTMLResponse:
"""
Generate and return the HTML that loads AsyncAPI Studio for the interactive
WebSocket API docs (normally served at `/asyncapi-docs`).
You would only call this function yourself if you needed to override some parts,
for example the URLs to use to load AsyncAPI Studio's JavaScript.
"""
navigation_html = ""
if docs_url:
navigation_html = f"""
<div style="padding: 10px; background-color: #f5f5f5; border-bottom: 1px solid #ddd;">
<a href="{docs_url}" style="color: #007bff; text-decoration: none; margin-right: 20px;">
📄 OpenAPI Docs (REST API)
</a>
<span style="color: #666;">WebSocket API Documentation</span>
</div>
"""
html = f"""
<!DOCTYPE html>
<html>
<head>
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<link rel="shortcut icon" href="{asyncapi_favicon_url}">
<title>{title}</title>
<style>
body {{
margin: 0;
padding: 0;
}}
#asyncapi {{
height: 100vh;
width: 100%;
}}
</style>
</head>
<body>
{navigation_html}
<div id="asyncapi"></div>
<script src="{asyncapi_js_url}"></script>
<script>
(async function() {{
const asyncapiSpec = await fetch('{asyncapi_url}').then(res => res.json());
const AsyncApiStandalone = window.AsyncApiStandalone || window.AsyncAPIStandalone;
if (AsyncApiStandalone) {{
AsyncApiStandalone.render({{
schema: asyncapiSpec,
config: {{
show: {{
sidebar: true,
info: true,
servers: true,
operations: true,
messages: true,
}},
}},
}}, document.getElementById('asyncapi'));
}} else {{
document.getElementById('asyncapi').innerHTML =
'<div style="padding: 20px; text-align: center;">' +
'<h2>Failed to load AsyncAPI Studio</h2>' +
'<p>Please check your internet connection and try again.</p>' +
'</div>';
}}
}})();
</script>
</body>
</html>
"""
return HTMLResponse(html)

224
fastapi/asyncapi/utils.py Normal file
View File

@ -0,0 +1,224 @@
from collections.abc import Sequence
from typing import Any
from pydantic import BaseModel
from fastapi import routing
from fastapi.asyncapi.constants import ASYNCAPI_VERSION, REF_PREFIX
from fastapi.encoders import jsonable_encoder
from starlette.routing import BaseRoute
def get_asyncapi_channel(
*,
route: routing.APIWebSocketRoute,
subscribe_payload_schema: dict[str, Any] | None = None,
publish_payload_schema: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""Generate AsyncAPI channel definition for a WebSocket route."""
channel: dict[str, Any] = {}
# WebSocket channels typically have subscribe operation
# (client subscribes to receive messages from server)
operation: dict[str, Any] = {
"operationId": route.name or f"websocket_{route.path_format}",
}
# Message schema: contentType and optional payload (schema for message body)
subscribe_message: dict[str, Any] = {
"contentType": "application/json",
}
if subscribe_payload_schema:
subscribe_message["payload"] = subscribe_payload_schema
operation["message"] = subscribe_message
channel["subscribe"] = operation
# WebSockets are bidirectional, so we also include publish
# (client can publish messages to server)
publish_operation: dict[str, Any] = {
"operationId": f"{route.name or f'websocket_{route.path_format}'}_publish",
"message": {
"contentType": "application/json",
**({"payload": publish_payload_schema} if publish_payload_schema else {}),
},
}
channel["publish"] = publish_operation
return channel
def _get_fields_from_websocket_routes(
routes: Sequence[BaseRoute],
) -> list[Any]:
"""Collect body (ModelField) params from WebSocket routes for schema generation."""
from fastapi.dependencies.utils import get_flat_dependant
from fastapi._compat import ModelField
from pydantic.fields import FieldInfo
fields: list[Any] = []
seen_models: set[type[BaseModel]] = set()
for route in routes or []:
if not isinstance(route, routing.APIWebSocketRoute):
continue
flat_dependant = get_flat_dependant(route.dependant, skip_repeats=True)
fields.extend(flat_dependant.body_params)
# Add explicit subscribe_schema / publish_schema as ModelFields so they get definitions
for model in (
getattr(route, "subscribe_schema", None),
getattr(route, "publish_schema", None),
):
if model is not None and isinstance(model, type) and issubclass(model, BaseModel) and model not in seen_models:
seen_models.add(model)
fields.append(
ModelField(
field_info=FieldInfo(annotation=model),
name=model.__name__,
mode="validation",
)
)
return fields
def get_asyncapi(
*,
title: str,
version: str,
asyncapi_version: str = ASYNCAPI_VERSION,
summary: str | None = None,
description: str | None = None,
routes: Sequence[BaseRoute],
servers: list[dict[str, str | Any]] | None = None,
terms_of_service: str | None = None,
contact: dict[str, str | Any] | None = None,
license_info: dict[str, str | Any] | None = None,
external_docs: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""
Generate AsyncAPI schema from FastAPI application routes.
Filters for WebSocket routes and generates AsyncAPI 2.6.0 compliant schema.
Includes components/schemas for message payloads when WebSocket routes use
Pydantic models (e.g. via Body() in dependencies).
"""
from fastapi._compat import (
ModelField,
get_definitions,
get_flat_models_from_fields,
get_model_name_map,
get_schema_from_model_field,
)
info: dict[str, Any] = {"title": title, "version": version}
if summary:
info["summary"] = summary
if description:
info["description"] = description
if terms_of_service:
info["termsOfService"] = terms_of_service
if contact:
info["contact"] = contact
if license_info:
info["license"] = license_info
output: dict[str, Any] = {"asyncapi": asyncapi_version, "info": info}
# Add default WebSocket server if no servers provided and we have WebSocket routes
websocket_routes = [
route for route in routes or [] if isinstance(route, routing.APIWebSocketRoute)
]
if websocket_routes and not servers:
# Default WebSocket server - can be overridden by providing servers parameter
output["servers"] = [
{
"url": "ws://localhost:8000",
"protocol": "ws",
"description": "WebSocket server",
}
]
elif servers:
output["servers"] = servers
# Build components/schemas from WebSocket body params and explicit subscribe/publish_schema
ws_fields = _get_fields_from_websocket_routes(routes or [])
components: dict[str, Any] = {}
route_subscribe_schemas: dict[str, dict[str, Any] | None] = {}
route_publish_schemas: dict[str, dict[str, Any] | None] = {}
if ws_fields:
flat_models = get_flat_models_from_fields(ws_fields, known_models=set())
model_name_map = get_model_name_map(flat_models)
field_mapping, definitions = get_definitions(
fields=ws_fields,
model_name_map=model_name_map,
separate_input_output_schemas=True,
)
if definitions:
components["schemas"] = {k: definitions[k] for k in sorted(definitions)}
# For each WebSocket route, resolve subscribe and publish payload schemas
for route in routes or []:
if not isinstance(route, routing.APIWebSocketRoute):
continue
sub_schema: dict[str, Any] | None = None
pub_schema: dict[str, Any] | None = None
# Explicit subscribe_schema / publish_schema (e.g. when route has no Body() in Depends)
subscribe_model = getattr(route, "subscribe_schema", None)
publish_model = getattr(route, "publish_schema", None)
if subscribe_model is not None and isinstance(subscribe_model, type) and issubclass(subscribe_model, BaseModel):
sub_schema = {"$ref": f"{REF_PREFIX}{subscribe_model.__name__}"}
if publish_model is not None and isinstance(publish_model, type) and issubclass(publish_model, BaseModel):
pub_schema = {"$ref": f"{REF_PREFIX}{publish_model.__name__}"}
# Fall back to first body param (Depends with Body()) for both if not set
if sub_schema is None or pub_schema is None:
flat_dependant = route._flat_dependant
if flat_dependant.body_params:
first_body = flat_dependant.body_params[0]
if isinstance(first_body, ModelField):
body_schema = get_schema_from_model_field(
field=first_body,
model_name_map=model_name_map,
field_mapping=field_mapping,
separate_input_output_schemas=True,
)
# Use only $ref for channel payload when schema is in components
if "$ref" in body_schema and body_schema["$ref"].startswith(
REF_PREFIX
):
body_schema = {"$ref": body_schema["$ref"]}
if sub_schema is None:
sub_schema = body_schema
if pub_schema is None:
pub_schema = body_schema
route_subscribe_schemas[route.path_format] = sub_schema
route_publish_schemas[route.path_format] = pub_schema
else:
for route in routes or []:
if not isinstance(route, routing.APIWebSocketRoute):
continue
route_subscribe_schemas[route.path_format] = None
route_publish_schemas[route.path_format] = None
channels: dict[str, dict[str, Any]] = {}
# Filter routes to only include WebSocket routes
for route in routes or []:
if isinstance(route, routing.APIWebSocketRoute):
sub_schema = route_subscribe_schemas.get(route.path_format)
pub_schema = route_publish_schemas.get(route.path_format)
channel = get_asyncapi_channel(
route=route,
subscribe_payload_schema=sub_schema,
publish_payload_schema=pub_schema,
)
if channel:
channels[route.path_format] = channel
output["channels"] = channels
if components:
output["components"] = components
if external_docs:
output["externalDocs"] = external_docs
return jsonable_encoder(output, by_alias=True, exclude_none=True) # type: ignore

View File

@ -1,105 +0,0 @@
from collections.abc import Sequence
from typing import Any
from fastapi import routing
from fastapi.encoders import jsonable_encoder
from starlette.routing import BaseRoute
def get_asyncapi_channel(
*,
route: routing.APIWebSocketRoute,
) -> dict[str, Any]:
"""Generate AsyncAPI channel definition for a WebSocket route."""
channel: dict[str, Any] = {}
# WebSocket channels typically have subscribe operation
# (client subscribes to receive messages from server)
operation: dict[str, Any] = {
"operationId": route.name or f"websocket_{route.path_format}",
}
# Basic message schema - can be enhanced later with actual message types
# For WebSockets, messages can be sent in both directions
message: dict[str, Any] = {
"contentType": "application/json",
}
operation["message"] = message
channel["subscribe"] = operation
# WebSockets are bidirectional, so we also include publish
# (client can publish messages to server)
publish_operation: dict[str, Any] = {
"operationId": f"{route.name or f'websocket_{route.path_format}'}_publish",
"message": message,
}
channel["publish"] = publish_operation
return channel
def get_asyncapi(
*,
title: str,
version: str,
asyncapi_version: str = "2.6.0",
summary: str | None = None,
description: str | None = None,
routes: Sequence[BaseRoute],
servers: list[dict[str, str | Any]] | None = None,
terms_of_service: str | None = None,
contact: dict[str, str | Any] | None = None,
license_info: dict[str, str | Any] | None = None,
external_docs: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""
Generate AsyncAPI schema from FastAPI application routes.
Filters for WebSocket routes and generates AsyncAPI 2.6.0 compliant schema.
"""
info: dict[str, Any] = {"title": title, "version": version}
if summary:
info["summary"] = summary
if description:
info["description"] = description
if terms_of_service:
info["termsOfService"] = terms_of_service
if contact:
info["contact"] = contact
if license_info:
info["license"] = license_info
output: dict[str, Any] = {"asyncapi": asyncapi_version, "info": info}
# Add default WebSocket server if no servers provided and we have WebSocket routes
websocket_routes = [
route for route in routes or [] if isinstance(route, routing.APIWebSocketRoute)
]
if websocket_routes and not servers:
# Default WebSocket server - can be overridden by providing servers parameter
output["servers"] = [
{
"url": "ws://localhost:8000",
"protocol": "ws",
"description": "WebSocket server",
}
]
elif servers:
output["servers"] = servers
channels: dict[str, dict[str, Any]] = {}
# Filter routes to only include WebSocket routes
for route in routes or []:
if isinstance(route, routing.APIWebSocketRoute):
channel = get_asyncapi_channel(route=route)
if channel:
channels[route.path_format] = channel
output["channels"] = channels
if external_docs:
output["externalDocs"] = external_docs
return jsonable_encoder(output, by_alias=True, exclude_none=True) # type: ignore

View File

@ -214,129 +214,6 @@ def get_swagger_ui_html(
return HTMLResponse(html)
def get_asyncapi_html(
*,
asyncapi_url: Annotated[
str,
Doc(
"""
The AsyncAPI URL that AsyncAPI Studio should load and use.
This is normally done automatically by FastAPI using the default URL
`/asyncapi.json`.
Read more about it in the
[FastAPI docs for AsyncAPI](https://fastapi.tiangolo.com/advanced/asyncapi/).
"""
),
],
title: Annotated[
str,
Doc(
"""
The HTML `<title>` content, normally shown in the browser tab.
"""
),
],
asyncapi_js_url: Annotated[
str,
Doc(
"""
The URL to use to load the AsyncAPI Studio JavaScript.
It is normally set to a CDN URL.
"""
),
] = "https://unpkg.com/@asyncapi/react-component@latest/browser/standalone/index.js",
asyncapi_favicon_url: Annotated[
str,
Doc(
"""
The URL of the favicon to use. It is normally shown in the browser tab.
"""
),
] = "https://fastapi.tiangolo.com/img/favicon.png",
docs_url: Annotated[
str | None,
Doc(
"""
The URL to the OpenAPI docs (Swagger UI) for navigation link.
"""
),
] = None,
) -> HTMLResponse:
"""
Generate and return the HTML that loads AsyncAPI Studio for the interactive
WebSocket API docs (normally served at `/asyncapi-docs`).
You would only call this function yourself if you needed to override some parts,
for example the URLs to use to load AsyncAPI Studio's JavaScript.
"""
navigation_html = ""
if docs_url:
navigation_html = f"""
<div style="padding: 10px; background-color: #f5f5f5; border-bottom: 1px solid #ddd;">
<a href="{docs_url}" style="color: #007bff; text-decoration: none; margin-right: 20px;">
📄 OpenAPI Docs (REST API)
</a>
<span style="color: #666;">WebSocket API Documentation</span>
</div>
"""
html = f"""
<!DOCTYPE html>
<html>
<head>
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<link rel="shortcut icon" href="{asyncapi_favicon_url}">
<title>{title}</title>
<style>
body {{
margin: 0;
padding: 0;
}}
#asyncapi {{
height: 100vh;
width: 100%;
}}
</style>
</head>
<body>
{navigation_html}
<div id="asyncapi"></div>
<script src="{asyncapi_js_url}"></script>
<script>
(async function() {{
const asyncapiSpec = await fetch('{asyncapi_url}').then(res => res.json());
const AsyncApiStandalone = window.AsyncApiStandalone || window.AsyncAPIStandalone;
if (AsyncApiStandalone) {{
AsyncApiStandalone.render({{
schema: asyncapiSpec,
config: {{
show: {{
sidebar: true,
info: true,
servers: true,
operations: true,
messages: true,
}},
}},
}}, document.getElementById('asyncapi'));
}} else {{
document.getElementById('asyncapi').innerHTML =
'<div style="padding: 20px; text-align: center;">' +
'<h2>Failed to load AsyncAPI Studio</h2>' +
'<p>Please check your internet connection and try again.</p>' +
'</div>';
}}
}})();
</script>
</body>
</html>
"""
return HTMLResponse(html)
def get_redoc_html(
*,
openapi_url: Annotated[

View File

@ -541,11 +541,15 @@ class APIWebSocketRoute(routing.WebSocketRoute):
name: str | None = None,
dependencies: Sequence[params.Depends] | None = None,
dependency_overrides_provider: Any | None = None,
subscribe_schema: type[Any] | None = None,
publish_schema: type[Any] | None = None,
) -> None:
self.path = path
self.endpoint = endpoint
self.name = get_name(endpoint) if name is None else name
self.dependencies = list(dependencies or [])
self.subscribe_schema = subscribe_schema
self.publish_schema = publish_schema
self.path_regex, self.path_format, self.param_convertors = compile_path(path)
self.dependant = get_dependant(
path=self.path_format, call=self.endpoint, scope="function"
@ -1214,6 +1218,8 @@ class APIRouter(routing.Router):
name: str | None = None,
*,
dependencies: Sequence[params.Depends] | None = None,
subscribe_schema: type[Any] | None = None,
publish_schema: type[Any] | None = None,
) -> None:
current_dependencies = self.dependencies.copy()
if dependencies:
@ -1225,6 +1231,8 @@ class APIRouter(routing.Router):
name=name,
dependencies=current_dependencies,
dependency_overrides_provider=self.dependency_overrides_provider,
subscribe_schema=subscribe_schema,
publish_schema=publish_schema,
)
self.routes.append(route)
@ -1259,6 +1267,25 @@ class APIRouter(routing.Router):
"""
),
] = None,
subscribe_schema: Annotated[
type[Any] | None,
Doc(
"""
Pydantic model for messages the client sends (subscribe operation).
Used to generate AsyncAPI message payload schema when the route
does not use Body() in dependencies.
"""
),
] = None,
publish_schema: Annotated[
type[Any] | None,
Doc(
"""
Pydantic model for messages the server sends (publish operation).
Used to generate AsyncAPI message payload schema.
"""
),
] = None,
) -> Callable[[DecoratedCallable], DecoratedCallable]:
"""
Decorate a WebSocket function.
@ -1289,7 +1316,12 @@ class APIRouter(routing.Router):
def decorator(func: DecoratedCallable) -> DecoratedCallable:
self.add_api_websocket_route(
path, func, name=name, dependencies=dependencies
path,
func,
name=name,
dependencies=dependencies,
subscribe_schema=subscribe_schema,
publish_schema=publish_schema,
)
return func
@ -1543,6 +1575,8 @@ class APIRouter(routing.Router):
route.endpoint,
dependencies=current_dependencies,
name=route.name,
subscribe_schema=route.subscribe_schema,
publish_schema=route.publish_schema,
)
elif isinstance(route, routing.WebSocketRoute):
self.add_websocket_route(

View File

@ -1,6 +1,7 @@
from fastapi import FastAPI, WebSocket
from fastapi.openapi.asyncapi_utils import get_asyncapi, get_asyncapi_channel
from fastapi import APIRouter, Body, Depends, FastAPI, WebSocket
from fastapi.asyncapi.utils import get_asyncapi, get_asyncapi_channel
from fastapi.testclient import TestClient
from pydantic import BaseModel
def test_asyncapi_schema():
@ -480,6 +481,115 @@ def test_asyncapi_url_none_no_link_in_swagger():
assert response.status_code == 404
def test_asyncapi_components_and_message_payload():
"""Test AsyncAPI schema includes components/schemas and message payload when models are used."""
app = FastAPI(title="Test API", version="1.0.0")
class QueryMessage(BaseModel):
"""Message sent on /query channel."""
text: str
limit: int = 10
def get_query_message(
msg: QueryMessage = Body(default=QueryMessage(text="", limit=10))
) -> QueryMessage:
return msg
@app.websocket("/query")
async def query_ws(websocket: WebSocket, msg: QueryMessage = Depends(get_query_message)):
await websocket.accept()
await websocket.close()
# Connect to websocket so handler and dependency are covered (body default used)
client = TestClient(app)
with client.websocket_connect("/query"):
pass
# Generate schema and assert components
response = client.get("/asyncapi.json")
assert response.status_code == 200, response.text
schema = response.json()
# Should have components with schemas (reusable model definitions)
assert "components" in schema
assert "schemas" in schema["components"]
assert "QueryMessage" in schema["components"]["schemas"]
query_schema = schema["components"]["schemas"]["QueryMessage"]
assert query_schema.get("title") == "QueryMessage"
assert "text" in query_schema.get("properties", {})
assert "limit" in query_schema.get("properties", {})
# Channel messages should reference the payload schema
channel = schema["channels"]["/query"]
for operation_key in ("subscribe", "publish"):
msg_spec = channel[operation_key]["message"]
assert msg_spec["contentType"] == "application/json"
assert "payload" in msg_spec
assert msg_spec["payload"] == {"$ref": "#/components/schemas/QueryMessage"}
def test_asyncapi_explicit_subscribe_publish_schema():
"""Test AsyncAPI schema when websocket uses subscribe_schema and publish_schema (no Body in deps).
Covers: components/schemas built from explicit subscribe_schema/publish_schema ModelFields,
and channel message payloads set from explicit subscribe_model/publish_model $refs.
"""
app = FastAPI(title="Test API", version="1.0.0")
router = APIRouter()
class ClientMessage(BaseModel):
"""Message the client sends."""
action: str
payload: str = ""
class ServerMessage(BaseModel):
"""Message the server sends."""
event: str
data: dict = {}
@router.websocket(
"/chat",
subscribe_schema=ClientMessage,
publish_schema=ServerMessage,
)
async def chat_ws(websocket: WebSocket):
await websocket.accept()
await websocket.close()
app.include_router(router)
client = TestClient(app)
with client.websocket_connect("/chat"):
pass
response = client.get("/asyncapi.json")
assert response.status_code == 200, response.text
schema = response.json()
# Components should include both models (from explicit subscribe_schema/publish_schema ModelFields)
assert "components" in schema
assert "schemas" in schema["components"]
assert "ClientMessage" in schema["components"]["schemas"]
assert "ServerMessage" in schema["components"]["schemas"]
client_schema = schema["components"]["schemas"]["ClientMessage"]
server_schema = schema["components"]["schemas"]["ServerMessage"]
assert client_schema.get("title") == "ClientMessage"
assert "action" in client_schema.get("properties", {})
assert server_schema.get("title") == "ServerMessage"
assert "event" in server_schema.get("properties", {})
# Channel subscribe/publish should use explicit $refs (subscribe_model / publish_model path)
channel = schema["channels"]["/chat"]
sub_msg = channel["subscribe"]["message"]
pub_msg = channel["publish"]["message"]
assert sub_msg["contentType"] == "application/json"
assert sub_msg["payload"] == {"$ref": "#/components/schemas/ClientMessage"}
assert pub_msg["contentType"] == "application/json"
assert pub_msg["payload"] == {"$ref": "#/components/schemas/ServerMessage"}
def test_asyncapi_with_root_path_in_servers():
"""Test AsyncAPI schema includes root_path in servers when root_path_in_servers is True."""
app = FastAPI(