mirror of https://github.com/tiangolo/fastapi.git
Merge 59f2819799 into 0127069d47
This commit is contained in:
commit
4289cdf35e
|
|
@ -4,6 +4,8 @@ from typing import Annotated, Any, TypeVar
|
|||
|
||||
from annotated_doc import Doc
|
||||
from fastapi import routing
|
||||
from fastapi.asyncapi.docs import get_asyncapi_html
|
||||
from fastapi.asyncapi.utils import get_asyncapi
|
||||
from fastapi.datastructures import Default, DefaultPlaceholder
|
||||
from fastapi.exception_handlers import (
|
||||
http_exception_handler,
|
||||
|
|
@ -442,6 +444,49 @@ class FastAPI(Starlette):
|
|||
"""
|
||||
),
|
||||
] = "/redoc",
|
||||
asyncapi_url: Annotated[
|
||||
str | None,
|
||||
Doc(
|
||||
"""
|
||||
The URL where the AsyncAPI schema will be served from.
|
||||
|
||||
If you set it to `None`, no AsyncAPI schema will be served publicly, and
|
||||
the default automatic endpoint `/asyncapi-docs` will also be disabled.
|
||||
|
||||
AsyncAPI is used to document WebSocket endpoints, similar to how OpenAPI
|
||||
documents HTTP endpoints.
|
||||
|
||||
**Example**
|
||||
|
||||
```python
|
||||
from fastapi import FastAPI
|
||||
|
||||
app = FastAPI(asyncapi_url="/api/v1/asyncapi.json")
|
||||
```
|
||||
"""
|
||||
),
|
||||
] = "/asyncapi.json",
|
||||
asyncapi_docs_url: Annotated[
|
||||
str | None,
|
||||
Doc(
|
||||
"""
|
||||
The URL where the AsyncAPI documentation UI will be served from.
|
||||
|
||||
If you set it to `None`, the AsyncAPI documentation UI will be disabled.
|
||||
|
||||
This provides an interactive UI for viewing WebSocket endpoint documentation,
|
||||
similar to how `/docs` provides Swagger UI for HTTP endpoints.
|
||||
|
||||
**Example**
|
||||
|
||||
```python
|
||||
from fastapi import FastAPI
|
||||
|
||||
app = FastAPI(asyncapi_docs_url="/async-docs")
|
||||
```
|
||||
"""
|
||||
),
|
||||
] = "/asyncapi-docs",
|
||||
swagger_ui_oauth2_redirect_url: Annotated[
|
||||
str | None,
|
||||
Doc(
|
||||
|
|
@ -882,6 +927,8 @@ class FastAPI(Starlette):
|
|||
self.root_path_in_servers = root_path_in_servers
|
||||
self.docs_url = docs_url
|
||||
self.redoc_url = redoc_url
|
||||
self.asyncapi_url = asyncapi_url
|
||||
self.asyncapi_docs_url = asyncapi_docs_url
|
||||
self.swagger_ui_oauth2_redirect_url = swagger_ui_oauth2_redirect_url
|
||||
self.swagger_ui_init_oauth = swagger_ui_init_oauth
|
||||
self.swagger_ui_parameters = swagger_ui_parameters
|
||||
|
|
@ -921,9 +968,15 @@ class FastAPI(Starlette):
|
|||
),
|
||||
] = "3.1.0"
|
||||
self.openapi_schema: dict[str, Any] | None = None
|
||||
self.asyncapi_schema: dict[str, Any] | None = None
|
||||
if self.openapi_url:
|
||||
assert self.title, "A title must be provided for OpenAPI, e.g.: 'My API'"
|
||||
assert self.version, "A version must be provided for OpenAPI, e.g.: '2.1.0'"
|
||||
if self.asyncapi_url:
|
||||
assert self.title, "A title must be provided for AsyncAPI, e.g.: 'My API'"
|
||||
assert self.version, (
|
||||
"A version must be provided for AsyncAPI, e.g.: '2.1.0'"
|
||||
)
|
||||
# TODO: remove when discarding the openapi_prefix parameter
|
||||
if openapi_prefix:
|
||||
logger.warning(
|
||||
|
|
@ -1098,6 +1151,36 @@ class FastAPI(Starlette):
|
|||
)
|
||||
return self.openapi_schema
|
||||
|
||||
def asyncapi(self) -> dict[str, Any]:
|
||||
"""
|
||||
Generate the AsyncAPI schema of the application. This is called by FastAPI
|
||||
internally.
|
||||
|
||||
The first time it is called it stores the result in the attribute
|
||||
`app.asyncapi_schema`, and next times it is called, it just returns that same
|
||||
result. To avoid the cost of generating the schema every time.
|
||||
|
||||
If you need to modify the generated AsyncAPI schema, you could modify it.
|
||||
|
||||
AsyncAPI is used to document WebSocket endpoints, similar to how OpenAPI
|
||||
documents HTTP endpoints.
|
||||
"""
|
||||
if not self.asyncapi_schema:
|
||||
self.asyncapi_schema = get_asyncapi(
|
||||
title=self.title,
|
||||
version=self.version,
|
||||
asyncapi_version="2.6.0",
|
||||
summary=self.summary,
|
||||
description=self.description,
|
||||
routes=self.routes,
|
||||
servers=self.servers,
|
||||
terms_of_service=self.terms_of_service,
|
||||
contact=self.contact,
|
||||
license_info=self.license_info,
|
||||
external_docs=self.openapi_external_docs,
|
||||
)
|
||||
return self.asyncapi_schema
|
||||
|
||||
def setup(self) -> None:
|
||||
if self.openapi_url:
|
||||
|
||||
|
|
@ -1114,6 +1197,21 @@ class FastAPI(Starlette):
|
|||
return JSONResponse(schema)
|
||||
|
||||
self.add_route(self.openapi_url, openapi, include_in_schema=False)
|
||||
if self.asyncapi_url:
|
||||
|
||||
async def asyncapi(req: Request) -> JSONResponse:
|
||||
root_path = req.scope.get("root_path", "").rstrip("/")
|
||||
schema = self.asyncapi()
|
||||
if root_path and self.root_path_in_servers:
|
||||
server_urls = {s.get("url") for s in schema.get("servers", [])}
|
||||
if root_path not in server_urls:
|
||||
schema = dict(schema)
|
||||
schema["servers"] = [{"url": root_path}] + schema.get(
|
||||
"servers", []
|
||||
)
|
||||
return JSONResponse(schema)
|
||||
|
||||
self.add_route(self.asyncapi_url, asyncapi, include_in_schema=False)
|
||||
if self.openapi_url and self.docs_url:
|
||||
|
||||
async def swagger_ui_html(req: Request) -> HTMLResponse:
|
||||
|
|
@ -1122,12 +1220,16 @@ class FastAPI(Starlette):
|
|||
oauth2_redirect_url = self.swagger_ui_oauth2_redirect_url
|
||||
if oauth2_redirect_url:
|
||||
oauth2_redirect_url = root_path + oauth2_redirect_url
|
||||
asyncapi_docs_url = None
|
||||
if self.asyncapi_url and self.asyncapi_docs_url:
|
||||
asyncapi_docs_url = root_path + self.asyncapi_docs_url
|
||||
return get_swagger_ui_html(
|
||||
openapi_url=openapi_url,
|
||||
title=f"{self.title} - Swagger UI",
|
||||
oauth2_redirect_url=oauth2_redirect_url,
|
||||
init_oauth=self.swagger_ui_init_oauth,
|
||||
swagger_ui_parameters=self.swagger_ui_parameters,
|
||||
asyncapi_docs_url=asyncapi_docs_url,
|
||||
)
|
||||
|
||||
self.add_route(self.docs_url, swagger_ui_html, include_in_schema=False)
|
||||
|
|
@ -1152,6 +1254,21 @@ class FastAPI(Starlette):
|
|||
)
|
||||
|
||||
self.add_route(self.redoc_url, redoc_html, include_in_schema=False)
|
||||
if self.asyncapi_url and self.asyncapi_docs_url:
|
||||
|
||||
async def asyncapi_ui_html(req: Request) -> HTMLResponse:
|
||||
root_path = req.scope.get("root_path", "").rstrip("/")
|
||||
asyncapi_url = root_path + self.asyncapi_url
|
||||
docs_url = root_path + self.docs_url if self.docs_url else None
|
||||
return get_asyncapi_html(
|
||||
asyncapi_url=asyncapi_url,
|
||||
title=f"{self.title} - AsyncAPI",
|
||||
docs_url=docs_url,
|
||||
)
|
||||
|
||||
self.add_route(
|
||||
self.asyncapi_docs_url, asyncapi_ui_html, include_in_schema=False
|
||||
)
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
if self.root_path:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,2 @@
|
|||
ASYNCAPI_VERSION = "2.6.0"
|
||||
REF_PREFIX = "#/components/schemas/"
|
||||
|
|
@ -0,0 +1,127 @@
|
|||
from typing import Annotated
|
||||
|
||||
from annotated_doc import Doc
|
||||
from starlette.responses import HTMLResponse
|
||||
|
||||
|
||||
def get_asyncapi_html(
|
||||
*,
|
||||
asyncapi_url: Annotated[
|
||||
str,
|
||||
Doc(
|
||||
"""
|
||||
The AsyncAPI URL that AsyncAPI Studio should load and use.
|
||||
|
||||
This is normally done automatically by FastAPI using the default URL
|
||||
`/asyncapi.json`.
|
||||
|
||||
Read more about it in the
|
||||
[FastAPI docs for AsyncAPI](https://fastapi.tiangolo.com/advanced/asyncapi/).
|
||||
"""
|
||||
),
|
||||
],
|
||||
title: Annotated[
|
||||
str,
|
||||
Doc(
|
||||
"""
|
||||
The HTML `<title>` content, normally shown in the browser tab.
|
||||
"""
|
||||
),
|
||||
],
|
||||
asyncapi_js_url: Annotated[
|
||||
str,
|
||||
Doc(
|
||||
"""
|
||||
The URL to use to load the AsyncAPI Studio JavaScript.
|
||||
|
||||
It is normally set to a CDN URL.
|
||||
"""
|
||||
),
|
||||
] = "https://unpkg.com/@asyncapi/react-component@latest/browser/standalone/index.js",
|
||||
asyncapi_favicon_url: Annotated[
|
||||
str,
|
||||
Doc(
|
||||
"""
|
||||
The URL of the favicon to use. It is normally shown in the browser tab.
|
||||
"""
|
||||
),
|
||||
] = "https://fastapi.tiangolo.com/img/favicon.png",
|
||||
docs_url: Annotated[
|
||||
str | None,
|
||||
Doc(
|
||||
"""
|
||||
The URL to the OpenAPI docs (Swagger UI) for navigation link.
|
||||
"""
|
||||
),
|
||||
] = None,
|
||||
) -> HTMLResponse:
|
||||
"""
|
||||
Generate and return the HTML that loads AsyncAPI Studio for the interactive
|
||||
WebSocket API docs (normally served at `/asyncapi-docs`).
|
||||
|
||||
You would only call this function yourself if you needed to override some parts,
|
||||
for example the URLs to use to load AsyncAPI Studio's JavaScript.
|
||||
"""
|
||||
navigation_html = ""
|
||||
if docs_url:
|
||||
navigation_html = f"""
|
||||
<div style="padding: 10px; background-color: #f5f5f5; border-bottom: 1px solid #ddd;">
|
||||
<a href="{docs_url}" style="color: #007bff; text-decoration: none; margin-right: 20px;">
|
||||
📄 OpenAPI Docs (REST API)
|
||||
</a>
|
||||
<span style="color: #666;">WebSocket API Documentation</span>
|
||||
</div>
|
||||
"""
|
||||
|
||||
html = f"""
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<link rel="shortcut icon" href="{asyncapi_favicon_url}">
|
||||
<title>{title}</title>
|
||||
<style>
|
||||
body {{
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
}}
|
||||
#asyncapi {{
|
||||
height: 100vh;
|
||||
width: 100%;
|
||||
}}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
{navigation_html}
|
||||
<div id="asyncapi"></div>
|
||||
<script src="{asyncapi_js_url}"></script>
|
||||
<script>
|
||||
(async function() {{
|
||||
const asyncapiSpec = await fetch('{asyncapi_url}').then(res => res.json());
|
||||
const AsyncApiStandalone = window.AsyncApiStandalone || window.AsyncAPIStandalone;
|
||||
if (AsyncApiStandalone) {{
|
||||
AsyncApiStandalone.render({{
|
||||
schema: asyncapiSpec,
|
||||
config: {{
|
||||
show: {{
|
||||
sidebar: true,
|
||||
info: true,
|
||||
servers: true,
|
||||
operations: true,
|
||||
messages: true,
|
||||
}},
|
||||
}},
|
||||
}}, document.getElementById('asyncapi'));
|
||||
}} else {{
|
||||
document.getElementById('asyncapi').innerHTML =
|
||||
'<div style="padding: 20px; text-align: center;">' +
|
||||
'<h2>Failed to load AsyncAPI Studio</h2>' +
|
||||
'<p>Please check your internet connection and try again.</p>' +
|
||||
'</div>';
|
||||
}}
|
||||
}})();
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
return HTMLResponse(html)
|
||||
|
|
@ -0,0 +1,235 @@
|
|||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from fastapi import routing
|
||||
from fastapi.asyncapi.constants import ASYNCAPI_VERSION, REF_PREFIX
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from pydantic import BaseModel
|
||||
from starlette.routing import BaseRoute
|
||||
|
||||
|
||||
def get_asyncapi_channel(
|
||||
*,
|
||||
route: routing.APIWebSocketRoute,
|
||||
subscribe_payload_schema: dict[str, Any] | None = None,
|
||||
publish_payload_schema: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Generate AsyncAPI channel definition for a WebSocket route."""
|
||||
channel: dict[str, Any] = {}
|
||||
|
||||
# WebSocket channels typically have subscribe operation
|
||||
# (client subscribes to receive messages from server)
|
||||
operation: dict[str, Any] = {
|
||||
"operationId": route.name or f"websocket_{route.path_format}",
|
||||
}
|
||||
|
||||
# Message schema: contentType and optional payload (schema for message body)
|
||||
subscribe_message: dict[str, Any] = {
|
||||
"contentType": "application/json",
|
||||
}
|
||||
if subscribe_payload_schema:
|
||||
subscribe_message["payload"] = subscribe_payload_schema
|
||||
|
||||
operation["message"] = subscribe_message
|
||||
channel["subscribe"] = operation
|
||||
|
||||
# WebSockets are bidirectional, so we also include publish
|
||||
# (client can publish messages to server)
|
||||
publish_operation: dict[str, Any] = {
|
||||
"operationId": f"{route.name or f'websocket_{route.path_format}'}_publish",
|
||||
"message": {
|
||||
"contentType": "application/json",
|
||||
**({"payload": publish_payload_schema} if publish_payload_schema else {}),
|
||||
},
|
||||
}
|
||||
channel["publish"] = publish_operation
|
||||
|
||||
return channel
|
||||
|
||||
|
||||
def _get_fields_from_websocket_routes(
|
||||
routes: Sequence[BaseRoute],
|
||||
) -> list[Any]:
|
||||
"""Collect body (ModelField) params from WebSocket routes for schema generation."""
|
||||
from fastapi._compat import ModelField
|
||||
from fastapi.dependencies.utils import get_flat_dependant
|
||||
from pydantic.fields import FieldInfo
|
||||
|
||||
fields: list[Any] = []
|
||||
seen_models: set[type[BaseModel]] = set()
|
||||
for route in routes or []:
|
||||
if not isinstance(route, routing.APIWebSocketRoute):
|
||||
continue
|
||||
flat_dependant = get_flat_dependant(route.dependant, skip_repeats=True)
|
||||
fields.extend(flat_dependant.body_params)
|
||||
# Add explicit subscribe_schema / publish_schema as ModelFields so they get definitions
|
||||
for model in (
|
||||
getattr(route, "subscribe_schema", None),
|
||||
getattr(route, "publish_schema", None),
|
||||
):
|
||||
if (
|
||||
model is not None
|
||||
and isinstance(model, type)
|
||||
and issubclass(model, BaseModel)
|
||||
and model not in seen_models
|
||||
):
|
||||
seen_models.add(model)
|
||||
fields.append(
|
||||
ModelField(
|
||||
field_info=FieldInfo(annotation=model),
|
||||
name=model.__name__,
|
||||
mode="validation",
|
||||
)
|
||||
)
|
||||
return fields
|
||||
|
||||
|
||||
def get_asyncapi(
|
||||
*,
|
||||
title: str,
|
||||
version: str,
|
||||
asyncapi_version: str = ASYNCAPI_VERSION,
|
||||
summary: str | None = None,
|
||||
description: str | None = None,
|
||||
routes: Sequence[BaseRoute],
|
||||
servers: list[dict[str, str | Any]] | None = None,
|
||||
terms_of_service: str | None = None,
|
||||
contact: dict[str, str | Any] | None = None,
|
||||
license_info: dict[str, str | Any] | None = None,
|
||||
external_docs: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Generate AsyncAPI schema from FastAPI application routes.
|
||||
|
||||
Filters for WebSocket routes and generates AsyncAPI 2.6.0 compliant schema.
|
||||
Includes components/schemas for message payloads when WebSocket routes use
|
||||
Pydantic models (e.g. via Body() in dependencies).
|
||||
"""
|
||||
from fastapi._compat import (
|
||||
ModelField,
|
||||
get_definitions,
|
||||
get_flat_models_from_fields,
|
||||
get_model_name_map,
|
||||
get_schema_from_model_field,
|
||||
)
|
||||
|
||||
info: dict[str, Any] = {"title": title, "version": version}
|
||||
if summary:
|
||||
info["summary"] = summary
|
||||
if description:
|
||||
info["description"] = description
|
||||
if terms_of_service:
|
||||
info["termsOfService"] = terms_of_service
|
||||
if contact:
|
||||
info["contact"] = contact
|
||||
if license_info:
|
||||
info["license"] = license_info
|
||||
|
||||
output: dict[str, Any] = {"asyncapi": asyncapi_version, "info": info}
|
||||
|
||||
# Add default WebSocket server if no servers provided and we have WebSocket routes
|
||||
websocket_routes = [
|
||||
route for route in routes or [] if isinstance(route, routing.APIWebSocketRoute)
|
||||
]
|
||||
if websocket_routes and not servers:
|
||||
# Default WebSocket server - can be overridden by providing servers parameter
|
||||
output["servers"] = [
|
||||
{
|
||||
"url": "ws://localhost:8000",
|
||||
"protocol": "ws",
|
||||
"description": "WebSocket server",
|
||||
}
|
||||
]
|
||||
elif servers:
|
||||
output["servers"] = servers
|
||||
|
||||
# Build components/schemas from WebSocket body params and explicit subscribe/publish_schema
|
||||
ws_fields = _get_fields_from_websocket_routes(routes or [])
|
||||
components: dict[str, Any] = {}
|
||||
route_subscribe_schemas: dict[str, dict[str, Any] | None] = {}
|
||||
route_publish_schemas: dict[str, dict[str, Any] | None] = {}
|
||||
if ws_fields:
|
||||
flat_models = get_flat_models_from_fields(ws_fields, known_models=set())
|
||||
model_name_map = get_model_name_map(flat_models)
|
||||
field_mapping, definitions = get_definitions(
|
||||
fields=ws_fields,
|
||||
model_name_map=model_name_map,
|
||||
separate_input_output_schemas=True,
|
||||
)
|
||||
if definitions:
|
||||
components["schemas"] = {k: definitions[k] for k in sorted(definitions)}
|
||||
# For each WebSocket route, resolve subscribe and publish payload schemas
|
||||
for route in routes or []:
|
||||
if not isinstance(route, routing.APIWebSocketRoute):
|
||||
continue
|
||||
sub_schema: dict[str, Any] | None = None
|
||||
pub_schema: dict[str, Any] | None = None
|
||||
# Explicit subscribe_schema / publish_schema (e.g. when route has no Body() in Depends)
|
||||
subscribe_model = getattr(route, "subscribe_schema", None)
|
||||
publish_model = getattr(route, "publish_schema", None)
|
||||
if (
|
||||
subscribe_model is not None
|
||||
and isinstance(subscribe_model, type)
|
||||
and issubclass(subscribe_model, BaseModel)
|
||||
):
|
||||
sub_schema = {"$ref": f"{REF_PREFIX}{subscribe_model.__name__}"}
|
||||
if (
|
||||
publish_model is not None
|
||||
and isinstance(publish_model, type)
|
||||
and issubclass(publish_model, BaseModel)
|
||||
):
|
||||
pub_schema = {"$ref": f"{REF_PREFIX}{publish_model.__name__}"}
|
||||
# Fall back to first body param (Depends with Body()) for both if not set
|
||||
if sub_schema is None or pub_schema is None:
|
||||
flat_dependant = route._flat_dependant
|
||||
if flat_dependant.body_params:
|
||||
first_body = flat_dependant.body_params[0]
|
||||
if isinstance(first_body, ModelField):
|
||||
body_schema = get_schema_from_model_field(
|
||||
field=first_body,
|
||||
model_name_map=model_name_map,
|
||||
field_mapping=field_mapping,
|
||||
separate_input_output_schemas=True,
|
||||
)
|
||||
# Use only $ref for channel payload when schema is in components
|
||||
if "$ref" in body_schema and body_schema["$ref"].startswith(
|
||||
REF_PREFIX
|
||||
):
|
||||
body_schema = {"$ref": body_schema["$ref"]}
|
||||
if sub_schema is None:
|
||||
sub_schema = body_schema
|
||||
if pub_schema is None:
|
||||
pub_schema = body_schema
|
||||
route_subscribe_schemas[route.path_format] = sub_schema
|
||||
route_publish_schemas[route.path_format] = pub_schema
|
||||
else:
|
||||
for route in routes or []:
|
||||
if not isinstance(route, routing.APIWebSocketRoute):
|
||||
continue
|
||||
route_subscribe_schemas[route.path_format] = None
|
||||
route_publish_schemas[route.path_format] = None
|
||||
|
||||
channels: dict[str, dict[str, Any]] = {}
|
||||
|
||||
# Filter routes to only include WebSocket routes
|
||||
for route in routes or []:
|
||||
if isinstance(route, routing.APIWebSocketRoute):
|
||||
sub_schema = route_subscribe_schemas.get(route.path_format)
|
||||
pub_schema = route_publish_schemas.get(route.path_format)
|
||||
channel = get_asyncapi_channel(
|
||||
route=route,
|
||||
subscribe_payload_schema=sub_schema,
|
||||
publish_payload_schema=pub_schema,
|
||||
)
|
||||
if channel:
|
||||
channels[route.path_format] = channel
|
||||
|
||||
output["channels"] = channels
|
||||
|
||||
if components:
|
||||
output["components"] = components
|
||||
|
||||
if external_docs:
|
||||
output["externalDocs"] = external_docs
|
||||
|
||||
return jsonable_encoder(output, by_alias=True, exclude_none=True) # type: ignore
|
||||
|
|
@ -133,6 +133,14 @@ def get_swagger_ui_html(
|
|||
"""
|
||||
),
|
||||
] = None,
|
||||
asyncapi_docs_url: Annotated[
|
||||
str | None,
|
||||
Doc(
|
||||
"""
|
||||
The URL to the AsyncAPI docs for navigation link.
|
||||
"""
|
||||
),
|
||||
] = None,
|
||||
) -> HTMLResponse:
|
||||
"""
|
||||
Generate and return the HTML that loads Swagger UI for the interactive
|
||||
|
|
@ -149,6 +157,17 @@ def get_swagger_ui_html(
|
|||
if swagger_ui_parameters:
|
||||
current_swagger_ui_parameters.update(swagger_ui_parameters)
|
||||
|
||||
navigation_html = ""
|
||||
if asyncapi_docs_url:
|
||||
navigation_html = f"""
|
||||
<div style="padding: 10px; background-color: #f5f5f5; border-bottom: 1px solid #ddd;">
|
||||
<span style="color: #666;">REST API Documentation</span>
|
||||
<a href="{asyncapi_docs_url}" style="color: #007bff; text-decoration: none; margin-left: 20px;">
|
||||
🔌 AsyncAPI Docs (WebSocket API)
|
||||
</a>
|
||||
</div>
|
||||
"""
|
||||
|
||||
html = f"""
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
|
|
@ -159,6 +178,7 @@ def get_swagger_ui_html(
|
|||
<title>{title}</title>
|
||||
</head>
|
||||
<body>
|
||||
{navigation_html}
|
||||
<div id="swagger-ui">
|
||||
</div>
|
||||
<script src="{swagger_js_url}"></script>
|
||||
|
|
|
|||
|
|
@ -775,11 +775,15 @@ class APIWebSocketRoute(routing.WebSocketRoute):
|
|||
name: str | None = None,
|
||||
dependencies: Sequence[params.Depends] | None = None,
|
||||
dependency_overrides_provider: Any | None = None,
|
||||
subscribe_schema: type[Any] | None = None,
|
||||
publish_schema: type[Any] | None = None,
|
||||
) -> None:
|
||||
self.path = path
|
||||
self.endpoint = endpoint
|
||||
self.name = get_name(endpoint) if name is None else name
|
||||
self.dependencies = list(dependencies or [])
|
||||
self.subscribe_schema = subscribe_schema
|
||||
self.publish_schema = publish_schema
|
||||
self.path_regex, self.path_format, self.param_convertors = compile_path(path)
|
||||
self.dependant = get_dependant(
|
||||
path=self.path_format, call=self.endpoint, scope="function"
|
||||
|
|
@ -1485,6 +1489,8 @@ class APIRouter(routing.Router):
|
|||
name: str | None = None,
|
||||
*,
|
||||
dependencies: Sequence[params.Depends] | None = None,
|
||||
subscribe_schema: type[Any] | None = None,
|
||||
publish_schema: type[Any] | None = None,
|
||||
) -> None:
|
||||
current_dependencies = self.dependencies.copy()
|
||||
if dependencies:
|
||||
|
|
@ -1496,6 +1502,8 @@ class APIRouter(routing.Router):
|
|||
name=name,
|
||||
dependencies=current_dependencies,
|
||||
dependency_overrides_provider=self.dependency_overrides_provider,
|
||||
subscribe_schema=subscribe_schema,
|
||||
publish_schema=publish_schema,
|
||||
)
|
||||
self.routes.append(route)
|
||||
|
||||
|
|
@ -1530,6 +1538,25 @@ class APIRouter(routing.Router):
|
|||
"""
|
||||
),
|
||||
] = None,
|
||||
subscribe_schema: Annotated[
|
||||
type[Any] | None,
|
||||
Doc(
|
||||
"""
|
||||
Pydantic model for messages the client sends (subscribe operation).
|
||||
Used to generate AsyncAPI message payload schema when the route
|
||||
does not use Body() in dependencies.
|
||||
"""
|
||||
),
|
||||
] = None,
|
||||
publish_schema: Annotated[
|
||||
type[Any] | None,
|
||||
Doc(
|
||||
"""
|
||||
Pydantic model for messages the server sends (publish operation).
|
||||
Used to generate AsyncAPI message payload schema.
|
||||
"""
|
||||
),
|
||||
] = None,
|
||||
) -> Callable[[DecoratedCallable], DecoratedCallable]:
|
||||
"""
|
||||
Decorate a WebSocket function.
|
||||
|
|
@ -1560,7 +1587,12 @@ class APIRouter(routing.Router):
|
|||
|
||||
def decorator(func: DecoratedCallable) -> DecoratedCallable:
|
||||
self.add_api_websocket_route(
|
||||
path, func, name=name, dependencies=dependencies
|
||||
path,
|
||||
func,
|
||||
name=name,
|
||||
dependencies=dependencies,
|
||||
subscribe_schema=subscribe_schema,
|
||||
publish_schema=publish_schema,
|
||||
)
|
||||
return func
|
||||
|
||||
|
|
@ -1814,6 +1846,8 @@ class APIRouter(routing.Router):
|
|||
route.endpoint,
|
||||
dependencies=current_dependencies,
|
||||
name=route.name,
|
||||
subscribe_schema=route.subscribe_schema,
|
||||
publish_schema=route.publish_schema,
|
||||
)
|
||||
elif isinstance(route, routing.WebSocketRoute):
|
||||
self.add_websocket_route(
|
||||
|
|
|
|||
|
|
@ -254,6 +254,9 @@ omit = [
|
|||
"docs_src/response_model/tutorial003_04_py310.py",
|
||||
"docs_src/dependencies/tutorial013_an_py310.py", # temporary code example?
|
||||
"docs_src/dependencies/tutorial014_an_py310.py", # temporary code example?
|
||||
# Only run (and cover) on Python 3.14+
|
||||
"docs_src/dependencies/tutorial008_an_py310.py",
|
||||
"tests/test_stringified_annotation_dependency_py314.py",
|
||||
# Pydantic v1 migration, no longer tested
|
||||
"docs_src/pydantic_v1_in_v2/tutorial001_an_py310.py",
|
||||
"docs_src/pydantic_v1_in_v2/tutorial002_an_py310.py",
|
||||
|
|
|
|||
|
|
@ -0,0 +1,618 @@
|
|||
from fastapi import APIRouter, Body, Depends, FastAPI, WebSocket
|
||||
from fastapi.asyncapi.utils import get_asyncapi, get_asyncapi_channel
|
||||
from fastapi.testclient import TestClient
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
def test_asyncapi_schema():
|
||||
"""Test AsyncAPI schema endpoint with WebSocket routes."""
|
||||
app = FastAPI(title="Test API", version="1.0.0")
|
||||
|
||||
@app.websocket("/ws")
|
||||
async def websocket_endpoint(websocket: WebSocket):
|
||||
await websocket.accept()
|
||||
await websocket.close()
|
||||
|
||||
@app.websocket("/ws/{item_id}")
|
||||
async def websocket_with_param(websocket: WebSocket, item_id: str):
|
||||
await websocket.accept()
|
||||
await websocket.close()
|
||||
|
||||
client = TestClient(app)
|
||||
with client.websocket_connect("/ws"):
|
||||
pass
|
||||
with client.websocket_connect("/ws/foo"):
|
||||
pass
|
||||
response = client.get("/asyncapi.json")
|
||||
assert response.status_code == 200, response.text
|
||||
schema = response.json()
|
||||
assert schema["asyncapi"] == "2.6.0"
|
||||
assert schema["info"]["title"] == "Test API"
|
||||
assert schema["info"]["version"] == "1.0.0"
|
||||
assert "channels" in schema
|
||||
assert "/ws" in schema["channels"]
|
||||
assert "/ws/{item_id}" in schema["channels"]
|
||||
|
||||
|
||||
def test_asyncapi_no_websockets():
|
||||
"""Test AsyncAPI schema with no WebSocket routes."""
|
||||
app = FastAPI(title="Test API", version="1.0.0")
|
||||
|
||||
@app.get("/")
|
||||
def read_root():
|
||||
return {"message": "Hello World"}
|
||||
|
||||
client = TestClient(app)
|
||||
response = client.get("/")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"message": "Hello World"}
|
||||
response = client.get("/asyncapi.json")
|
||||
assert response.status_code == 200, response.text
|
||||
schema = response.json()
|
||||
assert schema["asyncapi"] == "2.6.0"
|
||||
assert schema["info"]["title"] == "Test API"
|
||||
assert schema["channels"] == {}
|
||||
|
||||
|
||||
def test_asyncapi_caching():
|
||||
"""Test that AsyncAPI schema is cached."""
|
||||
app = FastAPI(title="Test API", version="1.0.0")
|
||||
|
||||
@app.websocket("/ws")
|
||||
async def websocket_endpoint(websocket: WebSocket):
|
||||
await websocket.accept()
|
||||
await websocket.close()
|
||||
|
||||
client = TestClient(app)
|
||||
with client.websocket_connect("/ws"):
|
||||
pass
|
||||
schema1 = app.asyncapi()
|
||||
schema2 = app.asyncapi()
|
||||
# Should return the same object (identity check)
|
||||
assert schema1 is schema2
|
||||
|
||||
|
||||
def test_asyncapi_ui():
|
||||
"""Test AsyncAPI UI endpoint."""
|
||||
app = FastAPI(title="Test API", version="1.0.0")
|
||||
|
||||
@app.websocket("/ws")
|
||||
async def websocket_endpoint(websocket: WebSocket):
|
||||
await websocket.accept()
|
||||
await websocket.close()
|
||||
|
||||
client = TestClient(app)
|
||||
with client.websocket_connect("/ws"):
|
||||
pass
|
||||
response = client.get("/asyncapi-docs")
|
||||
assert response.status_code == 200, response.text
|
||||
assert response.headers["content-type"] == "text/html; charset=utf-8"
|
||||
assert "@asyncapi/react-component" in response.text
|
||||
assert "/asyncapi.json" in response.text
|
||||
|
||||
|
||||
def test_asyncapi_ui_navigation():
|
||||
"""Test navigation links in AsyncAPI UI."""
|
||||
app = FastAPI(title="Test API", version="1.0.0")
|
||||
|
||||
@app.websocket("/ws")
|
||||
async def websocket_endpoint(websocket: WebSocket):
|
||||
await websocket.accept()
|
||||
await websocket.close()
|
||||
|
||||
client = TestClient(app)
|
||||
with client.websocket_connect("/ws"):
|
||||
pass
|
||||
response = client.get("/asyncapi-docs")
|
||||
assert response.status_code == 200, response.text
|
||||
# Should contain link to OpenAPI docs
|
||||
assert "/docs" in response.text
|
||||
assert "OpenAPI Docs" in response.text
|
||||
|
||||
|
||||
def test_swagger_ui_asyncapi_navigation():
|
||||
"""Test navigation link to AsyncAPI in Swagger UI."""
|
||||
app = FastAPI(title="Test API", version="1.0.0")
|
||||
|
||||
@app.get("/")
|
||||
def read_root():
|
||||
return {"message": "Hello World"}
|
||||
|
||||
@app.websocket("/ws")
|
||||
async def websocket_endpoint(websocket: WebSocket):
|
||||
await websocket.accept()
|
||||
await websocket.close()
|
||||
|
||||
client = TestClient(app)
|
||||
response = client.get("/")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"message": "Hello World"}
|
||||
with client.websocket_connect("/ws"):
|
||||
pass
|
||||
response = client.get("/docs")
|
||||
assert response.status_code == 200, response.text
|
||||
# Should contain link to AsyncAPI docs
|
||||
assert "/asyncapi-docs" in response.text
|
||||
assert "AsyncAPI Docs" in response.text
|
||||
|
||||
|
||||
def test_asyncapi_custom_urls():
|
||||
"""Test custom AsyncAPI URLs."""
|
||||
app = FastAPI(
|
||||
title="Test API",
|
||||
version="1.0.0",
|
||||
asyncapi_url="/custom/asyncapi.json",
|
||||
asyncapi_docs_url="/custom/asyncapi-docs",
|
||||
)
|
||||
|
||||
@app.websocket("/ws")
|
||||
async def websocket_endpoint(websocket: WebSocket):
|
||||
await websocket.accept()
|
||||
await websocket.close()
|
||||
|
||||
client = TestClient(app)
|
||||
with client.websocket_connect("/ws"):
|
||||
pass
|
||||
# Test custom JSON endpoint
|
||||
response = client.get("/custom/asyncapi.json")
|
||||
assert response.status_code == 200, response.text
|
||||
schema = response.json()
|
||||
assert schema["asyncapi"] == "2.6.0"
|
||||
|
||||
# Test custom UI endpoint
|
||||
response = client.get("/custom/asyncapi-docs")
|
||||
assert response.status_code == 200, response.text
|
||||
assert "/custom/asyncapi.json" in response.text
|
||||
|
||||
# Default endpoints should not exist
|
||||
response = client.get("/asyncapi.json")
|
||||
assert response.status_code == 404
|
||||
response = client.get("/asyncapi-docs")
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
def test_asyncapi_disabled():
|
||||
"""Test when AsyncAPI is disabled."""
|
||||
app = FastAPI(
|
||||
title="Test API",
|
||||
version="1.0.0",
|
||||
asyncapi_url=None,
|
||||
)
|
||||
|
||||
@app.websocket("/ws")
|
||||
async def websocket_endpoint(websocket: WebSocket):
|
||||
await websocket.accept()
|
||||
await websocket.close()
|
||||
|
||||
client = TestClient(app)
|
||||
with client.websocket_connect("/ws"):
|
||||
pass
|
||||
# Endpoints should return 404
|
||||
response = client.get("/asyncapi.json")
|
||||
assert response.status_code == 404
|
||||
response = client.get("/asyncapi-docs")
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
def test_asyncapi_channel_structure():
|
||||
"""Test AsyncAPI channel structure."""
|
||||
app = FastAPI(title="Test API", version="1.0.0")
|
||||
|
||||
@app.websocket("/ws")
|
||||
async def websocket_endpoint(websocket: WebSocket):
|
||||
await websocket.accept()
|
||||
await websocket.close()
|
||||
|
||||
client = TestClient(app)
|
||||
with client.websocket_connect("/ws"):
|
||||
pass
|
||||
response = client.get("/asyncapi.json")
|
||||
assert response.status_code == 200, response.text
|
||||
schema = response.json()
|
||||
channel = schema["channels"]["/ws"]
|
||||
assert "subscribe" in channel
|
||||
assert "operationId" in channel["subscribe"]
|
||||
assert "message" in channel["subscribe"]
|
||||
|
||||
|
||||
def test_asyncapi_multiple_websockets():
|
||||
"""Test AsyncAPI with multiple WebSocket routes."""
|
||||
app = FastAPI(title="Test API", version="1.0.0")
|
||||
|
||||
@app.websocket("/ws1")
|
||||
async def websocket1(websocket: WebSocket):
|
||||
await websocket.accept()
|
||||
await websocket.close()
|
||||
|
||||
@app.websocket("/ws2")
|
||||
async def websocket2(websocket: WebSocket):
|
||||
await websocket.accept()
|
||||
await websocket.close()
|
||||
|
||||
@app.websocket("/ws3/{param}")
|
||||
async def websocket3(websocket: WebSocket, param: str):
|
||||
await websocket.accept()
|
||||
await websocket.close()
|
||||
|
||||
client = TestClient(app)
|
||||
with client.websocket_connect("/ws1"):
|
||||
pass
|
||||
with client.websocket_connect("/ws2"):
|
||||
pass
|
||||
with client.websocket_connect("/ws3/bar"):
|
||||
pass
|
||||
response = client.get("/asyncapi.json")
|
||||
assert response.status_code == 200, response.text
|
||||
schema = response.json()
|
||||
assert len(schema["channels"]) == 3
|
||||
assert "/ws1" in schema["channels"]
|
||||
assert "/ws2" in schema["channels"]
|
||||
assert "/ws3/{param}" in schema["channels"]
|
||||
|
||||
|
||||
def test_asyncapi_with_metadata():
|
||||
"""Test AsyncAPI schema includes app metadata."""
|
||||
app = FastAPI(
|
||||
title="My API",
|
||||
version="2.0.0",
|
||||
summary="Test summary",
|
||||
description="Test description",
|
||||
)
|
||||
|
||||
@app.websocket("/ws")
|
||||
async def websocket_endpoint(websocket: WebSocket):
|
||||
await websocket.accept()
|
||||
await websocket.close()
|
||||
|
||||
client = TestClient(app)
|
||||
with client.websocket_connect("/ws"):
|
||||
pass
|
||||
response = client.get("/asyncapi.json")
|
||||
assert response.status_code == 200, response.text
|
||||
schema = response.json()
|
||||
assert schema["info"]["title"] == "My API"
|
||||
assert schema["info"]["version"] == "2.0.0"
|
||||
assert schema["info"]["summary"] == "Test summary"
|
||||
assert schema["info"]["description"] == "Test description"
|
||||
|
||||
|
||||
def test_asyncapi_ui_no_docs_url():
|
||||
"""Test AsyncAPI UI when docs_url is None."""
|
||||
app = FastAPI(
|
||||
title="Test API",
|
||||
version="1.0.0",
|
||||
docs_url=None,
|
||||
)
|
||||
|
||||
@app.websocket("/ws")
|
||||
async def websocket_endpoint(websocket: WebSocket):
|
||||
await websocket.accept()
|
||||
await websocket.close()
|
||||
|
||||
client = TestClient(app)
|
||||
with client.websocket_connect("/ws"):
|
||||
pass
|
||||
response = client.get("/asyncapi-docs")
|
||||
assert response.status_code == 200, response.text
|
||||
# Should not contain link to /docs if docs_url is None
|
||||
# But navigation should still work (just won't show the link)
|
||||
assert "/asyncapi.json" in response.text
|
||||
|
||||
|
||||
def test_asyncapi_with_servers():
|
||||
"""Test AsyncAPI schema with custom servers."""
|
||||
app = FastAPI(
|
||||
title="Test API",
|
||||
version="1.0.0",
|
||||
servers=[{"url": "wss://example.com", "protocol": "wss"}],
|
||||
)
|
||||
|
||||
@app.websocket("/ws")
|
||||
async def websocket_endpoint(websocket: WebSocket):
|
||||
await websocket.accept()
|
||||
await websocket.close()
|
||||
|
||||
client = TestClient(app)
|
||||
with client.websocket_connect("/ws"):
|
||||
pass
|
||||
response = client.get("/asyncapi.json")
|
||||
assert response.status_code == 200, response.text
|
||||
schema = response.json()
|
||||
assert "servers" in schema
|
||||
assert schema["servers"] == [{"url": "wss://example.com", "protocol": "wss"}]
|
||||
|
||||
|
||||
def test_asyncapi_with_all_metadata():
|
||||
"""Test AsyncAPI schema with all optional metadata fields."""
|
||||
app = FastAPI(
|
||||
title="Test API",
|
||||
version="1.0.0",
|
||||
summary="Test summary",
|
||||
description="Test description",
|
||||
terms_of_service="https://example.com/terms",
|
||||
contact={"name": "API Support", "email": "support@example.com"},
|
||||
license_info={"name": "MIT", "url": "https://opensource.org/licenses/MIT"},
|
||||
)
|
||||
|
||||
@app.websocket("/ws")
|
||||
async def websocket_endpoint(websocket: WebSocket):
|
||||
await websocket.accept()
|
||||
await websocket.close()
|
||||
|
||||
client = TestClient(app)
|
||||
with client.websocket_connect("/ws"):
|
||||
pass
|
||||
response = client.get("/asyncapi.json")
|
||||
assert response.status_code == 200, response.text
|
||||
schema = response.json()
|
||||
assert schema["info"]["summary"] == "Test summary"
|
||||
assert schema["info"]["description"] == "Test description"
|
||||
assert schema["info"]["termsOfService"] == "https://example.com/terms"
|
||||
assert schema["info"]["contact"] == {
|
||||
"name": "API Support",
|
||||
"email": "support@example.com",
|
||||
}
|
||||
assert schema["info"]["license"] == {
|
||||
"name": "MIT",
|
||||
"url": "https://opensource.org/licenses/MIT",
|
||||
}
|
||||
|
||||
|
||||
def test_asyncapi_with_external_docs():
|
||||
"""Test AsyncAPI schema with external documentation."""
|
||||
app = FastAPI(
|
||||
title="Test API",
|
||||
version="1.0.0",
|
||||
)
|
||||
|
||||
@app.websocket("/ws")
|
||||
async def websocket_endpoint(websocket: WebSocket):
|
||||
await websocket.accept()
|
||||
await websocket.close()
|
||||
|
||||
# Set external_docs after app creation
|
||||
app.openapi_external_docs = {
|
||||
"description": "External API documentation",
|
||||
"url": "https://docs.example.com",
|
||||
}
|
||||
|
||||
client = TestClient(app)
|
||||
with client.websocket_connect("/ws"):
|
||||
pass
|
||||
response = client.get("/asyncapi.json")
|
||||
assert response.status_code == 200, response.text
|
||||
schema = response.json()
|
||||
assert "externalDocs" in schema
|
||||
assert schema["externalDocs"] == {
|
||||
"description": "External API documentation",
|
||||
"url": "https://docs.example.com",
|
||||
}
|
||||
|
||||
|
||||
def test_asyncapi_channel_with_route_name():
|
||||
"""Test AsyncAPI channel with named route."""
|
||||
app = FastAPI(title="Test API", version="1.0.0")
|
||||
|
||||
@app.websocket("/ws", name="my_websocket")
|
||||
async def websocket_endpoint(websocket: WebSocket):
|
||||
await websocket.accept()
|
||||
await websocket.close()
|
||||
|
||||
client = TestClient(app)
|
||||
with client.websocket_connect("/ws"):
|
||||
pass
|
||||
response = client.get("/asyncapi.json")
|
||||
assert response.status_code == 200, response.text
|
||||
schema = response.json()
|
||||
channel = schema["channels"]["/ws"]
|
||||
assert channel["subscribe"]["operationId"] == "my_websocket"
|
||||
assert channel["publish"]["operationId"] == "my_websocket_publish"
|
||||
|
||||
|
||||
def test_get_asyncapi_channel_direct():
|
||||
"""Test get_asyncapi_channel function directly."""
|
||||
from fastapi import routing
|
||||
|
||||
app = FastAPI(title="Test API", version="1.0.0")
|
||||
|
||||
@app.websocket("/ws", name="test_ws")
|
||||
async def websocket_endpoint(websocket: WebSocket):
|
||||
await websocket.accept()
|
||||
await websocket.close()
|
||||
|
||||
client = TestClient(app)
|
||||
with client.websocket_connect("/ws"):
|
||||
pass
|
||||
# Get the route from the app
|
||||
route = next(r for r in app.routes if isinstance(r, routing.APIWebSocketRoute))
|
||||
channel = get_asyncapi_channel(route=route)
|
||||
assert "subscribe" in channel
|
||||
assert "publish" in channel
|
||||
assert channel["subscribe"]["operationId"] == "test_ws"
|
||||
assert channel["publish"]["operationId"] == "test_ws_publish"
|
||||
|
||||
|
||||
def test_get_asyncapi_direct():
|
||||
"""Test get_asyncapi function directly."""
|
||||
app = FastAPI(title="Test API", version="1.0.0")
|
||||
|
||||
@app.websocket("/ws")
|
||||
async def websocket_endpoint(websocket: WebSocket):
|
||||
await websocket.accept()
|
||||
await websocket.close()
|
||||
|
||||
client = TestClient(app)
|
||||
with client.websocket_connect("/ws"):
|
||||
pass
|
||||
schema = get_asyncapi(
|
||||
title=app.title,
|
||||
version=app.version,
|
||||
routes=app.routes,
|
||||
)
|
||||
assert schema["asyncapi"] == "2.6.0"
|
||||
assert schema["info"]["title"] == "Test API"
|
||||
assert "/ws" in schema["channels"]
|
||||
|
||||
|
||||
def test_asyncapi_url_none_no_link_in_swagger():
|
||||
"""Test that Swagger UI doesn't show AsyncAPI link when asyncapi_url is None."""
|
||||
app = FastAPI(
|
||||
title="Test API",
|
||||
version="1.0.0",
|
||||
asyncapi_url=None, # Explicitly disabled
|
||||
# asyncapi_docs_url defaults to "/asyncapi-docs"
|
||||
)
|
||||
|
||||
@app.websocket("/ws")
|
||||
async def websocket_endpoint(websocket: WebSocket):
|
||||
await websocket.accept()
|
||||
await websocket.close()
|
||||
|
||||
client = TestClient(app)
|
||||
with client.websocket_connect("/ws"):
|
||||
pass
|
||||
# Swagger UI should not show AsyncAPI link when asyncapi_url is None
|
||||
response = client.get("/docs")
|
||||
assert response.status_code == 200, response.text
|
||||
assert "/asyncapi-docs" not in response.text
|
||||
|
||||
# AsyncAPI endpoint should not exist
|
||||
response = client.get("/asyncapi-docs")
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
def test_asyncapi_components_and_message_payload():
|
||||
"""Test AsyncAPI schema includes components/schemas and message payload when models are used."""
|
||||
app = FastAPI(title="Test API", version="1.0.0")
|
||||
|
||||
class QueryMessage(BaseModel):
|
||||
"""Message sent on /query channel."""
|
||||
|
||||
text: str
|
||||
limit: int = 10
|
||||
|
||||
def get_query_message(
|
||||
msg: QueryMessage = Body(default=QueryMessage(text="", limit=10)),
|
||||
) -> QueryMessage:
|
||||
return msg
|
||||
|
||||
@app.websocket("/query")
|
||||
async def query_ws(
|
||||
websocket: WebSocket, msg: QueryMessage = Depends(get_query_message)
|
||||
):
|
||||
await websocket.accept()
|
||||
await websocket.close()
|
||||
|
||||
# Connect to websocket so handler and dependency are covered (body default used)
|
||||
client = TestClient(app)
|
||||
with client.websocket_connect("/query"):
|
||||
pass
|
||||
|
||||
# Generate schema and assert components
|
||||
response = client.get("/asyncapi.json")
|
||||
assert response.status_code == 200, response.text
|
||||
schema = response.json()
|
||||
|
||||
# Should have components with schemas (reusable model definitions)
|
||||
assert "components" in schema
|
||||
assert "schemas" in schema["components"]
|
||||
assert "QueryMessage" in schema["components"]["schemas"]
|
||||
query_schema = schema["components"]["schemas"]["QueryMessage"]
|
||||
assert query_schema.get("title") == "QueryMessage"
|
||||
assert "text" in query_schema.get("properties", {})
|
||||
assert "limit" in query_schema.get("properties", {})
|
||||
|
||||
# Channel messages should reference the payload schema
|
||||
channel = schema["channels"]["/query"]
|
||||
for operation_key in ("subscribe", "publish"):
|
||||
msg_spec = channel[operation_key]["message"]
|
||||
assert msg_spec["contentType"] == "application/json"
|
||||
assert "payload" in msg_spec
|
||||
assert msg_spec["payload"] == {"$ref": "#/components/schemas/QueryMessage"}
|
||||
|
||||
|
||||
def test_asyncapi_explicit_subscribe_publish_schema():
|
||||
"""Test AsyncAPI schema when websocket uses subscribe_schema and publish_schema (no Body in deps).
|
||||
|
||||
Covers: components/schemas built from explicit subscribe_schema/publish_schema ModelFields,
|
||||
and channel message payloads set from explicit subscribe_model/publish_model $refs.
|
||||
"""
|
||||
app = FastAPI(title="Test API", version="1.0.0")
|
||||
router = APIRouter()
|
||||
|
||||
class ClientMessage(BaseModel):
|
||||
"""Message the client sends."""
|
||||
|
||||
action: str
|
||||
payload: str = ""
|
||||
|
||||
class ServerMessage(BaseModel):
|
||||
"""Message the server sends."""
|
||||
|
||||
event: str
|
||||
data: dict = {}
|
||||
|
||||
@router.websocket(
|
||||
"/chat",
|
||||
subscribe_schema=ClientMessage,
|
||||
publish_schema=ServerMessage,
|
||||
)
|
||||
async def chat_ws(websocket: WebSocket):
|
||||
await websocket.accept()
|
||||
await websocket.close()
|
||||
|
||||
app.include_router(router)
|
||||
client = TestClient(app)
|
||||
with client.websocket_connect("/chat"):
|
||||
pass
|
||||
|
||||
response = client.get("/asyncapi.json")
|
||||
assert response.status_code == 200, response.text
|
||||
schema = response.json()
|
||||
|
||||
# Components should include both models (from explicit subscribe_schema/publish_schema ModelFields)
|
||||
assert "components" in schema
|
||||
assert "schemas" in schema["components"]
|
||||
assert "ClientMessage" in schema["components"]["schemas"]
|
||||
assert "ServerMessage" in schema["components"]["schemas"]
|
||||
client_schema = schema["components"]["schemas"]["ClientMessage"]
|
||||
server_schema = schema["components"]["schemas"]["ServerMessage"]
|
||||
assert client_schema.get("title") == "ClientMessage"
|
||||
assert "action" in client_schema.get("properties", {})
|
||||
assert server_schema.get("title") == "ServerMessage"
|
||||
assert "event" in server_schema.get("properties", {})
|
||||
|
||||
# Channel subscribe/publish should use explicit $refs (subscribe_model / publish_model path)
|
||||
channel = schema["channels"]["/chat"]
|
||||
sub_msg = channel["subscribe"]["message"]
|
||||
pub_msg = channel["publish"]["message"]
|
||||
assert sub_msg["contentType"] == "application/json"
|
||||
assert sub_msg["payload"] == {"$ref": "#/components/schemas/ClientMessage"}
|
||||
assert pub_msg["contentType"] == "application/json"
|
||||
assert pub_msg["payload"] == {"$ref": "#/components/schemas/ServerMessage"}
|
||||
|
||||
|
||||
def test_asyncapi_with_root_path_in_servers():
|
||||
"""Test AsyncAPI schema includes root_path in servers when root_path_in_servers is True."""
|
||||
app = FastAPI(
|
||||
title="Test API",
|
||||
version="1.0.0",
|
||||
root_path_in_servers=True,
|
||||
)
|
||||
|
||||
@app.websocket("/ws")
|
||||
async def websocket_endpoint(websocket: WebSocket):
|
||||
await websocket.accept()
|
||||
await websocket.close()
|
||||
|
||||
# Use TestClient with root_path to trigger the root_path logic
|
||||
client = TestClient(app, root_path="/api/v1")
|
||||
with client.websocket_connect("/ws"):
|
||||
pass
|
||||
response = client.get("/asyncapi.json")
|
||||
assert response.status_code == 200, response.text
|
||||
schema = response.json()
|
||||
assert "servers" in schema
|
||||
# Root path should be added to servers
|
||||
server_urls = [s["url"] for s in schema["servers"]]
|
||||
assert "/api/v1" in server_urls
|
||||
|
|
@ -1,4 +1,9 @@
|
|||
from fastapi.dependencies.utils import get_typed_annotation
|
||||
import inspect
|
||||
import sys
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
from fastapi.dependencies.utils import get_typed_annotation, get_typed_signature
|
||||
|
||||
|
||||
def test_get_typed_annotation():
|
||||
|
|
@ -6,3 +11,31 @@ def test_get_typed_annotation():
|
|||
annotation = "None"
|
||||
typed_annotation = get_typed_annotation(annotation, globals())
|
||||
assert typed_annotation is None
|
||||
|
||||
|
||||
def test_get_signature_nameerror_py314_branch():
|
||||
"""Cover _get_signature NameError branch with Python 3.14+ annotation_format path."""
|
||||
real_signature = inspect.signature
|
||||
|
||||
def mock_signature(call, *args, **kwargs):
|
||||
if kwargs.get("eval_str") is True:
|
||||
raise NameError("undefined name")
|
||||
# On Python < 3.14, inspect.signature does not accept annotation_format
|
||||
kwargs.pop("annotation_format", None)
|
||||
return real_signature(call, *args, **kwargs)
|
||||
|
||||
def simple_dep(x: int) -> int:
|
||||
return x
|
||||
|
||||
# annotationlib is only available on Python 3.14+; provide a minimal mock # noqa: E501
|
||||
fake_annotationlib = SimpleNamespace(Format=SimpleNamespace(FORWARDREF=object()))
|
||||
|
||||
with (
|
||||
patch.object(sys, "version_info", (3, 14)),
|
||||
patch.dict("sys.modules", {"annotationlib": fake_annotationlib}),
|
||||
patch("fastapi.dependencies.utils.inspect.signature", mock_signature),
|
||||
):
|
||||
sig = get_typed_signature(simple_dep)
|
||||
assert len(sig.parameters) == 1
|
||||
assert sig.parameters["x"].annotation is int
|
||||
assert simple_dep(42) == 42 # cover simple_dep body
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ import pytest
|
|||
from tests.utils import skip_module_if_py_gte_314
|
||||
|
||||
if sys.version_info >= (3, 14):
|
||||
skip_module_if_py_gte_314()
|
||||
skip_module_if_py_gte_314() # pragma: no cover
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.exceptions import PydanticV1NotSupportedError
|
||||
|
|
|
|||
|
|
@ -14,5 +14,5 @@ workdir_lock = pytest.mark.xdist_group("workdir_lock")
|
|||
|
||||
def skip_module_if_py_gte_314():
|
||||
"""Skip entire module on Python 3.14+ at import time."""
|
||||
if sys.version_info >= (3, 14):
|
||||
pytest.skip("requires python3.13-", allow_module_level=True)
|
||||
if sys.version_info >= (3, 14): # pragma: no cover
|
||||
pytest.skip("requires python3.13-", allow_module_level=True) # pragma: no cover
|
||||
|
|
|
|||
Loading…
Reference in New Issue