omit by default query params

This commit is contained in:
JONEMI21 2025-11-10 08:47:21 +00:00
parent 3847b353ce
commit 45825d7d11
6 changed files with 41 additions and 161 deletions

View File

@ -28,6 +28,7 @@ from .main import is_scalar_mapping_field as is_scalar_mapping_field
from .main import is_scalar_sequence_field as is_scalar_sequence_field
from .main import is_scalar_sequence_mapping_field as is_scalar_sequence_mapping_field
from .main import is_sequence_field as is_sequence_field
from .main import omit_by_default as omit_by_default
from .main import serialize_sequence_value as serialize_sequence_value
from .main import (
with_info_plain_validator_function as with_info_plain_validator_function,

View File

@ -384,3 +384,22 @@ def _is_model_class(value: Any) -> bool:
return lenient_issubclass(value, v2.BaseModel) # type: ignore[attr-defined]
return False
def omit_by_default(annotation):
from typing import Union
from pydantic import OnErrorOmit
origin = getattr(annotation, "__origin__", None)
args = getattr(annotation, "__args__", ())
if origin is Union:
new_args = tuple(omit_by_default(arg) for arg in args)
return Union[new_args]
elif origin in (list, List):
return List[omit_by_default(args[0])]
elif origin in (dict, Dict):
return Dict[args[0], omit_by_default(args[1])]
else:
return OnErrorOmit[annotation]

View File

@ -45,11 +45,16 @@ from fastapi._compat import (
is_uploadfile_sequence_annotation,
lenient_issubclass,
may_v1,
omit_by_default,
sequence_types,
serialize_sequence_value,
value_is_sequence,
)
from fastapi._compat.shared import annotation_is_pydantic_v1
from fastapi._compat.shared import (
annotation_is_pydantic_v1,
field_annotation_is_scalar_mapping,
field_annotation_is_scalar_sequence_mapping,
)
from fastapi.background import BackgroundTasks
from fastapi.concurrency import (
asynccontextmanager,
@ -500,6 +505,12 @@ def analyze_param(
field_info.alias = alias
if hasattr(field_info, "annotation") and (
field_annotation_is_scalar_sequence_mapping(field_info.annotation)
or field_annotation_is_scalar_mapping(field_info.annotation)
):
field_info.annotation = omit_by_default(field_info.annotation)
field = create_model_field(
name=param_name,
type_=use_annotation_from_field_info,
@ -833,7 +844,8 @@ def request_params_to_args(
errors.extend(errors_)
else:
values[field.name] = v_
# remove keys which were captured by a mapping query field but were otherwise specified
# remove keys which were captured by a mapping query field but were
# specified as individual fields
for field in fields:
if isinstance(values.get(field.name), dict) and (
is_scalar_mapping_field(field) or is_scalar_sequence_mapping_field(field)

View File

@ -194,13 +194,10 @@ def get_mapping_query_params(queries: Dict[str, str] = Query({})):
return {"queries": queries}
from pydantic import OnErrorOmit
@app.get("/query/mixed-params")
def get_mixed_mapping_query_params(
sequence_mapping_queries: Dict[str, List[Union[str, OnErrorOmit[int]]]] = Query({}),
mapping_query: Dict[str, str] = Query(),
sequence_mapping_queries: Dict[str, List[Union[int]]] = Query({}),
mapping_query: Dict[str, int] = Query(),
query: str = Query(),
):
return {
@ -213,29 +210,10 @@ def get_mixed_mapping_query_params(
@app.get("/query/mapping-sequence-params")
def get_sequence_mapping_query_params(
queries: Dict[str, List[OnErrorOmit[int]]] = Query({}),
):
def get_sequence_mapping_query_params(queries: Dict[str, List[int]] = Query({})):
return {"queries": queries}
@app.get("/query/mixed-type-params")
def get_mixed_mapping_mixed_type_query_params(
sequence_mapping_queries: Dict[str, List[OnErrorOmit[int]]] = Query({}),
mapping_query_str: Dict[str, OnErrorOmit[str]] = Query({}),
mapping_query_int: Dict[str, OnErrorOmit[int]] = Query({}),
query: int = Query(),
):
return {
"queries": {
"query": query,
"mapping_query_str": mapping_query_str,
"mapping_query_int": mapping_query_int,
"sequence_mapping_queries": sequence_mapping_queries,
}
}
@app.get("/enum-status-code", status_code=http.HTTPStatus.CREATED)
def get_enum_status_code():
return "foo bar"

View File

@ -1,111 +0,0 @@
import sys
from typing import Dict, List, Optional, Union
import pytest
from pydantic import OnErrorOmit
from typing_extensions import Annotated
def omit_by_default(annotation):
"""A simplified version of the omit_by_default function for testing purposes."""
origin = getattr(annotation, "__origin__", None)
args = getattr(annotation, "__args__", ())
if origin is Annotated:
new_args = (omit_by_default(args[0]),) + args[1:]
return Annotated[new_args[0], *new_args[1:]]
elif origin is Union:
new_args = tuple(omit_by_default(arg) for arg in args)
return Union[new_args]
elif origin in (list, List):
return List[omit_by_default(args[0])]
elif origin in (dict, Dict):
return Dict[args[0], omit_by_default(args[1])]
else:
return OnErrorOmit[annotation]
def test_omit_by_default_simple_type():
result = omit_by_default(int)
assert result == OnErrorOmit[int]
def test_omit_by_default_union():
result = omit_by_default(Union[int, str])
assert result == Union[OnErrorOmit[int], OnErrorOmit[str]]
def test_omit_by_default_optional():
result = omit_by_default(Optional[int])
assert result == Union[OnErrorOmit[int], OnErrorOmit[type(None)]]
def test_omit_by_default_annotated():
result = omit_by_default(Annotated[int, "metadata"])
origin = result.__origin__ if hasattr(result, "__origin__") else None
assert origin is Annotated
args = result.__args__ if hasattr(result, "__args__") else ()
assert len(args) == 2
assert args[0] == OnErrorOmit[int]
assert args[1] == "metadata"
def test_omit_by_default_annotated_union():
result = omit_by_default(Annotated[Union[int, str], "metadata"])
origin = result.__origin__ if hasattr(result, "__origin__") else None
assert origin is Annotated
args = result.__args__ if hasattr(result, "__args__") else ()
assert len(args) == 2
assert args[0] == Union[OnErrorOmit[int], OnErrorOmit[str]]
assert args[1] == "metadata"
def test_omit_by_default_list():
result = omit_by_default(List[int])
assert result == List[OnErrorOmit[int]]
def test_omit_by_default_dict():
result = omit_by_default(Dict[str, int])
assert result == Dict[str, OnErrorOmit[int]]
def test_omit_by_default_nested_union():
result = omit_by_default(Union[int, Union[str, float]])
assert result == Union[OnErrorOmit[int], OnErrorOmit[Union[str, float]]]
def test_omit_by_default_annotated_with_multiple_metadata():
result = omit_by_default(Annotated[str, "meta1", "meta2"])
origin = result.__origin__ if hasattr(result, "__origin__") else None
assert origin is Annotated
args = result.__args__ if hasattr(result, "__args__") else ()
assert len(args) == 3
assert args[0] == OnErrorOmit[str]
assert args[1] == "meta1"
assert args[2] == "meta2"
@pytest.mark.skipif(
sys.version_info < (3, 10), reason="Union type syntax requires Python 3.10+"
)
def test_omit_by_default_pipe_union():
annotation = eval("int | str")
result = omit_by_default(annotation)
assert result == Union[OnErrorOmit[int], OnErrorOmit[str]]
def test_omit_by_default_complex_nested():
result = omit_by_default(Annotated[Union[int, Optional[str]], "metadata"])
origin = result.__origin__ if hasattr(result, "__origin__") else None
assert origin is Annotated
args = result.__args__ if hasattr(result, "__args__") else ()
assert len(args) == 2
expected_union = Union[OnErrorOmit[int], OnErrorOmit[Union[str, type(None)]]]
assert args[0] == expected_union
assert args[1] == "metadata"
def test_omit_by_default_dict_with_union_value():
result = omit_by_default(Dict[str, Union[int, str]])
assert result == Dict[str, Union[OnErrorOmit[int], OnErrorOmit[str]]]

View File

@ -434,34 +434,15 @@ def test_sequence_mapping_query():
def test_mapping_with_non_mapping_query():
response = client.get("/query/mixed-params/?foo=fuzz&foo=baz&bar=buzz&query=fizz")
response = client.get("/query/mixed-params/?foo=1&foo=2&bar=3&query=fizz")
assert response.status_code == 200
assert response.json() == {
"queries": {
"query": "fizz",
"mapping_query": {"foo": "baz", "bar": "buzz"},
"mapping_query": {"foo": 2, "bar": 3},
"sequence_mapping_queries": {
"foo": ["fuzz", "baz"],
"bar": ["buzz"],
"foo": [1, 2],
"bar": [3],
},
}
}
def test_mapping_with_non_mapping_query_mixed_types():
response = client.get("/query/mixed-type-params/?foo=fuzz&foo=baz&bar=buzz&query=1")
assert response.status_code == 200
assert response.json() == {
"queries": {
"query": 1,
"mapping_query_int": {},
"mapping_query_str": {"bar": "buzz", "foo": "baz"},
"sequence_mapping_queries": {"bar": [], "foo": []},
}
}
def test_sequence_mapping_query_drops_invalid():
response = client.get("/query/mapping-sequence-params/?foo=fuzz&foo=buzz")
assert response.status_code == 200
assert response.json() == {"queries": {"foo": []}}