mirror of https://github.com/tiangolo/fastapi.git
update and remove otherwise captured query params
This commit is contained in:
parent
dcd67c7feb
commit
1a57459eda
|
|
@ -24,7 +24,9 @@ from .main import get_schema_from_model_field as get_schema_from_model_field
|
|||
from .main import is_bytes_field as is_bytes_field
|
||||
from .main import is_bytes_sequence_field as is_bytes_sequence_field
|
||||
from .main import is_scalar_field as is_scalar_field
|
||||
from .main import is_scalar_mapping_field as is_scalar_mapping_field
|
||||
from .main import is_scalar_sequence_field as is_scalar_sequence_field
|
||||
from .main import is_scalar_sequence_mapping_field as is_scalar_sequence_mapping_field
|
||||
from .main import is_sequence_field as is_sequence_field
|
||||
from .main import serialize_sequence_value as serialize_sequence_value
|
||||
from .main import (
|
||||
|
|
|
|||
|
|
@ -209,6 +209,30 @@ def is_sequence_field(field: ModelField) -> bool:
|
|||
return v2.is_sequence_field(field) # type: ignore[arg-type]
|
||||
|
||||
|
||||
def is_scalar_mapping_field(field: ModelField) -> bool:
|
||||
if isinstance(field, may_v1.ModelField):
|
||||
from fastapi._compat import v1
|
||||
|
||||
return v1.is_scalar_mapping_field(field)
|
||||
else:
|
||||
assert PYDANTIC_V2
|
||||
from . import v2
|
||||
|
||||
return v2.is_scalar_mapping_field(field) # type: ignore[arg-type]
|
||||
|
||||
|
||||
def is_scalar_sequence_mapping_field(field: ModelField) -> bool:
|
||||
if isinstance(field, may_v1.ModelField):
|
||||
from fastapi._compat import v1
|
||||
|
||||
return v1.is_scalar_sequence_mapping_field(field)
|
||||
else:
|
||||
assert PYDANTIC_V2
|
||||
from . import v2
|
||||
|
||||
return v2.is_scalar_sequence_mapping_field(field) # type: ignore[arg-type]
|
||||
|
||||
|
||||
def serialize_sequence_value(*, field: ModelField, value: Any) -> Sequence[Any]:
|
||||
if isinstance(field, may_v1.ModelField):
|
||||
from fastapi._compat import v1
|
||||
|
|
|
|||
|
|
@ -144,6 +144,45 @@ def field_annotation_is_scalar_sequence(annotation: Union[Type[Any], None]) -> b
|
|||
)
|
||||
|
||||
|
||||
def field_annotation_is_scalar_mapping(
|
||||
annotation: Union[Type[Any], None],
|
||||
) -> bool:
|
||||
origin = get_origin(annotation)
|
||||
if origin is Union or origin is UnionType:
|
||||
at_least_one_scalar_mapping = False
|
||||
for arg in get_args(annotation):
|
||||
if field_annotation_is_scalar_mapping(arg):
|
||||
at_least_one_scalar_mapping = True
|
||||
continue
|
||||
elif not field_annotation_is_scalar(arg):
|
||||
return False
|
||||
return at_least_one_scalar_mapping
|
||||
return lenient_issubclass(origin, Mapping) and all(
|
||||
field_annotation_is_scalar(sub_annotation)
|
||||
for sub_annotation in get_args(annotation)
|
||||
)
|
||||
|
||||
|
||||
def field_annotation_is_scalar_sequence_mapping(
|
||||
annotation: Union[Type[Any], None],
|
||||
) -> bool:
|
||||
origin = get_origin(annotation)
|
||||
if origin is Union or origin is UnionType:
|
||||
at_least_one_scalar_mapping = False
|
||||
for arg in get_args(annotation):
|
||||
if field_annotation_is_scalar_sequence_mapping(arg):
|
||||
at_least_one_scalar_mapping = True
|
||||
continue
|
||||
elif not field_annotation_is_scalar(arg):
|
||||
return False
|
||||
return at_least_one_scalar_mapping
|
||||
return lenient_issubclass(origin, Mapping) and all(
|
||||
field_annotation_is_scalar_sequence(sub_annotation)
|
||||
or field_annotation_is_scalar(sub_annotation)
|
||||
for sub_annotation in get_args(annotation)
|
||||
)
|
||||
|
||||
|
||||
def is_bytes_or_nonable_bytes_annotation(annotation: Any) -> bool:
|
||||
if lenient_issubclass(annotation, bytes):
|
||||
return True
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ from typing import (
|
|||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Mapping,
|
||||
Sequence,
|
||||
Set,
|
||||
Tuple,
|
||||
|
|
@ -84,6 +85,7 @@ else:
|
|||
from pydantic.v1.fields import (
|
||||
SHAPE_FROZENSET,
|
||||
SHAPE_LIST,
|
||||
SHAPE_MAPPING,
|
||||
SHAPE_SEQUENCE,
|
||||
SHAPE_SET,
|
||||
SHAPE_SINGLETON,
|
||||
|
|
@ -144,6 +146,11 @@ sequence_shape_to_type = {
|
|||
SHAPE_TUPLE_ELLIPSIS: list,
|
||||
}
|
||||
|
||||
mapping_shapes = {
|
||||
SHAPE_MAPPING,
|
||||
}
|
||||
mapping_shapes_to_type = {SHAPE_MAPPING: Mapping}
|
||||
|
||||
|
||||
@dataclass
|
||||
class GenerateJsonSchema:
|
||||
|
|
@ -219,6 +226,30 @@ def is_pv1_scalar_sequence_field(field: ModelField) -> bool:
|
|||
return False
|
||||
|
||||
|
||||
def is_pv1_scalar_sequence_mapping_field(field: ModelField) -> bool:
|
||||
if (field.shape in mapping_shapes) and not lenient_issubclass( # type: ignore[attr-defined]
|
||||
field.type_, BaseModel
|
||||
):
|
||||
if field.sub_fields is not None: # type: ignore[attr-defined]
|
||||
for sub_field in field.sub_fields: # type: ignore[attr-defined]
|
||||
if not is_scalar_sequence_field(sub_field):
|
||||
return False
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def is_pv1_scalar_mapping_field(field: ModelField) -> bool:
|
||||
if (field.shape in mapping_shapes) and not lenient_issubclass( # type: ignore[attr-defined]
|
||||
field.type_, BaseModel
|
||||
):
|
||||
if field.sub_fields is not None: # type: ignore[attr-defined]
|
||||
for sub_field in field.sub_fields: # type: ignore[attr-defined]
|
||||
if not is_scalar_field(sub_field):
|
||||
return False
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _model_rebuild(model: Type[BaseModel]) -> None:
|
||||
model.update_forward_refs()
|
||||
|
||||
|
|
@ -277,6 +308,14 @@ def is_scalar_sequence_field(field: ModelField) -> bool:
|
|||
return is_pv1_scalar_sequence_field(field)
|
||||
|
||||
|
||||
def is_scalar_mapping_field(field: ModelField) -> bool:
|
||||
return is_pv1_scalar_mapping_field(field)
|
||||
|
||||
|
||||
def is_scalar_sequence_mapping_field(field: ModelField) -> bool:
|
||||
return is_pv1_scalar_sequence_mapping_field(field)
|
||||
|
||||
|
||||
def is_bytes_field(field: ModelField) -> bool:
|
||||
return lenient_issubclass(field.type_, bytes) # type: ignore[no-any-return]
|
||||
|
||||
|
|
|
|||
|
|
@ -352,6 +352,16 @@ def is_scalar_sequence_field(field: ModelField) -> bool:
|
|||
return shared.field_annotation_is_scalar_sequence(field.field_info.annotation)
|
||||
|
||||
|
||||
def is_scalar_mapping_field(field: ModelField) -> bool:
|
||||
return shared.field_annotation_is_scalar_mapping(field.field_info.annotation)
|
||||
|
||||
|
||||
def is_scalar_sequence_mapping_field(field: ModelField) -> bool:
|
||||
return shared.field_annotation_is_scalar_sequence_mapping(
|
||||
field.field_info.annotation
|
||||
)
|
||||
|
||||
|
||||
def is_bytes_field(field: ModelField) -> bool:
|
||||
return shared.is_bytes_or_nonable_bytes_annotation(field.type_)
|
||||
|
||||
|
|
|
|||
|
|
@ -37,7 +37,9 @@ from fastapi._compat import (
|
|||
is_bytes_field,
|
||||
is_bytes_sequence_field,
|
||||
is_scalar_field,
|
||||
is_scalar_mapping_field,
|
||||
is_scalar_sequence_field,
|
||||
is_scalar_sequence_mapping_field,
|
||||
is_sequence_field,
|
||||
is_uploadfile_or_nonable_uploadfile_annotation,
|
||||
is_uploadfile_sequence_annotation,
|
||||
|
|
@ -512,6 +514,7 @@ def analyze_param(
|
|||
assert (
|
||||
is_scalar_field(field)
|
||||
or is_scalar_sequence_field(field)
|
||||
or is_scalar_sequence_mapping_field(field)
|
||||
or (
|
||||
_is_model_class(field.type_)
|
||||
# For Pydantic v1
|
||||
|
|
@ -707,6 +710,20 @@ def _validate_value_with_model_field(
|
|||
else:
|
||||
return deepcopy(field.default), []
|
||||
v_, errors_ = field.validate(value, values, loc=loc)
|
||||
if (
|
||||
errors_
|
||||
and isinstance(field.field_info, params.Query)
|
||||
and isinstance(value, Mapping)
|
||||
and (is_scalar_sequence_mapping_field(field) or is_scalar_mapping_field(field))
|
||||
):
|
||||
# Remove failing keys from the dict and try to re-validate
|
||||
invalid_keys = {err["loc"][2] for err in errors_ if len(err["loc"]) >= 3}
|
||||
v_, errors_ = field.validate(
|
||||
{k: v for k, v in value.items() if k not in invalid_keys},
|
||||
values,
|
||||
loc=loc,
|
||||
)
|
||||
|
||||
if _is_error_wrapper(errors_): # type: ignore[arg-type]
|
||||
return None, [errors_]
|
||||
elif isinstance(errors_, list):
|
||||
|
|
@ -720,10 +737,19 @@ def _get_multidict_value(
|
|||
field: ModelField, values: Mapping[str, Any], alias: Union[str, None] = None
|
||||
) -> Any:
|
||||
alias = alias or field.alias
|
||||
value: Any = None
|
||||
if is_sequence_field(field) and isinstance(values, (ImmutableMultiDict, Headers)):
|
||||
value = values.getlist(alias)
|
||||
else:
|
||||
value = values.get(alias, None)
|
||||
elif alias in values:
|
||||
value = values[alias]
|
||||
elif values and is_scalar_mapping_field(field) and isinstance(values, QueryParams):
|
||||
value = dict(values)
|
||||
elif (
|
||||
values
|
||||
and is_scalar_sequence_mapping_field(field)
|
||||
and isinstance(values, QueryParams)
|
||||
):
|
||||
value = {key: values.getlist(key) for key in values.keys()}
|
||||
if (
|
||||
value is None
|
||||
or (
|
||||
|
|
@ -820,6 +846,13 @@ def request_params_to_args(
|
|||
errors.extend(errors_)
|
||||
else:
|
||||
values[field.name] = v_
|
||||
# remove keys which were captured by a mapping query field but were otherwise specified
|
||||
for field in fields:
|
||||
if isinstance(values.get(field.name), dict) and (
|
||||
is_scalar_mapping_field(field) or is_scalar_sequence_mapping_field(field)
|
||||
):
|
||||
for f_ in fields:
|
||||
values[field.name].pop(f_.alias, None)
|
||||
return values, errors
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -191,12 +191,12 @@ def get_query_param_required_type(query: int = Query()):
|
|||
|
||||
@app.get("/query/mapping-params")
|
||||
def get_mapping_query_params(queries: Dict[str, str] = Query({})):
|
||||
return f"foo bar {queries['foo']} {queries['bar']}"
|
||||
return {"queries": queries}
|
||||
|
||||
|
||||
@app.get("/query/mapping-sequence-params")
|
||||
def get_sequence_mapping_query_params(queries: Dict[str, List[int]] = Query({})):
|
||||
return f"foo bar {dict(queries)}"
|
||||
return {"queries": queries}
|
||||
|
||||
|
||||
@app.get("/query/mixed-params")
|
||||
|
|
@ -205,10 +205,13 @@ def get_mixed_mapping_query_params(
|
|||
mapping_query: Dict[str, str] = Query(),
|
||||
query: str = Query(),
|
||||
):
|
||||
return (
|
||||
f"foo bar {sequence_mapping_queries['foo'][0]} {sequence_mapping_queries['foo'][1]} "
|
||||
f"{mapping_query['foo']} {mapping_query['bar']} {query}"
|
||||
)
|
||||
return {
|
||||
"queries": {
|
||||
"query": query,
|
||||
"mapping_query": mapping_query,
|
||||
"sequence_mapping_queries": sequence_mapping_queries,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@app.get("/query/mixed-type-params")
|
||||
|
|
@ -218,7 +221,14 @@ def get_mixed_mapping_mixed_type_query_params(
|
|||
mapping_query_int: Dict[str, int] = Query({}),
|
||||
query: int = Query(),
|
||||
):
|
||||
return f"foo bar {query} {mapping_query_str} {mapping_query_int} {dict(sequence_mapping_queries)}"
|
||||
return {
|
||||
"queries": {
|
||||
"query": query,
|
||||
"mapping_query_str": mapping_query_str,
|
||||
"mapping_query_int": mapping_query_int,
|
||||
"sequence_mapping_queries": sequence_mapping_queries,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@app.get("/enum-status-code", status_code=http.HTTPStatus.CREATED)
|
||||
|
|
|
|||
|
|
@ -424,31 +424,44 @@ def test_query_frozenset_query_1_query_1_query_2():
|
|||
def test_mapping_query():
|
||||
response = client.get("/query/mapping-params/?foo=fuzz&bar=buzz")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == "foo bar fuzz buzz"
|
||||
assert response.json() == {"queries": {"bar": "buzz", "foo": "fuzz"}}
|
||||
|
||||
|
||||
def test_mapping_with_non_mapping_query():
|
||||
response = client.get("/query/mixed-params/?foo=fuzz&foo=baz&bar=buzz&query=fizz")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == "foo bar fuzz baz baz buzz fizz"
|
||||
assert response.json() == {
|
||||
"queries": {
|
||||
"query": "fizz",
|
||||
"mapping_query": {"foo": "baz", "bar": "buzz"},
|
||||
"sequence_mapping_queries": {
|
||||
"foo": ["fuzz", "baz"],
|
||||
"bar": ["buzz"],
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def test_mapping_with_non_mapping_query_mixed_types():
|
||||
response = client.get("/query/mixed-type-params/?foo=fuzz&foo=baz&bar=buzz&query=1")
|
||||
assert response.status_code == 200
|
||||
assert (
|
||||
response.json()
|
||||
== "foo bar 1 {'foo': 'baz', 'bar': 'buzz', 'query': '1'} {'query': 1} {'foo': [], 'bar': [], 'query': [1]}"
|
||||
)
|
||||
assert response.json() == {
|
||||
"queries": {
|
||||
"query": 1,
|
||||
"mapping_query_str": {"foo": "baz", "bar": "buzz"},
|
||||
"mapping_query_int": {},
|
||||
"sequence_mapping_queries": {},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def test_sequence_mapping_query():
|
||||
response = client.get("/query/mapping-sequence-params/?foo=1&foo=2")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == "foo bar {'foo': [1, 2]}"
|
||||
assert response.json() == {"queries": {"foo": [1, 2]}}
|
||||
|
||||
|
||||
def test_sequence_mapping_query_drops_invalid():
|
||||
response = client.get("/query/mapping-sequence-params/?foo=fuzz&foo=buzz")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == "foo bar {'foo': []}"
|
||||
assert response.json() == {"queries": {}}
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ def test_foo_needy_very(client: TestClient):
|
|||
assert response.status_code == 200
|
||||
assert response.json() == {
|
||||
"query": 2,
|
||||
"string_mapping": {"query": "2", "foo": "baz"},
|
||||
"mapping_query_int": {"query": 2},
|
||||
"sequence_mapping_queries": {"query": [1, 2], "foo": []},
|
||||
"string_mapping": {"foo": "baz"},
|
||||
"mapping_query_int": {},
|
||||
"sequence_mapping_queries": {},
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue