diff --git a/tests/test_request_params/test_body/test_nullable_and_defaults.py b/tests/test_request_params/test_body/test_nullable_and_defaults.py new file mode 100644 index 0000000000..13d2df4796 --- /dev/null +++ b/tests/test_request_params/test_body/test_nullable_and_defaults.py @@ -0,0 +1,1072 @@ +from typing import Annotated, Any, Union +from unittest.mock import Mock, patch + +import pytest +from dirty_equals import IsList, IsOneOf, IsPartialDict +from fastapi import Body, FastAPI +from fastapi.testclient import TestClient +from pydantic import BaseModel, BeforeValidator, field_validator + +from .utils import get_body_model_name + +app = FastAPI() + + +def convert(v: Any) -> Any: + return v + + +# ===================================================================================== +# Nullable required + + +@app.post("/nullable-required") +async def read_nullable_required( + int_val: Annotated[Union[int, None], Body(), BeforeValidator(lambda v: convert(v))], + str_val: Annotated[Union[str, None], Body(), BeforeValidator(lambda v: convert(v))], + list_val: Annotated[ + Union[list[int], None], + Body(), + BeforeValidator(lambda v: convert(v)), + ], +): + return { + "int_val": int_val, + "str_val": str_val, + "list_val": list_val, + "fields_set": None, + } + + +class ModelNullableRequired(BaseModel): + int_val: Union[int, None] + str_val: Union[str, None] + list_val: Union[list[int], None] + + @field_validator("*", mode="before") + def validate_all(cls, v): + return convert(v) + + +@app.post("/model-nullable-required") +async def read_model_nullable_required(params: ModelNullableRequired): + return { + "int_val": params.int_val, + "str_val": params.str_val, + "list_val": params.list_val, + "fields_set": params.model_fields_set, + } + + +@app.post("/nullable-required-str") +async def read_nullable_required_no_embed_str( + str_val: Annotated[Union[str, None], Body(), BeforeValidator(lambda v: convert(v))], +): + return {"val": str_val} + + +@app.post("/nullable-required-int") +async def read_nullable_required_no_embed_int( + int_val: Annotated[Union[int, None], Body(), BeforeValidator(lambda v: convert(v))], +): + return {"val": int_val} + + +@app.post("/nullable-required-list") +async def read_nullable_required_no_embed_list( + list_val: Annotated[ + Union[list[int], None], Body(), BeforeValidator(lambda v: convert(v)) + ], +): + return {"val": list_val} + + +@pytest.mark.parametrize( + "path", + [ + "/nullable-required", + "/model-nullable-required", + ], +) +def test_nullable_required_schema(path: str): + openapi = app.openapi() + body_model_name = get_body_model_name(openapi, path) + + assert app.openapi()["components"]["schemas"][body_model_name] == { + "properties": { + "int_val": { + "title": "Int Val", + "anyOf": [{"type": "integer"}, {"type": "null"}], + }, + "str_val": { + "title": "Str Val", + "anyOf": [{"type": "string"}, {"type": "null"}], + }, + "list_val": { + "title": "List Val", + "anyOf": [ + {"type": "array", "items": {"type": "integer"}}, + {"type": "null"}, + ], + }, + }, + "required": ["int_val", "str_val", "list_val"], + "title": body_model_name, + "type": "object", + } + + +@pytest.mark.parametrize( + ("path", "schema"), + [ + ( + "/nullable-required-str", + { + "anyOf": [{"type": "string"}, {"type": "null"}], + "title": "Str Val", + }, + ), + ( + "/nullable-required-int", + { + "anyOf": [{"type": "integer"}, {"type": "null"}], + "title": "Int Val", + }, + ), + ( + "/nullable-required-list", + { + "anyOf": [ + {"type": "array", "items": {"type": "integer"}}, + {"type": "null"}, + ], + "title": "List Val", + }, + ), + ], +) +def test_nullable_required_no_embed_schema(path: str, schema: dict): + openapi = app.openapi() + path_operation = openapi["paths"][path]["post"] + assert ( + path_operation["requestBody"]["content"]["application/json"]["schema"] == schema + ) + assert path_operation["requestBody"]["required"] is True + + +@pytest.mark.parametrize( + "path", + [ + "/nullable-required", + "/model-nullable-required", + ], +) +def test_nullable_required_missing(path: str): + client = TestClient(app) + response = client.post(path, json={}) + assert response.status_code == 422 + assert response.json() == { + "detail": [ + { + "type": "missing", + "loc": ["body", "int_val"], + "msg": "Field required", + "input": IsOneOf(None, {}), + }, + { + "type": "missing", + "loc": ["body", "str_val"], + "msg": "Field required", + "input": IsOneOf(None, {}), + }, + { + "type": "missing", + "loc": ["body", "list_val"], + "msg": "Field required", + "input": IsOneOf(None, {}), + }, + ] + } + + +@pytest.mark.parametrize( + "path", + [ + pytest.param( + "/nullable-required", + marks=pytest.mark.xfail( + reason="For non-model Body parameters, gives errors for each parameter separately" + ), + ), + "/model-nullable-required", + ], +) +def test_nullable_required_no_body(path: str): + client = TestClient(app) + response = client.post(path) + assert response.status_code == 422 + assert response.json() == { + "detail": [ + { + "type": "missing", + "loc": ["body"], + "msg": "Field required", + "input": None, + }, + ] + } + + +@pytest.mark.parametrize( + "path", + [ + "/nullable-required-str", + "/nullable-required-int", + "/nullable-required-list", + ], +) +def test_nullable_required_no_embed_missing(path: str): + client = TestClient(app) + response = client.post(path) + assert response.status_code == 422 + assert response.json() == { + "detail": [ + { + "input": None, + "loc": ["body"], + "msg": "Field required", + "type": "missing", + } + ] + } + + +@pytest.mark.parametrize( + ("path", "msg", "error_type"), + [ + ( + "/nullable-required-str", + "Input should be a valid string", + "string_type", + ), + ( + "/nullable-required-int", + "Input should be a valid integer", + "int_type", + ), + ( + "/nullable-required-list", + "Input should be a valid list", + "list_type", + ), + ], +) +def test_nullable_required_pass_empty_dict(path: str, msg: str, error_type: str): + client = TestClient(app) + response = client.post(path, json={}) + assert response.status_code == 422 + assert response.json() == { + "detail": [ + { + "input": {}, + "loc": ["body"], + "msg": msg, + "type": error_type, + } + ] + } + + +@pytest.mark.parametrize( + "path", + [ + pytest.param( + "/nullable-required", + marks=pytest.mark.xfail( + reason="Null values are treated as missing for non-model Body parameters" + ), + ), + pytest.param( + "/model-nullable-required", + ), + ], +) +def test_nullable_required_pass_null(path: str): + client = TestClient(app) + + with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert: + response = client.post( + path, + json={ + "int_val": None, + "str_val": None, + "list_val": None, + }, + ) + + assert mock_convert.call_count == 3, "Validator should be called for each field" + assert response.status_code == 200, response.text + assert response.json() == { + "int_val": None, + "str_val": None, + "list_val": None, + "fields_set": IsOneOf( + None, IsList("int_val", "str_val", "list_val", check_order=False) + ), + } + + +@pytest.mark.parametrize( + "path", + [ + "/nullable-required-str", + "/nullable-required-int", + "/nullable-required-list", + ], +) +@pytest.mark.xfail(reason="Explicit null-body is treated as missing") +def test_nullable_required_no_embed_pass_null(path: str): + client = TestClient(app) + + with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert: + response = client.post(path, content="null") + + assert mock_convert.call_count == 1, "Validator should be called once for the field" + assert response.status_code == 200, response.text # pragma: no cover + assert response.json() == {"val": None} # pragma: no cover + # TODO: Remove 'no cover' when the issue is fixed + + +@pytest.mark.parametrize( + "path", + [ + "/nullable-required", + "/model-nullable-required", + ], +) +def test_nullable_required_pass_value(path: str): + client = TestClient(app) + + with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert: + response = client.post( + path, json={"int_val": "1", "str_val": "test", "list_val": ["1", "2"]} + ) + + assert mock_convert.call_count == 3, "Validator should be called for each field" + assert response.status_code == 200, response.text + assert response.json() == { + "int_val": 1, + "str_val": "test", + "list_val": [1, 2], + "fields_set": IsOneOf( + None, IsList("int_val", "str_val", "list_val", check_order=False) + ), + } + + +@pytest.mark.parametrize( + ("path", "value"), + [ + ("/nullable-required-str", "test"), + ("/nullable-required-int", 1), + ("/nullable-required-list", [1, 2]), + ], +) +def test_nullable_required_no_embed_pass_value(path: str, value: Any): + client = TestClient(app) + + with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert: + response = client.post(path, json=value) + + assert mock_convert.call_count == 1, "Validator should be called once for the field" + assert response.status_code == 200, response.text + assert response.json() == {"val": value} + + +# ===================================================================================== +# Nullable with default=None + + +@app.post("/nullable-non-required") +async def read_nullable_non_required( + int_val: Annotated[ + Union[int, None], + Body(), + BeforeValidator(lambda v: convert(v)), + ] = None, + str_val: Annotated[ + Union[str, None], + Body(), + BeforeValidator(lambda v: convert(v)), + ] = None, + list_val: Annotated[ + Union[list[int], None], + Body(), + BeforeValidator(lambda v: convert(v)), + ] = None, +): + return { + "int_val": int_val, + "str_val": str_val, + "list_val": list_val, + "fields_set": None, + } + + +class ModelNullableNonRequired(BaseModel): + int_val: Union[int, None] = None + str_val: Union[str, None] = None + list_val: Union[list[int], None] = None + + @field_validator("*", mode="before") + def validate_all(cls, v): + return convert(v) + + +@app.post("/model-nullable-non-required") +async def read_model_nullable_non_required( + params: ModelNullableNonRequired, +): + return { + "int_val": params.int_val, + "str_val": params.str_val, + "list_val": params.list_val, + "fields_set": params.model_fields_set, + } + + +@app.post("/nullable-non-required-str") +async def read_nullable_non_required_no_embed_str( + str_val: Annotated[ + Union[str, None], + Body(), + BeforeValidator(lambda v: convert(v)), + ] = None, +): + return {"val": str_val} + + +@app.post("/nullable-non-required-int") +async def read_nullable_non_required_no_embed_int( + int_val: Annotated[ + Union[int, None], + Body(), + BeforeValidator(lambda v: convert(v)), + ] = None, +): + return {"val": int_val} + + +@app.post("/nullable-non-required-list") +async def read_nullable_non_required_no_embed_list( + list_val: Annotated[ + Union[list[int], None], + Body(), + BeforeValidator(lambda v: convert(v)), + ] = None, +): + return {"val": list_val} + + +@pytest.mark.parametrize( + "path", + [ + "/nullable-non-required", + "/model-nullable-non-required", + ], +) +def test_nullable_non_required_schema(path: str): + openapi = app.openapi() + body_model_name = get_body_model_name(openapi, path) + + assert app.openapi()["components"]["schemas"][body_model_name] == { + "properties": { + "int_val": { + "title": "Int Val", + "anyOf": [{"type": "integer"}, {"type": "null"}], + # "default": None, # `None` values are omitted in OpenAPI schema + }, + "str_val": { + "title": "Str Val", + "anyOf": [{"type": "string"}, {"type": "null"}], + # "default": None, # `None` values are omitted in OpenAPI schema + }, + "list_val": { + "title": "List Val", + "anyOf": [ + {"type": "array", "items": {"type": "integer"}}, + {"type": "null"}, + ], + # "default": None, # `None` values are omitted in OpenAPI schema + }, + }, + "title": body_model_name, + "type": "object", + } + + +@pytest.mark.parametrize( + ("path", "schema"), + [ + ( + "/nullable-non-required-str", + { + "anyOf": [{"type": "string"}, {"type": "null"}], + "title": "Str Val", + # "default": None, # `None` values are omitted in OpenAPI schema + }, + ), + ( + "/nullable-non-required-int", + { + "anyOf": [{"type": "integer"}, {"type": "null"}], + "title": "Int Val", + # "default": None, # `None` values are omitted in OpenAPI schema + }, + ), + ( + "/nullable-non-required-list", + { + "anyOf": [ + {"type": "array", "items": {"type": "integer"}}, + {"type": "null"}, + ], + "title": "List Val", + # "default": None, # `None` values are omitted in OpenAPI schema + }, + ), + ], +) +def test_nullable_non_required_no_embed_schema(path: str, schema: dict): + openapi = app.openapi() + path_operation = openapi["paths"][path]["post"] + assert ( + path_operation["requestBody"]["content"]["application/json"]["schema"] == schema + ) + assert "required" not in path_operation["requestBody"] + + +@pytest.mark.parametrize( + "path", + [ + "/nullable-non-required", + "/model-nullable-non-required", + ], +) +def test_nullable_non_required_missing(path: str): + client = TestClient(app) + + with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert: + response = client.post(path, json={}) + + assert mock_convert.call_count == 0, ( + "Validator should not be called if the value is missing" + ) + assert response.status_code == 200 + assert response.json() == { + "int_val": None, + "str_val": None, + "list_val": None, + "fields_set": IsOneOf(None, []), + } + + +@pytest.mark.parametrize( + "path", + [ + pytest.param( + "/nullable-non-required", + marks=pytest.mark.xfail( + reason="For non-model Body parameters, validates each parameter separately" + ), + ), + "/model-nullable-non-required", + ], +) +def test_nullable_non_required_no_body(path: str): + client = TestClient(app) + response = client.post(path) + assert response.status_code == 422 + assert response.json() == { + "detail": [ + { + "type": "missing", + "loc": ["body"], + "msg": "Field required", + "input": None, + }, + ] + } + + +@pytest.mark.parametrize( + "path", + [ + "/nullable-non-required-str", + "/nullable-non-required-int", + "/nullable-non-required-list", + ], +) +def test_nullable_non_required_no_embed_missing(path: str): + client = TestClient(app) + + with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert: + response = client.post(path) + + assert mock_convert.call_count == 0, ( + "Validator should not be called if the value is missing" + ) + assert response.status_code == 200 + assert response.json() == {"val": None} + + +@pytest.mark.parametrize( + "path", + [ + pytest.param( + "/nullable-non-required", + marks=pytest.mark.xfail( + reason="Null values are treated as missing for non-model Body parameters" + ), + ), + "/model-nullable-non-required", + ], +) +def test_nullable_non_required_pass_null(path: str): + client = TestClient(app) + + with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert: + response = client.post( + path, + json={ + "int_val": None, + "str_val": None, + "list_val": None, + }, + ) + + assert mock_convert.call_count == 3, "Validator should be called for each field" + assert response.status_code == 200, response.text + assert response.json() == { + "int_val": None, + "str_val": None, + "list_val": None, + "fields_set": IsOneOf( + None, IsList("int_val", "str_val", "list_val", check_order=False) + ), + } + + +@pytest.mark.parametrize( + "path", + [ + "/nullable-non-required-str", + "/nullable-non-required-int", + "/nullable-non-required-list", + ], +) +@pytest.mark.xfail(reason="Explicit null-body is treated as missing") +def test_nullable_non_required_no_embed_pass_null(path: str): + client = TestClient(app) + + with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert: + response = client.post(path, content="null") + + assert mock_convert.call_count == 1, "Validator should be called once for the field" + assert response.status_code == 200, response.text # pragma: no cover + assert response.json() == {"val": None} # pragma: no cover + # TODO: Remove 'no cover' when the issue is fixed + + +@pytest.mark.parametrize( + "path", + [ + "/nullable-non-required", + "/model-nullable-non-required", + ], +) +def test_nullable_non_required_pass_value(path: str): + client = TestClient(app) + + with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert: + response = client.post( + path, json={"int_val": 1, "str_val": "test", "list_val": [1, 2]} + ) + + assert mock_convert.call_count == 3, "Validator should be called for each field" + assert response.status_code == 200, response.text + assert response.json() == { + "int_val": 1, + "str_val": "test", + "list_val": [1, 2], + "fields_set": IsOneOf( + None, IsList("int_val", "str_val", "list_val", check_order=False) + ), + } + + +@pytest.mark.parametrize( + ("path", "value"), + [ + ("/nullable-non-required-str", "test"), + ("/nullable-non-required-int", 1), + ("/nullable-non-required-list", [1, 2]), + ], +) +def test_nullable_non_required_no_embed_pass_value(path: str, value: Any): + client = TestClient(app) + + with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert: + response = client.post(path, json=value) + + assert mock_convert.call_count == 1, "Validator should be called once for the field" + assert response.status_code == 200, response.text + assert response.json() == {"val": value} + + +# ===================================================================================== +# Nullable with not-None default + + +@app.post("/nullable-with-non-null-default") +async def read_nullable_with_non_null_default( + *, + int_val: Annotated[ + Union[int, None], + Body(), + BeforeValidator(lambda v: convert(v)), + ] = -1, + str_val: Annotated[ + Union[str, None], + Body(), + BeforeValidator(lambda v: convert(v)), + ] = "default", + list_val: Annotated[ + Union[list[int], None], + Body(default_factory=lambda: [0]), + BeforeValidator(lambda v: convert(v)), + ], +): + return { + "int_val": int_val, + "str_val": str_val, + "list_val": list_val, + "fields_set": None, + } + + +class ModelNullableWithNonNullDefault(BaseModel): + int_val: Union[int, None] = -1 + str_val: Union[str, None] = "default" + list_val: Union[list[int], None] = [0] + + @field_validator("*", mode="before") + def validate_all(cls, v): + return convert(v) + + +@app.post("/model-nullable-with-non-null-default") +async def read_model_nullable_with_non_null_default( + params: ModelNullableWithNonNullDefault, +): + return { + "int_val": params.int_val, + "str_val": params.str_val, + "list_val": params.list_val, + "fields_set": params.model_fields_set, + } + + +@app.post("/nullable-with-non-null-default-str") +async def read_nullable_with_non_null_default_no_embed_str( + str_val: Annotated[ + Union[str, None], + Body(), + BeforeValidator(lambda v: convert(v)), + ] = "default", +): + return {"val": str_val} + + +@app.post("/nullable-with-non-null-default-int") +async def read_nullable_with_non_null_default_no_embed_int( + int_val: Annotated[ + Union[int, None], + Body(), + BeforeValidator(lambda v: convert(v)), + ] = -1, +): + return {"val": int_val} + + +@app.post("/nullable-with-non-null-default-list") +async def read_nullable_with_non_null_default_no_embed_list( + list_val: Annotated[ + Union[list[int], None], + Body(default_factory=lambda: [0]), + BeforeValidator(lambda v: convert(v)), + ], +): + return {"val": list_val} + + +@pytest.mark.parametrize( + "path", + [ + "/nullable-with-non-null-default", + "/model-nullable-with-non-null-default", + ], +) +def test_nullable_with_non_null_default_schema(path: str): + openapi = app.openapi() + body_model_name = get_body_model_name(openapi, path) + body_model = app.openapi()["components"]["schemas"][body_model_name] + + assert body_model == { + "properties": { + "int_val": { + "title": "Int Val", + "anyOf": [{"type": "integer"}, {"type": "null"}], + "default": -1, + }, + "str_val": { + "title": "Str Val", + "anyOf": [{"type": "string"}, {"type": "null"}], + "default": "default", + }, + "list_val": IsPartialDict( + { + "title": "List Val", + "anyOf": [ + {"type": "array", "items": {"type": "integer"}}, + {"type": "null"}, + ], + }, + ), + }, + "title": body_model_name, + "type": "object", + } + + if path == "/model-nullable-with-non-null-default": + # Check default value for list_val param for model-based Body parameters only. + # default_factory is not reflected in OpenAPI schema + assert body_model["properties"]["list_val"]["default"] == [0] + + +@pytest.mark.parametrize( + ("path", "schema"), + [ + ( + "/nullable-with-non-null-default-str", + { + "anyOf": [{"type": "string"}, {"type": "null"}], + "title": "Str Val", + "default": "default", + }, + ), + ( + "/nullable-with-non-null-default-int", + { + "anyOf": [{"type": "integer"}, {"type": "null"}], + "title": "Int Val", + "default": -1, + }, + ), + ( + "/nullable-with-non-null-default-list", + { + "anyOf": [ + {"type": "array", "items": {"type": "integer"}}, + {"type": "null"}, + ], + "title": "List Val", + }, + ), + ], +) +def test_nullable_with_non_null_default_no_embed_schema(path: str, schema: dict): + openapi = app.openapi() + path_operation = openapi["paths"][path]["post"] + assert ( + path_operation["requestBody"]["content"]["application/json"]["schema"] == schema + ) + assert "required" not in path_operation["requestBody"] + + +@pytest.mark.parametrize( + "path", + [ + "/nullable-with-non-null-default", + "/model-nullable-with-non-null-default", + ], +) +def test_nullable_with_non_null_default_missing(path: str): + client = TestClient(app) + + with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert: + response = client.post(path, json={}) + + assert mock_convert.call_count == 0, ( + "Validator should not be called if the value is missing" + ) + assert response.status_code == 200, response.text + assert response.json() == { + "int_val": -1, + "str_val": "default", + "list_val": [0], + "fields_set": IsOneOf(None, []), + } + + +@pytest.mark.parametrize( + "path", + [ + pytest.param( + "/nullable-with-non-null-default", + marks=pytest.mark.xfail( + reason="For non-model Body parameters, validates each parameter separately" + ), + ), + "/model-nullable-with-non-null-default", + ], +) +def test_nullable_with_non_null_default_no_body(path: str): + client = TestClient(app) + response = client.post(path) + assert response.status_code == 422 + assert response.json() == { + "detail": [ + { + "type": "missing", + "loc": ["body"], + "msg": "Field required", + "input": None, + }, + ] + } + + +@pytest.mark.parametrize( + ("path", "expected"), + [ + ("/nullable-with-non-null-default-str", "default"), + ("/nullable-with-non-null-default-int", -1), + ("/nullable-with-non-null-default-list", [0]), + ], +) +def test_nullable_with_non_null_default_no_embed_missing(path: str, expected: Any): + client = TestClient(app) + + with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert: + response = client.post(path) + + assert mock_convert.call_count == 0, ( + "Validator should not be called if the value is missing" + ) + assert response.status_code == 200, response.text + assert response.json() == {"val": expected} + + +@pytest.mark.parametrize( + "path", + [ + pytest.param( + "/nullable-with-non-null-default", + marks=pytest.mark.xfail( + reason="Null values are treated as missing for non-model Body parameters" + ), + ), + "/model-nullable-with-non-null-default", + ], +) +def test_nullable_with_non_null_default_pass_null(path: str): + client = TestClient(app) + + with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert: + response = client.post( + path, + json={ + "int_val": None, + "str_val": None, + "list_val": None, + }, + ) + + assert mock_convert.call_count == 3, "Validator should be called for each field" + assert response.status_code == 200, response.text + assert response.json() == { + "int_val": None, + "str_val": None, + "list_val": None, + "fields_set": IsOneOf( + None, IsList("int_val", "str_val", "list_val", check_order=False) + ), + } + + +@pytest.mark.parametrize( + "path", + [ + "/nullable-with-non-null-default-str", + "/nullable-with-non-null-default-int", + "/nullable-with-non-null-default-list", + ], +) +@pytest.mark.xfail(reason="Explicit null-body is treated as missing") +def test_nullable_with_non_null_default_no_embed_pass_null(path: str): + client = TestClient(app) + + with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert: + response = client.post(path, content="null") + + assert mock_convert.call_count == 1, "Validator should be called once for the field" + assert response.status_code == 200, response.text # pragma: no cover + assert response.json() == {"val": None} # pragma: no cover + # TODO: Remove 'no cover' when the issue is fixed + + +@pytest.mark.parametrize( + "path", + [ + "/nullable-with-non-null-default", + "/model-nullable-with-non-null-default", + ], +) +def test_nullable_with_non_null_default_pass_value(path: str): + client = TestClient(app) + + with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert: + response = client.post( + path, json={"int_val": "1", "str_val": "test", "list_val": ["1", "2"]} + ) + + assert mock_convert.call_count == 3, "Validator should be called for each field" + assert response.status_code == 200, response.text + assert response.json() == { + "int_val": 1, + "str_val": "test", + "list_val": [1, 2], + "fields_set": IsOneOf( + None, IsList("int_val", "str_val", "list_val", check_order=False) + ), + } + + +@pytest.mark.parametrize( + ("path", "value"), + [ + ("/nullable-with-non-null-default-str", "test"), + ("/nullable-with-non-null-default-int", 1), + ("/nullable-with-non-null-default-list", [1, 2]), + ], +) +def test_nullable_with_non_null_default_no_embed_pass_value(path: str, value: Any): + client = TestClient(app) + + with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert: + response = client.post(path, json=value) + + assert mock_convert.call_count == 1, "Validator should be called once for the field" + assert response.status_code == 200, response.text + assert response.json() == {"val": value} diff --git a/tests/test_request_params/test_cookie/test_nullable_and_defaults.py b/tests/test_request_params/test_cookie/test_nullable_and_defaults.py new file mode 100644 index 0000000000..88fa4f78d3 --- /dev/null +++ b/tests/test_request_params/test_cookie/test_nullable_and_defaults.py @@ -0,0 +1,402 @@ +from typing import Annotated, Any, Union +from unittest.mock import Mock, patch + +import pytest +from dirty_equals import IsList, IsOneOf +from fastapi import Cookie, FastAPI +from fastapi.testclient import TestClient +from pydantic import BaseModel, BeforeValidator, field_validator + +app = FastAPI() + + +def convert(v: Any) -> Any: + return v + + +# ===================================================================================== +# Nullable required + + +@app.get("/nullable-required") +async def read_nullable_required( + int_val: Annotated[ + Union[int, None], + Cookie(), + BeforeValidator(lambda v: convert(v)), + ], + str_val: Annotated[ + Union[str, None], + Cookie(), + BeforeValidator(lambda v: convert(v)), + ], +): + return { + "int_val": int_val, + "str_val": str_val, + "fields_set": None, + } + + +class ModelNullableRequired(BaseModel): + int_val: Union[int, None] + str_val: Union[str, None] + + @field_validator("*", mode="before") + @classmethod + def convert_fields(cls, v): + return convert(v) + + +@app.get("/model-nullable-required") +async def read_model_nullable_required( + params: Annotated[ModelNullableRequired, Cookie()], +): + return { + "int_val": params.int_val, + "str_val": params.str_val, + "fields_set": params.model_fields_set, + } + + +@pytest.mark.parametrize( + "path", + [ + "/nullable-required", + "/model-nullable-required", + ], +) +def test_nullable_required_schema(path: str): + assert app.openapi()["paths"][path]["get"]["parameters"] == [ + { + "required": True, + "schema": { + "title": "Int Val", + "anyOf": [{"type": "integer"}, {"type": "null"}], + }, + "name": "int_val", + "in": "cookie", + }, + { + "required": True, + "schema": { + "title": "Str Val", + "anyOf": [{"type": "string"}, {"type": "null"}], + }, + "name": "str_val", + "in": "cookie", + }, + ] + + +@pytest.mark.parametrize( + "path", + [ + "/nullable-required", + "/model-nullable-required", + ], +) +def test_nullable_required_missing(path: str): + client = TestClient(app) + with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert: + response = client.get(path) + + assert mock_convert.call_count == 0, ( + "Validator should not be called if the value is missing" + ) + assert response.status_code == 422 + assert response.json() == { + "detail": [ + { + "type": "missing", + "loc": ["cookie", "int_val"], + "msg": "Field required", + "input": IsOneOf(None, {}), + }, + { + "type": "missing", + "loc": ["cookie", "str_val"], + "msg": "Field required", + "input": IsOneOf(None, {}), + }, + ] + } + + +@pytest.mark.parametrize( + "path", + [ + "/nullable-required", + "/model-nullable-required", + ], +) +def test_nullable_required_pass_value(path: str): + client = TestClient(app) + client.cookies.set("int_val", "1") + client.cookies.set("str_val", "test") + + with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert: + response = client.get(path) + + assert mock_convert.call_count == 2, "Validator should be called for each field" + assert response.status_code == 200, response.text + assert response.json() == { + "int_val": 1, + "str_val": "test", + "fields_set": IsOneOf(None, IsList("int_val", "str_val", check_order=False)), + } + + +# ===================================================================================== +# Nullable with default=None + + +@app.get("/nullable-non-required") +async def read_nullable_non_required( + int_val: Annotated[ + Union[int, None], + Cookie(), + BeforeValidator(lambda v: convert(v)), + ] = None, + str_val: Annotated[ + Union[str, None], + Cookie(), + BeforeValidator(lambda v: convert(v)), + ] = None, +): + return { + "int_val": int_val, + "str_val": str_val, + "fields_set": None, + } + + +class ModelNullableNonRequired(BaseModel): + int_val: Union[int, None] = None + str_val: Union[str, None] = None + + @field_validator("*", mode="before") + @classmethod + def convert_fields(cls, v): + return convert(v) + + +@app.get("/model-nullable-non-required") +async def read_model_nullable_non_required( + params: Annotated[ModelNullableNonRequired, Cookie()], +): + return { + "int_val": params.int_val, + "str_val": params.str_val, + "fields_set": params.model_fields_set, + } + + +@pytest.mark.parametrize( + "path", + [ + "/nullable-non-required", + "/model-nullable-non-required", + ], +) +def test_nullable_non_required_schema(path: str): + assert app.openapi()["paths"][path]["get"]["parameters"] == [ + { + "required": False, + "schema": { + "title": "Int Val", + "anyOf": [{"type": "integer"}, {"type": "null"}], + # "default": None, # `None` values are omitted in OpenAPI schema + }, + "name": "int_val", + "in": "cookie", + }, + { + "required": False, + "schema": { + "title": "Str Val", + "anyOf": [{"type": "string"}, {"type": "null"}], + # "default": None, # `None` values are omitted in OpenAPI schema + }, + "name": "str_val", + "in": "cookie", + }, + ] + + +@pytest.mark.parametrize( + "path", + [ + "/nullable-non-required", + "/model-nullable-non-required", + ], +) +def test_nullable_non_required_missing(path: str): + client = TestClient(app) + + with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert: + response = client.get(path) + + assert mock_convert.call_count == 0, ( + "Validator should not be called if the value is missing" + ) + assert response.status_code == 200 + assert response.json() == { + "int_val": None, + "str_val": None, + "fields_set": IsOneOf(None, []), + } + + +@pytest.mark.parametrize( + "path", + [ + "/nullable-non-required", + "/model-nullable-non-required", + ], +) +def test_nullable_non_required_pass_value(path: str): + client = TestClient(app) + client.cookies.set("int_val", "1") + client.cookies.set("str_val", "test") + + with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert: + response = client.get(path) + + assert mock_convert.call_count == 2, "Validator should be called for each field" + assert response.status_code == 200, response.text + assert response.json() == { + "int_val": 1, + "str_val": "test", + "fields_set": IsOneOf(None, IsList("int_val", "str_val", check_order=False)), + } + + +# ===================================================================================== +# Nullable with not-None default + + +@app.get("/nullable-with-non-null-default") +async def read_nullable_with_non_null_default( + *, + int_val: Annotated[ + Union[int, None], + Cookie(), + BeforeValidator(lambda v: convert(v)), + ] = -1, + str_val: Annotated[ + Union[str, None], + Cookie(), + BeforeValidator(lambda v: convert(v)), + ] = "default", +): + return { + "int_val": int_val, + "str_val": str_val, + "fields_set": None, + } + + +class ModelNullableWithNonNullDefault(BaseModel): + int_val: Union[int, None] = -1 + str_val: Union[str, None] = "default" + + @field_validator("*", mode="before") + @classmethod + def convert_fields(cls, v): + return convert(v) + + +@app.get("/model-nullable-with-non-null-default") +async def read_model_nullable_with_non_null_default( + params: Annotated[ModelNullableWithNonNullDefault, Cookie()], +): + return { + "int_val": params.int_val, + "str_val": params.str_val, + "fields_set": params.model_fields_set, + } + + +@pytest.mark.parametrize( + "path", + [ + "/nullable-with-non-null-default", + "/model-nullable-with-non-null-default", + ], +) +def test_nullable_with_non_null_default_schema(path: str): + assert app.openapi()["paths"][path]["get"]["parameters"] == [ + { + "required": False, + "schema": { + "title": "Int Val", + "anyOf": [{"type": "integer"}, {"type": "null"}], + "default": -1, + }, + "name": "int_val", + "in": "cookie", + }, + { + "required": False, + "schema": { + "title": "Str Val", + "anyOf": [{"type": "string"}, {"type": "null"}], + "default": "default", + }, + "name": "str_val", + "in": "cookie", + }, + ] + + +@pytest.mark.parametrize( + "path", + [ + "/nullable-with-non-null-default", + "/model-nullable-with-non-null-default", + ], +) +@pytest.mark.xfail( + reason="Missing parameters are pre-populated with default values before validation" +) +def test_nullable_with_non_null_default_missing(path: str): + client = TestClient(app) + + with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert: + response = client.get(path) + + assert mock_convert.call_count == 0, ( + "Validator should not be called if the value is missing" + ) + assert response.status_code == 200 # pragma: no cover + assert response.json() == { # pragma: no cover + "int_val": -1, + "str_val": "default", + "fields_set": IsOneOf(None, []), + } + # TODO: Remove 'no cover' when the issue is fixed + + +@pytest.mark.parametrize( + "path", + [ + "/nullable-with-non-null-default", + "/model-nullable-with-non-null-default", + ], +) +def test_nullable_with_non_null_default_pass_value(path: str): + client = TestClient(app) + client.cookies.set("int_val", "1") + client.cookies.set("str_val", "test") + + with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert: + response = client.get(path) + + assert mock_convert.call_count == 2, "Validator should be called for each field" + assert response.status_code == 200, response.text + assert response.json() == { + "int_val": 1, + "str_val": "test", + "fields_set": IsOneOf(None, IsList("int_val", "str_val", check_order=False)), + } diff --git a/tests/test_request_params/test_file/test_nullable_and_defaults.py b/tests/test_request_params/test_file/test_nullable_and_defaults.py new file mode 100644 index 0000000000..4c34889e52 --- /dev/null +++ b/tests/test_request_params/test_file/test_nullable_and_defaults.py @@ -0,0 +1,480 @@ +from typing import Annotated, Any, Union +from unittest.mock import Mock, patch + +import pytest +from dirty_equals import IsOneOf +from fastapi import FastAPI, File, UploadFile +from fastapi.testclient import TestClient +from pydantic import BeforeValidator +from starlette.datastructures import UploadFile as StarletteUploadFile + +from .utils import get_body_model_name + +app = FastAPI() + + +def convert(v: Any) -> Any: + return v + + +# ===================================================================================== +# Nullable required + + +@app.post("/nullable-required-bytes") +async def read_nullable_required_bytes( + file: Annotated[ + Union[bytes, None], + File(), + BeforeValidator(lambda v: convert(v)), + ], + files: Annotated[ + Union[list[bytes], None], + File(), + BeforeValidator(lambda v: convert(v)), + ], +): + return { + "file": len(file) if file is not None else None, + "files": [len(f) for f in files] if files is not None else None, + } + + +@app.post("/nullable-required-uploadfile") +async def read_nullable_required_uploadfile( + file: Annotated[ + Union[UploadFile, None], + File(), + BeforeValidator(lambda v: convert(v)), + ], + files: Annotated[ + Union[list[UploadFile], None], + File(), + BeforeValidator(lambda v: convert(v)), + ], +): + return { + "file": file.size if file is not None else None, + "files": [f.size for f in files] if files is not None else None, + } + + +@pytest.mark.parametrize( + "path", + [ + "/nullable-required-bytes", + "/nullable-required-uploadfile", + ], +) +def test_nullable_required_schema(path: str): + openapi = app.openapi() + body_model_name = get_body_model_name(openapi, path) + + assert app.openapi()["components"]["schemas"][body_model_name] == { + "properties": { + "file": { + "title": "File", + "anyOf": [{"type": "string", "format": "binary"}, {"type": "null"}], + }, + "files": { + "title": "Files", + "anyOf": [ + {"type": "array", "items": {"type": "string", "format": "binary"}}, + {"type": "null"}, + ], + }, + }, + "required": ["file", "files"], + "title": body_model_name, + "type": "object", + } + + +@pytest.mark.parametrize( + "path", + [ + "/nullable-required-bytes", + "/nullable-required-uploadfile", + ], +) +def test_nullable_required_missing(path: str): + client = TestClient(app) + + with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert: + response = client.post(path) + + assert mock_convert.call_count == 0, ( + "Validator should not be called if the value is missing" + ) + assert response.status_code == 422 + assert response.json() == { + "detail": [ + { + "type": "missing", + "loc": ["body", "file"], + "msg": "Field required", + "input": IsOneOf(None, {}), + }, + { + "type": "missing", + "loc": ["body", "files"], + "msg": "Field required", + "input": IsOneOf(None, {}), + }, + ] + } + + +@pytest.mark.parametrize( + "path", + [ + "/nullable-required-bytes", + "/nullable-required-uploadfile", + ], +) +def test_nullable_required_pass_empty_file(path: str): + client = TestClient(app) + + with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert: + response = client.post( + path, + files=[("file", b""), ("files", b""), ("files", b"")], + ) + + assert mock_convert.call_count == 2, "Validator should be called for each field" + call_args = [call_args_item.args for call_args_item in mock_convert.call_args_list] + file_call_arg_1 = call_args[0][0] + files_call_arg_1 = call_args[1][0] + + assert ( + (file_call_arg_1 == b"") # file as bytes + or isinstance(file_call_arg_1, StarletteUploadFile) # file as UploadFile + ) + assert ( + (files_call_arg_1 == [b"", b""]) # files as bytes + or all( # files as UploadFile + isinstance(f, StarletteUploadFile) for f in files_call_arg_1 + ) + ) + + assert response.status_code == 200, response.text + assert response.json() == { + "file": 0, + "files": [0, 0], + } + + +@pytest.mark.parametrize( + "path", + [ + "/nullable-required-bytes", + "/nullable-required-uploadfile", + ], +) +def test_nullable_required_pass_file(path: str): + client = TestClient(app) + + with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert: + response = client.post( + path, + files=[ + ("file", b"test 1"), + ("files", b"test 2"), + ("files", b"test 3"), + ], + ) + + assert mock_convert.call_count == 2, "Validator should be called for each field" + assert response.status_code == 200, response.text + assert response.json() == {"file": 6, "files": [6, 6]} + + +# ===================================================================================== +# Nullable with default=None + + +@app.post("/nullable-non-required-bytes") +async def read_nullable_non_required_bytes( + file: Annotated[ + Union[bytes, None], + File(), + BeforeValidator(lambda v: convert(v)), + ] = None, + files: Annotated[ + Union[list[bytes], None], + File(), + BeforeValidator(lambda v: convert(v)), + ] = None, +): + return { + "file": len(file) if file is not None else None, + "files": [len(f) for f in files] if files is not None else None, + } + + +@app.post("/nullable-non-required-uploadfile") +async def read_nullable_non_required_uploadfile( + file: Annotated[ + Union[UploadFile, None], + File(), + BeforeValidator(lambda v: convert(v)), + ] = None, + files: Annotated[ + Union[list[UploadFile], None], + File(), + BeforeValidator(lambda v: convert(v)), + ] = None, +): + return { + "file": file.size if file is not None else None, + "files": [f.size for f in files] if files is not None else None, + } + + +@pytest.mark.parametrize( + "path", + [ + "/nullable-non-required-bytes", + "/nullable-non-required-uploadfile", + ], +) +def test_nullable_non_required_schema(path: str): + openapi = app.openapi() + body_model_name = get_body_model_name(openapi, path) + + assert app.openapi()["components"]["schemas"][body_model_name] == { + "properties": { + "file": { + "title": "File", + "anyOf": [{"type": "string", "format": "binary"}, {"type": "null"}], + # "default": None, # `None` values are omitted in OpenAPI schema + }, + "files": { + "title": "Files", + "anyOf": [ + {"type": "array", "items": {"type": "string", "format": "binary"}}, + {"type": "null"}, + ], + # "default": None, # `None` values are omitted in OpenAPI schema + }, + }, + "title": body_model_name, + "type": "object", + } + + +@pytest.mark.parametrize( + "path", + [ + "/nullable-non-required-bytes", + "/nullable-non-required-uploadfile", + ], +) +def test_nullable_non_required_missing(path: str): + client = TestClient(app) + + with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert: + response = client.post(path) + + assert mock_convert.call_count == 0, ( + "Validator should not be called if the value is missing" + ) + assert response.status_code == 200 + assert response.json() == { + "file": None, + "files": None, + } + + +@pytest.mark.parametrize( + "path", + [ + "/nullable-non-required-bytes", + "/nullable-non-required-uploadfile", + ], +) +def test_nullable_non_required_pass_empty_file(path: str): + client = TestClient(app) + + with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert: + response = client.post( + path, + files=[("file", b""), ("files", b""), ("files", b"")], + ) + + assert mock_convert.call_count == 2, "Validator should be called for each field" + call_args = [call_args_item.args for call_args_item in mock_convert.call_args_list] + file_call_arg_1 = call_args[0][0] + files_call_arg_1 = call_args[1][0] + + assert ( + (file_call_arg_1 == b"") # file as bytes + or isinstance(file_call_arg_1, StarletteUploadFile) # file as UploadFile + ) + assert ( + (files_call_arg_1 == [b"", b""]) # files as bytes + or all( # files as UploadFile + isinstance(f, StarletteUploadFile) for f in files_call_arg_1 + ) + ) + + assert response.status_code == 200, response.text + assert response.json() == {"file": 0, "files": [0, 0]} + + +@pytest.mark.parametrize( + "path", + [ + "/nullable-non-required-bytes", + "/nullable-non-required-uploadfile", + ], +) +def test_nullable_non_required_pass_file(path: str): + client = TestClient(app) + + with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert: + response = client.post( + path, + files=[("file", b"test 1"), ("files", b"test 2"), ("files", b"test 3")], + ) + + assert mock_convert.call_count == 2, "Validator should be called for each field" + assert response.status_code == 200, response.text + assert response.json() == {"file": 6, "files": [6, 6]} + + +# ===================================================================================== +# Nullable with not-None default + + +@app.post("/nullable-with-non-null-default-bytes") +async def read_nullable_with_non_null_default_bytes( + *, + file: Annotated[ + Union[bytes, None], + File(), + BeforeValidator(lambda v: convert(v)), + ] = b"default", + files: Annotated[ + Union[list[bytes], None], + File(default_factory=lambda: [b"default"]), + BeforeValidator(lambda v: convert(v)), + ], +): + return { + "file": len(file) if file is not None else None, + "files": [len(f) for f in files] if files is not None else None, + } + + +# Note: It seems to be not possible to create endpoint with UploadFile and non-None default + + +@pytest.mark.parametrize( + "path", + [ + "/nullable-with-non-null-default-bytes", + ], +) +def test_nullable_with_non_null_default_schema(path: str): + openapi = app.openapi() + body_model_name = get_body_model_name(openapi, path) + + assert app.openapi()["components"]["schemas"][body_model_name] == { + "properties": { + "file": { + "title": "File", + "anyOf": [{"type": "string", "format": "binary"}, {"type": "null"}], + "default": "default", # <= Default value for file looks strange to me + }, + "files": { + "title": "Files", + "anyOf": [ + {"type": "array", "items": {"type": "string", "format": "binary"}}, + {"type": "null"}, + ], + }, + }, + "title": body_model_name, + "type": "object", + } + + +@pytest.mark.parametrize( + "path", + [ + pytest.param( + "/nullable-with-non-null-default-bytes", + marks=pytest.mark.xfail( + reason="AttributeError: 'bytes' object has no attribute 'read'", + ), + ), + ], +) +def test_nullable_with_non_null_default_missing(path: str): + client = TestClient(app) + + with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert: + response = client.post(path) + + assert mock_convert.call_count == 0, ( # pragma: no cover + "Validator should not be called if the value is missing" + ) + assert response.status_code == 200 # pragma: no cover + assert response.json() == {"file": None, "files": None} # pragma: no cover + # TODO: Remove 'no cover' when the issue is fixed + + +@pytest.mark.parametrize( + "path", + [ + "/nullable-with-non-null-default-bytes", + ], +) +def test_nullable_with_non_null_default_pass_empty_file(path: str): + client = TestClient(app) + + with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert: + response = client.post( + path, + files=[("file", b""), ("files", b""), ("files", b"")], + ) + + assert mock_convert.call_count == 2, "Validator should be called for each field" + call_args = [call_args_item.args for call_args_item in mock_convert.call_args_list] + file_call_arg_1 = call_args[0][0] + files_call_arg_1 = call_args[1][0] + + assert ( + (file_call_arg_1 == b"") # file as bytes + or isinstance(file_call_arg_1, StarletteUploadFile) # file as UploadFile + ) + assert ( + (files_call_arg_1 == [b"", b""]) # files as bytes + or all( # files as UploadFile + isinstance(f, StarletteUploadFile) for f in files_call_arg_1 + ) + ) + + assert response.status_code == 200, response.text + assert response.json() == {"file": 0, "files": [0, 0]} + + +@pytest.mark.parametrize( + "path", + [ + "/nullable-with-non-null-default-bytes", + ], +) +def test_nullable_with_non_null_default_pass_file(path: str): + client = TestClient(app) + + with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert: + response = client.post( + path, + files=[("file", b"test 1"), ("files", b"test 2"), ("files", b"test 3")], + ) + + assert mock_convert.call_count == 2, "Validator should be called for each field" + assert response.status_code == 200, response.text + assert response.json() == {"file": 6, "files": [6, 6]} diff --git a/tests/test_request_params/test_form/test_nullable_and_defaults.py b/tests/test_request_params/test_form/test_nullable_and_defaults.py new file mode 100644 index 0000000000..d8147ca790 --- /dev/null +++ b/tests/test_request_params/test_form/test_nullable_and_defaults.py @@ -0,0 +1,609 @@ +from typing import Annotated, Any, Union +from unittest.mock import Mock, patch + +import pytest +from dirty_equals import IsList, IsOneOf, IsPartialDict +from fastapi import FastAPI, Form +from fastapi.testclient import TestClient +from pydantic import BaseModel, BeforeValidator, field_validator + +from .utils import get_body_model_name + +app = FastAPI() + + +def convert(v: Any) -> Any: + return v + + +# ===================================================================================== +# Nullable required + + +@app.post("/nullable-required") +async def read_nullable_required( + int_val: Annotated[ + Union[int, None], + Form(), + BeforeValidator(lambda v: convert(v)), + ], + str_val: Annotated[ + Union[str, None], + Form(), + BeforeValidator(lambda v: convert(v)), + ], + list_val: Annotated[ + Union[list[int], None], + Form(), + BeforeValidator(lambda v: convert(v)), + ], +): + return { + "int_val": int_val, + "str_val": str_val, + "list_val": list_val, + "fields_set": None, + } + + +class ModelNullableRequired(BaseModel): + int_val: Union[int, None] + str_val: Union[str, None] + list_val: Union[list[int], None] + + @field_validator("*", mode="before") + def convert_fields(cls, v): + return convert(v) + + +@app.post("/model-nullable-required") +async def read_model_nullable_required( + params: Annotated[ModelNullableRequired, Form()], +): + return { + "int_val": params.int_val, + "str_val": params.str_val, + "list_val": params.list_val, + "fields_set": params.model_fields_set, + } + + +@pytest.mark.parametrize( + "path", + [ + "/nullable-required", + "/model-nullable-required", + ], +) +def test_nullable_required_schema(path: str): + openapi = app.openapi() + body_model_name = get_body_model_name(openapi, path) + + assert app.openapi()["components"]["schemas"][body_model_name] == { + "properties": { + "int_val": { + "title": "Int Val", + "anyOf": [{"type": "integer"}, {"type": "null"}], + }, + "str_val": { + "title": "Str Val", + "anyOf": [{"type": "string"}, {"type": "null"}], + }, + "list_val": { + "title": "List Val", + "anyOf": [ + {"type": "array", "items": {"type": "integer"}}, + {"type": "null"}, + ], + }, + }, + "required": ["int_val", "str_val", "list_val"], + "title": body_model_name, + "type": "object", + } + + +@pytest.mark.parametrize( + "path", + [ + "/nullable-required", + "/model-nullable-required", + ], +) +def test_nullable_required_missing(path: str): + client = TestClient(app) + + with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert: + response = client.post(path) + + assert mock_convert.call_count == 0, ( + "Validator should not be called if the value is missing" + ) + assert response.status_code == 422 + assert response.json() == { + "detail": [ + { + "type": "missing", + "loc": ["body", "int_val"], + "msg": "Field required", + "input": IsOneOf(None, {}), + }, + { + "type": "missing", + "loc": ["body", "str_val"], + "msg": "Field required", + "input": IsOneOf(None, {}), + }, + { + "type": "missing", + "loc": ["body", "list_val"], + "msg": "Field required", + "input": IsOneOf(None, {}), + }, + ] + } + + +@pytest.mark.parametrize( + "path", + [ + pytest.param( + "/nullable-required", + marks=pytest.mark.xfail( + reason="Empty str is replaced with None, but then None gets dropped" + ), + ), + pytest.param( + "/model-nullable-required", + marks=pytest.mark.xfail( + reason="Empty strings are not replaced with None for models" + ), + ), + ], +) +def test_nullable_required_pass_empty_str(path: str): + client = TestClient(app) + + with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert: + response = client.post( + path, + data={ + "int_val": "", + "str_val": "", + "list_val": "0", # Empty strings are not treated as null for lists. It's Ok + }, + ) + + assert mock_convert.call_count == 3, "Validator should be called for each field" + assert mock_convert.call_args_list == [ + (""), # int_val + (""), # str_val + (["0"]), # list_val + ] + assert response.status_code == 200, response.text # pragma: no cover + assert response.json() == { # pragma: no cover + "int_val": None, + "str_val": None, + "list_val": [0], + "fields_set": IsOneOf( + None, IsList("int_val", "str_val", "list_val", check_order=False) + ), + } + # TODO: Remove 'no cover' when the issue is fixed + + +@pytest.mark.parametrize( + "path", + [ + "/nullable-required", + "/model-nullable-required", + ], +) +def test_nullable_required_pass_value(path: str): + client = TestClient(app) + + with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert: + response = client.post( + path, data={"int_val": "1", "str_val": "test", "list_val": ["1", "2"]} + ) + + assert mock_convert.call_count == 3, "Validator should be called for each field" + assert response.status_code == 200, response.text + assert response.json() == { + "int_val": 1, + "str_val": "test", + "list_val": [1, 2], + "fields_set": IsOneOf( + None, IsList("int_val", "str_val", "list_val", check_order=False) + ), + } + + +# ===================================================================================== +# Nullable with default=None + + +@app.post("/nullable-non-required") +async def read_nullable_non_required( + int_val: Annotated[ + Union[int, None], + Form(), + BeforeValidator(lambda v: convert(v)), + ] = None, + str_val: Annotated[ + Union[str, None], + Form(), + BeforeValidator(lambda v: convert(v)), + ] = None, + list_val: Annotated[ + Union[list[int], None], + Form(), + BeforeValidator(lambda v: convert(v)), + ] = None, +): + return { + "int_val": int_val, + "str_val": str_val, + "list_val": list_val, + "fields_set": None, + } + + +class ModelNullableNonRequired(BaseModel): + int_val: Union[int, None] = None + str_val: Union[str, None] = None + list_val: Union[list[int], None] = None + + @field_validator("*", mode="before") + def convert_fields(cls, v): + return convert(v) + + +@app.post("/model-nullable-non-required") +async def read_model_nullable_non_required( + params: Annotated[ModelNullableNonRequired, Form()], +): + return { + "int_val": params.int_val, + "str_val": params.str_val, + "list_val": params.list_val, + "fields_set": params.model_fields_set, + } + + +@pytest.mark.parametrize( + "path", + [ + "/nullable-non-required", + "/model-nullable-non-required", + ], +) +def test_nullable_non_required_schema(path: str): + openapi = app.openapi() + body_model_name = get_body_model_name(openapi, path) + + assert app.openapi()["components"]["schemas"][body_model_name] == { + "properties": { + "int_val": { + "title": "Int Val", + "anyOf": [{"type": "integer"}, {"type": "null"}], + # "default": None, # `None` values are omitted in OpenAPI schema + }, + "str_val": { + "title": "Str Val", + "anyOf": [{"type": "string"}, {"type": "null"}], + # "default": None, # `None` values are omitted in OpenAPI schema + }, + "list_val": { + "title": "List Val", + "anyOf": [ + {"type": "array", "items": {"type": "integer"}}, + {"type": "null"}, + ], + # "default": None, # `None` values are omitted in OpenAPI schema + }, + }, + "title": body_model_name, + "type": "object", + } + + +@pytest.mark.parametrize( + "path", + [ + "/nullable-non-required", + "/model-nullable-non-required", + ], +) +def test_nullable_non_required_missing(path: str): + client = TestClient(app) + + with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert: + response = client.post(path) + + assert mock_convert.call_count == 0, ( + "Validator should not be called if the value is missing" + ) + assert response.status_code == 200 + assert response.json() == { + "int_val": None, + "str_val": None, + "list_val": None, + "fields_set": IsOneOf(None, []), + } + + +@pytest.mark.parametrize( + "path", + [ + pytest.param( + "/nullable-non-required", + marks=pytest.mark.xfail( + reason="Empty str is replaced with None, but then None gets dropped" + ), + ), + pytest.param( + "/model-nullable-non-required", + marks=pytest.mark.xfail( + reason="Empty strings are not replaced with None for models" + ), + ), + ], +) +def test_nullable_non_required_pass_empty_str(path: str): + client = TestClient(app) + + with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert: + response = client.post( + path, + data={ + "int_val": "", + "str_val": "", + "list_val": "0", # Empty strings are not treated as null for lists. It's Ok + }, + ) + + assert mock_convert.call_count == 3, "Validator should be called for each field" + assert mock_convert.call_args_list == [ + (""), # int_val + (""), # str_val + (["0"]), # list_val + ] + assert response.status_code == 200, response.text # pragma: no cover + assert response.json() == { # pragma: no cover + "int_val": None, + "str_val": None, + "list_val": [0], + "fields_set": IsOneOf( + None, IsList("int_val", "str_val", "list_val", check_order=False) + ), + } + # TODO: Remove 'no cover' when the issue is fixed + + +@pytest.mark.parametrize( + "path", + [ + "/nullable-non-required", + "/model-nullable-non-required", + ], +) +def test_nullable_non_required_pass_value(path: str): + client = TestClient(app) + + with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert: + response = client.post( + path, data={"int_val": "1", "str_val": "test", "list_val": ["1", "2"]} + ) + + assert mock_convert.call_count == 3, "Validator should be called for each field" + assert response.status_code == 200, response.text + assert response.json() == { + "int_val": 1, + "str_val": "test", + "list_val": [1, 2], + "fields_set": IsOneOf( + None, IsList("int_val", "str_val", "list_val", check_order=False) + ), + } + + +# ===================================================================================== +# Nullable with not-None default + + +@app.post("/nullable-with-non-null-default") +async def read_nullable_with_non_null_default( + *, + int_val: Annotated[ + Union[int, None], + Form(), + BeforeValidator(lambda v: convert(v)), + ] = -1, + str_val: Annotated[ + Union[str, None], + Form(), + BeforeValidator(lambda v: convert(v)), + ] = "default", + list_val: Annotated[ + Union[list[int], None], + Form(default_factory=lambda: [0]), + BeforeValidator(lambda v: convert(v)), + ], +): + return { + "int_val": int_val, + "str_val": str_val, + "list_val": list_val, + "fields_set": None, + } + + +class ModelNullableWithNonNullDefault(BaseModel): + int_val: Union[int, None] = -1 + str_val: Union[str, None] = "default" + list_val: Union[list[int], None] = [0] + + @field_validator("*", mode="before") + def convert_fields(cls, v): + return convert(v) + + +@app.post("/model-nullable-with-non-null-default") +async def read_model_nullable_with_non_null_default( + params: Annotated[ModelNullableWithNonNullDefault, Form()], +): + return { + "int_val": params.int_val, + "str_val": params.str_val, + "list_val": params.list_val, + "fields_set": params.model_fields_set, + } + + +@pytest.mark.parametrize( + "path", + [ + "/nullable-with-non-null-default", + "/model-nullable-with-non-null-default", + ], +) +def test_nullable_with_non_null_default_schema(path: str): + openapi = app.openapi() + body_model_name = get_body_model_name(openapi, path) + body_model = app.openapi()["components"]["schemas"][body_model_name] + + assert body_model == { + "properties": { + "int_val": { + "title": "Int Val", + "anyOf": [{"type": "integer"}, {"type": "null"}], + "default": -1, + }, + "str_val": { + "title": "Str Val", + "anyOf": [{"type": "string"}, {"type": "null"}], + "default": "default", + }, + "list_val": IsPartialDict( + { + "title": "List Val", + "anyOf": [ + {"type": "array", "items": {"type": "integer"}}, + {"type": "null"}, + ], + } + ), + }, + "title": body_model_name, + "type": "object", + } + + if path == "/model-nullable-with-non-null-default": + # Check default value for list_val param for model-based Body parameters only. + # default_factory is not reflected in OpenAPI schema + assert body_model["properties"]["list_val"]["default"] == [0] + + +@pytest.mark.parametrize( + "path", + [ + "/nullable-with-non-null-default", + "/model-nullable-with-non-null-default", + ], +) +@pytest.mark.xfail( + reason="Missing parameters are pre-populated with default values before validation" +) +def test_nullable_with_non_null_default_missing(path: str): + client = TestClient(app) + + with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert: + response = client.post(path) + + assert mock_convert.call_count == 0, ( + "Validator should not be called if the value is missing" + ) + assert response.status_code == 200 # pragma: no cover + assert response.json() == { # pragma: no cover + "int_val": -1, + "str_val": "default", + "list_val": [0], + "fields_set": IsOneOf(None, []), + } + # TODO: Remove 'no cover' when the issue is fixed + + +@pytest.mark.parametrize( + "path", + [ + pytest.param( + "/nullable-with-non-null-default", + marks=pytest.mark.xfail( + reason="Empty str is replaced with default value, not with None" # Is this correct ??? + ), + ), + pytest.param( + "/model-nullable-with-non-null-default", + marks=pytest.mark.xfail( + reason="Empty strings are not replaced with None for models" + ), + ), + ], +) +def test_nullable_with_non_null_default_pass_empty_str(path: str): + client = TestClient(app) + + with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert: + response = client.post( + path, + data={ + "int_val": "", + "str_val": "", + "list_val": "0", # Empty strings are not treated as null for lists. It's Ok + }, + ) + + assert mock_convert.call_count == 3, "Validator should be called for each field" + assert mock_convert.call_args_list == [ + (""), # int_val + (""), # str_val + (["0"]), # list_val + ] + assert response.status_code == 200, response.text # pragma: no cover + assert response.json() == { # pragma: no cover + "int_val": None, + "str_val": None, + "list_val": [0], + "fields_set": IsOneOf( + None, IsList("int_val", "str_val", "list_val", check_order=False) + ), + } + # TODO: Remove 'no cover' when the issue is fixed + + +@pytest.mark.parametrize( + "path", + [ + "/nullable-with-non-null-default", + "/model-nullable-with-non-null-default", + ], +) +def test_nullable_with_non_null_default_pass_value(path: str): + client = TestClient(app) + + with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert: + response = client.post( + path, data={"int_val": "1", "str_val": "test", "list_val": ["1", "2"]} + ) + + assert mock_convert.call_count == 3, "Validator should be called for each field" + assert response.status_code == 200, response.text + assert response.json() == { + "int_val": 1, + "str_val": "test", + "list_val": [1, 2], + "fields_set": IsOneOf( + None, IsList("int_val", "str_val", "list_val", check_order=False) + ), + } diff --git a/tests/test_request_params/test_header/test_nullable_and_defaults.py b/tests/test_request_params/test_header/test_nullable_and_defaults.py new file mode 100644 index 0000000000..5aaae8592a --- /dev/null +++ b/tests/test_request_params/test_header/test_nullable_and_defaults.py @@ -0,0 +1,526 @@ +from typing import Annotated, Any, Union +from unittest.mock import Mock, patch + +import pytest +from dirty_equals import AnyThing, IsList, IsOneOf, IsPartialDict +from fastapi import FastAPI, Header +from fastapi.testclient import TestClient +from pydantic import BaseModel, BeforeValidator, field_validator + +app = FastAPI() + + +def convert(v: Any) -> Any: + return v + + +# ===================================================================================== +# Nullable required + + +@app.get("/nullable-required") +async def read_nullable_required( + int_val: Annotated[ + Union[int, None], + Header(), + BeforeValidator(lambda v: convert(v)), + ], + str_val: Annotated[ + Union[str, None], + Header(), + BeforeValidator(lambda v: convert(v)), + ], + list_val: Annotated[ + Union[list[int], None], + Header(), + BeforeValidator(lambda v: convert(v)), + ], +): + return { + "int_val": int_val, + "str_val": str_val, + "list_val": list_val, + "fields_set": None, + } + + +class ModelNullableRequired(BaseModel): + int_val: Union[int, None] + str_val: Union[str, None] + list_val: Union[list[int], None] + + @field_validator("*", mode="before") + @classmethod + def convert_fields(cls, v): + return convert(v) + + +@app.get("/model-nullable-required") +async def read_model_nullable_required( + params: Annotated[ModelNullableRequired, Header()], +): + return { + "int_val": params.int_val, + "str_val": params.str_val, + "list_val": params.list_val, + "fields_set": params.model_fields_set, + } + + +@pytest.mark.parametrize( + "path", + [ + pytest.param( + "/nullable-required", + marks=pytest.mark.xfail( + reason="Title contains hyphens for single Header parameters" + ), + ), + "/model-nullable-required", + ], +) +def test_nullable_required_schema(path: str): + assert app.openapi()["paths"][path]["get"]["parameters"] == [ + { + "required": True, + "schema": { + "title": "Int Val", + "anyOf": [{"type": "integer"}, {"type": "null"}], + }, + "name": "int-val", + "in": "header", + }, + { + "required": True, + "schema": { + "title": "Str Val", + "anyOf": [{"type": "string"}, {"type": "null"}], + }, + "name": "str-val", + "in": "header", + }, + { + "required": True, + "schema": { + "title": "List Val", + "anyOf": [ + {"type": "array", "items": {"type": "integer"}}, + {"type": "null"}, + ], + }, + "name": "list-val", + "in": "header", + }, + ] + + +@pytest.mark.parametrize( + "path", + [ + "/nullable-required", + pytest.param( + "/model-nullable-required", + marks=pytest.mark.xfail( + reason="With Header model fields use underscores in error locs for headers" + ), + ), + ], +) +def test_nullable_required_missing(path: str): + client = TestClient(app) + with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert: + response = client.get(path) + + assert mock_convert.call_count == 0, ( + "Validator should not be called if the value is missing" + ) + assert response.status_code == 422 + assert response.json() == { + "detail": [ + { + "type": "missing", + "loc": ["header", "int-val"], + "msg": "Field required", + "input": AnyThing(), + }, + { + "type": "missing", + "loc": ["header", "str-val"], + "msg": "Field required", + "input": AnyThing(), + }, + { + "type": "missing", + "loc": ["header", "list-val"], + "msg": "Field required", + "input": AnyThing(), + }, + ] + } + + +@pytest.mark.parametrize( + "path", + [ + "/nullable-required", + "/model-nullable-required", + ], +) +def test_nullable_required_pass_value(path: str): + client = TestClient(app) + + with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert: + response = client.get( + path, + headers=[ + ("int-val", "1"), + ("str-val", "test"), + ("list-val", "1"), + ("list-val", "2"), + ], + ) + + assert mock_convert.call_count == 3, "Validator should be called for each field" + assert response.status_code == 200, response.text + assert response.json() == { + "int_val": 1, + "str_val": "test", + "list_val": [1, 2], + "fields_set": IsOneOf( + None, IsList("int_val", "str_val", "list_val", check_order=False) + ), + } + + +# ===================================================================================== +# Nullable with default=None + + +@app.get("/nullable-non-required") +async def read_nullable_non_required( + int_val: Annotated[ + Union[int, None], + Header(), + BeforeValidator(lambda v: convert(v)), + ] = None, + str_val: Annotated[ + Union[str, None], + Header(), + BeforeValidator(lambda v: convert(v)), + ] = None, + list_val: Annotated[ + Union[list[int], None], + Header(), + BeforeValidator(lambda v: convert(v)), + ] = None, +): + return { + "int_val": int_val, + "str_val": str_val, + "list_val": list_val, + "fields_set": None, + } + + +class ModelNullableNonRequired(BaseModel): + int_val: Union[int, None] = None + str_val: Union[str, None] = None + list_val: Union[list[int], None] = None + + @field_validator("*", mode="before") + @classmethod + def convert_fields(cls, v): + return convert(v) + + +@app.get("/model-nullable-non-required") +async def read_model_nullable_non_required( + params: Annotated[ModelNullableNonRequired, Header()], +): + return { + "int_val": params.int_val, + "str_val": params.str_val, + "list_val": params.list_val, + "fields_set": params.model_fields_set, + } + + +@pytest.mark.parametrize( + "path", + [ + pytest.param( + "/nullable-non-required", + marks=pytest.mark.xfail( + reason="Title contains hyphens for single Header parameters" + ), + ), + "/model-nullable-non-required", + ], +) +def test_nullable_non_required_schema(path: str): + assert app.openapi()["paths"][path]["get"]["parameters"] == [ + { + "required": False, + "schema": { + "title": "Int Val", + "anyOf": [{"type": "integer"}, {"type": "null"}], + # "default": None, # `None` values are omitted in OpenAPI schema + }, + "name": "int-val", + "in": "header", + }, + { + "required": False, + "schema": { + "title": "Str Val", + "anyOf": [{"type": "string"}, {"type": "null"}], + # "default": None, # `None` values are omitted in OpenAPI schema + }, + "name": "str-val", + "in": "header", + }, + { + "required": False, + "schema": { + "title": "List Val", + "anyOf": [ + {"type": "array", "items": {"type": "integer"}}, + {"type": "null"}, + ], + # "default": None, # `None` values are omitted in OpenAPI schema + }, + "name": "list-val", + "in": "header", + }, + ] + + +@pytest.mark.parametrize( + "path", + [ + "/nullable-non-required", + "/model-nullable-non-required", + ], +) +def test_nullable_non_required_missing(path: str): + client = TestClient(app) + + with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert: + response = client.get(path) + + assert mock_convert.call_count == 0, ( + "Validator should not be called if the value is missing" + ) + assert response.status_code == 200 + assert response.json() == { + "int_val": None, + "str_val": None, + "list_val": None, + "fields_set": IsOneOf(None, []), + } + + +@pytest.mark.parametrize( + "path", + [ + "/nullable-non-required", + "/model-nullable-non-required", + ], +) +def test_nullable_non_required_pass_value(path: str): + client = TestClient(app) + + with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert: + response = client.get( + path, + headers=[ + ("int-val", "1"), + ("str-val", "test"), + ("list-val", "1"), + ("list-val", "2"), + ], + ) + + assert mock_convert.call_count == 3, "Validator should be called for each field" + assert response.status_code == 200, response.text + assert response.json() == { + "int_val": 1, + "str_val": "test", + "list_val": [1, 2], + "fields_set": IsOneOf( + None, IsList("int_val", "str_val", "list_val", check_order=False) + ), + } + + +# ===================================================================================== +# Nullable with not-None default + + +@app.get("/nullable-with-non-null-default") +async def read_nullable_with_non_null_default( + *, + int_val: Annotated[ + Union[int, None], + Header(), + BeforeValidator(lambda v: convert(v)), + ] = -1, + str_val: Annotated[ + Union[str, None], + Header(), + BeforeValidator(lambda v: convert(v)), + ] = "default", + list_val: Annotated[ + Union[list[int], None], + Header(default_factory=lambda: [0]), + BeforeValidator(lambda v: convert(v)), + ], +): + return { + "int_val": int_val, + "str_val": str_val, + "list_val": list_val, + "fields_set": None, + } + + +class ModelNullableWithNonNullDefault(BaseModel): + int_val: Union[int, None] = -1 + str_val: Union[str, None] = "default" + list_val: Union[list[int], None] = [0] + + @field_validator("*", mode="before") + @classmethod + def convert_fields(cls, v): + return convert(v) + + +@app.get("/model-nullable-with-non-null-default") +async def read_model_nullable_with_non_null_default( + params: Annotated[ModelNullableWithNonNullDefault, Header()], +): + return { + "int_val": params.int_val, + "str_val": params.str_val, + "list_val": params.list_val, + "fields_set": params.model_fields_set, + } + + +@pytest.mark.parametrize( + "path", + [ + pytest.param( + "/nullable-with-non-null-default", + marks=pytest.mark.xfail( + reason="Title contains hyphens for single Header parameters" + ), + ), + "/model-nullable-with-non-null-default", + ], +) +def test_nullable_with_non_null_default_schema(path: str): + parameters = app.openapi()["paths"][path]["get"]["parameters"] + assert parameters == [ + { + "required": False, + "schema": { + "title": "Int Val", + "anyOf": [{"type": "integer"}, {"type": "null"}], + "default": -1, + }, + "name": "int-val", + "in": "header", + }, + { + "required": False, + "schema": { + "title": "Str Val", + "anyOf": [{"type": "string"}, {"type": "null"}], + "default": "default", + }, + "name": "str-val", + "in": "header", + }, + { + "required": False, + "schema": IsPartialDict( + { + "title": "List Val", + "anyOf": [ + {"type": "array", "items": {"type": "integer"}}, + {"type": "null"}, + ], + } + ), + "name": "list-val", + "in": "header", + }, + ] + + if path == "/model-nullable-with-non-null-default": + # Check default value for list_val param for model-based Body parameters only. + # default_factory is not reflected in OpenAPI schema + assert parameters[2]["schema"]["default"] == [0] + + +@pytest.mark.parametrize( + "path", + [ + "/nullable-with-non-null-default", + "/model-nullable-with-non-null-default", + ], +) +@pytest.mark.xfail( + reason="Missing parameters are pre-populated with default values before validation" +) +def test_nullable_with_non_null_default_missing(path: str): + client = TestClient(app) + + with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert: + response = client.get(path) + + assert mock_convert.call_count == 0, ( + "Validator should not be called if the value is missing" + ) + assert response.status_code == 200 # pragma: no cover + assert response.json() == { # pragma: no cover + "int_val": -1, + "str_val": "default", + "list_val": [0], + "fields_set": IsOneOf(None, []), + } + # TODO: Remove 'no cover' when the issue is fixed + + +@pytest.mark.parametrize( + "path", + [ + "/nullable-with-non-null-default", + "/model-nullable-with-non-null-default", + ], +) +def test_nullable_with_non_null_default_pass_value(path: str): + client = TestClient(app) + + with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert: + response = client.get( + path, + headers=[ + ("int-val", "1"), + ("str-val", "test"), + ("list-val", "1"), + ("list-val", "2"), + ], + ) + + assert mock_convert.call_count == 3, "Validator should be called for each field" + assert response.status_code == 200, response.text + assert response.json() == { + "int_val": 1, + "str_val": "test", + "list_val": [1, 2], + "fields_set": IsOneOf( + None, IsList("int_val", "str_val", "list_val", check_order=False) + ), + } diff --git a/tests/test_request_params/test_path/test_nullable_and_defaults.py b/tests/test_request_params/test_path/test_nullable_and_defaults.py new file mode 100644 index 0000000000..e065538222 --- /dev/null +++ b/tests/test_request_params/test_path/test_nullable_and_defaults.py @@ -0,0 +1,2 @@ +# Not appllicable for Path parameters +# Path parameters cannot have default values or be nullable diff --git a/tests/test_request_params/test_query/test_nullable_and_defaults.py b/tests/test_request_params/test_query/test_nullable_and_defaults.py new file mode 100644 index 0000000000..e72bc2d629 --- /dev/null +++ b/tests/test_request_params/test_query/test_nullable_and_defaults.py @@ -0,0 +1,483 @@ +from typing import Annotated, Any, Union +from unittest.mock import Mock, patch + +import pytest +from dirty_equals import IsList, IsOneOf, IsPartialDict +from fastapi import FastAPI, Query +from fastapi.testclient import TestClient +from pydantic import BaseModel, BeforeValidator, field_validator + +app = FastAPI() + + +def convert(v: Any) -> Any: + return v + + +# ===================================================================================== +# Nullable required + + +@app.get("/nullable-required") +async def read_nullable_required( + int_val: Annotated[ + Union[int, None], + BeforeValidator(lambda v: convert(v)), + ], + str_val: Annotated[ + Union[str, None], + BeforeValidator(lambda v: convert(v)), + ], + list_val: Annotated[ + Union[list[int], None], + Query(), + BeforeValidator(lambda v: convert(v)), + ], +): + return { + "int_val": int_val, + "str_val": str_val, + "list_val": list_val, + "fields_set": None, + } + + +class ModelNullableRequired(BaseModel): + int_val: Union[int, None] + str_val: Union[str, None] + list_val: Union[list[int], None] + + @field_validator("*", mode="before") + @classmethod + def convert_all(cls, v: Any) -> Any: + return convert(v) + + +@app.get("/model-nullable-required") +async def read_model_nullable_required( + params: Annotated[ModelNullableRequired, Query()], +): + return { + "int_val": params.int_val, + "str_val": params.str_val, + "list_val": params.list_val, + "fields_set": params.model_fields_set, + } + + +@pytest.mark.parametrize( + "path", + [ + "/nullable-required", + "/model-nullable-required", + ], +) +def test_nullable_required_schema(path: str): + assert app.openapi()["paths"][path]["get"]["parameters"] == [ + { + "required": True, + "schema": { + "title": "Int Val", + "anyOf": [{"type": "integer"}, {"type": "null"}], + }, + "name": "int_val", + "in": "query", + }, + { + "required": True, + "schema": { + "title": "Str Val", + "anyOf": [{"type": "string"}, {"type": "null"}], + }, + "name": "str_val", + "in": "query", + }, + { + "in": "query", + "name": "list_val", + "required": True, + "schema": { + "anyOf": [ + {"items": {"type": "integer"}, "type": "array"}, + {"type": "null"}, + ], + "title": "List Val", + }, + }, + ] + + +@pytest.mark.parametrize( + "path", + [ + "/nullable-required", + "/model-nullable-required", + ], +) +def test_nullable_required_missing(path: str): + client = TestClient(app) + + with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert: + response = client.get(path) + + assert mock_convert.call_count == 0, ( + "Validator should not be called if the value is missing" + ) + assert response.status_code == 422 + assert response.json() == { + "detail": [ + { + "type": "missing", + "loc": ["query", "int_val"], + "msg": "Field required", + "input": IsOneOf(None, {}), + }, + { + "type": "missing", + "loc": ["query", "str_val"], + "msg": "Field required", + "input": IsOneOf(None, {}), + }, + { + "type": "missing", + "loc": ["query", "list_val"], + "msg": "Field required", + "input": IsOneOf(None, {}), + }, + ] + } + + +@pytest.mark.parametrize( + "path", + [ + "/nullable-required", + "/model-nullable-required", + ], +) +def test_nullable_required_pass_value(path: str): + client = TestClient(app) + + with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert: + response = client.get( + path, params={"int_val": "1", "str_val": "test", "list_val": ["1", "2"]} + ) + + assert mock_convert.call_count == 3, "Validator should be called for each field" + assert response.status_code == 200, response.text + assert response.json() == { + "int_val": 1, + "str_val": "test", + "list_val": [1, 2], + "fields_set": IsOneOf( + None, IsList("int_val", "str_val", "list_val", check_order=False) + ), + } + + +# ===================================================================================== +# Nullable with default=None + + +@app.get("/nullable-non-required") +async def read_nullable_non_required( + int_val: Annotated[ + Union[int, None], + BeforeValidator(lambda v: convert(v)), + ] = None, + str_val: Annotated[ + Union[str, None], + BeforeValidator(lambda v: convert(v)), + ] = None, + list_val: Annotated[ + Union[list[int], None], + Query(), + BeforeValidator(lambda v: convert(v)), + ] = None, +): + return { + "int_val": int_val, + "str_val": str_val, + "list_val": list_val, + "fields_set": None, + } + + +class ModelNullableNonRequired(BaseModel): + int_val: Union[int, None] = None + str_val: Union[str, None] = None + list_val: Union[list[int], None] = None + + @field_validator("*", mode="before") + @classmethod + def convert_all(cls, v: Any) -> Any: + return convert(v) + + +@app.get("/model-nullable-non-required") +async def read_model_nullable_non_required( + params: Annotated[ModelNullableNonRequired, Query()], +): + return { + "int_val": params.int_val, + "str_val": params.str_val, + "list_val": params.list_val, + "fields_set": params.model_fields_set, + } + + +@pytest.mark.parametrize( + "path", + [ + "/nullable-non-required", + "/model-nullable-non-required", + ], +) +def test_nullable_non_required_schema(path: str): + assert app.openapi()["paths"][path]["get"]["parameters"] == [ + { + "required": False, + "schema": { + "title": "Int Val", + "anyOf": [{"type": "integer"}, {"type": "null"}], + # "default": None, # `None` values are omitted in OpenAPI schema + }, + "name": "int_val", + "in": "query", + }, + { + "required": False, + "schema": { + "title": "Str Val", + "anyOf": [{"type": "string"}, {"type": "null"}], + # "default": None, # `None` values are omitted in OpenAPI schema + }, + "name": "str_val", + "in": "query", + }, + { + "in": "query", + "name": "list_val", + "required": False, + "schema": { + "anyOf": [ + {"items": {"type": "integer"}, "type": "array"}, + {"type": "null"}, + ], + "title": "List Val", + # "default": None, # `None` values are omitted in OpenAPI schema + }, + }, + ] + + +@pytest.mark.parametrize( + "path", + [ + "/nullable-non-required", + "/model-nullable-non-required", + ], +) +def test_nullable_non_required_missing(path: str): + client = TestClient(app) + + with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert: + response = client.get(path) + + assert mock_convert.call_count == 0, ( + "Validator should not be called if the value is missing" + ) + assert response.status_code == 200 + assert response.json() == { + "int_val": None, + "str_val": None, + "list_val": None, + "fields_set": IsOneOf(None, []), + } + + +@pytest.mark.parametrize( + "path", + [ + "/nullable-non-required", + "/model-nullable-non-required", + ], +) +def test_nullable_non_required_pass_value(path: str): + client = TestClient(app) + + with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert: + response = client.get( + path, params={"int_val": "1", "str_val": "test", "list_val": ["1", "2"]} + ) + + assert mock_convert.call_count == 3, "Validator should be called for each field" + assert response.status_code == 200, response.text + assert response.json() == { + "int_val": 1, + "str_val": "test", + "list_val": [1, 2], + "fields_set": IsOneOf( + None, IsList("int_val", "str_val", "list_val", check_order=False) + ), + } + + +# ===================================================================================== +# Nullable with not-None default + + +@app.get("/nullable-with-non-null-default") +async def read_nullable_with_non_null_default( + *, + int_val: Annotated[ + Union[int, None], + BeforeValidator(lambda v: convert(v)), + ] = -1, + str_val: Annotated[ + Union[str, None], + BeforeValidator(lambda v: convert(v)), + ] = "default", + list_val: Annotated[ + Union[list[int], None], + Query(default_factory=lambda: [0]), + BeforeValidator(lambda v: convert(v)), + ], +): + return { + "int_val": int_val, + "str_val": str_val, + "list_val": list_val, + "fields_set": None, + } + + +class ModelNullableWithNonNullDefault(BaseModel): + int_val: Union[int, None] = -1 + str_val: Union[str, None] = "default" + list_val: Union[list[int], None] = [0] + + @field_validator("*", mode="before") + @classmethod + def convert_all(cls, v: Any) -> Any: + return convert(v) + + +@app.get("/model-nullable-with-non-null-default") +async def read_model_nullable_with_non_null_default( + params: Annotated[ModelNullableWithNonNullDefault, Query()], +): + return { + "int_val": params.int_val, + "str_val": params.str_val, + "list_val": params.list_val, + "fields_set": params.model_fields_set, + } + + +@pytest.mark.parametrize( + "path", + [ + "/nullable-with-non-null-default", + "/model-nullable-with-non-null-default", + ], +) +def test_nullable_with_non_null_default_schema(path: str): + parameters = app.openapi()["paths"][path]["get"]["parameters"] + assert parameters == [ + { + "required": False, + "schema": { + "title": "Int Val", + "anyOf": [{"type": "integer"}, {"type": "null"}], + "default": -1, + }, + "name": "int_val", + "in": "query", + }, + { + "required": False, + "schema": { + "title": "Str Val", + "anyOf": [{"type": "string"}, {"type": "null"}], + "default": "default", + }, + "name": "str_val", + "in": "query", + }, + { + "in": "query", + "name": "list_val", + "required": False, + "schema": IsPartialDict( + { + "anyOf": [ + {"items": {"type": "integer"}, "type": "array"}, + {"type": "null"}, + ], + "title": "List Val", + } + ), + }, + ] + + if path == "/model-nullable-with-non-null-default": + # Check default value for list_val param for model-based Body parameters only. + # default_factory is not reflected in OpenAPI schema + assert parameters[2]["schema"]["default"] == [0] + + +@pytest.mark.parametrize( + "path", + [ + "/nullable-with-non-null-default", + "/model-nullable-with-non-null-default", + ], +) +@pytest.mark.xfail( + reason="Missing parameters are pre-populated with default values before validation" +) +def test_nullable_with_non_null_default_missing(path: str): + client = TestClient(app) + + with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert: + response = client.get(path) + + assert mock_convert.call_count == 0, ( + "Validator should not be called if the value is missing" + ) + assert response.status_code == 200 # pragma: no cover + assert response.json() == { # pragma: no cover + "int_val": -1, + "str_val": "default", + "list_val": [0], + "fields_set": IsOneOf(None, []), + } + # TODO: Remove 'no cover' when the issue is fixed + + +@pytest.mark.parametrize( + "path", + [ + "/nullable-with-non-null-default", + "/model-nullable-with-non-null-default", + ], +) +def test_nullable_with_non_null_default_pass_value(path: str): + client = TestClient(app) + + with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert: + response = client.get( + path, params={"int_val": "1", "str_val": "test", "list_val": ["1", "2"]} + ) + + assert mock_convert.call_count == 3, "Validator should be called for each field" + assert response.status_code == 200, response.text + assert response.json() == { + "int_val": 1, + "str_val": "test", + "list_val": [1, 2], + "fields_set": IsOneOf( + None, IsList("int_val", "str_val", "list_val", check_order=False) + ), + }