This commit is contained in:
commonism 2025-12-12 14:33:23 +00:00 committed by GitHub
commit e8c2b479d6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 192 additions and 7 deletions

View File

@ -1,5 +1,7 @@
import collections
import dataclasses import dataclasses
import inspect import inspect
import re
import sys import sys
from contextlib import AsyncExitStack, contextmanager from contextlib import AsyncExitStack, contextmanager
from copy import copy, deepcopy from copy import copy, deepcopy
@ -21,7 +23,7 @@ from typing import (
) )
import anyio import anyio
from fastapi import params from fastapi import params, temp_pydantic_v1_params
from fastapi._compat import ( from fastapi._compat import (
PYDANTIC_V2, PYDANTIC_V2,
ModelField, ModelField,
@ -77,8 +79,6 @@ from starlette.responses import Response
from starlette.websockets import WebSocket from starlette.websockets import WebSocket
from typing_extensions import Annotated, Literal, get_args, get_origin from typing_extensions import Annotated, Literal, get_args, get_origin
from .. import temp_pydantic_v1_params
multipart_not_installed_error = ( multipart_not_installed_error = (
'Form data requires "python-multipart" to be installed. \n' 'Form data requires "python-multipart" to be installed. \n'
'You can install "python-multipart" with: \n\n' 'You can install "python-multipart" with: \n\n'
@ -773,6 +773,66 @@ def _get_multidict_value(
return value return value
class ParameterCodec:
@staticmethod
def _default() -> Dict[str, Any]:
return collections.defaultdict(lambda: ParameterCodec._default())
@staticmethod
def decode(
field_info: Union[params.Param, temp_pydantic_v1_params.Param],
received_params: Union[Mapping[str, Any], QueryParams, Headers],
field: ModelField,
) -> Dict[str, Any]:
fn: Callable[
[
Union[params.Param, temp_pydantic_v1_params.Param],
Union[Mapping[str, Any], QueryParams, Headers],
ModelField,
],
Dict[str, Any],
]
fn = getattr(ParameterCodec, f"decode_{field_info.style}")
return fn(field_info, received_params, field)
@staticmethod
def decode_deepObject(
field_info: Union[params.Param, temp_pydantic_v1_params.Param],
received_params: Union[Mapping[str, Any], QueryParams, Headers],
field: ModelField,
) -> Dict[str, Any]:
data: List[Tuple[str, str]] = []
for k, v in received_params.items():
if k.startswith(f"{field.alias}["):
data.append((k, v))
r = ParameterCodec._default()
for k, v in data:
"""
k: name[attr0][attr1]
v: "5"
-> {"name":{"attr0":{"attr1":"5"}}}
"""
# p = tuple(map(lambda x: x[:-1] if x[-1] == ']' else x, k.split("[")))
# would do as well, but add basic validation …
p0 = re.split(r"(\[|\]\[|\]$)", k)
s = p0[1::2]
assert (
p0[-1] == ""
and s[0] == "["
and s[-1] == "]"
and all(x == "][" for x in s[1:-1])
)
p1 = tuple(p0[::2][:-1])
o = r
for i in p1[:-1]:
o = o[i]
o[p1[-1]] = v
return r
def request_params_to_args( def request_params_to_args(
fields: Sequence[ModelField], fields: Sequence[ModelField],
received_params: Union[Mapping[str, Any], QueryParams, Headers], received_params: Union[Mapping[str, Any], QueryParams, Headers],
@ -836,17 +896,30 @@ def request_params_to_args(
"Params must be subclasses of Param" "Params must be subclasses of Param"
) )
loc: Tuple[str, ...] = (field_info.in_.value,) loc: Tuple[str, ...] = (field_info.in_.value,)
if field_info.style == "deepObject":
value = ParameterCodec.decode(field_info, received_params, first_field)
value = value[first_field.alias]
v_, errors_ = _validate_value_with_model_field(
field=first_field, value=value, values=value, loc=loc
)
else:
v_, errors_ = _validate_value_with_model_field( v_, errors_ = _validate_value_with_model_field(
field=first_field, value=params_to_process, values=values, loc=loc field=first_field, value=params_to_process, values=values, loc=loc
) )
return {first_field.name: v_}, errors_ return {first_field.name: v_}, errors_
for field in fields: for field in fields:
value = _get_multidict_value(field, received_params)
field_info = field.field_info field_info = field.field_info
assert isinstance(field_info, (params.Param, temp_pydantic_v1_params.Param)), ( assert isinstance(field_info, (params.Param, temp_pydantic_v1_params.Param)), (
"Params must be subclasses of Param" "Params must be subclasses of Param"
) )
if field_info.style == "deepObject":
value = ParameterCodec.decode(field_info, received_params, field)
value = value[field.alias]
else:
value = _get_multidict_value(field, received_params)
loc = (field_info.in_.value, field.alias) loc = (field_info.in_.value, field.alias)
v_, errors_ = _validate_value_with_model_field( v_, errors_ = _validate_value_with_model_field(
field=field, value=value, values=values, loc=loc field=field, value=value, values=values, loc=loc

View File

@ -71,6 +71,8 @@ class Param(FieldInfo): # type: ignore[misc]
deprecated: Union[deprecated, str, bool, None] = None, deprecated: Union[deprecated, str, bool, None] = None,
include_in_schema: bool = True, include_in_schema: bool = True,
json_schema_extra: Union[Dict[str, Any], None] = None, json_schema_extra: Union[Dict[str, Any], None] = None,
style: str = _Unset,
explode: bool = _Unset,
**extra: Any, **extra: Any,
): ):
if example is not _Unset: if example is not _Unset:
@ -132,6 +134,8 @@ class Param(FieldInfo): # type: ignore[misc]
use_kwargs = {k: v for k, v in kwargs.items() if v is not _Unset} use_kwargs = {k: v for k, v in kwargs.items() if v is not _Unset}
super().__init__(**use_kwargs) super().__init__(**use_kwargs)
self.style = style
self.explode = explode
def __repr__(self) -> str: def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.default})" return f"{self.__class__.__name__}({self.default})"
@ -185,6 +189,8 @@ class Path(Param): # type: ignore[misc]
deprecated: Union[deprecated, str, bool, None] = None, deprecated: Union[deprecated, str, bool, None] = None,
include_in_schema: bool = True, include_in_schema: bool = True,
json_schema_extra: Union[Dict[str, Any], None] = None, json_schema_extra: Union[Dict[str, Any], None] = None,
style: Literal["matrix", "label", "simple"] = "simple",
explode: bool = False,
**extra: Any, **extra: Any,
): ):
assert default is ..., "Path parameters cannot have a default value" assert default is ..., "Path parameters cannot have a default value"
@ -219,6 +225,8 @@ class Path(Param): # type: ignore[misc]
openapi_examples=openapi_examples, openapi_examples=openapi_examples,
include_in_schema=include_in_schema, include_in_schema=include_in_schema,
json_schema_extra=json_schema_extra, json_schema_extra=json_schema_extra,
style=style,
explode=explode,
**extra, **extra,
) )
@ -271,8 +279,14 @@ class Query(Param): # type: ignore[misc]
deprecated: Union[deprecated, str, bool, None] = None, deprecated: Union[deprecated, str, bool, None] = None,
include_in_schema: bool = True, include_in_schema: bool = True,
json_schema_extra: Union[Dict[str, Any], None] = None, json_schema_extra: Union[Dict[str, Any], None] = None,
style: Literal[
"form", "spaceDelimited", "pipeDelimited", "deepObject"
] = "form",
explode: bool = _Unset,
**extra: Any, **extra: Any,
): ):
if explode is _Unset:
explode = False if style != "form" else True
super().__init__( super().__init__(
default=default, default=default,
default_factory=default_factory, default_factory=default_factory,
@ -303,6 +317,8 @@ class Query(Param): # type: ignore[misc]
openapi_examples=openapi_examples, openapi_examples=openapi_examples,
include_in_schema=include_in_schema, include_in_schema=include_in_schema,
json_schema_extra=json_schema_extra, json_schema_extra=json_schema_extra,
style=style,
explode=explode,
**extra, **extra,
) )
@ -356,6 +372,8 @@ class Header(Param): # type: ignore[misc]
deprecated: Union[deprecated, str, bool, None] = None, deprecated: Union[deprecated, str, bool, None] = None,
include_in_schema: bool = True, include_in_schema: bool = True,
json_schema_extra: Union[Dict[str, Any], None] = None, json_schema_extra: Union[Dict[str, Any], None] = None,
style: Literal["simple"] = "simple",
explode: bool = False,
**extra: Any, **extra: Any,
): ):
self.convert_underscores = convert_underscores self.convert_underscores = convert_underscores
@ -389,6 +407,8 @@ class Header(Param): # type: ignore[misc]
openapi_examples=openapi_examples, openapi_examples=openapi_examples,
include_in_schema=include_in_schema, include_in_schema=include_in_schema,
json_schema_extra=json_schema_extra, json_schema_extra=json_schema_extra,
style=style,
explode=explode,
**extra, **extra,
) )
@ -441,6 +461,8 @@ class Cookie(Param): # type: ignore[misc]
deprecated: Union[deprecated, str, bool, None] = None, deprecated: Union[deprecated, str, bool, None] = None,
include_in_schema: bool = True, include_in_schema: bool = True,
json_schema_extra: Union[Dict[str, Any], None] = None, json_schema_extra: Union[Dict[str, Any], None] = None,
style: Literal["form"] = "form",
explode: bool = False,
**extra: Any, **extra: Any,
): ):
super().__init__( super().__init__(
@ -473,6 +495,8 @@ class Cookie(Param): # type: ignore[misc]
openapi_examples=openapi_examples, openapi_examples=openapi_examples,
include_in_schema=include_in_schema, include_in_schema=include_in_schema,
json_schema_extra=json_schema_extra, json_schema_extra=json_schema_extra,
style=style,
explode=explode,
**extra, **extra,
) )

