mirror of https://github.com/tiangolo/fastapi.git
omit by default query params
This commit is contained in:
parent
3847b353ce
commit
45825d7d11
|
|
@ -28,6 +28,7 @@ from .main import is_scalar_mapping_field as is_scalar_mapping_field
|
|||
from .main import is_scalar_sequence_field as is_scalar_sequence_field
|
||||
from .main import is_scalar_sequence_mapping_field as is_scalar_sequence_mapping_field
|
||||
from .main import is_sequence_field as is_sequence_field
|
||||
from .main import omit_by_default as omit_by_default
|
||||
from .main import serialize_sequence_value as serialize_sequence_value
|
||||
from .main import (
|
||||
with_info_plain_validator_function as with_info_plain_validator_function,
|
||||
|
|
|
|||
|
|
@ -384,3 +384,22 @@ def _is_model_class(value: Any) -> bool:
|
|||
|
||||
return lenient_issubclass(value, v2.BaseModel) # type: ignore[attr-defined]
|
||||
return False
|
||||
|
||||
|
||||
def omit_by_default(annotation):
|
||||
from typing import Union
|
||||
|
||||
from pydantic import OnErrorOmit
|
||||
|
||||
origin = getattr(annotation, "__origin__", None)
|
||||
args = getattr(annotation, "__args__", ())
|
||||
|
||||
if origin is Union:
|
||||
new_args = tuple(omit_by_default(arg) for arg in args)
|
||||
return Union[new_args]
|
||||
elif origin in (list, List):
|
||||
return List[omit_by_default(args[0])]
|
||||
elif origin in (dict, Dict):
|
||||
return Dict[args[0], omit_by_default(args[1])]
|
||||
else:
|
||||
return OnErrorOmit[annotation]
|
||||
|
|
|
|||
|
|
@ -45,11 +45,16 @@ from fastapi._compat import (
|
|||
is_uploadfile_sequence_annotation,
|
||||
lenient_issubclass,
|
||||
may_v1,
|
||||
omit_by_default,
|
||||
sequence_types,
|
||||
serialize_sequence_value,
|
||||
value_is_sequence,
|
||||
)
|
||||
from fastapi._compat.shared import annotation_is_pydantic_v1
|
||||
from fastapi._compat.shared import (
|
||||
annotation_is_pydantic_v1,
|
||||
field_annotation_is_scalar_mapping,
|
||||
field_annotation_is_scalar_sequence_mapping,
|
||||
)
|
||||
from fastapi.background import BackgroundTasks
|
||||
from fastapi.concurrency import (
|
||||
asynccontextmanager,
|
||||
|
|
@ -500,6 +505,12 @@ def analyze_param(
|
|||
|
||||
field_info.alias = alias
|
||||
|
||||
if hasattr(field_info, "annotation") and (
|
||||
field_annotation_is_scalar_sequence_mapping(field_info.annotation)
|
||||
or field_annotation_is_scalar_mapping(field_info.annotation)
|
||||
):
|
||||
field_info.annotation = omit_by_default(field_info.annotation)
|
||||
|
||||
field = create_model_field(
|
||||
name=param_name,
|
||||
type_=use_annotation_from_field_info,
|
||||
|
|
@ -833,7 +844,8 @@ def request_params_to_args(
|
|||
errors.extend(errors_)
|
||||
else:
|
||||
values[field.name] = v_
|
||||
# remove keys which were captured by a mapping query field but were otherwise specified
|
||||
# remove keys which were captured by a mapping query field but were
|
||||
# specified as individual fields
|
||||
for field in fields:
|
||||
if isinstance(values.get(field.name), dict) and (
|
||||
is_scalar_mapping_field(field) or is_scalar_sequence_mapping_field(field)
|
||||
|
|
|
|||
|
|
@ -194,13 +194,10 @@ def get_mapping_query_params(queries: Dict[str, str] = Query({})):
|
|||
return {"queries": queries}
|
||||
|
||||
|
||||
from pydantic import OnErrorOmit
|
||||
|
||||
|
||||
@app.get("/query/mixed-params")
|
||||
def get_mixed_mapping_query_params(
|
||||
sequence_mapping_queries: Dict[str, List[Union[str, OnErrorOmit[int]]]] = Query({}),
|
||||
mapping_query: Dict[str, str] = Query(),
|
||||
sequence_mapping_queries: Dict[str, List[Union[int]]] = Query({}),
|
||||
mapping_query: Dict[str, int] = Query(),
|
||||
query: str = Query(),
|
||||
):
|
||||
return {
|
||||
|
|
@ -213,29 +210,10 @@ def get_mixed_mapping_query_params(
|
|||
|
||||
|
||||
@app.get("/query/mapping-sequence-params")
|
||||
def get_sequence_mapping_query_params(
|
||||
queries: Dict[str, List[OnErrorOmit[int]]] = Query({}),
|
||||
):
|
||||
def get_sequence_mapping_query_params(queries: Dict[str, List[int]] = Query({})):
|
||||
return {"queries": queries}
|
||||
|
||||
|
||||
@app.get("/query/mixed-type-params")
|
||||
def get_mixed_mapping_mixed_type_query_params(
|
||||
sequence_mapping_queries: Dict[str, List[OnErrorOmit[int]]] = Query({}),
|
||||
mapping_query_str: Dict[str, OnErrorOmit[str]] = Query({}),
|
||||
mapping_query_int: Dict[str, OnErrorOmit[int]] = Query({}),
|
||||
query: int = Query(),
|
||||
):
|
||||
return {
|
||||
"queries": {
|
||||
"query": query,
|
||||
"mapping_query_str": mapping_query_str,
|
||||
"mapping_query_int": mapping_query_int,
|
||||
"sequence_mapping_queries": sequence_mapping_queries,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@app.get("/enum-status-code", status_code=http.HTTPStatus.CREATED)
|
||||
def get_enum_status_code():
|
||||
return "foo bar"
|
||||
|
|
|
|||
|
|
@ -1,111 +0,0 @@
|
|||
import sys
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import pytest
|
||||
from pydantic import OnErrorOmit
|
||||
from typing_extensions import Annotated
|
||||
|
||||
|
||||
def omit_by_default(annotation):
|
||||
"""A simplified version of the omit_by_default function for testing purposes."""
|
||||
origin = getattr(annotation, "__origin__", None)
|
||||
args = getattr(annotation, "__args__", ())
|
||||
|
||||
if origin is Annotated:
|
||||
new_args = (omit_by_default(args[0]),) + args[1:]
|
||||
return Annotated[new_args[0], *new_args[1:]]
|
||||
elif origin is Union:
|
||||
new_args = tuple(omit_by_default(arg) for arg in args)
|
||||
return Union[new_args]
|
||||
elif origin in (list, List):
|
||||
return List[omit_by_default(args[0])]
|
||||
elif origin in (dict, Dict):
|
||||
return Dict[args[0], omit_by_default(args[1])]
|
||||
else:
|
||||
return OnErrorOmit[annotation]
|
||||
|
||||
|
||||
def test_omit_by_default_simple_type():
|
||||
result = omit_by_default(int)
|
||||
assert result == OnErrorOmit[int]
|
||||
|
||||
|
||||
def test_omit_by_default_union():
|
||||
result = omit_by_default(Union[int, str])
|
||||
assert result == Union[OnErrorOmit[int], OnErrorOmit[str]]
|
||||
|
||||
|
||||
def test_omit_by_default_optional():
|
||||
result = omit_by_default(Optional[int])
|
||||
assert result == Union[OnErrorOmit[int], OnErrorOmit[type(None)]]
|
||||
|
||||
|
||||
def test_omit_by_default_annotated():
|
||||
result = omit_by_default(Annotated[int, "metadata"])
|
||||
origin = result.__origin__ if hasattr(result, "__origin__") else None
|
||||
assert origin is Annotated
|
||||
args = result.__args__ if hasattr(result, "__args__") else ()
|
||||
assert len(args) == 2
|
||||
assert args[0] == OnErrorOmit[int]
|
||||
assert args[1] == "metadata"
|
||||
|
||||
|
||||
def test_omit_by_default_annotated_union():
|
||||
result = omit_by_default(Annotated[Union[int, str], "metadata"])
|
||||
origin = result.__origin__ if hasattr(result, "__origin__") else None
|
||||
assert origin is Annotated
|
||||
args = result.__args__ if hasattr(result, "__args__") else ()
|
||||
assert len(args) == 2
|
||||
assert args[0] == Union[OnErrorOmit[int], OnErrorOmit[str]]
|
||||
assert args[1] == "metadata"
|
||||
|
||||
|
||||
def test_omit_by_default_list():
|
||||
result = omit_by_default(List[int])
|
||||
assert result == List[OnErrorOmit[int]]
|
||||
|
||||
|
||||
def test_omit_by_default_dict():
|
||||
result = omit_by_default(Dict[str, int])
|
||||
assert result == Dict[str, OnErrorOmit[int]]
|
||||
|
||||
|
||||
def test_omit_by_default_nested_union():
|
||||
result = omit_by_default(Union[int, Union[str, float]])
|
||||
assert result == Union[OnErrorOmit[int], OnErrorOmit[Union[str, float]]]
|
||||
|
||||
|
||||
def test_omit_by_default_annotated_with_multiple_metadata():
|
||||
result = omit_by_default(Annotated[str, "meta1", "meta2"])
|
||||
origin = result.__origin__ if hasattr(result, "__origin__") else None
|
||||
assert origin is Annotated
|
||||
args = result.__args__ if hasattr(result, "__args__") else ()
|
||||
assert len(args) == 3
|
||||
assert args[0] == OnErrorOmit[str]
|
||||
assert args[1] == "meta1"
|
||||
assert args[2] == "meta2"
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.version_info < (3, 10), reason="Union type syntax requires Python 3.10+"
|
||||
)
|
||||
def test_omit_by_default_pipe_union():
|
||||
annotation = eval("int | str")
|
||||
result = omit_by_default(annotation)
|
||||
assert result == Union[OnErrorOmit[int], OnErrorOmit[str]]
|
||||
|
||||
|
||||
def test_omit_by_default_complex_nested():
|
||||
result = omit_by_default(Annotated[Union[int, Optional[str]], "metadata"])
|
||||
origin = result.__origin__ if hasattr(result, "__origin__") else None
|
||||
assert origin is Annotated
|
||||
args = result.__args__ if hasattr(result, "__args__") else ()
|
||||
assert len(args) == 2
|
||||
expected_union = Union[OnErrorOmit[int], OnErrorOmit[Union[str, type(None)]]]
|
||||
assert args[0] == expected_union
|
||||
assert args[1] == "metadata"
|
||||
|
||||
|
||||
def test_omit_by_default_dict_with_union_value():
|
||||
result = omit_by_default(Dict[str, Union[int, str]])
|
||||
assert result == Dict[str, Union[OnErrorOmit[int], OnErrorOmit[str]]]
|
||||
|
|
@ -434,34 +434,15 @@ def test_sequence_mapping_query():
|
|||
|
||||
|
||||
def test_mapping_with_non_mapping_query():
|
||||
response = client.get("/query/mixed-params/?foo=fuzz&foo=baz&bar=buzz&query=fizz")
|
||||
response = client.get("/query/mixed-params/?foo=1&foo=2&bar=3&query=fizz")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {
|
||||
"queries": {
|
||||
"query": "fizz",
|
||||
"mapping_query": {"foo": "baz", "bar": "buzz"},
|
||||
"mapping_query": {"foo": 2, "bar": 3},
|
||||
"sequence_mapping_queries": {
|
||||
"foo": ["fuzz", "baz"],
|
||||
"bar": ["buzz"],
|
||||
"foo": [1, 2],
|
||||
"bar": [3],
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def test_mapping_with_non_mapping_query_mixed_types():
|
||||
response = client.get("/query/mixed-type-params/?foo=fuzz&foo=baz&bar=buzz&query=1")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {
|
||||
"queries": {
|
||||
"query": 1,
|
||||
"mapping_query_int": {},
|
||||
"mapping_query_str": {"bar": "buzz", "foo": "baz"},
|
||||
"sequence_mapping_queries": {"bar": [], "foo": []},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def test_sequence_mapping_query_drops_invalid():
|
||||
response = client.get("/query/mapping-sequence-params/?foo=fuzz&foo=buzz")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"queries": {"foo": []}}
|
||||
|
|
|
|||
Loading…
Reference in New Issue