diff --git a/fastapi/routing.py b/fastapi/routing.py index 36acb6b89d..bd0448f935 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -844,28 +844,6 @@ class APIRoute(routing.Route): self.path = path self.endpoint = endpoint self.stream_item_type: Any | None = None - if isinstance(response_model, DefaultPlaceholder): - return_annotation = get_typed_return_annotation(endpoint) - if lenient_issubclass(return_annotation, Response): - response_model = None - else: - stream_item = get_stream_item_type(return_annotation) - if stream_item is not None: - # Extract item type for JSONL or SSE streaming when - # response_class is DefaultPlaceholder (JSONL) or - # EventSourceResponse (SSE). - # ServerSentEvent is excluded: it's a transport - # wrapper, not a data model, so it shouldn't feed - # into validation or OpenAPI schema generation. - if ( - isinstance(response_class, DefaultPlaceholder) - or lenient_issubclass(response_class, EventSourceResponse) - ) and not lenient_issubclass(stream_item, ServerSentEvent): - self.stream_item_type = stream_item - response_model = None - else: - response_model = return_annotation - self.response_model = response_model self.summary = summary self.response_description = response_description self.deprecated = deprecated @@ -901,27 +879,6 @@ class APIRoute(routing.Route): if isinstance(status_code, IntEnum): status_code = int(status_code) self.status_code = status_code - if self.response_model: - assert is_body_allowed_for_status_code(status_code), ( - f"Status code {status_code} must not have a response body" - ) - response_name = "Response_" + self.unique_id - self.response_field = create_model_field( - name=response_name, - type_=self.response_model, - mode="serialization", - ) - else: - self.response_field = None # type: ignore # ty: ignore[unused-ignore-comment] - if self.stream_item_type: - stream_item_name = "StreamItem_" + self.unique_id - self.stream_item_field: ModelField | None = create_model_field( - name=stream_item_name, - type_=self.stream_item_type, - mode="serialization", - ) - else: - self.stream_item_field = None self.dependencies = list(dependencies or []) self.description = description or inspect.cleandoc(self.endpoint.__doc__ or "") # if a "form feed" character (page break) is found in the description text, @@ -973,6 +930,50 @@ class APIRoute(routing.Route): self.is_json_stream = is_generator and isinstance( response_class, DefaultPlaceholder ) + if isinstance(response_model, DefaultPlaceholder): + return_annotation = get_typed_return_annotation(endpoint) + if lenient_issubclass(return_annotation, Response): + response_model = None + else: + stream_item = get_stream_item_type(return_annotation) + if stream_item is not None and is_generator: + # Extract item type for JSONL or SSE streaming for + # generator endpoints when response_class is + # DefaultPlaceholder (JSONL) or EventSourceResponse + # (SSE). + # ServerSentEvent is excluded: it's a transport + # wrapper, not a data model, so it shouldn't feed + # into validation or OpenAPI schema generation. + if ( + isinstance(response_class, DefaultPlaceholder) + or lenient_issubclass(response_class, EventSourceResponse) + ) and not lenient_issubclass(stream_item, ServerSentEvent): + self.stream_item_type = stream_item + response_model = None + else: + response_model = return_annotation + self.response_model = response_model + if self.response_model: + assert is_body_allowed_for_status_code(status_code), ( + f"Status code {status_code} must not have a response body" + ) + response_name = "Response_" + self.unique_id + self.response_field = create_model_field( + name=response_name, + type_=self.response_model, + mode="serialization", + ) + else: + self.response_field = None # type: ignore # ty: ignore[unused-ignore-comment] + if self.stream_item_type: + stream_item_name = "StreamItem_" + self.unique_id + self.stream_item_field: ModelField | None = create_model_field( + name=stream_item_name, + type_=self.stream_item_type, + mode="serialization", + ) + else: + self.stream_item_field = None self.app = request_response(self.get_route_handler()) def get_route_handler(self) -> Callable[[Request], Coroutine[Any, Any, Response]]: diff --git a/tests/test_skip_defaults.py b/tests/test_skip_defaults.py index 238da7392f..f6333c754b 100644 --- a/tests/test_skip_defaults.py +++ b/tests/test_skip_defaults.py @@ -1,3 +1,5 @@ +from collections.abc import Iterable + from fastapi import FastAPI from fastapi.testclient import TestClient from pydantic import BaseModel @@ -65,6 +67,21 @@ def get_exclude_unset_none() -> ModelDefaults: return ModelDefaults(x=None, y="y") +@app.get("/iterable_exclude_unset", response_model_exclude_unset=True) +def get_iterable_exclude_unset() -> Iterable[ModelDefaults]: + return [ModelDefaults(x=None, y="y")] + + +@app.get("/iterable_exclude_defaults", response_model_exclude_defaults=True) +def get_iterable_exclude_defaults() -> Iterable[ModelDefaults]: + return [ModelDefaults(x=None, y="y")] + + +@app.get("/iterable_exclude_none", response_model_exclude_none=True) +def get_iterable_exclude_none() -> Iterable[ModelDefaults]: + return [ModelDefaults(x=None, y="y")] + + client = TestClient(app) @@ -91,3 +108,18 @@ def test_return_exclude_none(): def test_return_exclude_unset_none(): response = client.get("/exclude_unset_none") assert response.json() == {"y": "y"} + + +def test_return_iterable_exclude_unset(): + response = client.get("/iterable_exclude_unset") + assert response.json() == [{"x": None, "y": "y"}] + + +def test_return_iterable_exclude_defaults(): + response = client.get("/iterable_exclude_defaults") + assert response.json() == [{}] + + +def test_return_iterable_exclude_none(): + response = client.get("/iterable_exclude_none") + assert response.json() == [{"y": "y", "z": "z"}]