mirror of https://github.com/tiangolo/fastapi.git
Merge ed24198186 into 5420847d9f
This commit is contained in:
commit
1e6366c3b0
|
|
@ -180,13 +180,13 @@ def get_flat_dependant(
|
|||
def _get_flat_fields_from_params(fields: list[ModelField]) -> list[ModelField]:
|
||||
if not fields:
|
||||
return fields
|
||||
first_field = fields[0]
|
||||
if len(fields) == 1 and lenient_issubclass(
|
||||
first_field.field_info.annotation, BaseModel
|
||||
):
|
||||
fields_to_extract = get_cached_model_fields(first_field.field_info.annotation)
|
||||
return fields_to_extract
|
||||
return fields
|
||||
fields_to_extract = []
|
||||
for f in fields:
|
||||
if lenient_issubclass(f.field_info.annotation, BaseModel):
|
||||
fields_to_extract.extend(get_cached_model_fields(f.field_info.annotation))
|
||||
else:
|
||||
fields_to_extract.append(f)
|
||||
return fields_to_extract
|
||||
|
||||
|
||||
def get_flat_params(dependant: Dependant) -> list[ModelField]:
|
||||
|
|
@ -758,32 +758,27 @@ def request_params_to_args(
|
|||
if not fields:
|
||||
return values, errors
|
||||
|
||||
first_field = fields[0]
|
||||
fields_to_extract = fields
|
||||
single_not_embedded_field = False
|
||||
default_convert_underscores = True
|
||||
if len(fields) == 1 and lenient_issubclass(
|
||||
first_field.field_info.annotation, BaseModel
|
||||
):
|
||||
fields_to_extract = get_cached_model_fields(first_field.field_info.annotation)
|
||||
single_not_embedded_field = True
|
||||
# If headers are in a Pydantic model, the way to disable convert_underscores
|
||||
# would be with Header(convert_underscores=False) at the Pydantic model level
|
||||
default_convert_underscores = getattr(
|
||||
first_field.field_info, "convert_underscores", True
|
||||
)
|
||||
|
||||
params_to_process: dict[str, Any] = {}
|
||||
|
||||
fields_to_extract = [
|
||||
(field, cached_field)
|
||||
for field in fields
|
||||
if lenient_issubclass(field.field_info.annotation, BaseModel)
|
||||
for cached_field in get_cached_model_fields(field.field_info.annotation)
|
||||
]
|
||||
|
||||
processed_keys = set()
|
||||
|
||||
for field in fields_to_extract:
|
||||
for parent_field, field in fields_to_extract:
|
||||
alias = None
|
||||
if isinstance(received_params, Headers):
|
||||
# Handle fields extracted from a Pydantic Model for a header, each field
|
||||
# doesn't have a FieldInfo of type Header with the default convert_underscores=True
|
||||
convert_underscores = getattr(
|
||||
field.field_info, "convert_underscores", default_convert_underscores
|
||||
parent_field.field_info,
|
||||
"convert_underscores",
|
||||
default_convert_underscores,
|
||||
)
|
||||
if convert_underscores:
|
||||
alias = get_validation_alias(field)
|
||||
|
|
@ -805,27 +800,24 @@ def request_params_to_args(
|
|||
else:
|
||||
params_to_process[key] = received_params.get(key)
|
||||
|
||||
if single_not_embedded_field:
|
||||
field_info = first_field.field_info
|
||||
assert isinstance(field_info, params.Param), (
|
||||
"Params must be subclasses of Param"
|
||||
)
|
||||
loc: tuple[str, ...] = (field_info.in_.value,)
|
||||
v_, errors_ = _validate_value_with_model_field(
|
||||
field=first_field, value=params_to_process, values=values, loc=loc
|
||||
)
|
||||
return {first_field.name: v_}, errors_
|
||||
|
||||
for field in fields:
|
||||
value = _get_multidict_value(field, received_params)
|
||||
field_info = field.field_info
|
||||
assert isinstance(field_info, params.Param), (
|
||||
"Params must be subclasses of Param"
|
||||
)
|
||||
loc = (field_info.in_.value, get_validation_alias(field))
|
||||
v_, errors_ = _validate_value_with_model_field(
|
||||
field=field, value=value, values=values, loc=loc
|
||||
)
|
||||
|
||||
if lenient_issubclass(field.field_info.annotation, BaseModel):
|
||||
loc: tuple[str, ...] = (field_info.in_.value,)
|
||||
v_, errors_ = _validate_value_with_model_field(
|
||||
field=field, value=params_to_process, values=values, loc=loc
|
||||
)
|
||||
else:
|
||||
value = _get_multidict_value(field, received_params)
|
||||
loc = (field_info.in_.value, get_validation_alias(field))
|
||||
v_, errors_ = _validate_value_with_model_field(
|
||||
field=field, value=value, values=values, loc=loc
|
||||
)
|
||||
|
||||
if errors_:
|
||||
errors.extend(errors_)
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,136 @@
|
|||
from typing import Annotated, Any, Callable
|
||||
|
||||
import pytest
|
||||
from fastapi import APIRouter, Cookie, FastAPI, Header, Query, status
|
||||
from fastapi.testclient import TestClient
|
||||
from pydantic import BaseModel
|
||||
|
||||
app = FastAPI()
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
class NameModel(BaseModel):
|
||||
name: str
|
||||
|
||||
|
||||
class AgeModel(BaseModel):
|
||||
age: int
|
||||
|
||||
|
||||
def add_routes(
|
||||
in_: Callable[..., Any],
|
||||
prefix: str,
|
||||
) -> None:
|
||||
router = APIRouter(prefix=prefix)
|
||||
|
||||
@router.get("/models")
|
||||
async def route_models(
|
||||
name_model: Annotated[NameModel, in_()],
|
||||
age_model: Annotated[AgeModel, in_()],
|
||||
):
|
||||
return {
|
||||
"name": name_model.name,
|
||||
"age": age_model.age,
|
||||
}
|
||||
|
||||
@router.get("/mixed")
|
||||
async def route_mixed(
|
||||
name_model: Annotated[NameModel, in_()],
|
||||
age: Annotated[int, in_()],
|
||||
):
|
||||
return {
|
||||
"name": name_model.name,
|
||||
"age": age,
|
||||
}
|
||||
|
||||
app.include_router(router)
|
||||
|
||||
|
||||
add_routes(Query, "/query")
|
||||
add_routes(Header, "/header")
|
||||
add_routes(Cookie, "/cookie")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("in_", "prefix", "call_arg"),
|
||||
[
|
||||
(Query, "/query", "params"),
|
||||
(Header, "/header", "headers"),
|
||||
(Cookie, "/cookie", "cookies"),
|
||||
],
|
||||
ids=[
|
||||
"query",
|
||||
"header",
|
||||
"cookie",
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"type_",
|
||||
[
|
||||
"models",
|
||||
"mixed",
|
||||
],
|
||||
ids=[
|
||||
"models",
|
||||
"mixed",
|
||||
],
|
||||
)
|
||||
def test_multiple_params(in_, prefix, call_arg, type_):
|
||||
params = {"name": "John", "age": "42"}
|
||||
kwargs = {}
|
||||
|
||||
if call_arg == "cookies":
|
||||
client.cookies = params
|
||||
else:
|
||||
kwargs[call_arg] = params
|
||||
|
||||
response = client.get(f"{prefix}/{type_}", **kwargs)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert response.json() == {"name": "John", "age": 42}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("prefix", "in_"),
|
||||
[
|
||||
("/query", "query"),
|
||||
("/header", "header"),
|
||||
("/cookie", "cookie"),
|
||||
],
|
||||
ids=[
|
||||
"query",
|
||||
"header",
|
||||
"cookie",
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"type_",
|
||||
[
|
||||
"models",
|
||||
"mixed",
|
||||
],
|
||||
ids=[
|
||||
"models",
|
||||
"mixed",
|
||||
],
|
||||
)
|
||||
def test_openapi_schema(prefix, in_, type_):
|
||||
response = client.get("/openapi.json")
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
|
||||
schema = response.json()
|
||||
assert schema["paths"][f"{prefix}/{type_}"]["get"]["parameters"] == [
|
||||
{
|
||||
"required": True,
|
||||
"in": in_,
|
||||
"name": "name",
|
||||
"schema": {"title": "Name", "type": "string"},
|
||||
},
|
||||
{
|
||||
"required": True,
|
||||
"in": in_,
|
||||
"name": "age",
|
||||
"schema": {"title": "Age", "type": "integer"},
|
||||
},
|
||||
]
|
||||
Loading…
Reference in New Issue