mirror of https://github.com/tiangolo/fastapi.git
🐛 Fix Pydantic field clone logic with validators (#899)
This commit is contained in:
parent
4f964939a1
commit
70bdade23b
|
|
@ -93,12 +93,9 @@ def create_cloned_field(field: ModelField) -> ModelField:
|
||||||
use_type = original_type
|
use_type = original_type
|
||||||
if lenient_issubclass(original_type, BaseModel):
|
if lenient_issubclass(original_type, BaseModel):
|
||||||
original_type = cast(Type[BaseModel], original_type)
|
original_type = cast(Type[BaseModel], original_type)
|
||||||
use_type = create_model(
|
use_type = create_model(original_type.__name__, __base__=original_type)
|
||||||
original_type.__name__, __config__=original_type.__config__
|
|
||||||
)
|
|
||||||
for f in original_type.__fields__.values():
|
for f in original_type.__fields__.values():
|
||||||
use_type.__fields__[f.name] = create_cloned_field(f)
|
use_type.__fields__[f.name] = create_cloned_field(f)
|
||||||
use_type.__validators__ = original_type.__validators__
|
|
||||||
if PYDANTIC_1:
|
if PYDANTIC_1:
|
||||||
new_field = ModelField(
|
new_field = ModelField(
|
||||||
name=field.name,
|
name=field.name,
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
|
import pytest
|
||||||
from fastapi import Depends, FastAPI
|
from fastapi import Depends, FastAPI
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, ValidationError, validator
|
||||||
from starlette.testclient import TestClient
|
from starlette.testclient import TestClient
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
@ -18,14 +19,20 @@ class ModelA(BaseModel):
|
||||||
description: str = None
|
description: str = None
|
||||||
model_b: ModelB
|
model_b: ModelB
|
||||||
|
|
||||||
|
@validator("name")
|
||||||
|
def lower_username(cls, name: str, values):
|
||||||
|
if not name.endswith("A"):
|
||||||
|
raise ValueError("name must end in A")
|
||||||
|
return name
|
||||||
|
|
||||||
|
|
||||||
async def get_model_c() -> ModelC:
|
async def get_model_c() -> ModelC:
|
||||||
return ModelC(username="test-user", password="test-password")
|
return ModelC(username="test-user", password="test-password")
|
||||||
|
|
||||||
|
|
||||||
@app.get("/model", response_model=ModelA)
|
@app.get("/model/{name}", response_model=ModelA)
|
||||||
async def get_model_a(model_c=Depends(get_model_c)):
|
async def get_model_a(name: str, model_c=Depends(get_model_c)):
|
||||||
return {"name": "model-a-name", "description": "model-a-desc", "model_b": model_c}
|
return {"name": name, "description": "model-a-desc", "model_b": model_c}
|
||||||
|
|
||||||
|
|
||||||
client = TestClient(app)
|
client = TestClient(app)
|
||||||
|
|
@ -35,10 +42,18 @@ openapi_schema = {
|
||||||
"openapi": "3.0.2",
|
"openapi": "3.0.2",
|
||||||
"info": {"title": "FastAPI", "version": "0.1.0"},
|
"info": {"title": "FastAPI", "version": "0.1.0"},
|
||||||
"paths": {
|
"paths": {
|
||||||
"/model": {
|
"/model/{name}": {
|
||||||
"get": {
|
"get": {
|
||||||
"summary": "Get Model A",
|
"summary": "Get Model A",
|
||||||
"operationId": "get_model_a_model_get",
|
"operationId": "get_model_a_model__name__get",
|
||||||
|
"parameters": [
|
||||||
|
{
|
||||||
|
"required": True,
|
||||||
|
"schema": {"title": "Name", "type": "string"},
|
||||||
|
"name": "name",
|
||||||
|
"in": "path",
|
||||||
|
}
|
||||||
|
],
|
||||||
"responses": {
|
"responses": {
|
||||||
"200": {
|
"200": {
|
||||||
"description": "Successful Response",
|
"description": "Successful Response",
|
||||||
|
|
@ -47,13 +62,34 @@ openapi_schema = {
|
||||||
"schema": {"$ref": "#/components/schemas/ModelA"}
|
"schema": {"$ref": "#/components/schemas/ModelA"}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
},
|
||||||
|
"422": {
|
||||||
|
"description": "Validation Error",
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/components/schemas/HTTPValidationError"
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"components": {
|
"components": {
|
||||||
"schemas": {
|
"schemas": {
|
||||||
|
"HTTPValidationError": {
|
||||||
|
"title": "HTTPValidationError",
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"detail": {
|
||||||
|
"title": "Detail",
|
||||||
|
"type": "array",
|
||||||
|
"items": {"$ref": "#/components/schemas/ValidationError"},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
"ModelA": {
|
"ModelA": {
|
||||||
"title": "ModelA",
|
"title": "ModelA",
|
||||||
"required": ["name", "model_b"],
|
"required": ["name", "model_b"],
|
||||||
|
|
@ -70,6 +106,20 @@ openapi_schema = {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {"username": {"title": "Username", "type": "string"}},
|
"properties": {"username": {"title": "Username", "type": "string"}},
|
||||||
},
|
},
|
||||||
|
"ValidationError": {
|
||||||
|
"title": "ValidationError",
|
||||||
|
"required": ["loc", "msg", "type"],
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"loc": {
|
||||||
|
"title": "Location",
|
||||||
|
"type": "array",
|
||||||
|
"items": {"type": "string"},
|
||||||
|
},
|
||||||
|
"msg": {"title": "Message", "type": "string"},
|
||||||
|
"type": {"title": "Error Type", "type": "string"},
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
@ -82,10 +132,22 @@ def test_openapi_schema():
|
||||||
|
|
||||||
|
|
||||||
def test_filter_sub_model():
|
def test_filter_sub_model():
|
||||||
response = client.get("/model")
|
response = client.get("/model/modelA")
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.json() == {
|
assert response.json() == {
|
||||||
"name": "model-a-name",
|
"name": "modelA",
|
||||||
"description": "model-a-desc",
|
"description": "model-a-desc",
|
||||||
"model_b": {"username": "test-user"},
|
"model_b": {"username": "test-user"},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_validator_is_cloned():
|
||||||
|
with pytest.raises(ValidationError) as err:
|
||||||
|
client.get("/model/modelX")
|
||||||
|
assert err.value.errors() == [
|
||||||
|
{
|
||||||
|
"loc": ("response", "name"),
|
||||||
|
"msg": "name must end in A",
|
||||||
|
"type": "value_error",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue