diff --git a/fastapi/_compat/model_field.py b/fastapi/_compat/model_field.py index fa2008c5e..3b7dd6337 100644 --- a/fastapi/_compat/model_field.py +++ b/fastapi/_compat/model_field.py @@ -51,3 +51,15 @@ class ModelField(Protocol): exclude_defaults: bool = False, exclude_none: bool = False, ) -> Any: ... + + def serialize_json( + self, + value: Any, + *, + include: Union[IncEx, None] = None, + exclude: Union[IncEx, None] = None, + by_alias: bool = True, + exclude_unset: bool = False, + exclude_defaults: bool = False, + exclude_none: bool = False, + ) -> bytes: ... diff --git a/fastapi/_compat/v2.py b/fastapi/_compat/v2.py index a17d62556..f3e8f4261 100644 --- a/fastapi/_compat/v2.py +++ b/fastapi/_compat/v2.py @@ -208,6 +208,27 @@ class ModelField: exclude_none=exclude_none, ) + def serialize_json( + self, + value: Any, + *, + include: Union[IncEx, None] = None, + exclude: Union[IncEx, None] = None, + by_alias: bool = True, + exclude_unset: bool = False, + exclude_defaults: bool = False, + exclude_none: bool = False, + ) -> bytes: + return self._type_adapter.dump_json( + value, + include=include, + exclude=exclude, + by_alias=by_alias, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + ) + def __hash__(self) -> int: # Each ModelField is unique for our purposes, to allow making a dict from # ModelField to its JSON Schema. diff --git a/fastapi/responses.py b/fastapi/responses.py index 6c8db6f33..dd31e00ec 100644 --- a/fastapi/responses.py +++ b/fastapi/responses.py @@ -46,3 +46,15 @@ class ORJSONResponse(JSONResponse): return orjson.dumps( content, option=orjson.OPT_NON_STR_KEYS | orjson.OPT_SERIALIZE_NUMPY ) + + +class PydanticJSONResponse(JSONResponse): + """ + JSON response using Pydantic v2's built-in JSON serialization + """ + + def render(self, content: Any) -> bytes: + assert isinstance(content, bytes), ( + "PydanticJSONResponse must be used with a response model" + ) + return content diff --git a/fastapi/routing.py b/fastapi/routing.py index 9be2b44bc..2cc96ab76 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -52,6 +52,7 @@ from fastapi.exceptions import ( ResponseValidationError, WebSocketRequestValidationError, ) +from fastapi.responses import JSONResponse, PydanticJSONResponse, Response from fastapi.types import DecoratedCallable, IncEx from fastapi.utils import ( create_cloned_field, @@ -67,7 +68,6 @@ from starlette._utils import is_async_callable from starlette.concurrency import run_in_threadpool from starlette.exceptions import HTTPException from starlette.requests import Request -from starlette.responses import JSONResponse, Response from starlette.routing import ( BaseRoute, Match, @@ -244,6 +244,7 @@ async def serialize_response( *, field: Optional[ModelField] = None, response_content: Any, + to_json: bool = False, include: Optional[IncEx] = None, exclude: Optional[IncEx] = None, by_alias: bool = True, @@ -257,6 +258,7 @@ async def serialize_response( errors = [] if not hasattr(field, "serialize"): # pydantic v1 + assert not to_json, "PydanticJSONResponse requires a pydantic v2 model" response_content = _prepare_response_content( response_content, exclude_unset=exclude_unset, @@ -282,6 +284,16 @@ async def serialize_response( ) if hasattr(field, "serialize"): + if to_json: + return field.serialize_json( + value, + include=include, + exclude=exclude, + by_alias=by_alias, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + ) return field.serialize( value, include=include, @@ -452,6 +464,7 @@ def get_request_handler( content = await serialize_response( field=response_field, response_content=raw_response, + to_json=issubclass(actual_response_class, PydanticJSONResponse), include=response_model_include, exclude=response_model_exclude, by_alias=response_model_by_alias, diff --git a/tests/test_pydantic_json_response_class.py b/tests/test_pydantic_json_response_class.py new file mode 100644 index 000000000..e020f89a6 --- /dev/null +++ b/tests/test_pydantic_json_response_class.py @@ -0,0 +1,182 @@ +import math +from typing import List, Optional + +import pytest +from dirty_equals import IsFloatNan +from fastapi import FastAPI +from fastapi._compat import PYDANTIC_V2 +from fastapi.responses import PydanticJSONResponse +from fastapi.testclient import TestClient +from pydantic import BaseModel, Field + +from .utils import needs_py_lt_314, needs_pydanticv2 + +app = FastAPI(default_response_class=PydanticJSONResponse) + + +class CustomResponse(PydanticJSONResponse): + media_type = "application/x-custom" + + +class Item(BaseModel): + name: str + price: float + category: str = Field("food", alias="CAT") + tax: float = 8.875 + description: Optional[str] = None + + +@app.get("/response-model", response_model=Item) +@app.get( + "/response-model-include", + response_model=Item, + response_model_include={"name", "category"}, +) +@app.get( + "/response-model-exclude", + response_model=Item, + response_model_exclude={"tax", "description"}, +) +@app.get( + "/response-model-by-alias-false", + response_model=Item, + response_model_by_alias=False, +) +@app.get( + "/response-model-exclude-unset", + response_model=Item, + response_model_exclude_unset=True, +) +@app.get( + "/response-model-exclude-defaults", + response_model=Item, + response_model_exclude_defaults=True, +) +@app.get( + "/response-model-exclude-none", + response_model=Item, + response_model_exclude_none=True, +) +def get_response_model_params(): + return {"name": "cheese", "price": 1.23, "tax": 8.875, "description": None} + + +class FloatsNone(BaseModel): + # pydantic converts inf/nan to None by default + numbers: List[float] + + +class FloatsNum(FloatsNone): + model_config = {"ser_json_inf_nan": "constants"} + + +class FloatsStr(FloatsNone): + model_config = {"ser_json_inf_nan": "strings"} + + +@app.get("/floats-none", response_model=FloatsNone) +@app.get("/floats-num", response_model=FloatsNum) +@app.get("/floats-str", response_model=FloatsStr) +@app.get("/custom-class", response_class=CustomResponse, response_model=FloatsStr) +def get_floats(): + return {"numbers": [3.14, math.inf, math.nan, 2.72]} + + +client = TestClient(app) + + +@needs_pydanticv2 +@pytest.mark.parametrize( + "path,expected_response", + [ + ( + "/response-model", + { + "name": "cheese", + "price": 1.23, + "CAT": "food", + "tax": 8.875, + "description": None, + }, + ), + ("/response-model-include", {"name": "cheese", "CAT": "food"}), + ("/response-model-exclude", {"name": "cheese", "price": 1.23, "CAT": "food"}), + ( + "/response-model-by-alias-false", + { + "name": "cheese", + "price": 1.23, + "category": "food", + "tax": 8.875, + "description": None, + }, + ), + ( + "/response-model-exclude-unset", + { + "name": "cheese", + "price": 1.23, + "tax": 8.875, + "description": None, + }, + ), + ("/response-model-exclude-defaults", {"name": "cheese", "price": 1.23}), + ( + "/response-model-exclude-none", + { + "name": "cheese", + "price": 1.23, + "CAT": "food", + "tax": 8.875, + }, + ), + ], +) +def test_response_model_params(path: str, expected_response: dict): + response = client.get(path) + assert response.status_code == 200 + assert response.json() == expected_response + + +@needs_pydanticv2 +@pytest.mark.parametrize( + "path,expected_numbers", + [ + ("/floats-none", [3.14, None, None, 2.72]), + ("/floats-num", [3.14, math.inf, IsFloatNan, 2.72]), + ("/floats-str", [3.14, "Infinity", "NaN", 2.72]), + ], +) +def test_floats(path: str, expected_numbers: list): + response = client.get(path) + assert response.status_code == 200 + assert response.json() == {"numbers": expected_numbers} + + +@needs_pydanticv2 +def test_custom_response_class(): + response = client.get("/custom-class") + assert response.status_code == 200 + assert response.headers["content-type"] == "application/x-custom" + assert response.json() == {"numbers": [3.14, "Infinity", "NaN", 2.72]} + + +@needs_py_lt_314 +def test_requires_pydantic_v2_model(): + if PYDANTIC_V2: + from pydantic.v1 import BaseModel as BaseModelV1 + else: + from pydantic import BaseModel as BaseModelV1 + + app = FastAPI(default_response_class=PydanticJSONResponse) + + class ModelV1(BaseModelV1): + data: str + + @app.get("/model-v1") + def get_model_v1() -> ModelV1: + return ModelV1(data="hello") + + client = TestClient(app) + with pytest.raises(AssertionError, match="requires a pydantic v2 model"): + client.get("/model-v1")