From 22381558446c5d1ac376680a6581dd63b3a04119 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Sun, 1 Mar 2026 01:21:52 -0800 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20Add=20support=20for=20Server=20Sent?= =?UTF-8?q?=20Events=20(#15030)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docs/en/docs/advanced/stream-data.md | 6 + docs/en/docs/tutorial/server-sent-events.md | 120 +++++++ docs/en/docs/tutorial/stream-json-lines.md | 10 +- docs/en/mkdocs.yml | 1 + docs_src/server_sent_events/__init__.py | 0 .../server_sent_events/tutorial001_py310.py | 43 +++ .../server_sent_events/tutorial002_py310.py | 26 ++ .../server_sent_events/tutorial003_py310.py | 17 + .../server_sent_events/tutorial004_py310.py | 31 ++ .../server_sent_events/tutorial005_py310.py | 19 ++ fastapi/.agents/skills/fastapi/SKILL.md | 50 +++ fastapi/openapi/utils.py | 21 ++ fastapi/responses.py | 1 + fastapi/routing.py | 180 ++++++++-- fastapi/sse.py | 222 ++++++++++++ pyproject.toml | 1 + tests/test_sse.py | 318 ++++++++++++++++++ .../test_server_sent_events/__init__.py | 0 .../test_tutorial001.py | 191 +++++++++++ .../test_tutorial002.py | 83 +++++ .../test_tutorial003.py | 73 ++++ .../test_tutorial004.py | 164 +++++++++ .../test_tutorial005.py | 141 ++++++++ 23 files changed, 1681 insertions(+), 37 deletions(-) create mode 100644 docs/en/docs/tutorial/server-sent-events.md create mode 100644 docs_src/server_sent_events/__init__.py create mode 100644 docs_src/server_sent_events/tutorial001_py310.py create mode 100644 docs_src/server_sent_events/tutorial002_py310.py create mode 100644 docs_src/server_sent_events/tutorial003_py310.py create mode 100644 docs_src/server_sent_events/tutorial004_py310.py create mode 100644 docs_src/server_sent_events/tutorial005_py310.py create mode 100644 fastapi/sse.py create mode 100644 tests/test_sse.py create mode 100644 tests/test_tutorial/test_server_sent_events/__init__.py create mode 100644 tests/test_tutorial/test_server_sent_events/test_tutorial001.py create mode 100644 tests/test_tutorial/test_server_sent_events/test_tutorial002.py create mode 100644 tests/test_tutorial/test_server_sent_events/test_tutorial003.py create mode 100644 tests/test_tutorial/test_server_sent_events/test_tutorial004.py create mode 100644 tests/test_tutorial/test_server_sent_events/test_tutorial005.py diff --git a/docs/en/docs/advanced/stream-data.md b/docs/en/docs/advanced/stream-data.md index 422ade867a..fe9212a24b 100644 --- a/docs/en/docs/advanced/stream-data.md +++ b/docs/en/docs/advanced/stream-data.md @@ -4,6 +4,12 @@ If you want to stream data that can be structured as JSON, you should [Stream JS But if you want to **stream pure binary data** or strings, here's how you can do it. +/// info + +Added in FastAPI 0.134.0. + +/// + ## Use Cases { #use-cases } You could use this if you want to stream pure strings, for example directly from the output of an **AI LLM** service. diff --git a/docs/en/docs/tutorial/server-sent-events.md b/docs/en/docs/tutorial/server-sent-events.md new file mode 100644 index 0000000000..0a4bed2660 --- /dev/null +++ b/docs/en/docs/tutorial/server-sent-events.md @@ -0,0 +1,120 @@ +# Server-Sent Events (SSE) { #server-sent-events-sse } + +You can stream data to the client using **Server-Sent Events** (SSE). + +This is similar to [Stream JSON Lines](stream-json-lines.md){.internal-link target=_blank}, but uses the `text/event-stream` format, which is supported natively by browsers with the `EventSource` API. + +/// info + +Added in FastAPI 0.135.0. + +/// + +## What are Server-Sent Events? { #what-are-server-sent-events } + +SSE is a standard for streaming data from the server to the client over HTTP. + +Each event is a small text block with "fields" like `data`, `event`, `id`, and `retry`, separated by blank lines. + +It looks like this: + +``` +data: {"name": "Portal Gun", "price": 999.99} + +data: {"name": "Plumbus", "price": 32.99} + +``` + +SSE is commonly used for AI chat streaming, live notifications, logs and observability, and other cases where the server pushes updates to the client. + +/// tip + +If you want to stream binary data, for example video or audio, check the advanced guide: [Stream Data](../advanced/stream-data.md){.internal-link target=_blank}. + +/// + +## Stream SSE with FastAPI { #stream-sse-with-fastapi } + +To stream SSE with FastAPI, use `yield` in your *path operation function* and set `response_class=EventSourceResponse`. + +Import `EventSourceResponse` from `fastapi.sse`: + +{* ../../docs_src/server_sent_events/tutorial001_py310.py ln[1:25] hl[4,22] *} + +Each yielded item is encoded as JSON and sent in the `data:` field of an SSE event. + +If you declare the return type as `AsyncIterable[Item]`, FastAPI will use it to **validate**, **document**, and **serialize** the data using Pydantic. + +{* ../../docs_src/server_sent_events/tutorial001_py310.py ln[1:25] hl[10:12,23] *} + +/// tip + +As Pydantic will serialize it in the **Rust** side, you will get much higher **performance** than if you don't declare a return type. + +/// + +### Non-async *path operation functions* { #non-async-path-operation-functions } + +You can also use regular `def` functions (without `async`), and use `yield` the same way. + +FastAPI will make sure it's run correctly so that it doesn't block the event loop. + +As in this case the function is not async, the right return type would be `Iterable[Item]`: + +{* ../../docs_src/server_sent_events/tutorial001_py310.py ln[28:31] hl[29] *} + +### No Return Type { #no-return-type } + +You can also omit the return type. FastAPI will use the [`jsonable_encoder`](./encoder.md){.internal-link target=_blank} to convert the data and send it. + +{* ../../docs_src/server_sent_events/tutorial001_py310.py ln[34:37] hl[35] *} + +## `ServerSentEvent` { #serversentevent } + +If you need to set SSE fields like `event`, `id`, `retry`, or `comment`, you can yield `ServerSentEvent` objects instead of plain data. + +Import `ServerSentEvent` from `fastapi.sse`: + +{* ../../docs_src/server_sent_events/tutorial002_py310.py hl[4,26] *} + +The `data` field is always encoded as JSON. You can pass any value that can be serialized as JSON, including Pydantic models. + +## Raw Data { #raw-data } + +If you need to send data **without** JSON encoding, use `raw_data` instead of `data`. + +This is useful for sending pre-formatted text, log lines, or special "sentinel" values like `[DONE]`. + +{* ../../docs_src/server_sent_events/tutorial003_py310.py hl[17] *} + +/// note + +`data` and `raw_data` are mutually exclusive. You can only set one of them on each `ServerSentEvent`. + +/// + +## Resuming with `Last-Event-ID` { #resuming-with-last-event-id } + +When a browser reconnects after a connection drop, it sends the last received `id` in the `Last-Event-ID` header. + +You can read it as a header parameter and use it to resume the stream from where the client left off: + +{* ../../docs_src/server_sent_events/tutorial004_py310.py hl[25,27,31] *} + +## SSE with POST { #sse-with-post } + +SSE works with **any HTTP method**, not just `GET`. + +This is useful for protocols like MCP that stream SSE over `POST`: + +{* ../../docs_src/server_sent_events/tutorial005_py310.py hl[14] *} + +## Technical Details { #technical-details } + +FastAPI implements some SSE best practices out of the box. + +* Send a **"keep alive" `ping` comment** every 15 seconds when there hasn't been any message, to prevent some proxies from closing the connection, as suggested in the HTML specification: Server-Sent Events. +* Set the `Cache-Control: no-cache` header to **prevent caching** of the stream. +* Set a special header `X-Accel-Buffering: no` to **prevent buffering** in some proxies like Nginx. + +You don't have to do anything about it, it works out of the box. 🤓 diff --git a/docs/en/docs/tutorial/stream-json-lines.md b/docs/en/docs/tutorial/stream-json-lines.md index b65d0c0fab..2ee3aacc6c 100644 --- a/docs/en/docs/tutorial/stream-json-lines.md +++ b/docs/en/docs/tutorial/stream-json-lines.md @@ -2,6 +2,12 @@ You could have a sequence of data that you would like to send in a "**stream**", you could do it with **JSON Lines**. +/// info + +Added in FastAPI 0.134.0. + +/// + ## What is a Stream? { #what-is-a-stream } "**Streaming**" data means that your app will start sending data items to the client without waiting for the entire sequence of items to be ready. @@ -100,6 +106,6 @@ You can also omit the return type. FastAPI will then use the [`jsonable_encoder` {* ../../docs_src/stream_json_lines/tutorial001_py310.py ln[33:36] hl[34] *} -## Server Sent Events (SSE) { #server-sent-events-sse } +## Server-Sent Events (SSE) { #server-sent-events-sse } -A future version of FastAPI will also have first-class support for Server Sent Events (SSE), which are quite similar, but with a couple of extra details. 🤓 +FastAPI also has first-class support for Server-Sent Events (SSE), which are quite similar but with a couple of extra details. You can learn about them in the next chapter: [Server-Sent Events (SSE)](server-sent-events.md){.internal-link target=_blank}. 🤓 diff --git a/docs/en/mkdocs.yml b/docs/en/mkdocs.yml index 4c017e1b5a..78f03bf443 100644 --- a/docs/en/mkdocs.yml +++ b/docs/en/mkdocs.yml @@ -155,6 +155,7 @@ nav: - tutorial/sql-databases.md - tutorial/bigger-applications.md - tutorial/stream-json-lines.md + - tutorial/server-sent-events.md - tutorial/background-tasks.md - tutorial/metadata.md - tutorial/static-files.md diff --git a/docs_src/server_sent_events/__init__.py b/docs_src/server_sent_events/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/docs_src/server_sent_events/tutorial001_py310.py b/docs_src/server_sent_events/tutorial001_py310.py new file mode 100644 index 0000000000..8fa470da50 --- /dev/null +++ b/docs_src/server_sent_events/tutorial001_py310.py @@ -0,0 +1,43 @@ +from collections.abc import AsyncIterable, Iterable + +from fastapi import FastAPI +from fastapi.sse import EventSourceResponse +from pydantic import BaseModel + +app = FastAPI() + + +class Item(BaseModel): + name: str + description: str | None + + +items = [ + Item(name="Plumbus", description="A multi-purpose household device."), + Item(name="Portal Gun", description="A portal opening device."), + Item(name="Meeseeks Box", description="A box that summons a Meeseeks."), +] + + +@app.get("/items/stream", response_class=EventSourceResponse) +async def sse_items() -> AsyncIterable[Item]: + for item in items: + yield item + + +@app.get("/items/stream-no-async", response_class=EventSourceResponse) +def sse_items_no_async() -> Iterable[Item]: + for item in items: + yield item + + +@app.get("/items/stream-no-annotation", response_class=EventSourceResponse) +async def sse_items_no_annotation(): + for item in items: + yield item + + +@app.get("/items/stream-no-async-no-annotation", response_class=EventSourceResponse) +def sse_items_no_async_no_annotation(): + for item in items: + yield item diff --git a/docs_src/server_sent_events/tutorial002_py310.py b/docs_src/server_sent_events/tutorial002_py310.py new file mode 100644 index 0000000000..0f6136f4fd --- /dev/null +++ b/docs_src/server_sent_events/tutorial002_py310.py @@ -0,0 +1,26 @@ +from collections.abc import AsyncIterable + +from fastapi import FastAPI +from fastapi.sse import EventSourceResponse, ServerSentEvent +from pydantic import BaseModel + +app = FastAPI() + + +class Item(BaseModel): + name: str + price: float + + +items = [ + Item(name="Plumbus", price=32.99), + Item(name="Portal Gun", price=999.99), + Item(name="Meeseeks Box", price=49.99), +] + + +@app.get("/items/stream", response_class=EventSourceResponse) +async def stream_items() -> AsyncIterable[ServerSentEvent]: + yield ServerSentEvent(comment="stream of item updates") + for i, item in enumerate(items): + yield ServerSentEvent(data=item, event="item_update", id=str(i + 1), retry=5000) diff --git a/docs_src/server_sent_events/tutorial003_py310.py b/docs_src/server_sent_events/tutorial003_py310.py new file mode 100644 index 0000000000..3006deb86d --- /dev/null +++ b/docs_src/server_sent_events/tutorial003_py310.py @@ -0,0 +1,17 @@ +from collections.abc import AsyncIterable + +from fastapi import FastAPI +from fastapi.sse import EventSourceResponse, ServerSentEvent + +app = FastAPI() + + +@app.get("/logs/stream", response_class=EventSourceResponse) +async def stream_logs() -> AsyncIterable[ServerSentEvent]: + logs = [ + "2025-01-01 INFO Application started", + "2025-01-01 DEBUG Connected to database", + "2025-01-01 WARN High memory usage detected", + ] + for log_line in logs: + yield ServerSentEvent(raw_data=log_line) diff --git a/docs_src/server_sent_events/tutorial004_py310.py b/docs_src/server_sent_events/tutorial004_py310.py new file mode 100644 index 0000000000..3e8f8d113f --- /dev/null +++ b/docs_src/server_sent_events/tutorial004_py310.py @@ -0,0 +1,31 @@ +from collections.abc import AsyncIterable +from typing import Annotated + +from fastapi import FastAPI, Header +from fastapi.sse import EventSourceResponse, ServerSentEvent +from pydantic import BaseModel + +app = FastAPI() + + +class Item(BaseModel): + name: str + price: float + + +items = [ + Item(name="Plumbus", price=32.99), + Item(name="Portal Gun", price=999.99), + Item(name="Meeseeks Box", price=49.99), +] + + +@app.get("/items/stream", response_class=EventSourceResponse) +async def stream_items( + last_event_id: Annotated[int | None, Header()] = None, +) -> AsyncIterable[ServerSentEvent]: + start = last_event_id + 1 if last_event_id is not None else 0 + for i, item in enumerate(items): + if i < start: + continue + yield ServerSentEvent(data=item, id=str(i)) diff --git a/docs_src/server_sent_events/tutorial005_py310.py b/docs_src/server_sent_events/tutorial005_py310.py new file mode 100644 index 0000000000..4e6730e5aa --- /dev/null +++ b/docs_src/server_sent_events/tutorial005_py310.py @@ -0,0 +1,19 @@ +from collections.abc import AsyncIterable + +from fastapi import FastAPI +from fastapi.sse import EventSourceResponse, ServerSentEvent +from pydantic import BaseModel + +app = FastAPI() + + +class Prompt(BaseModel): + text: str + + +@app.post("/chat/stream", response_class=EventSourceResponse) +async def stream_chat(prompt: Prompt) -> AsyncIterable[ServerSentEvent]: + words = prompt.text.split() + for word in words: + yield ServerSentEvent(data=word, event="token") + yield ServerSentEvent(raw_data="[DONE]", event="done") diff --git a/fastapi/.agents/skills/fastapi/SKILL.md b/fastapi/.agents/skills/fastapi/SKILL.md index ead0f61749..8e2329cbdb 100644 --- a/fastapi/.agents/skills/fastapi/SKILL.md +++ b/fastapi/.agents/skills/fastapi/SKILL.md @@ -521,6 +521,56 @@ async def stream_items() -> AsyncIterable[Item]: yield item ``` +## Server-Sent Events (SSE) + +To stream Server-Sent Events, use `response_class=EventSourceResponse` and `yield` items from the endpoint. + +Plain objects are automatically JSON-serialized as `data:` fields, declare the return type so the serialization is done by Pydantic: + +```python +from collections.abc import AsyncIterable + +from fastapi import FastAPI +from fastapi.sse import EventSourceResponse +from pydantic import BaseModel + +app = FastAPI() + + +class Item(BaseModel): + name: str + price: float + + +@app.get("/items/stream", response_class=EventSourceResponse) +async def stream_items() -> AsyncIterable[Item]: + yield Item(name="Plumbus", price=32.99) + yield Item(name="Portal Gun", price=999.99) +``` + +For full control over SSE fields (`event`, `id`, `retry`, `comment`), yield `ServerSentEvent` instances: + +```python +from collections.abc import AsyncIterable + +from fastapi import FastAPI +from fastapi.sse import EventSourceResponse, ServerSentEvent + +app = FastAPI() + + +@app.get("/events", response_class=EventSourceResponse) +async def stream_events() -> AsyncIterable[ServerSentEvent]: + yield ServerSentEvent(data={"status": "started"}, event="status", id="1") + yield ServerSentEvent(data={"progress": 50}, event="progress", id="2") +``` + +Use `raw_data` instead of `data` to send pre-formatted strings without JSON encoding: + +```python +yield ServerSentEvent(raw_data="plain text line", event="log") +``` + ## Stream bytes To stream bytes, declare a `response_class=` of `StreamingResponse` or a sub-class, and use `yield` to return the data. diff --git a/fastapi/openapi/utils.py b/fastapi/openapi/utils.py index 3ddc0c14a9..828442559b 100644 --- a/fastapi/openapi/utils.py +++ b/fastapi/openapi/utils.py @@ -29,6 +29,7 @@ from fastapi.openapi.constants import METHODS_WITH_BODY, REF_PREFIX from fastapi.openapi.models import OpenAPI from fastapi.params import Body, ParamTypes from fastapi.responses import Response +from fastapi.sse import _SSE_EVENT_SCHEMA from fastapi.types import ModelNameMap from fastapi.utils import ( deep_dict_update, @@ -372,6 +373,26 @@ def get_openapi_path( operation.setdefault("responses", {}).setdefault( status_code, {} ).setdefault("content", {})["application/jsonl"] = jsonl_content + elif route.is_sse_stream: + sse_content: dict[str, Any] = {} + item_schema = copy.deepcopy(_SSE_EVENT_SCHEMA) + if route.stream_item_field: + content_schema = get_schema_from_model_field( + field=route.stream_item_field, + model_name_map=model_name_map, + field_mapping=field_mapping, + separate_input_output_schemas=separate_input_output_schemas, + ) + item_schema["required"] = ["data"] + item_schema["properties"]["data"] = { + "type": "string", + "contentMediaType": "application/json", + "contentSchema": content_schema, + } + sse_content["itemSchema"] = item_schema + operation.setdefault("responses", {}).setdefault( + status_code, {} + ).setdefault("content", {})["text/event-stream"] = sse_content elif route_response_media_type: response_schema = {"type": "string"} if lenient_issubclass(current_response_class, JSONResponse): diff --git a/fastapi/responses.py b/fastapi/responses.py index 5b1154c046..554b0952b0 100644 --- a/fastapi/responses.py +++ b/fastapi/responses.py @@ -1,6 +1,7 @@ from typing import Any from fastapi.exceptions import FastAPIDeprecationWarning +from fastapi.sse import EventSourceResponse as EventSourceResponse # noqa from starlette.responses import FileResponse as FileResponse # noqa from starlette.responses import HTMLResponse as HTMLResponse # noqa from starlette.responses import JSONResponse as JSONResponse # noqa diff --git a/fastapi/routing.py b/fastapi/routing.py index f00cd2ca75..a52271690f 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -56,6 +56,13 @@ from fastapi.exceptions import ( ResponseValidationError, WebSocketRequestValidationError, ) +from fastapi.sse import ( + _PING_INTERVAL, + KEEPALIVE_COMMENT, + EventSourceResponse, + ServerSentEvent, + format_sse_event, +) from fastapi.types import DecoratedCallable, IncEx from fastapi.utils import ( create_model_field, @@ -66,7 +73,7 @@ from fastapi.utils import ( from starlette import routing from starlette._exception_handler import wrap_app_handling_exceptions from starlette._utils import is_async_callable -from starlette.concurrency import run_in_threadpool +from starlette.concurrency import iterate_in_threadpool, run_in_threadpool from starlette.exceptions import HTTPException from starlette.requests import Request from starlette.responses import JSONResponse, Response, StreamingResponse @@ -361,6 +368,7 @@ def get_request_handler( actual_response_class: type[Response] = response_class.value else: actual_response_class = response_class + is_sse_stream = lenient_issubclass(actual_response_class, EventSourceResponse) if isinstance(strict_content_type, DefaultPlaceholder): actual_strict_content_type: bool = strict_content_type.value else: @@ -452,35 +460,125 @@ def get_request_handler( errors = solved_result.errors assert dependant.call # For types if not errors: - if is_json_stream: + # Shared serializer for stream items (JSONL and SSE). + # Validates against stream_item_field when set, then + # serializes to JSON bytes. + def _serialize_data(data: Any) -> bytes: + if stream_item_field: + value, errors_ = stream_item_field.validate( + data, {}, loc=("response",) + ) + if errors_: + ctx = endpoint_ctx or EndpointContext() + raise ResponseValidationError( + errors=errors_, + body=data, + endpoint_ctx=ctx, + ) + return stream_item_field.serialize_json( + value, + include=response_model_include, + exclude=response_model_exclude, + by_alias=response_model_by_alias, + exclude_unset=response_model_exclude_unset, + exclude_defaults=response_model_exclude_defaults, + exclude_none=response_model_exclude_none, + ) + else: + data = jsonable_encoder(data) + return json.dumps(data).encode("utf-8") + + if is_sse_stream: + # Generator endpoint: stream as Server-Sent Events + gen = dependant.call(**solved_result.values) + + def _serialize_sse_item(item: Any) -> bytes: + if isinstance(item, ServerSentEvent): + # User controls the event structure. + # Serialize the data payload if present. + # For ServerSentEvent items we skip stream_item_field + # validation (the user may mix types intentionally). + if item.raw_data is not None: + data_str: str | None = item.raw_data + elif item.data is not None: + if hasattr(item.data, "model_dump_json"): + data_str = item.data.model_dump_json() + else: + data_str = json.dumps(jsonable_encoder(item.data)) + else: + data_str = None + return format_sse_event( + data_str=data_str, + event=item.event, + id=item.id, + retry=item.retry, + comment=item.comment, + ) + else: + # Plain object: validate + serialize via + # stream_item_field (if set) and wrap in data field + return format_sse_event( + data_str=_serialize_data(item).decode("utf-8") + ) + + if dependant.is_async_gen_callable: + sse_aiter: AsyncIterator[Any] = gen.__aiter__() + else: + sse_aiter = iterate_in_threadpool(gen) + + async def _async_stream_sse() -> AsyncIterator[bytes]: + # Use a memory stream to decouple generator iteration + # from the keepalive timer. A producer task pulls items + # from the generator independently, so + # `anyio.fail_after` never wraps the generator's + # `__anext__` directly - avoiding CancelledError that + # would finalize the generator and also working for sync + # generators running in a thread pool. + send_stream, receive_stream = anyio.create_memory_object_stream[ + bytes + ](max_buffer_size=1) + + async def _producer() -> None: + async with send_stream: + async for raw_item in sse_aiter: + await send_stream.send(_serialize_sse_item(raw_item)) + + async with anyio.create_task_group() as tg: + tg.start_soon(_producer) + async with receive_stream: + try: + while True: + try: + with anyio.fail_after(_PING_INTERVAL): + data = await receive_stream.receive() + yield data + # To allow for cancellation to trigger + # Ref: https://github.com/fastapi/fastapi/issues/14680 + await anyio.sleep(0) + except TimeoutError: + yield KEEPALIVE_COMMENT + except anyio.EndOfStream: + pass + + sse_stream_content: AsyncIterator[bytes] | Iterator[bytes] = ( + _async_stream_sse() + ) + + response = StreamingResponse( + sse_stream_content, + media_type="text/event-stream", + background=solved_result.background_tasks, + ) + response.headers["Cache-Control"] = "no-cache" + # For Nginx proxies to not buffer server sent events + response.headers["X-Accel-Buffering"] = "no" + response.headers.raw.extend(solved_result.response.headers.raw) + elif is_json_stream: # Generator endpoint: stream as JSONL gen = dependant.call(**solved_result.values) def _serialize_item(item: Any) -> bytes: - if stream_item_field: - value, errors = stream_item_field.validate( - item, {}, loc=("response",) - ) - if errors: - ctx = endpoint_ctx or EndpointContext() - raise ResponseValidationError( - errors=errors, - body=item, - endpoint_ctx=ctx, - ) - line = stream_item_field.serialize_json( - value, - include=response_model_include, - exclude=response_model_exclude, - by_alias=response_model_by_alias, - exclude_unset=response_model_exclude_unset, - exclude_defaults=response_model_exclude_defaults, - exclude_none=response_model_exclude_none, - ) - return line + b"\n" - else: - data = jsonable_encoder(item) - return json.dumps(data).encode("utf-8") + b"\n" + return _serialize_data(item) + b"\n" if dependant.is_async_gen_callable: @@ -491,7 +589,7 @@ def get_request_handler( # Ref: https://github.com/fastapi/fastapi/issues/14680 await anyio.sleep(0) - stream_content: AsyncIterator[bytes] | Iterator[bytes] = ( + jsonl_stream_content: AsyncIterator[bytes] | Iterator[bytes] = ( _async_stream_jsonl() ) else: @@ -500,10 +598,10 @@ def get_request_handler( for item in gen: yield _serialize_item(item) - stream_content = _sync_stream_jsonl() + jsonl_stream_content = _sync_stream_jsonl() response = StreamingResponse( - stream_content, + jsonl_stream_content, media_type="application/jsonl", background=solved_result.background_tasks, ) @@ -709,9 +807,16 @@ class APIRoute(routing.Route): else: stream_item = get_stream_item_type(return_annotation) if stream_item is not None: - # Only extract item type for JSONL streaming when no - # explicit response_class (e.g. StreamingResponse) was set - if isinstance(response_class, DefaultPlaceholder): + # Extract item type for JSONL or SSE streaming when + # response_class is DefaultPlaceholder (JSONL) or + # EventSourceResponse (SSE). + # ServerSentEvent is excluded: it's a transport + # wrapper, not a data model, so it shouldn't feed + # into validation or OpenAPI schema generation. + if ( + isinstance(response_class, DefaultPlaceholder) + or lenient_issubclass(response_class, EventSourceResponse) + ) and not lenient_issubclass(stream_item, ServerSentEvent): self.stream_item_type = stream_item response_model = None else: @@ -814,11 +919,16 @@ class APIRoute(routing.Route): name=self.unique_id, embed_body_fields=self._embed_body_fields, ) - # Detect generator endpoints that should stream as JSONL - # (only when no explicit response_class like StreamingResponse is set) - self.is_json_stream = isinstance(response_class, DefaultPlaceholder) and ( + # Detect generator endpoints that should stream as JSONL or SSE + is_generator = ( self.dependant.is_async_gen_callable or self.dependant.is_gen_callable ) + self.is_sse_stream = is_generator and lenient_issubclass( + response_class, EventSourceResponse + ) + self.is_json_stream = is_generator and isinstance( + response_class, DefaultPlaceholder + ) self.app = request_response(self.get_route_handler()) def get_route_handler(self) -> Callable[[Request], Coroutine[Any, Any, Response]]: diff --git a/fastapi/sse.py b/fastapi/sse.py new file mode 100644 index 0000000000..901d824964 --- /dev/null +++ b/fastapi/sse.py @@ -0,0 +1,222 @@ +from typing import Annotated, Any + +from annotated_doc import Doc +from pydantic import AfterValidator, BaseModel, Field, model_validator +from starlette.responses import StreamingResponse + +# Canonical SSE event schema matching the OpenAPI 3.2 spec +# (Section 4.14.4 "Special Considerations for Server-Sent Events") +_SSE_EVENT_SCHEMA: dict[str, Any] = { + "type": "object", + "properties": { + "data": {"type": "string"}, + "event": {"type": "string"}, + "id": {"type": "string"}, + "retry": {"type": "integer", "minimum": 0}, + }, +} + + +class EventSourceResponse(StreamingResponse): + """Streaming response with `text/event-stream` media type. + + Use as `response_class=EventSourceResponse` on a *path operation* that uses `yield` + to enable Server Sent Events (SSE) responses. + + Works with **any HTTP method** (`GET`, `POST`, etc.), which makes it compatible + with protocols like MCP that stream SSE over `POST`. + + The actual encoding logic lives in the FastAPI routing layer. This class + serves mainly as a marker and sets the correct `Content-Type`. + """ + + media_type = "text/event-stream" + + +def _check_id_no_null(v: str | None) -> str | None: + if v is not None and "\0" in v: + raise ValueError("SSE 'id' must not contain null characters") + return v + + +class ServerSentEvent(BaseModel): + """Represents a single Server-Sent Event. + + When `yield`ed from a *path operation function* that uses + `response_class=EventSourceResponse`, each `ServerSentEvent` is encoded + into the [SSE wire format](https://html.spec.whatwg.org/multipage/server-sent-events.html#parsing-an-event-stream) + (`text/event-stream`). + + If you yield a plain object (dict, Pydantic model, etc.) instead, it is + automatically JSON-encoded and sent as the `data:` field. + + All `data` values **including plain strings** are JSON-serialized. + + For example, `data="hello"` produces `data: "hello"` on the wire (with + quotes). + """ + + data: Annotated[ + Any, + Doc( + """ + The event payload. + + Can be any JSON-serializable value: a Pydantic model, dict, list, + string, number, etc. It is **always** serialized to JSON: strings + are quoted (`"hello"` becomes `data: "hello"` on the wire). + + Mutually exclusive with `raw_data`. + """ + ), + ] = None + raw_data: Annotated[ + str | None, + Doc( + """ + Raw string to send as the `data:` field **without** JSON encoding. + + Use this when you need to send pre-formatted text, HTML fragments, + CSV lines, or any non-JSON payload. The string is placed directly + into the `data:` field as-is. + + Mutually exclusive with `data`. + """ + ), + ] = None + event: Annotated[ + str | None, + Doc( + """ + Optional event type name. + + Maps to `addEventListener(event, ...)` on the browser. When omitted, + the browser dispatches on the generic `message` event. + """ + ), + ] = None + id: Annotated[ + str | None, + AfterValidator(_check_id_no_null), + Doc( + """ + Optional event ID. + + The browser sends this value back as the `Last-Event-ID` header on + automatic reconnection. **Must not contain null (`\\0`) characters.** + """ + ), + ] = None + retry: Annotated[ + int | None, + Field(ge=0), + Doc( + """ + Optional reconnection time in **milliseconds**. + + Tells the browser how long to wait before reconnecting after the + connection is lost. Must be a non-negative integer. + """ + ), + ] = None + comment: Annotated[ + str | None, + Doc( + """ + Optional comment line(s). + + Comment lines start with `:` in the SSE wire format and are ignored by + `EventSource` clients. Useful for keep-alive pings to prevent + proxy/load-balancer timeouts. + """ + ), + ] = None + + @model_validator(mode="after") + def _check_data_exclusive(self) -> "ServerSentEvent": + if self.data is not None and self.raw_data is not None: + raise ValueError( + "Cannot set both 'data' and 'raw_data' on the same " + "ServerSentEvent. Use 'data' for JSON-serialized payloads " + "or 'raw_data' for pre-formatted strings." + ) + return self + + +def format_sse_event( + *, + data_str: Annotated[ + str | None, + Doc( + """ + Pre-serialized data string to use as the `data:` field. + """ + ), + ] = None, + event: Annotated[ + str | None, + Doc( + """ + Optional event type name (`event:` field). + """ + ), + ] = None, + id: Annotated[ + str | None, + Doc( + """ + Optional event ID (`id:` field). + """ + ), + ] = None, + retry: Annotated[ + int | None, + Doc( + """ + Optional reconnection time in milliseconds (`retry:` field). + """ + ), + ] = None, + comment: Annotated[ + str | None, + Doc( + """ + Optional comment line(s) (`:` prefix). + """ + ), + ] = None, +) -> bytes: + """Build SSE wire-format bytes from **pre-serialized** data. + + The result always ends with `\n\n` (the event terminator). + """ + lines: list[str] = [] + + if comment is not None: + for line in comment.splitlines(): + lines.append(f": {line}") + + if event is not None: + lines.append(f"event: {event}") + + if data_str is not None: + for line in data_str.splitlines(): + lines.append(f"data: {line}") + + if id is not None: + lines.append(f"id: {id}") + + if retry is not None: + lines.append(f"retry: {retry}") + + lines.append("") + lines.append("") + return "\n".join(lines).encode("utf-8") + + +# Keep-alive comment, per the SSE spec recommendation +KEEPALIVE_COMMENT = b": ping\n\n" + +# Seconds between keep-alive pings when a generator is idle. +# Private but importable so tests can monkeypatch it. +_PING_INTERVAL: float = 15.0 diff --git a/pyproject.toml b/pyproject.toml index 37caa322f6..3d699f68fb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -324,6 +324,7 @@ ignore = [ "docs_src/stream_json_lines/tutorial001_py310.py" = ["UP028"] "docs_src/stream_data/tutorial001_py310.py" = ["UP028"] "docs_src/stream_data/tutorial002_py310.py" = ["UP028"] +"docs_src/server_sent_events/tutorial001_py310.py" = ["UP028"] [tool.ruff.lint.isort] known-third-party = ["fastapi", "pydantic", "starlette"] diff --git a/tests/test_sse.py b/tests/test_sse.py new file mode 100644 index 0000000000..6dfec61838 --- /dev/null +++ b/tests/test_sse.py @@ -0,0 +1,318 @@ +import asyncio +import time +from collections.abc import AsyncIterable, Iterable + +import fastapi.routing +import pytest +from fastapi import APIRouter, FastAPI +from fastapi.responses import EventSourceResponse +from fastapi.sse import ServerSentEvent +from fastapi.testclient import TestClient +from pydantic import BaseModel + + +class Item(BaseModel): + name: str + description: str | None = None + + +items = [ + Item(name="Plumbus", description="A multi-purpose household device."), + Item(name="Portal Gun", description="A portal opening device."), + Item(name="Meeseeks Box", description="A box that summons a Meeseeks."), +] + + +app = FastAPI() + + +@app.get("/items/stream", response_class=EventSourceResponse) +async def sse_items() -> AsyncIterable[Item]: + for item in items: + yield item + + +@app.get("/items/stream-sync", response_class=EventSourceResponse) +def sse_items_sync() -> Iterable[Item]: + yield from items + + +@app.get("/items/stream-no-annotation", response_class=EventSourceResponse) +async def sse_items_no_annotation(): + for item in items: + yield item + + +@app.get("/items/stream-sync-no-annotation", response_class=EventSourceResponse) +def sse_items_sync_no_annotation(): + yield from items + + +@app.get("/items/stream-dict", response_class=EventSourceResponse) +async def sse_items_dict(): + for item in items: + yield {"name": item.name, "description": item.description} + + +@app.get("/items/stream-sse-event", response_class=EventSourceResponse) +async def sse_items_event(): + yield ServerSentEvent(data="hello", event="greeting", id="1") + yield ServerSentEvent(data={"key": "value"}, event="json-data", id="2") + yield ServerSentEvent(comment="just a comment") + yield ServerSentEvent(data="retry-test", retry=5000) + + +@app.get("/items/stream-mixed", response_class=EventSourceResponse) +async def sse_items_mixed() -> AsyncIterable[Item]: + yield items[0] + yield ServerSentEvent(data="custom-event", event="special") + yield items[1] + + +@app.get("/items/stream-string", response_class=EventSourceResponse) +async def sse_items_string(): + yield ServerSentEvent(data="plain text data") + + +@app.post("/items/stream-post", response_class=EventSourceResponse) +async def sse_items_post() -> AsyncIterable[Item]: + for item in items: + yield item + + +@app.get("/items/stream-raw", response_class=EventSourceResponse) +async def sse_items_raw(): + yield ServerSentEvent(raw_data="plain text without quotes") + yield ServerSentEvent(raw_data="
html fragment
", event="html") + yield ServerSentEvent(raw_data="cpu,87.3,1709145600", event="csv") + + +router = APIRouter() + + +@router.get("/events", response_class=EventSourceResponse) +async def stream_events(): + yield {"msg": "hello"} + yield {"msg": "world"} + + +app.include_router(router, prefix="/api") + + +@pytest.fixture(name="client") +def client_fixture(): + with TestClient(app) as c: + yield c + + +def test_async_generator_with_model(client: TestClient): + response = client.get("/items/stream") + assert response.status_code == 200 + assert response.headers["content-type"] == "text/event-stream; charset=utf-8" + assert response.headers["cache-control"] == "no-cache" + assert response.headers["x-accel-buffering"] == "no" + + lines = response.text.strip().split("\n") + data_lines = [line for line in lines if line.startswith("data: ")] + assert len(data_lines) == 3 + assert '"name":"Plumbus"' in data_lines[0] or '"name": "Plumbus"' in data_lines[0] + assert ( + '"name":"Portal Gun"' in data_lines[1] + or '"name": "Portal Gun"' in data_lines[1] + ) + assert ( + '"name":"Meeseeks Box"' in data_lines[2] + or '"name": "Meeseeks Box"' in data_lines[2] + ) + + +def test_sync_generator_with_model(client: TestClient): + response = client.get("/items/stream-sync") + assert response.status_code == 200 + assert response.headers["content-type"] == "text/event-stream; charset=utf-8" + + data_lines = [ + line for line in response.text.strip().split("\n") if line.startswith("data: ") + ] + assert len(data_lines) == 3 + + +def test_async_generator_no_annotation(client: TestClient): + response = client.get("/items/stream-no-annotation") + assert response.status_code == 200 + assert response.headers["content-type"] == "text/event-stream; charset=utf-8" + + data_lines = [ + line for line in response.text.strip().split("\n") if line.startswith("data: ") + ] + assert len(data_lines) == 3 + + +def test_sync_generator_no_annotation(client: TestClient): + response = client.get("/items/stream-sync-no-annotation") + assert response.status_code == 200 + assert response.headers["content-type"] == "text/event-stream; charset=utf-8" + + data_lines = [ + line for line in response.text.strip().split("\n") if line.startswith("data: ") + ] + assert len(data_lines) == 3 + + +def test_dict_items(client: TestClient): + response = client.get("/items/stream-dict") + assert response.status_code == 200 + data_lines = [ + line for line in response.text.strip().split("\n") if line.startswith("data: ") + ] + assert len(data_lines) == 3 + assert '"name"' in data_lines[0] + + +def test_post_method_sse(client: TestClient): + """SSE should work with POST (needed for MCP compatibility).""" + response = client.post("/items/stream-post") + assert response.status_code == 200 + assert response.headers["content-type"] == "text/event-stream; charset=utf-8" + data_lines = [ + line for line in response.text.strip().split("\n") if line.startswith("data: ") + ] + assert len(data_lines) == 3 + + +def test_sse_events_with_fields(client: TestClient): + response = client.get("/items/stream-sse-event") + assert response.status_code == 200 + text = response.text + + assert "event: greeting\n" in text + assert 'data: "hello"\n' in text + assert "id: 1\n" in text + + assert "event: json-data\n" in text + assert "id: 2\n" in text + assert 'data: {"key": "value"}\n' in text + + assert ": just a comment\n" in text + + assert "retry: 5000\n" in text + assert 'data: "retry-test"\n' in text + + +def test_mixed_plain_and_sse_events(client: TestClient): + response = client.get("/items/stream-mixed") + assert response.status_code == 200 + text = response.text + + assert "event: special\n" in text + assert 'data: "custom-event"\n' in text + assert '"name"' in text + + +def test_string_data_json_encoded(client: TestClient): + """Strings are always JSON-encoded (quoted).""" + response = client.get("/items/stream-string") + assert response.status_code == 200 + assert 'data: "plain text data"\n' in response.text + + +def test_server_sent_event_null_id_rejected(): + with pytest.raises(ValueError, match="null"): + ServerSentEvent(data="test", id="has\0null") + + +def test_server_sent_event_negative_retry_rejected(): + with pytest.raises(ValueError): + ServerSentEvent(data="test", retry=-1) + + +def test_server_sent_event_float_retry_rejected(): + with pytest.raises(ValueError): + ServerSentEvent(data="test", retry=1.5) # type: ignore[arg-type] + + +def test_raw_data_sent_without_json_encoding(client: TestClient): + """raw_data is sent as-is, not JSON-encoded.""" + response = client.get("/items/stream-raw") + assert response.status_code == 200 + text = response.text + + # raw_data should appear without JSON quotes + assert "data: plain text without quotes\n" in text + # Not JSON-quoted + assert 'data: "plain text without quotes"' not in text + + assert "event: html\n" in text + assert "data:
html fragment
\n" in text + + assert "event: csv\n" in text + assert "data: cpu,87.3,1709145600\n" in text + + +def test_data_and_raw_data_mutually_exclusive(): + """Cannot set both data and raw_data.""" + with pytest.raises(ValueError, match="Cannot set both"): + ServerSentEvent(data="json", raw_data="raw") + + +def test_sse_on_router_included_in_app(client: TestClient): + response = client.get("/api/events") + assert response.status_code == 200 + assert response.headers["content-type"] == "text/event-stream; charset=utf-8" + data_lines = [ + line for line in response.text.strip().split("\n") if line.startswith("data: ") + ] + assert len(data_lines) == 2 + + +# Keepalive ping tests + + +keepalive_app = FastAPI() + + +@keepalive_app.get("/slow-async", response_class=EventSourceResponse) +async def slow_async_stream(): + yield {"n": 1} + # Sleep longer than the (monkeypatched) ping interval so a keepalive + # comment is emitted before the next item. + await asyncio.sleep(0.3) + yield {"n": 2} + + +@keepalive_app.get("/slow-sync", response_class=EventSourceResponse) +def slow_sync_stream(): + yield {"n": 1} + time.sleep(0.3) + yield {"n": 2} + + +def test_keepalive_ping_async(monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr(fastapi.routing, "_PING_INTERVAL", 0.05) + with TestClient(keepalive_app) as c: + response = c.get("/slow-async") + assert response.status_code == 200 + text = response.text + # The keepalive comment ": ping" should appear between the two data events + assert ": ping\n" in text + data_lines = [line for line in text.split("\n") if line.startswith("data: ")] + assert len(data_lines) == 2 + + +def test_keepalive_ping_sync(monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr(fastapi.routing, "_PING_INTERVAL", 0.05) + with TestClient(keepalive_app) as c: + response = c.get("/slow-sync") + assert response.status_code == 200 + text = response.text + assert ": ping\n" in text + data_lines = [line for line in text.split("\n") if line.startswith("data: ")] + assert len(data_lines) == 2 + + +def test_no_keepalive_when_fast(client: TestClient): + """No keepalive comment when items arrive quickly.""" + response = client.get("/items/stream") + assert response.status_code == 200 + # KEEPALIVE_COMMENT is ": ping\n\n". + assert ": ping\n" not in response.text diff --git a/tests/test_tutorial/test_server_sent_events/__init__.py b/tests/test_tutorial/test_server_sent_events/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/test_tutorial/test_server_sent_events/test_tutorial001.py b/tests/test_tutorial/test_server_sent_events/test_tutorial001.py new file mode 100644 index 0000000000..75fdda3601 --- /dev/null +++ b/tests/test_tutorial/test_server_sent_events/test_tutorial001.py @@ -0,0 +1,191 @@ +import importlib + +import pytest +from fastapi.testclient import TestClient +from inline_snapshot import snapshot + + +@pytest.fixture( + name="client", + params=[ + pytest.param("tutorial001_py310"), + ], +) +def get_client(request: pytest.FixtureRequest): + mod = importlib.import_module(f"docs_src.server_sent_events.{request.param}") + + client = TestClient(mod.app) + return client + + +@pytest.mark.parametrize( + "path", + [ + "/items/stream", + "/items/stream-no-async", + "/items/stream-no-annotation", + "/items/stream-no-async-no-annotation", + ], +) +def test_stream_items(client: TestClient, path: str): + response = client.get(path) + assert response.status_code == 200, response.text + assert response.headers["content-type"] == "text/event-stream; charset=utf-8" + data_lines = [ + line for line in response.text.strip().split("\n") if line.startswith("data: ") + ] + assert len(data_lines) == 3 + + +def test_openapi_schema(client: TestClient): + response = client.get("/openapi.json") + assert response.status_code == 200, response.text + assert response.json() == snapshot( + { + "openapi": "3.1.0", + "info": {"title": "FastAPI", "version": "0.1.0"}, + "paths": { + "/items/stream": { + "get": { + "summary": "Sse Items", + "operationId": "sse_items_items_stream_get", + "responses": { + "200": { + "description": "Successful Response", + "content": { + "text/event-stream": { + "itemSchema": { + "type": "object", + "properties": { + "data": { + "type": "string", + "contentMediaType": "application/json", + "contentSchema": { + "$ref": "#/components/schemas/Item" + }, + }, + "event": {"type": "string"}, + "id": {"type": "string"}, + "retry": { + "type": "integer", + "minimum": 0, + }, + }, + "required": ["data"], + } + } + }, + } + }, + } + }, + "/items/stream-no-async": { + "get": { + "summary": "Sse Items No Async", + "operationId": "sse_items_no_async_items_stream_no_async_get", + "responses": { + "200": { + "description": "Successful Response", + "content": { + "text/event-stream": { + "itemSchema": { + "type": "object", + "properties": { + "data": { + "type": "string", + "contentMediaType": "application/json", + "contentSchema": { + "$ref": "#/components/schemas/Item" + }, + }, + "event": {"type": "string"}, + "id": {"type": "string"}, + "retry": { + "type": "integer", + "minimum": 0, + }, + }, + "required": ["data"], + } + } + }, + } + }, + } + }, + "/items/stream-no-annotation": { + "get": { + "summary": "Sse Items No Annotation", + "operationId": "sse_items_no_annotation_items_stream_no_annotation_get", + "responses": { + "200": { + "description": "Successful Response", + "content": { + "text/event-stream": { + "itemSchema": { + "type": "object", + "properties": { + "data": {"type": "string"}, + "event": {"type": "string"}, + "id": {"type": "string"}, + "retry": { + "type": "integer", + "minimum": 0, + }, + }, + } + } + }, + } + }, + } + }, + "/items/stream-no-async-no-annotation": { + "get": { + "summary": "Sse Items No Async No Annotation", + "operationId": "sse_items_no_async_no_annotation_items_stream_no_async_no_annotation_get", + "responses": { + "200": { + "description": "Successful Response", + "content": { + "text/event-stream": { + "itemSchema": { + "type": "object", + "properties": { + "data": {"type": "string"}, + "event": {"type": "string"}, + "id": {"type": "string"}, + "retry": { + "type": "integer", + "minimum": 0, + }, + }, + } + } + }, + } + }, + } + }, + }, + "components": { + "schemas": { + "Item": { + "properties": { + "name": {"type": "string", "title": "Name"}, + "description": { + "anyOf": [ + {"type": "string"}, + {"type": "null"}, + ], + "title": "Description", + }, + }, + "type": "object", + "required": ["name", "description"], + "title": "Item", + } + } + }, + } + ) diff --git a/tests/test_tutorial/test_server_sent_events/test_tutorial002.py b/tests/test_tutorial/test_server_sent_events/test_tutorial002.py new file mode 100644 index 0000000000..b9cbf43854 --- /dev/null +++ b/tests/test_tutorial/test_server_sent_events/test_tutorial002.py @@ -0,0 +1,83 @@ +import importlib + +import pytest +from fastapi.testclient import TestClient +from inline_snapshot import snapshot + + +@pytest.fixture( + name="client", + params=[ + pytest.param("tutorial002_py310"), + ], +) +def get_client(request: pytest.FixtureRequest): + mod = importlib.import_module(f"docs_src.server_sent_events.{request.param}") + client = TestClient(mod.app) + return client + + +def test_stream_items(client: TestClient): + response = client.get("/items/stream") + assert response.status_code == 200, response.text + assert response.headers["content-type"] == "text/event-stream; charset=utf-8" + + lines = response.text.strip().split("\n") + + # First event is a comment-only event + assert lines[0] == ": stream of item updates" + + # Remaining lines contain event:, data:, id:, retry: fields + event_lines = [line for line in lines if line.startswith("event: ")] + assert len(event_lines) == 3 + assert all(line == "event: item_update" for line in event_lines) + + data_lines = [line for line in lines if line.startswith("data: ")] + assert len(data_lines) == 3 + + id_lines = [line for line in lines if line.startswith("id: ")] + assert id_lines == ["id: 1", "id: 2", "id: 3"] + + retry_lines = [line for line in lines if line.startswith("retry: ")] + assert len(retry_lines) == 3 + assert all(line == "retry: 5000" for line in retry_lines) + + +def test_openapi_schema(client: TestClient): + response = client.get("/openapi.json") + assert response.status_code == 200, response.text + assert response.json() == snapshot( + { + "openapi": "3.1.0", + "info": {"title": "FastAPI", "version": "0.1.0"}, + "paths": { + "/items/stream": { + "get": { + "summary": "Stream Items", + "operationId": "stream_items_items_stream_get", + "responses": { + "200": { + "description": "Successful Response", + "content": { + "text/event-stream": { + "itemSchema": { + "type": "object", + "properties": { + "data": {"type": "string"}, + "event": {"type": "string"}, + "id": {"type": "string"}, + "retry": { + "type": "integer", + "minimum": 0, + }, + }, + } + } + }, + } + }, + } + } + }, + } + ) diff --git a/tests/test_tutorial/test_server_sent_events/test_tutorial003.py b/tests/test_tutorial/test_server_sent_events/test_tutorial003.py new file mode 100644 index 0000000000..6277a27c90 --- /dev/null +++ b/tests/test_tutorial/test_server_sent_events/test_tutorial003.py @@ -0,0 +1,73 @@ +import importlib + +import pytest +from fastapi.testclient import TestClient +from inline_snapshot import snapshot + + +@pytest.fixture( + name="client", + params=[ + pytest.param("tutorial003_py310"), + ], +) +def get_client(request: pytest.FixtureRequest): + mod = importlib.import_module(f"docs_src.server_sent_events.{request.param}") + client = TestClient(mod.app) + return client + + +def test_stream_logs(client: TestClient): + response = client.get("/logs/stream") + assert response.status_code == 200, response.text + assert response.headers["content-type"] == "text/event-stream; charset=utf-8" + + data_lines = [ + line for line in response.text.strip().split("\n") if line.startswith("data: ") + ] + assert len(data_lines) == 3 + + # raw_data is sent without JSON encoding (no quotes around the string) + assert data_lines[0] == "data: 2025-01-01 INFO Application started" + assert data_lines[1] == "data: 2025-01-01 DEBUG Connected to database" + assert data_lines[2] == "data: 2025-01-01 WARN High memory usage detected" + + +def test_openapi_schema(client: TestClient): + response = client.get("/openapi.json") + assert response.status_code == 200, response.text + assert response.json() == snapshot( + { + "openapi": "3.1.0", + "info": {"title": "FastAPI", "version": "0.1.0"}, + "paths": { + "/logs/stream": { + "get": { + "summary": "Stream Logs", + "operationId": "stream_logs_logs_stream_get", + "responses": { + "200": { + "description": "Successful Response", + "content": { + "text/event-stream": { + "itemSchema": { + "type": "object", + "properties": { + "data": {"type": "string"}, + "event": {"type": "string"}, + "id": {"type": "string"}, + "retry": { + "type": "integer", + "minimum": 0, + }, + }, + } + } + }, + } + }, + } + } + }, + } + ) diff --git a/tests/test_tutorial/test_server_sent_events/test_tutorial004.py b/tests/test_tutorial/test_server_sent_events/test_tutorial004.py new file mode 100644 index 0000000000..38ce888c1c --- /dev/null +++ b/tests/test_tutorial/test_server_sent_events/test_tutorial004.py @@ -0,0 +1,164 @@ +import importlib + +import pytest +from fastapi.testclient import TestClient +from inline_snapshot import snapshot + + +@pytest.fixture( + name="client", + params=[ + pytest.param("tutorial004_py310"), + ], +) +def get_client(request: pytest.FixtureRequest): + mod = importlib.import_module(f"docs_src.server_sent_events.{request.param}") + client = TestClient(mod.app) + return client + + +def test_stream_all_items(client: TestClient): + response = client.get("/items/stream") + assert response.status_code == 200, response.text + + data_lines = [ + line for line in response.text.strip().split("\n") if line.startswith("data: ") + ] + assert len(data_lines) == 3 + + id_lines = [ + line for line in response.text.strip().split("\n") if line.startswith("id: ") + ] + assert id_lines == ["id: 0", "id: 1", "id: 2"] + + +def test_resume_from_last_event_id(client: TestClient): + response = client.get( + "/items/stream", + headers={"last-event-id": "0"}, + ) + assert response.status_code == 200, response.text + + data_lines = [ + line for line in response.text.strip().split("\n") if line.startswith("data: ") + ] + assert len(data_lines) == 2 + + id_lines = [ + line for line in response.text.strip().split("\n") if line.startswith("id: ") + ] + assert id_lines == ["id: 1", "id: 2"] + + +def test_resume_from_last_item(client: TestClient): + response = client.get( + "/items/stream", + headers={"last-event-id": "1"}, + ) + assert response.status_code == 200, response.text + + data_lines = [ + line for line in response.text.strip().split("\n") if line.startswith("data: ") + ] + assert len(data_lines) == 1 + + id_lines = [ + line for line in response.text.strip().split("\n") if line.startswith("id: ") + ] + assert id_lines == ["id: 2"] + + +def test_openapi_schema(client: TestClient): + response = client.get("/openapi.json") + assert response.status_code == 200, response.text + assert response.json() == snapshot( + { + "openapi": "3.1.0", + "info": {"title": "FastAPI", "version": "0.1.0"}, + "paths": { + "/items/stream": { + "get": { + "summary": "Stream Items", + "operationId": "stream_items_items_stream_get", + "parameters": [ + { + "name": "last-event-id", + "in": "header", + "required": False, + "schema": { + "anyOf": [{"type": "integer"}, {"type": "null"}], + "title": "Last-Event-Id", + }, + } + ], + "responses": { + "200": { + "description": "Successful Response", + "content": { + "text/event-stream": { + "itemSchema": { + "type": "object", + "properties": { + "data": {"type": "string"}, + "event": {"type": "string"}, + "id": {"type": "string"}, + "retry": { + "type": "integer", + "minimum": 0, + }, + }, + } + } + }, + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + }, + }, + } + } + }, + "components": { + "schemas": { + "HTTPValidationError": { + "properties": { + "detail": { + "items": { + "$ref": "#/components/schemas/ValidationError" + }, + "type": "array", + "title": "Detail", + } + }, + "type": "object", + "title": "HTTPValidationError", + }, + "ValidationError": { + "properties": { + "loc": { + "items": { + "anyOf": [{"type": "string"}, {"type": "integer"}] + }, + "type": "array", + "title": "Location", + }, + "msg": {"type": "string", "title": "Message"}, + "type": {"type": "string", "title": "Error Type"}, + "input": {"title": "Input"}, + "ctx": {"type": "object", "title": "Context"}, + }, + "type": "object", + "required": ["loc", "msg", "type"], + "title": "ValidationError", + }, + } + }, + } + ) diff --git a/tests/test_tutorial/test_server_sent_events/test_tutorial005.py b/tests/test_tutorial/test_server_sent_events/test_tutorial005.py new file mode 100644 index 0000000000..1b5c3492f7 --- /dev/null +++ b/tests/test_tutorial/test_server_sent_events/test_tutorial005.py @@ -0,0 +1,141 @@ +import importlib + +import pytest +from fastapi.testclient import TestClient +from inline_snapshot import snapshot + + +@pytest.fixture( + name="client", + params=[ + pytest.param("tutorial005_py310"), + ], +) +def get_client(request: pytest.FixtureRequest): + mod = importlib.import_module(f"docs_src.server_sent_events.{request.param}") + client = TestClient(mod.app) + return client + + +def test_stream_chat(client: TestClient): + response = client.post( + "/chat/stream", + json={"text": "hello world"}, + ) + assert response.status_code == 200, response.text + assert response.headers["content-type"] == "text/event-stream; charset=utf-8" + + lines = response.text.strip().split("\n") + + event_lines = [line for line in lines if line.startswith("event: ")] + assert event_lines == [ + "event: token", + "event: token", + "event: done", + ] + + data_lines = [line for line in lines if line.startswith("data: ")] + assert data_lines == [ + 'data: "hello"', + 'data: "world"', + "data: [DONE]", + ] + + +def test_openapi_schema(client: TestClient): + response = client.get("/openapi.json") + assert response.status_code == 200, response.text + assert response.json() == snapshot( + { + "openapi": "3.1.0", + "info": {"title": "FastAPI", "version": "0.1.0"}, + "paths": { + "/chat/stream": { + "post": { + "summary": "Stream Chat", + "operationId": "stream_chat_chat_stream_post", + "requestBody": { + "content": { + "application/json": { + "schema": {"$ref": "#/components/schemas/Prompt"} + } + }, + "required": True, + }, + "responses": { + "200": { + "description": "Successful Response", + "content": { + "text/event-stream": { + "itemSchema": { + "type": "object", + "properties": { + "data": {"type": "string"}, + "event": {"type": "string"}, + "id": {"type": "string"}, + "retry": { + "type": "integer", + "minimum": 0, + }, + }, + } + } + }, + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + }, + }, + } + } + }, + "components": { + "schemas": { + "HTTPValidationError": { + "properties": { + "detail": { + "items": { + "$ref": "#/components/schemas/ValidationError" + }, + "type": "array", + "title": "Detail", + } + }, + "type": "object", + "title": "HTTPValidationError", + }, + "Prompt": { + "properties": {"text": {"type": "string", "title": "Text"}}, + "type": "object", + "required": ["text"], + "title": "Prompt", + }, + "ValidationError": { + "properties": { + "loc": { + "items": { + "anyOf": [{"type": "string"}, {"type": "integer"}] + }, + "type": "array", + "title": "Location", + }, + "msg": {"type": "string", "title": "Message"}, + "type": {"type": "string", "title": "Error Type"}, + "input": {"title": "Input"}, + "ctx": {"type": "object", "title": "Context"}, + }, + "type": "object", + "required": ["loc", "msg", "type"], + "title": "ValidationError", + }, + } + }, + } + )