This commit is contained in:
Merlin 2026-02-17 10:01:26 +00:00 committed by GitHub
commit 2fbba95ecb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 140 additions and 11 deletions

View File

@ -23,6 +23,7 @@ from .v2 import ModelField as ModelField
from .v2 import PydanticSchemaGenerationError as PydanticSchemaGenerationError
from .v2 import RequiredParam as RequiredParam
from .v2 import Undefined as Undefined
from .v2 import UndefinedType as UndefinedType
from .v2 import Url as Url
from .v2 import copy_field_info as copy_field_info
from .v2 import create_body_model as create_body_model

View File

@ -30,7 +30,7 @@ from pydantic.fields import FieldInfo as FieldInfo
from pydantic.json_schema import GenerateJsonSchema as GenerateJsonSchema
from pydantic.json_schema import JsonSchemaValue as JsonSchemaValue
from pydantic_core import CoreSchema as CoreSchema
from pydantic_core import PydanticUndefined
from pydantic_core import PydanticUndefined, PydanticUndefinedType
from pydantic_core import Url as Url
from pydantic_core.core_schema import (
with_info_plain_validator_function as with_info_plain_validator_function,
@ -38,6 +38,7 @@ from pydantic_core.core_schema import (
RequiredParam = PydanticUndefined
Undefined = PydanticUndefined
UndefinedType = PydanticUndefinedType
evaluate_forwardref = eval_type_lenient
# TODO: remove when dropping support for Pydantic < v2.12.3

View File

@ -1,6 +1,7 @@
import dataclasses
import inspect
import sys
import types
from collections.abc import Callable, Mapping, Sequence
from contextlib import AsyncExitStack, contextmanager
from copy import copy, deepcopy
@ -21,6 +22,7 @@ from fastapi._compat import (
ModelField,
RequiredParam,
Undefined,
UndefinedType,
copy_field_info,
create_body_model,
evaluate_forwardref,
@ -566,7 +568,7 @@ async def solve_dependencies(
*,
request: Request | WebSocket,
dependant: Dependant,
body: dict[str, Any] | FormData | None = None,
body: dict[str, Any] | FormData | UndefinedType | None = None,
background_tasks: StarletteBackgroundTasks | None = None,
response: Response | None = None,
dependency_overrides_provider: Any | None = None,
@ -702,10 +704,32 @@ async def solve_dependencies(
)
if sys.hexversion >= 0x030A0000 and sys.hexversion < 0x030E0000:
def _allows_none(field: ModelField) -> bool:
origin = get_origin(field.field_info.annotation)
return (origin is Union or origin is types.UnionType) and type(
None
) in get_args(field.field_info.annotation)
else:
def _allows_none(field: ModelField) -> bool:
origin = get_origin(field.field_info.annotation)
return origin is Union and type(None) in get_args(field.field_info.annotation)
def _validate_value_with_model_field(
*, field: ModelField, value: Any, values: dict[str, Any], loc: tuple[str, ...]
) -> tuple[Any, list[Any]]:
if value is Undefined:
if field.field_info.is_required():
return None, [get_missing_field_error(loc=loc)]
else:
return deepcopy(field.default), []
if value is None:
if _allows_none(field):
return value, []
if field.field_info.is_required():
return None, [get_missing_field_error(loc=loc)]
else:
@ -721,6 +745,7 @@ def _get_multidict_value(
field: ModelField, values: Mapping[str, Any], alias: str | None = None
) -> Any:
alias = alias or get_validation_alias(field)
value: Any
if (
(not _is_json_field(field))
and field_annotation_is_sequence(field.field_info.annotation)
@ -728,9 +753,9 @@ def _get_multidict_value(
):
value = values.getlist(alias)
else:
value = values.get(alias, None)
value = values.get(alias, Undefined)
if (
value is None
value is Undefined
or (
isinstance(field.field_info, params.Form)
and isinstance(value, str) # For type checks
@ -742,7 +767,7 @@ def _get_multidict_value(
)
):
if field.field_info.is_required():
return
return Undefined
else:
return deepcopy(field.default)
return value
@ -790,7 +815,7 @@ def request_params_to_args(
if alias == field.name:
alias = alias.replace("_", "-")
value = _get_multidict_value(field, received_params, alias=alias)
if value is not None:
if value is not Undefined and value is not None:
params_to_process[get_validation_alias(field)] = value
processed_keys.add(alias or get_validation_alias(field))
@ -902,7 +927,7 @@ async def _extract_form_body(
for sub_value in value:
results.append(await sub_value.read())
value = serialize_sequence_value(field=field, value=results)
if value is not None:
if value is not Undefined and value is not None:
values[get_validation_alias(field)] = value
field_aliases = {get_validation_alias(field) for field in body_fields}
for key in received_body.keys():
@ -917,7 +942,7 @@ async def _extract_form_body(
async def request_body_to_args(
body_fields: list[ModelField],
received_body: dict[str, Any] | FormData | None,
received_body: dict[str, Any] | FormData | UndefinedType | None,
embed_body_fields: bool,
) -> tuple[dict[str, Any], list[dict[str, Any]]]:
values: dict[str, Any] = {}
@ -947,10 +972,12 @@ async def request_body_to_args(
return {first_field.name: v_}, errors_
for field in body_fields:
loc = ("body", get_validation_alias(field))
value: Any | None = None
if body_to_process is not None:
value: Any = Undefined
if body_to_process is not None and not isinstance(
body_to_process, UndefinedType
):
try:
value = body_to_process.get(get_validation_alias(field))
value = body_to_process.get(get_validation_alias(field), Undefined)
# If the received body is a list, not a dict
except AttributeError:
errors.append(get_missing_field_error(loc))

View File

@ -0,0 +1,100 @@
import sys
import pytest
from dirty_equals import IsDict
from fastapi import Body, FastAPI
from fastapi.testclient import TestClient
app = FastAPI()
DEFAULT = 1234567890
endpoints = []
if sys.hexversion >= 0x30A0000:
from typing import Annotated
@app.post("/api1")
def api1(
integer_or_null: Annotated[int | None, Body(embed=True)] = DEFAULT,
) -> dict:
return {"received": integer_or_null}
endpoints.append("/api1")
if sys.hexversion >= 0x3090000:
from typing import Annotated
@app.post("/api2")
def api2(
integer_or_null: Annotated[int | None, Body(embed=True)] = DEFAULT,
) -> dict:
return {"received": integer_or_null}
endpoints.append("/api2")
@app.post("/api3")
def api3(
integer_or_null: Annotated[int | None, Body(embed=True)] = DEFAULT,
) -> dict:
return {"received": integer_or_null}
endpoints.append("/api3")
@app.post("/api4")
def api4(integer_or_null: int | None = Body(embed=True, default=DEFAULT)) -> dict:
return {"received": integer_or_null}
endpoints.append("/api4")
@app.post("/api5")
def api5(integer: int = Body(embed=True)) -> dict:
return {"received": integer}
client = TestClient(app)
@pytest.mark.parametrize("api", endpoints)
def test_apis(api):
response = client.post(api, json={"integer_or_null": 100})
assert response.status_code == 200, response.text
assert response.json() == {"received": 100}
response = client.post(api, json={"integer_or_null": None})
assert response.status_code == 200, response.text
assert response.json() == {"received": None}
def test_required_field():
response = client.post("/api5", json={"integer": 100})
assert response.status_code == 200, response.text
assert response.json() == {"received": 100}
response = client.post("/api5", json={"integer": None})
assert response.status_code == 422, response.text
assert response.json() == IsDict(
{
"detail": [
{
"loc": ["body", "integer"],
"msg": "Field required",
"type": "missing",
"input": None,
}
]
}
) | IsDict(
{
"detail": [
{
"loc": ["body", "integer"],
"msg": "field required",
"type": "value_error.missing",
}
]
}
)