mirror of https://github.com/tiangolo/fastapi.git
Fix defaults for model-based params
This commit is contained in:
parent
937d3075f9
commit
6cfd51bd3b
|
|
@ -751,7 +751,11 @@ def _is_json_field(field: ModelField) -> bool:
|
|||
|
||||
|
||||
def _get_multidict_value(
|
||||
field: ModelField, values: Mapping[str, Any], alias: str | None = None
|
||||
field: ModelField,
|
||||
values: Mapping[str, Any],
|
||||
alias: str | None = None,
|
||||
*,
|
||||
use_default_when_missing: bool = True,
|
||||
) -> Any:
|
||||
alias = alias or get_validation_alias(field)
|
||||
if (
|
||||
|
|
@ -776,8 +780,9 @@ def _get_multidict_value(
|
|||
):
|
||||
if field.field_info.is_required():
|
||||
return
|
||||
else:
|
||||
if use_default_when_missing:
|
||||
return deepcopy(field.default)
|
||||
return
|
||||
return value
|
||||
|
||||
|
||||
|
|
@ -795,11 +800,13 @@ def request_params_to_args(
|
|||
fields_to_extract = fields
|
||||
single_not_embedded_field = False
|
||||
default_convert_underscores = True
|
||||
is_model_param = False
|
||||
if len(fields) == 1 and lenient_issubclass(
|
||||
first_field.field_info.annotation, BaseModel
|
||||
):
|
||||
fields_to_extract = get_cached_model_fields(first_field.field_info.annotation)
|
||||
single_not_embedded_field = True
|
||||
is_model_param = True
|
||||
# If headers are in a Pydantic model, the way to disable convert_underscores
|
||||
# would be with Header(convert_underscores=False) at the Pydantic model level
|
||||
default_convert_underscores = getattr(
|
||||
|
|
@ -822,7 +829,12 @@ def request_params_to_args(
|
|||
alias = get_validation_alias(field)
|
||||
if alias == field.name:
|
||||
alias = alias.replace("_", "-")
|
||||
value = _get_multidict_value(field, received_params, alias=alias)
|
||||
value = _get_multidict_value(
|
||||
field,
|
||||
received_params,
|
||||
alias=alias,
|
||||
use_default_when_missing=not is_model_param,
|
||||
)
|
||||
if value is not None:
|
||||
params_to_process[get_validation_alias(field)] = value
|
||||
processed_keys.add(alias or get_validation_alias(field))
|
||||
|
|
@ -912,11 +924,15 @@ def _should_embed_body_fields(fields: list[ModelField]) -> bool:
|
|||
async def _extract_form_body(
|
||||
body_fields: list[ModelField],
|
||||
received_body: FormData,
|
||||
*,
|
||||
use_default_when_missing: bool = True,
|
||||
) -> dict[str, Any]:
|
||||
values = {}
|
||||
|
||||
for field in body_fields:
|
||||
value = _get_multidict_value(field, received_body)
|
||||
value = _get_multidict_value(
|
||||
field, received_body, use_default_when_missing=use_default_when_missing
|
||||
)
|
||||
field_info = field.field_info
|
||||
if (
|
||||
isinstance(field_info, params.File)
|
||||
|
|
@ -970,7 +986,16 @@ async def request_body_to_args(
|
|||
fields_to_extract = get_cached_model_fields(first_field.field_info.annotation)
|
||||
|
||||
if isinstance(received_body, FormData):
|
||||
body_to_process = await _extract_form_body(fields_to_extract, received_body)
|
||||
body_to_process = await _extract_form_body(
|
||||
fields_to_extract,
|
||||
received_body,
|
||||
# Keep omitted fields absent so Pydantic can apply defaults without
|
||||
# marking them as explicitly provided on the resulting model.
|
||||
use_default_when_missing=not (
|
||||
single_not_embedded_field
|
||||
and lenient_issubclass(first_field.field_info.annotation, BaseModel)
|
||||
),
|
||||
)
|
||||
|
||||
if single_not_embedded_field:
|
||||
loc: tuple[str, ...] = ("body",)
|
||||
|
|
|
|||
|
|
@ -99,13 +99,13 @@ def test_no_data():
|
|||
"type": "missing",
|
||||
"loc": ["body", "username"],
|
||||
"msg": "Field required",
|
||||
"input": {"tags": ["foo", "bar"], "with": "nothing"},
|
||||
"input": {},
|
||||
},
|
||||
{
|
||||
"type": "missing",
|
||||
"loc": ["body", "lastname"],
|
||||
"msg": "Field required",
|
||||
"input": {"tags": ["foo", "bar"], "with": "nothing"},
|
||||
"input": {},
|
||||
},
|
||||
]
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,96 @@
|
|||
from typing import Annotated
|
||||
|
||||
import pytest
|
||||
from fastapi import Cookie, FastAPI, Form, Header, Query
|
||||
from fastapi.testclient import TestClient
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
class DefaultModel(BaseModel):
|
||||
field_1: bool = True
|
||||
|
||||
|
||||
class InvalidDefaultModel(BaseModel):
|
||||
field_1: Annotated[str, Field(default=0)]
|
||||
|
||||
|
||||
@app.get("/query")
|
||||
def read_query(model: Annotated[DefaultModel, Query()]) -> dict[str, object]:
|
||||
return {"fields_set": sorted(model.model_fields_set), "model": model.model_dump()}
|
||||
|
||||
|
||||
@app.get("/header")
|
||||
def read_header(model: Annotated[DefaultModel, Header()]) -> dict[str, object]:
|
||||
return {"fields_set": sorted(model.model_fields_set), "model": model.model_dump()}
|
||||
|
||||
|
||||
@app.get("/cookie")
|
||||
def read_cookie(model: Annotated[DefaultModel, Cookie()]) -> dict[str, object]:
|
||||
return {"fields_set": sorted(model.model_fields_set), "model": model.model_dump()}
|
||||
|
||||
|
||||
@app.post("/form")
|
||||
def read_form(model: Annotated[DefaultModel, Form()]) -> dict[str, object]:
|
||||
return {"fields_set": sorted(model.model_fields_set), "model": model.model_dump()}
|
||||
|
||||
|
||||
@app.post("/body-invalid-default")
|
||||
def read_body_invalid_default(model: InvalidDefaultModel) -> dict[str, list[str]]:
|
||||
return {"fields_set": sorted(model.model_fields_set)}
|
||||
|
||||
|
||||
@app.post("/form-invalid-default")
|
||||
def read_form_invalid_default(
|
||||
model: Annotated[InvalidDefaultModel, Form()],
|
||||
) -> dict[str, list[str]]:
|
||||
return {"fields_set": sorted(model.model_fields_set)}
|
||||
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("method", "path", "kwargs"),
|
||||
[
|
||||
("get", "/query", {}),
|
||||
("get", "/header", {}),
|
||||
("get", "/cookie", {}),
|
||||
("post", "/form", {"data": {}}),
|
||||
],
|
||||
)
|
||||
def test_missing_model_defaults_not_marked_as_set(
|
||||
method: str, path: str, kwargs: dict[str, object]
|
||||
) -> None:
|
||||
response = getattr(client, method)(path, **kwargs)
|
||||
|
||||
assert response.status_code == 200, response.text
|
||||
assert response.json() == {
|
||||
"fields_set": [],
|
||||
"model": {"field_1": True},
|
||||
}
|
||||
|
||||
|
||||
def test_explicit_form_model_value_is_still_marked_as_set() -> None:
|
||||
response = client.post("/form", data={"field_1": "false"})
|
||||
|
||||
assert response.status_code == 200, response.text
|
||||
assert response.json() == {
|
||||
"fields_set": ["field_1"],
|
||||
"model": {"field_1": False},
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path",
|
||||
["/body-invalid-default", "/form-invalid-default"],
|
||||
)
|
||||
def test_omitted_invalid_defaults_do_not_trigger_validation(path: str) -> None:
|
||||
if path == "/body-invalid-default":
|
||||
response = client.post(path, json={})
|
||||
else:
|
||||
response = client.post(path, data={})
|
||||
|
||||
assert response.status_code == 200, response.text
|
||||
assert response.json() == {"fields_set": []}
|
||||
Loading…
Reference in New Issue