Pass None instead of the default value to parameters that accept it when

null is given

Signed-off-by: merlinz01 <158784988+merlinz01@users.noreply.github.com>
This commit is contained in:
merlinz01 2024-09-05 14:32:14 -04:00
parent e94028ab60
commit 5122db7eb6
2 changed files with 71 additions and 9 deletions

View File

@ -1,6 +1,7 @@
import dataclasses
import inspect
import sys
import types
from collections.abc import Coroutine, Mapping, Sequence
from contextlib import AsyncExitStack, contextmanager
from copy import copy, deepcopy
@ -21,6 +22,7 @@ from fastapi._compat import (
ModelField,
RequiredParam,
Undefined,
UndefinedType,
copy_field_info,
create_body_model,
evaluate_forwardref,
@ -570,7 +572,7 @@ async def solve_dependencies(
*,
request: Union[Request, WebSocket],
dependant: Dependant,
body: Optional[Union[dict[str, Any], FormData]] = None,
body: Optional[Union[dict[str, Any], FormData, UndefinedType]] = Undefined,
background_tasks: Optional[StarletteBackgroundTasks] = None,
response: Optional[Response] = None,
dependency_overrides_provider: Optional[Any] = None,
@ -706,10 +708,24 @@ async def solve_dependencies(
)
def _allows_none(field: ModelField) -> bool:
origin = get_origin(field.type_)
return (origin is Union or origin is types.UnionType) and type(None) in get_args(
field.type_
)
def _validate_value_with_model_field(
*, field: ModelField, value: Any, values: dict[str, Any], loc: tuple[str, ...]
) -> tuple[Any, list[Any]]:
if value is Undefined:
if field.required:
return None, [get_missing_field_error(loc=loc)]
else:
return deepcopy(field.default), []
if value is None:
if _allows_none(field):
return value, []
if field.field_info.is_required():
return None, [get_missing_field_error(loc=loc)]
else:
@ -732,9 +748,9 @@ def _get_multidict_value(
):
value = values.getlist(alias)
else:
value = values.get(alias, None)
value = values.get(alias, Undefined)
if (
value is None
value is Undefined
or (
isinstance(field.field_info, params.Form)
and isinstance(value, str) # For type checks
@ -746,7 +762,7 @@ def _get_multidict_value(
)
):
if field.field_info.is_required():
return
return Undefined
else:
return deepcopy(field.default)
return value
@ -914,7 +930,7 @@ async def _extract_form_body(
for sub_value in value:
tg.start_soon(process_fn, sub_value.read)
value = serialize_sequence_value(field=field, value=results)
if value is not None:
if value is not Undefined and value is not None:
values[get_validation_alias(field)] = value
field_aliases = {get_validation_alias(field) for field in body_fields}
for key in received_body.keys():
@ -929,7 +945,7 @@ async def _extract_form_body(
async def request_body_to_args(
body_fields: list[ModelField],
received_body: Optional[Union[dict[str, Any], FormData]],
received_body: Optional[Union[dict[str, Any], FormData, UndefinedType]],
embed_body_fields: bool,
) -> tuple[dict[str, Any], list[dict[str, Any]]]:
values: dict[str, Any] = {}
@ -959,10 +975,12 @@ async def request_body_to_args(
return {first_field.name: v_}, errors_
for field in body_fields:
loc = ("body", get_validation_alias(field))
value: Optional[Any] = None
if body_to_process is not None:
value: Optional[Any] = Undefined
if body_to_process is not None and not isinstance(
body_to_process, UndefinedType
):
try:
value = body_to_process.get(get_validation_alias(field))
value = body_to_process.get(get_validation_alias(field), Undefined)
# If the received body is a list, not a dict
except AttributeError:
errors.append(get_missing_field_error(loc))

View File

@ -0,0 +1,44 @@
from typing import Annotated, Optional, Union
import pytest
from fastapi import Body, FastAPI
from fastapi.testclient import TestClient
app = FastAPI()
DEFAULT = 1234567890
@app.post("/api1")
def api1(integer_or_null: Annotated[int | None, Body(embed=True)] = DEFAULT) -> dict:
return {"received": integer_or_null}
@app.post("/api2")
def api2(integer_or_null: Annotated[Optional[int], Body(embed=True)] = DEFAULT) -> dict:
return {"received": integer_or_null}
@app.post("/api3")
def api3(
integer_or_null: Annotated[Union[int, None], Body(embed=True)] = DEFAULT,
) -> dict:
return {"received": integer_or_null}
@app.post("/api4")
def api4(integer_or_null: Optional[int] = Body(embed=True, default=DEFAULT)) -> dict:
return {"received": integer_or_null}
client = TestClient(app)
@pytest.mark.parametrize("api", ["/api1", "/api2", "/api3", "/api4"])
def test_api1_integer(api):
response = client.post(api, json={"integer_or_null": 100})
assert response.status_code == 200, response.text
assert response.json() == {"received": 100}
response = client.post(api, json={"integer_or_null": None})
assert response.status_code == 200, response.text
assert response.json() == {"received": None}