Merge remote-tracking branch 'yerlin/free-form-queries' into free-form-queries

This commit is contained in:
JONEMI19 2023-03-13 15:46:28 +00:00
commit b092a0be32
5 changed files with 72 additions and 4 deletions

View File

@ -37,6 +37,7 @@ from pydantic.errors import MissingError
from pydantic.fields import (
SHAPE_FROZENSET,
SHAPE_LIST,
SHAPE_MAPPING,
SHAPE_SEQUENCE,
SHAPE_SET,
SHAPE_SINGLETON,
@ -74,6 +75,9 @@ sequence_shape_to_type = {
SHAPE_TUPLE_ELLIPSIS: list,
}
mapping_shapes = {SHAPE_MAPPING}
mapping_types = dict
mapping_shapes_to_type = {SHAPE_MAPPING: dict}
multipart_not_installed_error = (
'Form data requires "python-multipart" to be installed. \n'
@ -245,6 +249,20 @@ def is_scalar_sequence_field(field: ModelField) -> bool:
return False
def is_scalar_mapping_field(field: ModelField) -> bool:
if (field.shape in mapping_shapes) and not lenient_issubclass(
field.type_, BaseModel
):
if field.sub_fields is not None:
for sub_field in field.sub_fields:
if not is_scalar_field(sub_field):
return False
return True
if lenient_issubclass(field.type_, mapping_types):
return True
return False
def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
signature = inspect.signature(call)
globalns = getattr(call, "__globals__", {})
@ -324,9 +342,10 @@ def get_dependant(
add_param_to_fields(field=param_field, dependant=dependant)
elif is_scalar_field(field=param_field):
add_param_to_fields(field=param_field, dependant=dependant)
elif isinstance(
param.default, (params.Query, params.Header)
) and is_scalar_sequence_field(param_field):
elif isinstance(param.default, (params.Query, params.Header)) and (
is_scalar_sequence_field(param_field)
or is_scalar_mapping_field(param_field)
):
add_param_to_fields(field=param_field, dependant=dependant)
else:
field_info = param_field.field_info
@ -603,6 +622,10 @@ def request_params_to_args(
received_params, (QueryParams, Headers)
):
value = received_params.getlist(field.alias) or field.default
elif is_scalar_mapping_field(field) and isinstance(
received_params, (QueryParams,)
):
value = received_params._dict
else:
value = received_params.get(field.alias)
field_info = field.field_info

View File

@ -189,6 +189,11 @@ def get_query_param_required_type(query: int = Query()):
return f"foo bar {query}"
@app.get("/query/params")
def get_query_params(queries: Dict[str, int] = Query({})):
return f"foo bar {queries}"
@app.get("/enum-status-code", status_code=http.HTTPStatus.CREATED)
def get_enum_status_code():
return "foo bar"

View File

@ -1078,6 +1078,41 @@ openapi_schema = {
],
}
},
"/query/params": {
"get": {
"summary": "Get Query Params",
"operationId": "get_query_params_query_params_get",
"parameters": [
{
"required": False,
"schema": {
"title": "Queries",
"type": "object",
"additionalProperties": {"type": "integer"},
"default": {},
},
"name": "queries",
"in": "query",
}
],
"responses": {
"200": {
"description": "Successful Response",
"content": {"application/json": {"schema": {}}},
},
"422": {
"description": "Validation Error",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/HTTPValidationError"
}
}
},
},
},
}
},
"/enum-status-code": {
"get": {
"responses": {

View File

@ -1,4 +1,4 @@
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Tuple
import pytest
from fastapi import FastAPI, Query

View File

@ -54,6 +54,11 @@ response_not_valid_int = {
("/query/param-required/int?query=50", 200, "foo bar 50"),
("/query/param-required/int?query=foo", 422, response_not_valid_int),
("/query/frozenset/?query=1&query=1&query=2", 200, "1,2"),
(
"/query/params?first-query=1&second-query=2",
200,
"foo bar {'first-query': 1, " "'second-query': 2}",
),
],
)
def test_get_path(path, expected_status, expected_response):