rm omit by default fn

This commit is contained in:
JONEMI21 2025-11-10 07:08:41 +00:00
parent 62344d272d
commit 076a8cf3fa
10 changed files with 141 additions and 52 deletions

View File

@ -189,7 +189,7 @@ You could also use `Enum`s the same way as with [Path Parameters](path-params.md
## Free Form Query Parameters
Sometimes you want to receive some query parameters, but you don't know in advance what they are called. **FastAPI** provides support for this use case.
Sometimes you want to receive some query parameters, but you don't know in advance what they are called. **FastAPI** provides support for this use case as well.
=== "Python 3.10+"

View File

@ -10,10 +10,10 @@ app = FastAPI()
def get_mixed_mapping_mixed_type_query_params(
query: Annotated[int, Query()] = None,
mapping_query_str_or_int: Annotated[
Union[Dict[str, OnErrorOmit[str]], Dict[str, int]], Query()
Union[Dict[str, OnErrorOmit[str]], Dict[str, OnErrorOmit[int]]], Query()
] = None,
mapping_query_int: Annotated[Dict[str, int], Query()] = None,
sequence_mapping_int: Annotated[Dict[str, List[int]], Query()] = None,
mapping_query_int: Annotated[Dict[str, OnErrorOmit[int]], Query()] = None,
sequence_mapping_int: Annotated[Dict[str, List[OnErrorOmit[int]]], Query()] = None,
):
return {
"query": query,

View File

@ -28,7 +28,6 @@ from .main import is_scalar_mapping_field as is_scalar_mapping_field
from .main import is_scalar_sequence_field as is_scalar_sequence_field
from .main import is_scalar_sequence_mapping_field as is_scalar_sequence_mapping_field
from .main import is_sequence_field as is_sequence_field
from .main import omit_by_default as omit_by_default
from .main import serialize_sequence_value as serialize_sequence_value
from .main import (
with_info_plain_validator_function as with_info_plain_validator_function,
@ -42,6 +41,9 @@ from .shared import PYDANTIC_V2 as PYDANTIC_V2
from .shared import PYDANTIC_VERSION_MINOR_TUPLE as PYDANTIC_VERSION_MINOR_TUPLE
from .shared import annotation_is_pydantic_v1 as annotation_is_pydantic_v1
from .shared import field_annotation_is_scalar as field_annotation_is_scalar
from .shared import (
field_annotation_is_scalar_mapping as field_annotation_is_scalar_mapping,
)
from .shared import (
field_annotation_is_scalar_sequence_mapping as field_annotation_is_scalar_sequence_mapping,
)

View File

@ -28,7 +28,6 @@ if PYDANTIC_V2:
from .v2 import Validator as Validator
from .v2 import evaluate_forwardref as evaluate_forwardref
from .v2 import get_missing_field_error as get_missing_field_error
from .v2 import omit_by_default as omit_by_default
from .v2 import (
with_info_plain_validator_function as with_info_plain_validator_function,
)

View File

@ -17,8 +17,8 @@ from typing import (
from fastapi._compat import may_v1, shared
from fastapi.openapi.constants import REF_TEMPLATE
from fastapi.types import IncEx, ModelNameMap, UnionType
from pydantic import BaseModel, OnErrorOmit, TypeAdapter, create_model
from fastapi.types import IncEx, ModelNameMap
from pydantic import BaseModel, TypeAdapter, create_model
from pydantic import PydanticSchemaGenerationError as PydanticSchemaGenerationError
from pydantic import PydanticUndefinedAnnotation as PydanticUndefinedAnnotation
from pydantic import ValidationError as ValidationError
@ -487,22 +487,3 @@ def get_flat_models_from_fields(
def get_long_model_name(model: TypeModelOrEnum) -> str:
return f"{model.__module__}__{model.__qualname__}".replace(".", "__")
def omit_by_default(annotation: Any) -> Any:
# Update the annotation to use OnErrorOmit for the inner type(s)
origin = get_origin(annotation)
if origin is Union or origin is UnionType:
new_args = []
for arg in get_args(annotation):
new_arg = omit_by_default(arg)
new_args.append(new_arg)
return Union[tuple(new_args)] # type: ignore[return-value]
elif origin is Annotated:
annotated_args = get_args(annotation)
base_annotation = annotated_args[0]
new_base_annotation = omit_by_default(base_annotation)
new_metadata = annotated_args[1:]
return Annotated[new_base_annotation + new_metadata] # type: ignore[return-value]
else:
return OnErrorOmit[annotation] # type: ignore[return-value]

View File

@ -31,7 +31,6 @@ from fastapi._compat import (
create_body_model,
evaluate_forwardref,
field_annotation_is_scalar,
field_annotation_is_scalar_sequence_mapping,
get_annotation_from_field_info,
get_cached_model_fields,
get_missing_field_error,
@ -46,7 +45,6 @@ from fastapi._compat import (
is_uploadfile_sequence_annotation,
lenient_issubclass,
may_v1,
omit_by_default,
sequence_types,
serialize_sequence_value,
value_is_sequence,
@ -488,11 +486,6 @@ def analyze_param(
):
field_info.in_ = params.ParamTypes.query
if isinstance(
field_info, (params.Query, temp_pydantic_v1_params.Query)
) and field_annotation_is_scalar_sequence_mapping(use_annotation):
use_annotation = omit_by_default(use_annotation)
use_annotation_from_field_info = get_annotation_from_field_info(
use_annotation,
field_info,
@ -525,6 +518,7 @@ def analyze_param(
is_scalar_field(field)
or is_scalar_sequence_field(field)
or is_scalar_sequence_mapping_field(field)
or is_scalar_mapping_field(field)
or (
_is_model_class(field.type_)
# For Pydantic v1

View File

@ -193,15 +193,11 @@ def get_query_param_required_type(query: int = Query()):
def get_mapping_query_params(queries: Dict[str, str] = Query({})):
return {"queries": queries}
@app.get("/query/mapping-sequence-params")
def get_sequence_mapping_query_params(queries: Dict[str, List[int]] = Query({})):
return {"queries": queries}
from pydantic import OnErrorOmit
@app.get("/query/mixed-params")
def get_mixed_mapping_query_params(
sequence_mapping_queries: Dict[str, List[Union[str, int]]] = Query({}),
sequence_mapping_queries: Dict[str, List[Union[str, OnErrorOmit[int]]]] = Query({}),
mapping_query: Dict[str, str] = Query(),
query: str = Query(),
):
@ -213,12 +209,16 @@ def get_mixed_mapping_query_params(
}
}
@app.get("/query/mapping-sequence-params")
def get_sequence_mapping_query_params(queries: Dict[str, List[OnErrorOmit[int]]] = Query({})):
return {"queries": queries}
@app.get("/query/mixed-type-params")
def get_mixed_mapping_mixed_type_query_params(
sequence_mapping_queries: Dict[str, List[int]] = Query({}),
mapping_query_str: Dict[str, str] = Query({}),
mapping_query_int: Dict[str, int] = Query({}),
sequence_mapping_queries: Dict[str, List[OnErrorOmit[int]]] = Query({}),
mapping_query_str: Dict[str, OnErrorOmit[str]] = Query({}),
mapping_query_int: Dict[str, OnErrorOmit[int]] = Query({}),
query: int = Query(),
):
return {

View File

@ -0,0 +1,111 @@
import sys
from typing import Dict, List, Optional, Union
import pytest
from pydantic import OnErrorOmit
from typing_extensions import Annotated
def omit_by_default(annotation):
"""A simplified version of the omit_by_default function for testing purposes."""
origin = getattr(annotation, "__origin__", None)
args = getattr(annotation, "__args__", ())
if origin is Annotated:
new_args = (omit_by_default(args[0]),) + args[1:]
return Annotated[new_args[0], *new_args[1:]]
elif origin is Union:
new_args = tuple(omit_by_default(arg) for arg in args)
return Union[new_args]
elif origin in (list, List):
return List[omit_by_default(args[0])]
elif origin in (dict, Dict):
return Dict[args[0], omit_by_default(args[1])]
else:
return OnErrorOmit[annotation]
def test_omit_by_default_simple_type():
result = omit_by_default(int)
assert result == OnErrorOmit[int]
def test_omit_by_default_union():
result = omit_by_default(Union[int, str])
assert result == Union[OnErrorOmit[int], OnErrorOmit[str]]
def test_omit_by_default_optional():
result = omit_by_default(Optional[int])
assert result == Union[OnErrorOmit[int], OnErrorOmit[type(None)]]
def test_omit_by_default_annotated():
result = omit_by_default(Annotated[int, "metadata"])
origin = result.__origin__ if hasattr(result, "__origin__") else None
assert origin is Annotated
args = result.__args__ if hasattr(result, "__args__") else ()
assert len(args) == 2
assert args[0] == OnErrorOmit[int]
assert args[1] == "metadata"
def test_omit_by_default_annotated_union():
result = omit_by_default(Annotated[Union[int, str], "metadata"])
origin = result.__origin__ if hasattr(result, "__origin__") else None
assert origin is Annotated
args = result.__args__ if hasattr(result, "__args__") else ()
assert len(args) == 2
assert args[0] == Union[OnErrorOmit[int], OnErrorOmit[str]]
assert args[1] == "metadata"
def test_omit_by_default_list():
result = omit_by_default(List[int])
assert result == List[OnErrorOmit[int]]
def test_omit_by_default_dict():
result = omit_by_default(Dict[str, int])
assert result == Dict[str, OnErrorOmit[int]]
def test_omit_by_default_nested_union():
result = omit_by_default(Union[int, Union[str, float]])
assert result == Union[OnErrorOmit[int], OnErrorOmit[Union[str, float]]]
def test_omit_by_default_annotated_with_multiple_metadata():
result = omit_by_default(Annotated[str, "meta1", "meta2"])
origin = result.__origin__ if hasattr(result, "__origin__") else None
assert origin is Annotated
args = result.__args__ if hasattr(result, "__args__") else ()
assert len(args) == 3
assert args[0] == OnErrorOmit[str]
assert args[1] == "meta1"
assert args[2] == "meta2"
@pytest.mark.skipif(
sys.version_info < (3, 10), reason="Union type syntax requires Python 3.10+"
)
def test_omit_by_default_pipe_union():
annotation = eval("int | str")
result = omit_by_default(annotation)
assert result == Union[OnErrorOmit[int], OnErrorOmit[str]]
def test_omit_by_default_complex_nested():
result = omit_by_default(Annotated[Union[int, Optional[str]], "metadata"])
origin = result.__origin__ if hasattr(result, "__origin__") else None
assert origin is Annotated
args = result.__args__ if hasattr(result, "__args__") else ()
assert len(args) == 2
expected_union = Union[OnErrorOmit[int], OnErrorOmit[Union[str, type(None)]]]
assert args[0] == expected_union
assert args[1] == "metadata"
def test_omit_by_default_dict_with_union_value():
result = omit_by_default(Dict[str, Union[int, str]])
assert result == Dict[str, Union[OnErrorOmit[int], OnErrorOmit[str]]]

View File

@ -426,6 +426,10 @@ def test_mapping_query():
assert response.status_code == 200
assert response.json() == {"queries": {"bar": "buzz", "foo": "fuzz"}}
def test_sequence_mapping_query():
response = client.get("/query/mapping-sequence-params/?foo=1&foo=2")
assert response.status_code == 200
assert response.json() == {"queries": {"foo": [1, 2]}}
def test_mapping_with_non_mapping_query():
response = client.get("/query/mixed-params/?foo=fuzz&foo=baz&bar=buzz&query=fizz")
@ -448,20 +452,14 @@ def test_mapping_with_non_mapping_query_mixed_types():
assert response.json() == {
"queries": {
"query": 1,
"mapping_query_str": {"foo": "baz", "bar": "buzz"},
"mapping_query_int": {},
"sequence_mapping_queries": {},
"mapping_query_str": {"bar": "buzz", "foo": "baz"},
"sequence_mapping_queries": {"bar": [], "foo": []},
}
}
def test_sequence_mapping_query():
response = client.get("/query/mapping-sequence-params/?foo=1&foo=2")
assert response.status_code == 200
assert response.json() == {"queries": {"foo": [1, 2]}}
def test_sequence_mapping_query_drops_invalid():
response = client.get("/query/mapping-sequence-params/?foo=fuzz&foo=buzz")
assert response.status_code == 200
assert response.json() == {"queries": {}}
assert response.json() == {"queries": {"foo": []}}

View File

@ -1,6 +1,10 @@
import pytest
from fastapi._compat import PYDANTIC_V2
from fastapi.testclient import TestClient
if not PYDANTIC_V2:
pytest.skip("This test is only for Pydantic v2", allow_module_level=True)
@pytest.fixture(name="client")
def get_client():