mirror of https://github.com/tiangolo/fastapi.git
✨ Add support and tests for Pydantic dataclasses in response_model (#454)
This commit is contained in:
parent
c218e0d560
commit
3025a368c6
|
|
@ -1,4 +1,5 @@
|
||||||
import re
|
import re
|
||||||
|
from dataclasses import is_dataclass
|
||||||
from typing import Any, Dict, List, Sequence, Set, Type, cast
|
from typing import Any, Dict, List, Sequence, Set, Type, cast
|
||||||
|
|
||||||
from fastapi import routing
|
from fastapi import routing
|
||||||
|
|
@ -52,6 +53,8 @@ def get_path_param_names(path: str) -> Set[str]:
|
||||||
|
|
||||||
def create_cloned_field(field: Field) -> Field:
|
def create_cloned_field(field: Field) -> Field:
|
||||||
original_type = field.type_
|
original_type = field.type_
|
||||||
|
if is_dataclass(original_type) and hasattr(original_type, "__pydantic_model__"):
|
||||||
|
original_type = original_type.__pydantic_model__ # type: ignore
|
||||||
use_type = original_type
|
use_type = original_type
|
||||||
if lenient_issubclass(original_type, BaseModel):
|
if lenient_issubclass(original_type, BaseModel):
|
||||||
original_type = cast(Type[BaseModel], original_type)
|
original_type = cast(Type[BaseModel], original_type)
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,7 @@
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
import pytest
|
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from pydantic import BaseModel, ValidationError
|
from pydantic import BaseModel
|
||||||
from starlette.testclient import TestClient
|
from starlette.testclient import TestClient
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
@ -14,38 +13,45 @@ class Item(BaseModel):
|
||||||
owner_ids: List[int] = None
|
owner_ids: List[int] = None
|
||||||
|
|
||||||
|
|
||||||
@app.get("/items/invalid", response_model=Item)
|
@app.get("/items/valid", response_model=Item)
|
||||||
def get_invalid():
|
def get_valid():
|
||||||
return {"name": "invalid", "price": "foo"}
|
return {"name": "valid", "price": 1.0}
|
||||||
|
|
||||||
|
|
||||||
@app.get("/items/innerinvalid", response_model=Item)
|
@app.get("/items/coerce", response_model=Item)
|
||||||
def get_innerinvalid():
|
def get_coerce():
|
||||||
return {"name": "double invalid", "price": "foo", "owner_ids": ["foo", "bar"]}
|
return {"name": "coerce", "price": "1.0"}
|
||||||
|
|
||||||
|
|
||||||
@app.get("/items/invalidlist", response_model=List[Item])
|
@app.get("/items/validlist", response_model=List[Item])
|
||||||
def get_invalidlist():
|
def get_validlist():
|
||||||
return [
|
return [
|
||||||
{"name": "foo"},
|
{"name": "foo"},
|
||||||
{"name": "bar", "price": "bar"},
|
{"name": "bar", "price": 1.0},
|
||||||
{"name": "baz", "price": "baz"},
|
{"name": "baz", "price": 2.0, "owner_ids": [1, 2, 3]},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
client = TestClient(app)
|
client = TestClient(app)
|
||||||
|
|
||||||
|
|
||||||
def test_invalid():
|
def test_valid():
|
||||||
with pytest.raises(ValidationError):
|
response = client.get("/items/valid")
|
||||||
client.get("/items/invalid")
|
response.raise_for_status()
|
||||||
|
assert response.json() == {"name": "valid", "price": 1.0, "owner_ids": None}
|
||||||
|
|
||||||
|
|
||||||
def test_double_invalid():
|
def test_coerce():
|
||||||
with pytest.raises(ValidationError):
|
response = client.get("/items/coerce")
|
||||||
client.get("/items/innerinvalid")
|
response.raise_for_status()
|
||||||
|
assert response.json() == {"name": "coerce", "price": 1.0, "owner_ids": None}
|
||||||
|
|
||||||
|
|
||||||
def test_invalid_list():
|
def test_validlist():
|
||||||
with pytest.raises(ValidationError):
|
response = client.get("/items/validlist")
|
||||||
client.get("/items/invalidlist")
|
response.raise_for_status()
|
||||||
|
assert response.json() == [
|
||||||
|
{"name": "foo", "price": None, "owner_ids": None},
|
||||||
|
{"name": "bar", "price": 1.0, "owner_ids": None},
|
||||||
|
{"name": "baz", "price": 2.0, "owner_ids": [1, 2, 3]},
|
||||||
|
]
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,58 @@
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from pydantic.dataclasses import dataclass
|
||||||
|
from starlette.testclient import TestClient
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Item:
|
||||||
|
name: str
|
||||||
|
price: float = None
|
||||||
|
owner_ids: List[int] = None
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/items/valid", response_model=Item)
|
||||||
|
def get_valid():
|
||||||
|
return {"name": "valid", "price": 1.0}
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/items/coerce", response_model=Item)
|
||||||
|
def get_coerce():
|
||||||
|
return {"name": "coerce", "price": "1.0"}
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/items/validlist", response_model=List[Item])
|
||||||
|
def get_validlist():
|
||||||
|
return [
|
||||||
|
{"name": "foo"},
|
||||||
|
{"name": "bar", "price": 1.0},
|
||||||
|
{"name": "baz", "price": 2.0, "owner_ids": [1, 2, 3]},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
|
||||||
|
|
||||||
|
def test_valid():
|
||||||
|
response = client.get("/items/valid")
|
||||||
|
response.raise_for_status()
|
||||||
|
assert response.json() == {"name": "valid", "price": 1.0, "owner_ids": None}
|
||||||
|
|
||||||
|
|
||||||
|
def test_coerce():
|
||||||
|
response = client.get("/items/coerce")
|
||||||
|
response.raise_for_status()
|
||||||
|
assert response.json() == {"name": "coerce", "price": 1.0, "owner_ids": None}
|
||||||
|
|
||||||
|
|
||||||
|
def test_validlist():
|
||||||
|
response = client.get("/items/validlist")
|
||||||
|
response.raise_for_status()
|
||||||
|
assert response.json() == [
|
||||||
|
{"name": "foo", "price": None, "owner_ids": None},
|
||||||
|
{"name": "bar", "price": 1.0, "owner_ids": None},
|
||||||
|
{"name": "baz", "price": 2.0, "owner_ids": [1, 2, 3]},
|
||||||
|
]
|
||||||
|
|
@ -0,0 +1,51 @@
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from pydantic import BaseModel, ValidationError
|
||||||
|
from starlette.testclient import TestClient
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
|
||||||
|
class Item(BaseModel):
|
||||||
|
name: str
|
||||||
|
price: float = None
|
||||||
|
owner_ids: List[int] = None
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/items/invalid", response_model=Item)
|
||||||
|
def get_invalid():
|
||||||
|
return {"name": "invalid", "price": "foo"}
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/items/innerinvalid", response_model=Item)
|
||||||
|
def get_innerinvalid():
|
||||||
|
return {"name": "double invalid", "price": "foo", "owner_ids": ["foo", "bar"]}
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/items/invalidlist", response_model=List[Item])
|
||||||
|
def get_invalidlist():
|
||||||
|
return [
|
||||||
|
{"name": "foo"},
|
||||||
|
{"name": "bar", "price": "bar"},
|
||||||
|
{"name": "baz", "price": "baz"},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
|
||||||
|
|
||||||
|
def test_invalid():
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
client.get("/items/invalid")
|
||||||
|
|
||||||
|
|
||||||
|
def test_double_invalid():
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
client.get("/items/innerinvalid")
|
||||||
|
|
||||||
|
|
||||||
|
def test_invalid_list():
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
client.get("/items/invalidlist")
|
||||||
|
|
@ -0,0 +1,53 @@
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from pydantic import ValidationError
|
||||||
|
from pydantic.dataclasses import dataclass
|
||||||
|
from starlette.testclient import TestClient
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Item:
|
||||||
|
name: str
|
||||||
|
price: float = None
|
||||||
|
owner_ids: List[int] = None
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/items/invalid", response_model=Item)
|
||||||
|
def get_invalid():
|
||||||
|
return {"name": "invalid", "price": "foo"}
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/items/innerinvalid", response_model=Item)
|
||||||
|
def get_innerinvalid():
|
||||||
|
return {"name": "double invalid", "price": "foo", "owner_ids": ["foo", "bar"]}
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/items/invalidlist", response_model=List[Item])
|
||||||
|
def get_invalidlist():
|
||||||
|
return [
|
||||||
|
{"name": "foo"},
|
||||||
|
{"name": "bar", "price": "bar"},
|
||||||
|
{"name": "baz", "price": "baz"},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
|
||||||
|
|
||||||
|
def test_invalid():
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
client.get("/items/invalid")
|
||||||
|
|
||||||
|
|
||||||
|
def test_double_invalid():
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
client.get("/items/innerinvalid")
|
||||||
|
|
||||||
|
|
||||||
|
def test_invalid_list():
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
client.get("/items/invalidlist")
|
||||||
Loading…
Reference in New Issue