diff --git a/fastapi/routing.py b/fastapi/routing.py index e2c83aa7b3..f65478c978 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -836,10 +836,11 @@ class APIRoute(routing.Route): generate_unique_id_function: Callable[["APIRoute"], str] | DefaultPlaceholder = Default(generate_unique_id), strict_content_type: bool | DefaultPlaceholder = Default(True), + stream_item_type: Any | None = None, ) -> None: self.path = path self.endpoint = endpoint - self.stream_item_type: Any | None = None + self.stream_item_type: Any | None = stream_item_type if isinstance(response_model, DefaultPlaceholder): return_annotation = get_typed_return_annotation(endpoint) if lenient_issubclass(return_annotation, Response): @@ -1360,6 +1361,7 @@ class APIRouter(routing.Router): generate_unique_id_function: Callable[[APIRoute], str] | DefaultPlaceholder = Default(generate_unique_id), strict_content_type: bool | DefaultPlaceholder = Default(True), + stream_item_type: Any | None = None, ) -> None: route_class = route_class_override or self.route_class responses = responses or {} @@ -1409,6 +1411,7 @@ class APIRouter(routing.Router): strict_content_type=get_value_or_default( strict_content_type, self.strict_content_type ), + stream_item_type=stream_item_type, ) self.routes.append(route) @@ -1789,6 +1792,7 @@ class APIRouter(routing.Router): router.strict_content_type, self.strict_content_type, ), + stream_item_type=route.stream_item_type, ) elif isinstance(route, routing.Route): methods = list(route.methods or []) diff --git a/tests/test_sse.py b/tests/test_sse.py index 6dfec61838..b58e547353 100644 --- a/tests/test_sse.py +++ b/tests/test_sse.py @@ -96,6 +96,18 @@ async def stream_events(): yield {"msg": "world"} +@router.get("/events-typed", response_class=EventSourceResponse) +async def stream_events_typed() -> AsyncIterable[Item]: + for item in items: + yield item + + +@router.get("/events-jsonl") +async def stream_events_jsonl() -> AsyncIterable[Item]: + for item in items: + yield item + + app.include_router(router, prefix="/api") @@ -265,6 +277,49 @@ def test_sse_on_router_included_in_app(client: TestClient): assert len(data_lines) == 2 +def test_sse_router_typed_openapi_schema(client: TestClient): + """Typed SSE endpoint on a router should preserve itemSchema with contentSchema.""" + response = client.get("/openapi.json") + assert response.status_code == 200 + paths = response.json()["paths"] + sse_response = paths["/api/events-typed"]["get"]["responses"]["200"] + assert sse_response == { + "description": "Successful Response", + "content": { + "text/event-stream": { + "itemSchema": { + "type": "object", + "properties": { + "data": { + "type": "string", + "contentMediaType": "application/json", + "contentSchema": {"$ref": "#/components/schemas/Item"}, + }, + "event": {"type": "string"}, + "id": {"type": "string"}, + "retry": {"type": "integer", "minimum": 0}, + }, + "required": ["data"], + } + } + }, + } + + +def test_jsonl_router_typed_openapi_schema(client: TestClient): + """Typed JSONL endpoint on a router should preserve itemSchema.""" + response = client.get("/openapi.json") + assert response.status_code == 200 + paths = response.json()["paths"] + jsonl_response = paths["/api/events-jsonl"]["get"]["responses"]["200"] + assert jsonl_response == { + "description": "Successful Response", + "content": { + "application/jsonl": {"itemSchema": {"$ref": "#/components/schemas/Item"}} + }, + } + + # Keepalive ping tests