This commit is contained in:
commonism 2025-12-12 14:33:23 +00:00 committed by GitHub
commit e8c2b479d6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 192 additions and 7 deletions

View File

@ -1,5 +1,7 @@
import collections
import dataclasses
import inspect
import re
import sys
from contextlib import AsyncExitStack, contextmanager
from copy import copy, deepcopy
@ -21,7 +23,7 @@ from typing import (
)
import anyio
from fastapi import params
from fastapi import params, temp_pydantic_v1_params
from fastapi._compat import (
PYDANTIC_V2,
ModelField,
@ -77,8 +79,6 @@ from starlette.responses import Response
from starlette.websockets import WebSocket
from typing_extensions import Annotated, Literal, get_args, get_origin
from .. import temp_pydantic_v1_params
multipart_not_installed_error = (
'Form data requires "python-multipart" to be installed. \n'
'You can install "python-multipart" with: \n\n'
@ -773,6 +773,66 @@ def _get_multidict_value(
return value
class ParameterCodec:
@staticmethod
def _default() -> Dict[str, Any]:
return collections.defaultdict(lambda: ParameterCodec._default())
@staticmethod
def decode(
field_info: Union[params.Param, temp_pydantic_v1_params.Param],
received_params: Union[Mapping[str, Any], QueryParams, Headers],
field: ModelField,
) -> Dict[str, Any]:
fn: Callable[
[
Union[params.Param, temp_pydantic_v1_params.Param],
Union[Mapping[str, Any], QueryParams, Headers],
ModelField,
],
Dict[str, Any],
]
fn = getattr(ParameterCodec, f"decode_{field_info.style}")
return fn(field_info, received_params, field)
@staticmethod
def decode_deepObject(
field_info: Union[params.Param, temp_pydantic_v1_params.Param],
received_params: Union[Mapping[str, Any], QueryParams, Headers],
field: ModelField,
) -> Dict[str, Any]:
data: List[Tuple[str, str]] = []
for k, v in received_params.items():
if k.startswith(f"{field.alias}["):
data.append((k, v))
r = ParameterCodec._default()
for k, v in data:
"""
k: name[attr0][attr1]
v: "5"
-> {"name":{"attr0":{"attr1":"5"}}}
"""
# p = tuple(map(lambda x: x[:-1] if x[-1] == ']' else x, k.split("[")))
# would do as well, but add basic validation …
p0 = re.split(r"(\[|\]\[|\]$)", k)
s = p0[1::2]
assert (
p0[-1] == ""
and s[0] == "["
and s[-1] == "]"
and all(x == "][" for x in s[1:-1])
)
p1 = tuple(p0[::2][:-1])
o = r
for i in p1[:-1]:
o = o[i]
o[p1[-1]] = v
return r
def request_params_to_args(
fields: Sequence[ModelField],
received_params: Union[Mapping[str, Any], QueryParams, Headers],
@ -836,17 +896,30 @@ def request_params_to_args(
"Params must be subclasses of Param"
)
loc: Tuple[str, ...] = (field_info.in_.value,)
v_, errors_ = _validate_value_with_model_field(
field=first_field, value=params_to_process, values=values, loc=loc
)
if field_info.style == "deepObject":
value = ParameterCodec.decode(field_info, received_params, first_field)
value = value[first_field.alias]
v_, errors_ = _validate_value_with_model_field(
field=first_field, value=value, values=value, loc=loc
)
else:
v_, errors_ = _validate_value_with_model_field(
field=first_field, value=params_to_process, values=values, loc=loc
)
return {first_field.name: v_}, errors_
for field in fields:
value = _get_multidict_value(field, received_params)
field_info = field.field_info
assert isinstance(field_info, (params.Param, temp_pydantic_v1_params.Param)), (
"Params must be subclasses of Param"
)
if field_info.style == "deepObject":
value = ParameterCodec.decode(field_info, received_params, field)
value = value[field.alias]
else:
value = _get_multidict_value(field, received_params)
loc = (field_info.in_.value, field.alias)
v_, errors_ = _validate_value_with_model_field(
field=field, value=value, values=values, loc=loc

View File

@ -71,6 +71,8 @@ class Param(FieldInfo): # type: ignore[misc]
deprecated: Union[deprecated, str, bool, None] = None,
include_in_schema: bool = True,
json_schema_extra: Union[Dict[str, Any], None] = None,
style: str = _Unset,
explode: bool = _Unset,
**extra: Any,
):
if example is not _Unset:
@ -132,6 +134,8 @@ class Param(FieldInfo): # type: ignore[misc]
use_kwargs = {k: v for k, v in kwargs.items() if v is not _Unset}
super().__init__(**use_kwargs)
self.style = style
self.explode = explode
def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.default})"
@ -185,6 +189,8 @@ class Path(Param): # type: ignore[misc]
deprecated: Union[deprecated, str, bool, None] = None,
include_in_schema: bool = True,
json_schema_extra: Union[Dict[str, Any], None] = None,
style: Literal["matrix", "label", "simple"] = "simple",
explode: bool = False,
**extra: Any,
):
assert default is ..., "Path parameters cannot have a default value"
@ -219,6 +225,8 @@ class Path(Param): # type: ignore[misc]
openapi_examples=openapi_examples,
include_in_schema=include_in_schema,
json_schema_extra=json_schema_extra,
style=style,
explode=explode,
**extra,
)
@ -271,8 +279,14 @@ class Query(Param): # type: ignore[misc]
deprecated: Union[deprecated, str, bool, None] = None,
include_in_schema: bool = True,
json_schema_extra: Union[Dict[str, Any], None] = None,
style: Literal[
"form", "spaceDelimited", "pipeDelimited", "deepObject"
] = "form",
explode: bool = _Unset,
**extra: Any,
):
if explode is _Unset:
explode = False if style != "form" else True
super().__init__(
default=default,
default_factory=default_factory,
@ -303,6 +317,8 @@ class Query(Param): # type: ignore[misc]
openapi_examples=openapi_examples,
include_in_schema=include_in_schema,
json_schema_extra=json_schema_extra,
style=style,
explode=explode,
**extra,
)
@ -356,6 +372,8 @@ class Header(Param): # type: ignore[misc]
deprecated: Union[deprecated, str, bool, None] = None,
include_in_schema: bool = True,
json_schema_extra: Union[Dict[str, Any], None] = None,
style: Literal["simple"] = "simple",
explode: bool = False,
**extra: Any,
):
self.convert_underscores = convert_underscores
@ -389,6 +407,8 @@ class Header(Param): # type: ignore[misc]
openapi_examples=openapi_examples,
include_in_schema=include_in_schema,
json_schema_extra=json_schema_extra,
style=style,
explode=explode,
**extra,
)
@ -441,6 +461,8 @@ class Cookie(Param): # type: ignore[misc]
deprecated: Union[deprecated, str, bool, None] = None,
include_in_schema: bool = True,
json_schema_extra: Union[Dict[str, Any], None] = None,
style: Literal["form"] = "form",
explode: bool = False,
**extra: Any,
):
super().__init__(
@ -473,6 +495,8 @@ class Cookie(Param): # type: ignore[misc]
openapi_examples=openapi_examples,
include_in_schema=include_in_schema,
json_schema_extra=json_schema_extra,
style=style,
explode=explode,
**extra,
)

