mirror of https://github.com/tiangolo/fastapi.git
update and remove otherwise captured query params
This commit is contained in:
parent
dcd67c7feb
commit
1a57459eda
|
|
@ -24,7 +24,9 @@ from .main import get_schema_from_model_field as get_schema_from_model_field
|
||||||
from .main import is_bytes_field as is_bytes_field
|
from .main import is_bytes_field as is_bytes_field
|
||||||
from .main import is_bytes_sequence_field as is_bytes_sequence_field
|
from .main import is_bytes_sequence_field as is_bytes_sequence_field
|
||||||
from .main import is_scalar_field as is_scalar_field
|
from .main import is_scalar_field as is_scalar_field
|
||||||
|
from .main import is_scalar_mapping_field as is_scalar_mapping_field
|
||||||
from .main import is_scalar_sequence_field as is_scalar_sequence_field
|
from .main import is_scalar_sequence_field as is_scalar_sequence_field
|
||||||
|
from .main import is_scalar_sequence_mapping_field as is_scalar_sequence_mapping_field
|
||||||
from .main import is_sequence_field as is_sequence_field
|
from .main import is_sequence_field as is_sequence_field
|
||||||
from .main import serialize_sequence_value as serialize_sequence_value
|
from .main import serialize_sequence_value as serialize_sequence_value
|
||||||
from .main import (
|
from .main import (
|
||||||
|
|
|
||||||
|
|
@ -209,6 +209,30 @@ def is_sequence_field(field: ModelField) -> bool:
|
||||||
return v2.is_sequence_field(field) # type: ignore[arg-type]
|
return v2.is_sequence_field(field) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
|
||||||
|
def is_scalar_mapping_field(field: ModelField) -> bool:
|
||||||
|
if isinstance(field, may_v1.ModelField):
|
||||||
|
from fastapi._compat import v1
|
||||||
|
|
||||||
|
return v1.is_scalar_mapping_field(field)
|
||||||
|
else:
|
||||||
|
assert PYDANTIC_V2
|
||||||
|
from . import v2
|
||||||
|
|
||||||
|
return v2.is_scalar_mapping_field(field) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
|
||||||
|
def is_scalar_sequence_mapping_field(field: ModelField) -> bool:
|
||||||
|
if isinstance(field, may_v1.ModelField):
|
||||||
|
from fastapi._compat import v1
|
||||||
|
|
||||||
|
return v1.is_scalar_sequence_mapping_field(field)
|
||||||
|
else:
|
||||||
|
assert PYDANTIC_V2
|
||||||
|
from . import v2
|
||||||
|
|
||||||
|
return v2.is_scalar_sequence_mapping_field(field) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
|
||||||
def serialize_sequence_value(*, field: ModelField, value: Any) -> Sequence[Any]:
|
def serialize_sequence_value(*, field: ModelField, value: Any) -> Sequence[Any]:
|
||||||
if isinstance(field, may_v1.ModelField):
|
if isinstance(field, may_v1.ModelField):
|
||||||
from fastapi._compat import v1
|
from fastapi._compat import v1
|
||||||
|
|
|
||||||
|
|
@ -144,6 +144,45 @@ def field_annotation_is_scalar_sequence(annotation: Union[Type[Any], None]) -> b
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def field_annotation_is_scalar_mapping(
|
||||||
|
annotation: Union[Type[Any], None],
|
||||||
|
) -> bool:
|
||||||
|
origin = get_origin(annotation)
|
||||||
|
if origin is Union or origin is UnionType:
|
||||||
|
at_least_one_scalar_mapping = False
|
||||||
|
for arg in get_args(annotation):
|
||||||
|
if field_annotation_is_scalar_mapping(arg):
|
||||||
|
at_least_one_scalar_mapping = True
|
||||||
|
continue
|
||||||
|
elif not field_annotation_is_scalar(arg):
|
||||||
|
return False
|
||||||
|
return at_least_one_scalar_mapping
|
||||||
|
return lenient_issubclass(origin, Mapping) and all(
|
||||||
|
field_annotation_is_scalar(sub_annotation)
|
||||||
|
for sub_annotation in get_args(annotation)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def field_annotation_is_scalar_sequence_mapping(
|
||||||
|
annotation: Union[Type[Any], None],
|
||||||
|
) -> bool:
|
||||||
|
origin = get_origin(annotation)
|
||||||
|
if origin is Union or origin is UnionType:
|
||||||
|
at_least_one_scalar_mapping = False
|
||||||
|
for arg in get_args(annotation):
|
||||||
|
if field_annotation_is_scalar_sequence_mapping(arg):
|
||||||
|
at_least_one_scalar_mapping = True
|
||||||
|
continue
|
||||||
|
elif not field_annotation_is_scalar(arg):
|
||||||
|
return False
|
||||||
|
return at_least_one_scalar_mapping
|
||||||
|
return lenient_issubclass(origin, Mapping) and all(
|
||||||
|
field_annotation_is_scalar_sequence(sub_annotation)
|
||||||
|
or field_annotation_is_scalar(sub_annotation)
|
||||||
|
for sub_annotation in get_args(annotation)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def is_bytes_or_nonable_bytes_annotation(annotation: Any) -> bool:
|
def is_bytes_or_nonable_bytes_annotation(annotation: Any) -> bool:
|
||||||
if lenient_issubclass(annotation, bytes):
|
if lenient_issubclass(annotation, bytes):
|
||||||
return True
|
return True
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ from typing import (
|
||||||
Callable,
|
Callable,
|
||||||
Dict,
|
Dict,
|
||||||
List,
|
List,
|
||||||
|
Mapping,
|
||||||
Sequence,
|
Sequence,
|
||||||
Set,
|
Set,
|
||||||
Tuple,
|
Tuple,
|
||||||
|
|
@ -84,6 +85,7 @@ else:
|
||||||
from pydantic.v1.fields import (
|
from pydantic.v1.fields import (
|
||||||
SHAPE_FROZENSET,
|
SHAPE_FROZENSET,
|
||||||
SHAPE_LIST,
|
SHAPE_LIST,
|
||||||
|
SHAPE_MAPPING,
|
||||||
SHAPE_SEQUENCE,
|
SHAPE_SEQUENCE,
|
||||||
SHAPE_SET,
|
SHAPE_SET,
|
||||||
SHAPE_SINGLETON,
|
SHAPE_SINGLETON,
|
||||||
|
|
@ -144,6 +146,11 @@ sequence_shape_to_type = {
|
||||||
SHAPE_TUPLE_ELLIPSIS: list,
|
SHAPE_TUPLE_ELLIPSIS: list,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
mapping_shapes = {
|
||||||
|
SHAPE_MAPPING,
|
||||||
|
}
|
||||||
|
mapping_shapes_to_type = {SHAPE_MAPPING: Mapping}
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class GenerateJsonSchema:
|
class GenerateJsonSchema:
|
||||||
|
|
@ -219,6 +226,30 @@ def is_pv1_scalar_sequence_field(field: ModelField) -> bool:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def is_pv1_scalar_sequence_mapping_field(field: ModelField) -> bool:
|
||||||
|
if (field.shape in mapping_shapes) and not lenient_issubclass( # type: ignore[attr-defined]
|
||||||
|
field.type_, BaseModel
|
||||||
|
):
|
||||||
|
if field.sub_fields is not None: # type: ignore[attr-defined]
|
||||||
|
for sub_field in field.sub_fields: # type: ignore[attr-defined]
|
||||||
|
if not is_scalar_sequence_field(sub_field):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def is_pv1_scalar_mapping_field(field: ModelField) -> bool:
|
||||||
|
if (field.shape in mapping_shapes) and not lenient_issubclass( # type: ignore[attr-defined]
|
||||||
|
field.type_, BaseModel
|
||||||
|
):
|
||||||
|
if field.sub_fields is not None: # type: ignore[attr-defined]
|
||||||
|
for sub_field in field.sub_fields: # type: ignore[attr-defined]
|
||||||
|
if not is_scalar_field(sub_field):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def _model_rebuild(model: Type[BaseModel]) -> None:
|
def _model_rebuild(model: Type[BaseModel]) -> None:
|
||||||
model.update_forward_refs()
|
model.update_forward_refs()
|
||||||
|
|
||||||
|
|
@ -277,6 +308,14 @@ def is_scalar_sequence_field(field: ModelField) -> bool:
|
||||||
return is_pv1_scalar_sequence_field(field)
|
return is_pv1_scalar_sequence_field(field)
|
||||||
|
|
||||||
|
|
||||||
|
def is_scalar_mapping_field(field: ModelField) -> bool:
|
||||||
|
return is_pv1_scalar_mapping_field(field)
|
||||||
|
|
||||||
|
|
||||||
|
def is_scalar_sequence_mapping_field(field: ModelField) -> bool:
|
||||||
|
return is_pv1_scalar_sequence_mapping_field(field)
|
||||||
|
|
||||||
|
|
||||||
def is_bytes_field(field: ModelField) -> bool:
|
def is_bytes_field(field: ModelField) -> bool:
|
||||||
return lenient_issubclass(field.type_, bytes) # type: ignore[no-any-return]
|
return lenient_issubclass(field.type_, bytes) # type: ignore[no-any-return]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -352,6 +352,16 @@ def is_scalar_sequence_field(field: ModelField) -> bool:
|
||||||
return shared.field_annotation_is_scalar_sequence(field.field_info.annotation)
|
return shared.field_annotation_is_scalar_sequence(field.field_info.annotation)
|
||||||
|
|
||||||
|
|
||||||
|
def is_scalar_mapping_field(field: ModelField) -> bool:
|
||||||
|
return shared.field_annotation_is_scalar_mapping(field.field_info.annotation)
|
||||||
|
|
||||||
|
|
||||||
|
def is_scalar_sequence_mapping_field(field: ModelField) -> bool:
|
||||||
|
return shared.field_annotation_is_scalar_sequence_mapping(
|
||||||
|
field.field_info.annotation
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def is_bytes_field(field: ModelField) -> bool:
|
def is_bytes_field(field: ModelField) -> bool:
|
||||||
return shared.is_bytes_or_nonable_bytes_annotation(field.type_)
|
return shared.is_bytes_or_nonable_bytes_annotation(field.type_)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -37,7 +37,9 @@ from fastapi._compat import (
|
||||||
is_bytes_field,
|
is_bytes_field,
|
||||||
is_bytes_sequence_field,
|
is_bytes_sequence_field,
|
||||||
is_scalar_field,
|
is_scalar_field,
|
||||||
|
is_scalar_mapping_field,
|
||||||
is_scalar_sequence_field,
|
is_scalar_sequence_field,
|
||||||
|
is_scalar_sequence_mapping_field,
|
||||||
is_sequence_field,
|
is_sequence_field,
|
||||||
is_uploadfile_or_nonable_uploadfile_annotation,
|
is_uploadfile_or_nonable_uploadfile_annotation,
|
||||||
is_uploadfile_sequence_annotation,
|
is_uploadfile_sequence_annotation,
|
||||||
|
|
@ -512,6 +514,7 @@ def analyze_param(
|
||||||
assert (
|
assert (
|
||||||
is_scalar_field(field)
|
is_scalar_field(field)
|
||||||
or is_scalar_sequence_field(field)
|
or is_scalar_sequence_field(field)
|
||||||
|
or is_scalar_sequence_mapping_field(field)
|
||||||
or (
|
or (
|
||||||
_is_model_class(field.type_)
|
_is_model_class(field.type_)
|
||||||
# For Pydantic v1
|
# For Pydantic v1
|
||||||
|
|
@ -707,6 +710,20 @@ def _validate_value_with_model_field(
|
||||||
else:
|
else:
|
||||||
return deepcopy(field.default), []
|
return deepcopy(field.default), []
|
||||||
v_, errors_ = field.validate(value, values, loc=loc)
|
v_, errors_ = field.validate(value, values, loc=loc)
|
||||||
|
if (
|
||||||
|
errors_
|
||||||
|
and isinstance(field.field_info, params.Query)
|
||||||
|
and isinstance(value, Mapping)
|
||||||
|
and (is_scalar_sequence_mapping_field(field) or is_scalar_mapping_field(field))
|
||||||
|
):
|
||||||
|
# Remove failing keys from the dict and try to re-validate
|
||||||
|
invalid_keys = {err["loc"][2] for err in errors_ if len(err["loc"]) >= 3}
|
||||||
|
v_, errors_ = field.validate(
|
||||||
|
{k: v for k, v in value.items() if k not in invalid_keys},
|
||||||
|
values,
|
||||||
|
loc=loc,
|
||||||
|
)
|
||||||
|
|
||||||
if _is_error_wrapper(errors_): # type: ignore[arg-type]
|
if _is_error_wrapper(errors_): # type: ignore[arg-type]
|
||||||
return None, [errors_]
|
return None, [errors_]
|
||||||
elif isinstance(errors_, list):
|
elif isinstance(errors_, list):
|
||||||
|
|
@ -720,10 +737,19 @@ def _get_multidict_value(
|
||||||
field: ModelField, values: Mapping[str, Any], alias: Union[str, None] = None
|
field: ModelField, values: Mapping[str, Any], alias: Union[str, None] = None
|
||||||
) -> Any:
|
) -> Any:
|
||||||
alias = alias or field.alias
|
alias = alias or field.alias
|
||||||
|
value: Any = None
|
||||||
if is_sequence_field(field) and isinstance(values, (ImmutableMultiDict, Headers)):
|
if is_sequence_field(field) and isinstance(values, (ImmutableMultiDict, Headers)):
|
||||||
value = values.getlist(alias)
|
value = values.getlist(alias)
|
||||||
else:
|
elif alias in values:
|
||||||
value = values.get(alias, None)
|
value = values[alias]
|
||||||
|
elif values and is_scalar_mapping_field(field) and isinstance(values, QueryParams):
|
||||||
|
value = dict(values)
|
||||||
|
elif (
|
||||||
|
values
|
||||||
|
and is_scalar_sequence_mapping_field(field)
|
||||||
|
and isinstance(values, QueryParams)
|
||||||
|
):
|
||||||
|
value = {key: values.getlist(key) for key in values.keys()}
|
||||||
if (
|
if (
|
||||||
value is None
|
value is None
|
||||||
or (
|
or (
|
||||||
|
|
@ -820,6 +846,13 @@ def request_params_to_args(
|
||||||
errors.extend(errors_)
|
errors.extend(errors_)
|
||||||
else:
|
else:
|
||||||
values[field.name] = v_
|
values[field.name] = v_
|
||||||
|
# remove keys which were captured by a mapping query field but were otherwise specified
|
||||||
|
for field in fields:
|
||||||
|
if isinstance(values.get(field.name), dict) and (
|
||||||
|
is_scalar_mapping_field(field) or is_scalar_sequence_mapping_field(field)
|
||||||
|
):
|
||||||
|
for f_ in fields:
|
||||||
|
values[field.name].pop(f_.alias, None)
|
||||||
return values, errors
|
return values, errors
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -191,12 +191,12 @@ def get_query_param_required_type(query: int = Query()):
|
||||||
|
|
||||||
@app.get("/query/mapping-params")
|
@app.get("/query/mapping-params")
|
||||||
def get_mapping_query_params(queries: Dict[str, str] = Query({})):
|
def get_mapping_query_params(queries: Dict[str, str] = Query({})):
|
||||||
return f"foo bar {queries['foo']} {queries['bar']}"
|
return {"queries": queries}
|
||||||
|
|
||||||
|
|
||||||
@app.get("/query/mapping-sequence-params")
|
@app.get("/query/mapping-sequence-params")
|
||||||
def get_sequence_mapping_query_params(queries: Dict[str, List[int]] = Query({})):
|
def get_sequence_mapping_query_params(queries: Dict[str, List[int]] = Query({})):
|
||||||
return f"foo bar {dict(queries)}"
|
return {"queries": queries}
|
||||||
|
|
||||||
|
|
||||||
@app.get("/query/mixed-params")
|
@app.get("/query/mixed-params")
|
||||||
|
|
@ -205,10 +205,13 @@ def get_mixed_mapping_query_params(
|
||||||
mapping_query: Dict[str, str] = Query(),
|
mapping_query: Dict[str, str] = Query(),
|
||||||
query: str = Query(),
|
query: str = Query(),
|
||||||
):
|
):
|
||||||
return (
|
return {
|
||||||
f"foo bar {sequence_mapping_queries['foo'][0]} {sequence_mapping_queries['foo'][1]} "
|
"queries": {
|
||||||
f"{mapping_query['foo']} {mapping_query['bar']} {query}"
|
"query": query,
|
||||||
)
|
"mapping_query": mapping_query,
|
||||||
|
"sequence_mapping_queries": sequence_mapping_queries,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@app.get("/query/mixed-type-params")
|
@app.get("/query/mixed-type-params")
|
||||||
|
|
@ -218,7 +221,14 @@ def get_mixed_mapping_mixed_type_query_params(
|
||||||
mapping_query_int: Dict[str, int] = Query({}),
|
mapping_query_int: Dict[str, int] = Query({}),
|
||||||
query: int = Query(),
|
query: int = Query(),
|
||||||
):
|
):
|
||||||
return f"foo bar {query} {mapping_query_str} {mapping_query_int} {dict(sequence_mapping_queries)}"
|
return {
|
||||||
|
"queries": {
|
||||||
|
"query": query,
|
||||||
|
"mapping_query_str": mapping_query_str,
|
||||||
|
"mapping_query_int": mapping_query_int,
|
||||||
|
"sequence_mapping_queries": sequence_mapping_queries,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@app.get("/enum-status-code", status_code=http.HTTPStatus.CREATED)
|
@app.get("/enum-status-code", status_code=http.HTTPStatus.CREATED)
|
||||||
|
|
|
||||||
|
|
@ -424,31 +424,44 @@ def test_query_frozenset_query_1_query_1_query_2():
|
||||||
def test_mapping_query():
|
def test_mapping_query():
|
||||||
response = client.get("/query/mapping-params/?foo=fuzz&bar=buzz")
|
response = client.get("/query/mapping-params/?foo=fuzz&bar=buzz")
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.json() == "foo bar fuzz buzz"
|
assert response.json() == {"queries": {"bar": "buzz", "foo": "fuzz"}}
|
||||||
|
|
||||||
|
|
||||||
def test_mapping_with_non_mapping_query():
|
def test_mapping_with_non_mapping_query():
|
||||||
response = client.get("/query/mixed-params/?foo=fuzz&foo=baz&bar=buzz&query=fizz")
|
response = client.get("/query/mixed-params/?foo=fuzz&foo=baz&bar=buzz&query=fizz")
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.json() == "foo bar fuzz baz baz buzz fizz"
|
assert response.json() == {
|
||||||
|
"queries": {
|
||||||
|
"query": "fizz",
|
||||||
|
"mapping_query": {"foo": "baz", "bar": "buzz"},
|
||||||
|
"sequence_mapping_queries": {
|
||||||
|
"foo": ["fuzz", "baz"],
|
||||||
|
"bar": ["buzz"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def test_mapping_with_non_mapping_query_mixed_types():
|
def test_mapping_with_non_mapping_query_mixed_types():
|
||||||
response = client.get("/query/mixed-type-params/?foo=fuzz&foo=baz&bar=buzz&query=1")
|
response = client.get("/query/mixed-type-params/?foo=fuzz&foo=baz&bar=buzz&query=1")
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert (
|
assert response.json() == {
|
||||||
response.json()
|
"queries": {
|
||||||
== "foo bar 1 {'foo': 'baz', 'bar': 'buzz', 'query': '1'} {'query': 1} {'foo': [], 'bar': [], 'query': [1]}"
|
"query": 1,
|
||||||
)
|
"mapping_query_str": {"foo": "baz", "bar": "buzz"},
|
||||||
|
"mapping_query_int": {},
|
||||||
|
"sequence_mapping_queries": {},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def test_sequence_mapping_query():
|
def test_sequence_mapping_query():
|
||||||
response = client.get("/query/mapping-sequence-params/?foo=1&foo=2")
|
response = client.get("/query/mapping-sequence-params/?foo=1&foo=2")
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.json() == "foo bar {'foo': [1, 2]}"
|
assert response.json() == {"queries": {"foo": [1, 2]}}
|
||||||
|
|
||||||
|
|
||||||
def test_sequence_mapping_query_drops_invalid():
|
def test_sequence_mapping_query_drops_invalid():
|
||||||
response = client.get("/query/mapping-sequence-params/?foo=fuzz&foo=buzz")
|
response = client.get("/query/mapping-sequence-params/?foo=fuzz&foo=buzz")
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.json() == "foo bar {'foo': []}"
|
assert response.json() == {"queries": {}}
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,7 @@ def test_foo_needy_very(client: TestClient):
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.json() == {
|
assert response.json() == {
|
||||||
"query": 2,
|
"query": 2,
|
||||||
"string_mapping": {"query": "2", "foo": "baz"},
|
"string_mapping": {"foo": "baz"},
|
||||||
"mapping_query_int": {"query": 2},
|
"mapping_query_int": {},
|
||||||
"sequence_mapping_queries": {"query": [1, 2], "foo": []},
|
"sequence_mapping_queries": {},
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue