mirror of https://github.com/tiangolo/fastapi.git
Merge fa9af08f8d into da4135ce1e
This commit is contained in:
commit
2fbba95ecb
|
|
@ -23,6 +23,7 @@ from .v2 import ModelField as ModelField
|
|||
from .v2 import PydanticSchemaGenerationError as PydanticSchemaGenerationError
|
||||
from .v2 import RequiredParam as RequiredParam
|
||||
from .v2 import Undefined as Undefined
|
||||
from .v2 import UndefinedType as UndefinedType
|
||||
from .v2 import Url as Url
|
||||
from .v2 import copy_field_info as copy_field_info
|
||||
from .v2 import create_body_model as create_body_model
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ from pydantic.fields import FieldInfo as FieldInfo
|
|||
from pydantic.json_schema import GenerateJsonSchema as GenerateJsonSchema
|
||||
from pydantic.json_schema import JsonSchemaValue as JsonSchemaValue
|
||||
from pydantic_core import CoreSchema as CoreSchema
|
||||
from pydantic_core import PydanticUndefined
|
||||
from pydantic_core import PydanticUndefined, PydanticUndefinedType
|
||||
from pydantic_core import Url as Url
|
||||
from pydantic_core.core_schema import (
|
||||
with_info_plain_validator_function as with_info_plain_validator_function,
|
||||
|
|
@ -38,6 +38,7 @@ from pydantic_core.core_schema import (
|
|||
|
||||
RequiredParam = PydanticUndefined
|
||||
Undefined = PydanticUndefined
|
||||
UndefinedType = PydanticUndefinedType
|
||||
evaluate_forwardref = eval_type_lenient
|
||||
|
||||
# TODO: remove when dropping support for Pydantic < v2.12.3
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import dataclasses
|
||||
import inspect
|
||||
import sys
|
||||
import types
|
||||
from collections.abc import Callable, Mapping, Sequence
|
||||
from contextlib import AsyncExitStack, contextmanager
|
||||
from copy import copy, deepcopy
|
||||
|
|
@ -21,6 +22,7 @@ from fastapi._compat import (
|
|||
ModelField,
|
||||
RequiredParam,
|
||||
Undefined,
|
||||
UndefinedType,
|
||||
copy_field_info,
|
||||
create_body_model,
|
||||
evaluate_forwardref,
|
||||
|
|
@ -566,7 +568,7 @@ async def solve_dependencies(
|
|||
*,
|
||||
request: Request | WebSocket,
|
||||
dependant: Dependant,
|
||||
body: dict[str, Any] | FormData | None = None,
|
||||
body: dict[str, Any] | FormData | UndefinedType | None = None,
|
||||
background_tasks: StarletteBackgroundTasks | None = None,
|
||||
response: Response | None = None,
|
||||
dependency_overrides_provider: Any | None = None,
|
||||
|
|
@ -702,10 +704,32 @@ async def solve_dependencies(
|
|||
)
|
||||
|
||||
|
||||
if sys.hexversion >= 0x030A0000 and sys.hexversion < 0x030E0000:
|
||||
|
||||
def _allows_none(field: ModelField) -> bool:
|
||||
origin = get_origin(field.field_info.annotation)
|
||||
return (origin is Union or origin is types.UnionType) and type(
|
||||
None
|
||||
) in get_args(field.field_info.annotation)
|
||||
|
||||
else:
|
||||
|
||||
def _allows_none(field: ModelField) -> bool:
|
||||
origin = get_origin(field.field_info.annotation)
|
||||
return origin is Union and type(None) in get_args(field.field_info.annotation)
|
||||
|
||||
|
||||
def _validate_value_with_model_field(
|
||||
*, field: ModelField, value: Any, values: dict[str, Any], loc: tuple[str, ...]
|
||||
) -> tuple[Any, list[Any]]:
|
||||
if value is Undefined:
|
||||
if field.field_info.is_required():
|
||||
return None, [get_missing_field_error(loc=loc)]
|
||||
else:
|
||||
return deepcopy(field.default), []
|
||||
if value is None:
|
||||
if _allows_none(field):
|
||||
return value, []
|
||||
if field.field_info.is_required():
|
||||
return None, [get_missing_field_error(loc=loc)]
|
||||
else:
|
||||
|
|
@ -721,6 +745,7 @@ def _get_multidict_value(
|
|||
field: ModelField, values: Mapping[str, Any], alias: str | None = None
|
||||
) -> Any:
|
||||
alias = alias or get_validation_alias(field)
|
||||
value: Any
|
||||
if (
|
||||
(not _is_json_field(field))
|
||||
and field_annotation_is_sequence(field.field_info.annotation)
|
||||
|
|
@ -728,9 +753,9 @@ def _get_multidict_value(
|
|||
):
|
||||
value = values.getlist(alias)
|
||||
else:
|
||||
value = values.get(alias, None)
|
||||
value = values.get(alias, Undefined)
|
||||
if (
|
||||
value is None
|
||||
value is Undefined
|
||||
or (
|
||||
isinstance(field.field_info, params.Form)
|
||||
and isinstance(value, str) # For type checks
|
||||
|
|
@ -742,7 +767,7 @@ def _get_multidict_value(
|
|||
)
|
||||
):
|
||||
if field.field_info.is_required():
|
||||
return
|
||||
return Undefined
|
||||
else:
|
||||
return deepcopy(field.default)
|
||||
return value
|
||||
|
|
@ -790,7 +815,7 @@ def request_params_to_args(
|
|||
if alias == field.name:
|
||||
alias = alias.replace("_", "-")
|
||||
value = _get_multidict_value(field, received_params, alias=alias)
|
||||
if value is not None:
|
||||
if value is not Undefined and value is not None:
|
||||
params_to_process[get_validation_alias(field)] = value
|
||||
processed_keys.add(alias or get_validation_alias(field))
|
||||
|
||||
|
|
@ -902,7 +927,7 @@ async def _extract_form_body(
|
|||
for sub_value in value:
|
||||
results.append(await sub_value.read())
|
||||
value = serialize_sequence_value(field=field, value=results)
|
||||
if value is not None:
|
||||
if value is not Undefined and value is not None:
|
||||
values[get_validation_alias(field)] = value
|
||||
field_aliases = {get_validation_alias(field) for field in body_fields}
|
||||
for key in received_body.keys():
|
||||
|
|
@ -917,7 +942,7 @@ async def _extract_form_body(
|
|||
|
||||
async def request_body_to_args(
|
||||
body_fields: list[ModelField],
|
||||
received_body: dict[str, Any] | FormData | None,
|
||||
received_body: dict[str, Any] | FormData | UndefinedType | None,
|
||||
embed_body_fields: bool,
|
||||
) -> tuple[dict[str, Any], list[dict[str, Any]]]:
|
||||
values: dict[str, Any] = {}
|
||||
|
|
@ -947,10 +972,12 @@ async def request_body_to_args(
|
|||
return {first_field.name: v_}, errors_
|
||||
for field in body_fields:
|
||||
loc = ("body", get_validation_alias(field))
|
||||
value: Any | None = None
|
||||
if body_to_process is not None:
|
||||
value: Any = Undefined
|
||||
if body_to_process is not None and not isinstance(
|
||||
body_to_process, UndefinedType
|
||||
):
|
||||
try:
|
||||
value = body_to_process.get(get_validation_alias(field))
|
||||
value = body_to_process.get(get_validation_alias(field), Undefined)
|
||||
# If the received body is a list, not a dict
|
||||
except AttributeError:
|
||||
errors.append(get_missing_field_error(loc))
|
||||
|
|
|
|||
|
|
@ -0,0 +1,100 @@
|
|||
import sys
|
||||
|
||||
import pytest
|
||||
from dirty_equals import IsDict
|
||||
from fastapi import Body, FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
app = FastAPI()
|
||||
DEFAULT = 1234567890
|
||||
|
||||
endpoints = []
|
||||
|
||||
if sys.hexversion >= 0x30A0000:
|
||||
from typing import Annotated
|
||||
|
||||
@app.post("/api1")
|
||||
def api1(
|
||||
integer_or_null: Annotated[int | None, Body(embed=True)] = DEFAULT,
|
||||
) -> dict:
|
||||
return {"received": integer_or_null}
|
||||
|
||||
endpoints.append("/api1")
|
||||
|
||||
|
||||
if sys.hexversion >= 0x3090000:
|
||||
from typing import Annotated
|
||||
|
||||
@app.post("/api2")
|
||||
def api2(
|
||||
integer_or_null: Annotated[int | None, Body(embed=True)] = DEFAULT,
|
||||
) -> dict:
|
||||
return {"received": integer_or_null}
|
||||
|
||||
endpoints.append("/api2")
|
||||
|
||||
@app.post("/api3")
|
||||
def api3(
|
||||
integer_or_null: Annotated[int | None, Body(embed=True)] = DEFAULT,
|
||||
) -> dict:
|
||||
return {"received": integer_or_null}
|
||||
|
||||
endpoints.append("/api3")
|
||||
|
||||
|
||||
@app.post("/api4")
|
||||
def api4(integer_or_null: int | None = Body(embed=True, default=DEFAULT)) -> dict:
|
||||
return {"received": integer_or_null}
|
||||
|
||||
|
||||
endpoints.append("/api4")
|
||||
|
||||
|
||||
@app.post("/api5")
|
||||
def api5(integer: int = Body(embed=True)) -> dict:
|
||||
return {"received": integer}
|
||||
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("api", endpoints)
|
||||
def test_apis(api):
|
||||
response = client.post(api, json={"integer_or_null": 100})
|
||||
assert response.status_code == 200, response.text
|
||||
assert response.json() == {"received": 100}
|
||||
|
||||
response = client.post(api, json={"integer_or_null": None})
|
||||
assert response.status_code == 200, response.text
|
||||
assert response.json() == {"received": None}
|
||||
|
||||
|
||||
def test_required_field():
|
||||
response = client.post("/api5", json={"integer": 100})
|
||||
assert response.status_code == 200, response.text
|
||||
assert response.json() == {"received": 100}
|
||||
|
||||
response = client.post("/api5", json={"integer": None})
|
||||
assert response.status_code == 422, response.text
|
||||
assert response.json() == IsDict(
|
||||
{
|
||||
"detail": [
|
||||
{
|
||||
"loc": ["body", "integer"],
|
||||
"msg": "Field required",
|
||||
"type": "missing",
|
||||
"input": None,
|
||||
}
|
||||
]
|
||||
}
|
||||
) | IsDict(
|
||||
{
|
||||
"detail": [
|
||||
{
|
||||
"loc": ["body", "integer"],
|
||||
"msg": "field required",
|
||||
"type": "value_error.missing",
|
||||
}
|
||||
]
|
||||
}
|
||||
)
|
||||
Loading…
Reference in New Issue