Add Service() annotation

This commit is contained in:
Maxim Martynov 2024-01-06 21:45:05 +03:00
parent a94ef3351e
commit 9a18adbc60
No known key found for this signature in database
GPG Key ID: B65B10B4DB25B7E9
8 changed files with 354 additions and 7 deletions

View File

@ -1,4 +1,4 @@
# Dependencies - `Depends()` and `Security()`
# Dependencies - `Depends()`, `Security()` and `Service()`
## `Depends()`
@ -27,3 +27,18 @@ from fastapi import Security
```
::: fastapi.Security
## `Service()`
`Depends()` is designed to include matching fields of annotated class as body/query params.
To avoid this, you can annotate class using `Service()` instead of `Depends()`.
Here is the reference for it and its parameters.
You can import it directly from `fastapi`:
```python
from fastapi import Service
```
::: fastapi.Service

View File

@ -18,6 +18,7 @@ from .param_functions import Header as Header
from .param_functions import Path as Path
from .param_functions import Query as Query
from .param_functions import Security as Security
from .param_functions import Service as Service
from .requests import Request as Request
from .responses import Response as Response
from .routing import APIRouter as APIRouter

View File

@ -32,6 +32,7 @@ class Dependant:
background_tasks_param_name: Optional[str] = None,
security_scopes_param_name: Optional[str] = None,
security_scopes: Optional[List[str]] = None,
expose: bool = True,
use_cache: bool = True,
path: Optional[str] = None,
) -> None:
@ -51,6 +52,7 @@ class Dependant:
self.security_scopes_param_name = security_scopes_param_name
self.name = name
self.call = call
self.expose = expose
self.use_cache = use_cache
# Store the path to be able to re-generate a dependable from it in overrides
self.path = path

View File

@ -134,6 +134,9 @@ def get_sub_dependant(
) -> Dependant:
security_requirement = None
security_scopes = security_scopes or []
expose = True
if isinstance(depends, params.Service):
expose = False
if isinstance(depends, params.Security):
dependency_scopes = depends.scopes
security_scopes.extend(dependency_scopes)
@ -149,6 +152,7 @@ def get_sub_dependant(
call=dependency,
name=name,
security_scopes=security_scopes,
expose=expose,
use_cache=depends.use_cache,
)
if security_requirement:
@ -244,6 +248,7 @@ def get_dependant(
call: Callable[..., Any],
name: Optional[str] = None,
security_scopes: Optional[List[str]] = None,
expose: bool = True,
use_cache: bool = True,
) -> Dependant:
path_param_names = get_path_param_names(path)
@ -254,6 +259,7 @@ def get_dependant(
name=name,
path=path,
security_scopes=security_scopes,
expose=expose,
use_cache=use_cache,
)
for param_name, param in signature_params.items():
@ -283,7 +289,9 @@ def get_dependant(
), f"Cannot specify multiple FastAPI annotations for {param_name!r}"
continue
assert param_field is not None
if is_body_param(param_field=param_field, is_path_param=is_path_param):
if not expose:
continue
elif is_body_param(param_field=param_field, is_path_param=is_path_param):
dependant.body_params.append(param_field)
else:
add_param_to_fields(field=param_field, dependant=dependant)
@ -567,6 +575,7 @@ async def solve_dependencies(
call=call,
name=sub_dependant.name,
security_scopes=sub_dependant.security_scopes,
expose=sub_dependant.expose,
)
solved_result = await solve_dependencies(

View File

@ -2246,9 +2246,15 @@ def Depends( # noqa: N802
] = True,
) -> Any:
"""
Declare a FastAPI dependency.
Declare a FastAPI Field dependency.
It takes a single "dependable" callable (like a function).
Objects of annotated class are automatically created and filled up by FastAPI
dependency injection mechanism (they should be annotated with `Body()`/`Query()`/etc).
Fields are automatically exposed to OpenAPI schema.
It takes a single "dependable" callable (like a function) which is factory creating objects of
annotated class. If "dependable" is omitted, FastAPI will use class constructor.
Don't call it directly, FastAPI will call it for you.
@ -2266,6 +2272,7 @@ def Depends( # noqa: N802
async def common_parameters(q: str | None = None, skip: int = 0, limit: int = 100):
# Query params
return {"q": q, "skip": skip, "limit": limit}
@ -2298,7 +2305,7 @@ def Security( # noqa: N802
dependency.
The term "scope" comes from the OAuth2 specification, it seems to be
intentionaly vague and interpretable. It normally refers to permissions,
intentionally vague and interpretable. It normally refers to permissions,
in cases to roles.
These scopes are integrated with OpenAPI (and the API docs at `/docs`).
@ -2358,3 +2365,82 @@ def Security( # noqa: N802
```
"""
return params.Security(dependency=dependency, scopes=scopes, use_cache=use_cache)
def Service( # noqa: N802
dependency: Annotated[
Optional[Callable[..., Any]],
Doc(
"""
A "dependable" callable (like a function).
Don't call it directly, FastAPI will call it for you, just pass the object
directly.
"""
),
] = None,
*,
use_cache: Annotated[
bool,
Doc(
"""
By default, after a dependency is called the first time in a request, if
the dependency is declared again for the rest of the request (for example
if the dependency is needed by several dependencies), the value will be
re-used for the rest of the request.
Set `use_cache` to `False` to disable this behavior and ensure the
dependency is called again (if declared more than once) in the same request.
"""
),
] = True,
) -> Any:
"""
Declare a FastAPI Service dependency.
Objects of annotated class are automatically created and filled up by FastAPI
dependency injection mechanism (they should be annotated with `Depends()`/`Service()`).
Unlike `Depends()`, `Service()` does not expose fields to OpenAPI schema.
It takes a single "dependable" callable (like a function) which is factory creating objects of
annotated class. If "dependable" is omitted, FastAPI will use class constructor.
Don't call it directly, FastAPI will call it for you.
Read more about it in the
[FastAPI docs for Dependencies](https://fastapi.tiangolo.com/tutorial/dependencies/).
**Example**
```python
from typing import Annotated
from fastapi import FastAPI, Service, Depends
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from .models import Item
from .db import async_session_factory
app = FastAPI()
app.dependency_overrides[AsyncSession] = async_session_factory
class ItemsService:
def __init__(self, session: Annotated[AsyncSession, Depends()]):
self.session = session
async def get_items(self) -> list[Item]:
result = await session.scalars(select(Item))
return result.all()
@app.get("/items/")
async def read_items(items_service: Annotated[ItemsService, Service()]) -> list[Item]:
return await items_service.get_items()
```
"""
return params.Service(dependency=dependency, use_cache=use_cache)

View File

@ -760,7 +760,10 @@ class File(Form):
class Depends:
def __init__(
self, dependency: Optional[Callable[..., Any]] = None, *, use_cache: bool = True
self,
dependency: Optional[Callable[..., Any]] = None,
*,
use_cache: bool = True,
):
self.dependency = dependency
self.use_cache = use_cache
@ -781,3 +784,7 @@ class Security(Depends):
):
super().__init__(dependency=dependency, use_cache=use_cache)
self.scopes = scopes or []
class Service(Depends):
pass

View File

@ -502,7 +502,11 @@ class APIRoute(routing.Route):
additional_status_code
), f"Status code {additional_status_code} must not have a response body"
response_name = f"Response_{additional_status_code}_{self.unique_id}"
response_field = create_response_field(name=response_name, type_=model)
response_field = create_response_field(
name=response_name,
type_=model,
mode="serialization",
)
response_fields[additional_status_code] = response_field
if response_fields:
self.response_fields: Dict[Union[int, str], ModelField] = response_fields

View File

@ -0,0 +1,223 @@
from abc import ABC, abstractmethod
from typing import List, runtime_checkable
import pytest
from fastapi import APIRouter, FastAPI, Service
from fastapi.testclient import TestClient
from pydantic import BaseModel
from typing_extensions import Protocol
app = FastAPI()
router = APIRouter()
class Item(BaseModel):
name: str
value: int
@runtime_checkable
class ItemsProtocol(Protocol):
def get_items(self) -> List[Item]:
... # pragma: nocover
class ItemsInterface(ABC):
@abstractmethod
def get_items(self) -> List[Item]:
... # pragma: nocover
class ItemsService(ItemsInterface):
def __init__(self, name: str, value: int):
self.name = name
self.value = value
def get_items(self) -> List[Item]:
return [Item(name=self.name, value=self.value)]
class ClassNestedProtocol:
def __init__(self, impl: ItemsProtocol = Service()):
self.impl = impl
class ClassNestedInterface:
def __init__(self, impl: ItemsInterface = Service()):
self.impl = impl
class ClassNested:
def __init__(self, impl: ItemsService = Service()):
self.impl = impl
@app.get("/depends-on-protocol/")
async def depends_on_protocol(service: ItemsProtocol = Service()):
return {"in": "depends-on-protocol", "items": service.get_items()}
@app.get("/depends-on-interface/")
async def depends_on_interface(service: ItemsInterface = Service()):
return {"in": "depends-on-interface", "items": service.get_items()}
@app.get("/depends-on-class/")
async def depends_on_class(service: ItemsService = Service()):
return {"in": "depends-on-class", "items": service.get_items()}
@app.get("/depends-on-nested-protocol/")
async def depends_on_nested_protocol(service: ClassNestedProtocol = Service()):
return {"in": "depends-on-nested-protocol", "items": service.impl.get_items()}
@app.get("/depends-on-nested-interface/")
async def depends_on_nested_interface(service: ClassNestedInterface = Service()):
return {"in": "depends-on-nested-interface", "items": service.impl.get_items()}
@app.get("/depends-on-nested-class/")
async def depends_on_nested_class(service: ClassNested = Service()):
return {"in": "depends-on-nested-class", "items": service.impl.get_items()}
app.include_router(router)
client = TestClient(app)
def test_depends_on_protocol_no_override():
with pytest.raises(TypeError, match="Protocols cannot be instantiated"):
# not 422 error about missing args and kwargs inputs
client.get("/depends-on-protocol/")
def test_depends_on_protocol_override():
app.dependency_overrides[ItemsProtocol] = lambda: ItemsService(name="abc", value=1)
response = client.get("/depends-on-protocol/")
assert response.status_code == 200
assert response.json() == {
"in": "depends-on-protocol",
"items": [
{
"name": "abc",
"value": 1,
},
],
}
app.dependency_overrides = {}
def test_depends_on_interface_no_override():
error_msg = "Can't instantiate abstract class ItemsInterface with.*abstract method"
with pytest.raises(TypeError, match=error_msg):
client.get("/depends-on-interface/")
def test_depends_on_interface_override():
app.dependency_overrides[ItemsInterface] = lambda: ItemsService(name="abc", value=1)
response = client.get("/depends-on-interface/")
assert response.status_code == 200
assert response.json() == {
"in": "depends-on-interface",
"items": [
{
"name": "abc",
"value": 1,
},
],
}
app.dependency_overrides = {}
def test_depends_on_class_no_override():
error_msg = "missing 2 required positional arguments: 'name' and 'value'"
with pytest.raises(TypeError, match=error_msg):
# not 422 error about missing body fields
client.get("/depends-on-class/")
def test_depends_on_class_override():
app.dependency_overrides[ItemsService] = lambda: ItemsService(name="abc", value=1)
response = client.get("/depends-on-class/")
assert response.status_code == 200
assert response.json() == {
"in": "depends-on-class",
"items": [
{
"name": "abc",
"value": 1,
},
],
}
app.dependency_overrides = {}
def test_depends_on_nested_protocol_no_override():
with pytest.raises(TypeError, match="Protocols cannot be instantiated"):
client.get("/depends-on-nested-protocol/")
def test_depends_on_nested_protocol_override_top_level():
service = ClassNestedProtocol(impl=ItemsService(name="abc", value=1))
app.dependency_overrides[ClassNestedProtocol] = lambda: service
response = client.get("/depends-on-nested-protocol/")
assert response.status_code == 200
assert response.json() == {
"in": "depends-on-nested-protocol",
"items": [
{
"name": "abc",
"value": 1,
},
],
}
app.dependency_overrides = {}
def test_depends_on_nested_interface_no_override():
error_msg = "Can't instantiate abstract class ItemsInterface with.*abstract method"
with pytest.raises(TypeError, match=error_msg):
client.get("/depends-on-nested-interface/")
def test_depends_on_nested_interface_override_top_level():
service = ClassNestedInterface(impl=ItemsService(name="abc", value=1))
app.dependency_overrides[ClassNestedInterface] = lambda: service
response = client.get("/depends-on-nested-interface/")
assert response.status_code == 200
assert response.json() == {
"in": "depends-on-nested-interface",
"items": [
{
"name": "abc",
"value": 1,
},
],
}
app.dependency_overrides = {}
def test_depends_on_nested_class_no_override():
error_msg = "missing 2 required positional arguments: 'name' and 'value'"
with pytest.raises(TypeError, match=error_msg):
client.get("/depends-on-nested-class/")
def test_depends_on_nested_class_override_top_level():
service = ClassNested(impl=ItemsService(name="abc", value=1))
app.dependency_overrides[ClassNested] = lambda: service
response = client.get("/depends-on-nested-class/")
assert response.status_code == 200
assert response.json() == {
"in": "depends-on-nested-class",
"items": [
{
"name": "abc",
"value": 1,
},
],
}
app.dependency_overrides = {}