From a67f8cf074b6dbe26fc2a3c75f4e1466179ac704 Mon Sep 17 00:00:00 2001 From: JSCU-CNI <121175071+JSCU-CNI@users.noreply.github.com> Date: Mon, 14 Jul 2025 17:16:22 +0200 Subject: [PATCH] Test and resolve support for `Header(convert_underscores=False)` --- fastapi/dependencies/utils.py | 33 +++--- fastapi/openapi/utils.py | 26 ++--- tests/test_multiple_parameter_models.py | 149 ++++++++++++++---------- 3 files changed, 118 insertions(+), 90 deletions(-) diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index 27ec3f5f0..d8fc86931 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -222,13 +222,27 @@ def get_flat_dependant( def _get_flat_fields_from_params(fields: List[ModelField]) -> List[ModelField]: if not fields: return fields + + return [field for field, _ in _get_flat_fields_from_params_with_origin(fields)] + + +def _get_flat_fields_from_params_with_origin( + fields: Sequence[ModelField], +) -> Sequence[Tuple[ModelField, ModelField]]: + """Same as :func:`_get_flat_fields_from_params`, but returns the result + as tuples ``(flat_field, origin_field)``. + """ result = [] for field in fields: if _is_model_class(field.type_): - fields_to_extract = get_cached_model_fields(field.type_) - result.extend(fields_to_extract) + result.extend( + [ + (model_field, field) + for model_field in get_cached_model_fields(field.type_) + ] + ) else: - result.append(field) + result.append((field, field)) return result @@ -785,18 +799,7 @@ def request_params_to_args( if not fields: return values, errors - fields_to_extract = [] - for field in fields: - if lenient_issubclass(field.type_, BaseModel): - fields_to_extract.extend( - [ - (model_field, field) - for model_field in get_cached_model_fields(field.type_) - ] - ) - else: - fields_to_extract.append((field, field)) - + fields_to_extract = _get_flat_fields_from_params_with_origin(fields) params_to_process: Dict[str, Any] = {} processed_keys = set() diff --git a/fastapi/openapi/utils.py b/fastapi/openapi/utils.py index dbc93d289..cc137971e 100644 --- a/fastapi/openapi/utils.py +++ b/fastapi/openapi/utils.py @@ -16,7 +16,7 @@ from fastapi._compat import ( from fastapi.datastructures import DefaultPlaceholder from fastapi.dependencies.models import Dependant from fastapi.dependencies.utils import ( - _get_flat_fields_from_params, + _get_flat_fields_from_params_with_origin, get_flat_dependant, get_flat_params, ) @@ -31,7 +31,6 @@ from fastapi.utils import ( generate_operation_id_for_path, is_body_allowed_for_status_code, ) -from pydantic import BaseModel from starlette.responses import JSONResponse from starlette.routing import BaseRoute from typing_extensions import Literal @@ -103,25 +102,22 @@ def _get_openapi_operation_parameters( ) -> List[Dict[str, Any]]: parameters = [] flat_dependant = get_flat_dependant(dependant, skip_repeats=True) - path_params = _get_flat_fields_from_params(flat_dependant.path_params) - query_params = _get_flat_fields_from_params(flat_dependant.query_params) - header_params = _get_flat_fields_from_params(flat_dependant.header_params) - cookie_params = _get_flat_fields_from_params(flat_dependant.cookie_params) + path_params = _get_flat_fields_from_params_with_origin(flat_dependant.path_params) + query_params = _get_flat_fields_from_params_with_origin(flat_dependant.query_params) + header_params = _get_flat_fields_from_params_with_origin( + flat_dependant.header_params + ) + cookie_params = _get_flat_fields_from_params_with_origin( + flat_dependant.cookie_params + ) parameter_groups = [ (ParamTypes.path, path_params), (ParamTypes.query, query_params), (ParamTypes.header, header_params), (ParamTypes.cookie, cookie_params), ] - default_convert_underscores = True - if len(flat_dependant.header_params) == 1: - first_field = flat_dependant.header_params[0] - if lenient_issubclass(first_field.type_, BaseModel): - default_convert_underscores = getattr( - first_field.field_info, "convert_underscores", True - ) for param_type, param_group in parameter_groups: - for param in param_group: + for param, base_field in param_group: field_info = param.field_info # field_info = cast(Param, field_info) if not getattr(field_info, "include_in_schema", True): @@ -136,7 +132,7 @@ def _get_openapi_operation_parameters( convert_underscores = getattr( param.field_info, "convert_underscores", - default_convert_underscores, + getattr(base_field.field_info, "convert_underscores", True), ) if ( param_type == ParamTypes.header diff --git a/tests/test_multiple_parameter_models.py b/tests/test_multiple_parameter_models.py index dd3c64e8c..4194e03fe 100644 --- a/tests/test_multiple_parameter_models.py +++ b/tests/test_multiple_parameter_models.py @@ -8,15 +8,15 @@ app = FastAPI() class Model(BaseModel): - field1: int = Field(0) + field_1: int = Field(0) class Model2(BaseModel): - field2: int = Field(0) + field_2: int = Field(0) class ModelNoExtra(BaseModel): - field1: int = Field(0) + field_1: int = Field(0) if PYDANTIC_V2: model_config = ConfigDict(extra="forbid") else: @@ -25,50 +25,60 @@ class ModelNoExtra(BaseModel): extra = "forbid" -for param in (Query, Header, Cookie): +def HeaderU(*args, **kwargs): + """Header callable that ensures that convert_underscores is False.""" + return Header(*args, convert_underscores=False, **kwargs) + + +for param in (Query, Header, HeaderU, Cookie): # Generates 4 views for all three Query, Header, and Cookie params: # i.e. /query-depends/, /query-arguments/, /query-argument/, /query-models/ for query - def dependency(field2: int = param(0)): - return field2 + def dependency(field_2: int = param(0, title="Field 2")): + return field_2 @app.get(f"/{param.__name__.lower()}-depends/") async def with_depends(model1: Model = param(), dependency=Depends(dependency)): """Model1 is specified via Query()/Header()/Cookie() and Model2 through Depends""" - return {"field1": model1.field1, "field2": dependency} + return {"field_1": model1.field_1, "field_2": dependency} + + @app.get(f"/{param.__name__.lower()}-arguments/") + async def with_argument( + field_1: int = param(0, title="Field 1"), + field_2: int = param(0, title="Field 2"), + ): + """Model1 and Model2 are specified as direct arguments (sanity check)""" + return {"field_1": field_1, "field_2": field_2} @app.get(f"/{param.__name__.lower()}-argument/") - async def with_model_and_argument(model1: Model = param(), field2: int = param(0)): + async def with_model_and_argument( + model1: Model = param(), field_2: int = param(0, title="Field 2") + ): """Model1 is specified via Query()/Header()/Cookie() and Model2 as direct argument""" - return {"field1": model1.field1, "field2": field2} + return {"field_1": model1.field_1, "field_2": field_2} @app.get(f"/{param.__name__.lower()}-models/") async def with_models(model1: Model = param(), model2: Model2 = param()): """Model1 and Model2 are specified via Query()/Header()/Cookie()""" - return {"field1": model1.field1, "field2": model2.field2} - - @app.get(f"/{param.__name__.lower()}-arguments/") - async def with_argument(field1: int = param(0), field2: int = param(0)): - """Model1 and Model2 are specified as direct arguments (sanity check)""" - return {"field1": field1, "field2": field2} + return {"field_1": model1.field_1, "field_2": model2.field_2} @app.get("/mixed/") async def mixed_model_sources(model1: Model = Query(), model2: Model2 = Header()): """Model1 is specified as Query(), Model2 as Header()""" - return {"field1": model1.field1, "field2": model2.field2} + return {"field_1": model1.field_1, "field_2": model2.field_2} @app.get("/duplicate/") async def duplicate_name(model: Model = Query(), same_model: Model = Query()): """Model1 is specified twice in Query()""" - return {"field1": model.field1, "duplicate": same_model.field1} + return {"field_1": model.field_1, "duplicate": same_model.field_1} @app.get("/duplicate2/") async def duplicate_name2(model: Model = Query(), same_model: Model = Header()): """Model1 is specified twice, once in Query(), once in Header()""" - return {"field1": model.field1, "duplicate": same_model.field1} + return {"field_1": model.field_1, "duplicate": same_model.field_1} @app.get("/duplicate-no-extra/") @@ -76,7 +86,7 @@ async def duplicate_name_no_extra( model: Model = Query(), same_model: ModelNoExtra = Query() ): """Uses Model and ModelNoExtra, but they have overlapping names""" - return {"field1": model.field1, "duplicate": same_model.field1} + return {"field_1": model.field_1, "duplicate": same_model.field_1} @app.get("/no-extra/") @@ -93,9 +103,9 @@ client = TestClient(app) ["/query-depends/", "/query-arguments/", "/query-argument/", "/query-models/"], ) def test_query_depends(path): - response = client.get(path, params={"field1": 0, "field2": 1}) + response = client.get(path, params={"field_1": 0, "field_2": 1}) assert response.status_code == 200 - assert response.json() == {"field1": 0, "field2": 1} + assert response.json() == {"field_1": 0, "field_2": 1} @pytest.mark.parametrize( @@ -103,9 +113,24 @@ def test_query_depends(path): ["/header-depends/", "/header-arguments/", "/header-argument/", "/header-models/"], ) def test_header_depends(path): - response = client.get(path, headers={"field1": "0", "field2": "1"}) + response = client.get(path, headers={"field-1": "0", "field-2": "1"}) assert response.status_code == 200 - assert response.json() == {"field1": 0, "field2": 1} + assert response.json() == {"field_1": 0, "field_2": 1} + + +@pytest.mark.parametrize( + "path", + [ + "/headeru-depends/", + "/headeru-arguments/", + "/headeru-argument/", + "/headeru-models/", + ], +) +def test_headeru_depends(path): + response = client.get(path, headers={"field_1": "0", "field_2": "1"}) + assert response.status_code == 200 + assert response.json() == {"field_1": 0, "field_2": 1} @pytest.mark.parametrize( @@ -113,16 +138,16 @@ def test_header_depends(path): ["/cookie-depends/", "/cookie-arguments/", "/cookie-argument/", "/cookie-models/"], ) def test_cookie_depends(path): - client.cookies = {"field1": "0", "field2": "1"} + client.cookies = {"field_1": "0", "field_2": "1"} response = client.get(path) assert response.status_code == 200 - assert response.json() == {"field1": 0, "field2": 1} + assert response.json() == {"field_1": 0, "field_2": 1} def test_mixed(): - response = client.get("/mixed/", params={"field1": 0}, headers={"field2": "1"}) + response = client.get("/mixed/", params={"field_1": 0}, headers={"field-2": "1"}) assert response.status_code == 200 - assert response.json() == {"field1": 0, "field2": 1} + assert response.json() == {"field_1": 0, "field_2": 1} @pytest.mark.parametrize( @@ -130,20 +155,20 @@ def test_mixed(): ["/duplicate/", "/duplicate2/", "/duplicate-no-extra/"], ) def test_duplicate_name(path): - response = client.get(path, params={"field1": 0}) + response = client.get(path, params={"field_1": 0}) assert response.status_code == 200 - assert response.json() == {"field1": 0, "duplicate": 0} + assert response.json() == {"field_1": 0, "duplicate": 0} def test_no_extra(): - response = client.get("/no-extra/", params={"field1": 0, "field2": 1}) + response = client.get("/no-extra/", params={"field_1": 0, "field_2": 1}) assert response.status_code == 422 if PYDANTIC_V2: assert response.json() == { "detail": [ { "input": "1", - "loc": ["query", "field2"], + "loc": ["query", "field_2"], "msg": "Extra inputs are not permitted", "type": "extra_forbidden", } @@ -153,7 +178,7 @@ def test_no_extra(): assert response.json() == { "detail": [ { - "loc": ["query", "field2"], + "loc": ["query", "field_2"], "msg": "extra fields not permitted", "type": "value_error.extra", } @@ -162,37 +187,41 @@ def test_no_extra(): @pytest.mark.parametrize( - ("path", "in_"), + ("path", "in_", "convert_underscores"), [ - ("/query-depends/", "query"), - ("/query-arguments/", "query"), - ("/query-argument/", "query"), - ("/query-models/", "query"), - ("/header-depends/", "header"), - ("/header-arguments/", "header"), - ("/header-argument/", "header"), - ("/header-models/", "header"), - ("/cookie-depends/", "cookie"), - ("/cookie-arguments/", "cookie"), - ("/cookie-argument/", "cookie"), - ("/cookie-models/", "cookie"), + ("/query-depends/", "query", False), + ("/query-arguments/", "query", False), + ("/query-argument/", "query", False), + ("/query-models/", "query", False), + ("/header-depends/", "header", True), + ("/header-arguments/", "header", True), + ("/header-argument/", "header", True), + ("/header-models/", "header", True), + ("/headeru-depends/", "header", False), + ("/headeru-arguments/", "header", False), + ("/headeru-argument/", "header", False), + ("/headeru-models/", "header", False), + ("/cookie-depends/", "cookie", False), + ("/cookie-arguments/", "cookie", False), + ("/cookie-argument/", "cookie", False), + ("/cookie-models/", "cookie", False), ], ) -def test_parameters_openapi_schema(path, in_): +def test_parameters_openapi_schema(path, in_, convert_underscores): response = client.get("/openapi.json") assert response.status_code == 200, response.text assert response.json()["paths"][path]["get"]["parameters"] == [ { - "name": "field1", + "name": "field-1" if convert_underscores else "field_1", "in": in_, "required": False, - "schema": {"type": "integer", "default": 0, "title": "Field1"}, + "schema": {"type": "integer", "default": 0, "title": "Field 1"}, }, { - "name": "field2", + "name": "field-2" if convert_underscores else "field_2", "in": in_, "required": False, - "schema": {"type": "integer", "default": 0, "title": "Field2"}, + "schema": {"type": "integer", "default": 0, "title": "Field 2"}, }, ] @@ -202,16 +231,16 @@ def test_parameters_openapi_mixed(): assert response.status_code == 200, response.text assert response.json()["paths"]["/mixed/"]["get"]["parameters"] == [ { - "name": "field1", + "name": "field_1", "in": "query", "required": False, - "schema": {"type": "integer", "default": 0, "title": "Field1"}, + "schema": {"type": "integer", "default": 0, "title": "Field 1"}, }, { - "name": "field2", + "name": "field-2", "in": "header", "required": False, - "schema": {"type": "integer", "default": 0, "title": "Field2"}, + "schema": {"type": "integer", "default": 0, "title": "Field 2"}, }, ] @@ -221,10 +250,10 @@ def test_parameters_openapi_duplicate_name(): assert response.status_code == 200, response.text assert response.json()["paths"]["/duplicate/"]["get"]["parameters"] == [ { - "name": "field1", + "name": "field_1", "in": "query", "required": False, - "schema": {"type": "integer", "default": 0, "title": "Field1"}, + "schema": {"type": "integer", "default": 0, "title": "Field 1"}, }, ] @@ -234,15 +263,15 @@ def test_parameters_openapi_duplicate_name2(): assert response.status_code == 200, response.text assert response.json()["paths"]["/duplicate2/"]["get"]["parameters"] == [ { - "name": "field1", + "name": "field_1", "in": "query", "required": False, - "schema": {"type": "integer", "default": 0, "title": "Field1"}, + "schema": {"type": "integer", "default": 0, "title": "Field 1"}, }, { - "name": "field1", + "name": "field-1", "in": "header", "required": False, - "schema": {"type": "integer", "default": 0, "title": "Field1"}, + "schema": {"type": "integer", "default": 0, "title": "Field 1"}, }, ]