diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index 80f9c76e9..e0a2a1b33 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -1,6 +1,7 @@ import dataclasses import inspect import sys +import types from collections.abc import Coroutine, Mapping, Sequence from contextlib import AsyncExitStack, contextmanager from copy import copy, deepcopy @@ -21,6 +22,7 @@ from fastapi._compat import ( ModelField, RequiredParam, Undefined, + UndefinedType, copy_field_info, create_body_model, evaluate_forwardref, @@ -570,7 +572,7 @@ async def solve_dependencies( *, request: Union[Request, WebSocket], dependant: Dependant, - body: Optional[Union[dict[str, Any], FormData]] = None, + body: Optional[Union[dict[str, Any], FormData, UndefinedType]] = Undefined, background_tasks: Optional[StarletteBackgroundTasks] = None, response: Optional[Response] = None, dependency_overrides_provider: Optional[Any] = None, @@ -706,10 +708,24 @@ async def solve_dependencies( ) +def _allows_none(field: ModelField) -> bool: + origin = get_origin(field.type_) + return (origin is Union or origin is types.UnionType) and type(None) in get_args( + field.type_ + ) + + def _validate_value_with_model_field( *, field: ModelField, value: Any, values: dict[str, Any], loc: tuple[str, ...] ) -> tuple[Any, list[Any]]: + if value is Undefined: + if field.required: + return None, [get_missing_field_error(loc=loc)] + else: + return deepcopy(field.default), [] if value is None: + if _allows_none(field): + return value, [] if field.field_info.is_required(): return None, [get_missing_field_error(loc=loc)] else: @@ -732,9 +748,9 @@ def _get_multidict_value( ): value = values.getlist(alias) else: - value = values.get(alias, None) + value = values.get(alias, Undefined) if ( - value is None + value is Undefined or ( isinstance(field.field_info, params.Form) and isinstance(value, str) # For type checks @@ -746,7 +762,7 @@ def _get_multidict_value( ) ): if field.field_info.is_required(): - return + return Undefined else: return deepcopy(field.default) return value @@ -914,7 +930,7 @@ async def _extract_form_body( for sub_value in value: tg.start_soon(process_fn, sub_value.read) value = serialize_sequence_value(field=field, value=results) - if value is not None: + if value is not Undefined and value is not None: values[get_validation_alias(field)] = value field_aliases = {get_validation_alias(field) for field in body_fields} for key in received_body.keys(): @@ -929,7 +945,7 @@ async def _extract_form_body( async def request_body_to_args( body_fields: list[ModelField], - received_body: Optional[Union[dict[str, Any], FormData]], + received_body: Optional[Union[dict[str, Any], FormData, UndefinedType]], embed_body_fields: bool, ) -> tuple[dict[str, Any], list[dict[str, Any]]]: values: dict[str, Any] = {} @@ -959,10 +975,12 @@ async def request_body_to_args( return {first_field.name: v_}, errors_ for field in body_fields: loc = ("body", get_validation_alias(field)) - value: Optional[Any] = None - if body_to_process is not None: + value: Optional[Any] = Undefined + if body_to_process is not None and not isinstance( + body_to_process, UndefinedType + ): try: - value = body_to_process.get(get_validation_alias(field)) + value = body_to_process.get(get_validation_alias(field), Undefined) # If the received body is a list, not a dict except AttributeError: errors.append(get_missing_field_error(loc)) diff --git a/tests/test_none_passed_when_null_received.py b/tests/test_none_passed_when_null_received.py new file mode 100644 index 000000000..f4e171202 --- /dev/null +++ b/tests/test_none_passed_when_null_received.py @@ -0,0 +1,44 @@ +from typing import Annotated, Optional, Union + +import pytest +from fastapi import Body, FastAPI +from fastapi.testclient import TestClient + +app = FastAPI() +DEFAULT = 1234567890 + + +@app.post("/api1") +def api1(integer_or_null: Annotated[int | None, Body(embed=True)] = DEFAULT) -> dict: + return {"received": integer_or_null} + + +@app.post("/api2") +def api2(integer_or_null: Annotated[Optional[int], Body(embed=True)] = DEFAULT) -> dict: + return {"received": integer_or_null} + + +@app.post("/api3") +def api3( + integer_or_null: Annotated[Union[int, None], Body(embed=True)] = DEFAULT, +) -> dict: + return {"received": integer_or_null} + + +@app.post("/api4") +def api4(integer_or_null: Optional[int] = Body(embed=True, default=DEFAULT)) -> dict: + return {"received": integer_or_null} + + +client = TestClient(app) + + +@pytest.mark.parametrize("api", ["/api1", "/api2", "/api3", "/api4"]) +def test_api1_integer(api): + response = client.post(api, json={"integer_or_null": 100}) + assert response.status_code == 200, response.text + assert response.json() == {"received": 100} + + response = client.post(api, json={"integer_or_null": None}) + assert response.status_code == 200, response.text + assert response.json() == {"received": None}