Merge remote-tracking branch 'origin/free-form-queries'

This commit is contained in:
JONEMI19 2023-07-07 19:08:20 +00:00
commit 1dde63024a
5 changed files with 1239 additions and 14 deletions

View File

@ -1,4 +1,5 @@
import inspect
from collections import defaultdict
from contextlib import contextmanager
from copy import deepcopy
from typing import (
@ -450,6 +451,11 @@ def is_body_param(*, param_field: ModelField, is_path_param: bool) -> bool:
param_field.field_info, (params.Query, params.Header)
) and is_scalar_sequence_field(param_field):
return False
elif isinstance(param_field.field_info, (params.Query, params.Header)) and (
is_scalar_sequence_mapping_field(param_field)
or is_scalar_mapping_field(param_field)
):
return False
else:
assert isinstance(
param_field.field_info, params.Body
@ -644,6 +650,19 @@ def request_params_to_args(
received_params, (QueryParams, Headers)
):
value = received_params.getlist(field.alias) or field.default
if is_scalar_mapping_field(field) and isinstance(
received_params, (QueryParams, Headers)
):
value = dict(received_params.multi_items()) or field.default
elif is_scalar_sequence_mapping_field(field) and isinstance(
received_params, (QueryParams, Headers)
):
if not len(received_params.multi_items()):
value = field.default
else:
value = defaultdict(list)
for k, v in received_params.multi_items():
value[k].append(v)
else:
value = received_params.get(field.alias)
field_info = field.field_info

View File

@ -1,5 +1,5 @@
import http
from typing import FrozenSet, Optional
from typing import FrozenSet, List, Mapping, Optional
from fastapi import FastAPI, Path, Query
@ -184,6 +184,16 @@ def get_query_param_required_type(query: int = Query()):
return f"foo bar {query}"
@app.get("/query/params")
def get_query_params(query: Mapping[str, int] = Query({})):
return f"foo bar {query}"
@app.get("/query/sequence-params")
def get_sequence_query_params(query: Mapping[str, List[int]] = Query({})):
return f"foo bar {query}"
@app.get("/enum-status-code", status_code=http.HTTPStatus.CREATED)
def get_enum_status_code():
return "foo bar"

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,13 @@
from typing import List, Mapping
import pytest
from fastapi import FastAPI, Query
def test_invalid_sequence():
with pytest.raises(AssertionError):
app = FastAPI()
@app.get("/items/")
def read_items(q: Mapping[str, List[List[str]]] = Query(default=None)):
pass # pragma: no cover

View File

@ -1,4 +1,4 @@
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Tuple
import pytest
from fastapi import FastAPI, Query
@ -39,15 +39,3 @@ def test_invalid_dict():
@app.get("/items/")
def read_items(q: Dict[str, Item] = Query(default=None)):
pass # pragma: no cover
def test_invalid_simple_dict():
with pytest.raises(AssertionError):
app = FastAPI()
class Item(BaseModel):
title: str
@app.get("/items/")
def read_items(q: Optional[dict] = Query(default=None)):
pass # pragma: no cover