This commit is contained in:
Eugene Toder 2025-12-16 21:09:33 +00:00 committed by GitHub
commit fe3dd606cf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 241 additions and 1 deletions

View File

@ -51,3 +51,15 @@ class ModelField(Protocol):
exclude_defaults: bool = False, exclude_defaults: bool = False,
exclude_none: bool = False, exclude_none: bool = False,
) -> Any: ... ) -> Any: ...
def serialize_json(
self,
value: Any,
*,
include: Union[IncEx, None] = None,
exclude: Union[IncEx, None] = None,
by_alias: bool = True,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
) -> bytes: ...

View File

@ -208,6 +208,27 @@ class ModelField:
exclude_none=exclude_none, exclude_none=exclude_none,
) )
def serialize_json(
self,
value: Any,
*,
include: Union[IncEx, None] = None,
exclude: Union[IncEx, None] = None,
by_alias: bool = True,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
) -> bytes:
return self._type_adapter.dump_json(
value,
include=include,
exclude=exclude,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
)
def __hash__(self) -> int: def __hash__(self) -> int:
# Each ModelField is unique for our purposes, to allow making a dict from # Each ModelField is unique for our purposes, to allow making a dict from
# ModelField to its JSON Schema. # ModelField to its JSON Schema.

View File

@ -46,3 +46,15 @@ class ORJSONResponse(JSONResponse):
return orjson.dumps( return orjson.dumps(
content, option=orjson.OPT_NON_STR_KEYS | orjson.OPT_SERIALIZE_NUMPY content, option=orjson.OPT_NON_STR_KEYS | orjson.OPT_SERIALIZE_NUMPY
) )
class PydanticJSONResponse(JSONResponse):
"""
JSON response using Pydantic v2's built-in JSON serialization
"""
def render(self, content: Any) -> bytes:
assert isinstance(content, bytes), (
"PydanticJSONResponse must be used with a response model"
)
return content

View File

