diff --git a/docs/en/docs/release-notes.md b/docs/en/docs/release-notes.md index c397528aa..8c80851b4 100644 --- a/docs/en/docs/release-notes.md +++ b/docs/en/docs/release-notes.md @@ -7,6 +7,12 @@ hide: ## Latest Changes +## 0.123.6 + +### Fixes + +* 🐛 Fix support for functools wraps and partial combined, for async and regular functions and classes in path operations and dependencies. PR [#14448](https://github.com/fastapi/fastapi/pull/14448) by [@tiangolo](https://github.com/tiangolo). + ## 0.123.5 ### Features diff --git a/fastapi/__init__.py b/fastapi/__init__.py index 80b2a99c1..32223231e 100644 --- a/fastapi/__init__.py +++ b/fastapi/__init__.py @@ -1,6 +1,6 @@ """FastAPI framework, high performance, easy to learn, fast to code, ready for production""" -__version__ = "0.123.5" +__version__ = "0.123.6" from starlette import status as status diff --git a/fastapi/dependencies/models.py b/fastapi/dependencies/models.py index 2a4d9a010..9b545e4e5 100644 --- a/fastapi/dependencies/models.py +++ b/fastapi/dependencies/models.py @@ -15,6 +15,19 @@ else: # pragma: no cover from asyncio import iscoroutinefunction +def _unwrapped_call(call: Optional[Callable[..., Any]]) -> Any: + if call is None: + return call # pragma: no cover + unwrapped = inspect.unwrap(_impartial(call)) + return unwrapped + + +def _impartial(func: Callable[..., Any]) -> Callable[..., Any]: + while isinstance(func, partial): + func = func.func + return func + + @dataclass class SecurityRequirement: security_scheme: SecurityBase @@ -75,37 +88,82 @@ class Dependant: return True return False - @cached_property - def _unwrapped_call(self) -> Any: - if self.call is None: - return self.call # pragma: no cover - unwrapped = inspect.unwrap(self.call) - if isinstance(unwrapped, partial): - unwrapped = unwrapped.func - return unwrapped - @cached_property def is_gen_callable(self) -> bool: - if inspect.isgeneratorfunction(self._unwrapped_call): + if self.call is None: + return False # pragma: no cover + if inspect.isgeneratorfunction( + _impartial(self.call) + ) or inspect.isgeneratorfunction(_unwrapped_call(self.call)): return True - dunder_call = getattr(self._unwrapped_call, "__call__", None) # noqa: B004 - return inspect.isgeneratorfunction(dunder_call) + dunder_call = getattr(_impartial(self.call), "__call__", None) # noqa: B004 + if dunder_call is None: + return False # pragma: no cover + if inspect.isgeneratorfunction( + _impartial(dunder_call) + ) or inspect.isgeneratorfunction(_unwrapped_call(dunder_call)): + return True + dunder_unwrapped_call = getattr(_unwrapped_call(self.call), "__call__", None) # noqa: B004 + if dunder_unwrapped_call is None: + return False # pragma: no cover + if inspect.isgeneratorfunction( + _impartial(dunder_unwrapped_call) + ) or inspect.isgeneratorfunction(_unwrapped_call(dunder_unwrapped_call)): + return True + return False @cached_property def is_async_gen_callable(self) -> bool: - if inspect.isasyncgenfunction(self._unwrapped_call): + if self.call is None: + return False # pragma: no cover + if inspect.isasyncgenfunction( + _impartial(self.call) + ) or inspect.isasyncgenfunction(_unwrapped_call(self.call)): return True - dunder_call = getattr(self._unwrapped_call, "__call__", None) # noqa: B004 - return inspect.isasyncgenfunction(dunder_call) + dunder_call = getattr(_impartial(self.call), "__call__", None) # noqa: B004 + if dunder_call is None: + return False # pragma: no cover + if inspect.isasyncgenfunction( + _impartial(dunder_call) + ) or inspect.isasyncgenfunction(_unwrapped_call(dunder_call)): + return True + dunder_unwrapped_call = getattr(_unwrapped_call(self.call), "__call__", None) # noqa: B004 + if dunder_unwrapped_call is None: + return False # pragma: no cover + if inspect.isasyncgenfunction( + _impartial(dunder_unwrapped_call) + ) or inspect.isasyncgenfunction(_unwrapped_call(dunder_unwrapped_call)): + return True + return False @cached_property def is_coroutine_callable(self) -> bool: - if inspect.isroutine(self._unwrapped_call): - return iscoroutinefunction(self._unwrapped_call) - if inspect.isclass(self._unwrapped_call): - return False - dunder_call = getattr(self._unwrapped_call, "__call__", None) # noqa: B004 - return iscoroutinefunction(dunder_call) + if self.call is None: + return False # pragma: no cover + if inspect.isroutine(_impartial(self.call)) and iscoroutinefunction( + _impartial(self.call) + ): + return True + if inspect.isroutine(_unwrapped_call(self.call)) and iscoroutinefunction( + _unwrapped_call(self.call) + ): + return True + dunder_call = getattr(_impartial(self.call), "__call__", None) # noqa: B004 + if dunder_call is None: + return False # pragma: no cover + if iscoroutinefunction(_impartial(dunder_call)) or iscoroutinefunction( + _unwrapped_call(dunder_call) + ): + return True + dunder_unwrapped_call = getattr(_unwrapped_call(self.call), "__call__", None) # noqa: B004 + if dunder_unwrapped_call is None: + return False # pragma: no cover + if iscoroutinefunction( + _impartial(dunder_unwrapped_call) + ) or iscoroutinefunction(_unwrapped_call(dunder_unwrapped_call)): + return True + # if inspect.isclass(self.call): False, covered by default return + return False @cached_property def computed_scope(self) -> Union[str, None]: diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index f574ffca0..1ff35f648 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -555,10 +555,10 @@ async def _solve_generator( *, dependant: Dependant, stack: AsyncExitStack, sub_values: Dict[str, Any] ) -> Any: assert dependant.call - if dependant.is_gen_callable: - cm = contextmanager_in_threadpool(contextmanager(dependant.call)(**sub_values)) - elif dependant.is_async_gen_callable: + if dependant.is_async_gen_callable: cm = asynccontextmanager(dependant.call)(**sub_values) + elif dependant.is_gen_callable: + cm = contextmanager_in_threadpool(contextmanager(dependant.call)(**sub_values)) return await stack.enter_async_context(cm) diff --git a/tests/test_dependency_wrapped.py b/tests/test_dependency_wrapped.py index f581ccba4..08356712d 100644 --- a/tests/test_dependency_wrapped.py +++ b/tests/test_dependency_wrapped.py @@ -1,10 +1,18 @@ +import inspect +import sys from functools import wraps from typing import AsyncGenerator, Generator import pytest from fastapi import Depends, FastAPI +from fastapi.concurrency import iterate_in_threadpool, run_in_threadpool from fastapi.testclient import TestClient +if sys.version_info >= (3, 13): # pragma: no cover + from inspect import iscoroutinefunction +else: # pragma: no cover + from asyncio import iscoroutinefunction + def noop_wrap(func): @wraps(func) @@ -14,8 +22,163 @@ def noop_wrap(func): return wrapper +def noop_wrap_async(func): + if inspect.isgeneratorfunction(func): + + @wraps(func) + async def gen_wrapper(*args, **kwargs): + async for item in iterate_in_threadpool(func(*args, **kwargs)): + yield item + + return gen_wrapper + + elif inspect.isasyncgenfunction(func): + + @wraps(func) + async def async_gen_wrapper(*args, **kwargs): + async for item in func(*args, **kwargs): + yield item + + return async_gen_wrapper + + @wraps(func) + async def wrapper(*args, **kwargs): + if inspect.isroutine(func) and iscoroutinefunction(func): + return await func(*args, **kwargs) + if inspect.isclass(func): + return await run_in_threadpool(func, *args, **kwargs) + dunder_call = getattr(func, "__call__", None) # noqa: B004 + if iscoroutinefunction(dunder_call): + return await dunder_call(*args, **kwargs) + return await run_in_threadpool(func, *args, **kwargs) + + return wrapper + + +class ClassInstanceDep: + def __call__(self): + return True + + +class_instance_dep = ClassInstanceDep() +wrapped_class_instance_dep = noop_wrap(class_instance_dep) +wrapped_class_instance_dep_async_wrapper = noop_wrap_async(class_instance_dep) + + +class ClassInstanceGenDep: + def __call__(self): + yield True + + +class_instance_gen_dep = ClassInstanceGenDep() +wrapped_class_instance_gen_dep = noop_wrap(class_instance_gen_dep) + + +class ClassInstanceWrappedDep: + @noop_wrap + def __call__(self): + return True + + +class_instance_wrapped_dep = ClassInstanceWrappedDep() + + +class ClassInstanceWrappedAsyncDep: + @noop_wrap_async + def __call__(self): + return True + + +class_instance_wrapped_async_dep = ClassInstanceWrappedAsyncDep() + + +class ClassInstanceWrappedGenDep: + @noop_wrap + def __call__(self): + yield True + + +class_instance_wrapped_gen_dep = ClassInstanceWrappedGenDep() + + +class ClassInstanceWrappedAsyncGenDep: + @noop_wrap_async + def __call__(self): + yield True + + +class_instance_wrapped_async_gen_dep = ClassInstanceWrappedAsyncGenDep() + + +class ClassDep: + def __init__(self): + self.value = True + + +wrapped_class_dep = noop_wrap(ClassDep) +wrapped_class_dep_async_wrapper = noop_wrap_async(ClassDep) + + +class ClassInstanceAsyncDep: + async def __call__(self): + return True + + +class_instance_async_dep = ClassInstanceAsyncDep() +wrapped_class_instance_async_dep = noop_wrap(class_instance_async_dep) +wrapped_class_instance_async_dep_async_wrapper = noop_wrap_async( + class_instance_async_dep +) + + +class ClassInstanceAsyncGenDep: + async def __call__(self): + yield True + + +class_instance_async_gen_dep = ClassInstanceAsyncGenDep() +wrapped_class_instance_async_gen_dep = noop_wrap(class_instance_async_gen_dep) + + +class ClassInstanceAsyncWrappedDep: + @noop_wrap + async def __call__(self): + return True + + +class_instance_async_wrapped_dep = ClassInstanceAsyncWrappedDep() + + +class ClassInstanceAsyncWrappedAsyncDep: + @noop_wrap_async + async def __call__(self): + return True + + +class_instance_async_wrapped_async_dep = ClassInstanceAsyncWrappedAsyncDep() + + +class ClassInstanceAsyncWrappedGenDep: + @noop_wrap + async def __call__(self): + yield True + + +class_instance_async_wrapped_gen_dep = ClassInstanceAsyncWrappedGenDep() + + +class ClassInstanceAsyncWrappedGenAsyncDep: + @noop_wrap_async + async def __call__(self): + yield True + + +class_instance_async_wrapped_gen_async_dep = ClassInstanceAsyncWrappedGenAsyncDep() + app = FastAPI() +# Sync wrapper + @noop_wrap def wrapped_dependency() -> bool: @@ -59,16 +222,225 @@ async def get_async_wrapped_gen_dependency( return value +@app.get("/wrapped-class-instance-dependency/") +async def get_wrapped_class_instance_dependency( + value: bool = Depends(wrapped_class_instance_dep), +): + return value + + +@app.get("/wrapped-class-instance-async-dependency/") +async def get_wrapped_class_instance_async_dependency( + value: bool = Depends(wrapped_class_instance_async_dep), +): + return value + + +@app.get("/wrapped-class-instance-gen-dependency/") +async def get_wrapped_class_instance_gen_dependency( + value: bool = Depends(wrapped_class_instance_gen_dep), +): + return value + + +@app.get("/wrapped-class-instance-async-gen-dependency/") +async def get_wrapped_class_instance_async_gen_dependency( + value: bool = Depends(wrapped_class_instance_async_gen_dep), +): + return value + + +@app.get("/class-instance-wrapped-dependency/") +async def get_class_instance_wrapped_dependency( + value: bool = Depends(class_instance_wrapped_dep), +): + return value + + +@app.get("/class-instance-wrapped-async-dependency/") +async def get_class_instance_wrapped_async_dependency( + value: bool = Depends(class_instance_wrapped_async_dep), +): + return value + + +@app.get("/class-instance-async-wrapped-dependency/") +async def get_class_instance_async_wrapped_dependency( + value: bool = Depends(class_instance_async_wrapped_dep), +): + return value + + +@app.get("/class-instance-async-wrapped-async-dependency/") +async def get_class_instance_async_wrapped_async_dependency( + value: bool = Depends(class_instance_async_wrapped_async_dep), +): + return value + + +@app.get("/class-instance-wrapped-gen-dependency/") +async def get_class_instance_wrapped_gen_dependency( + value: bool = Depends(class_instance_wrapped_gen_dep), +): + return value + + +@app.get("/class-instance-wrapped-async-gen-dependency/") +async def get_class_instance_wrapped_async_gen_dependency( + value: bool = Depends(class_instance_wrapped_async_gen_dep), +): + return value + + +@app.get("/class-instance-async-wrapped-gen-dependency/") +async def get_class_instance_async_wrapped_gen_dependency( + value: bool = Depends(class_instance_async_wrapped_gen_dep), +): + return value + + +@app.get("/class-instance-async-wrapped-gen-async-dependency/") +async def get_class_instance_async_wrapped_gen_async_dependency( + value: bool = Depends(class_instance_async_wrapped_gen_async_dep), +): + return value + + +@app.get("/wrapped-class-dependency/") +async def get_wrapped_class_dependency(value: ClassDep = Depends(wrapped_class_dep)): + return value.value + + +@app.get("/wrapped-endpoint/") +@noop_wrap +def get_wrapped_endpoint(): + return True + + +@app.get("/async-wrapped-endpoint/") +@noop_wrap +async def get_async_wrapped_endpoint(): + return True + + +# Async wrapper + + +@noop_wrap_async +def wrapped_dependency_async_wrapper() -> bool: + return True + + +@noop_wrap_async +def wrapped_gen_dependency_async_wrapper() -> Generator[bool, None, None]: + yield True + + +@noop_wrap_async +async def async_wrapped_dependency_async_wrapper() -> bool: + return True + + +@noop_wrap_async +async def async_wrapped_gen_dependency_async_wrapper() -> AsyncGenerator[bool, None]: + yield True + + +@app.get("/wrapped-dependency-async-wrapper/") +async def get_wrapped_dependency_async_wrapper( + value: bool = Depends(wrapped_dependency_async_wrapper), +): + return value + + +@app.get("/wrapped-gen-dependency-async-wrapper/") +async def get_wrapped_gen_dependency_async_wrapper( + value: bool = Depends(wrapped_gen_dependency_async_wrapper), +): + return value + + +@app.get("/async-wrapped-dependency-async-wrapper/") +async def get_async_wrapped_dependency_async_wrapper( + value: bool = Depends(async_wrapped_dependency_async_wrapper), +): + return value + + +@app.get("/async-wrapped-gen-dependency-async-wrapper/") +async def get_async_wrapped_gen_dependency_async_wrapper( + value: bool = Depends(async_wrapped_gen_dependency_async_wrapper), +): + return value + + +@app.get("/wrapped-class-instance-dependency-async-wrapper/") +async def get_wrapped_class_instance_dependency_async_wrapper( + value: bool = Depends(wrapped_class_instance_dep_async_wrapper), +): + return value + + +@app.get("/wrapped-class-instance-async-dependency-async-wrapper/") +async def get_wrapped_class_instance_async_dependency_async_wrapper( + value: bool = Depends(wrapped_class_instance_async_dep_async_wrapper), +): + return value + + +@app.get("/wrapped-class-dependency-async-wrapper/") +async def get_wrapped_class_dependency_async_wrapper( + value: ClassDep = Depends(wrapped_class_dep_async_wrapper), +): + return value.value + + +@app.get("/wrapped-endpoint-async-wrapper/") +@noop_wrap_async +def get_wrapped_endpoint_async_wrapper(): + return True + + +@app.get("/async-wrapped-endpoint-async-wrapper/") +@noop_wrap_async +async def get_async_wrapped_endpoint_async_wrapper(): + return True + + client = TestClient(app) @pytest.mark.parametrize( "route", [ - "/wrapped-dependency", - "/wrapped-gen-dependency", - "/async-wrapped-dependency", - "/async-wrapped-gen-dependency", + "/wrapped-dependency/", + "/wrapped-gen-dependency/", + "/async-wrapped-dependency/", + "/async-wrapped-gen-dependency/", + "/wrapped-class-instance-dependency/", + "/wrapped-class-instance-async-dependency/", + "/wrapped-class-instance-gen-dependency/", + "/wrapped-class-instance-async-gen-dependency/", + "/class-instance-wrapped-dependency/", + "/class-instance-wrapped-async-dependency/", + "/class-instance-async-wrapped-dependency/", + "/class-instance-async-wrapped-async-dependency/", + "/class-instance-wrapped-gen-dependency/", + "/class-instance-wrapped-async-gen-dependency/", + "/class-instance-async-wrapped-gen-dependency/", + "/class-instance-async-wrapped-gen-async-dependency/", + "/wrapped-class-dependency/", + "/wrapped-endpoint/", + "/async-wrapped-endpoint/", + "/wrapped-dependency-async-wrapper/", + "/wrapped-gen-dependency-async-wrapper/", + "/async-wrapped-dependency-async-wrapper/", + "/async-wrapped-gen-dependency-async-wrapper/", + "/wrapped-class-instance-dependency-async-wrapper/", + "/wrapped-class-instance-async-dependency-async-wrapper/", + "/wrapped-class-dependency-async-wrapper/", + "/wrapped-endpoint-async-wrapper/", + "/async-wrapped-endpoint-async-wrapper/", ], ) def test_class_dependency(route):