Add Service() annotation

This commit is contained in:
Maxim Martynov 2024-01-06 21:45:05 +03:00
parent a94ef3351e
commit 9a18adbc60
No known key found for this signature in database
GPG Key ID: B65B10B4DB25B7E9
8 changed files with 354 additions and 7 deletions

View File

@ -1,4 +1,4 @@
# Dependencies - `Depends()` and `Security()` # Dependencies - `Depends()`, `Security()` and `Service()`
## `Depends()` ## `Depends()`
@ -27,3 +27,18 @@ from fastapi import Security
``` ```
::: fastapi.Security ::: fastapi.Security
## `Service()`
`Depends()` is designed to include matching fields of annotated class as body/query params.
To avoid this, you can annotate class using `Service()` instead of `Depends()`.
Here is the reference for it and its parameters.
You can import it directly from `fastapi`:
```python
from fastapi import Service
```
::: fastapi.Service

View File

@ -18,6 +18,7 @@ from .param_functions import Header as Header
from .param_functions import Path as Path from .param_functions import Path as Path
from .param_functions import Query as Query from .param_functions import Query as Query
from .param_functions import Security as Security from .param_functions import Security as Security
from .param_functions import Service as Service
from .requests import Request as Request from .requests import Request as Request
from .responses import Response as Response from .responses import Response as Response
from .routing import APIRouter as APIRouter from .routing import APIRouter as APIRouter

View File

@ -32,6 +32,7 @@ class Dependant:
background_tasks_param_name: Optional[str] = None, background_tasks_param_name: Optional[str] = None,
security_scopes_param_name: Optional[str] = None, security_scopes_param_name: Optional[str] = None,
security_scopes: Optional[List[str]] = None, security_scopes: Optional[List[str]] = None,
expose: bool = True,
use_cache: bool = True, use_cache: bool = True,
path: Optional[str] = None, path: Optional[str] = None,
) -> None: ) -> None:
@ -51,6 +52,7 @@ class Dependant:
self.security_scopes_param_name = security_scopes_param_name self.security_scopes_param_name = security_scopes_param_name
self.name = name self.name = name
self.call = call self.call = call
self.expose = expose
self.use_cache = use_cache self.use_cache = use_cache
# Store the path to be able to re-generate a dependable from it in overrides # Store the path to be able to re-generate a dependable from it in overrides
self.path = path self.path = path

View File

@ -134,6 +134,9 @@ def get_sub_dependant(
) -> Dependant: ) -> Dependant:
security_requirement = None security_requirement = None
security_scopes = security_scopes or [] security_scopes = security_scopes or []
expose = True
if isinstance(depends, params.Service):
expose = False
if isinstance(depends, params.Security): if isinstance(depends, params.Security):
dependency_scopes = depends.scopes dependency_scopes = depends.scopes
security_scopes.extend(dependency_scopes) security_scopes.extend(dependency_scopes)
@ -149,6 +152,7 @@ def get_sub_dependant(
call=dependency, call=dependency,
name=name, name=name,
security_scopes=security_scopes, security_scopes=security_scopes,
expose=expose,
use_cache=depends.use_cache, use_cache=depends.use_cache,
) )
if security_requirement: if security_requirement:
@ -244,6 +248,7 @@ def get_dependant(
call: Callable[..., Any], call: Callable[..., Any],
name: Optional[str] = None, name: Optional[str] = None,
security_scopes: Optional[List[str]] = None, security_scopes: Optional[List[str]] = None,
expose: bool = True,
use_cache: bool = True, use_cache: bool = True,
) -> Dependant: ) -> Dependant:
path_param_names = get_path_param_names(path) path_param_names = get_path_param_names(path)
@ -254,6 +259,7 @@ def get_dependant(
name=name, name=name,
path=path, path=path,
security_scopes=security_scopes, security_scopes=security_scopes,
expose=expose,
use_cache=use_cache, use_cache=use_cache,
) )
for param_name, param in signature_params.items(): for param_name, param in signature_params.items():
@ -283,7 +289,9 @@ def get_dependant(
), f"Cannot specify multiple FastAPI annotations for {param_name!r}" ), f"Cannot specify multiple FastAPI annotations for {param_name!r}"
continue continue
assert param_field is not None assert param_field is not None
if is_body_param(param_field=param_field, is_path_param=is_path_param): if not expose:
continue
elif is_body_param(param_field=param_field, is_path_param=is_path_param):
dependant.body_params.append(param_field) dependant.body_params.append(param_field)
else: else:
add_param_to_fields(field=param_field, dependant=dependant) add_param_to_fields(field=param_field, dependant=dependant)
@ -567,6 +575,7 @@ async def solve_dependencies(
call=call, call=call,
name=sub_dependant.name, name=sub_dependant.name,
security_scopes=sub_dependant.security_scopes, security_scopes=sub_dependant.security_scopes,
expose=sub_dependant.expose,
) )
solved_result = await solve_dependencies( solved_result = await solve_dependencies(

View File

@ -2246,9 +2246,15 @@ def Depends( # noqa: N802
] = True, ] = True,
) -> Any: ) -> Any:
""" """
Declare a FastAPI dependency. Declare a FastAPI Field dependency.
It takes a single "dependable" callable (like a function). Objects of annotated class are automatically created and filled up by FastAPI
dependency injection mechanism (they should be annotated with `Body()`/`Query()`/etc).
Fields are automatically exposed to OpenAPI schema.
It takes a single "dependable" callable (like a function) which is factory creating objects of
annotated class. If "dependable" is omitted, FastAPI will use class constructor.
Don't call it directly, FastAPI will call it for you. Don't call it directly, FastAPI will call it for you.
@ -2266,6 +2272,7 @@ def Depends( # noqa: N802
async def common_parameters(q: str | None = None, skip: int = 0, limit: int = 100): async def common_parameters(q: str | None = None, skip: int = 0, limit: int = 100):
# Query params
return {"q": q, "skip": skip, "limit": limit} return {"q": q, "skip": skip, "limit": limit}
@ -2298,7 +2305,7 @@ def Security( # noqa: N802
dependency. dependency.
The term "scope" comes from the OAuth2 specification, it seems to be The term "scope" comes from the OAuth2 specification, it seems to be
intentionaly vague and interpretable. It normally refers to permissions, intentionally vague and interpretable. It normally refers to permissions,
in cases to roles. in cases to roles.
These scopes are integrated with OpenAPI (and the API docs at `/docs`). These scopes are integrated with OpenAPI (and the API docs at `/docs`).
@ -2358,3 +2365,82 @@ def Security( # noqa: N802
``` ```
""" """
return params.Security(dependency=dependency, scopes=scopes, use_cache=use_cache) return params.Security(dependency=dependency, scopes=scopes, use_cache=use_cache)
def Service( # noqa: N802
dependency: Annotated[
Optional[Callable[..., Any]],
Doc(
"""
A "dependable" callable (like a function).
Don't call it directly, FastAPI will call it for you, just pass the object
directly.
"""
),
] = None,
*,
use_cache: Annotated[
bool,
Doc(
"""
By default, after a dependency is called the first time in a request, if
the dependency is declared again for the rest of the request (for example
if the dependency is needed by several dependencies), the value will be
re-used for the rest of the request.
Set `use_cache` to `False` to disable this behavior and ensure the
dependency is called again (if declared more than once) in the same request.
"""
),
] = True,
) -> Any:
"""
Declare a FastAPI Service dependency.
Objects of annotated class are automatically created and filled up by FastAPI
dependency injection mechanism (they should be annotated with `Depends()`/`Service()`).
Unlike `Depends()`, `Service()` does not expose fields to OpenAPI schema.
It takes a single "dependable" callable (like a function) which is factory creating objects of
annotated class. If "dependable" is omitted, FastAPI will use class constructor.
Don't call it directly, FastAPI will call it for you.
Read more about it in the
[FastAPI docs for Dependencies](https://fastapi.tiangolo.com/tutorial/dependencies/).
**Example**
```python
from typing import Annotated
from fastapi import FastAPI, Service, Depends
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from .models import Item
from .db import async_session_factory
app = FastAPI()
app.dependency_overrides[AsyncSession] = async_session_factory
class ItemsService:
def __init__(self, session: Annotated[AsyncSession, Depends()]):
self.session = session
async def get_items(self) -> list[Item]:
result = await session.scalars(select(Item))
return result.all()
@app.get("/items/")
async def read_items(items_service: Annotated[ItemsService, Service()]) -> list[Item]:
return await items_service.get_items()
```
"""
return params.Service(dependency=dependency, use_cache=use_cache)

View File

@ -760,7 +760,10 @@ class File(Form):
class Depends: class Depends:
def __init__( def __init__(
self, dependency: Optional[Callable[..., Any]] = None, *, use_cache: bool = True self,
dependency: Optional[Callable[..., Any]] = None,
*,
use_cache: bool = True,
): ):
self.dependency = dependency self.dependency = dependency
self.use_cache = use_cache self.use_cache = use_cache
@ -781,3 +784,7 @@ class Security(Depends):
): ):
super().__init__(dependency=dependency, use_cache=use_cache) super().__init__(dependency=dependency, use_cache=use_cache)
self.scopes = scopes or [] self.scopes = scopes or []
class Service(Depends):
pass

View File

@ -502,7 +502,11 @@ class APIRoute(routing.Route):
additional_status_code additional_status_code
), f"Status code {additional_status_code} must not have a response body" ), f"Status code {additional_status_code} must not have a response body"
response_name = f"Response_{additional_status_code}_{self.unique_id}" response_name = f"Response_{additional_status_code}_{self.unique_id}"
response_field = create_response_field(name=response_name, type_=model) response_field = create_response_field(
name=response_name,
type_=model,
mode="serialization",
)
response_fields[additional_status_code] = response_field response_fields[additional_status_code] = response_field
if response_fields: if response_fields:
self.response_fields: Dict[Union[int, str], ModelField] = response_fields self.response_fields: Dict[Union[int, str], ModelField] = response_fields

View File

@ -0,0 +1,223 @@
from abc import ABC, abstractmethod
from typing import List, runtime_checkable
import pytest
from fastapi import APIRouter, FastAPI, Service
from fastapi.testclient import TestClient
from pydantic import BaseModel
from typing_extensions import Protocol
app = FastAPI()
router = APIRouter()
class Item(BaseModel):
name: str
value: int
@runtime_checkable
class ItemsProtocol(Protocol):
def get_items(self) -> List[Item]:
... # pragma: nocover
class ItemsInterface(ABC):
@abstractmethod
def get_items(self) -> List[Item]:
... # pragma: nocover
class ItemsService(ItemsInterface):
def __init__(self, name: str, value: int):
self.name = name
self.value = value
def get_items(self) -> List[Item]:
return [Item(name=self.name, value=self.value)]
class ClassNestedProtocol:
def __init__(self, impl: ItemsProtocol = Service()):
self.impl = impl
class ClassNestedInterface:
def __init__(self, impl: ItemsInterface = Service()):
self.impl = impl
class ClassNested:
def __init__(self, impl: ItemsService = Service()):
self.impl = impl
@app.get("/depends-on-protocol/")
async def depends_on_protocol(service: ItemsProtocol = Service()):
return {"in": "depends-on-protocol", "items": service.get_items()}
@app.get("/depends-on-interface/")
async def depends_on_interface(service: ItemsInterface = Service()):
return {"in": "depends-on-interface", "items": service.get_items()}
@app.get("/depends-on-class/")
async def depends_on_class(service: ItemsService = Service()):
return {"in": "depends-on-class", "items": service.get_items()}
@app.get("/depends-on-nested-protocol/")
async def depends_on_nested_protocol(service: ClassNestedProtocol = Service()):
return {"in": "depends-on-nested-protocol", "items": service.impl.get_items()}
@app.get("/depends-on-nested-interface/")
async def depends_on_nested_interface(service: ClassNestedInterface = Service()):
return {"in": "depends-on-nested-interface", "items": service.impl.get_items()}
@app.get("/depends-on-nested-class/")
async def depends_on_nested_class(service: ClassNested = Service()):
return {"in": "depends-on-nested-class", "items": service.impl.get_items()}
app.include_router(router)
client = TestClient(app)
def test_depends_on_protocol_no_override():
with pytest.raises(TypeError, match="Protocols cannot be instantiated"):
# not 422 error about missing args and kwargs inputs
client.get("/depends-on-protocol/")
def test_depends_on_protocol_override():
app.dependency_overrides[ItemsProtocol] = lambda: ItemsService(name="abc", value=1)
response = client.get("/depends-on-protocol/")
assert response.status_code == 200
assert response.json() == {
"in": "depends-on-protocol",
"items": [
{
"name": "abc",
"value": 1,
},
],
}
app.dependency_overrides = {}
def test_depends_on_interface_no_override():
error_msg = "Can't instantiate abstract class ItemsInterface with.*abstract method"
with pytest.raises(TypeError, match=error_msg):
client.get("/depends-on-interface/")
def test_depends_on_interface_override():
app.dependency_overrides[ItemsInterface] = lambda: ItemsService(name="abc", value=1)
response = client.get("/depends-on-interface/")
assert response.status_code == 200
assert response.json() == {
"in": "depends-on-interface",
"items": [
{
"name": "abc",
"value": 1,
},
],
}
app.dependency_overrides = {}
def test_depends_on_class_no_override():
error_msg = "missing 2 required positional arguments: 'name' and 'value'"
with pytest.raises(TypeError, match=error_msg):
# not 422 error about missing body fields
client.get("/depends-on-class/")
def test_depends_on_class_override():
app.dependency_overrides[ItemsService] = lambda: ItemsService(name="abc", value=1)
response = client.get("/depends-on-class/")
assert response.status_code == 200
assert response.json() == {
"in": "depends-on-class",
"items": [
{
"name": "abc",
"value": 1,
},
],
}
app.dependency_overrides = {}
def test_depends_on_nested_protocol_no_override():
with pytest.raises(TypeError, match="Protocols cannot be instantiated"):
client.get("/depends-on-nested-protocol/")
def test_depends_on_nested_protocol_override_top_level():
service = ClassNestedProtocol(impl=ItemsService(name="abc", value=1))
app.dependency_overrides[ClassNestedProtocol] = lambda: service
response = client.get("/depends-on-nested-protocol/")
assert response.status_code == 200
assert response.json() == {
"in": "depends-on-nested-protocol",
"items": [
{
"name": "abc",
"value": 1,
},
],
}
app.dependency_overrides = {}
def test_depends_on_nested_interface_no_override():
error_msg = "Can't instantiate abstract class ItemsInterface with.*abstract method"
with pytest.raises(TypeError, match=error_msg):
client.get("/depends-on-nested-interface/")
def test_depends_on_nested_interface_override_top_level():
service = ClassNestedInterface(impl=ItemsService(name="abc", value=1))
app.dependency_overrides[ClassNestedInterface] = lambda: service
response = client.get("/depends-on-nested-interface/")
assert response.status_code == 200
assert response.json() == {
"in": "depends-on-nested-interface",
"items": [
{
"name": "abc",
"value": 1,
},
],
}
app.dependency_overrides = {}
def test_depends_on_nested_class_no_override():
error_msg = "missing 2 required positional arguments: 'name' and 'value'"
with pytest.raises(TypeError, match=error_msg):
client.get("/depends-on-nested-class/")
def test_depends_on_nested_class_override_top_level():
service = ClassNested(impl=ItemsService(name="abc", value=1))
app.dependency_overrides[ClassNested] = lambda: service
response = client.get("/depends-on-nested-class/")
assert response.status_code == 200
assert response.json() == {
"in": "depends-on-nested-class",
"items": [
{
"name": "abc",
"value": 1,
},
],
}
app.dependency_overrides = {}