🐛 Fix Pydantic field clone logic with validators (#899)

This commit is contained in:
Andy Smith 2020-02-03 22:03:51 -05:00 committed by GitHub
parent 4f964939a1
commit 70bdade23b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 72 additions and 13 deletions

View File

@ -93,12 +93,9 @@ def create_cloned_field(field: ModelField) -> ModelField:
use_type = original_type use_type = original_type
if lenient_issubclass(original_type, BaseModel): if lenient_issubclass(original_type, BaseModel):
original_type = cast(Type[BaseModel], original_type) original_type = cast(Type[BaseModel], original_type)
use_type = create_model( use_type = create_model(original_type.__name__, __base__=original_type)
original_type.__name__, __config__=original_type.__config__
)
for f in original_type.__fields__.values(): for f in original_type.__fields__.values():
use_type.__fields__[f.name] = create_cloned_field(f) use_type.__fields__[f.name] = create_cloned_field(f)
use_type.__validators__ = original_type.__validators__
if PYDANTIC_1: if PYDANTIC_1:
new_field = ModelField( new_field = ModelField(
name=field.name, name=field.name,

View File

@ -1,5 +1,6 @@
import pytest
from fastapi import Depends, FastAPI from fastapi import Depends, FastAPI
from pydantic import BaseModel from pydantic import BaseModel, ValidationError, validator
from starlette.testclient import TestClient from starlette.testclient import TestClient
app = FastAPI() app = FastAPI()
@ -18,14 +19,20 @@ class ModelA(BaseModel):
description: str = None description: str = None
model_b: ModelB model_b: ModelB
@validator("name")
def lower_username(cls, name: str, values):
if not name.endswith("A"):
raise ValueError("name must end in A")
return name
async def get_model_c() -> ModelC: async def get_model_c() -> ModelC:
return ModelC(username="test-user", password="test-password") return ModelC(username="test-user", password="test-password")
@app.get("/model", response_model=ModelA) @app.get("/model/{name}", response_model=ModelA)
async def get_model_a(model_c=Depends(get_model_c)): async def get_model_a(name: str, model_c=Depends(get_model_c)):
return {"name": "model-a-name", "description": "model-a-desc", "model_b": model_c} return {"name": name, "description": "model-a-desc", "model_b": model_c}
client = TestClient(app) client = TestClient(app)
@ -35,10 +42,18 @@ openapi_schema = {
"openapi": "3.0.2", "openapi": "3.0.2",
"info": {"title": "FastAPI", "version": "0.1.0"}, "info": {"title": "FastAPI", "version": "0.1.0"},
"paths": { "paths": {
"/model": { "/model/{name}": {
"get": { "get": {
"summary": "Get Model A", "summary": "Get Model A",
"operationId": "get_model_a_model_get", "operationId": "get_model_a_model__name__get",
"parameters": [
{
"required": True,
"schema": {"title": "Name", "type": "string"},
"name": "name",
"in": "path",
}
],
"responses": { "responses": {
"200": { "200": {
"description": "Successful Response", "description": "Successful Response",
@ -47,13 +62,34 @@ openapi_schema = {
"schema": {"$ref": "#/components/schemas/ModelA"} "schema": {"$ref": "#/components/schemas/ModelA"}
} }
}, },
} },
"422": {
"description": "Validation Error",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/HTTPValidationError"
}
}
},
},
}, },
} }
} }
}, },
"components": { "components": {
"schemas": { "schemas": {
"HTTPValidationError": {
"title": "HTTPValidationError",
"type": "object",
"properties": {
"detail": {
"title": "Detail",
"type": "array",
"items": {"$ref": "#/components/schemas/ValidationError"},
}
},
},
"ModelA": { "ModelA": {
"title": "ModelA", "title": "ModelA",
"required": ["name", "model_b"], "required": ["name", "model_b"],
@ -70,6 +106,20 @@ openapi_schema = {
"type": "object", "type": "object",
"properties": {"username": {"title": "Username", "type": "string"}}, "properties": {"username": {"title": "Username", "type": "string"}},
}, },
"ValidationError": {
"title": "ValidationError",
"required": ["loc", "msg", "type"],
"type": "object",
"properties": {
"loc": {
"title": "Location",
"type": "array",
"items": {"type": "string"},
},
"msg": {"title": "Message", "type": "string"},
"type": {"title": "Error Type", "type": "string"},
},
},
} }
}, },
} }
@ -82,10 +132,22 @@ def test_openapi_schema():
def test_filter_sub_model(): def test_filter_sub_model():
response = client.get("/model") response = client.get("/model/modelA")
assert response.status_code == 200 assert response.status_code == 200
assert response.json() == { assert response.json() == {
"name": "model-a-name", "name": "modelA",
"description": "model-a-desc", "description": "model-a-desc",
"model_b": {"username": "test-user"}, "model_b": {"username": "test-user"},
} }
def test_validator_is_cloned():
with pytest.raises(ValidationError) as err:
client.get("/model/modelX")
assert err.value.errors() == [
{
"loc": ("response", "name"),
"msg": "name must end in A",
"type": "value_error",
}
]