update and remove otherwise captured query params

This commit is contained in:
JONEMI21 2025-11-06 22:04:36 +00:00
parent dcd67c7feb
commit 1a57459eda
9 changed files with 190 additions and 20 deletions

View File

@ -24,7 +24,9 @@ from .main import get_schema_from_model_field as get_schema_from_model_field
from .main import is_bytes_field as is_bytes_field
from .main import is_bytes_sequence_field as is_bytes_sequence_field
from .main import is_scalar_field as is_scalar_field
from .main import is_scalar_mapping_field as is_scalar_mapping_field
from .main import is_scalar_sequence_field as is_scalar_sequence_field
from .main import is_scalar_sequence_mapping_field as is_scalar_sequence_mapping_field
from .main import is_sequence_field as is_sequence_field
from .main import serialize_sequence_value as serialize_sequence_value
from .main import (

View File

@ -209,6 +209,30 @@ def is_sequence_field(field: ModelField) -> bool:
return v2.is_sequence_field(field) # type: ignore[arg-type]
def is_scalar_mapping_field(field: ModelField) -> bool:
if isinstance(field, may_v1.ModelField):
from fastapi._compat import v1
return v1.is_scalar_mapping_field(field)
else:
assert PYDANTIC_V2
from . import v2
return v2.is_scalar_mapping_field(field) # type: ignore[arg-type]
def is_scalar_sequence_mapping_field(field: ModelField) -> bool:
if isinstance(field, may_v1.ModelField):
from fastapi._compat import v1
return v1.is_scalar_sequence_mapping_field(field)
else:
assert PYDANTIC_V2
from . import v2
return v2.is_scalar_sequence_mapping_field(field) # type: ignore[arg-type]
def serialize_sequence_value(*, field: ModelField, value: Any) -> Sequence[Any]:
if isinstance(field, may_v1.ModelField):
from fastapi._compat import v1

View File

@ -144,6 +144,45 @@ def field_annotation_is_scalar_sequence(annotation: Union[Type[Any], None]) -> b
)
def field_annotation_is_scalar_mapping(
annotation: Union[Type[Any], None],
) -> bool:
origin = get_origin(annotation)
if origin is Union or origin is UnionType:
at_least_one_scalar_mapping = False
for arg in get_args(annotation):
if field_annotation_is_scalar_mapping(arg):
at_least_one_scalar_mapping = True
continue
elif not field_annotation_is_scalar(arg):
return False
return at_least_one_scalar_mapping
return lenient_issubclass(origin, Mapping) and all(
field_annotation_is_scalar(sub_annotation)
for sub_annotation in get_args(annotation)
)
def field_annotation_is_scalar_sequence_mapping(
annotation: Union[Type[Any], None],
) -> bool:
origin = get_origin(annotation)
if origin is Union or origin is UnionType:
at_least_one_scalar_mapping = False
for arg in get_args(annotation):
if field_annotation_is_scalar_sequence_mapping(arg):
at_least_one_scalar_mapping = True
continue
elif not field_annotation_is_scalar(arg):
return False
return at_least_one_scalar_mapping
return lenient_issubclass(origin, Mapping) and all(
field_annotation_is_scalar_sequence(sub_annotation)
or field_annotation_is_scalar(sub_annotation)
for sub_annotation in get_args(annotation)
)
def is_bytes_or_nonable_bytes_annotation(annotation: Any) -> bool:
if lenient_issubclass(annotation, bytes):
return True

View File

