This commit is contained in:
Merlin 2025-12-16 21:07:31 +00:00 committed by GitHub
commit 821e24d96b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 142 additions and 10 deletions

View File

@ -1,6 +1,7 @@
import dataclasses import dataclasses
import inspect import inspect
import sys import sys
import types
from contextlib import AsyncExitStack, contextmanager from contextlib import AsyncExitStack, contextmanager
from copy import copy, deepcopy from copy import copy, deepcopy
from dataclasses import dataclass from dataclasses import dataclass
@ -27,6 +28,7 @@ from fastapi._compat import (
ModelField, ModelField,
RequiredParam, RequiredParam,
Undefined, Undefined,
UndefinedType,
_is_error_wrapper, _is_error_wrapper,
_is_model_class, _is_model_class,
copy_field_info, copy_field_info,
@ -595,7 +597,7 @@ async def solve_dependencies(
*, *,
request: Union[Request, WebSocket], request: Union[Request, WebSocket],
dependant: Dependant, dependant: Dependant,
body: Optional[Union[Dict[str, Any], FormData]] = None, body: Optional[Union[Dict[str, Any], FormData, UndefinedType]] = Undefined,
background_tasks: Optional[StarletteBackgroundTasks] = None, background_tasks: Optional[StarletteBackgroundTasks] = None,
response: Optional[Response] = None, response: Optional[Response] = None,
dependency_overrides_provider: Optional[Any] = None, dependency_overrides_provider: Optional[Any] = None,
@ -731,10 +733,36 @@ async def solve_dependencies(
) )
if PYDANTIC_V2:
if sys.hexversion >= 0x30A0000:
def _allows_none(field: ModelField) -> bool:
origin = get_origin(field.type_)
return (origin is Union or origin is types.UnionType) and type(
None
) in get_args(field.type_)
else:
def _allows_none(field: ModelField) -> bool:
origin = get_origin(field.type_)
return origin is Union and type(None) in get_args(field.type_)
else:
def _allows_none(field: ModelField) -> bool:
return field.allow_none # type: ignore
def _validate_value_with_model_field( def _validate_value_with_model_field(
*, field: ModelField, value: Any, values: Dict[str, Any], loc: Tuple[str, ...] *, field: ModelField, value: Any, values: Dict[str, Any], loc: Tuple[str, ...]
) -> Tuple[Any, List[Any]]: ) -> Tuple[Any, List[Any]]:
if value is Undefined:
if field.required:
return None, [get_missing_field_error(loc=loc)]
else:
return deepcopy(field.default), []
if value is None: if value is None:
if _allows_none(field):
return value, []
if field.required: if field.required:
return None, [get_missing_field_error(loc=loc)] return None, [get_missing_field_error(loc=loc)]
else: else:
@ -753,12 +781,13 @@ def _get_multidict_value(
field: ModelField, values: Mapping[str, Any], alias: Union[str, None] = None field: ModelField, values: Mapping[str, Any], alias: Union[str, None] = None
) -> Any: ) -> Any:
alias = alias or get_validation_alias(field) alias = alias or get_validation_alias(field)
value: Any
if is_sequence_field(field) and isinstance(values, (ImmutableMultiDict, Headers)): if is_sequence_field(field) and isinstance(values, (ImmutableMultiDict, Headers)):
value = values.getlist(alias) value = values.getlist(alias)
else: else:
value = values.get(alias, None) value = values.get(alias, Undefined)
if ( if (
value is None value is Undefined
or ( or (
isinstance(field.field_info, (params.Form, temp_pydantic_v1_params.Form)) isinstance(field.field_info, (params.Form, temp_pydantic_v1_params.Form))
and isinstance(value, str) # For type checks and isinstance(value, str) # For type checks
@ -767,7 +796,7 @@ def _get_multidict_value(
or (is_sequence_field(field) and len(value) == 0) or (is_sequence_field(field) and len(value) == 0)
): ):
if field.required: if field.required:
return return Undefined
else: else:
return deepcopy(field.default) return deepcopy(field.default)
return value return value
@ -813,7 +842,7 @@ def request_params_to_args(
if alias == field.name: if alias == field.name:
alias = alias.replace("_", "-") alias = alias.replace("_", "-")
value = _get_multidict_value(field, received_params, alias=alias) value = _get_multidict_value(field, received_params, alias=alias)
if value is not None: if value is not Undefined and value is not None:
params_to_process[get_validation_alias(field)] = value params_to_process[get_validation_alias(field)] = value
processed_keys.add(alias or get_validation_alias(field)) processed_keys.add(alias or get_validation_alias(field))
@ -933,7 +962,7 @@ async def _extract_form_body(
for sub_value in value: for sub_value in value:
tg.start_soon(process_fn, sub_value.read) tg.start_soon(process_fn, sub_value.read)
value = serialize_sequence_value(field=field, value=results) value = serialize_sequence_value(field=field, value=results)
if value is not None: if value is not Undefined and value is not None:
values[get_validation_alias(field)] = value values[get_validation_alias(field)] = value
field_aliases = {get_validation_alias(field) for field in body_fields} field_aliases = {get_validation_alias(field) for field in body_fields}
for key in received_body.keys(): for key in received_body.keys():
@ -948,7 +977,7 @@ async def _extract_form_body(
async def request_body_to_args( async def request_body_to_args(
body_fields: List[ModelField], body_fields: List[ModelField],
received_body: Optional[Union[Dict[str, Any], FormData]], received_body: Optional[Union[Dict[str, Any], FormData, UndefinedType]],
embed_body_fields: bool, embed_body_fields: bool,
) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]: ) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]:
values: Dict[str, Any] = {} values: Dict[str, Any] = {}
@ -978,10 +1007,12 @@ async def request_body_to_args(
return {first_field.name: v_}, errors_ return {first_field.name: v_}, errors_
for field in body_fields: for field in body_fields:
loc = ("body", get_validation_alias(field)) loc = ("body", get_validation_alias(field))
value: Optional[Any] = None value: Optional[Any] = Undefined
if body_to_process is not None: if body_to_process is not None and not isinstance(
body_to_process, UndefinedType
):
try: try:
value = body_to_process.get(get_validation_alias(field)) value = body_to_process.get(get_validation_alias(field), Undefined)
# If the received body is a list, not a dict # If the received body is a list, not a dict
except AttributeError: except AttributeError:
errors.append(get_missing_field_error(loc)) errors.append(get_missing_field_error(loc))

View File

@ -0,0 +1,101 @@
import sys
from typing import Optional, Union
import pytest
from dirty_equals import IsDict
from fastapi import Body, FastAPI
from fastapi.testclient import TestClient
app = FastAPI()
DEFAULT = 1234567890
endpoints = []
if sys.hexversion >= 0x30A0000:
from typing import Annotated
@app.post("/api1")
def api1(
integer_or_null: Annotated[int | None, Body(embed=True)] = DEFAULT,
) -> dict:
return {"received": integer_or_null}
endpoints.append("/api1")
if sys.hexversion >= 0x3090000:
from typing import Annotated
@app.post("/api2")
def api2(
integer_or_null: Annotated[Optional[int], Body(embed=True)] = DEFAULT,
) -> dict:
return {"received": integer_or_null}
endpoints.append("/api2")
@app.post("/api3")
def api3(
integer_or_null: Annotated[Union[int, None], Body(embed=True)] = DEFAULT,
) -> dict:
return {"received": integer_or_null}
endpoints.append("/api3")
@app.post("/api4")
def api4(integer_or_null: Optional[int] = Body(embed=True, default=DEFAULT)) -> dict:
return {"received": integer_or_null}
endpoints.append("/api4")
@app.post("/api5")
def api5(integer: int = Body(embed=True)) -> dict:
return {"received": integer}
client = TestClient(app)
@pytest.mark.parametrize("api", endpoints)
def test_apis(api):
response = client.post(api, json={"integer_or_null": 100})
assert response.status_code == 200, response.text
assert response.json() == {"received": 100}
response = client.post(api, json={"integer_or_null": None})
assert response.status_code == 200, response.text
assert response.json() == {"received": None}
def test_required_field():
response = client.post("/api5", json={"integer": 100})
assert response.status_code == 200, response.text
assert response.json() == {"received": 100}
response = client.post("/api5", json={"integer": None})
assert response.status_code == 422, response.text
assert response.json() == IsDict(
{
"detail": [
{
"loc": ["body", "integer"],
"msg": "Field required",
"type": "missing",
"input": None,
}
]
}
) | IsDict(
{
"detail": [
{
"loc": ["body", "integer"],
"msg": "field required",
"type": "value_error.missing",
}
]
}
)