From 38610a8fd471af60e0a6b98b08fdb1a16942c8ca Mon Sep 17 00:00:00 2001 From: merlinz01 <158784988+merlinz01@users.noreply.github.com> Date: Thu, 5 Sep 2024 14:32:14 -0400 Subject: [PATCH 1/8] Pass None instead of the default value to parameters that accept it when null is given Signed-off-by: merlinz01 <158784988+merlinz01@users.noreply.github.com> --- fastapi/dependencies/utils.py | 37 ++++++++++++---- tests/test_none_passed_when_null_received.py | 44 ++++++++++++++++++++ 2 files changed, 72 insertions(+), 9 deletions(-) create mode 100644 tests/test_none_passed_when_null_received.py diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index cc7e55b4b..b4f1938b4 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -1,6 +1,7 @@ import dataclasses import inspect import sys +import types from contextlib import AsyncExitStack, contextmanager from copy import copy, deepcopy from dataclasses import dataclass @@ -27,6 +28,7 @@ from fastapi._compat import ( ModelField, RequiredParam, Undefined, + UndefinedType, _is_error_wrapper, _is_model_class, copy_field_info, @@ -595,7 +597,7 @@ async def solve_dependencies( *, request: Union[Request, WebSocket], dependant: Dependant, - body: Optional[Union[Dict[str, Any], FormData]] = None, + body: Optional[Union[Dict[str, Any], FormData, UndefinedType]] = Undefined, background_tasks: Optional[StarletteBackgroundTasks] = None, response: Optional[Response] = None, dependency_overrides_provider: Optional[Any] = None, @@ -731,10 +733,24 @@ async def solve_dependencies( ) +def _allows_none(field: ModelField) -> bool: + origin = get_origin(field.type_) + return (origin is Union or origin is types.UnionType) and type(None) in get_args( + field.type_ + ) + + def _validate_value_with_model_field( *, field: ModelField, value: Any, values: Dict[str, Any], loc: Tuple[str, ...] ) -> Tuple[Any, List[Any]]: + if value is Undefined: + if field.required: + return None, [get_missing_field_error(loc=loc)] + else: + return deepcopy(field.default), [] if value is None: + if _allows_none(field): + return value, [] if field.required: return None, [get_missing_field_error(loc=loc)] else: @@ -753,12 +769,13 @@ def _get_multidict_value( field: ModelField, values: Mapping[str, Any], alias: Union[str, None] = None ) -> Any: alias = alias or get_validation_alias(field) + value: Any if is_sequence_field(field) and isinstance(values, (ImmutableMultiDict, Headers)): value = values.getlist(alias) else: - value = values.get(alias, None) + value = values.get(alias, Undefined) if ( - value is None + value is Undefined or ( isinstance(field.field_info, (params.Form, temp_pydantic_v1_params.Form)) and isinstance(value, str) # For type checks @@ -767,7 +784,7 @@ def _get_multidict_value( or (is_sequence_field(field) and len(value) == 0) ): if field.required: - return + return Undefined else: return deepcopy(field.default) return value @@ -933,7 +950,7 @@ async def _extract_form_body( for sub_value in value: tg.start_soon(process_fn, sub_value.read) value = serialize_sequence_value(field=field, value=results) - if value is not None: + if value is not Undefined and value is not None: values[get_validation_alias(field)] = value field_aliases = {get_validation_alias(field) for field in body_fields} for key in received_body.keys(): @@ -948,7 +965,7 @@ async def _extract_form_body( async def request_body_to_args( body_fields: List[ModelField], - received_body: Optional[Union[Dict[str, Any], FormData]], + received_body: Optional[Union[Dict[str, Any], FormData, UndefinedType]], embed_body_fields: bool, ) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]: values: Dict[str, Any] = {} @@ -978,10 +995,12 @@ async def request_body_to_args( return {first_field.name: v_}, errors_ for field in body_fields: loc = ("body", get_validation_alias(field)) - value: Optional[Any] = None - if body_to_process is not None: + value: Optional[Any] = Undefined + if body_to_process is not None and not isinstance( + body_to_process, UndefinedType + ): try: - value = body_to_process.get(get_validation_alias(field)) + value = body_to_process.get(get_validation_alias(field), Undefined) # If the received body is a list, not a dict except AttributeError: errors.append(get_missing_field_error(loc)) diff --git a/tests/test_none_passed_when_null_received.py b/tests/test_none_passed_when_null_received.py new file mode 100644 index 000000000..f4e171202 --- /dev/null +++ b/tests/test_none_passed_when_null_received.py @@ -0,0 +1,44 @@ +from typing import Annotated, Optional, Union + +import pytest +from fastapi import Body, FastAPI +from fastapi.testclient import TestClient + +app = FastAPI() +DEFAULT = 1234567890 + + +@app.post("/api1") +def api1(integer_or_null: Annotated[int | None, Body(embed=True)] = DEFAULT) -> dict: + return {"received": integer_or_null} + + +@app.post("/api2") +def api2(integer_or_null: Annotated[Optional[int], Body(embed=True)] = DEFAULT) -> dict: + return {"received": integer_or_null} + + +@app.post("/api3") +def api3( + integer_or_null: Annotated[Union[int, None], Body(embed=True)] = DEFAULT, +) -> dict: + return {"received": integer_or_null} + + +@app.post("/api4") +def api4(integer_or_null: Optional[int] = Body(embed=True, default=DEFAULT)) -> dict: + return {"received": integer_or_null} + + +client = TestClient(app) + + +@pytest.mark.parametrize("api", ["/api1", "/api2", "/api3", "/api4"]) +def test_api1_integer(api): + response = client.post(api, json={"integer_or_null": 100}) + assert response.status_code == 200, response.text + assert response.json() == {"received": 100} + + response = client.post(api, json={"integer_or_null": None}) + assert response.status_code == 200, response.text + assert response.json() == {"received": None} From 826f11a611a9df58bb9aec6798cf11ecb7fdda2c Mon Sep 17 00:00:00 2001 From: merlinz01 <158784988+merlinz01@users.noreply.github.com> Date: Thu, 5 Sep 2024 21:36:10 -0400 Subject: [PATCH 2/8] make tests compatible with Python 3.8 and 3.9 Signed-off-by: merlinz01 <158784988+merlinz01@users.noreply.github.com> --- tests/test_none_passed_when_null_received.py | 46 ++++++++++++++------ 1 file changed, 33 insertions(+), 13 deletions(-) diff --git a/tests/test_none_passed_when_null_received.py b/tests/test_none_passed_when_null_received.py index f4e171202..4d1b3b0a8 100644 --- a/tests/test_none_passed_when_null_received.py +++ b/tests/test_none_passed_when_null_received.py @@ -1,4 +1,5 @@ -from typing import Annotated, Optional, Union +import sys +from typing import Optional, Union import pytest from fastapi import Body, FastAPI @@ -7,22 +8,38 @@ from fastapi.testclient import TestClient app = FastAPI() DEFAULT = 1234567890 +endpoints = [] -@app.post("/api1") -def api1(integer_or_null: Annotated[int | None, Body(embed=True)] = DEFAULT) -> dict: - return {"received": integer_or_null} +if sys.hexversion >= 0x31000000: + from typing import Annotated + + @app.post("/api1") + def api1( + integer_or_null: Annotated[int | None, Body(embed=True)] = DEFAULT, + ) -> dict: + return {"received": integer_or_null} + + endpoints.append("/api1") -@app.post("/api2") -def api2(integer_or_null: Annotated[Optional[int], Body(embed=True)] = DEFAULT) -> dict: - return {"received": integer_or_null} +if sys.hexversion >= 0x30900000: + from typing import Annotated + @app.post("/api2") + def api2( + integer_or_null: Annotated[Optional[int], Body(embed=True)] = DEFAULT, + ) -> dict: + return {"received": integer_or_null} -@app.post("/api3") -def api3( - integer_or_null: Annotated[Union[int, None], Body(embed=True)] = DEFAULT, -) -> dict: - return {"received": integer_or_null} + endpoints.append("/api2") + + @app.post("/api3") + def api3( + integer_or_null: Annotated[Union[int, None], Body(embed=True)] = DEFAULT, + ) -> dict: + return {"received": integer_or_null} + + endpoints.append("/api3") @app.post("/api4") @@ -30,10 +47,13 @@ def api4(integer_or_null: Optional[int] = Body(embed=True, default=DEFAULT)) -> return {"received": integer_or_null} +endpoints.append("/api4") + + client = TestClient(app) -@pytest.mark.parametrize("api", ["/api1", "/api2", "/api3", "/api4"]) +@pytest.mark.parametrize("api", endpoints) def test_api1_integer(api): response = client.post(api, json={"integer_or_null": 100}) assert response.status_code == 200, response.text From 6d1617df7f8e837e2b52c219b8c017c9d4ca2a81 Mon Sep 17 00:00:00 2001 From: merlinz01 <158784988+merlinz01@users.noreply.github.com> Date: Thu, 5 Sep 2024 21:58:36 -0400 Subject: [PATCH 3/8] make compatible with Python <3.10 and Pydantic v1 Signed-off-by: merlinz01 <158784988+merlinz01@users.noreply.github.com> --- fastapi/dependencies/utils.py | 22 +++++++++++++++----- tests/test_none_passed_when_null_received.py | 4 ++-- 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index b4f1938b4..2f0b806ec 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -733,11 +733,23 @@ async def solve_dependencies( ) -def _allows_none(field: ModelField) -> bool: - origin = get_origin(field.type_) - return (origin is Union or origin is types.UnionType) and type(None) in get_args( - field.type_ - ) +if PYDANTIC_V2: + if sys.hexversion >= 0x30A00000: + + def _allows_none(field: ModelField) -> bool: + origin = get_origin(field.type_) + return (origin is Union or origin is types.UnionType) and type( + None + ) in get_args(field.type_) + else: + + def _allows_none(field: ModelField) -> bool: + origin = get_origin(field.type_) + return origin is Union and type(None) in get_args(field.type_) +else: + + def _allows_none(field: ModelField) -> bool: + return field.allow_none def _validate_value_with_model_field( diff --git a/tests/test_none_passed_when_null_received.py b/tests/test_none_passed_when_null_received.py index 4d1b3b0a8..51d3991f0 100644 --- a/tests/test_none_passed_when_null_received.py +++ b/tests/test_none_passed_when_null_received.py @@ -10,7 +10,7 @@ DEFAULT = 1234567890 endpoints = [] -if sys.hexversion >= 0x31000000: +if sys.hexversion >= 0x30A0000: from typing import Annotated @app.post("/api1") @@ -22,7 +22,7 @@ if sys.hexversion >= 0x31000000: endpoints.append("/api1") -if sys.hexversion >= 0x30900000: +if sys.hexversion >= 0x3090000: from typing import Annotated @app.post("/api2") From 1252ed2df506c63570594c7261157161992b7172 Mon Sep 17 00:00:00 2001 From: merlinz01 <158784988+merlinz01@users.noreply.github.com> Date: Thu, 5 Sep 2024 22:10:42 -0400 Subject: [PATCH 4/8] fix Python version check Signed-off-by: merlinz01 <158784988+merlinz01@users.noreply.github.com> --- fastapi/dependencies/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index 2f0b806ec..184f10613 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -734,7 +734,7 @@ async def solve_dependencies( if PYDANTIC_V2: - if sys.hexversion >= 0x30A00000: + if sys.hexversion >= 0x30A0000: def _allows_none(field: ModelField) -> bool: origin = get_origin(field.type_) @@ -749,7 +749,7 @@ if PYDANTIC_V2: else: def _allows_none(field: ModelField) -> bool: - return field.allow_none + return field.allow_none # type: ignore def _validate_value_with_model_field( From f433590560b4d83316ec833b57829b0a00baca2d Mon Sep 17 00:00:00 2001 From: merlinz01 <158784988+merlinz01@users.noreply.github.com> Date: Thu, 5 Sep 2024 23:04:10 -0400 Subject: [PATCH 5/8] add test for required field and passing null Signed-off-by: merlinz01 <158784988+merlinz01@users.noreply.github.com> --- tests/test_none_passed_when_null_received.py | 22 +++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/tests/test_none_passed_when_null_received.py b/tests/test_none_passed_when_null_received.py index 51d3991f0..1a0be42fd 100644 --- a/tests/test_none_passed_when_null_received.py +++ b/tests/test_none_passed_when_null_received.py @@ -50,11 +50,16 @@ def api4(integer_or_null: Optional[int] = Body(embed=True, default=DEFAULT)) -> endpoints.append("/api4") +@app.post("/api5") +def api5(integer: int = Body(embed=True)) -> dict: + return {"received": integer} + + client = TestClient(app) @pytest.mark.parametrize("api", endpoints) -def test_api1_integer(api): +def test_apis(api): response = client.post(api, json={"integer_or_null": 100}) assert response.status_code == 200, response.text assert response.json() == {"received": 100} @@ -62,3 +67,18 @@ def test_api1_integer(api): response = client.post(api, json={"integer_or_null": None}) assert response.status_code == 200, response.text assert response.json() == {"received": None} + + +def test_required_field(): + response = client.post("/api5", json={"integer": None}) + assert response.status_code == 422, response.text + assert response.json() == { + "detail": [ + { + "loc": ["body", "integer"], + "msg": "Field required", + "type": "missing", + "input": None, + } + ] + } From 78b49ac9ae13bf4dd825f25ccc3381a924328cc5 Mon Sep 17 00:00:00 2001 From: merlinz01 <158784988+merlinz01@users.noreply.github.com> Date: Fri, 6 Sep 2024 06:57:14 -0400 Subject: [PATCH 6/8] update required field test for Pydantic v1 Signed-off-by: merlinz01 <158784988+merlinz01@users.noreply.github.com> --- tests/test_none_passed_when_null_received.py | 33 ++++++++++++++------ 1 file changed, 23 insertions(+), 10 deletions(-) diff --git a/tests/test_none_passed_when_null_received.py b/tests/test_none_passed_when_null_received.py index 1a0be42fd..868dd3350 100644 --- a/tests/test_none_passed_when_null_received.py +++ b/tests/test_none_passed_when_null_received.py @@ -2,6 +2,7 @@ import sys from typing import Optional, Union import pytest +from dirty_equals import IsDict from fastapi import Body, FastAPI from fastapi.testclient import TestClient @@ -72,13 +73,25 @@ def test_apis(api): def test_required_field(): response = client.post("/api5", json={"integer": None}) assert response.status_code == 422, response.text - assert response.json() == { - "detail": [ - { - "loc": ["body", "integer"], - "msg": "Field required", - "type": "missing", - "input": None, - } - ] - } + assert response.json() == IsDict( + { + "detail": [ + { + "loc": ["body", "integer"], + "msg": "Field required", + "type": "missing", + "input": None, + } + ] + } + ) | IsDict( + { + "detail": [ + { + "loc": ["body", "integer"], + "msg": "field required", + "type": "value_error.missing", + } + ] + } + ) From 96d2ca4426900701df22bcc68a77cb4944c134db Mon Sep 17 00:00:00 2001 From: merlinz01 <158784988+merlinz01@users.noreply.github.com> Date: Fri, 6 Sep 2024 07:04:14 -0400 Subject: [PATCH 7/8] add test for full coverage Signed-off-by: merlinz01 <158784988+merlinz01@users.noreply.github.com> --- tests/test_none_passed_when_null_received.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/test_none_passed_when_null_received.py b/tests/test_none_passed_when_null_received.py index 868dd3350..3da094c6a 100644 --- a/tests/test_none_passed_when_null_received.py +++ b/tests/test_none_passed_when_null_received.py @@ -71,6 +71,10 @@ def test_apis(api): def test_required_field(): + response = client.post("/api5", json={"integer": 100}) + assert response.status_code == 200, response.text + assert response.json() == {"received": 100} + response = client.post("/api5", json={"integer": None}) assert response.status_code == 422, response.text assert response.json() == IsDict( From af05854e916ca237474537124b5113b42d9cb5f9 Mon Sep 17 00:00:00 2001 From: merlinz01 <158784988+merlinz01@users.noreply.github.com> Date: Sat, 12 Oct 2024 21:59:00 -0400 Subject: [PATCH 8/8] update to make work with latest master changes Signed-off-by: merlinz01 <158784988+merlinz01@users.noreply.github.com> --- fastapi/dependencies/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index 184f10613..2bebef62a 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -842,7 +842,7 @@ def request_params_to_args( if alias == field.name: alias = alias.replace("_", "-") value = _get_multidict_value(field, received_params, alias=alias) - if value is not None: + if value is not Undefined and value is not None: params_to_process[get_validation_alias(field)] = value processed_keys.add(alias or get_validation_alias(field))