View File

@ -59,6 +59,8 @@ class Param(FieldInfo): # type: ignore[misc]
deprecated: Union[deprecated, str, bool, None] = None, deprecated: Union[deprecated, str, bool, None] = None,
include_in_schema: bool = True, include_in_schema: bool = True,
json_schema_extra: Union[Dict[str, Any], None] = None, json_schema_extra: Union[Dict[str, Any], None] = None,
style: str = _Unset,
explode: bool = _Unset,
**extra: Any, **extra: Any,
): ):
if example is not _Unset: if example is not _Unset:
@ -107,6 +109,8 @@ class Param(FieldInfo): # type: ignore[misc]
use_kwargs = {k: v for k, v in kwargs.items() if v is not _Unset} use_kwargs = {k: v for k, v in kwargs.items() if v is not _Unset}
super().__init__(**use_kwargs) super().__init__(**use_kwargs)
self.style = style
self.explode = explode
def __repr__(self) -> str: def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.default})" return f"{self.__class__.__name__}({self.default})"

View File

@ -122,6 +122,7 @@ all = [
"pydantic-settings >=2.0.0", "pydantic-settings >=2.0.0",
# Extra Pydantic data types # Extra Pydantic data types
"pydantic-extra-types >=2.0.0", "pydantic-extra-types >=2.0.0",
"more-itertools"
] ]
[project.scripts] [project.scripts]

