Add BeforeValidator to Body tests

This commit is contained in:
Yurii Motov 2026-02-05 09:41:12 +01:00
parent 3441e14197
commit 9e85c19d3a
1 changed files with 212 additions and 68 deletions

View File

@ -1,25 +1,34 @@
from typing import Annotated, Any, Union from typing import Annotated, Any, Union
from unittest.mock import Mock, patch
import pytest import pytest
from dirty_equals import IsList, IsOneOf from dirty_equals import IsList, IsOneOf
from fastapi import Body, FastAPI from fastapi import Body, FastAPI
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from pydantic import BaseModel from pydantic import BaseModel, BeforeValidator, field_validator
from .utils import get_body_model_name from .utils import get_body_model_name
app = FastAPI() app = FastAPI()
def convert(v: Any) -> Any:
return v
# ===================================================================================== # =====================================================================================
# Nullable required # Nullable required
@app.post("/nullable-required") @app.post("/nullable-required")
async def read_nullable_required( async def read_nullable_required(
int_val: Annotated[Union[int, None], Body()], int_val: Annotated[Union[int, None], Body(), BeforeValidator(lambda v: convert(v))],
str_val: Annotated[Union[str, None], Body()], str_val: Annotated[Union[str, None], Body(), BeforeValidator(lambda v: convert(v))],
list_val: Union[list[int], None], list_val: Annotated[
Union[list[int], None],
Body(),
BeforeValidator(lambda v: convert(v)),
],
): ):
return { return {
"int_val": int_val, "int_val": int_val,
@ -34,6 +43,10 @@ class ModelNullableRequired(BaseModel):
str_val: Union[str, None] str_val: Union[str, None]
list_val: Union[list[int], None] list_val: Union[list[int], None]
@field_validator("*", mode="before")
def validate_all(cls, v):
return convert(v)
@app.post("/model-nullable-required") @app.post("/model-nullable-required")
async def read_model_nullable_required(params: ModelNullableRequired): async def read_model_nullable_required(params: ModelNullableRequired):
@ -47,21 +60,23 @@ async def read_model_nullable_required(params: ModelNullableRequired):
@app.post("/nullable-required-str") @app.post("/nullable-required-str")
async def read_nullable_required_no_embed_str( async def read_nullable_required_no_embed_str(
str_val: Annotated[Union[str, None], Body()], str_val: Annotated[Union[str, None], Body(), BeforeValidator(lambda v: convert(v))],
): ):
return {"val": str_val} return {"val": str_val}
@app.post("/nullable-required-int") @app.post("/nullable-required-int")
async def read_nullable_required_no_embed_int( async def read_nullable_required_no_embed_int(
int_val: Annotated[Union[int, None], Body()], int_val: Annotated[Union[int, None], Body(), BeforeValidator(lambda v: convert(v))],
): ):
return {"val": int_val} return {"val": int_val}
@app.post("/nullable-required-list") @app.post("/nullable-required-list")
async def read_nullable_required_no_embed_list( async def read_nullable_required_no_embed_list(
list_val: Annotated[Union[list[int], None], Body()], list_val: Annotated[
Union[list[int], None], Body(), BeforeValidator(lambda v: convert(v))
],
): ):
return {"val": list_val} return {"val": list_val}
@ -278,14 +293,18 @@ def test_nullable_required_pass_empty_dict(path: str, msg: str, error_type: str)
) )
def test_nullable_required_pass_null(path: str): def test_nullable_required_pass_null(path: str):
client = TestClient(app) client = TestClient(app)
response = client.post(
path, with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert:
json={ response = client.post(
"int_val": None, path,
"str_val": None, json={
"list_val": None, "int_val": None,
}, "str_val": None,
) "list_val": None,
},
)
assert mock_convert.call_count == 3, "Validator should be called for each field"
assert response.status_code == 200, response.text assert response.status_code == 200, response.text
assert response.json() == { assert response.json() == {
"int_val": None, "int_val": None,
@ -308,10 +327,13 @@ def test_nullable_required_pass_null(path: str):
@pytest.mark.xfail(reason="Explicit null-body is treated as missing") @pytest.mark.xfail(reason="Explicit null-body is treated as missing")
def test_nullable_required_no_embed_pass_null(path: str): def test_nullable_required_no_embed_pass_null(path: str):
client = TestClient(app) client = TestClient(app)
response = client.post(path, content="null")
with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert:
response = client.post(path, content="null")
assert mock_convert.call_count == 1, "Validator should be called once for the field"
assert response.status_code == 200, response.text assert response.status_code == 200, response.text
assert response.json() == {"val": None} assert response.json() == {"val": None}
# TODO: add test with BeforeValidator to ensure that it recieves `None` value
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -323,9 +345,13 @@ def test_nullable_required_no_embed_pass_null(path: str):
) )
def test_nullable_required_pass_value(path: str): def test_nullable_required_pass_value(path: str):
client = TestClient(app) client = TestClient(app)
response = client.post(
path, json={"int_val": "1", "str_val": "test", "list_val": ["1", "2"]} with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert:
) response = client.post(
path, json={"int_val": "1", "str_val": "test", "list_val": ["1", "2"]}
)
assert mock_convert.call_count == 3, "Validator should be called for each field"
assert response.status_code == 200, response.text assert response.status_code == 200, response.text
assert response.json() == { assert response.json() == {
"int_val": 1, "int_val": 1,
@ -347,10 +373,11 @@ def test_nullable_required_pass_value(path: str):
) )
def test_nullable_required_no_embed_pass_value(path: str, value: Any): def test_nullable_required_no_embed_pass_value(path: str, value: Any):
client = TestClient(app) client = TestClient(app)
response = client.post(
path, with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert:
json=value, response = client.post(path, json=value)
)
assert mock_convert.call_count == 1, "Validator should be called once for the field"
assert response.status_code == 200, response.text assert response.status_code == 200, response.text
assert response.json() == {"val": value} assert response.json() == {"val": value}
@ -361,9 +388,21 @@ def test_nullable_required_no_embed_pass_value(path: str, value: Any):
@app.post("/nullable-non-required") @app.post("/nullable-non-required")
async def read_nullable_non_required( async def read_nullable_non_required(
int_val: Annotated[Union[int, None], Body()] = None, int_val: Annotated[
str_val: Annotated[Union[str, None], Body()] = None, Union[int, None],
list_val: Union[list[int], None] = None, Body(),
BeforeValidator(lambda v: convert(v)),
] = None,
str_val: Annotated[
Union[str, None],
Body(),
BeforeValidator(lambda v: convert(v)),
] = None,
list_val: Annotated[
Union[list[int], None],
Body(),
BeforeValidator(lambda v: convert(v)),
] = None,
): ):
return { return {
"int_val": int_val, "int_val": int_val,
@ -378,6 +417,10 @@ class ModelNullableNonRequired(BaseModel):
str_val: Union[str, None] = None str_val: Union[str, None] = None
list_val: Union[list[int], None] = None list_val: Union[list[int], None] = None
@field_validator("*", mode="before")
def validate_all(cls, v):
return convert(v)
@app.post("/model-nullable-non-required") @app.post("/model-nullable-non-required")
async def read_model_nullable_non_required( async def read_model_nullable_non_required(
@ -393,21 +436,33 @@ async def read_model_nullable_non_required(
@app.post("/nullable-non-required-str") @app.post("/nullable-non-required-str")
async def read_nullable_non_required_no_embed_str( async def read_nullable_non_required_no_embed_str(
str_val: Annotated[Union[str, None], Body()] = None, str_val: Annotated[
Union[str, None],
Body(),
BeforeValidator(lambda v: convert(v)),
] = None,
): ):
return {"val": str_val} return {"val": str_val}
@app.post("/nullable-non-required-int") @app.post("/nullable-non-required-int")
async def read_nullable_non_required_no_embed_int( async def read_nullable_non_required_no_embed_int(
int_val: Annotated[Union[int, None], Body()] = None, int_val: Annotated[
Union[int, None],
Body(),
BeforeValidator(lambda v: convert(v)),
] = None,
): ):
return {"val": int_val} return {"val": int_val}
@app.post("/nullable-non-required-list") @app.post("/nullable-non-required-list")
async def read_nullable_non_required_no_embed_list( async def read_nullable_non_required_no_embed_list(
list_val: Annotated[Union[list[int], None], Body()] = None, list_val: Annotated[
Union[list[int], None],
Body(),
BeforeValidator(lambda v: convert(v)),
] = None,
): ):
return {"val": list_val} return {"val": list_val}
@ -499,7 +554,13 @@ def test_nullable_non_required_no_embed_schema(path: str, schema: dict):
) )
def test_nullable_non_required_missing(path: str): def test_nullable_non_required_missing(path: str):
client = TestClient(app) client = TestClient(app)
response = client.post(path, json={})
with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert:
response = client.post(path, json={})
assert mock_convert.call_count == 0, (
"Validator should not be called if the value is missing"
)
assert response.status_code == 200 assert response.status_code == 200
assert response.json() == { assert response.json() == {
"int_val": None, "int_val": None,
@ -547,7 +608,13 @@ def test_nullable_non_required_no_body(path: str):
) )
def test_nullable_non_required_no_embed_missing(path: str): def test_nullable_non_required_no_embed_missing(path: str):
client = TestClient(app) client = TestClient(app)
response = client.post(path)
with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert:
response = client.post(path)
assert mock_convert.call_count == 0, (
"Validator should not be called if the value is missing"
)
assert response.status_code == 200 assert response.status_code == 200
assert response.json() == {"val": None} assert response.json() == {"val": None}
@ -555,20 +622,29 @@ def test_nullable_non_required_no_embed_missing(path: str):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"path", "path",
[ [
"/nullable-non-required", pytest.param(
"/nullable-non-required",
marks=pytest.mark.xfail(
reason="Null values are treated as missing for non-model Body parameters"
),
),
"/model-nullable-non-required", "/model-nullable-non-required",
], ],
) )
def test_nullable_non_required_pass_null(path: str): def test_nullable_non_required_pass_null(path: str):
client = TestClient(app) client = TestClient(app)
response = client.post(
path, with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert:
json={ response = client.post(
"int_val": None, path,
"str_val": None, json={
"list_val": None, "int_val": None,
}, "str_val": None,
) "list_val": None,
},
)
assert mock_convert.call_count == 3, "Validator should be called for each field"
assert response.status_code == 200, response.text assert response.status_code == 200, response.text
assert response.json() == { assert response.json() == {
"int_val": None, "int_val": None,
@ -588,12 +664,16 @@ def test_nullable_non_required_pass_null(path: str):
"/nullable-non-required-list", "/nullable-non-required-list",
], ],
) )
@pytest.mark.xfail(reason="Explicit null-body is treated as missing")
def test_nullable_non_required_no_embed_pass_null(path: str): def test_nullable_non_required_no_embed_pass_null(path: str):
client = TestClient(app) client = TestClient(app)
response = client.post(path, content="null")
with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert:
response = client.post(path, content="null")
assert mock_convert.call_count == 1, "Validator should be called once for the field"
assert response.status_code == 200, response.text assert response.status_code == 200, response.text
assert response.json() == {"val": None} assert response.json() == {"val": None}
# TODO: add test with BeforeValidator to ensure that it recieves `None` value
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -605,9 +685,13 @@ def test_nullable_non_required_no_embed_pass_null(path: str):
) )
def test_nullable_non_required_pass_value(path: str): def test_nullable_non_required_pass_value(path: str):
client = TestClient(app) client = TestClient(app)
response = client.post(
path, json={"int_val": 1, "str_val": "test", "list_val": [1, 2]} with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert:
) response = client.post(
path, json={"int_val": 1, "str_val": "test", "list_val": [1, 2]}
)
assert mock_convert.call_count == 3, "Validator should be called for each field"
assert response.status_code == 200, response.text assert response.status_code == 200, response.text
assert response.json() == { assert response.json() == {
"int_val": 1, "int_val": 1,
@ -629,7 +713,11 @@ def test_nullable_non_required_pass_value(path: str):
) )
def test_nullable_non_required_no_embed_pass_value(path: str, value: Any): def test_nullable_non_required_no_embed_pass_value(path: str, value: Any):
client = TestClient(app) client = TestClient(app)
response = client.post(path, json=value)
with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert:
response = client.post(path, json=value)
assert mock_convert.call_count == 1, "Validator should be called once for the field"
assert response.status_code == 200, response.text assert response.status_code == 200, response.text
assert response.json() == {"val": value} assert response.json() == {"val": value}
@ -641,9 +729,21 @@ def test_nullable_non_required_no_embed_pass_value(path: str, value: Any):
@app.post("/nullable-with-non-null-default") @app.post("/nullable-with-non-null-default")
async def read_nullable_with_non_null_default( async def read_nullable_with_non_null_default(
*, *,
int_val: Annotated[Union[int, None], Body()] = -1, int_val: Annotated[
str_val: Annotated[Union[str, None], Body()] = "default", Union[int, None],
list_val: Annotated[Union[list[int], None], Body(default_factory=lambda: [0])], Body(),
BeforeValidator(lambda v: convert(v)),
] = -1,
str_val: Annotated[
Union[str, None],
Body(),
BeforeValidator(lambda v: convert(v)),
] = "default",
list_val: Annotated[
Union[list[int], None],
Body(default_factory=lambda: [0]),
BeforeValidator(lambda v: convert(v)),
],
): ):
return { return {
"int_val": int_val, "int_val": int_val,
@ -658,6 +758,10 @@ class ModelNullableWithNonNullDefault(BaseModel):
str_val: Union[str, None] = "default" str_val: Union[str, None] = "default"
list_val: Union[list[int], None] = [0] list_val: Union[list[int], None] = [0]
@field_validator("*", mode="before")
def validate_all(cls, v):
return convert(v)
@app.post("/model-nullable-with-non-null-default") @app.post("/model-nullable-with-non-null-default")
async def read_model_nullable_with_non_null_default( async def read_model_nullable_with_non_null_default(
@ -673,21 +777,33 @@ async def read_model_nullable_with_non_null_default(
@app.post("/nullable-with-non-null-default-str") @app.post("/nullable-with-non-null-default-str")
async def read_nullable_with_non_null_default_no_embed_str( async def read_nullable_with_non_null_default_no_embed_str(
str_val: Annotated[Union[str, None], Body()] = "default", str_val: Annotated[
Union[str, None],
Body(),
BeforeValidator(lambda v: convert(v)),
] = "default",
): ):
return {"val": str_val} return {"val": str_val}
@app.post("/nullable-with-non-null-default-int") @app.post("/nullable-with-non-null-default-int")
async def read_nullable_with_non_null_default_no_embed_int( async def read_nullable_with_non_null_default_no_embed_int(
int_val: Annotated[Union[int, None], Body()] = -1, int_val: Annotated[
Union[int, None],
Body(),
BeforeValidator(lambda v: convert(v)),
] = -1,
): ):
return {"val": int_val} return {"val": int_val}
@app.post("/nullable-with-non-null-default-list") @app.post("/nullable-with-non-null-default-list")
async def read_nullable_with_non_null_default_no_embed_list( async def read_nullable_with_non_null_default_no_embed_list(
list_val: Annotated[Union[list[int], None], Body(default_factory=lambda: [0])], list_val: Annotated[
Union[list[int], None],
Body(default_factory=lambda: [0]),
BeforeValidator(lambda v: convert(v)),
],
): ):
return {"val": list_val} return {"val": list_val}
@ -787,7 +903,13 @@ def test_nullable_with_non_null_default_no_embed_schema(path: str, schema: dict)
) )
def test_nullable_with_non_null_default_missing(path: str): def test_nullable_with_non_null_default_missing(path: str):
client = TestClient(app) client = TestClient(app)
response = client.post(path, json={})
with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert:
response = client.post(path, json={})
assert mock_convert.call_count == 0, (
"Validator should not be called if the value is missing"
)
assert response.status_code == 200, response.text assert response.status_code == 200, response.text
assert response.json() == { assert response.json() == {
"int_val": -1, "int_val": -1,
@ -835,7 +957,13 @@ def test_nullable_with_non_null_default_no_body(path: str):
) )
def test_nullable_with_non_null_default_no_embed_missing(path: str, expected: Any): def test_nullable_with_non_null_default_no_embed_missing(path: str, expected: Any):
client = TestClient(app) client = TestClient(app)
response = client.post(path)
with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert:
response = client.post(path)
assert mock_convert.call_count == 0, (
"Validator should not be called if the value is missing"
)
assert response.status_code == 200, response.text assert response.status_code == 200, response.text
assert response.json() == {"val": expected} assert response.json() == {"val": expected}
@ -854,14 +982,18 @@ def test_nullable_with_non_null_default_no_embed_missing(path: str, expected: An
) )
def test_nullable_with_non_null_default_pass_null(path: str): def test_nullable_with_non_null_default_pass_null(path: str):
client = TestClient(app) client = TestClient(app)
response = client.post(
path, with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert:
json={ response = client.post(
"int_val": None, path,
"str_val": None, json={
"list_val": None, "int_val": None,
}, "str_val": None,
) "list_val": None,
},
)
assert mock_convert.call_count == 3, "Validator should be called for each field"
assert response.status_code == 200, response.text assert response.status_code == 200, response.text
assert response.json() == { assert response.json() == {
"int_val": None, "int_val": None,
@ -884,7 +1016,11 @@ def test_nullable_with_non_null_default_pass_null(path: str):
@pytest.mark.xfail(reason="Explicit null-body is treated as missing") @pytest.mark.xfail(reason="Explicit null-body is treated as missing")
def test_nullable_with_non_null_default_no_embed_pass_null(path: str): def test_nullable_with_non_null_default_no_embed_pass_null(path: str):
client = TestClient(app) client = TestClient(app)
response = client.post(path, content="null")
with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert:
response = client.post(path, content="null")
assert mock_convert.call_count == 1, "Validator should be called once for the field"
assert response.status_code == 200, response.text assert response.status_code == 200, response.text
assert response.json() == {"val": None} assert response.json() == {"val": None}
@ -898,9 +1034,13 @@ def test_nullable_with_non_null_default_no_embed_pass_null(path: str):
) )
def test_nullable_with_non_null_default_pass_value(path: str): def test_nullable_with_non_null_default_pass_value(path: str):
client = TestClient(app) client = TestClient(app)
response = client.post(
path, json={"int_val": "1", "str_val": "test", "list_val": ["1", "2"]} with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert:
) response = client.post(
path, json={"int_val": "1", "str_val": "test", "list_val": ["1", "2"]}
)
assert mock_convert.call_count == 3, "Validator should be called for each field"
assert response.status_code == 200, response.text assert response.status_code == 200, response.text
assert response.json() == { assert response.json() == {
"int_val": 1, "int_val": 1,
@ -922,6 +1062,10 @@ def test_nullable_with_non_null_default_pass_value(path: str):
) )
def test_nullable_with_non_null_default_no_embed_pass_value(path: str, value: Any): def test_nullable_with_non_null_default_no_embed_pass_value(path: str, value: Any):
client = TestClient(app) client = TestClient(app)
response = client.post(path, json=value)
with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert:
response = client.post(path, json=value)
assert mock_convert.call_count == 1, "Validator should be called once for the field"
assert response.status_code == 200, response.text assert response.status_code == 200, response.text
assert response.json() == {"val": value} assert response.json() == {"val": value}