diff --git a/fastapi/_compat/__init__.py b/fastapi/_compat/__init__.py index 4581c38c8..84076e164 100644 --- a/fastapi/_compat/__init__.py +++ b/fastapi/_compat/__init__.py @@ -23,6 +23,7 @@ from .v2 import ModelField as ModelField from .v2 import PydanticSchemaGenerationError as PydanticSchemaGenerationError from .v2 import RequiredParam as RequiredParam from .v2 import Undefined as Undefined +from .v2 import UndefinedType as UndefinedType from .v2 import Url as Url from .v2 import copy_field_info as copy_field_info from .v2 import create_body_model as create_body_model diff --git a/fastapi/_compat/v2.py b/fastapi/_compat/v2.py index b83bc1b55..d0257acac 100644 --- a/fastapi/_compat/v2.py +++ b/fastapi/_compat/v2.py @@ -30,7 +30,7 @@ from pydantic.fields import FieldInfo as FieldInfo from pydantic.json_schema import GenerateJsonSchema as GenerateJsonSchema from pydantic.json_schema import JsonSchemaValue as JsonSchemaValue from pydantic_core import CoreSchema as CoreSchema -from pydantic_core import PydanticUndefined +from pydantic_core import PydanticUndefined, PydanticUndefinedType from pydantic_core import Url as Url from pydantic_core.core_schema import ( with_info_plain_validator_function as with_info_plain_validator_function, @@ -38,6 +38,7 @@ from pydantic_core.core_schema import ( RequiredParam = PydanticUndefined Undefined = PydanticUndefined +UndefinedType = PydanticUndefinedType evaluate_forwardref = eval_type_lenient # TODO: remove when dropping support for Pydantic < v2.12.3 diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index ab18ec2db..6621735e5 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -1,6 +1,7 @@ import dataclasses import inspect import sys +import types from collections.abc import Callable, Mapping, Sequence from contextlib import AsyncExitStack, contextmanager from copy import copy, deepcopy @@ -21,6 +22,7 @@ from fastapi._compat import ( ModelField, RequiredParam, Undefined, + UndefinedType, copy_field_info, create_body_model, evaluate_forwardref, @@ -566,7 +568,7 @@ async def solve_dependencies( *, request: Request | WebSocket, dependant: Dependant, - body: dict[str, Any] | FormData | None = None, + body: dict[str, Any] | FormData | UndefinedType | None = None, background_tasks: StarletteBackgroundTasks | None = None, response: Response | None = None, dependency_overrides_provider: Any | None = None, @@ -702,10 +704,32 @@ async def solve_dependencies( ) +if sys.hexversion >= 0x030A0000 and sys.hexversion < 0x030E0000: + + def _allows_none(field: ModelField) -> bool: + origin = get_origin(field.field_info.annotation) + return (origin is Union or origin is types.UnionType) and type( + None + ) in get_args(field.field_info.annotation) + +else: + + def _allows_none(field: ModelField) -> bool: + origin = get_origin(field.field_info.annotation) + return origin is Union and type(None) in get_args(field.field_info.annotation) + + def _validate_value_with_model_field( *, field: ModelField, value: Any, values: dict[str, Any], loc: tuple[str, ...] ) -> tuple[Any, list[Any]]: + if value is Undefined: + if field.field_info.is_required(): + return None, [get_missing_field_error(loc=loc)] + else: + return deepcopy(field.default), [] if value is None: + if _allows_none(field): + return value, [] if field.field_info.is_required(): return None, [get_missing_field_error(loc=loc)] else: @@ -721,6 +745,7 @@ def _get_multidict_value( field: ModelField, values: Mapping[str, Any], alias: str | None = None ) -> Any: alias = alias or get_validation_alias(field) + value: Any if ( (not _is_json_field(field)) and field_annotation_is_sequence(field.field_info.annotation) @@ -728,9 +753,9 @@ def _get_multidict_value( ): value = values.getlist(alias) else: - value = values.get(alias, None) + value = values.get(alias, Undefined) if ( - value is None + value is Undefined or ( isinstance(field.field_info, params.Form) and isinstance(value, str) # For type checks @@ -742,7 +767,7 @@ def _get_multidict_value( ) ): if field.field_info.is_required(): - return + return Undefined else: return deepcopy(field.default) return value @@ -790,7 +815,7 @@ def request_params_to_args( if alias == field.name: alias = alias.replace("_", "-") value = _get_multidict_value(field, received_params, alias=alias) - if value is not None: + if value is not Undefined and value is not None: params_to_process[get_validation_alias(field)] = value processed_keys.add(alias or get_validation_alias(field)) @@ -902,7 +927,7 @@ async def _extract_form_body( for sub_value in value: results.append(await sub_value.read()) value = serialize_sequence_value(field=field, value=results) - if value is not None: + if value is not Undefined and value is not None: values[get_validation_alias(field)] = value field_aliases = {get_validation_alias(field) for field in body_fields} for key in received_body.keys(): @@ -917,7 +942,7 @@ async def _extract_form_body( async def request_body_to_args( body_fields: list[ModelField], - received_body: dict[str, Any] | FormData | None, + received_body: dict[str, Any] | FormData | UndefinedType | None, embed_body_fields: bool, ) -> tuple[dict[str, Any], list[dict[str, Any]]]: values: dict[str, Any] = {} @@ -947,10 +972,12 @@ async def request_body_to_args( return {first_field.name: v_}, errors_ for field in body_fields: loc = ("body", get_validation_alias(field)) - value: Any | None = None - if body_to_process is not None: + value: Any = Undefined + if body_to_process is not None and not isinstance( + body_to_process, UndefinedType + ): try: - value = body_to_process.get(get_validation_alias(field)) + value = body_to_process.get(get_validation_alias(field), Undefined) # If the received body is a list, not a dict except AttributeError: errors.append(get_missing_field_error(loc)) diff --git a/tests/test_none_passed_when_null_received.py b/tests/test_none_passed_when_null_received.py new file mode 100644 index 000000000..5c431af61 --- /dev/null +++ b/tests/test_none_passed_when_null_received.py @@ -0,0 +1,100 @@ +import sys + +import pytest +from dirty_equals import IsDict +from fastapi import Body, FastAPI +from fastapi.testclient import TestClient + +app = FastAPI() +DEFAULT = 1234567890 + +endpoints = [] + +if sys.hexversion >= 0x30A0000: + from typing import Annotated + + @app.post("/api1") + def api1( + integer_or_null: Annotated[int | None, Body(embed=True)] = DEFAULT, + ) -> dict: + return {"received": integer_or_null} + + endpoints.append("/api1") + + +if sys.hexversion >= 0x3090000: + from typing import Annotated + + @app.post("/api2") + def api2( + integer_or_null: Annotated[int | None, Body(embed=True)] = DEFAULT, + ) -> dict: + return {"received": integer_or_null} + + endpoints.append("/api2") + + @app.post("/api3") + def api3( + integer_or_null: Annotated[int | None, Body(embed=True)] = DEFAULT, + ) -> dict: + return {"received": integer_or_null} + + endpoints.append("/api3") + + +@app.post("/api4") +def api4(integer_or_null: int | None = Body(embed=True, default=DEFAULT)) -> dict: + return {"received": integer_or_null} + + +endpoints.append("/api4") + + +@app.post("/api5") +def api5(integer: int = Body(embed=True)) -> dict: + return {"received": integer} + + +client = TestClient(app) + + +@pytest.mark.parametrize("api", endpoints) +def test_apis(api): + response = client.post(api, json={"integer_or_null": 100}) + assert response.status_code == 200, response.text + assert response.json() == {"received": 100} + + response = client.post(api, json={"integer_or_null": None}) + assert response.status_code == 200, response.text + assert response.json() == {"received": None} + + +def test_required_field(): + response = client.post("/api5", json={"integer": 100}) + assert response.status_code == 200, response.text + assert response.json() == {"received": 100} + + response = client.post("/api5", json={"integer": None}) + assert response.status_code == 422, response.text + assert response.json() == IsDict( + { + "detail": [ + { + "loc": ["body", "integer"], + "msg": "Field required", + "type": "missing", + "input": None, + } + ] + } + ) | IsDict( + { + "detail": [ + { + "loc": ["body", "integer"], + "msg": "field required", + "type": "value_error.missing", + } + ] + } + )