From bba4d4c95e0de8fc9c99f88c269909d37d64d494 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Wed, 3 Dec 2025 23:29:28 -0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20Fix=20support=20for=20functools?= =?UTF-8?q?=20wraps=20and=20partial=20combined,=20for=20async=20and=20regu?= =?UTF-8?q?lar=20functions=20and=20classes=20in=20path=20operations=20and?= =?UTF-8?q?=20dependencies=20(#14448)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Yurii Motov --- fastapi/dependencies/models.py | 100 ++++++-- fastapi/dependencies/utils.py | 6 +- tests/test_dependency_wrapped.py | 380 ++++++++++++++++++++++++++++++- 3 files changed, 458 insertions(+), 28 deletions(-) diff --git a/fastapi/dependencies/models.py b/fastapi/dependencies/models.py index 2a4d9a010..9b545e4e5 100644 --- a/fastapi/dependencies/models.py +++ b/fastapi/dependencies/models.py @@ -15,6 +15,19 @@ else: # pragma: no cover from asyncio import iscoroutinefunction +def _unwrapped_call(call: Optional[Callable[..., Any]]) -> Any: + if call is None: + return call # pragma: no cover + unwrapped = inspect.unwrap(_impartial(call)) + return unwrapped + + +def _impartial(func: Callable[..., Any]) -> Callable[..., Any]: + while isinstance(func, partial): + func = func.func + return func + + @dataclass class SecurityRequirement: security_scheme: SecurityBase @@ -75,37 +88,82 @@ class Dependant: return True return False - @cached_property - def _unwrapped_call(self) -> Any: - if self.call is None: - return self.call # pragma: no cover - unwrapped = inspect.unwrap(self.call) - if isinstance(unwrapped, partial): - unwrapped = unwrapped.func - return unwrapped - @cached_property def is_gen_callable(self) -> bool: - if inspect.isgeneratorfunction(self._unwrapped_call): + if self.call is None: + return False # pragma: no cover + if inspect.isgeneratorfunction( + _impartial(self.call) + ) or inspect.isgeneratorfunction(_unwrapped_call(self.call)): return True - dunder_call = getattr(self._unwrapped_call, "__call__", None) # noqa: B004 - return inspect.isgeneratorfunction(dunder_call) + dunder_call = getattr(_impartial(self.call), "__call__", None) # noqa: B004 + if dunder_call is None: + return False # pragma: no cover + if inspect.isgeneratorfunction( + _impartial(dunder_call) + ) or inspect.isgeneratorfunction(_unwrapped_call(dunder_call)): + return True + dunder_unwrapped_call = getattr(_unwrapped_call(self.call), "__call__", None) # noqa: B004 + if dunder_unwrapped_call is None: + return False # pragma: no cover + if inspect.isgeneratorfunction( + _impartial(dunder_unwrapped_call) + ) or inspect.isgeneratorfunction(_unwrapped_call(dunder_unwrapped_call)): + return True + return False @cached_property def is_async_gen_callable(self) -> bool: - if inspect.isasyncgenfunction(self._unwrapped_call): + if self.call is None: + return False # pragma: no cover + if inspect.isasyncgenfunction( + _impartial(self.call) + ) or inspect.isasyncgenfunction(_unwrapped_call(self.call)): return True - dunder_call = getattr(self._unwrapped_call, "__call__", None) # noqa: B004 - return inspect.isasyncgenfunction(dunder_call) + dunder_call = getattr(_impartial(self.call), "__call__", None) # noqa: B004 + if dunder_call is None: + return False # pragma: no cover + if inspect.isasyncgenfunction( + _impartial(dunder_call) + ) or inspect.isasyncgenfunction(_unwrapped_call(dunder_call)): + return True + dunder_unwrapped_call = getattr(_unwrapped_call(self.call), "__call__", None) # noqa: B004 + if dunder_unwrapped_call is None: + return False # pragma: no cover + if inspect.isasyncgenfunction( + _impartial(dunder_unwrapped_call) + ) or inspect.isasyncgenfunction(_unwrapped_call(dunder_unwrapped_call)): + return True + return False @cached_property def is_coroutine_callable(self) -> bool: - if inspect.isroutine(self._unwrapped_call): - return iscoroutinefunction(self._unwrapped_call) - if inspect.isclass(self._unwrapped_call): - return False - dunder_call = getattr(self._unwrapped_call, "__call__", None) # noqa: B004 - return iscoroutinefunction(dunder_call) + if self.call is None: + return False # pragma: no cover + if inspect.isroutine(_impartial(self.call)) and iscoroutinefunction( + _impartial(self.call) + ): + return True + if inspect.isroutine(_unwrapped_call(self.call)) and iscoroutinefunction( + _unwrapped_call(self.call) + ): + return True + dunder_call = getattr(_impartial(self.call), "__call__", None) # noqa: B004 + if dunder_call is None: + return False # pragma: no cover + if iscoroutinefunction(_impartial(dunder_call)) or iscoroutinefunction( + _unwrapped_call(dunder_call) + ): + return True + dunder_unwrapped_call = getattr(_unwrapped_call(self.call), "__call__", None) # noqa: B004 + if dunder_unwrapped_call is None: + return False # pragma: no cover + if iscoroutinefunction( + _impartial(dunder_unwrapped_call) + ) or iscoroutinefunction(_unwrapped_call(dunder_unwrapped_call)): + return True + # if inspect.isclass(self.call): False, covered by default return + return False @cached_property def computed_scope(self) -> Union[str, None]: diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index 1a493a9fd..91348c8ea 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -548,10 +548,10 @@ async def _solve_generator( *, dependant: Dependant, stack: AsyncExitStack, sub_values: Dict[str, Any] ) -> Any: assert dependant.call - if dependant.is_gen_callable: - cm = contextmanager_in_threadpool(contextmanager(dependant.call)(**sub_values)) - elif dependant.is_async_gen_callable: + if dependant.is_async_gen_callable: cm = asynccontextmanager(dependant.call)(**sub_values) + elif dependant.is_gen_callable: + cm = contextmanager_in_threadpool(contextmanager(dependant.call)(**sub_values)) return await stack.enter_async_context(cm) diff --git a/tests/test_dependency_wrapped.py b/tests/test_dependency_wrapped.py index f581ccba4..08356712d 100644 --- a/tests/test_dependency_wrapped.py +++ b/tests/test_dependency_wrapped.py @@ -1,10 +1,18 @@ +import inspect +import sys from functools import wraps from typing import AsyncGenerator, Generator import pytest from fastapi import Depends, FastAPI +from fastapi.concurrency import iterate_in_threadpool, run_in_threadpool from fastapi.testclient import TestClient +if sys.version_info >= (3, 13): # pragma: no cover + from inspect import iscoroutinefunction +else: # pragma: no cover + from asyncio import iscoroutinefunction + def noop_wrap(func): @wraps(func) @@ -14,8 +22,163 @@ def noop_wrap(func): return wrapper +def noop_wrap_async(func): + if inspect.isgeneratorfunction(func): + + @wraps(func) + async def gen_wrapper(*args, **kwargs): + async for item in iterate_in_threadpool(func(*args, **kwargs)): + yield item + + return gen_wrapper + + elif inspect.isasyncgenfunction(func): + + @wraps(func) + async def async_gen_wrapper(*args, **kwargs): + async for item in func(*args, **kwargs): + yield item + + return async_gen_wrapper + + @wraps(func) + async def wrapper(*args, **kwargs): + if inspect.isroutine(func) and iscoroutinefunction(func): + return await func(*args, **kwargs) + if inspect.isclass(func): + return await run_in_threadpool(func, *args, **kwargs) + dunder_call = getattr(func, "__call__", None) # noqa: B004 + if iscoroutinefunction(dunder_call): + return await dunder_call(*args, **kwargs) + return await run_in_threadpool(func, *args, **kwargs) + + return wrapper + + +class ClassInstanceDep: + def __call__(self): + return True + + +class_instance_dep = ClassInstanceDep() +wrapped_class_instance_dep = noop_wrap(class_instance_dep) +wrapped_class_instance_dep_async_wrapper = noop_wrap_async(class_instance_dep) + + +class ClassInstanceGenDep: + def __call__(self): + yield True + + +class_instance_gen_dep = ClassInstanceGenDep() +wrapped_class_instance_gen_dep = noop_wrap(class_instance_gen_dep) + + +class ClassInstanceWrappedDep: + @noop_wrap + def __call__(self): + return True + + +class_instance_wrapped_dep = ClassInstanceWrappedDep() + + +class ClassInstanceWrappedAsyncDep: + @noop_wrap_async + def __call__(self): + return True + + +class_instance_wrapped_async_dep = ClassInstanceWrappedAsyncDep() + + +class ClassInstanceWrappedGenDep: + @noop_wrap + def __call__(self): + yield True + + +class_instance_wrapped_gen_dep = ClassInstanceWrappedGenDep() + + +class ClassInstanceWrappedAsyncGenDep: + @noop_wrap_async + def __call__(self): + yield True + + +class_instance_wrapped_async_gen_dep = ClassInstanceWrappedAsyncGenDep() + + +class ClassDep: + def __init__(self): + self.value = True + + +wrapped_class_dep = noop_wrap(ClassDep) +wrapped_class_dep_async_wrapper = noop_wrap_async(ClassDep) + + +class ClassInstanceAsyncDep: + async def __call__(self): + return True + + +class_instance_async_dep = ClassInstanceAsyncDep() +wrapped_class_instance_async_dep = noop_wrap(class_instance_async_dep) +wrapped_class_instance_async_dep_async_wrapper = noop_wrap_async( + class_instance_async_dep +) + + +class ClassInstanceAsyncGenDep: + async def __call__(self): + yield True + + +class_instance_async_gen_dep = ClassInstanceAsyncGenDep() +wrapped_class_instance_async_gen_dep = noop_wrap(class_instance_async_gen_dep) + + +class ClassInstanceAsyncWrappedDep: + @noop_wrap + async def __call__(self): + return True + + +class_instance_async_wrapped_dep = ClassInstanceAsyncWrappedDep() + + +class ClassInstanceAsyncWrappedAsyncDep: + @noop_wrap_async + async def __call__(self): + return True + + +class_instance_async_wrapped_async_dep = ClassInstanceAsyncWrappedAsyncDep() + + +class ClassInstanceAsyncWrappedGenDep: + @noop_wrap + async def __call__(self): + yield True + + +class_instance_async_wrapped_gen_dep = ClassInstanceAsyncWrappedGenDep() + + +class ClassInstanceAsyncWrappedGenAsyncDep: + @noop_wrap_async + async def __call__(self): + yield True + + +class_instance_async_wrapped_gen_async_dep = ClassInstanceAsyncWrappedGenAsyncDep() + app = FastAPI() +# Sync wrapper + @noop_wrap def wrapped_dependency() -> bool: @@ -59,16 +222,225 @@ async def get_async_wrapped_gen_dependency( return value +@app.get("/wrapped-class-instance-dependency/") +async def get_wrapped_class_instance_dependency( + value: bool = Depends(wrapped_class_instance_dep), +): + return value + + +@app.get("/wrapped-class-instance-async-dependency/") +async def get_wrapped_class_instance_async_dependency( + value: bool = Depends(wrapped_class_instance_async_dep), +): + return value + + +@app.get("/wrapped-class-instance-gen-dependency/") +async def get_wrapped_class_instance_gen_dependency( + value: bool = Depends(wrapped_class_instance_gen_dep), +): + return value + + +@app.get("/wrapped-class-instance-async-gen-dependency/") +async def get_wrapped_class_instance_async_gen_dependency( + value: bool = Depends(wrapped_class_instance_async_gen_dep), +): + return value + + +@app.get("/class-instance-wrapped-dependency/") +async def get_class_instance_wrapped_dependency( + value: bool = Depends(class_instance_wrapped_dep), +): + return value + + +@app.get("/class-instance-wrapped-async-dependency/") +async def get_class_instance_wrapped_async_dependency( + value: bool = Depends(class_instance_wrapped_async_dep), +): + return value + + +@app.get("/class-instance-async-wrapped-dependency/") +async def get_class_instance_async_wrapped_dependency( + value: bool = Depends(class_instance_async_wrapped_dep), +): + return value + + +@app.get("/class-instance-async-wrapped-async-dependency/") +async def get_class_instance_async_wrapped_async_dependency( + value: bool = Depends(class_instance_async_wrapped_async_dep), +): + return value + + +@app.get("/class-instance-wrapped-gen-dependency/") +async def get_class_instance_wrapped_gen_dependency( + value: bool = Depends(class_instance_wrapped_gen_dep), +): + return value + + +@app.get("/class-instance-wrapped-async-gen-dependency/") +async def get_class_instance_wrapped_async_gen_dependency( + value: bool = Depends(class_instance_wrapped_async_gen_dep), +): + return value + + +@app.get("/class-instance-async-wrapped-gen-dependency/") +async def get_class_instance_async_wrapped_gen_dependency( + value: bool = Depends(class_instance_async_wrapped_gen_dep), +): + return value + + +@app.get("/class-instance-async-wrapped-gen-async-dependency/") +async def get_class_instance_async_wrapped_gen_async_dependency( + value: bool = Depends(class_instance_async_wrapped_gen_async_dep), +): + return value + + +@app.get("/wrapped-class-dependency/") +async def get_wrapped_class_dependency(value: ClassDep = Depends(wrapped_class_dep)): + return value.value + + +@app.get("/wrapped-endpoint/") +@noop_wrap +def get_wrapped_endpoint(): + return True + + +@app.get("/async-wrapped-endpoint/") +@noop_wrap +async def get_async_wrapped_endpoint(): + return True + + +# Async wrapper + + +@noop_wrap_async +def wrapped_dependency_async_wrapper() -> bool: + return True + + +@noop_wrap_async +def wrapped_gen_dependency_async_wrapper() -> Generator[bool, None, None]: + yield True + + +@noop_wrap_async +async def async_wrapped_dependency_async_wrapper() -> bool: + return True + + +@noop_wrap_async +async def async_wrapped_gen_dependency_async_wrapper() -> AsyncGenerator[bool, None]: + yield True + + +@app.get("/wrapped-dependency-async-wrapper/") +async def get_wrapped_dependency_async_wrapper( + value: bool = Depends(wrapped_dependency_async_wrapper), +): + return value + + +@app.get("/wrapped-gen-dependency-async-wrapper/") +async def get_wrapped_gen_dependency_async_wrapper( + value: bool = Depends(wrapped_gen_dependency_async_wrapper), +): + return value + + +@app.get("/async-wrapped-dependency-async-wrapper/") +async def get_async_wrapped_dependency_async_wrapper( + value: bool = Depends(async_wrapped_dependency_async_wrapper), +): + return value + + +@app.get("/async-wrapped-gen-dependency-async-wrapper/") +async def get_async_wrapped_gen_dependency_async_wrapper( + value: bool = Depends(async_wrapped_gen_dependency_async_wrapper), +): + return value + + +@app.get("/wrapped-class-instance-dependency-async-wrapper/") +async def get_wrapped_class_instance_dependency_async_wrapper( + value: bool = Depends(wrapped_class_instance_dep_async_wrapper), +): + return value + + +@app.get("/wrapped-class-instance-async-dependency-async-wrapper/") +async def get_wrapped_class_instance_async_dependency_async_wrapper( + value: bool = Depends(wrapped_class_instance_async_dep_async_wrapper), +): + return value + + +@app.get("/wrapped-class-dependency-async-wrapper/") +async def get_wrapped_class_dependency_async_wrapper( + value: ClassDep = Depends(wrapped_class_dep_async_wrapper), +): + return value.value + + +@app.get("/wrapped-endpoint-async-wrapper/") +@noop_wrap_async +def get_wrapped_endpoint_async_wrapper(): + return True + + +@app.get("/async-wrapped-endpoint-async-wrapper/") +@noop_wrap_async +async def get_async_wrapped_endpoint_async_wrapper(): + return True + + client = TestClient(app) @pytest.mark.parametrize( "route", [ - "/wrapped-dependency", - "/wrapped-gen-dependency", - "/async-wrapped-dependency", - "/async-wrapped-gen-dependency", + "/wrapped-dependency/", + "/wrapped-gen-dependency/", + "/async-wrapped-dependency/", + "/async-wrapped-gen-dependency/", + "/wrapped-class-instance-dependency/", + "/wrapped-class-instance-async-dependency/", + "/wrapped-class-instance-gen-dependency/", + "/wrapped-class-instance-async-gen-dependency/", + "/class-instance-wrapped-dependency/", + "/class-instance-wrapped-async-dependency/", + "/class-instance-async-wrapped-dependency/", + "/class-instance-async-wrapped-async-dependency/", + "/class-instance-wrapped-gen-dependency/", + "/class-instance-wrapped-async-gen-dependency/", + "/class-instance-async-wrapped-gen-dependency/", + "/class-instance-async-wrapped-gen-async-dependency/", + "/wrapped-class-dependency/", + "/wrapped-endpoint/", + "/async-wrapped-endpoint/", + "/wrapped-dependency-async-wrapper/", + "/wrapped-gen-dependency-async-wrapper/", + "/async-wrapped-dependency-async-wrapper/", + "/async-wrapped-gen-dependency-async-wrapper/", + "/wrapped-class-instance-dependency-async-wrapper/", + "/wrapped-class-instance-async-dependency-async-wrapper/", + "/wrapped-class-dependency-async-wrapper/", + "/wrapped-endpoint-async-wrapper/", + "/async-wrapped-endpoint-async-wrapper/", ], ) def test_class_dependency(route):