mirror of https://github.com/tiangolo/fastapi.git
✨ Add Service() annotation
This commit is contained in:
parent
a94ef3351e
commit
9a18adbc60
|
|
@ -1,4 +1,4 @@
|
|||
# Dependencies - `Depends()` and `Security()`
|
||||
# Dependencies - `Depends()`, `Security()` and `Service()`
|
||||
|
||||
## `Depends()`
|
||||
|
||||
|
|
@ -27,3 +27,18 @@ from fastapi import Security
|
|||
```
|
||||
|
||||
::: fastapi.Security
|
||||
|
||||
## `Service()`
|
||||
|
||||
`Depends()` is designed to include matching fields of annotated class as body/query params.
|
||||
To avoid this, you can annotate class using `Service()` instead of `Depends()`.
|
||||
|
||||
Here is the reference for it and its parameters.
|
||||
|
||||
You can import it directly from `fastapi`:
|
||||
|
||||
```python
|
||||
from fastapi import Service
|
||||
```
|
||||
|
||||
::: fastapi.Service
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ from .param_functions import Header as Header
|
|||
from .param_functions import Path as Path
|
||||
from .param_functions import Query as Query
|
||||
from .param_functions import Security as Security
|
||||
from .param_functions import Service as Service
|
||||
from .requests import Request as Request
|
||||
from .responses import Response as Response
|
||||
from .routing import APIRouter as APIRouter
|
||||
|
|
|
|||
|
|
@ -32,6 +32,7 @@ class Dependant:
|
|||
background_tasks_param_name: Optional[str] = None,
|
||||
security_scopes_param_name: Optional[str] = None,
|
||||
security_scopes: Optional[List[str]] = None,
|
||||
expose: bool = True,
|
||||
use_cache: bool = True,
|
||||
path: Optional[str] = None,
|
||||
) -> None:
|
||||
|
|
@ -51,6 +52,7 @@ class Dependant:
|
|||
self.security_scopes_param_name = security_scopes_param_name
|
||||
self.name = name
|
||||
self.call = call
|
||||
self.expose = expose
|
||||
self.use_cache = use_cache
|
||||
# Store the path to be able to re-generate a dependable from it in overrides
|
||||
self.path = path
|
||||
|
|
|
|||
|
|
@ -134,6 +134,9 @@ def get_sub_dependant(
|
|||
) -> Dependant:
|
||||
security_requirement = None
|
||||
security_scopes = security_scopes or []
|
||||
expose = True
|
||||
if isinstance(depends, params.Service):
|
||||
expose = False
|
||||
if isinstance(depends, params.Security):
|
||||
dependency_scopes = depends.scopes
|
||||
security_scopes.extend(dependency_scopes)
|
||||
|
|
@ -149,6 +152,7 @@ def get_sub_dependant(
|
|||
call=dependency,
|
||||
name=name,
|
||||
security_scopes=security_scopes,
|
||||
expose=expose,
|
||||
use_cache=depends.use_cache,
|
||||
)
|
||||
if security_requirement:
|
||||
|
|
@ -244,6 +248,7 @@ def get_dependant(
|
|||
call: Callable[..., Any],
|
||||
name: Optional[str] = None,
|
||||
security_scopes: Optional[List[str]] = None,
|
||||
expose: bool = True,
|
||||
use_cache: bool = True,
|
||||
) -> Dependant:
|
||||
path_param_names = get_path_param_names(path)
|
||||
|
|
@ -254,6 +259,7 @@ def get_dependant(
|
|||
name=name,
|
||||
path=path,
|
||||
security_scopes=security_scopes,
|
||||
expose=expose,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
for param_name, param in signature_params.items():
|
||||
|
|
@ -283,7 +289,9 @@ def get_dependant(
|
|||
), f"Cannot specify multiple FastAPI annotations for {param_name!r}"
|
||||
continue
|
||||
assert param_field is not None
|
||||
if is_body_param(param_field=param_field, is_path_param=is_path_param):
|
||||
if not expose:
|
||||
continue
|
||||
elif is_body_param(param_field=param_field, is_path_param=is_path_param):
|
||||
dependant.body_params.append(param_field)
|
||||
else:
|
||||
add_param_to_fields(field=param_field, dependant=dependant)
|
||||
|
|
@ -567,6 +575,7 @@ async def solve_dependencies(
|
|||
call=call,
|
||||
name=sub_dependant.name,
|
||||
security_scopes=sub_dependant.security_scopes,
|
||||
expose=sub_dependant.expose,
|
||||
)
|
||||
|
||||
solved_result = await solve_dependencies(
|
||||
|
|
|
|||
|
|
@ -2246,9 +2246,15 @@ def Depends( # noqa: N802
|
|||
] = True,
|
||||
) -> Any:
|
||||
"""
|
||||
Declare a FastAPI dependency.
|
||||
Declare a FastAPI Field dependency.
|
||||
|
||||
It takes a single "dependable" callable (like a function).
|
||||
Objects of annotated class are automatically created and filled up by FastAPI
|
||||
dependency injection mechanism (they should be annotated with `Body()`/`Query()`/etc).
|
||||
|
||||
Fields are automatically exposed to OpenAPI schema.
|
||||
|
||||
It takes a single "dependable" callable (like a function) which is factory creating objects of
|
||||
annotated class. If "dependable" is omitted, FastAPI will use class constructor.
|
||||
|
||||
Don't call it directly, FastAPI will call it for you.
|
||||
|
||||
|
|
@ -2266,6 +2272,7 @@ def Depends( # noqa: N802
|
|||
|
||||
|
||||
async def common_parameters(q: str | None = None, skip: int = 0, limit: int = 100):
|
||||
# Query params
|
||||
return {"q": q, "skip": skip, "limit": limit}
|
||||
|
||||
|
||||
|
|
@ -2298,7 +2305,7 @@ def Security( # noqa: N802
|
|||
dependency.
|
||||
|
||||
The term "scope" comes from the OAuth2 specification, it seems to be
|
||||
intentionaly vague and interpretable. It normally refers to permissions,
|
||||
intentionally vague and interpretable. It normally refers to permissions,
|
||||
in cases to roles.
|
||||
|
||||
These scopes are integrated with OpenAPI (and the API docs at `/docs`).
|
||||
|
|
@ -2358,3 +2365,82 @@ def Security( # noqa: N802
|
|||
```
|
||||
"""
|
||||
return params.Security(dependency=dependency, scopes=scopes, use_cache=use_cache)
|
||||
|
||||
|
||||
def Service( # noqa: N802
|
||||
dependency: Annotated[
|
||||
Optional[Callable[..., Any]],
|
||||
Doc(
|
||||
"""
|
||||
A "dependable" callable (like a function).
|
||||
|
||||
Don't call it directly, FastAPI will call it for you, just pass the object
|
||||
directly.
|
||||
"""
|
||||
),
|
||||
] = None,
|
||||
*,
|
||||
use_cache: Annotated[
|
||||
bool,
|
||||
Doc(
|
||||
"""
|
||||
By default, after a dependency is called the first time in a request, if
|
||||
the dependency is declared again for the rest of the request (for example
|
||||
if the dependency is needed by several dependencies), the value will be
|
||||
re-used for the rest of the request.
|
||||
|
||||
Set `use_cache` to `False` to disable this behavior and ensure the
|
||||
dependency is called again (if declared more than once) in the same request.
|
||||
"""
|
||||
),
|
||||
] = True,
|
||||
) -> Any:
|
||||
"""
|
||||
Declare a FastAPI Service dependency.
|
||||
|
||||
Objects of annotated class are automatically created and filled up by FastAPI
|
||||
dependency injection mechanism (they should be annotated with `Depends()`/`Service()`).
|
||||
|
||||
Unlike `Depends()`, `Service()` does not expose fields to OpenAPI schema.
|
||||
|
||||
It takes a single "dependable" callable (like a function) which is factory creating objects of
|
||||
annotated class. If "dependable" is omitted, FastAPI will use class constructor.
|
||||
|
||||
Don't call it directly, FastAPI will call it for you.
|
||||
|
||||
Read more about it in the
|
||||
[FastAPI docs for Dependencies](https://fastapi.tiangolo.com/tutorial/dependencies/).
|
||||
|
||||
**Example**
|
||||
|
||||
```python
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import FastAPI, Service, Depends
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from .models import Item
|
||||
from .db import async_session_factory
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
app.dependency_overrides[AsyncSession] = async_session_factory
|
||||
|
||||
|
||||
class ItemsService:
|
||||
def __init__(self, session: Annotated[AsyncSession, Depends()]):
|
||||
self.session = session
|
||||
|
||||
async def get_items(self) -> list[Item]:
|
||||
result = await session.scalars(select(Item))
|
||||
return result.all()
|
||||
|
||||
|
||||
@app.get("/items/")
|
||||
async def read_items(items_service: Annotated[ItemsService, Service()]) -> list[Item]:
|
||||
return await items_service.get_items()
|
||||
```
|
||||
"""
|
||||
return params.Service(dependency=dependency, use_cache=use_cache)
|
||||
|
|
|
|||
|
|
@ -760,7 +760,10 @@ class File(Form):
|
|||
|
||||
class Depends:
|
||||
def __init__(
|
||||
self, dependency: Optional[Callable[..., Any]] = None, *, use_cache: bool = True
|
||||
self,
|
||||
dependency: Optional[Callable[..., Any]] = None,
|
||||
*,
|
||||
use_cache: bool = True,
|
||||
):
|
||||
self.dependency = dependency
|
||||
self.use_cache = use_cache
|
||||
|
|
@ -781,3 +784,7 @@ class Security(Depends):
|
|||
):
|
||||
super().__init__(dependency=dependency, use_cache=use_cache)
|
||||
self.scopes = scopes or []
|
||||
|
||||
|
||||
class Service(Depends):
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -502,7 +502,11 @@ class APIRoute(routing.Route):
|
|||
additional_status_code
|
||||
), f"Status code {additional_status_code} must not have a response body"
|
||||
response_name = f"Response_{additional_status_code}_{self.unique_id}"
|
||||
response_field = create_response_field(name=response_name, type_=model)
|
||||
response_field = create_response_field(
|
||||
name=response_name,
|
||||
type_=model,
|
||||
mode="serialization",
|
||||
)
|
||||
response_fields[additional_status_code] = response_field
|
||||
if response_fields:
|
||||
self.response_fields: Dict[Union[int, str], ModelField] = response_fields
|
||||
|
|
|
|||
|
|
@ -0,0 +1,223 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import List, runtime_checkable
|
||||
|
||||
import pytest
|
||||
from fastapi import APIRouter, FastAPI, Service
|
||||
from fastapi.testclient import TestClient
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Protocol
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class Item(BaseModel):
|
||||
name: str
|
||||
value: int
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class ItemsProtocol(Protocol):
|
||||
def get_items(self) -> List[Item]:
|
||||
... # pragma: nocover
|
||||
|
||||
|
||||
class ItemsInterface(ABC):
|
||||
@abstractmethod
|
||||
def get_items(self) -> List[Item]:
|
||||
... # pragma: nocover
|
||||
|
||||
|
||||
class ItemsService(ItemsInterface):
|
||||
def __init__(self, name: str, value: int):
|
||||
self.name = name
|
||||
self.value = value
|
||||
|
||||
def get_items(self) -> List[Item]:
|
||||
return [Item(name=self.name, value=self.value)]
|
||||
|
||||
|
||||
class ClassNestedProtocol:
|
||||
def __init__(self, impl: ItemsProtocol = Service()):
|
||||
self.impl = impl
|
||||
|
||||
|
||||
class ClassNestedInterface:
|
||||
def __init__(self, impl: ItemsInterface = Service()):
|
||||
self.impl = impl
|
||||
|
||||
|
||||
class ClassNested:
|
||||
def __init__(self, impl: ItemsService = Service()):
|
||||
self.impl = impl
|
||||
|
||||
|
||||
@app.get("/depends-on-protocol/")
|
||||
async def depends_on_protocol(service: ItemsProtocol = Service()):
|
||||
return {"in": "depends-on-protocol", "items": service.get_items()}
|
||||
|
||||
|
||||
@app.get("/depends-on-interface/")
|
||||
async def depends_on_interface(service: ItemsInterface = Service()):
|
||||
return {"in": "depends-on-interface", "items": service.get_items()}
|
||||
|
||||
|
||||
@app.get("/depends-on-class/")
|
||||
async def depends_on_class(service: ItemsService = Service()):
|
||||
return {"in": "depends-on-class", "items": service.get_items()}
|
||||
|
||||
|
||||
@app.get("/depends-on-nested-protocol/")
|
||||
async def depends_on_nested_protocol(service: ClassNestedProtocol = Service()):
|
||||
return {"in": "depends-on-nested-protocol", "items": service.impl.get_items()}
|
||||
|
||||
|
||||
@app.get("/depends-on-nested-interface/")
|
||||
async def depends_on_nested_interface(service: ClassNestedInterface = Service()):
|
||||
return {"in": "depends-on-nested-interface", "items": service.impl.get_items()}
|
||||
|
||||
|
||||
@app.get("/depends-on-nested-class/")
|
||||
async def depends_on_nested_class(service: ClassNested = Service()):
|
||||
return {"in": "depends-on-nested-class", "items": service.impl.get_items()}
|
||||
|
||||
|
||||
app.include_router(router)
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
def test_depends_on_protocol_no_override():
|
||||
with pytest.raises(TypeError, match="Protocols cannot be instantiated"):
|
||||
# not 422 error about missing args and kwargs inputs
|
||||
client.get("/depends-on-protocol/")
|
||||
|
||||
|
||||
def test_depends_on_protocol_override():
|
||||
app.dependency_overrides[ItemsProtocol] = lambda: ItemsService(name="abc", value=1)
|
||||
response = client.get("/depends-on-protocol/")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {
|
||||
"in": "depends-on-protocol",
|
||||
"items": [
|
||||
{
|
||||
"name": "abc",
|
||||
"value": 1,
|
||||
},
|
||||
],
|
||||
}
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
def test_depends_on_interface_no_override():
|
||||
error_msg = "Can't instantiate abstract class ItemsInterface with.*abstract method"
|
||||
with pytest.raises(TypeError, match=error_msg):
|
||||
client.get("/depends-on-interface/")
|
||||
|
||||
|
||||
def test_depends_on_interface_override():
|
||||
app.dependency_overrides[ItemsInterface] = lambda: ItemsService(name="abc", value=1)
|
||||
response = client.get("/depends-on-interface/")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {
|
||||
"in": "depends-on-interface",
|
||||
"items": [
|
||||
{
|
||||
"name": "abc",
|
||||
"value": 1,
|
||||
},
|
||||
],
|
||||
}
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
def test_depends_on_class_no_override():
|
||||
error_msg = "missing 2 required positional arguments: 'name' and 'value'"
|
||||
with pytest.raises(TypeError, match=error_msg):
|
||||
# not 422 error about missing body fields
|
||||
client.get("/depends-on-class/")
|
||||
|
||||
|
||||
def test_depends_on_class_override():
|
||||
app.dependency_overrides[ItemsService] = lambda: ItemsService(name="abc", value=1)
|
||||
response = client.get("/depends-on-class/")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {
|
||||
"in": "depends-on-class",
|
||||
"items": [
|
||||
{
|
||||
"name": "abc",
|
||||
"value": 1,
|
||||
},
|
||||
],
|
||||
}
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
def test_depends_on_nested_protocol_no_override():
|
||||
with pytest.raises(TypeError, match="Protocols cannot be instantiated"):
|
||||
client.get("/depends-on-nested-protocol/")
|
||||
|
||||
|
||||
def test_depends_on_nested_protocol_override_top_level():
|
||||
service = ClassNestedProtocol(impl=ItemsService(name="abc", value=1))
|
||||
app.dependency_overrides[ClassNestedProtocol] = lambda: service
|
||||
response = client.get("/depends-on-nested-protocol/")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {
|
||||
"in": "depends-on-nested-protocol",
|
||||
"items": [
|
||||
{
|
||||
"name": "abc",
|
||||
"value": 1,
|
||||
},
|
||||
],
|
||||
}
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
def test_depends_on_nested_interface_no_override():
|
||||
error_msg = "Can't instantiate abstract class ItemsInterface with.*abstract method"
|
||||
with pytest.raises(TypeError, match=error_msg):
|
||||
client.get("/depends-on-nested-interface/")
|
||||
|
||||
|
||||
def test_depends_on_nested_interface_override_top_level():
|
||||
service = ClassNestedInterface(impl=ItemsService(name="abc", value=1))
|
||||
app.dependency_overrides[ClassNestedInterface] = lambda: service
|
||||
response = client.get("/depends-on-nested-interface/")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {
|
||||
"in": "depends-on-nested-interface",
|
||||
"items": [
|
||||
{
|
||||
"name": "abc",
|
||||
"value": 1,
|
||||
},
|
||||
],
|
||||
}
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
def test_depends_on_nested_class_no_override():
|
||||
error_msg = "missing 2 required positional arguments: 'name' and 'value'"
|
||||
with pytest.raises(TypeError, match=error_msg):
|
||||
client.get("/depends-on-nested-class/")
|
||||
|
||||
|
||||
def test_depends_on_nested_class_override_top_level():
|
||||
service = ClassNested(impl=ItemsService(name="abc", value=1))
|
||||
app.dependency_overrides[ClassNested] = lambda: service
|
||||
response = client.get("/depends-on-nested-class/")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {
|
||||
"in": "depends-on-nested-class",
|
||||
"items": [
|
||||
{
|
||||
"name": "abc",
|
||||
"value": 1,
|
||||
},
|
||||
],
|
||||
}
|
||||
app.dependency_overrides = {}
|
||||
Loading…
Reference in New Issue