From baa5cd2ca6c67c760387c1bf47ef56037c94343c Mon Sep 17 00:00:00 2001 From: Daniyar Yeralin Date: Wed, 12 Aug 2020 14:37:07 -0400 Subject: [PATCH 1/4] Introduce mapping shapes --- fastapi/dependencies/utils.py | 30 ++++++++++++++++++++++++++++-- 1 file changed, 28 insertions(+), 2 deletions(-) diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index 7c9f7e847..d09ccde8f 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -41,7 +41,7 @@ from pydantic.fields import ( SHAPE_TUPLE_ELLIPSIS, FieldInfo, ModelField, - Required, + Required, SHAPE_MAPPING, ) from pydantic.schema import get_annotation_from_field_info from pydantic.typing import ForwardRef, evaluate_forwardref @@ -69,6 +69,13 @@ sequence_shape_to_type = { SHAPE_TUPLE_ELLIPSIS: list, } +mapping_shapes = { + SHAPE_MAPPING +} +mapping_types = (dict) +mapping_shapes_to_type = { + SHAPE_MAPPING: dict +} multipart_not_installed_error = ( 'Form data requires "python-multipart" to be installed. \n' @@ -240,6 +247,20 @@ def is_scalar_sequence_field(field: ModelField) -> bool: return False +def is_scalar_mapping_field(field: ModelField) -> bool: + if (field.shape in mapping_shapes) and not lenient_issubclass( + field.type_, BaseModel + ): + if field.sub_fields is not None: + for sub_field in field.sub_fields: + if not is_scalar_field(sub_field): + return False + return True + if lenient_issubclass(field.type_, mapping_types): + return True + return False + + def get_typed_signature(call: Callable) -> inspect.Signature: signature = inspect.signature(call) globalns = getattr(call, "__globals__", {}) @@ -324,7 +345,8 @@ def get_dependant( add_param_to_fields(field=param_field, dependant=dependant) elif isinstance( param.default, (params.Query, params.Header) - ) and is_scalar_sequence_field(param_field): + ) and (is_scalar_sequence_field(param_field) + or is_scalar_mapping_field(param_field)): add_param_to_fields(field=param_field, dependant=dependant) else: field_info = param_field.field_info @@ -610,6 +632,10 @@ def request_params_to_args( received_params, (QueryParams, Headers) ): value = received_params.getlist(field.alias) or field.default + elif is_scalar_mapping_field(field) and isinstance( + received_params, (QueryParams,) + ): + value = received_params._dict else: value = received_params.get(field.alias) field_info = field.field_info From e8c83f100e35f012ec3716d9341aee56fb8ddb1b Mon Sep 17 00:00:00 2001 From: Daniyar Yeralin Date: Wed, 12 Aug 2020 14:37:24 -0400 Subject: [PATCH 2/4] Cover mapping shapes with tests --- tests/main.py | 7 ++++++- tests/test_query.py | 2 ++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/tests/main.py b/tests/main.py index d5603d0e6..fdd16f2d2 100644 --- a/tests/main.py +++ b/tests/main.py @@ -1,5 +1,5 @@ import http -from typing import Optional +from typing import Optional, Dict from fastapi import FastAPI, Path, Query @@ -184,6 +184,11 @@ def get_query_param_required(query=Query(...)): return f"foo bar {query}" +@app.get("/query/params") +def get_query_params(queries: Dict[str, int] = Query({})): + return f"foo bar {queries}" + + @app.get("/query/param-required/int") def get_query_param_required_type(query: int = Query(...)): return f"foo bar {query}" diff --git a/tests/test_query.py b/tests/test_query.py index cdbdd1ccd..6c4f4b0e8 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -53,6 +53,8 @@ response_not_valid_int = { ("/query/param-required/int", 422, response_missing), ("/query/param-required/int?query=50", 200, "foo bar 50"), ("/query/param-required/int?query=foo", 422, response_not_valid_int), + ("/query/params?first-query=1&second-query=2", 200, "foo bar {'first-query': 1, " + "'second-query': 2}") ], ) def test_get_path(path, expected_status, expected_response): From ec624a27be2bc9abc9063263fd7a6173ab47f323 Mon Sep 17 00:00:00 2001 From: Daniyar Yeralin Date: Wed, 12 Aug 2020 14:40:38 -0400 Subject: [PATCH 3/4] Format proposed code --- fastapi/dependencies/utils.py | 21 +++++++++------------ tests/main.py | 2 +- tests/test_query.py | 7 +++++-- 3 files changed, 15 insertions(+), 15 deletions(-) diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index d09ccde8f..843f59bf6 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -34,6 +34,7 @@ from pydantic.error_wrappers import ErrorWrapper from pydantic.errors import MissingError from pydantic.fields import ( SHAPE_LIST, + SHAPE_MAPPING, SHAPE_SEQUENCE, SHAPE_SET, SHAPE_SINGLETON, @@ -41,7 +42,7 @@ from pydantic.fields import ( SHAPE_TUPLE_ELLIPSIS, FieldInfo, ModelField, - Required, SHAPE_MAPPING, + Required, ) from pydantic.schema import get_annotation_from_field_info from pydantic.typing import ForwardRef, evaluate_forwardref @@ -69,13 +70,9 @@ sequence_shape_to_type = { SHAPE_TUPLE_ELLIPSIS: list, } -mapping_shapes = { - SHAPE_MAPPING -} -mapping_types = (dict) -mapping_shapes_to_type = { - SHAPE_MAPPING: dict -} +mapping_shapes = {SHAPE_MAPPING} +mapping_types = dict +mapping_shapes_to_type = {SHAPE_MAPPING: dict} multipart_not_installed_error = ( 'Form data requires "python-multipart" to be installed. \n' @@ -343,10 +340,10 @@ def get_dependant( add_param_to_fields(field=param_field, dependant=dependant) elif is_scalar_field(field=param_field): add_param_to_fields(field=param_field, dependant=dependant) - elif isinstance( - param.default, (params.Query, params.Header) - ) and (is_scalar_sequence_field(param_field) - or is_scalar_mapping_field(param_field)): + elif isinstance(param.default, (params.Query, params.Header)) and ( + is_scalar_sequence_field(param_field) + or is_scalar_mapping_field(param_field) + ): add_param_to_fields(field=param_field, dependant=dependant) else: field_info = param_field.field_info diff --git a/tests/main.py b/tests/main.py index fdd16f2d2..17b380f6d 100644 --- a/tests/main.py +++ b/tests/main.py @@ -1,5 +1,5 @@ import http -from typing import Optional, Dict +from typing import Dict, Optional from fastapi import FastAPI, Path, Query diff --git a/tests/test_query.py b/tests/test_query.py index 6c4f4b0e8..934d34136 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -53,8 +53,11 @@ response_not_valid_int = { ("/query/param-required/int", 422, response_missing), ("/query/param-required/int?query=50", 200, "foo bar 50"), ("/query/param-required/int?query=foo", 422, response_not_valid_int), - ("/query/params?first-query=1&second-query=2", 200, "foo bar {'first-query': 1, " - "'second-query': 2}") + ( + "/query/params?first-query=1&second-query=2", + 200, + "foo bar {'first-query': 1, " "'second-query': 2}", + ), ], ) def test_get_path(path, expected_status, expected_response): From 23740cf5248a1a2e3ad448ce631915f0da0724a4 Mon Sep 17 00:00:00 2001 From: Daniyar Yeralin Date: Wed, 12 Aug 2020 15:02:49 -0400 Subject: [PATCH 4/4] Adapt tests to a new endpoint Remove unused import Add entry to openapi.json Reformat openapi.json --- tests/main.py | 10 ++++---- tests/test_application.py | 35 ++++++++++++++++++++++++++++ tests/test_invalid_sequence_param.py | 14 +---------- 3 files changed, 41 insertions(+), 18 deletions(-) diff --git a/tests/main.py b/tests/main.py index 17b380f6d..7700b4dbc 100644 --- a/tests/main.py +++ b/tests/main.py @@ -184,16 +184,16 @@ def get_query_param_required(query=Query(...)): return f"foo bar {query}" -@app.get("/query/params") -def get_query_params(queries: Dict[str, int] = Query({})): - return f"foo bar {queries}" - - @app.get("/query/param-required/int") def get_query_param_required_type(query: int = Query(...)): return f"foo bar {query}" +@app.get("/query/params") +def get_query_params(queries: Dict[str, int] = Query({})): + return f"foo bar {queries}" + + @app.get("/enum-status-code", status_code=http.HTTPStatus.CREATED) def get_enum_status_code(): return "foo bar" diff --git a/tests/test_application.py b/tests/test_application.py index 5ba737307..dea2524a2 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -1078,6 +1078,41 @@ openapi_schema = { ], } }, + "/query/params": { + "get": { + "summary": "Get Query Params", + "operationId": "get_query_params_query_params_get", + "parameters": [ + { + "required": False, + "schema": { + "title": "Queries", + "type": "object", + "additionalProperties": {"type": "integer"}, + "default": {}, + }, + "name": "queries", + "in": "query", + } + ], + "responses": { + "200": { + "description": "Successful Response", + "content": {"application/json": {"schema": {}}}, + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + }, + }, + } + }, "/enum-status-code": { "get": { "responses": { diff --git a/tests/test_invalid_sequence_param.py b/tests/test_invalid_sequence_param.py index f00dd7b93..836b5f947 100644 --- a/tests/test_invalid_sequence_param.py +++ b/tests/test_invalid_sequence_param.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Tuple import pytest from fastapi import FastAPI, Query @@ -39,15 +39,3 @@ def test_invalid_dict(): @app.get("/items/") def read_items(q: Dict[str, Item] = Query(None)): pass # pragma: no cover - - -def test_invalid_simple_dict(): - with pytest.raises(AssertionError): - app = FastAPI() - - class Item(BaseModel): - title: str - - @app.get("/items/") - def read_items(q: Optional[dict] = Query(None)): - pass # pragma: no cover