83
tests/test_param_style.py Normal file
View File

@ -0,0 +1,83 @@
from typing import List, Optional
import pydantic
import pytest
from fastapi import FastAPI, Query
from fastapi._compat import PYDANTIC_V2
from fastapi.testclient import TestClient
from pydantic import BaseModel
from typing_extensions import Literal
class Dog(BaseModel):
pet_type: Literal["dog"]
name: str
class Matrjoschka(BaseModel):
size: str = 0 # without type coecerion Query parameters are limited to str
inner: Optional["Matrjoschka"] = None
app = FastAPI()
@app.post(
"/pet",
operation_id="createPet",
)
def createPet(pet: Dog = Query(style="deepObject")) -> Dog:
return pet
@app.post(
"/toy",
operation_id="createToy",
)
def createToy(toy: Matrjoschka = Query(style="deepObject")) -> Matrjoschka:
return toy
@app.post("/multi", operation_id="createMulti")
def createMulti(
a: Matrjoschka = Query(style="deepObject"),
b: Matrjoschka = Query(style="deepObject"),
) -> List[Matrjoschka]:
return [a, b]
client = TestClient(app)
def test_pet():
response = client.post("""/pet?pet[pet_type]=dog&pet[name]=doggy""")
if PYDANTIC_V2:
dog = Dog.model_validate(response.json())
else:
dog = Dog.parse_obj(response.json())
assert response.status_code == 200
assert dog.pet_type == "dog" and dog.name == "doggy"
def test_matrjoschka():
response = client.post(
"""/toy?toy[size]=3&toy[inner][size]=2&toy[inner][inner][size]=1"""
)
print(response)
if PYDANTIC_V2:
toy = Matrjoschka.model_validate(response.json())
else:
toy = Matrjoschka.parse_obj(response.json())
assert response.status_code == 200
assert toy
assert toy.inner.size == "2"
@pytest.mark.skipif(not PYDANTIC_V2, reason="Only for Pydantic v2")
def test_multi():
response = client.post("""/multi?a[size]=1&b[size]=1""")
print(response)
t = pydantic.TypeAdapter(List[Matrjoschka])
v = t.validate_python(response.json())
assert all(i.size == "1" for i in v), v