diff --git a/tests/test_dependency_wrapped.py b/tests/test_dependency_wrapped.py index 52dd0048de..276f9460fe 100644 --- a/tests/test_dependency_wrapped.py +++ b/tests/test_dependency_wrapped.py @@ -5,7 +5,7 @@ from typing import AsyncGenerator, Generator import pytest from fastapi import Depends, FastAPI -from fastapi.concurrency import run_in_threadpool +from fastapi.concurrency import iterate_in_threadpool, run_in_threadpool from fastapi.testclient import TestClient @@ -18,10 +18,26 @@ def noop_wrap(func): def noop_wrap_async(func): + if inspect.isgeneratorfunction(func): + + @wraps(func) + async def gen_wrapper(*args, **kwargs): + async for item in iterate_in_threadpool(func(*args, **kwargs)): + yield item + + return gen_wrapper + + elif inspect.isasyncgenfunction(func): + + @wraps(func) + async def async_gen_wrapper(*args, **kwargs): + async for item in func(*args, **kwargs): + yield item + + return async_gen_wrapper + @wraps(func) async def wrapper(*args, **kwargs): - if inspect.isgeneratorfunction(func) or inspect.isasyncgenfunction(func): - return # TODO: Handle generator functions if inspect.isroutine(func) and iscoroutinefunction(func): return await func(*args, **kwargs) if inspect.isclass(func):