From 19698b436dbfa7eb36bcfa80c43f08920fe79c1e Mon Sep 17 00:00:00 2001 From: JONEMI19 Date: Sat, 8 Jul 2023 08:35:51 +0000 Subject: [PATCH] drop invalid query parameters --- fastapi/_compat.py | 7 ++----- fastapi/dependencies/utils.py | 39 +++++++++++++++++++++++++++-------- tests/main.py | 25 ++++++++++++++++++++-- tests/test_query.py | 24 +++++++++++++++++++++ 4 files changed, 79 insertions(+), 16 deletions(-) diff --git a/fastapi/_compat.py b/fastapi/_compat.py index c5c1b376d..68302fdbe 100644 --- a/fastapi/_compat.py +++ b/fastapi/_compat.py @@ -658,11 +658,8 @@ def field_annotation_is_scalar_sequence_mapping( return False return at_least_one_scalar_mapping return field_annotation_is_mapping(annotation) and all( - ( - field_annotation_is_scalar_sequence(sub_annotation) - or field_annotation_is_scalar(sub_annotation) - ) - for sub_annotation in get_args(annotation) + field_annotation_is_scalar_sequence(sub_annotation) + for sub_annotation in get_args(annotation)[1:] ) diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index e1d173f6b..5898dd057 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -453,7 +453,7 @@ def is_body_param(*, param_field: ModelField, is_path_param: bool) -> bool: param_field.field_info, (params.Query, params.Header) ) and is_scalar_sequence_field(param_field): return False - elif isinstance(param_field.field_info, (params.Query, params.Header)) and ( + elif isinstance(param_field.field_info, params.Query) and ( is_scalar_sequence_mapping_field(param_field) or is_scalar_mapping_field(param_field) ): @@ -640,6 +640,7 @@ async def solve_dependencies( ) return values, errors, background_tasks, response, dependency_cache +class Marker: pass def request_params_to_args( required_params: Sequence[ModelField], @@ -653,18 +654,15 @@ def request_params_to_args( ): value = received_params.getlist(field.alias) or field.default elif is_scalar_mapping_field(field) and isinstance( - received_params, (QueryParams) + received_params, QueryParams ): value = dict(received_params.multi_items()) or field.default elif is_scalar_sequence_mapping_field(field) and isinstance( - received_params, (QueryParams) + received_params, QueryParams ): - if not len(received_params.multi_items()): - value = field.default - else: - value = defaultdict(list) - for k, v in received_params.multi_items(): - value[k].append(v) + value = defaultdict(list) + for k, v in received_params.multi_items(): + value[k].append(v) else: value = received_params.get(field.alias) field_info = field.field_info @@ -681,6 +679,29 @@ def request_params_to_args( v_, errors_ = field.validate(value, values, loc=loc) if isinstance(errors_, ErrorWrapper): errors.append(errors_) + elif ( + isinstance(errors_, list) + and is_scalar_sequence_mapping_field(field) + and isinstance(received_params, QueryParams) + ): + new_errors = _regenerate_error_with_loc(errors=errors_, loc_prefix=()) + # remove all invalid parameters + marker = Marker() + for _, _, key, index in [error["loc"] for error in new_errors]: + value[key][index] = marker + for key in value: + value[key] = [x for x in value[key] if x != marker] + values[field.name] = value + elif ( + isinstance(errors_, list) + and is_scalar_mapping_field(field) + and isinstance(received_params, QueryParams) + ): + new_errors = _regenerate_error_with_loc(errors=errors_, loc_prefix=()) + # remove all invalid parameters + for _, _, key in [error["loc"] for error in new_errors]: + value.pop(key) + values[field.name] = value elif isinstance(errors_, list): new_errors = _regenerate_error_with_loc(errors=errors_, loc_prefix=()) errors.extend(new_errors) diff --git a/tests/main.py b/tests/main.py index 1a590add8..3eaf4863c 100644 --- a/tests/main.py +++ b/tests/main.py @@ -5,7 +5,6 @@ from fastapi import FastAPI, Path, Query app = FastAPI() - @app.api_route("/api_route") def non_operation(): return {"message": "Hello World"} @@ -196,7 +195,29 @@ def get_mapping_query_params(queries: Mapping[str, str] = Query({})): @app.get("/query/mapping-sequence-params") def get_sequence_mapping_query_params(queries: Mapping[str, List[int]] = Query({})): - return f"foo bar {queries}" + return f"foo bar {dict(queries)}" + + +@app.get("/query/mixed-params") +def get_mixed_mapping_query_params( + sequence_mapping_queries: Mapping[str, List[str]] = Query({}), + mapping_query: Mapping[str, str] = Query(), + query: str = Query(), +): + return ( + f"foo bar {sequence_mapping_queries['foo'][0]} {sequence_mapping_queries['foo'][1]} " + f"{mapping_query['foo']} {mapping_query['bar']} {query}" + ) + + +@app.get("/query/mixed-type-params") +def get_mixed_mapping_mixed_type_query_params( + sequence_mapping_queries: Mapping[str, List[int]] = Query({}), + mapping_query_str: Mapping[str, str] = Query({}), + mapping_query_int: Mapping[str, int] = Query({}), + query: int = Query(), +): + return f"foo bar {query} {mapping_query_str} {mapping_query_int} {dict(sequence_mapping_queries)}" @app.get("/enum-status-code", status_code=http.HTTPStatus.CREATED) diff --git a/tests/test_query.py b/tests/test_query.py index 43f0a8038..7619762da 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -414,3 +414,27 @@ def test_mapping_query(): response = client.get("/query/mapping-params/?foo=fuzz&bar=buzz") assert response.status_code == 200 assert response.json() == "foo bar fuzz buzz" + + +def test_mapping_with_non_mapping_query(): + response = client.get("/query/mixed-params/?foo=fuzz&foo=baz&bar=buzz&query=fizz") + assert response.status_code == 200 + assert response.json() == "foo bar fuzz baz baz buzz fizz" + + +def test_mapping_with_non_mapping_query_mixed_types(): + response = client.get("/query/mixed-type-params/?foo=fuzz&foo=baz&bar=buzz&query=1") + assert response.status_code == 200 + assert response.json() == "foo bar 1 {'foo': 'baz', 'bar': 'buzz', 'query': '1'} {'query': '1'} {'foo': [], 'bar': [], 'query': ['1']}" + + +def test_sequence_mapping_query(): + response = client.get("/query/mapping-sequence-params/?foo=1&foo=2") + assert response.status_code == 200 + assert response.json() == "foo bar {'foo': [1, 2]}" + + +def test_sequence_mapping_query_drops_invalid(): + response = client.get("/query/mapping-sequence-params/?foo=fuzz&foo=buzz") + assert response.status_code == 200 + assert response.json() == "foo bar {'foo': []}"