@ -6,6 +6,7 @@ from typing import (
Callable,
Dict,
List,
Mapping,
Sequence,
Set,
Tuple,
@ -84,6 +85,7 @@ else:
from pydantic.v1.fields import (
SHAPE_FROZENSET,
SHAPE_LIST,
SHAPE_MAPPING,
SHAPE_SEQUENCE,
SHAPE_SET,
SHAPE_SINGLETON,
@ -144,6 +146,11 @@ sequence_shape_to_type = {
SHAPE_TUPLE_ELLIPSIS: list,
}
mapping_shapes = {
SHAPE_MAPPING,
}
mapping_shapes_to_type = {SHAPE_MAPPING: Mapping}
@dataclass
class GenerateJsonSchema:
@ -219,6 +226,30 @@ def is_pv1_scalar_sequence_field(field: ModelField) -> bool:
return False
def is_pv1_scalar_sequence_mapping_field(field: ModelField) -> bool:
if (field.shape in mapping_shapes) and not lenient_issubclass( # type: ignore[attr-defined]
field.type_, BaseModel
):
if field.sub_fields is not None: # type: ignore[attr-defined]
for sub_field in field.sub_fields: # type: ignore[attr-defined]
if not is_scalar_sequence_field(sub_field):
return False
return True
return False
def is_pv1_scalar_mapping_field(field: ModelField) -> bool:
if (field.shape in mapping_shapes) and not lenient_issubclass( # type: ignore[attr-defined]
field.type_, BaseModel
):
if field.sub_fields is not None: # type: ignore[attr-defined]
for sub_field in field.sub_fields: # type: ignore[attr-defined]
if not is_scalar_field(sub_field):
return False
return True
return False
def _model_rebuild(model: Type[BaseModel]) -> None:
model.update_forward_refs()
@ -277,6 +308,14 @@ def is_scalar_sequence_field(field: ModelField) -> bool:
return is_pv1_scalar_sequence_field(field)
def is_scalar_mapping_field(field: ModelField) -> bool:
return is_pv1_scalar_mapping_field(field)
def is_scalar_sequence_mapping_field(field: ModelField) -> bool:
return is_pv1_scalar_sequence_mapping_field(field)
def is_bytes_field(field: ModelField) -> bool:
return lenient_issubclass(field.type_, bytes) # type: ignore[no-any-return]

View File

@ -352,6 +352,16 @@ def is_scalar_sequence_field(field: ModelField) -> bool:
return shared.field_annotation_is_scalar_sequence(field.field_info.annotation)
def is_scalar_mapping_field(field: ModelField) -> bool:
return shared.field_annotation_is_scalar_mapping(field.field_info.annotation)
def is_scalar_sequence_mapping_field(field: ModelField) -> bool:
return shared.field_annotation_is_scalar_sequence_mapping(
field.field_info.annotation
)
def is_bytes_field(field: ModelField) -> bool:
return shared.is_bytes_or_nonable_bytes_annotation(field.type_)

View File

@ -37,7 +37,9 @@ from fastapi._compat import (
is_bytes_field,
is_bytes_sequence_field,
is_scalar_field,
is_scalar_mapping_field,
is_scalar_sequence_field,
is_scalar_sequence_mapping_field,
is_sequence_field,
is_uploadfile_or_nonable_uploadfile_annotation,
is_uploadfile_sequence_annotation,
@ -512,6 +514,7 @@ def analyze_param(
assert (
is_scalar_field(field)
or is_scalar_sequence_field(field)
or is_scalar_sequence_mapping_field(field)
or (
_is_model_class(field.type_)
# For Pydantic v1
@ -707,6 +710,20 @@ def _validate_value_with_model_field(
else:
return deepcopy(field.default), []
v_, errors_ = field.validate(value, values, loc=loc)
if (
errors_
and isinstance(field.field_info, params.Query)
and isinstance(value, Mapping)
and (is_scalar_sequence_mapping_field(field) or is_scalar_mapping_field(field))
):
# Remove failing keys from the dict and try to re-validate
invalid_keys = {err["loc"][2] for err in errors_ if len(err["loc"]) >= 3}
v_, errors_ = field.validate(
{k: v for k, v in value.items() if k not in invalid_keys},
values,
loc=loc,
)
if _is_error_wrapper(errors_): # type: ignore[arg-type]
return None, [errors_]
elif isinstance(errors_, list):
@ -720,10 +737,19 @@ def _get_multidict_value(
field: ModelField, values: Mapping[str, Any], alias: Union[str, None] = None
) -> Any:
alias = alias or field.alias
value: Any = None
if is_sequence_field(field) and isinstance(values, (ImmutableMultiDict, Headers)):
value = values.getlist(alias)
else:
value = values.get(alias, None)
elif alias in values:
value = values[alias]
elif values and is_scalar_mapping_field(field) and isinstance(values, QueryParams):
value = dict(values)
elif (
values
and is_scalar_sequence_mapping_field(field)
and isinstance(values, QueryParams)
):
value = {key: values.getlist(key) for key in values.keys()}
if (
value is None
or (
@ -820,6 +846,13 @@ def request_params_to_args(
errors.extend(errors_)
else:
values[field.name] = v_
# remove keys which were captured by a mapping query field but were otherwise specified
for field in fields:
if isinstance(values.get(field.name), dict) and (
is_scalar_mapping_field(field) or is_scalar_sequence_mapping_field(field)
):
for f_ in fields:
values[field.name].pop(f_.alias, None)
return values, errors

View File

@ -191,12 +191,12 @@ def get_query_param_required_type(query: int = Query()):
@app.get("/query/mapping-params")
def get_mapping_query_params(queries: Dict[str, str] = Query({})):
return f"foo bar {queries['foo']} {queries['bar']}"
return {"queries": queries}
@app.get("/query/mapping-sequence-params")
def get_sequence_mapping_query_params(queries: Dict[str, List[int]] = Query({})):
return f"foo bar {dict(queries)}"
return {"queries": queries}
@app.get("/query/mixed-params")
@ -205,10 +205,13 @@ def get_mixed_mapping_query_params(
mapping_query: Dict[str, str] = Query(),
query: str = Query(),
):
return (
f"foo bar {sequence_mapping_queries['foo'][0]} {sequence_mapping_queries['foo'][1]} "
f"{mapping_query['foo']} {mapping_query['bar']} {query}"
)
return {
"queries": {
"query": query,
"mapping_query": mapping_query,
"sequence_mapping_queries": sequence_mapping_queries,
}
}
@app.get("/query/mixed-type-params")
@ -218,7 +221,14 @@ def get_mixed_mapping_mixed_type_query_params(
mapping_query_int: Dict[str, int] = Query({}),
query: int = Query(),
):
return f"foo bar {query} {mapping_query_str} {mapping_query_int} {dict(sequence_mapping_queries)}"
return {
"queries": {
"query": query,
"mapping_query_str": mapping_query_str,
"mapping_query_int": mapping_query_int,
"sequence_mapping_queries": sequence_mapping_queries,
}
}
@app.get("/enum-status-code", status_code=http.HTTPStatus.CREATED)

View File

@ -424,31 +424,44 @@ def test_query_frozenset_query_1_query_1_query_2():
def test_mapping_query():
response = client.get("/query/mapping-params/?foo=fuzz&bar=buzz")
assert response.status_code == 200
assert response.json() == "foo bar fuzz buzz"
assert response.json() == {"queries": {"bar": "buzz", "foo": "fuzz"}}
def test_mapping_with_non_mapping_query():
response = client.get("/query/mixed-params/?foo=fuzz&foo=baz&bar=buzz&query=fizz")
assert response.status_code == 200
assert response.json() == "foo bar fuzz baz baz buzz fizz"
assert response.json() == {
"queries": {
"query": "fizz",
"mapping_query": {"foo": "baz", "bar": "buzz"},
"sequence_mapping_queries": {
"foo": ["fuzz", "baz"],
"bar": ["buzz"],
},
}
}
def test_mapping_with_non_mapping_query_mixed_types():
response = client.get("/query/mixed-type-params/?foo=fuzz&foo=baz&bar=buzz&query=1")
assert response.status_code == 200
assert (
response.json()
== "foo bar 1 {'foo': 'baz', 'bar': 'buzz', 'query': '1'} {'query': 1} {'foo': [], 'bar': [], 'query': [1]}"
)
assert response.json() == {
"queries": {
"query": 1,
"mapping_query_str": {"foo": "baz", "bar": "buzz"},
"mapping_query_int": {},
"sequence_mapping_queries": {},
}
}
def test_sequence_mapping_query():
response = client.get("/query/mapping-sequence-params/?foo=1&foo=2")
assert response.status_code == 200
assert response.json() == "foo bar {'foo': [1, 2]}"
assert response.json() == {"queries": {"foo": [1, 2]}}
def test_sequence_mapping_query_drops_invalid():
response = client.get("/query/mapping-sequence-params/?foo=fuzz&foo=buzz")
assert response.status_code == 200
assert response.json() == "foo bar {'foo': []}"
assert response.json() == {"queries": {}}

View File

@ -15,7 +15,7 @@ def test_foo_needy_very(client: TestClient):
assert response.status_code == 200
assert response.json() == {
"query": 2,
"string_mapping": {"query": "2", "foo": "baz"},
"mapping_query_int": {"query": 2},
"sequence_mapping_queries": {"query": [1, 2], "foo": []},
"string_mapping": {"foo": "baz"},
"mapping_query_int": {},
"sequence_mapping_queries": {},
}