From 73c411e1b92a1659fef76655562e6d2f28be064e Mon Sep 17 00:00:00 2001 From: Matthew Martin Date: Tue, 2 Dec 2025 07:34:19 -0600 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20Handle=20wrapped=20dependencies=20(?= =?UTF-8?q?#9555)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Motov Yurii <109919500+YuriiMotov@users.noreply.github.com> Co-authored-by: Yurii Motov Co-authored-by: Sebastián Ramírez --- fastapi/dependencies/models.py | 22 +++++---- tests/test_dependency_wrapped.py | 77 ++++++++++++++++++++++++++++++++ 2 files changed, 91 insertions(+), 8 deletions(-) create mode 100644 tests/test_dependency_wrapped.py diff --git a/fastapi/dependencies/models.py b/fastapi/dependencies/models.py index fbb666a7d..13486dd18 100644 --- a/fastapi/dependencies/models.py +++ b/fastapi/dependencies/models.py @@ -75,27 +75,33 @@ class Dependant: return True return False + @cached_property + def _unwrapped_call(self) -> Any: + if self.call is None: + return self.call # pragma: no cover + return inspect.unwrap(self.call) + @cached_property def is_gen_callable(self) -> bool: - if inspect.isgeneratorfunction(self.call): + if inspect.isgeneratorfunction(self._unwrapped_call): return True - dunder_call = getattr(self.call, "__call__", None) # noqa: B004 + dunder_call = getattr(self._unwrapped_call, "__call__", None) # noqa: B004 return inspect.isgeneratorfunction(dunder_call) @cached_property def is_async_gen_callable(self) -> bool: - if inspect.isasyncgenfunction(self.call): + if inspect.isasyncgenfunction(self._unwrapped_call): return True - dunder_call = getattr(self.call, "__call__", None) # noqa: B004 + dunder_call = getattr(self._unwrapped_call, "__call__", None) # noqa: B004 return inspect.isasyncgenfunction(dunder_call) @cached_property def is_coroutine_callable(self) -> bool: - if inspect.isroutine(self.call): - return iscoroutinefunction(self.call) - if inspect.isclass(self.call): + if inspect.isroutine(self._unwrapped_call): + return iscoroutinefunction(self._unwrapped_call) + if inspect.isclass(self._unwrapped_call): return False - dunder_call = getattr(self.call, "__call__", None) # noqa: B004 + dunder_call = getattr(self._unwrapped_call, "__call__", None) # noqa: B004 return iscoroutinefunction(dunder_call) @cached_property diff --git a/tests/test_dependency_wrapped.py b/tests/test_dependency_wrapped.py new file mode 100644 index 000000000..f581ccba4 --- /dev/null +++ b/tests/test_dependency_wrapped.py @@ -0,0 +1,77 @@ +from functools import wraps +from typing import AsyncGenerator, Generator + +import pytest +from fastapi import Depends, FastAPI +from fastapi.testclient import TestClient + + +def noop_wrap(func): + @wraps(func) + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + + return wrapper + + +app = FastAPI() + + +@noop_wrap +def wrapped_dependency() -> bool: + return True + + +@noop_wrap +def wrapped_gen_dependency() -> Generator[bool, None, None]: + yield True + + +@noop_wrap +async def async_wrapped_dependency() -> bool: + return True + + +@noop_wrap +async def async_wrapped_gen_dependency() -> AsyncGenerator[bool, None]: + yield True + + +@app.get("/wrapped-dependency/") +async def get_wrapped_dependency(value: bool = Depends(wrapped_dependency)): + return value + + +@app.get("/wrapped-gen-dependency/") +async def get_wrapped_gen_dependency(value: bool = Depends(wrapped_gen_dependency)): + return value + + +@app.get("/async-wrapped-dependency/") +async def get_async_wrapped_dependency(value: bool = Depends(async_wrapped_dependency)): + return value + + +@app.get("/async-wrapped-gen-dependency/") +async def get_async_wrapped_gen_dependency( + value: bool = Depends(async_wrapped_gen_dependency), +): + return value + + +client = TestClient(app) + + +@pytest.mark.parametrize( + "route", + [ + "/wrapped-dependency", + "/wrapped-gen-dependency", + "/async-wrapped-dependency", + "/async-wrapped-gen-dependency", + ], +) +def test_class_dependency(route): + response = client.get(route) + assert response.status_code == 200, response.text + assert response.json() is True