@ -52,6 +52,7 @@ from fastapi.exceptions import (
ResponseValidationError, ResponseValidationError,
WebSocketRequestValidationError, WebSocketRequestValidationError,
) )
from fastapi.responses import JSONResponse, PydanticJSONResponse, Response
from fastapi.types import DecoratedCallable, IncEx from fastapi.types import DecoratedCallable, IncEx
from fastapi.utils import ( from fastapi.utils import (
create_cloned_field, create_cloned_field,
@ -67,7 +68,6 @@ from starlette._utils import is_async_callable
from starlette.concurrency import run_in_threadpool from starlette.concurrency import run_in_threadpool
from starlette.exceptions import HTTPException from starlette.exceptions import HTTPException
from starlette.requests import Request from starlette.requests import Request
from starlette.responses import JSONResponse, Response
from starlette.routing import ( from starlette.routing import (
BaseRoute, BaseRoute,
Match, Match,
@ -244,6 +244,7 @@ async def serialize_response(
*, *,
field: Optional[ModelField] = None, field: Optional[ModelField] = None,
response_content: Any, response_content: Any,
to_json: bool = False,
include: Optional[IncEx] = None, include: Optional[IncEx] = None,
exclude: Optional[IncEx] = None, exclude: Optional[IncEx] = None,
by_alias: bool = True, by_alias: bool = True,
@ -257,6 +258,7 @@ async def serialize_response(
errors = [] errors = []
if not hasattr(field, "serialize"): if not hasattr(field, "serialize"):
# pydantic v1 # pydantic v1
assert not to_json, "PydanticJSONResponse requires a pydantic v2 model"
response_content = _prepare_response_content( response_content = _prepare_response_content(
response_content, response_content,
exclude_unset=exclude_unset, exclude_unset=exclude_unset,
@ -282,6 +284,16 @@ async def serialize_response(
) )
if hasattr(field, "serialize"): if hasattr(field, "serialize"):
if to_json:
return field.serialize_json(
value,
include=include,
exclude=exclude,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
)
return field.serialize( return field.serialize(
value, value,
include=include, include=include,
@ -452,6 +464,7 @@ def get_request_handler(
content = await serialize_response( content = await serialize_response(
field=response_field, field=response_field,
response_content=raw_response, response_content=raw_response,
to_json=issubclass(actual_response_class, PydanticJSONResponse),
include=response_model_include, include=response_model_include,
exclude=response_model_exclude, exclude=response_model_exclude,
by_alias=response_model_by_alias, by_alias=response_model_by_alias,

View File

@ -0,0 +1,182 @@
import math
from typing import List, Optional
import pytest
from dirty_equals import IsFloatNan
from fastapi import FastAPI
from fastapi._compat import PYDANTIC_V2
from fastapi.responses import PydanticJSONResponse
from fastapi.testclient import TestClient
from pydantic import BaseModel, Field
from .utils import needs_py_lt_314, needs_pydanticv2
app = FastAPI(default_response_class=PydanticJSONResponse)
class CustomResponse(PydanticJSONResponse):
media_type = "application/x-custom"
class Item(BaseModel):
name: str
price: float
category: str = Field("food", alias="CAT")
tax: float = 8.875
description: Optional[str] = None
@app.get("/response-model", response_model=Item)
@app.get(
"/response-model-include",
response_model=Item,
response_model_include={"name", "category"},
)
@app.get(
"/response-model-exclude",
response_model=Item,
response_model_exclude={"tax", "description"},
)
@app.get(
"/response-model-by-alias-false",
response_model=Item,
response_model_by_alias=False,
)
@app.get(
"/response-model-exclude-unset",
response_model=Item,
response_model_exclude_unset=True,
)
@app.get(
"/response-model-exclude-defaults",
response_model=Item,
response_model_exclude_defaults=True,
)
@app.get(
"/response-model-exclude-none",
response_model=Item,
response_model_exclude_none=True,
)
def get_response_model_params():
return {"name": "cheese", "price": 1.23, "tax": 8.875, "description": None}
class FloatsNone(BaseModel):
# pydantic converts inf/nan to None by default
numbers: List[float]
class FloatsNum(FloatsNone):
model_config = {"ser_json_inf_nan": "constants"}
class FloatsStr(FloatsNone):
model_config = {"ser_json_inf_nan": "strings"}
@app.get("/floats-none", response_model=FloatsNone)
@app.get("/floats-num", response_model=FloatsNum)
@app.get("/floats-str", response_model=FloatsStr)
@app.get("/custom-class", response_class=CustomResponse, response_model=FloatsStr)
def get_floats():
return {"numbers": [3.14, math.inf, math.nan, 2.72]}
client = TestClient(app)
@needs_pydanticv2
@pytest.mark.parametrize(
"path,expected_response",
[
(
"/response-model",
{
"name": "cheese",
"price": 1.23,
"CAT": "food",
"tax": 8.875,
"description": None,
},
),
("/response-model-include", {"name": "cheese", "CAT": "food"}),
("/response-model-exclude", {"name": "cheese", "price": 1.23, "CAT": "food"}),
(
"/response-model-by-alias-false",
{
"name": "cheese",
"price": 1.23,
"category": "food",
"tax": 8.875,
"description": None,
},
),
(
"/response-model-exclude-unset",
{
"name": "cheese",
"price": 1.23,
"tax": 8.875,
"description": None,
},
),
("/response-model-exclude-defaults", {"name": "cheese", "price": 1.23}),
(
"/response-model-exclude-none",
{
"name": "cheese",
"price": 1.23,
"CAT": "food",
"tax": 8.875,
},
),
],
)
def test_response_model_params(path: str, expected_response: dict):
response = client.get(path)
assert response.status_code == 200
assert response.json() == expected_response
@needs_pydanticv2
@pytest.mark.parametrize(
"path,expected_numbers",
[
("/floats-none", [3.14, None, None, 2.72]),
("/floats-num", [3.14, math.inf, IsFloatNan, 2.72]),
("/floats-str", [3.14, "Infinity", "NaN", 2.72]),
],
)
def test_floats(path: str, expected_numbers: list):
response = client.get(path)
assert response.status_code == 200
assert response.json() == {"numbers": expected_numbers}
@needs_pydanticv2
def test_custom_response_class():
response = client.get("/custom-class")
assert response.status_code == 200
assert response.headers["content-type"] == "application/x-custom"
assert response.json() == {"numbers": [3.14, "Infinity", "NaN", 2.72]}
@needs_py_lt_314
def test_requires_pydantic_v2_model():
if PYDANTIC_V2:
from pydantic.v1 import BaseModel as BaseModelV1
else:
from pydantic import BaseModel as BaseModelV1
app = FastAPI(default_response_class=PydanticJSONResponse)
class ModelV1(BaseModelV1):
data: str
@app.get("/model-v1")
def get_model_v1() -> ModelV1:
return ModelV1(data="hello")
client = TestClient(app)
with pytest.raises(AssertionError, match="requires a pydantic v2 model"):
client.get("/model-v1")