diff --git a/docs/en/docs/img/tutorial/query-params/image01.png b/docs/en/docs/img/tutorial/query-params/image01.png new file mode 100644 index 0000000000..8fa5f85b70 Binary files /dev/null and b/docs/en/docs/img/tutorial/query-params/image01.png differ diff --git a/docs/en/docs/tutorial/query-params.md b/docs/en/docs/tutorial/query-params.md index 3c9c225fb0..e0532870b4 100644 --- a/docs/en/docs/tutorial/query-params.md +++ b/docs/en/docs/tutorial/query-params.md @@ -186,3 +186,17 @@ In this case, there are 3 query parameters: You could also use `Enum`s the same way as with [Path Parameters](path-params.md#predefined-values){.internal-link target=_blank}. /// + +## Free Form Query Parameters { #free-form-query-parameters } + +Sometimes you want to receive some query parameters, but you don't know in advance what they are called. **FastAPI** provides support for this use case as well. + +=== "Python 3.10+" + + ```Python hl_lines="8" + {!> ../../../docs_src/query_params/tutorial007_py310.py!} + ``` + +And when you open your browser at http://127.0.0.1:8000/docs, you will that OpenAPI supports this format of query parameter: + + diff --git a/docs_src/query_params/tutorial007_py310.py b/docs_src/query_params/tutorial007_py310.py new file mode 100644 index 0000000000..848dd57bb8 --- /dev/null +++ b/docs_src/query_params/tutorial007_py310.py @@ -0,0 +1,20 @@ +from typing import Annotated + +from fastapi import FastAPI, Query + +app = FastAPI() + + +@app.get("/query/mixed-type-params") +def get_mixed_mapping_mixed_type_query_params( + query: Annotated[int, Query()] = None, + mapping_query_str: Annotated[dict[str, str], Query()] = None, + mapping_query_int: Annotated[dict[str, int], Query()] = None, + sequence_mapping_int: Annotated[dict[str, list[int]], Query()] = None, +): + return { + "query": query, + "mapping_query_str": mapping_query_str, + "mapping_query_int": mapping_query_int, + "sequence_mapping_int": sequence_mapping_int, + } diff --git a/fastapi/_compat/__init__.py b/fastapi/_compat/__init__.py index 3dfaf9b712..0ea6ca32fc 100644 --- a/fastapi/_compat/__init__.py +++ b/fastapi/_compat/__init__.py @@ -2,6 +2,12 @@ from .shared import PYDANTIC_V2 as PYDANTIC_V2 from .shared import PYDANTIC_VERSION_MINOR_TUPLE as PYDANTIC_VERSION_MINOR_TUPLE from .shared import annotation_is_pydantic_v1 as annotation_is_pydantic_v1 from .shared import field_annotation_is_scalar as field_annotation_is_scalar +from .shared import ( + field_annotation_is_scalar_mapping as field_annotation_is_scalar_mapping, +) +from .shared import ( + field_annotation_is_scalar_sequence_mapping as field_annotation_is_scalar_sequence_mapping, +) from .shared import is_pydantic_v1_model_class as is_pydantic_v1_model_class from .shared import is_pydantic_v1_model_instance as is_pydantic_v1_model_instance from .shared import ( @@ -33,8 +39,11 @@ from .v2 import get_schema_from_model_field as get_schema_from_model_field from .v2 import is_bytes_field as is_bytes_field from .v2 import is_bytes_sequence_field as is_bytes_sequence_field from .v2 import is_scalar_field as is_scalar_field +from .v2 import is_scalar_mapping_field as is_scalar_mapping_field from .v2 import is_scalar_sequence_field as is_scalar_sequence_field +from .v2 import is_scalar_sequence_mapping_field as is_scalar_sequence_mapping_field from .v2 import is_sequence_field as is_sequence_field +from .v2 import omit_by_default as omit_by_default from .v2 import serialize_sequence_value as serialize_sequence_value from .v2 import ( with_info_plain_validator_function as with_info_plain_validator_function, diff --git a/fastapi/_compat/shared.py b/fastapi/_compat/shared.py index 68b9bbdf1e..51cb6691f7 100644 --- a/fastapi/_compat/shared.py +++ b/fastapi/_compat/shared.py @@ -125,6 +125,52 @@ def field_annotation_is_scalar_sequence(annotation: Union[type[Any], None]) -> b ) +def field_annotation_is_scalar_mapping( + annotation: Union[type[Any], None], +) -> bool: + origin = get_origin(annotation) + if origin is Annotated: + return field_annotation_is_scalar_mapping(get_args(annotation)[0]) + if origin is Union or origin is UnionType: + at_least_one_scalar_mapping = False + for arg in get_args(annotation): + if field_annotation_is_scalar_mapping(arg): + at_least_one_scalar_mapping = True + continue + elif not field_annotation_is_scalar(arg): + return False + return at_least_one_scalar_mapping + return lenient_issubclass(origin, Mapping) and all( + field_annotation_is_scalar(sub_annotation) + for sub_annotation in get_args(annotation) + ) + + +def field_annotation_is_scalar_sequence_mapping( + annotation: Union[type[Any], None], +) -> bool: + origin = get_origin(annotation) + if origin is Annotated: + return field_annotation_is_scalar_sequence_mapping(get_args(annotation)[0]) + if origin is Union or origin is UnionType: + at_least_one_scalar_mapping = False + for arg in get_args(annotation): + if field_annotation_is_scalar_sequence_mapping(arg): + at_least_one_scalar_mapping = True + continue + elif not ( + field_annotation_is_scalar_sequence_mapping(arg) + or field_annotation_is_scalar_mapping(arg) + ): + return False + return at_least_one_scalar_mapping + return lenient_issubclass(origin, Mapping) and all( + field_annotation_is_scalar_sequence(sub_annotation) + or field_annotation_is_scalar(sub_annotation) + for sub_annotation in get_args(annotation) + ) + + def is_bytes_or_nonable_bytes_annotation(annotation: Any) -> bool: if lenient_issubclass(annotation, bytes): return True diff --git a/fastapi/_compat/v2.py b/fastapi/_compat/v2.py index dae78a32e0..1a6aabb2c7 100644 --- a/fastapi/_compat/v2.py +++ b/fastapi/_compat/v2.py @@ -5,17 +5,20 @@ from copy import copy, deepcopy from dataclasses import dataclass, is_dataclass from enum import Enum from functools import lru_cache -from typing import ( - Annotated, - Any, - Union, - cast, -) +from typing import Annotated, Any, Callable, Union, cast from fastapi._compat import shared from fastapi.openapi.constants import REF_TEMPLATE from fastapi.types import IncEx, ModelNameMap, UnionType -from pydantic import BaseModel, ConfigDict, Field, TypeAdapter, create_model +from pydantic import ( + BaseModel, + ConfigDict, + Field, + OnErrorOmit, + TypeAdapter, + WrapValidator, + create_model, +) from pydantic import PydanticSchemaGenerationError as PydanticSchemaGenerationError from pydantic import PydanticUndefinedAnnotation as PydanticUndefinedAnnotation from pydantic import ValidationError as ValidationError @@ -411,6 +414,16 @@ def is_scalar_sequence_field(field: ModelField) -> bool: return shared.field_annotation_is_scalar_sequence(field.field_info.annotation) +def is_scalar_mapping_field(field: ModelField) -> bool: + return shared.field_annotation_is_scalar_mapping(field.field_info.annotation) + + +def is_scalar_sequence_mapping_field(field: ModelField) -> bool: + return shared.field_annotation_is_scalar_sequence_mapping( + field.field_info.annotation + ) + + def is_bytes_field(field: ModelField) -> bool: return shared.is_bytes_or_nonable_bytes_annotation(field.type_) @@ -566,3 +579,83 @@ def _regenerate_error_with_loc( ] return updated_loc_errors + + +if shared.PYDANTIC_VERSION_MINOR_TUPLE >= (2, 6): + # Omit by default for scalar mapping and scalar sequence mapping annotations + # added in Pydantic v2.6 https://github.com/pydantic/pydantic/releases/tag/v2.6.0 + def _omit_by_default(annotation: Any, depth: int = 0) -> Any: + origin = get_origin(annotation) + args = get_args(annotation) + + if (origin is Union or origin is UnionType) and depth == 0: + # making the depth check since the values of dicts being Union types + # is not working as expected as of Pydantic v2.12.3 so we just omit at + # the top level Union here for now + # https://github.com/pydantic/pydantic-core/issues/1900 + # https://github.com/pydantic/pydantic/issues/12750 + return Union[tuple(_omit_by_default(arg) for arg in args)] + elif origin is list: + return list[_omit_by_default(args[0], depth=depth + 1)] # type: ignore[misc] + elif origin is dict: + return dict[args[0], _omit_by_default(args[1], depth=depth + 1)] # type: ignore[misc,valid-type] + else: + return OnErrorOmit[annotation] # type: ignore[misc] + + def omit_by_default( + field_info: FieldInfo, + ) -> tuple[FieldInfo, dict[str, Callable[..., Any]]]: + new_annotation = _omit_by_default(field_info.annotation) + new_field_info = copy_field_info( + field_info=field_info, annotation=new_annotation + ) + return new_field_info, {} + +else: # pragma: no cover + + def ignore_invalid(v: Any, handler: Callable[[Any], Any]) -> Any: + try: + return handler(v) + except ValidationError as exc: + # pop the keys or elements that caused the validation errors and revalidate + for error in exc.errors(): + loc = error["loc"] + if len(loc) == 0: + continue + if isinstance(loc[0], int) and isinstance(v, list): + index = loc[0] + if 0 <= index < len(v): + v[index] = None + + # Handle nested list validation errors (e.g., dict[str, list[str]]) + elif isinstance(loc[0], str) and isinstance(v, dict): + key = loc[0] + if ( + len(loc) > 1 + and isinstance(loc[1], int) + and key in v + and isinstance(v[key], list) + ): + list_index = loc[1] + v[key][list_index] = None + elif key in v: + v.pop(key) + + if isinstance(v, list): + v = [el for el in v if el is not None] + + if isinstance(v, dict): + for key in v.keys(): + if isinstance(v[key], list): + v[key] = [el for el in v[key] if el is not None] + + return handler(v) + + def omit_by_default( + field_info: FieldInfo, + ) -> tuple[FieldInfo, dict[str, Callable[..., Any]]]: + """add a wrap validator to omit invalid values by default.""" + field_info.metadata = field_info.metadata or [] + [ + WrapValidator(ignore_invalid) + ] + return field_info, {} diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index fc5dfed85a..107049be2d 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -26,16 +26,21 @@ from fastapi._compat import ( create_body_model, evaluate_forwardref, field_annotation_is_scalar, + field_annotation_is_scalar_mapping, + field_annotation_is_scalar_sequence_mapping, get_cached_model_fields, get_missing_field_error, is_bytes_field, is_bytes_sequence_field, is_scalar_field, + is_scalar_mapping_field, is_scalar_sequence_field, + is_scalar_sequence_mapping_field, is_sequence_field, is_uploadfile_or_nonable_uploadfile_annotation, is_uploadfile_sequence_annotation, lenient_issubclass, + omit_by_default, sequence_types, serialize_sequence_value, value_is_sequence, @@ -502,7 +507,17 @@ def analyze_param( alias = param_name.replace("_", "-") else: alias = field_info.alias or param_name + field_info.alias = alias + + # Omit by default for scalar mapping and scalar sequence mapping query fields + class_validators: dict[str, Callable[..., Any]] = {} + if isinstance(field_info, params.Query) and ( + field_annotation_is_scalar_sequence_mapping(use_annotation_from_field_info) + or field_annotation_is_scalar_mapping(use_annotation_from_field_info) + ): + field_info, class_validators = omit_by_default(field_info) + field = create_model_field( name=param_name, type_=use_annotation_from_field_info, @@ -510,6 +525,7 @@ def analyze_param( alias=alias, required=field_info.default in (RequiredParam, Undefined), field_info=field_info, + class_validators=class_validators, ) if is_path_param: assert is_scalar_field(field=field), ( @@ -519,6 +535,8 @@ def analyze_param( assert ( is_scalar_field(field) or is_scalar_sequence_field(field) + or is_scalar_mapping_field(field) + or is_scalar_sequence_mapping_field(field) or ( lenient_issubclass(field.type_, BaseModel) # For Pydantic v1 @@ -714,6 +732,7 @@ def _validate_value_with_model_field( else: return deepcopy(field.default), [] v_, errors_ = field.validate(value, values, loc=loc) + if isinstance(errors_, list): new_errors = _regenerate_error_with_loc(errors=errors_, loc_prefix=()) return None, new_errors @@ -725,10 +744,19 @@ def _get_multidict_value( field: ModelField, values: Mapping[str, Any], alias: Union[str, None] = None ) -> Any: alias = alias or get_validation_alias(field) + value: Any = None if is_sequence_field(field) and isinstance(values, (ImmutableMultiDict, Headers)): value = values.getlist(alias) - else: - value = values.get(alias, None) + elif alias in values: + value = values[alias] + elif values and is_scalar_mapping_field(field) and isinstance(values, QueryParams): + value = dict(values) + elif ( + values + and is_scalar_sequence_mapping_field(field) + and isinstance(values, QueryParams) + ): + value = {key: values.getlist(key) for key in values.keys()} if ( value is None or ( @@ -825,6 +853,14 @@ def request_params_to_args( errors.extend(errors_) else: values[field.name] = v_ + # remove keys which were captured by a mapping query field but were + # specified as individual fields + for field in fields: + if isinstance(values.get(field.name), dict) and ( + is_scalar_mapping_field(field) or is_scalar_sequence_mapping_field(field) + ): + for f_ in fields: + values[field.name].pop(f_.alias, None) return values, errors diff --git a/tests/main.py b/tests/main.py index 7edb16c615..7e291f6e86 100644 --- a/tests/main.py +++ b/tests/main.py @@ -189,6 +189,48 @@ def get_query_param_required_type(query: int = Query()): return f"foo bar {query}" +@app.get("/query/mapping-params") +def get_mapping_query_params(queries: dict[str, str] = Query({})): + return {"queries": queries} + + +@app.get("/query/mixed-params") +def get_mixed_mapping_query_params( + sequence_mapping_queries: dict[str, list[int]] = Query({}), + mapping_query: dict[str, str] = Query(), + query: str = Query(), +): + return { + "queries": { + "query": query, + "mapping_query": mapping_query, + "sequence_mapping_queries": sequence_mapping_queries, + } + } + + +@app.get("/query/mapping-sequence-params") +def get_sequence_mapping_query_params(queries: dict[str, list[int]] = Query({})): + return {"queries": queries} + + +@app.get("/query/mixed-type-params") +def get_mixed_mapping_mixed_type_query_params( + sequence_mapping_queries: dict[str, list[int]] = Query({}), + mapping_query_str: dict[str, str] = Query({}), + mapping_query_int: dict[str, int] = Query({}), + query: int = Query(), +): + return { + "queries": { + "query": query, + "mapping_query_str": mapping_query_str, + "mapping_query_int": mapping_query_int, + "sequence_mapping_queries": sequence_mapping_queries, + } + } + + @app.get("/enum-status-code", status_code=http.HTTPStatus.CREATED) def get_enum_status_code(): return "foo bar" diff --git a/tests/test_application.py b/tests/test_application.py index fe97e674c0..9153398947 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -1111,6 +1111,235 @@ def test_openapi_schema(): ], } }, + "/query/mapping-params": { + "get": { + "operationId": "get_mapping_query_params_query_mapping_params_get", + "parameters": [ + { + "in": "query", + "name": "queries", + "required": False, + "schema": { + "additionalProperties": { + "type": "string", + }, + "default": {}, + "title": "Queries", + "type": "object", + }, + }, + ], + "responses": { + "200": { + "content": { + "application/json": { + "schema": {}, + }, + }, + "description": "Successful Response", + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError", + }, + }, + }, + "description": "Validation Error", + }, + }, + "summary": "Get Mapping Query Params", + }, + }, + "/query/mapping-sequence-params": { + "get": { + "operationId": "get_sequence_mapping_query_params_query_mapping_sequence_params_get", + "parameters": [ + { + "in": "query", + "name": "queries", + "required": False, + "schema": { + "additionalProperties": { + "items": { + "type": "integer", + }, + "type": "array", + }, + "default": {}, + "title": "Queries", + "type": "object", + }, + }, + ], + "responses": { + "200": { + "content": { + "application/json": { + "schema": {}, + }, + }, + "description": "Successful Response", + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError", + }, + }, + }, + "description": "Validation Error", + }, + }, + "summary": "Get Sequence Mapping Query Params", + }, + }, + "/query/mixed-params": { + "get": { + "operationId": "get_mixed_mapping_query_params_query_mixed_params_get", + "parameters": [ + { + "in": "query", + "name": "sequence_mapping_queries", + "required": False, + "schema": { + "additionalProperties": { + "items": { + "type": "integer", + }, + "type": "array", + }, + "default": {}, + "title": "Sequence Mapping Queries", + "type": "object", + }, + }, + { + "in": "query", + "name": "mapping_query", + "required": True, + "schema": { + "additionalProperties": { + "type": "string", + }, + "title": "Mapping Query", + "type": "object", + }, + }, + { + "in": "query", + "name": "query", + "required": True, + "schema": { + "title": "Query", + "type": "string", + }, + }, + ], + "responses": { + "200": { + "content": { + "application/json": { + "schema": {}, + }, + }, + "description": "Successful Response", + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError", + }, + }, + }, + "description": "Validation Error", + }, + }, + "summary": "Get Mixed Mapping Query Params", + }, + }, + "/query/mixed-type-params": { + "get": { + "operationId": "get_mixed_mapping_mixed_type_query_params_query_mixed_type_params_get", + "parameters": [ + { + "in": "query", + "name": "sequence_mapping_queries", + "required": False, + "schema": { + "additionalProperties": { + "items": { + "type": "integer", + }, + "type": "array", + }, + "default": {}, + "title": "Sequence Mapping Queries", + "type": "object", + }, + }, + { + "in": "query", + "name": "mapping_query_str", + "required": False, + "schema": { + "additionalProperties": { + "type": "string", + }, + "default": {}, + "title": "Mapping Query Str", + "type": "object", + }, + }, + { + "in": "query", + "name": "mapping_query_int", + "required": False, + "schema": { + "additionalProperties": { + "type": "integer", + }, + "default": {}, + "title": "Mapping Query Int", + "type": "object", + }, + }, + { + "in": "query", + "name": "query", + "required": True, + "schema": { + "title": "Query", + "type": "integer", + }, + }, + ], + "responses": { + "200": { + "content": { + "application/json": { + "schema": {}, + }, + }, + "description": "Successful Response", + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError", + }, + }, + }, + "description": "Validation Error", + }, + }, + "summary": "Get Mixed Mapping Mixed Type Query Params", + }, + }, "/enum-status-code": { "get": { "responses": { diff --git a/tests/test_invalid_mapping_param.py b/tests/test_invalid_mapping_param.py new file mode 100644 index 0000000000..511d7d97cf --- /dev/null +++ b/tests/test_invalid_mapping_param.py @@ -0,0 +1,11 @@ +import pytest +from fastapi import FastAPI, Query + + +def test_invalid_sequence(): + with pytest.raises(AssertionError): + app = FastAPI() + + @app.get("/items/") + def read_items(q: dict[str, list[list[str]]] = Query(default=None)): + pass # pragma: no cover diff --git a/tests/test_invalid_sequence_param.py b/tests/test_invalid_sequence_param.py index 3695344f7a..d0cc2a0d11 100644 --- a/tests/test_invalid_sequence_param.py +++ b/tests/test_invalid_sequence_param.py @@ -1,5 +1,3 @@ -from typing import Optional - import pytest from fastapi import FastAPI, Query from pydantic import BaseModel @@ -48,18 +46,3 @@ def test_invalid_dict(): @app.get("/items/") def read_items(q: dict[str, Item] = Query(default=None)): pass # pragma: no cover - - -def test_invalid_simple_dict(): - with pytest.raises( - AssertionError, - match="Query parameter 'q' must be one of the supported types", - ): - app = FastAPI() - - class Item(BaseModel): - title: str - - @app.get("/items/") - def read_items(q: Optional[dict] = Query(default=None)): - pass # pragma: no cover diff --git a/tests/test_query.py b/tests/test_query.py index c25960caca..0d48b9c294 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -248,12 +248,6 @@ def test_query_param_required_int_query_foo(): } -def test_query_frozenset_query_1_query_1_query_2(): - response = client.get("/query/frozenset/?query=1&query=1&query=2") - assert response.status_code == 200 - assert response.json() == "1,2" - - def test_query_list(): response = client.get("/query/list/?device_ids=1&device_ids=2") assert response.status_code == 200 @@ -275,3 +269,49 @@ def test_query_list_default_empty(): response = client.get("/query/list-default/") assert response.status_code == 200 assert response.json() == [] + + +def test_query_frozenset_query_1_query_1_query_2(): + response = client.get("/query/frozenset/?query=1&query=1&query=2") + assert response.status_code == 200 + assert response.json() == "1,2" + + +def test_mapping_query(): + response = client.get("/query/mapping-params/?foo=fuzz&bar=buzz") + assert response.status_code == 200 + assert response.json() == {"queries": {"bar": "buzz", "foo": "fuzz"}} + + +def test_sequence_mapping_query(): + response = client.get("/query/mapping-sequence-params/?foo=1&foo=2") + assert response.status_code == 200 + assert response.json() == {"queries": {"foo": [1, 2]}} + + +def test_mixed_sequence_mapping_query(): + response = client.get("/query/mixed-type-params?query=2&foo=1&bar=3&foo=2&foo=baz") + assert response.status_code == 200 + assert response.json() == { + "queries": { + "mapping_query_int": {"bar": 3}, + "mapping_query_str": {"bar": "3", "foo": "baz"}, + "query": 2, + "sequence_mapping_queries": {"bar": [3], "foo": [1, 2]}, + } + } + + +def test_mapping_with_non_mapping_query(): + response = client.get("/query/mixed-params/?foo=1&foo=2&bar=3&query=fizz") + assert response.status_code == 200 + assert response.json() == { + "queries": { + "query": "fizz", + "mapping_query": {"foo": "2", "bar": "3"}, + "sequence_mapping_queries": { + "foo": [1, 2], + "bar": [3], + }, + } + } diff --git a/tests/test_request_params/test_query/test_free_form.py b/tests/test_request_params/test_query/test_free_form.py new file mode 100644 index 0000000000..cfb555cf6c --- /dev/null +++ b/tests/test_request_params/test_query/test_free_form.py @@ -0,0 +1,298 @@ +from typing import Annotated, Union + +import pytest +from dirty_equals import IsOneOf +from fastapi import FastAPI, Query +from fastapi.testclient import TestClient +from pydantic import BaseModel + +app = FastAPI() + +# ===================================================================================== +# Without aliases which exercise the "Wildcard" capture behavior + + +@app.get("/required-dict-str") +async def read_required_dict_str(p: Annotated[dict[str, str], Query()]): + return {"p": p} + + +class QueryModelRequiredDictStr(BaseModel): + p: dict[str, str] + + +@app.get("/model-required-dict-str") +def read_model_required_dict_str(p: Annotated[QueryModelRequiredDictStr, Query()]): + return {"p": p.p} + + +@pytest.mark.parametrize( + "path", + ["/required-dict-str", "/model-required-dict-str"], +) +def test_required_dict_str_schema(path: str): + assert app.openapi()["paths"][path]["get"]["parameters"] == [ + { + "required": True, + "schema": { + "title": "P", + "type": "object", + "additionalProperties": {"type": "string"}, + }, + "name": "p", + "in": "query", + } + ] + + +@pytest.mark.parametrize( + "path", + ["/required-dict-str", "/model-required-dict-str"], +) +def test_required_dict_str_missing(path: str): + client = TestClient(app) + response = client.get(path) + assert response.status_code == 422 + assert response.json() == { + "detail": [ + { + "type": "missing", + "loc": ["query", "p"], + "msg": "Field required", + "input": IsOneOf(None, {}), + } + ] + } + + +@pytest.mark.parametrize( + "path", + ["/required-dict-str", "/model-required-dict-str"], +) +def test_required_dict_str(path: str): + client = TestClient(app) + response = client.get(f"{path}?foo=bar&baz=qux") + assert response.status_code == 200 + assert response.json() == {"p": {"foo": "bar", "baz": "qux"}} + + +# ===================================================================================== +# With union types + + +@app.get("/required-dict-union") +async def read_required_dict_union( + p: Annotated[Union[dict[str, str], dict[str, int]], Query()], +): + return {"p": p} + + +class QueryModelRequiredDictUnion(BaseModel): + p: Union[dict[str, str], dict[str, int]] + + +@app.get("/model-required-dict-union") +def read_model_required_dict_union(p: Annotated[QueryModelRequiredDictUnion, Query()]): + return {"p": p.p} + + +@pytest.mark.parametrize( + "path", + ["/required-dict-union", "/model-required-dict-union"], +) +def test_required_dict_union_schema(path: str): + assert app.openapi()["paths"][path]["get"]["parameters"] == [ + { + "required": True, + "schema": { + "title": "P", + "anyOf": [ + { + "type": "object", + "additionalProperties": {"type": "string"}, + }, + { + "type": "object", + "additionalProperties": {"type": "integer"}, + }, + ], + }, + "name": "p", + "in": "query", + } + ] + + +@pytest.mark.parametrize( + "path", + ["/required-dict-union", "/model-required-dict-union"], +) +def test_required_dict_union_missing(path: str): + client = TestClient(app) + response = client.get(path) + assert response.status_code == 422 + assert response.json() == { + "detail": [ + { + "type": "missing", + "loc": ["query", "p"], + "msg": "Field required", + "input": IsOneOf(None, {}), + } + ] + } + + +@pytest.mark.parametrize( + "path", + ["/required-dict-union", "/model-required-dict-union"], +) +def test_required_dict_union(path: str): + client = TestClient(app) + response = client.get(f"{path}?foo=bar&baz=42") + assert response.status_code == 200 + assert response.json() == {"p": {"foo": "bar", "baz": "42"}} + + +@app.get("/required-dict-of-union") +async def read_required_dict_of_union( + p: Annotated[dict[str, Union[int, bool]], Query()], +): + return {"p": p} + + +class QueryModelRequiredDictOfUnion(BaseModel): + p: dict[str, Union[int, bool]] + + +@app.get("/model-required-dict-of-union") +def read_model_required_dict_of_union( + p: Annotated[QueryModelRequiredDictOfUnion, Query()], +): + return {"p": p.p} + + +@pytest.mark.parametrize( + "path", + ["/required-dict-of-union", "/model-required-dict-of-union"], +) +def test_required_dict_of_union_schema(path: str): + assert app.openapi()["paths"][path]["get"]["parameters"] == [ + { + "required": True, + "schema": { + "title": "P", + "type": "object", + "additionalProperties": { + "anyOf": [ + {"type": "integer"}, + {"type": "boolean"}, + ] + }, + }, + "name": "p", + "in": "query", + } + ] + + +@pytest.mark.parametrize( + "path", + ["/required-dict-of-union", "/model-required-dict-of-union"], +) +def test_required_dict_of_union_missing(path: str): + client = TestClient(app) + response = client.get(path) + assert response.status_code == 422 + assert response.json() == { + "detail": [ + { + "type": "missing", + "loc": ["query", "p"], + "msg": "Field required", + "input": IsOneOf(None, {}), + } + ] + } + + +@pytest.mark.parametrize( + "path", + ["/required-dict-of-union", "/model-required-dict-of-union"], +) +def test_required_dict_of_union(path: str): + client = TestClient(app) + # Testing the "Wildcard" capture behavior for dicts + response = client.get(f"{path}?foo=True&baz=42") + assert response.status_code == 200 + assert response.json() == {"p": {"foo": True, "baz": 42}} + + +@app.get("/required-dict-of-list") +async def read_required_dict_of_list(p: Annotated[dict[str, list[int]], Query()]): + return {"p": p} + + +class QueryModelRequiredDictOfList(BaseModel): + p: dict[str, list[int]] + + +@app.get("/model-required-dict-of-list") +def read_model_required_dict_of_list( + p: Annotated[QueryModelRequiredDictOfList, Query()], +): + return {"p": p.p} + + +@pytest.mark.parametrize( + "path", + ["/required-dict-of-list", "/model-required-dict-of-list"], +) +def test_required_dict_of_list_schema(path: str): + assert app.openapi()["paths"][path]["get"]["parameters"] == [ + { + "required": True, + "schema": { + "title": "P", + "type": "object", + "additionalProperties": { + "type": "array", + "items": {"type": "integer"}, + }, + }, + "name": "p", + "in": "query", + } + ] + + +@pytest.mark.parametrize( + "path", + ["/required-dict-of-list", "/model-required-dict-of-list"], +) +def test_required_dict_of_list_missing(path: str): + client = TestClient(app) + response = client.get(path) + assert response.status_code == 422 + assert response.json() == { + "detail": [ + { + "type": "missing", + "loc": ["query", "p"], + "msg": "Field required", + "input": IsOneOf(None, {}), + } + ] + } + + +@pytest.mark.parametrize( + "path", + ["/required-dict-of-list", "/model-required-dict-of-list"], +) +def test_required_dict_of_list(path: str): + client = TestClient(app) + # Testing the "Wildcard" capture behavior for dicts with list values + response = client.get(f"{path}?foo=1&foo=2&baz=3") + assert response.status_code == 200 + assert response.json() == {"p": {"foo": [1, 2], "baz": [3]}} diff --git a/tests/test_tutorial/test_query_params/test_tutorial007_py310.py b/tests/test_tutorial/test_query_params/test_tutorial007_py310.py new file mode 100644 index 0000000000..7101c00119 --- /dev/null +++ b/tests/test_tutorial/test_query_params/test_tutorial007_py310.py @@ -0,0 +1,36 @@ +import pytest +from fastapi.testclient import TestClient + +from tests.utils import needs_py310 + + +@pytest.fixture(name="client") +def get_client(): + from docs_src.query_params.tutorial007_py310 import app + + c = TestClient(app) + return c + + +@needs_py310 +def test_foo_needy_very(client: TestClient): + response = client.get("/query/mixed-type-params?query=1&query=2&foo=bar&foo=baz") + assert response.status_code == 200 + assert response.json() == { + "query": 2, + "mapping_query_str": {"foo": "baz"}, + "mapping_query_int": {}, + "sequence_mapping_int": {"foo": []}, + } + + +@needs_py310 +def test_just_string_not_scalar_mapping(client: TestClient): + response = client.get("/query/mixed-type-params?query=2&foo=1&bar=3&foo=2&foo=baz") + assert response.status_code == 200 + assert response.json() == { + "query": 2, + "mapping_query_str": {"bar": "3", "foo": "baz"}, + "mapping_query_int": {"bar": 3}, + "sequence_mapping_int": {"bar": [3], "foo": [1, 2]}, + }