mirror of https://github.com/tiangolo/fastapi.git
Merge branch 'free-form-queries-100' into master-mj
This commit is contained in:
commit
13cd1af00b
|
|
@ -25,3 +25,6 @@ archive.zip
|
|||
*~
|
||||
.*.sw?
|
||||
.cache
|
||||
|
||||
main.py
|
||||
.devcontainer
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
# See https://pre-commit.com for more information
|
||||
# See https://pre-commit.com/hooks.html for more hooks
|
||||
default_language_version:
|
||||
python: python3.10
|
||||
python: python3.11
|
||||
repos:
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v4.4.0
|
||||
|
|
|
|||
|
|
@ -43,6 +43,13 @@ sequence_annotation_to_type = {
|
|||
|
||||
sequence_types = tuple(sequence_annotation_to_type.keys())
|
||||
|
||||
mapping_annotation_to_type = {
|
||||
Mapping: list,
|
||||
}
|
||||
|
||||
mapping_types = tuple(mapping_annotation_to_type.keys())
|
||||
|
||||
|
||||
if PYDANTIC_V2:
|
||||
from pydantic import PydanticSchemaGenerationError as PydanticSchemaGenerationError
|
||||
from pydantic import TypeAdapter
|
||||
|
|
@ -228,6 +235,12 @@ if PYDANTIC_V2:
|
|||
def is_scalar_sequence_field(field: ModelField) -> bool:
|
||||
return field_annotation_is_scalar_sequence(field.field_info.annotation)
|
||||
|
||||
def is_scalar_sequence_mapping_field(field: ModelField) -> bool:
|
||||
return field_annotation_is_scalar_sequence_mapping(field.field_info.annotation)
|
||||
|
||||
def is_scalar_mapping_field(field: ModelField) -> bool:
|
||||
return field_annotation_is_scalar_mapping(field.field_info.annotation)
|
||||
|
||||
def is_bytes_field(field: ModelField) -> bool:
|
||||
return is_bytes_or_nonable_bytes_annotation(field.type_)
|
||||
|
||||
|
|
@ -275,6 +288,7 @@ else:
|
|||
from pydantic.fields import ( # type: ignore[attr-defined]
|
||||
SHAPE_FROZENSET,
|
||||
SHAPE_LIST,
|
||||
SHAPE_MAPPING,
|
||||
SHAPE_SEQUENCE,
|
||||
SHAPE_SET,
|
||||
SHAPE_SINGLETON,
|
||||
|
|
@ -325,6 +339,7 @@ else:
|
|||
SHAPE_SEQUENCE,
|
||||
SHAPE_TUPLE_ELLIPSIS,
|
||||
}
|
||||
|
||||
sequence_shape_to_type = {
|
||||
SHAPE_LIST: list,
|
||||
SHAPE_SET: set,
|
||||
|
|
@ -333,6 +348,11 @@ else:
|
|||
SHAPE_TUPLE_ELLIPSIS: list,
|
||||
}
|
||||
|
||||
mapping_shapes = {
|
||||
SHAPE_MAPPING,
|
||||
}
|
||||
mapping_shapes_to_type = {SHAPE_MAPPING: Mapping}
|
||||
|
||||
@dataclass
|
||||
class GenerateJsonSchema: # type: ignore[no-redef]
|
||||
ref_template: str
|
||||
|
|
@ -400,6 +420,30 @@ else:
|
|||
return True
|
||||
return False
|
||||
|
||||
def is_pv1_scalar_mapping_field(field: ModelField) -> bool:
|
||||
if (field.shape in mapping_shapes) and not lenient_issubclass(
|
||||
field.type_, BaseModel
|
||||
):
|
||||
if field.sub_fields is None:
|
||||
return False
|
||||
for sub_field in field.sub_fields:
|
||||
if not is_scalar_field(sub_field):
|
||||
return False
|
||||
return True
|
||||
return False
|
||||
|
||||
def is_pv1_scalar_sequence_mapping_field(field: ModelField) -> bool:
|
||||
if (field.shape in mapping_shapes) and not lenient_issubclass(
|
||||
field.type_, BaseModel
|
||||
):
|
||||
if field.sub_fields is None:
|
||||
return False
|
||||
for sub_field in field.sub_fields:
|
||||
if not is_scalar_sequence_field(sub_field):
|
||||
return False
|
||||
return True
|
||||
return False
|
||||
|
||||
def _normalize_errors(errors: Sequence[Any]) -> List[Dict[str, Any]]:
|
||||
use_errors: List[Any] = []
|
||||
for error in errors:
|
||||
|
|
@ -468,6 +512,12 @@ else:
|
|||
def is_scalar_sequence_field(field: ModelField) -> bool:
|
||||
return is_pv1_scalar_sequence_field(field)
|
||||
|
||||
def is_scalar_sequence_mapping_field(field: ModelField) -> bool:
|
||||
return is_pv1_scalar_sequence_mapping_field(field)
|
||||
|
||||
def is_scalar_mapping_field(field: ModelField) -> bool:
|
||||
return is_pv1_scalar_mapping_field(field)
|
||||
|
||||
def is_bytes_field(field: ModelField) -> bool:
|
||||
return lenient_issubclass(field.type_, bytes)
|
||||
|
||||
|
|
@ -517,14 +567,27 @@ def field_annotation_is_sequence(annotation: Union[Type[Any], None]) -> bool:
|
|||
)
|
||||
|
||||
|
||||
def _annotation_is_mapping(annotation: Union[Type[Any], None]) -> bool:
|
||||
if lenient_issubclass(annotation, (str, bytes)):
|
||||
return False
|
||||
return lenient_issubclass(annotation, mapping_types)
|
||||
|
||||
|
||||
def field_annotation_is_mapping(annotation: Union[Type[Any], None]) -> bool:
|
||||
return _annotation_is_mapping(annotation) or _annotation_is_mapping(
|
||||
get_origin(annotation)
|
||||
)
|
||||
|
||||
|
||||
def value_is_sequence(value: Any) -> bool:
|
||||
return isinstance(value, sequence_types) and not isinstance(value, (str, bytes)) # type: ignore[arg-type]
|
||||
|
||||
|
||||
def _annotation_is_complex(annotation: Union[Type[Any], None]) -> bool:
|
||||
return (
|
||||
lenient_issubclass(annotation, (BaseModel, Mapping, UploadFile))
|
||||
lenient_issubclass(annotation, (BaseModel, UploadFile))
|
||||
or _annotation_is_sequence(annotation)
|
||||
or _annotation_is_mapping(annotation)
|
||||
or is_dataclass(annotation)
|
||||
)
|
||||
|
||||
|
|
@ -564,6 +627,45 @@ def field_annotation_is_scalar_sequence(annotation: Union[Type[Any], None]) -> b
|
|||
)
|
||||
|
||||
|
||||
def field_annotation_is_scalar_mapping(annotation: Union[Type[Any], None]) -> bool:
|
||||
origin = get_origin(annotation)
|
||||
if origin is Union or origin is UnionType:
|
||||
at_least_one_scalar_mapping = False
|
||||
for arg in get_args(annotation):
|
||||
if field_annotation_is_scalar_mapping(arg):
|
||||
at_least_one_scalar_mapping = True
|
||||
continue
|
||||
elif not field_annotation_is_scalar(arg):
|
||||
return False
|
||||
return at_least_one_scalar_mapping
|
||||
return field_annotation_is_mapping(annotation) and all(
|
||||
field_annotation_is_scalar(sub_annotation)
|
||||
for sub_annotation in get_args(annotation)
|
||||
)
|
||||
|
||||
|
||||
def field_annotation_is_scalar_sequence_mapping(
|
||||
annotation: Union[Type[Any], None]
|
||||
) -> bool:
|
||||
origin = get_origin(annotation)
|
||||
if origin is Union or origin is UnionType:
|
||||
at_least_one_scalar_mapping = False
|
||||
for arg in get_args(annotation):
|
||||
if field_annotation_is_scalar_mapping(arg):
|
||||
at_least_one_scalar_mapping = True
|
||||
continue
|
||||
elif not field_annotation_is_scalar(arg):
|
||||
return False
|
||||
return at_least_one_scalar_mapping
|
||||
return field_annotation_is_mapping(annotation) and all(
|
||||
(
|
||||
field_annotation_is_scalar_sequence(sub_annotation)
|
||||
or field_annotation_is_scalar(sub_annotation)
|
||||
)
|
||||
for sub_annotation in get_args(annotation)
|
||||
)
|
||||
|
||||
|
||||
def is_bytes_or_nonable_bytes_annotation(annotation: Any) -> bool:
|
||||
if lenient_issubclass(annotation, bytes):
|
||||
return True
|
||||
|
|
|
|||
|
|
@ -36,7 +36,9 @@ from fastapi._compat import (
|
|||
is_bytes_field,
|
||||
is_bytes_sequence_field,
|
||||
is_scalar_field,
|
||||
is_scalar_mapping_field,
|
||||
is_scalar_sequence_field,
|
||||
is_scalar_sequence_mapping_field,
|
||||
is_sequence_field,
|
||||
is_uploadfile_or_nonable_uploadfile_annotation,
|
||||
is_uploadfile_sequence_annotation,
|
||||
|
|
@ -451,9 +453,10 @@ def is_body_param(*, param_field: ModelField, is_path_param: bool) -> bool:
|
|||
param_field.field_info, (params.Query, params.Header)
|
||||
) and is_scalar_sequence_field(param_field):
|
||||
return False
|
||||
elif isinstance(
|
||||
param_field.field_info, (params.Query, params.Header)
|
||||
) and (is_scalar_sequence_mapping_field(param_field) or is_scalar_mapping_field(param_field)):
|
||||
elif isinstance(param_field.field_info, (params.Query, params.Header)) and (
|
||||
is_scalar_sequence_mapping_field(param_field)
|
||||
or is_scalar_mapping_field(param_field)
|
||||
):
|
||||
return False
|
||||
else:
|
||||
assert isinstance(
|
||||
|
|
@ -649,7 +652,7 @@ def request_params_to_args(
|
|||
received_params, (QueryParams, Headers)
|
||||
):
|
||||
value = received_params.getlist(field.alias) or field.default
|
||||
if is_scalar_mapping_field(field) and isinstance(
|
||||
elif is_scalar_mapping_field(field) and isinstance(
|
||||
received_params, (QueryParams, Headers)
|
||||
):
|
||||
value = dict(received_params.multi_items()) or field.default
|
||||
|
|
@ -658,7 +661,7 @@ def request_params_to_args(
|
|||
):
|
||||
if not len(received_params.multi_items()):
|
||||
value = field.default
|
||||
else:
|
||||
else:
|
||||
value = defaultdict(list)
|
||||
for k, v in received_params.multi_items():
|
||||
value[k].append(v)
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -1,5 +1,5 @@
|
|||
import http
|
||||
from typing import FrozenSet, Optional
|
||||
from typing import FrozenSet, List, Mapping, Optional
|
||||
|
||||
from fastapi import FastAPI, Path, Query
|
||||
|
||||
|
|
@ -184,6 +184,21 @@ def get_query_param_required_type(query: int = Query()):
|
|||
return f"foo bar {query}"
|
||||
|
||||
|
||||
@app.get("/query/sequence-params")
|
||||
def get_sequence_query_params(query: Mapping[str, List[int]] = Query({})):
|
||||
return f"foo bar {query}"
|
||||
|
||||
|
||||
@app.get("/query/mapping-params")
|
||||
def get_mapping_query_params(queries: Mapping[str, str] = Query({})):
|
||||
return f"foo bar {queries['foo']} {queries['bar']}"
|
||||
|
||||
|
||||
@app.get("/query/mapping-sequence-params")
|
||||
def get_sequence_mapping_query_params(queries: Mapping[str, List[int]] = Query({})):
|
||||
return f"foo bar {queries}"
|
||||
|
||||
|
||||
@app.get("/enum-status-code", status_code=http.HTTPStatus.CREATED)
|
||||
def get_enum_status_code():
|
||||
return "foo bar"
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,13 @@
|
|||
from typing import List, Mapping
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI, Query
|
||||
|
||||
|
||||
def test_invalid_sequence():
|
||||
with pytest.raises(AssertionError):
|
||||
app = FastAPI()
|
||||
|
||||
@app.get("/items/")
|
||||
def read_items(q: Mapping[str, List[List[str]]] = Query(default=None)):
|
||||
pass # pragma: no cover
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI, Query
|
||||
|
|
@ -39,15 +39,3 @@ def test_invalid_dict():
|
|||
@app.get("/items/")
|
||||
def read_items(q: Dict[str, Item] = Query(default=None)):
|
||||
pass # pragma: no cover
|
||||
|
||||
|
||||
def test_invalid_simple_dict():
|
||||
with pytest.raises(AssertionError):
|
||||
app = FastAPI()
|
||||
|
||||
class Item(BaseModel):
|
||||
title: str
|
||||
|
||||
@app.get("/items/")
|
||||
def read_items(q: Optional[dict] = Query(default=None)):
|
||||
pass # pragma: no cover
|
||||
|
|
|
|||
|
|
@ -408,3 +408,9 @@ def test_query_frozenset_query_1_query_1_query_2():
|
|||
response = client.get("/query/frozenset/?query=1&query=1&query=2")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == "1,2"
|
||||
|
||||
|
||||
def test_mapping_query():
|
||||
response = client.get("/query/mapping-params/?foo=fuzz&bar=buzz")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == "foo bar fuzz buzz"
|
||||
|
|
|
|||
Loading…
Reference in New Issue