View File

@ -59,6 +59,8 @@ class Param(FieldInfo): # type: ignore[misc]
deprecated: Union[deprecated, str, bool, None] = None,
include_in_schema: bool = True,
json_schema_extra: Union[Dict[str, Any], None] = None,
style: str = _Unset,
explode: bool = _Unset,
**extra: Any,
):
if example is not _Unset:
@ -107,6 +109,8 @@ class Param(FieldInfo): # type: ignore[misc]
use_kwargs = {k: v for k, v in kwargs.items() if v is not _Unset}
super().__init__(**use_kwargs)
self.style = style
self.explode = explode
def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.default})"

View File

@ -122,6 +122,7 @@ all = [
"pydantic-settings >=2.0.0",
# Extra Pydantic data types
"pydantic-extra-types >=2.0.0",
"more-itertools"
]
[project.scripts]

83
tests/test_param_style.py Normal file
View File

@ -0,0 +1,83 @@
from typing import List, Optional
import pydantic
import pytest
from fastapi import FastAPI, Query
from fastapi._compat import PYDANTIC_V2
from fastapi.testclient import TestClient
from pydantic import BaseModel
from typing_extensions import Literal
class Dog(BaseModel):
pet_type: Literal["dog"]
name: str
class Matrjoschka(BaseModel):
size: str = 0 # without type coecerion Query parameters are limited to str
inner: Optional["Matrjoschka"] = None
app = FastAPI()
@app.post(
"/pet",
operation_id="createPet",
)
def createPet(pet: Dog = Query(style="deepObject")) -> Dog:
return pet
@app.post(
"/toy",
operation_id="createToy",
)
def createToy(toy: Matrjoschka = Query(style="deepObject")) -> Matrjoschka:
return toy
@app.post("/multi", operation_id="createMulti")
def createMulti(
a: Matrjoschka = Query(style="deepObject"),
b: Matrjoschka = Query(style="deepObject"),
) -> List[Matrjoschka]:
return [a, b]
client = TestClient(app)
def test_pet():
response = client.post("""/pet?pet[pet_type]=dog&pet[name]=doggy""")
if PYDANTIC_V2:
dog = Dog.model_validate(response.json())
else:
dog = Dog.parse_obj(response.json())
assert response.status_code == 200
assert dog.pet_type == "dog" and dog.name == "doggy"
def test_matrjoschka():
response = client.post(
"""/toy?toy[size]=3&toy[inner][size]=2&toy[inner][inner][size]=1"""
)
print(response)
if PYDANTIC_V2:
toy = Matrjoschka.model_validate(response.json())
else:
toy = Matrjoschka.parse_obj(response.json())
assert response.status_code == 200
assert toy
assert toy.inner.size == "2"
@pytest.mark.skipif(not PYDANTIC_V2, reason="Only for Pydantic v2")
def test_multi():
response = client.post("""/multi?a[size]=1&b[size]=1""")
print(response)
t = pydantic.TypeAdapter(List[Matrjoschka])
v = t.validate_python(response.json())
assert all(i.size == "1" for i in v), v