mirror of https://github.com/tiangolo/fastapi.git
🐛 Fix callable class generator dependencies (#1365)
* Fix callable class generator dependencies * workaround to support asynccontextmanager backfill for pre python3.7 Co-authored-by: Micah Rosales <mrosales@users.noreply.github.com>
This commit is contained in:
parent
a552cbdf59
commit
b90bf2da9e
|
|
@ -274,7 +274,7 @@ def get_dependant(
|
||||||
path_param_names = get_path_param_names(path)
|
path_param_names = get_path_param_names(path)
|
||||||
endpoint_signature = get_typed_signature(call)
|
endpoint_signature = get_typed_signature(call)
|
||||||
signature_params = endpoint_signature.parameters
|
signature_params = endpoint_signature.parameters
|
||||||
if inspect.isgeneratorfunction(call) or inspect.isasyncgenfunction(call):
|
if is_gen_callable(call) or is_async_gen_callable(call):
|
||||||
check_dependency_contextmanagers()
|
check_dependency_contextmanagers()
|
||||||
dependant = Dependant(call=call, name=name, path=path, use_cache=use_cache)
|
dependant = Dependant(call=call, name=name, path=path, use_cache=use_cache)
|
||||||
for param_name, param in signature_params.items():
|
for param_name, param in signature_params.items():
|
||||||
|
|
@ -412,19 +412,41 @@ def add_param_to_fields(*, field: ModelField, dependant: Dependant) -> None:
|
||||||
|
|
||||||
def is_coroutine_callable(call: Callable) -> bool:
|
def is_coroutine_callable(call: Callable) -> bool:
|
||||||
if inspect.isroutine(call):
|
if inspect.isroutine(call):
|
||||||
return asyncio.iscoroutinefunction(call)
|
return inspect.iscoroutinefunction(call)
|
||||||
if inspect.isclass(call):
|
if inspect.isclass(call):
|
||||||
return False
|
return False
|
||||||
call = getattr(call, "__call__", None)
|
call = getattr(call, "__call__", None)
|
||||||
return asyncio.iscoroutinefunction(call)
|
return inspect.iscoroutinefunction(call)
|
||||||
|
|
||||||
|
|
||||||
|
def is_async_gen_callable(call: Callable) -> bool:
|
||||||
|
if inspect.isasyncgenfunction(call):
|
||||||
|
return True
|
||||||
|
call = getattr(call, "__call__", None)
|
||||||
|
return inspect.isasyncgenfunction(call)
|
||||||
|
|
||||||
|
|
||||||
|
def is_gen_callable(call: Callable) -> bool:
|
||||||
|
if inspect.isgeneratorfunction(call):
|
||||||
|
return True
|
||||||
|
call = getattr(call, "__call__", None)
|
||||||
|
return inspect.isgeneratorfunction(call)
|
||||||
|
|
||||||
|
|
||||||
async def solve_generator(
|
async def solve_generator(
|
||||||
*, call: Callable, stack: AsyncExitStack, sub_values: Dict[str, Any]
|
*, call: Callable, stack: AsyncExitStack, sub_values: Dict[str, Any]
|
||||||
) -> Any:
|
) -> Any:
|
||||||
if inspect.isgeneratorfunction(call):
|
if is_gen_callable(call):
|
||||||
cm = contextmanager_in_threadpool(contextmanager(call)(**sub_values))
|
cm = contextmanager_in_threadpool(contextmanager(call)(**sub_values))
|
||||||
elif inspect.isasyncgenfunction(call):
|
elif is_async_gen_callable(call):
|
||||||
|
if not inspect.isasyncgenfunction(call):
|
||||||
|
# asynccontextmanager from the async_generator backfill pre python3.7
|
||||||
|
# does not support callables that are not functions or methods.
|
||||||
|
# See https://github.com/python-trio/async_generator/issues/32
|
||||||
|
#
|
||||||
|
# Expand the callable class into its __call__ method before decorating it.
|
||||||
|
# This approach will work on newer python versions as well.
|
||||||
|
call = getattr(call, "__call__", None)
|
||||||
cm = asynccontextmanager(call)(**sub_values)
|
cm = asynccontextmanager(call)(**sub_values)
|
||||||
return await stack.enter_async_context(cm)
|
return await stack.enter_async_context(cm)
|
||||||
|
|
||||||
|
|
@ -505,7 +527,7 @@ async def solve_dependencies(
|
||||||
continue
|
continue
|
||||||
if sub_dependant.use_cache and sub_dependant.cache_key in dependency_cache:
|
if sub_dependant.use_cache and sub_dependant.cache_key in dependency_cache:
|
||||||
solved = dependency_cache[sub_dependant.cache_key]
|
solved = dependency_cache[sub_dependant.cache_key]
|
||||||
elif inspect.isgeneratorfunction(call) or inspect.isasyncgenfunction(call):
|
elif is_gen_callable(call) or is_async_gen_callable(call):
|
||||||
stack = request.scope.get("fastapi_astack")
|
stack = request.scope.get("fastapi_astack")
|
||||||
if stack is None:
|
if stack is None:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
from typing import AsyncGenerator, Generator
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from fastapi import Depends, FastAPI
|
from fastapi import Depends, FastAPI
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
|
|
@ -10,11 +12,21 @@ class CallableDependency:
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
class CallableGenDependency:
|
||||||
|
def __call__(self, value: str) -> Generator[str, None, None]:
|
||||||
|
yield value
|
||||||
|
|
||||||
|
|
||||||
class AsyncCallableDependency:
|
class AsyncCallableDependency:
|
||||||
async def __call__(self, value: str) -> str:
|
async def __call__(self, value: str) -> str:
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncCallableGenDependency:
|
||||||
|
async def __call__(self, value: str) -> AsyncGenerator[str, None]:
|
||||||
|
yield value
|
||||||
|
|
||||||
|
|
||||||
class MethodsDependency:
|
class MethodsDependency:
|
||||||
def synchronous(self, value: str) -> str:
|
def synchronous(self, value: str) -> str:
|
||||||
return value
|
return value
|
||||||
|
|
@ -22,9 +34,17 @@ class MethodsDependency:
|
||||||
async def asynchronous(self, value: str) -> str:
|
async def asynchronous(self, value: str) -> str:
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
def synchronous_gen(self, value: str) -> Generator[str, None, None]:
|
||||||
|
yield value
|
||||||
|
|
||||||
|
async def asynchronous_gen(self, value: str) -> AsyncGenerator[str, None]:
|
||||||
|
yield value
|
||||||
|
|
||||||
|
|
||||||
callable_dependency = CallableDependency()
|
callable_dependency = CallableDependency()
|
||||||
|
callable_gen_dependency = CallableGenDependency()
|
||||||
async_callable_dependency = AsyncCallableDependency()
|
async_callable_dependency = AsyncCallableDependency()
|
||||||
|
async_callable_gen_dependency = AsyncCallableGenDependency()
|
||||||
methods_dependency = MethodsDependency()
|
methods_dependency = MethodsDependency()
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -33,11 +53,23 @@ async def get_callable_dependency(value: str = Depends(callable_dependency)):
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/callable-gen-dependency")
|
||||||
|
async def get_callable_gen_dependency(value: str = Depends(callable_gen_dependency)):
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
@app.get("/async-callable-dependency")
|
@app.get("/async-callable-dependency")
|
||||||
async def get_callable_dependency(value: str = Depends(async_callable_dependency)):
|
async def get_callable_dependency(value: str = Depends(async_callable_dependency)):
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/async-callable-gen-dependency")
|
||||||
|
async def get_callable_gen_dependency(
|
||||||
|
value: str = Depends(async_callable_gen_dependency),
|
||||||
|
):
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
@app.get("/synchronous-method-dependency")
|
@app.get("/synchronous-method-dependency")
|
||||||
async def get_synchronous_method_dependency(
|
async def get_synchronous_method_dependency(
|
||||||
value: str = Depends(methods_dependency.synchronous),
|
value: str = Depends(methods_dependency.synchronous),
|
||||||
|
|
@ -45,6 +77,13 @@ async def get_synchronous_method_dependency(
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/synchronous-method-gen-dependency")
|
||||||
|
async def get_synchronous_method_gen_dependency(
|
||||||
|
value: str = Depends(methods_dependency.synchronous_gen),
|
||||||
|
):
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
@app.get("/asynchronous-method-dependency")
|
@app.get("/asynchronous-method-dependency")
|
||||||
async def get_asynchronous_method_dependency(
|
async def get_asynchronous_method_dependency(
|
||||||
value: str = Depends(methods_dependency.asynchronous),
|
value: str = Depends(methods_dependency.asynchronous),
|
||||||
|
|
@ -52,6 +91,13 @@ async def get_asynchronous_method_dependency(
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/asynchronous-method-gen-dependency")
|
||||||
|
async def get_asynchronous_method_gen_dependency(
|
||||||
|
value: str = Depends(methods_dependency.asynchronous_gen),
|
||||||
|
):
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
client = TestClient(app)
|
client = TestClient(app)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -59,9 +105,13 @@ client = TestClient(app)
|
||||||
"route,value",
|
"route,value",
|
||||||
[
|
[
|
||||||
("/callable-dependency", "callable-dependency"),
|
("/callable-dependency", "callable-dependency"),
|
||||||
|
("/callable-gen-dependency", "callable-gen-dependency"),
|
||||||
("/async-callable-dependency", "async-callable-dependency"),
|
("/async-callable-dependency", "async-callable-dependency"),
|
||||||
|
("/async-callable-gen-dependency", "async-callable-gen-dependency"),
|
||||||
("/synchronous-method-dependency", "synchronous-method-dependency"),
|
("/synchronous-method-dependency", "synchronous-method-dependency"),
|
||||||
|
("/synchronous-method-gen-dependency", "synchronous-method-gen-dependency"),
|
||||||
("/asynchronous-method-dependency", "asynchronous-method-dependency"),
|
("/asynchronous-method-dependency", "asynchronous-method-dependency"),
|
||||||
|
("/asynchronous-method-gen-dependency", "asynchronous-method-gen-dependency"),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_class_dependency(route, value):
|
def test_class_dependency(route, value):
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue