mirror of https://github.com/tiangolo/fastapi.git
✨ Refactor param extraction using Pydantic Field (#278)
* ✨ Refactor parameter dependency using Pydantic Field * ⬆️ Upgrade required Pydantic version with latest Shape values * ✨ Add tutorials and code for using Enum and Optional * ✅ Add tests for tutorials with new types and extra cases * ♻️ Format, clean, and add annotations to dependencies.utils * 📝 Update tutorial for query parameters with list defaults * ✅ Add tests for query param with list default
This commit is contained in:
parent
83b1a117cc
commit
bd407cc4ed
Binary file not shown.
|
After Width: | Height: | Size: 82 KiB |
|
|
@ -0,0 +1,21 @@
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
from fastapi import FastAPI
|
||||||
|
|
||||||
|
|
||||||
|
class ModelName(Enum):
|
||||||
|
alexnet = "alexnet"
|
||||||
|
resnet = "resnet"
|
||||||
|
lenet = "lenet"
|
||||||
|
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/model/{model_name}")
|
||||||
|
async def get_model(model_name: ModelName):
|
||||||
|
if model_name == ModelName.alexnet:
|
||||||
|
return {"model_name": model_name, "message": "Deep Learning FTW!"}
|
||||||
|
if model_name.value == "lenet":
|
||||||
|
return {"model_name": model_name, "message": "LeCNN all the images"}
|
||||||
|
return {"model_name": model_name, "message": "Have some residuals"}
|
||||||
|
|
@ -0,0 +1,11 @@
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from fastapi import FastAPI
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/items/{item_id}")
|
||||||
|
async def read_user_item(item_id: str, limit: Optional[int] = None):
|
||||||
|
item = {"item_id": item_id, "limit": limit}
|
||||||
|
return item
|
||||||
|
|
@ -0,0 +1,11 @@
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from fastapi import FastAPI, Query
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/items/")
|
||||||
|
async def read_items(q: List[str] = Query(["foo", "bar"])):
|
||||||
|
query_items = {"q": q}
|
||||||
|
return query_items
|
||||||
|
|
@ -35,7 +35,7 @@ If you run this example and open your browser at <a href="http://127.0.0.1:8000/
|
||||||
|
|
||||||
!!! check
|
!!! check
|
||||||
Notice that the value your function received (and returned) is `3`, as a Python `int`, not a string `"3"`.
|
Notice that the value your function received (and returned) is `3`, as a Python `int`, not a string `"3"`.
|
||||||
|
|
||||||
So, with that type declaration, **FastAPI** gives you automatic request <abbr title="converting the string that comes from an HTTP request into Python data">"parsing"</abbr>.
|
So, with that type declaration, **FastAPI** gives you automatic request <abbr title="converting the string that comes from an HTTP request into Python data">"parsing"</abbr>.
|
||||||
|
|
||||||
## Data validation
|
## Data validation
|
||||||
|
|
@ -61,12 +61,11 @@ because the path parameter `item_id` had a value of `"foo"`, which is not an `in
|
||||||
|
|
||||||
The same error would appear if you provided a `float` instead of an int, as in: <a href="http://127.0.0.1:8000/items/4.2" target="_blank">http://127.0.0.1:8000/items/4.2</a>
|
The same error would appear if you provided a `float` instead of an int, as in: <a href="http://127.0.0.1:8000/items/4.2" target="_blank">http://127.0.0.1:8000/items/4.2</a>
|
||||||
|
|
||||||
|
|
||||||
!!! check
|
!!! check
|
||||||
So, with the same Python type declaration, **FastAPI** gives you data validation.
|
So, with the same Python type declaration, **FastAPI** gives you data validation.
|
||||||
|
|
||||||
Notice that the error also clearly states exactly the point where the validation didn't pass.
|
Notice that the error also clearly states exactly the point where the validation didn't pass.
|
||||||
|
|
||||||
This is incredibly helpful while developing and debugging code that interacts with your API.
|
This is incredibly helpful while developing and debugging code that interacts with your API.
|
||||||
|
|
||||||
## Documentation
|
## Documentation
|
||||||
|
|
@ -96,8 +95,7 @@ All the data validation is performed under the hood by <a href="https://pydantic
|
||||||
|
|
||||||
You can use the same type declarations with `str`, `float`, `bool` and many other complex data types.
|
You can use the same type declarations with `str`, `float`, `bool` and many other complex data types.
|
||||||
|
|
||||||
These are explored in the next chapters of the tutorial.
|
Several of these are explored in the next chapters of the tutorial.
|
||||||
|
|
||||||
|
|
||||||
## Order matters
|
## Order matters
|
||||||
|
|
||||||
|
|
@ -115,6 +113,73 @@ Because path operations are evaluated in order, you need to make sure that the p
|
||||||
|
|
||||||
Otherwise, the path for `/users/{user_id}` would match also for `/users/me`, "thinking" that it's receiving a parameter `user_id` with a value of `"me"`.
|
Otherwise, the path for `/users/{user_id}` would match also for `/users/me`, "thinking" that it's receiving a parameter `user_id` with a value of `"me"`.
|
||||||
|
|
||||||
|
## Predefined values
|
||||||
|
|
||||||
|
If you have a *path operation* that receives a *path parameter*, but you want the possible valid *path parameter* values to be predefined, you can use a standard Python <abbr title="Enumeration">`Enum`</abbr>.
|
||||||
|
|
||||||
|
### Create an `Enum` class
|
||||||
|
|
||||||
|
Import `Enum` and create a sub-class that inherits from it.
|
||||||
|
|
||||||
|
And create class attributes with fixed values, those fixed values will be the available valid values:
|
||||||
|
|
||||||
|
```Python hl_lines="1 6 7 8 9"
|
||||||
|
{!./src/path_params/tutorial005.py!}
|
||||||
|
```
|
||||||
|
|
||||||
|
!!! info
|
||||||
|
<a href="https://docs.python.org/3/library/enum.html" target="_blank">Enumerations (or enums) are available in Python</a> since version 3.4.
|
||||||
|
|
||||||
|
!!! tip
|
||||||
|
If you are wondering, "AlexNet", "ResNet", and "LeNet" are just names of Machine Learning <abbr title="Technically, Deep Learning model architectures">models</abbr>.
|
||||||
|
|
||||||
|
### Declare a *path parameter*
|
||||||
|
|
||||||
|
Then create a *path parameter* with a type annotation using the enum class you created (`ModelName`):
|
||||||
|
|
||||||
|
```Python hl_lines="16"
|
||||||
|
{!./src/path_params/tutorial005.py!}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Check the docs
|
||||||
|
|
||||||
|
Because the available values for the *path parameter* are specified, the interactive docs can show them nicely:
|
||||||
|
|
||||||
|
<img src="/img/tutorial/path-params/image03.png">
|
||||||
|
|
||||||
|
### Working with Python *enumerations*
|
||||||
|
|
||||||
|
The value of the *path parameter* will be an *enumeration member*.
|
||||||
|
|
||||||
|
#### Compare *enumeration members*
|
||||||
|
|
||||||
|
You can compare it with the *enumeration member* in your created enum `ModelName`:
|
||||||
|
|
||||||
|
```Python hl_lines="17"
|
||||||
|
{!./src/path_params/tutorial005.py!}
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Get the *enumeration value*
|
||||||
|
|
||||||
|
You can get the actual value (a `str` in this case) using `model_name.value`, or in general, `your_enum_member.value`:
|
||||||
|
|
||||||
|
```Python hl_lines="19"
|
||||||
|
{!./src/path_params/tutorial005.py!}
|
||||||
|
```
|
||||||
|
|
||||||
|
!!! tip
|
||||||
|
You could also access the value `"lenet"` with `ModelName.lenet.value`.
|
||||||
|
|
||||||
|
#### Return *enumeration members*
|
||||||
|
|
||||||
|
You can return *enum members* from your *path operation*, even nested in a JSON body (e.g. a `dict`).
|
||||||
|
|
||||||
|
They will be converted to their corresponding values before returning them to the client:
|
||||||
|
|
||||||
|
```Python hl_lines="18 20 21"
|
||||||
|
{!./src/path_params/tutorial005.py!}
|
||||||
|
```
|
||||||
|
|
||||||
## Path parameters containing paths
|
## Path parameters containing paths
|
||||||
|
|
||||||
Let's say you have a *path operation* with a path `/files/{file_path}`.
|
Let's say you have a *path operation* with a path `/files/{file_path}`.
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,6 @@ The query parameter `q` is of type `str`, and by default is `None`, so it is opt
|
||||||
|
|
||||||
We are going to enforce that even though `q` is optional, whenever it is provided, it **doesn't exceed a length of 50 characters**.
|
We are going to enforce that even though `q` is optional, whenever it is provided, it **doesn't exceed a length of 50 characters**.
|
||||||
|
|
||||||
|
|
||||||
### Import `Query`
|
### Import `Query`
|
||||||
|
|
||||||
To achieve that, first import `Query` from `fastapi`:
|
To achieve that, first import `Query` from `fastapi`:
|
||||||
|
|
@ -29,7 +28,7 @@ And now use it as the default value of your parameter, setting the parameter `ma
|
||||||
{!./src/query_params_str_validations/tutorial002.py!}
|
{!./src/query_params_str_validations/tutorial002.py!}
|
||||||
```
|
```
|
||||||
|
|
||||||
As we have to replace the default value `None` with `Query(None)`, the first parameter to `Query` serves the same purpose of defining that default value.
|
As we have to replace the default value `None` with `Query(None)`, the first parameter to `Query` serves the same purpose of defining that default value.
|
||||||
|
|
||||||
So:
|
So:
|
||||||
|
|
||||||
|
|
@ -41,7 +40,7 @@ q: str = Query(None)
|
||||||
|
|
||||||
```Python
|
```Python
|
||||||
q: str = None
|
q: str = None
|
||||||
```
|
```
|
||||||
|
|
||||||
But it declares it explicitly as being a query parameter.
|
But it declares it explicitly as being a query parameter.
|
||||||
|
|
||||||
|
|
@ -53,7 +52,6 @@ q: str = Query(None, max_length=50)
|
||||||
|
|
||||||
This will validate the data, show a clear error when the data is not valid, and document the parameter in the OpenAPI schema path operation.
|
This will validate the data, show a clear error when the data is not valid, and document the parameter in the OpenAPI schema path operation.
|
||||||
|
|
||||||
|
|
||||||
## Add more validations
|
## Add more validations
|
||||||
|
|
||||||
You can also add a parameter `min_length`:
|
You can also add a parameter `min_length`:
|
||||||
|
|
@ -119,7 +117,7 @@ So, when you need to declare a value as required while using `Query`, you can us
|
||||||
{!./src/query_params_str_validations/tutorial006.py!}
|
{!./src/query_params_str_validations/tutorial006.py!}
|
||||||
```
|
```
|
||||||
|
|
||||||
!!! info
|
!!! info
|
||||||
If you hadn't seen that `...` before: it is a a special single value, it is <a href="https://docs.python.org/3/library/constants.html#Ellipsis" target="_blank">part of Python and is called "Ellipsis"</a>.
|
If you hadn't seen that `...` before: it is a a special single value, it is <a href="https://docs.python.org/3/library/constants.html#Ellipsis" target="_blank">part of Python and is called "Ellipsis"</a>.
|
||||||
|
|
||||||
This will let **FastAPI** know that this parameter is required.
|
This will let **FastAPI** know that this parameter is required.
|
||||||
|
|
@ -156,11 +154,35 @@ So, the response to that URL would be:
|
||||||
!!! tip
|
!!! tip
|
||||||
To declare a query parameter with a type of `list`, like in the example above, you need to explicitly use `Query`, otherwise it would be interpreted as a request body.
|
To declare a query parameter with a type of `list`, like in the example above, you need to explicitly use `Query`, otherwise it would be interpreted as a request body.
|
||||||
|
|
||||||
|
|
||||||
The interactive API docs will update accordingly, to allow multiple values:
|
The interactive API docs will update accordingly, to allow multiple values:
|
||||||
|
|
||||||
<img src="/img/tutorial/query-params-str-validations/image02.png">
|
<img src="/img/tutorial/query-params-str-validations/image02.png">
|
||||||
|
|
||||||
|
### Query parameter list / multiple values with defaults
|
||||||
|
|
||||||
|
And you can also define a default `list` of values if none are provided:
|
||||||
|
|
||||||
|
```Python hl_lines="9"
|
||||||
|
{!./src/query_params_str_validations/tutorial012.py!}
|
||||||
|
```
|
||||||
|
|
||||||
|
If you go to:
|
||||||
|
|
||||||
|
```
|
||||||
|
http://localhost:8000/items/
|
||||||
|
```
|
||||||
|
|
||||||
|
the default of `q` will be: `["foo", "bar"]` and your response will be:
|
||||||
|
|
||||||
|
```JSON
|
||||||
|
{
|
||||||
|
"q": [
|
||||||
|
"foo",
|
||||||
|
"bar"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
## Declare more metadata
|
## Declare more metadata
|
||||||
|
|
||||||
You can add more information about the parameter.
|
You can add more information about the parameter.
|
||||||
|
|
|
||||||
|
|
@ -186,3 +186,39 @@ In this case, there are 3 query parameters:
|
||||||
* `needy`, a required `str`.
|
* `needy`, a required `str`.
|
||||||
* `skip`, an `int` with a default value of `0`.
|
* `skip`, an `int` with a default value of `0`.
|
||||||
* `limit`, an optional `int`.
|
* `limit`, an optional `int`.
|
||||||
|
|
||||||
|
!!! tip
|
||||||
|
You could also use `Enum`s <a href="https://fastapi.tiangolo.com/tutorial/path-params/#predefined-values" target="_blank">the same way as with *path parameters*</a>.
|
||||||
|
|
||||||
|
## Optional type declarations
|
||||||
|
|
||||||
|
!!! warning
|
||||||
|
This might be an advanced use case.
|
||||||
|
|
||||||
|
You might want to skip it.
|
||||||
|
|
||||||
|
If you are using `mypy` it could complain with type declarations like:
|
||||||
|
|
||||||
|
```Python
|
||||||
|
limit: int = None
|
||||||
|
```
|
||||||
|
|
||||||
|
With an error like:
|
||||||
|
|
||||||
|
```
|
||||||
|
Incompatible types in assignment (expression has type "None", variable has type "int")
|
||||||
|
```
|
||||||
|
|
||||||
|
In those cases you can use `Optional` to tell `mypy` that the value could be `None`, like:
|
||||||
|
|
||||||
|
```Python
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
limit: Optional[int] = None
|
||||||
|
```
|
||||||
|
|
||||||
|
In a *path operation* that could look like:
|
||||||
|
|
||||||
|
```Python hl_lines="9"
|
||||||
|
{!./src/query_params/tutorial007.py!}
|
||||||
|
```
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,6 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import inspect
|
import inspect
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from datetime import date, datetime, time, timedelta
|
|
||||||
from decimal import Decimal
|
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
Callable,
|
Callable,
|
||||||
|
|
@ -14,8 +12,8 @@ from typing import (
|
||||||
Tuple,
|
Tuple,
|
||||||
Type,
|
Type,
|
||||||
Union,
|
Union,
|
||||||
|
cast,
|
||||||
)
|
)
|
||||||
from uuid import UUID
|
|
||||||
|
|
||||||
from fastapi import params
|
from fastapi import params
|
||||||
from fastapi.dependencies.models import Dependant, SecurityRequirement
|
from fastapi.dependencies.models import Dependant, SecurityRequirement
|
||||||
|
|
@ -23,7 +21,7 @@ from fastapi.security.base import SecurityBase
|
||||||
from fastapi.security.oauth2 import OAuth2, SecurityScopes
|
from fastapi.security.oauth2 import OAuth2, SecurityScopes
|
||||||
from fastapi.security.open_id_connect_url import OpenIdConnect
|
from fastapi.security.open_id_connect_url import OpenIdConnect
|
||||||
from fastapi.utils import get_path_param_names
|
from fastapi.utils import get_path_param_names
|
||||||
from pydantic import BaseConfig, Schema, create_model
|
from pydantic import BaseConfig, BaseModel, Schema, create_model
|
||||||
from pydantic.error_wrappers import ErrorWrapper
|
from pydantic.error_wrappers import ErrorWrapper
|
||||||
from pydantic.errors import MissingError
|
from pydantic.errors import MissingError
|
||||||
from pydantic.fields import Field, Required, Shape
|
from pydantic.fields import Field, Required, Shape
|
||||||
|
|
@ -35,22 +33,21 @@ from starlette.datastructures import FormData, Headers, QueryParams, UploadFile
|
||||||
from starlette.requests import Request
|
from starlette.requests import Request
|
||||||
from starlette.websockets import WebSocket
|
from starlette.websockets import WebSocket
|
||||||
|
|
||||||
param_supported_types = (
|
sequence_shapes = {
|
||||||
str,
|
Shape.LIST,
|
||||||
int,
|
Shape.SET,
|
||||||
float,
|
Shape.TUPLE,
|
||||||
bool,
|
Shape.SEQUENCE,
|
||||||
UUID,
|
Shape.TUPLE_ELLIPS,
|
||||||
date,
|
}
|
||||||
datetime,
|
|
||||||
time,
|
|
||||||
timedelta,
|
|
||||||
Decimal,
|
|
||||||
)
|
|
||||||
|
|
||||||
sequence_shapes = {Shape.LIST, Shape.SET, Shape.TUPLE}
|
|
||||||
sequence_types = (list, set, tuple)
|
sequence_types = (list, set, tuple)
|
||||||
sequence_shape_to_type = {Shape.LIST: list, Shape.SET: set, Shape.TUPLE: tuple}
|
sequence_shape_to_type = {
|
||||||
|
Shape.LIST: list,
|
||||||
|
Shape.SET: set,
|
||||||
|
Shape.TUPLE: tuple,
|
||||||
|
Shape.SEQUENCE: list,
|
||||||
|
Shape.TUPLE_ELLIPS: list,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def get_param_sub_dependant(
|
def get_param_sub_dependant(
|
||||||
|
|
@ -126,6 +123,26 @@ def get_flat_dependant(dependant: Dependant) -> Dependant:
|
||||||
return flat_dependant
|
return flat_dependant
|
||||||
|
|
||||||
|
|
||||||
|
def is_scalar_field(field: Field) -> bool:
|
||||||
|
return (
|
||||||
|
field.shape == Shape.SINGLETON
|
||||||
|
and not lenient_issubclass(field.type_, BaseModel)
|
||||||
|
and not isinstance(field.schema, params.Body)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def is_scalar_sequence_field(field: Field) -> bool:
|
||||||
|
if field.shape in sequence_shapes and not lenient_issubclass(
|
||||||
|
field.type_, BaseModel
|
||||||
|
):
|
||||||
|
if field.sub_fields is not None:
|
||||||
|
for sub_field in field.sub_fields:
|
||||||
|
if not is_scalar_field(sub_field):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def get_dependant(
|
def get_dependant(
|
||||||
*, path: str, call: Callable, name: str = None, security_scopes: List[str] = None
|
*, path: str, call: Callable, name: str = None, security_scopes: List[str] = None
|
||||||
) -> Dependant:
|
) -> Dependant:
|
||||||
|
|
@ -133,83 +150,78 @@ def get_dependant(
|
||||||
endpoint_signature = inspect.signature(call)
|
endpoint_signature = inspect.signature(call)
|
||||||
signature_params = endpoint_signature.parameters
|
signature_params = endpoint_signature.parameters
|
||||||
dependant = Dependant(call=call, name=name)
|
dependant = Dependant(call=call, name=name)
|
||||||
for param_name in signature_params:
|
for param_name, param in signature_params.items():
|
||||||
param = signature_params[param_name]
|
|
||||||
if isinstance(param.default, params.Depends):
|
if isinstance(param.default, params.Depends):
|
||||||
sub_dependant = get_param_sub_dependant(
|
sub_dependant = get_param_sub_dependant(
|
||||||
param=param, path=path, security_scopes=security_scopes
|
param=param, path=path, security_scopes=security_scopes
|
||||||
)
|
)
|
||||||
dependant.dependencies.append(sub_dependant)
|
dependant.dependencies.append(sub_dependant)
|
||||||
for param_name in signature_params:
|
for param_name, param in signature_params.items():
|
||||||
param = signature_params[param_name]
|
if isinstance(param.default, params.Depends):
|
||||||
if (
|
continue
|
||||||
(param.default == param.empty) or isinstance(param.default, params.Path)
|
if add_non_field_param_to_dependency(param=param, dependant=dependant):
|
||||||
) and (param_name in path_param_names):
|
continue
|
||||||
assert (
|
param_field = get_param_field(param=param, default_schema=params.Query)
|
||||||
lenient_issubclass(param.annotation, param_supported_types)
|
if param_name in path_param_names:
|
||||||
or param.annotation == param.empty
|
assert param.default == param.empty or isinstance(
|
||||||
|
param.default, params.Path
|
||||||
|
), "Path params must have no defaults or use Path(...)"
|
||||||
|
assert is_scalar_field(
|
||||||
|
field=param_field
|
||||||
), f"Path params must be of one of the supported types"
|
), f"Path params must be of one of the supported types"
|
||||||
add_param_to_fields(
|
param_field = get_param_field(
|
||||||
param=param,
|
param=param,
|
||||||
dependant=dependant,
|
|
||||||
default_schema=params.Path,
|
default_schema=params.Path,
|
||||||
force_type=params.ParamTypes.path,
|
force_type=params.ParamTypes.path,
|
||||||
)
|
)
|
||||||
elif (
|
add_param_to_fields(field=param_field, dependant=dependant)
|
||||||
param.default == param.empty
|
elif is_scalar_field(field=param_field):
|
||||||
or param.default is None
|
add_param_to_fields(field=param_field, dependant=dependant)
|
||||||
or isinstance(param.default, param_supported_types)
|
elif isinstance(
|
||||||
) and (
|
param.default, (params.Query, params.Header)
|
||||||
param.annotation == param.empty
|
) and is_scalar_sequence_field(param_field):
|
||||||
or lenient_issubclass(param.annotation, param_supported_types)
|
add_param_to_fields(field=param_field, dependant=dependant)
|
||||||
):
|
else:
|
||||||
add_param_to_fields(
|
assert isinstance(
|
||||||
param=param, dependant=dependant, default_schema=params.Query
|
param_field.schema, params.Body
|
||||||
)
|
), f"Param: {param_field.name} can only be a request body, using Body(...)"
|
||||||
elif isinstance(param.default, params.Param):
|
dependant.body_params.append(param_field)
|
||||||
if param.annotation != param.empty:
|
|
||||||
origin = getattr(param.annotation, "__origin__", None)
|
|
||||||
param_all_types = param_supported_types + (list, tuple, set)
|
|
||||||
if isinstance(param.default, (params.Query, params.Header)):
|
|
||||||
assert lenient_issubclass(
|
|
||||||
param.annotation, param_all_types
|
|
||||||
) or lenient_issubclass(
|
|
||||||
origin, param_all_types
|
|
||||||
), f"Parameters for Query and Header must be of type str, int, float, bool, list, tuple or set: {param}"
|
|
||||||
else:
|
|
||||||
assert lenient_issubclass(
|
|
||||||
param.annotation, param_supported_types
|
|
||||||
), f"Parameters for Path and Cookies must be of type str, int, float, bool: {param}"
|
|
||||||
add_param_to_fields(
|
|
||||||
param=param, dependant=dependant, default_schema=params.Query
|
|
||||||
)
|
|
||||||
elif lenient_issubclass(param.annotation, Request):
|
|
||||||
dependant.request_param_name = param_name
|
|
||||||
elif lenient_issubclass(param.annotation, WebSocket):
|
|
||||||
dependant.websocket_param_name = param_name
|
|
||||||
elif lenient_issubclass(param.annotation, BackgroundTasks):
|
|
||||||
dependant.background_tasks_param_name = param_name
|
|
||||||
elif lenient_issubclass(param.annotation, SecurityScopes):
|
|
||||||
dependant.security_scopes_param_name = param_name
|
|
||||||
elif not isinstance(param.default, params.Depends):
|
|
||||||
add_param_to_body_fields(param=param, dependant=dependant)
|
|
||||||
return dependant
|
return dependant
|
||||||
|
|
||||||
|
|
||||||
def add_param_to_fields(
|
def add_non_field_param_to_dependency(
|
||||||
|
*, param: inspect.Parameter, dependant: Dependant
|
||||||
|
) -> Optional[bool]:
|
||||||
|
if lenient_issubclass(param.annotation, Request):
|
||||||
|
dependant.request_param_name = param.name
|
||||||
|
return True
|
||||||
|
elif lenient_issubclass(param.annotation, WebSocket):
|
||||||
|
dependant.websocket_param_name = param.name
|
||||||
|
return True
|
||||||
|
elif lenient_issubclass(param.annotation, BackgroundTasks):
|
||||||
|
dependant.background_tasks_param_name = param.name
|
||||||
|
return True
|
||||||
|
elif lenient_issubclass(param.annotation, SecurityScopes):
|
||||||
|
dependant.security_scopes_param_name = param.name
|
||||||
|
return True
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get_param_field(
|
||||||
*,
|
*,
|
||||||
param: inspect.Parameter,
|
param: inspect.Parameter,
|
||||||
dependant: Dependant,
|
default_schema: Type[params.Param] = params.Param,
|
||||||
default_schema: Type[Schema] = params.Param,
|
|
||||||
force_type: params.ParamTypes = None,
|
force_type: params.ParamTypes = None,
|
||||||
) -> None:
|
) -> Field:
|
||||||
default_value = Required
|
default_value = Required
|
||||||
|
had_schema = False
|
||||||
if not param.default == param.empty:
|
if not param.default == param.empty:
|
||||||
default_value = param.default
|
default_value = param.default
|
||||||
if isinstance(default_value, params.Param):
|
if isinstance(default_value, Schema):
|
||||||
|
had_schema = True
|
||||||
schema = default_value
|
schema = default_value
|
||||||
default_value = schema.default
|
default_value = schema.default
|
||||||
if getattr(schema, "in_", None) is None:
|
if isinstance(schema, params.Param) and getattr(schema, "in_", None) is None:
|
||||||
schema.in_ = default_schema.in_
|
schema.in_ = default_schema.in_
|
||||||
if force_type:
|
if force_type:
|
||||||
schema.in_ = force_type
|
schema.in_ = force_type
|
||||||
|
|
@ -234,43 +246,26 @@ def add_param_to_fields(
|
||||||
class_validators={},
|
class_validators={},
|
||||||
schema=schema,
|
schema=schema,
|
||||||
)
|
)
|
||||||
if schema.in_ == params.ParamTypes.path:
|
if not had_schema and not is_scalar_field(field=field):
|
||||||
|
field.schema = params.Body(schema.default)
|
||||||
|
return field
|
||||||
|
|
||||||
|
|
||||||
|
def add_param_to_fields(*, field: Field, dependant: Dependant) -> None:
|
||||||
|
field.schema = cast(params.Param, field.schema)
|
||||||
|
if field.schema.in_ == params.ParamTypes.path:
|
||||||
dependant.path_params.append(field)
|
dependant.path_params.append(field)
|
||||||
elif schema.in_ == params.ParamTypes.query:
|
elif field.schema.in_ == params.ParamTypes.query:
|
||||||
dependant.query_params.append(field)
|
dependant.query_params.append(field)
|
||||||
elif schema.in_ == params.ParamTypes.header:
|
elif field.schema.in_ == params.ParamTypes.header:
|
||||||
dependant.header_params.append(field)
|
dependant.header_params.append(field)
|
||||||
else:
|
else:
|
||||||
assert (
|
assert (
|
||||||
schema.in_ == params.ParamTypes.cookie
|
field.schema.in_ == params.ParamTypes.cookie
|
||||||
), f"non-body parameters must be in path, query, header or cookie: {param.name}"
|
), f"non-body parameters must be in path, query, header or cookie: {field.name}"
|
||||||
dependant.cookie_params.append(field)
|
dependant.cookie_params.append(field)
|
||||||
|
|
||||||
|
|
||||||
def add_param_to_body_fields(*, param: inspect.Parameter, dependant: Dependant) -> None:
|
|
||||||
default_value = Required
|
|
||||||
if not param.default == param.empty:
|
|
||||||
default_value = param.default
|
|
||||||
if isinstance(default_value, Schema):
|
|
||||||
schema = default_value
|
|
||||||
default_value = schema.default
|
|
||||||
else:
|
|
||||||
schema = Schema(default_value)
|
|
||||||
required = default_value == Required
|
|
||||||
annotation = get_annotation_from_schema(param.annotation, schema)
|
|
||||||
field = Field(
|
|
||||||
name=param.name,
|
|
||||||
type_=annotation,
|
|
||||||
default=None if required else default_value,
|
|
||||||
alias=schema.alias or param.name,
|
|
||||||
required=required,
|
|
||||||
model_config=BaseConfig,
|
|
||||||
class_validators={},
|
|
||||||
schema=schema,
|
|
||||||
)
|
|
||||||
dependant.body_params.append(field)
|
|
||||||
|
|
||||||
|
|
||||||
def is_coroutine_callable(call: Callable) -> bool:
|
def is_coroutine_callable(call: Callable) -> bool:
|
||||||
if inspect.isfunction(call):
|
if inspect.isfunction(call):
|
||||||
return asyncio.iscoroutinefunction(call)
|
return asyncio.iscoroutinefunction(call)
|
||||||
|
|
@ -354,7 +349,7 @@ def request_params_to_args(
|
||||||
if field.shape in sequence_shapes and isinstance(
|
if field.shape in sequence_shapes and isinstance(
|
||||||
received_params, (QueryParams, Headers)
|
received_params, (QueryParams, Headers)
|
||||||
):
|
):
|
||||||
value = received_params.getlist(field.alias)
|
value = received_params.getlist(field.alias) or field.default
|
||||||
else:
|
else:
|
||||||
value = received_params.get(field.alias)
|
value = received_params.get(field.alias)
|
||||||
schema: params.Param = field.schema
|
schema: params.Param = field.schema
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import Any, Dict, List, Optional, Sequence, Tuple, Type
|
from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, cast
|
||||||
|
|
||||||
from fastapi import routing
|
from fastapi import routing
|
||||||
from fastapi.dependencies.models import Dependant
|
from fastapi.dependencies.models import Dependant
|
||||||
|
|
@ -9,7 +9,7 @@ from fastapi.openapi.models import OpenAPI
|
||||||
from fastapi.params import Body, Param
|
from fastapi.params import Body, Param
|
||||||
from fastapi.utils import get_flat_models_from_routes, get_model_definitions
|
from fastapi.utils import get_flat_models_from_routes, get_model_definitions
|
||||||
from pydantic.fields import Field
|
from pydantic.fields import Field
|
||||||
from pydantic.schema import Schema, field_schema, get_model_name_map
|
from pydantic.schema import field_schema, get_model_name_map
|
||||||
from pydantic.utils import lenient_issubclass
|
from pydantic.utils import lenient_issubclass
|
||||||
from starlette.responses import JSONResponse
|
from starlette.responses import JSONResponse
|
||||||
from starlette.routing import BaseRoute
|
from starlette.routing import BaseRoute
|
||||||
|
|
@ -97,12 +97,8 @@ def get_openapi_operation_request_body(
|
||||||
body_schema, _ = field_schema(
|
body_schema, _ = field_schema(
|
||||||
body_field, model_name_map=model_name_map, ref_prefix=REF_PREFIX
|
body_field, model_name_map=model_name_map, ref_prefix=REF_PREFIX
|
||||||
)
|
)
|
||||||
schema: Schema = body_field.schema
|
body_field.schema = cast(Body, body_field.schema)
|
||||||
if isinstance(schema, Body):
|
request_media_type = body_field.schema.media_type
|
||||||
request_media_type = schema.media_type
|
|
||||||
else:
|
|
||||||
# Includes not declared media types (Schema)
|
|
||||||
request_media_type = "application/json"
|
|
||||||
required = body_field.required
|
required = body_field.required
|
||||||
request_body_oai: Dict[str, Any] = {}
|
request_body_oai: Dict[str, Any] = {}
|
||||||
if required:
|
if required:
|
||||||
|
|
|
||||||
|
|
@ -20,7 +20,7 @@ classifiers = [
|
||||||
]
|
]
|
||||||
requires = [
|
requires = [
|
||||||
"starlette >=0.11.1,<=0.12.0",
|
"starlette >=0.11.1,<=0.12.0",
|
||||||
"pydantic >=0.17,<=0.26.0"
|
"pydantic >=0.26,<=0.26.0"
|
||||||
]
|
]
|
||||||
description-file = "README.md"
|
description-file = "README.md"
|
||||||
requires-python = ">=3.6"
|
requires-python = ">=3.6"
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,29 @@
|
||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi import FastAPI, Query
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
def test_invalid_sequence():
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
class Item(BaseModel):
|
||||||
|
title: str
|
||||||
|
|
||||||
|
@app.get("/items/")
|
||||||
|
def read_items(q: List[Item] = Query(None)):
|
||||||
|
pass # pragma: no cover
|
||||||
|
|
||||||
|
|
||||||
|
def test_invalid_tuple():
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
class Item(BaseModel):
|
||||||
|
title: str
|
||||||
|
|
||||||
|
@app.get("/items/")
|
||||||
|
def read_items(q: Tuple[Item, Item] = Query(None)):
|
||||||
|
pass # pragma: no cover
|
||||||
|
|
@ -0,0 +1,120 @@
|
||||||
|
import pytest
|
||||||
|
from starlette.testclient import TestClient
|
||||||
|
|
||||||
|
from path_params.tutorial005 import app
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
|
||||||
|
openapi_schema = {
|
||||||
|
"openapi": "3.0.2",
|
||||||
|
"info": {"title": "Fast API", "version": "0.1.0"},
|
||||||
|
"paths": {
|
||||||
|
"/model/{model_name}": {
|
||||||
|
"get": {
|
||||||
|
"responses": {
|
||||||
|
"200": {
|
||||||
|
"description": "Successful Response",
|
||||||
|
"content": {"application/json": {"schema": {}}},
|
||||||
|
},
|
||||||
|
"422": {
|
||||||
|
"description": "Validation Error",
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/components/schemas/HTTPValidationError"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"summary": "Get Model",
|
||||||
|
"operationId": "get_model_model__model_name__get",
|
||||||
|
"parameters": [
|
||||||
|
{
|
||||||
|
"required": True,
|
||||||
|
"schema": {
|
||||||
|
"title": "Model_Name",
|
||||||
|
"enum": ["alexnet", "resnet", "lenet"],
|
||||||
|
},
|
||||||
|
"name": "model_name",
|
||||||
|
"in": "path",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"components": {
|
||||||
|
"schemas": {
|
||||||
|
"ValidationError": {
|
||||||
|
"title": "ValidationError",
|
||||||
|
"required": ["loc", "msg", "type"],
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"loc": {
|
||||||
|
"title": "Location",
|
||||||
|
"type": "array",
|
||||||
|
"items": {"type": "string"},
|
||||||
|
},
|
||||||
|
"msg": {"title": "Message", "type": "string"},
|
||||||
|
"type": {"title": "Error Type", "type": "string"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"HTTPValidationError": {
|
||||||
|
"title": "HTTPValidationError",
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"detail": {
|
||||||
|
"title": "Detail",
|
||||||
|
"type": "array",
|
||||||
|
"items": {"$ref": "#/components/schemas/ValidationError"},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_openapi():
|
||||||
|
response = client.get("/openapi.json")
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == openapi_schema
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"url,status_code,expected",
|
||||||
|
[
|
||||||
|
(
|
||||||
|
"/model/alexnet",
|
||||||
|
200,
|
||||||
|
{"model_name": "alexnet", "message": "Deep Learning FTW!"},
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"/model/lenet",
|
||||||
|
200,
|
||||||
|
{"model_name": "lenet", "message": "LeCNN all the images"},
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"/model/resnet",
|
||||||
|
200,
|
||||||
|
{"model_name": "resnet", "message": "Have some residuals"},
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"/model/foo",
|
||||||
|
422,
|
||||||
|
{
|
||||||
|
"detail": [
|
||||||
|
{
|
||||||
|
"loc": ["path", "model_name"],
|
||||||
|
"msg": "value is not a valid enumeration member",
|
||||||
|
"type": "type_error.enum",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_get_enums(url, status_code, expected):
|
||||||
|
response = client.get(url)
|
||||||
|
assert response.status_code == status_code
|
||||||
|
assert response.json() == expected
|
||||||
|
|
@ -0,0 +1,95 @@
|
||||||
|
from starlette.testclient import TestClient
|
||||||
|
|
||||||
|
from query_params.tutorial007 import app
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
|
||||||
|
openapi_schema = {
|
||||||
|
"openapi": "3.0.2",
|
||||||
|
"info": {"title": "Fast API", "version": "0.1.0"},
|
||||||
|
"paths": {
|
||||||
|
"/items/{item_id}": {
|
||||||
|
"get": {
|
||||||
|
"responses": {
|
||||||
|
"200": {
|
||||||
|
"description": "Successful Response",
|
||||||
|
"content": {"application/json": {"schema": {}}},
|
||||||
|
},
|
||||||
|
"422": {
|
||||||
|
"description": "Validation Error",
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/components/schemas/HTTPValidationError"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"summary": "Read User Item",
|
||||||
|
"operationId": "read_user_item_items__item_id__get",
|
||||||
|
"parameters": [
|
||||||
|
{
|
||||||
|
"required": True,
|
||||||
|
"schema": {"title": "Item_Id", "type": "string"},
|
||||||
|
"name": "item_id",
|
||||||
|
"in": "path",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"required": False,
|
||||||
|
"schema": {"title": "Limit", "type": "integer"},
|
||||||
|
"name": "limit",
|
||||||
|
"in": "query",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"components": {
|
||||||
|
"schemas": {
|
||||||
|
"ValidationError": {
|
||||||
|
"title": "ValidationError",
|
||||||
|
"required": ["loc", "msg", "type"],
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"loc": {
|
||||||
|
"title": "Location",
|
||||||
|
"type": "array",
|
||||||
|
"items": {"type": "string"},
|
||||||
|
},
|
||||||
|
"msg": {"title": "Message", "type": "string"},
|
||||||
|
"type": {"title": "Error Type", "type": "string"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"HTTPValidationError": {
|
||||||
|
"title": "HTTPValidationError",
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"detail": {
|
||||||
|
"title": "Detail",
|
||||||
|
"type": "array",
|
||||||
|
"items": {"$ref": "#/components/schemas/ValidationError"},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_openapi():
|
||||||
|
response = client.get("/openapi.json")
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == openapi_schema
|
||||||
|
|
||||||
|
|
||||||
|
def test_read_item():
|
||||||
|
response = client.get("/items/foo")
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == {"item_id": "foo", "limit": None}
|
||||||
|
|
||||||
|
|
||||||
|
def test_read_item_query():
|
||||||
|
response = client.get("/items/foo?limit=5")
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == {"item_id": "foo", "limit": 5}
|
||||||
|
|
@ -0,0 +1,96 @@
|
||||||
|
from starlette.testclient import TestClient
|
||||||
|
|
||||||
|
from query_params_str_validations.tutorial012 import app
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
|
||||||
|
openapi_schema = {
|
||||||
|
"openapi": "3.0.2",
|
||||||
|
"info": {"title": "Fast API", "version": "0.1.0"},
|
||||||
|
"paths": {
|
||||||
|
"/items/": {
|
||||||
|
"get": {
|
||||||
|
"responses": {
|
||||||
|
"200": {
|
||||||
|
"description": "Successful Response",
|
||||||
|
"content": {"application/json": {"schema": {}}},
|
||||||
|
},
|
||||||
|
"422": {
|
||||||
|
"description": "Validation Error",
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/components/schemas/HTTPValidationError"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"summary": "Read Items",
|
||||||
|
"operationId": "read_items_items__get",
|
||||||
|
"parameters": [
|
||||||
|
{
|
||||||
|
"required": False,
|
||||||
|
"schema": {
|
||||||
|
"title": "Q",
|
||||||
|
"type": "array",
|
||||||
|
"items": {"type": "string"},
|
||||||
|
"default": ["foo", "bar"],
|
||||||
|
},
|
||||||
|
"name": "q",
|
||||||
|
"in": "query",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"components": {
|
||||||
|
"schemas": {
|
||||||
|
"ValidationError": {
|
||||||
|
"title": "ValidationError",
|
||||||
|
"required": ["loc", "msg", "type"],
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"loc": {
|
||||||
|
"title": "Location",
|
||||||
|
"type": "array",
|
||||||
|
"items": {"type": "string"},
|
||||||
|
},
|
||||||
|
"msg": {"title": "Message", "type": "string"},
|
||||||
|
"type": {"title": "Error Type", "type": "string"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"HTTPValidationError": {
|
||||||
|
"title": "HTTPValidationError",
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"detail": {
|
||||||
|
"title": "Detail",
|
||||||
|
"type": "array",
|
||||||
|
"items": {"$ref": "#/components/schemas/ValidationError"},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_openapi_schema():
|
||||||
|
response = client.get("/openapi.json")
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == openapi_schema
|
||||||
|
|
||||||
|
|
||||||
|
def test_default_query_values():
|
||||||
|
url = "/items/"
|
||||||
|
response = client.get(url)
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == {"q": ["foo", "bar"]}
|
||||||
|
|
||||||
|
|
||||||
|
def test_multi_query_values():
|
||||||
|
url = "/items/?q=baz&q=foobar"
|
||||||
|
response = client.get(url)
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == {"q": ["baz", "foobar"]}
|
||||||
Loading…
Reference in New Issue