mirror of https://github.com/tiangolo/fastapi.git
drop invalid query parameters
This commit is contained in:
parent
cd34cdc02c
commit
19698b436d
|
|
@ -658,11 +658,8 @@ def field_annotation_is_scalar_sequence_mapping(
|
||||||
return False
|
return False
|
||||||
return at_least_one_scalar_mapping
|
return at_least_one_scalar_mapping
|
||||||
return field_annotation_is_mapping(annotation) and all(
|
return field_annotation_is_mapping(annotation) and all(
|
||||||
(
|
field_annotation_is_scalar_sequence(sub_annotation)
|
||||||
field_annotation_is_scalar_sequence(sub_annotation)
|
for sub_annotation in get_args(annotation)[1:]
|
||||||
or field_annotation_is_scalar(sub_annotation)
|
|
||||||
)
|
|
||||||
for sub_annotation in get_args(annotation)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -453,7 +453,7 @@ def is_body_param(*, param_field: ModelField, is_path_param: bool) -> bool:
|
||||||
param_field.field_info, (params.Query, params.Header)
|
param_field.field_info, (params.Query, params.Header)
|
||||||
) and is_scalar_sequence_field(param_field):
|
) and is_scalar_sequence_field(param_field):
|
||||||
return False
|
return False
|
||||||
elif isinstance(param_field.field_info, (params.Query, params.Header)) and (
|
elif isinstance(param_field.field_info, params.Query) and (
|
||||||
is_scalar_sequence_mapping_field(param_field)
|
is_scalar_sequence_mapping_field(param_field)
|
||||||
or is_scalar_mapping_field(param_field)
|
or is_scalar_mapping_field(param_field)
|
||||||
):
|
):
|
||||||
|
|
@ -640,6 +640,7 @@ async def solve_dependencies(
|
||||||
)
|
)
|
||||||
return values, errors, background_tasks, response, dependency_cache
|
return values, errors, background_tasks, response, dependency_cache
|
||||||
|
|
||||||
|
class Marker: pass
|
||||||
|
|
||||||
def request_params_to_args(
|
def request_params_to_args(
|
||||||
required_params: Sequence[ModelField],
|
required_params: Sequence[ModelField],
|
||||||
|
|
@ -653,18 +654,15 @@ def request_params_to_args(
|
||||||
):
|
):
|
||||||
value = received_params.getlist(field.alias) or field.default
|
value = received_params.getlist(field.alias) or field.default
|
||||||
elif is_scalar_mapping_field(field) and isinstance(
|
elif is_scalar_mapping_field(field) and isinstance(
|
||||||
received_params, (QueryParams)
|
received_params, QueryParams
|
||||||
):
|
):
|
||||||
value = dict(received_params.multi_items()) or field.default
|
value = dict(received_params.multi_items()) or field.default
|
||||||
elif is_scalar_sequence_mapping_field(field) and isinstance(
|
elif is_scalar_sequence_mapping_field(field) and isinstance(
|
||||||
received_params, (QueryParams)
|
received_params, QueryParams
|
||||||
):
|
):
|
||||||
if not len(received_params.multi_items()):
|
value = defaultdict(list)
|
||||||
value = field.default
|
for k, v in received_params.multi_items():
|
||||||
else:
|
value[k].append(v)
|
||||||
value = defaultdict(list)
|
|
||||||
for k, v in received_params.multi_items():
|
|
||||||
value[k].append(v)
|
|
||||||
else:
|
else:
|
||||||
value = received_params.get(field.alias)
|
value = received_params.get(field.alias)
|
||||||
field_info = field.field_info
|
field_info = field.field_info
|
||||||
|
|
@ -681,6 +679,29 @@ def request_params_to_args(
|
||||||
v_, errors_ = field.validate(value, values, loc=loc)
|
v_, errors_ = field.validate(value, values, loc=loc)
|
||||||
if isinstance(errors_, ErrorWrapper):
|
if isinstance(errors_, ErrorWrapper):
|
||||||
errors.append(errors_)
|
errors.append(errors_)
|
||||||
|
elif (
|
||||||
|
isinstance(errors_, list)
|
||||||
|
and is_scalar_sequence_mapping_field(field)
|
||||||
|
and isinstance(received_params, QueryParams)
|
||||||
|
):
|
||||||
|
new_errors = _regenerate_error_with_loc(errors=errors_, loc_prefix=())
|
||||||
|
# remove all invalid parameters
|
||||||
|
marker = Marker()
|
||||||
|
for _, _, key, index in [error["loc"] for error in new_errors]:
|
||||||
|
value[key][index] = marker
|
||||||
|
for key in value:
|
||||||
|
value[key] = [x for x in value[key] if x != marker]
|
||||||
|
values[field.name] = value
|
||||||
|
elif (
|
||||||
|
isinstance(errors_, list)
|
||||||
|
and is_scalar_mapping_field(field)
|
||||||
|
and isinstance(received_params, QueryParams)
|
||||||
|
):
|
||||||
|
new_errors = _regenerate_error_with_loc(errors=errors_, loc_prefix=())
|
||||||
|
# remove all invalid parameters
|
||||||
|
for _, _, key in [error["loc"] for error in new_errors]:
|
||||||
|
value.pop(key)
|
||||||
|
values[field.name] = value
|
||||||
elif isinstance(errors_, list):
|
elif isinstance(errors_, list):
|
||||||
new_errors = _regenerate_error_with_loc(errors=errors_, loc_prefix=())
|
new_errors = _regenerate_error_with_loc(errors=errors_, loc_prefix=())
|
||||||
errors.extend(new_errors)
|
errors.extend(new_errors)
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,6 @@ from fastapi import FastAPI, Path, Query
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
|
|
||||||
@app.api_route("/api_route")
|
@app.api_route("/api_route")
|
||||||
def non_operation():
|
def non_operation():
|
||||||
return {"message": "Hello World"}
|
return {"message": "Hello World"}
|
||||||
|
|
@ -196,7 +195,29 @@ def get_mapping_query_params(queries: Mapping[str, str] = Query({})):
|
||||||
|
|
||||||
@app.get("/query/mapping-sequence-params")
|
@app.get("/query/mapping-sequence-params")
|
||||||
def get_sequence_mapping_query_params(queries: Mapping[str, List[int]] = Query({})):
|
def get_sequence_mapping_query_params(queries: Mapping[str, List[int]] = Query({})):
|
||||||
return f"foo bar {queries}"
|
return f"foo bar {dict(queries)}"
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/query/mixed-params")
|
||||||
|
def get_mixed_mapping_query_params(
|
||||||
|
sequence_mapping_queries: Mapping[str, List[str]] = Query({}),
|
||||||
|
mapping_query: Mapping[str, str] = Query(),
|
||||||
|
query: str = Query(),
|
||||||
|
):
|
||||||
|
return (
|
||||||
|
f"foo bar {sequence_mapping_queries['foo'][0]} {sequence_mapping_queries['foo'][1]} "
|
||||||
|
f"{mapping_query['foo']} {mapping_query['bar']} {query}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/query/mixed-type-params")
|
||||||
|
def get_mixed_mapping_mixed_type_query_params(
|
||||||
|
sequence_mapping_queries: Mapping[str, List[int]] = Query({}),
|
||||||
|
mapping_query_str: Mapping[str, str] = Query({}),
|
||||||
|
mapping_query_int: Mapping[str, int] = Query({}),
|
||||||
|
query: int = Query(),
|
||||||
|
):
|
||||||
|
return f"foo bar {query} {mapping_query_str} {mapping_query_int} {dict(sequence_mapping_queries)}"
|
||||||
|
|
||||||
|
|
||||||
@app.get("/enum-status-code", status_code=http.HTTPStatus.CREATED)
|
@app.get("/enum-status-code", status_code=http.HTTPStatus.CREATED)
|
||||||
|
|
|
||||||
|
|
@ -414,3 +414,27 @@ def test_mapping_query():
|
||||||
response = client.get("/query/mapping-params/?foo=fuzz&bar=buzz")
|
response = client.get("/query/mapping-params/?foo=fuzz&bar=buzz")
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.json() == "foo bar fuzz buzz"
|
assert response.json() == "foo bar fuzz buzz"
|
||||||
|
|
||||||
|
|
||||||
|
def test_mapping_with_non_mapping_query():
|
||||||
|
response = client.get("/query/mixed-params/?foo=fuzz&foo=baz&bar=buzz&query=fizz")
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == "foo bar fuzz baz baz buzz fizz"
|
||||||
|
|
||||||
|
|
||||||
|
def test_mapping_with_non_mapping_query_mixed_types():
|
||||||
|
response = client.get("/query/mixed-type-params/?foo=fuzz&foo=baz&bar=buzz&query=1")
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == "foo bar 1 {'foo': 'baz', 'bar': 'buzz', 'query': '1'} {'query': '1'} {'foo': [], 'bar': [], 'query': ['1']}"
|
||||||
|
|
||||||
|
|
||||||
|
def test_sequence_mapping_query():
|
||||||
|
response = client.get("/query/mapping-sequence-params/?foo=1&foo=2")
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == "foo bar {'foo': [1, 2]}"
|
||||||
|
|
||||||
|
|
||||||
|
def test_sequence_mapping_query_drops_invalid():
|
||||||
|
response = client.get("/query/mapping-sequence-params/?foo=fuzz&foo=buzz")
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == "foo bar {'foo': []}"
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue