Allow using dependables with `functools.partial()` (#9753)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Motov Yurii <109919500+YuriiMotov@users.noreply.github.com>
Co-authored-by: Yurii Motov <yurii.motov.monte@gmail.com>
Co-authored-by: Sebastián Ramírez <tiangolo@gmail.com>
This commit is contained in:
Lie Ryan 2025-12-03 07:58:30 +11:00 committed by GitHub
parent aee8e78078
commit 9824486616
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 256 additions and 2 deletions

View File

@ -1,7 +1,7 @@
import inspect
import sys
from dataclasses import dataclass, field
from functools import cached_property
from functools import cached_property, partial
from typing import Any, Callable, List, Optional, Sequence, Union
from fastapi._compat import ModelField
@ -79,7 +79,10 @@ class Dependant:
def _unwrapped_call(self) -> Any:
if self.call is None:
return self.call # pragma: no cover
return inspect.unwrap(self.call)
unwrapped = inspect.unwrap(self.call)
if isinstance(unwrapped, partial):
unwrapped = unwrapped.func
return unwrapped
@cached_property
def is_gen_callable(self) -> bool:

View File

@ -0,0 +1,251 @@
from functools import partial
from typing import AsyncGenerator, Generator
import pytest
from fastapi import Depends, FastAPI
from fastapi.testclient import TestClient
from typing_extensions import Annotated
app = FastAPI()
def function_dependency(value: str) -> str:
return value
async def async_function_dependency(value: str) -> str:
return value
def gen_dependency(value: str) -> Generator[str, None, None]:
yield value
async def async_gen_dependency(value: str) -> AsyncGenerator[str, None]:
yield value
class CallableDependency:
def __call__(self, value: str) -> str:
return value
class CallableGenDependency:
def __call__(self, value: str) -> Generator[str, None, None]:
yield value
class AsyncCallableDependency:
async def __call__(self, value: str) -> str:
return value
class AsyncCallableGenDependency:
async def __call__(self, value: str) -> AsyncGenerator[str, None]:
yield value
class MethodsDependency:
def synchronous(self, value: str) -> str:
return value
async def asynchronous(self, value: str) -> str:
return value
def synchronous_gen(self, value: str) -> Generator[str, None, None]:
yield value
async def asynchronous_gen(self, value: str) -> AsyncGenerator[str, None]:
yield value
callable_dependency = CallableDependency()
callable_gen_dependency = CallableGenDependency()
async_callable_dependency = AsyncCallableDependency()
async_callable_gen_dependency = AsyncCallableGenDependency()
methods_dependency = MethodsDependency()
@app.get("/partial-function-dependency")
async def get_partial_function_dependency(
value: Annotated[
str, Depends(partial(function_dependency, "partial-function-dependency"))
],
) -> str:
return value
@app.get("/partial-async-function-dependency")
async def get_partial_async_function_dependency(
value: Annotated[
str,
Depends(
partial(async_function_dependency, "partial-async-function-dependency")
),
],
) -> str:
return value
@app.get("/partial-gen-dependency")
async def get_partial_gen_dependency(
value: Annotated[str, Depends(partial(gen_dependency, "partial-gen-dependency"))],
) -> str:
return value
@app.get("/partial-async-gen-dependency")
async def get_partial_async_gen_dependency(
value: Annotated[
str, Depends(partial(async_gen_dependency, "partial-async-gen-dependency"))
],
) -> str:
return value
@app.get("/partial-callable-dependency")
async def get_partial_callable_dependency(
value: Annotated[
str, Depends(partial(callable_dependency, "partial-callable-dependency"))
],
) -> str:
return value
@app.get("/partial-callable-gen-dependency")
async def get_partial_callable_gen_dependency(
value: Annotated[
str,
Depends(partial(callable_gen_dependency, "partial-callable-gen-dependency")),
],
) -> str:
return value
@app.get("/partial-async-callable-dependency")
async def get_partial_async_callable_dependency(
value: Annotated[
str,
Depends(
partial(async_callable_dependency, "partial-async-callable-dependency")
),
],
) -> str:
return value
@app.get("/partial-async-callable-gen-dependency")
async def get_partial_async_callable_gen_dependency(
value: Annotated[
str,
Depends(
partial(
async_callable_gen_dependency, "partial-async-callable-gen-dependency"
)
),
],
) -> str:
return value
@app.get("/partial-synchronous-method-dependency")
async def get_partial_synchronous_method_dependency(
value: Annotated[
str,
Depends(
partial(
methods_dependency.synchronous, "partial-synchronous-method-dependency"
)
),
],
) -> str:
return value
@app.get("/partial-synchronous-method-gen-dependency")
async def get_partial_synchronous_method_gen_dependency(
value: Annotated[
str,
Depends(
partial(
methods_dependency.synchronous_gen,
"partial-synchronous-method-gen-dependency",
)
),
],
) -> str:
return value
@app.get("/partial-asynchronous-method-dependency")
async def get_partial_asynchronous_method_dependency(
value: Annotated[
str,
Depends(
partial(
methods_dependency.asynchronous,
"partial-asynchronous-method-dependency",
)
),
],
) -> str:
return value
@app.get("/partial-asynchronous-method-gen-dependency")
async def get_partial_asynchronous_method_gen_dependency(
value: Annotated[
str,
Depends(
partial(
methods_dependency.asynchronous_gen,
"partial-asynchronous-method-gen-dependency",
)
),
],
) -> str:
return value
client = TestClient(app)
@pytest.mark.parametrize(
"route,value",
[
("/partial-function-dependency", "partial-function-dependency"),
(
"/partial-async-function-dependency",
"partial-async-function-dependency",
),
("/partial-gen-dependency", "partial-gen-dependency"),
("/partial-async-gen-dependency", "partial-async-gen-dependency"),
("/partial-callable-dependency", "partial-callable-dependency"),
("/partial-callable-gen-dependency", "partial-callable-gen-dependency"),
("/partial-async-callable-dependency", "partial-async-callable-dependency"),
(
"/partial-async-callable-gen-dependency",
"partial-async-callable-gen-dependency",
),
(
"/partial-synchronous-method-dependency",
"partial-synchronous-method-dependency",
),
(
"/partial-synchronous-method-gen-dependency",
"partial-synchronous-method-gen-dependency",
),
(
"/partial-asynchronous-method-dependency",
"partial-asynchronous-method-dependency",
),
(
"/partial-asynchronous-method-gen-dependency",
"partial-asynchronous-method-gen-dependency",
),
],
)
def test_dependency_types_with_partial(route: str, value: str) -> None:
response = client.get(route)
assert response.status_code == 200, response.text
assert response.json() == value