diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index ab18ec2db..07926a2eb 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -180,13 +180,13 @@ def get_flat_dependant( def _get_flat_fields_from_params(fields: list[ModelField]) -> list[ModelField]: if not fields: return fields - first_field = fields[0] - if len(fields) == 1 and lenient_issubclass( - first_field.field_info.annotation, BaseModel - ): - fields_to_extract = get_cached_model_fields(first_field.field_info.annotation) - return fields_to_extract - return fields + fields_to_extract = [] + for f in fields: + if lenient_issubclass(f.field_info.annotation, BaseModel): + fields_to_extract.extend(get_cached_model_fields(f.field_info.annotation)) + else: + fields_to_extract.append(f) + return fields_to_extract def get_flat_params(dependant: Dependant) -> list[ModelField]: @@ -758,32 +758,27 @@ def request_params_to_args( if not fields: return values, errors - first_field = fields[0] - fields_to_extract = fields - single_not_embedded_field = False default_convert_underscores = True - if len(fields) == 1 and lenient_issubclass( - first_field.field_info.annotation, BaseModel - ): - fields_to_extract = get_cached_model_fields(first_field.field_info.annotation) - single_not_embedded_field = True - # If headers are in a Pydantic model, the way to disable convert_underscores - # would be with Header(convert_underscores=False) at the Pydantic model level - default_convert_underscores = getattr( - first_field.field_info, "convert_underscores", True - ) - params_to_process: dict[str, Any] = {} + fields_to_extract = [ + (field, cached_field) + for field in fields + if lenient_issubclass(field.field_info.annotation, BaseModel) + for cached_field in get_cached_model_fields(field.field_info.annotation) + ] + processed_keys = set() - for field in fields_to_extract: + for parent_field, field in fields_to_extract: alias = None if isinstance(received_params, Headers): # Handle fields extracted from a Pydantic Model for a header, each field # doesn't have a FieldInfo of type Header with the default convert_underscores=True convert_underscores = getattr( - field.field_info, "convert_underscores", default_convert_underscores + parent_field.field_info, + "convert_underscores", + default_convert_underscores, ) if convert_underscores: alias = get_validation_alias(field) @@ -805,27 +800,24 @@ def request_params_to_args( else: params_to_process[key] = received_params.get(key) - if single_not_embedded_field: - field_info = first_field.field_info - assert isinstance(field_info, params.Param), ( - "Params must be subclasses of Param" - ) - loc: tuple[str, ...] = (field_info.in_.value,) - v_, errors_ = _validate_value_with_model_field( - field=first_field, value=params_to_process, values=values, loc=loc - ) - return {first_field.name: v_}, errors_ - for field in fields: - value = _get_multidict_value(field, received_params) field_info = field.field_info assert isinstance(field_info, params.Param), ( "Params must be subclasses of Param" ) - loc = (field_info.in_.value, get_validation_alias(field)) - v_, errors_ = _validate_value_with_model_field( - field=field, value=value, values=values, loc=loc - ) + + if lenient_issubclass(field.field_info.annotation, BaseModel): + loc: tuple[str, ...] = (field_info.in_.value,) + v_, errors_ = _validate_value_with_model_field( + field=field, value=params_to_process, values=values, loc=loc + ) + else: + value = _get_multidict_value(field, received_params) + loc = (field_info.in_.value, get_validation_alias(field)) + v_, errors_ = _validate_value_with_model_field( + field=field, value=value, values=values, loc=loc + ) + if errors_: errors.extend(errors_) else: diff --git a/tests/test_multiple_params_models.py b/tests/test_multiple_params_models.py new file mode 100644 index 000000000..021a6b944 --- /dev/null +++ b/tests/test_multiple_params_models.py @@ -0,0 +1,136 @@ +from typing import Annotated, Any, Callable + +import pytest +from fastapi import APIRouter, Cookie, FastAPI, Header, Query, status +from fastapi.testclient import TestClient +from pydantic import BaseModel + +app = FastAPI() +client = TestClient(app) + + +class NameModel(BaseModel): + name: str + + +class AgeModel(BaseModel): + age: int + + +def add_routes( + in_: Callable[..., Any], + prefix: str, +) -> None: + router = APIRouter(prefix=prefix) + + @router.get("/models") + async def route_models( + name_model: Annotated[NameModel, in_()], + age_model: Annotated[AgeModel, in_()], + ): + return { + "name": name_model.name, + "age": age_model.age, + } + + @router.get("/mixed") + async def route_mixed( + name_model: Annotated[NameModel, in_()], + age: Annotated[int, in_()], + ): + return { + "name": name_model.name, + "age": age, + } + + app.include_router(router) + + +add_routes(Query, "/query") +add_routes(Header, "/header") +add_routes(Cookie, "/cookie") + + +@pytest.mark.parametrize( + ("in_", "prefix", "call_arg"), + [ + (Query, "/query", "params"), + (Header, "/header", "headers"), + (Cookie, "/cookie", "cookies"), + ], + ids=[ + "query", + "header", + "cookie", + ], +) +@pytest.mark.parametrize( + "type_", + [ + "models", + "mixed", + ], + ids=[ + "models", + "mixed", + ], +) +def test_multiple_params(in_, prefix, call_arg, type_): + params = {"name": "John", "age": "42"} + kwargs = {} + + if call_arg == "cookies": + client.cookies = params + else: + kwargs[call_arg] = params + + response = client.get(f"{prefix}/{type_}", **kwargs) + + assert response.status_code == status.HTTP_200_OK + assert response.json() == {"name": "John", "age": 42} + + +@pytest.mark.parametrize( + ("prefix", "in_"), + [ + ("/query", "query"), + ("/header", "header"), + ("/cookie", "cookie"), + ], + ids=[ + "query", + "header", + "cookie", + ], +) +@pytest.mark.parametrize( + "type_", + [ + "models", + "mixed", + ], + ids=[ + "models", + "mixed", + ], +) +def test_openapi_schema(prefix, in_, type_): + response = client.get("/openapi.json") + + assert response.status_code == status.HTTP_200_OK + + schema = response.json() + assert schema["paths"][f"{prefix}/{type_}"]["get"]["parameters"] == [ + { + "required": True, + "in": in_, + "name": "name", + "schema": {"title": "Name", "type": "string"}, + }, + { + "required": True, + "in": in_, + "name": "age", + "schema": {"title": "Age", "type": "integer"}, + }, + ]