From 83a5e5c7959b9a0fa708f668956c46316d9da525 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Thu, 4 Dec 2025 00:45:16 +0100 Subject: [PATCH] =?UTF-8?q?=E2=9C=85=20Update=20tests=20to=20cover=20wrapp?= =?UTF-8?q?ers=20for=20generators=20and=20async=20generators?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/test_dependency_wrapped.py | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/tests/test_dependency_wrapped.py b/tests/test_dependency_wrapped.py index 52dd0048de..276f9460fe 100644 --- a/tests/test_dependency_wrapped.py +++ b/tests/test_dependency_wrapped.py @@ -5,7 +5,7 @@ from typing import AsyncGenerator, Generator import pytest from fastapi import Depends, FastAPI -from fastapi.concurrency import run_in_threadpool +from fastapi.concurrency import iterate_in_threadpool, run_in_threadpool from fastapi.testclient import TestClient @@ -18,10 +18,26 @@ def noop_wrap(func): def noop_wrap_async(func): + if inspect.isgeneratorfunction(func): + + @wraps(func) + async def gen_wrapper(*args, **kwargs): + async for item in iterate_in_threadpool(func(*args, **kwargs)): + yield item + + return gen_wrapper + + elif inspect.isasyncgenfunction(func): + + @wraps(func) + async def async_gen_wrapper(*args, **kwargs): + async for item in func(*args, **kwargs): + yield item + + return async_gen_wrapper + @wraps(func) async def wrapper(*args, **kwargs): - if inspect.isgeneratorfunction(func) or inspect.isasyncgenfunction(func): - return # TODO: Handle generator functions if inspect.isroutine(func) and iscoroutinefunction(func): return await func(*args, **kwargs) if inspect.isclass(func):