mirror of https://github.com/tiangolo/fastapi.git
Merge branch 'master' into fix-duplicate-special-dependency-handling
This commit is contained in:
commit
fff0a93ecd
|
|
@ -7,6 +7,32 @@ hide:
|
|||
|
||||
## Latest Changes
|
||||
|
||||
## 0.123.9
|
||||
|
||||
### Fixes
|
||||
|
||||
* 🐛 Fix OAuth2 scopes in OpenAPI in extra corner cases, parent dependency with scopes, sub-dependency security scheme without scopes. PR [#14459](https://github.com/fastapi/fastapi/pull/14459) by [@tiangolo](https://github.com/tiangolo).
|
||||
|
||||
## 0.123.8
|
||||
|
||||
### Fixes
|
||||
|
||||
* 🐛 Fix OpenAPI security scheme OAuth2 scopes declaration, deduplicate security schemes with different scopes. PR [#14455](https://github.com/fastapi/fastapi/pull/14455) by [@tiangolo](https://github.com/tiangolo).
|
||||
|
||||
## 0.123.7
|
||||
|
||||
### Fixes
|
||||
|
||||
* 🐛 Fix evaluating stringified annotations in Python 3.10. PR [#11355](https://github.com/fastapi/fastapi/pull/11355) by [@chaen](https://github.com/chaen).
|
||||
|
||||
## 0.123.6
|
||||
|
||||
### Fixes
|
||||
|
||||
* 🐛 Fix support for functools wraps and partial combined, for async and regular functions and classes in path operations and dependencies. PR [#14448](https://github.com/fastapi/fastapi/pull/14448) by [@tiangolo](https://github.com/tiangolo).
|
||||
|
||||
## 0.123.5
|
||||
|
||||
### Features
|
||||
|
||||
* ✨ Allow using dependables with `functools.partial()`. PR [#9753](https://github.com/fastapi/fastapi/pull/9753) by [@lieryan](https://github.com/lieryan).
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
"""FastAPI framework, high performance, easy to learn, fast to code, ready for production"""
|
||||
|
||||
__version__ = "0.123.4"
|
||||
__version__ = "0.123.9"
|
||||
|
||||
from starlette import status as status
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import inspect
|
|||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
from functools import cached_property, partial
|
||||
from typing import Any, Callable, List, Optional, Sequence, Set, Union
|
||||
from typing import Any, Callable, List, Optional, Set, Union
|
||||
|
||||
from fastapi._compat import ModelField
|
||||
from fastapi.security.base import SecurityBase
|
||||
|
|
@ -15,10 +15,17 @@ else: # pragma: no cover
|
|||
from asyncio import iscoroutinefunction
|
||||
|
||||
|
||||
@dataclass
|
||||
class SecurityRequirement:
|
||||
security_scheme: SecurityBase
|
||||
scopes: Optional[Sequence[str]] = None
|
||||
def _unwrapped_call(call: Optional[Callable[..., Any]]) -> Any:
|
||||
if call is None:
|
||||
return call # pragma: no cover
|
||||
unwrapped = inspect.unwrap(_impartial(call))
|
||||
return unwrapped
|
||||
|
||||
|
||||
def _impartial(func: Callable[..., Any]) -> Callable[..., Any]:
|
||||
while isinstance(func, partial):
|
||||
func = func.func
|
||||
return func
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -29,7 +36,6 @@ class Dependant:
|
|||
cookie_params: List[ModelField] = field(default_factory=list)
|
||||
body_params: List[ModelField] = field(default_factory=list)
|
||||
dependencies: List["Dependant"] = field(default_factory=list)
|
||||
security_requirements: List[SecurityRequirement] = field(default_factory=list)
|
||||
name: Optional[str] = None
|
||||
call: Optional[Callable[..., Any]] = None
|
||||
request_param_names: Set[str] = field(default_factory=set)
|
||||
|
|
@ -70,42 +76,108 @@ class Dependant:
|
|||
return True
|
||||
if self.security_scopes_param_names:
|
||||
return True
|
||||
if self._is_security_scheme:
|
||||
return True
|
||||
for sub_dep in self.dependencies:
|
||||
if sub_dep._uses_scopes:
|
||||
return True
|
||||
return False
|
||||
|
||||
@cached_property
|
||||
def _unwrapped_call(self) -> Any:
|
||||
def _is_security_scheme(self) -> bool:
|
||||
if self.call is None:
|
||||
return self.call # pragma: no cover
|
||||
unwrapped = inspect.unwrap(self.call)
|
||||
if isinstance(unwrapped, partial):
|
||||
unwrapped = unwrapped.func
|
||||
return False # pragma: no cover
|
||||
unwrapped = _unwrapped_call(self.call)
|
||||
return isinstance(unwrapped, SecurityBase)
|
||||
|
||||
# Mainly to get the type of SecurityBase, but it's the same self.call
|
||||
@cached_property
|
||||
def _security_scheme(self) -> SecurityBase:
|
||||
unwrapped = _unwrapped_call(self.call)
|
||||
assert isinstance(unwrapped, SecurityBase)
|
||||
return unwrapped
|
||||
|
||||
@cached_property
|
||||
def _security_dependencies(self) -> List["Dependant"]:
|
||||
security_deps = [dep for dep in self.dependencies if dep._is_security_scheme]
|
||||
return security_deps
|
||||
|
||||
@cached_property
|
||||
def is_gen_callable(self) -> bool:
|
||||
if inspect.isgeneratorfunction(self._unwrapped_call):
|
||||
if self.call is None:
|
||||
return False # pragma: no cover
|
||||
if inspect.isgeneratorfunction(
|
||||
_impartial(self.call)
|
||||
) or inspect.isgeneratorfunction(_unwrapped_call(self.call)):
|
||||
return True
|
||||
dunder_call = getattr(self._unwrapped_call, "__call__", None) # noqa: B004
|
||||
return inspect.isgeneratorfunction(dunder_call)
|
||||
dunder_call = getattr(_impartial(self.call), "__call__", None) # noqa: B004
|
||||
if dunder_call is None:
|
||||
return False # pragma: no cover
|
||||
if inspect.isgeneratorfunction(
|
||||
_impartial(dunder_call)
|
||||
) or inspect.isgeneratorfunction(_unwrapped_call(dunder_call)):
|
||||
return True
|
||||
dunder_unwrapped_call = getattr(_unwrapped_call(self.call), "__call__", None) # noqa: B004
|
||||
if dunder_unwrapped_call is None:
|
||||
return False # pragma: no cover
|
||||
if inspect.isgeneratorfunction(
|
||||
_impartial(dunder_unwrapped_call)
|
||||
) or inspect.isgeneratorfunction(_unwrapped_call(dunder_unwrapped_call)):
|
||||
return True
|
||||
return False
|
||||
|
||||
@cached_property
|
||||
def is_async_gen_callable(self) -> bool:
|
||||
if inspect.isasyncgenfunction(self._unwrapped_call):
|
||||
if self.call is None:
|
||||
return False # pragma: no cover
|
||||
if inspect.isasyncgenfunction(
|
||||
_impartial(self.call)
|
||||
) or inspect.isasyncgenfunction(_unwrapped_call(self.call)):
|
||||
return True
|
||||
dunder_call = getattr(self._unwrapped_call, "__call__", None) # noqa: B004
|
||||
return inspect.isasyncgenfunction(dunder_call)
|
||||
dunder_call = getattr(_impartial(self.call), "__call__", None) # noqa: B004
|
||||
if dunder_call is None:
|
||||
return False # pragma: no cover
|
||||
if inspect.isasyncgenfunction(
|
||||
_impartial(dunder_call)
|
||||
) or inspect.isasyncgenfunction(_unwrapped_call(dunder_call)):
|
||||
return True
|
||||
dunder_unwrapped_call = getattr(_unwrapped_call(self.call), "__call__", None) # noqa: B004
|
||||
if dunder_unwrapped_call is None:
|
||||
return False # pragma: no cover
|
||||
if inspect.isasyncgenfunction(
|
||||
_impartial(dunder_unwrapped_call)
|
||||
) or inspect.isasyncgenfunction(_unwrapped_call(dunder_unwrapped_call)):
|
||||
return True
|
||||
return False
|
||||
|
||||
@cached_property
|
||||
def is_coroutine_callable(self) -> bool:
|
||||
if inspect.isroutine(self._unwrapped_call):
|
||||
return iscoroutinefunction(self._unwrapped_call)
|
||||
if inspect.isclass(self._unwrapped_call):
|
||||
return False
|
||||
dunder_call = getattr(self._unwrapped_call, "__call__", None) # noqa: B004
|
||||
return iscoroutinefunction(dunder_call)
|
||||
if self.call is None:
|
||||
return False # pragma: no cover
|
||||
if inspect.isroutine(_impartial(self.call)) and iscoroutinefunction(
|
||||
_impartial(self.call)
|
||||
):
|
||||
return True
|
||||
if inspect.isroutine(_unwrapped_call(self.call)) and iscoroutinefunction(
|
||||
_unwrapped_call(self.call)
|
||||
):
|
||||
return True
|
||||
dunder_call = getattr(_impartial(self.call), "__call__", None) # noqa: B004
|
||||
if dunder_call is None:
|
||||
return False # pragma: no cover
|
||||
if iscoroutinefunction(_impartial(dunder_call)) or iscoroutinefunction(
|
||||
_unwrapped_call(dunder_call)
|
||||
):
|
||||
return True
|
||||
dunder_unwrapped_call = getattr(_unwrapped_call(self.call), "__call__", None) # noqa: B004
|
||||
if dunder_unwrapped_call is None:
|
||||
return False # pragma: no cover
|
||||
if iscoroutinefunction(
|
||||
_impartial(dunder_unwrapped_call)
|
||||
) or iscoroutinefunction(_unwrapped_call(dunder_unwrapped_call)):
|
||||
return True
|
||||
# if inspect.isclass(self.call): False, covered by default return
|
||||
return False
|
||||
|
||||
@cached_property
|
||||
def computed_scope(self) -> Union[str, None]:
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import dataclasses
|
||||
import inspect
|
||||
import sys
|
||||
from contextlib import AsyncExitStack, contextmanager
|
||||
from copy import copy, deepcopy
|
||||
from dataclasses import dataclass
|
||||
|
|
@ -54,10 +55,9 @@ from fastapi.concurrency import (
|
|||
asynccontextmanager,
|
||||
contextmanager_in_threadpool,
|
||||
)
|
||||
from fastapi.dependencies.models import Dependant, SecurityRequirement
|
||||
from fastapi.dependencies.models import Dependant
|
||||
from fastapi.exceptions import DependencyScopeError
|
||||
from fastapi.logger import logger
|
||||
from fastapi.security.base import SecurityBase
|
||||
from fastapi.security.oauth2 import SecurityScopes
|
||||
from fastapi.types import DependencyCacheKey
|
||||
from fastapi.utils import create_model_field, get_path_param_names
|
||||
|
|
@ -141,10 +141,14 @@ def get_flat_dependant(
|
|||
*,
|
||||
skip_repeats: bool = False,
|
||||
visited: Optional[List[DependencyCacheKey]] = None,
|
||||
parent_oauth_scopes: Optional[List[str]] = None,
|
||||
) -> Dependant:
|
||||
if visited is None:
|
||||
visited = []
|
||||
visited.append(dependant.cache_key)
|
||||
use_parent_oauth_scopes = (parent_oauth_scopes or []) + (
|
||||
dependant.oauth_scopes or []
|
||||
)
|
||||
|
||||
flat_dependant = Dependant(
|
||||
path_params=dependant.path_params.copy(),
|
||||
|
|
@ -152,22 +156,37 @@ def get_flat_dependant(
|
|||
header_params=dependant.header_params.copy(),
|
||||
cookie_params=dependant.cookie_params.copy(),
|
||||
body_params=dependant.body_params.copy(),
|
||||
security_requirements=dependant.security_requirements.copy(),
|
||||
name=dependant.name,
|
||||
call=dependant.call,
|
||||
request_param_name=dependant.request_param_name,
|
||||
websocket_param_name=dependant.websocket_param_name,
|
||||
http_connection_param_name=dependant.http_connection_param_name,
|
||||
response_param_name=dependant.response_param_name,
|
||||
background_tasks_param_name=dependant.background_tasks_param_name,
|
||||
security_scopes_param_name=dependant.security_scopes_param_name,
|
||||
own_oauth_scopes=dependant.own_oauth_scopes,
|
||||
parent_oauth_scopes=use_parent_oauth_scopes,
|
||||
use_cache=dependant.use_cache,
|
||||
path=dependant.path,
|
||||
scope=dependant.scope,
|
||||
)
|
||||
for sub_dependant in dependant.dependencies:
|
||||
if skip_repeats and sub_dependant.cache_key in visited:
|
||||
continue
|
||||
flat_sub = get_flat_dependant(
|
||||
sub_dependant, skip_repeats=skip_repeats, visited=visited
|
||||
sub_dependant,
|
||||
skip_repeats=skip_repeats,
|
||||
visited=visited,
|
||||
parent_oauth_scopes=flat_dependant.oauth_scopes,
|
||||
)
|
||||
flat_dependant.dependencies.append(flat_sub)
|
||||
flat_dependant.path_params.extend(flat_sub.path_params)
|
||||
flat_dependant.query_params.extend(flat_sub.query_params)
|
||||
flat_dependant.header_params.extend(flat_sub.header_params)
|
||||
flat_dependant.cookie_params.extend(flat_sub.cookie_params)
|
||||
flat_dependant.body_params.extend(flat_sub.body_params)
|
||||
flat_dependant.security_requirements.extend(flat_sub.security_requirements)
|
||||
flat_dependant.dependencies.extend(flat_sub.dependencies)
|
||||
|
||||
return flat_dependant
|
||||
|
||||
|
||||
|
|
@ -191,7 +210,10 @@ def get_flat_params(dependant: Dependant) -> List[ModelField]:
|
|||
|
||||
|
||||
def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
|
||||
signature = inspect.signature(call)
|
||||
if sys.version_info >= (3, 10):
|
||||
signature = inspect.signature(call, eval_str=True)
|
||||
else:
|
||||
signature = inspect.signature(call)
|
||||
unwrapped = inspect.unwrap(call)
|
||||
globalns = getattr(unwrapped, "__globals__", {})
|
||||
typed_params = [
|
||||
|
|
@ -217,7 +239,10 @@ def get_typed_annotation(annotation: Any, globalns: Dict[str, Any]) -> Any:
|
|||
|
||||
|
||||
def get_typed_return_annotation(call: Callable[..., Any]) -> Any:
|
||||
signature = inspect.signature(call)
|
||||
if sys.version_info >= (3, 10):
|
||||
signature = inspect.signature(call, eval_str=True)
|
||||
else:
|
||||
signature = inspect.signature(call)
|
||||
unwrapped = inspect.unwrap(call)
|
||||
annotation = signature.return_annotation
|
||||
|
||||
|
|
@ -251,11 +276,6 @@ def get_dependant(
|
|||
path_param_names = get_path_param_names(path)
|
||||
endpoint_signature = get_typed_signature(call)
|
||||
signature_params = endpoint_signature.parameters
|
||||
if isinstance(call, SecurityBase):
|
||||
security_requirement = SecurityRequirement(
|
||||
security_scheme=call, scopes=current_scopes
|
||||
)
|
||||
dependant.security_requirements.append(security_requirement)
|
||||
for param_name, param in signature_params.items():
|
||||
is_path_param = param_name in path_param_names
|
||||
param_details = analyze_param(
|
||||
|
|
@ -548,10 +568,10 @@ async def _solve_generator(
|
|||
*, dependant: Dependant, stack: AsyncExitStack, sub_values: Dict[str, Any]
|
||||
) -> Any:
|
||||
assert dependant.call
|
||||
if dependant.is_gen_callable:
|
||||
cm = contextmanager_in_threadpool(contextmanager(dependant.call)(**sub_values))
|
||||
elif dependant.is_async_gen_callable:
|
||||
if dependant.is_async_gen_callable:
|
||||
cm = asynccontextmanager(dependant.call)(**sub_values)
|
||||
elif dependant.is_gen_callable:
|
||||
cm = contextmanager_in_threadpool(contextmanager(dependant.call)(**sub_values))
|
||||
return await stack.enter_async_context(cm)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -79,16 +79,25 @@ def get_openapi_security_definitions(
|
|||
flat_dependant: Dependant,
|
||||
) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]:
|
||||
security_definitions = {}
|
||||
operation_security = []
|
||||
for security_requirement in flat_dependant.security_requirements:
|
||||
# Use a dict to merge scopes for same security scheme
|
||||
operation_security_dict: Dict[str, List[str]] = {}
|
||||
for security_dependency in flat_dependant._security_dependencies:
|
||||
security_definition = jsonable_encoder(
|
||||
security_requirement.security_scheme.model,
|
||||
security_dependency._security_scheme.model,
|
||||
by_alias=True,
|
||||
exclude_none=True,
|
||||
)
|
||||
security_name = security_requirement.security_scheme.scheme_name
|
||||
security_name = security_dependency._security_scheme.scheme_name
|
||||
security_definitions[security_name] = security_definition
|
||||
operation_security.append({security_name: security_requirement.scopes})
|
||||
# Merge scopes for the same security scheme
|
||||
if security_name not in operation_security_dict:
|
||||
operation_security_dict[security_name] = []
|
||||
for scope in security_dependency.oauth_scopes or []:
|
||||
if scope not in operation_security_dict[security_name]:
|
||||
operation_security_dict[security_name].append(scope)
|
||||
operation_security = [
|
||||
{name: scopes} for name, scopes in operation_security_dict.items()
|
||||
]
|
||||
return security_definitions, operation_security
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,10 +1,18 @@
|
|||
import inspect
|
||||
import sys
|
||||
from functools import wraps
|
||||
from typing import AsyncGenerator, Generator
|
||||
|
||||
import pytest
|
||||
from fastapi import Depends, FastAPI
|
||||
from fastapi.concurrency import iterate_in_threadpool, run_in_threadpool
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
if sys.version_info >= (3, 13): # pragma: no cover
|
||||
from inspect import iscoroutinefunction
|
||||
else: # pragma: no cover
|
||||
from asyncio import iscoroutinefunction
|
||||
|
||||
|
||||
def noop_wrap(func):
|
||||
@wraps(func)
|
||||
|
|
@ -14,8 +22,163 @@ def noop_wrap(func):
|
|||
return wrapper
|
||||
|
||||
|
||||
def noop_wrap_async(func):
|
||||
if inspect.isgeneratorfunction(func):
|
||||
|
||||
@wraps(func)
|
||||
async def gen_wrapper(*args, **kwargs):
|
||||
async for item in iterate_in_threadpool(func(*args, **kwargs)):
|
||||
yield item
|
||||
|
||||
return gen_wrapper
|
||||
|
||||
elif inspect.isasyncgenfunction(func):
|
||||
|
||||
@wraps(func)
|
||||
async def async_gen_wrapper(*args, **kwargs):
|
||||
async for item in func(*args, **kwargs):
|
||||
yield item
|
||||
|
||||
return async_gen_wrapper
|
||||
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
if inspect.isroutine(func) and iscoroutinefunction(func):
|
||||
return await func(*args, **kwargs)
|
||||
if inspect.isclass(func):
|
||||
return await run_in_threadpool(func, *args, **kwargs)
|
||||
dunder_call = getattr(func, "__call__", None) # noqa: B004
|
||||
if iscoroutinefunction(dunder_call):
|
||||
return await dunder_call(*args, **kwargs)
|
||||
return await run_in_threadpool(func, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class ClassInstanceDep:
|
||||
def __call__(self):
|
||||
return True
|
||||
|
||||
|
||||
class_instance_dep = ClassInstanceDep()
|
||||
wrapped_class_instance_dep = noop_wrap(class_instance_dep)
|
||||
wrapped_class_instance_dep_async_wrapper = noop_wrap_async(class_instance_dep)
|
||||
|
||||
|
||||
class ClassInstanceGenDep:
|
||||
def __call__(self):
|
||||
yield True
|
||||
|
||||
|
||||
class_instance_gen_dep = ClassInstanceGenDep()
|
||||
wrapped_class_instance_gen_dep = noop_wrap(class_instance_gen_dep)
|
||||
|
||||
|
||||
class ClassInstanceWrappedDep:
|
||||
@noop_wrap
|
||||
def __call__(self):
|
||||
return True
|
||||
|
||||
|
||||
class_instance_wrapped_dep = ClassInstanceWrappedDep()
|
||||
|
||||
|
||||
class ClassInstanceWrappedAsyncDep:
|
||||
@noop_wrap_async
|
||||
def __call__(self):
|
||||
return True
|
||||
|
||||
|
||||
class_instance_wrapped_async_dep = ClassInstanceWrappedAsyncDep()
|
||||
|
||||
|
||||
class ClassInstanceWrappedGenDep:
|
||||
@noop_wrap
|
||||
def __call__(self):
|
||||
yield True
|
||||
|
||||
|
||||
class_instance_wrapped_gen_dep = ClassInstanceWrappedGenDep()
|
||||
|
||||
|
||||
class ClassInstanceWrappedAsyncGenDep:
|
||||
@noop_wrap_async
|
||||
def __call__(self):
|
||||
yield True
|
||||
|
||||
|
||||
class_instance_wrapped_async_gen_dep = ClassInstanceWrappedAsyncGenDep()
|
||||
|
||||
|
||||
class ClassDep:
|
||||
def __init__(self):
|
||||
self.value = True
|
||||
|
||||
|
||||
wrapped_class_dep = noop_wrap(ClassDep)
|
||||
wrapped_class_dep_async_wrapper = noop_wrap_async(ClassDep)
|
||||
|
||||
|
||||
class ClassInstanceAsyncDep:
|
||||
async def __call__(self):
|
||||
return True
|
||||
|
||||
|
||||
class_instance_async_dep = ClassInstanceAsyncDep()
|
||||
wrapped_class_instance_async_dep = noop_wrap(class_instance_async_dep)
|
||||
wrapped_class_instance_async_dep_async_wrapper = noop_wrap_async(
|
||||
class_instance_async_dep
|
||||
)
|
||||
|
||||
|
||||
class ClassInstanceAsyncGenDep:
|
||||
async def __call__(self):
|
||||
yield True
|
||||
|
||||
|
||||
class_instance_async_gen_dep = ClassInstanceAsyncGenDep()
|
||||
wrapped_class_instance_async_gen_dep = noop_wrap(class_instance_async_gen_dep)
|
||||
|
||||
|
||||
class ClassInstanceAsyncWrappedDep:
|
||||
@noop_wrap
|
||||
async def __call__(self):
|
||||
return True
|
||||
|
||||
|
||||
class_instance_async_wrapped_dep = ClassInstanceAsyncWrappedDep()
|
||||
|
||||
|
||||
class ClassInstanceAsyncWrappedAsyncDep:
|
||||
@noop_wrap_async
|
||||
async def __call__(self):
|
||||
return True
|
||||
|
||||
|
||||
class_instance_async_wrapped_async_dep = ClassInstanceAsyncWrappedAsyncDep()
|
||||
|
||||
|
||||
class ClassInstanceAsyncWrappedGenDep:
|
||||
@noop_wrap
|
||||
async def __call__(self):
|
||||
yield True
|
||||
|
||||
|
||||
class_instance_async_wrapped_gen_dep = ClassInstanceAsyncWrappedGenDep()
|
||||
|
||||
|
||||
class ClassInstanceAsyncWrappedGenAsyncDep:
|
||||
@noop_wrap_async
|
||||
async def __call__(self):
|
||||
yield True
|
||||
|
||||
|
||||
class_instance_async_wrapped_gen_async_dep = ClassInstanceAsyncWrappedGenAsyncDep()
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
# Sync wrapper
|
||||
|
||||
|
||||
@noop_wrap
|
||||
def wrapped_dependency() -> bool:
|
||||
|
|
@ -59,16 +222,225 @@ async def get_async_wrapped_gen_dependency(
|
|||
return value
|
||||
|
||||
|
||||
@app.get("/wrapped-class-instance-dependency/")
|
||||
async def get_wrapped_class_instance_dependency(
|
||||
value: bool = Depends(wrapped_class_instance_dep),
|
||||
):
|
||||
return value
|
||||
|
||||
|
||||
@app.get("/wrapped-class-instance-async-dependency/")
|
||||
async def get_wrapped_class_instance_async_dependency(
|
||||
value: bool = Depends(wrapped_class_instance_async_dep),
|
||||
):
|
||||
return value
|
||||
|
||||
|
||||
@app.get("/wrapped-class-instance-gen-dependency/")
|
||||
async def get_wrapped_class_instance_gen_dependency(
|
||||
value: bool = Depends(wrapped_class_instance_gen_dep),
|
||||
):
|
||||
return value
|
||||
|
||||
|
||||
@app.get("/wrapped-class-instance-async-gen-dependency/")
|
||||
async def get_wrapped_class_instance_async_gen_dependency(
|
||||
value: bool = Depends(wrapped_class_instance_async_gen_dep),
|
||||
):
|
||||
return value
|
||||
|
||||
|
||||
@app.get("/class-instance-wrapped-dependency/")
|
||||
async def get_class_instance_wrapped_dependency(
|
||||
value: bool = Depends(class_instance_wrapped_dep),
|
||||
):
|
||||
return value
|
||||
|
||||
|
||||
@app.get("/class-instance-wrapped-async-dependency/")
|
||||
async def get_class_instance_wrapped_async_dependency(
|
||||
value: bool = Depends(class_instance_wrapped_async_dep),
|
||||
):
|
||||
return value
|
||||
|
||||
|
||||
@app.get("/class-instance-async-wrapped-dependency/")
|
||||
async def get_class_instance_async_wrapped_dependency(
|
||||
value: bool = Depends(class_instance_async_wrapped_dep),
|
||||
):
|
||||
return value
|
||||
|
||||
|
||||
@app.get("/class-instance-async-wrapped-async-dependency/")
|
||||
async def get_class_instance_async_wrapped_async_dependency(
|
||||
value: bool = Depends(class_instance_async_wrapped_async_dep),
|
||||
):
|
||||
return value
|
||||
|
||||
|
||||
@app.get("/class-instance-wrapped-gen-dependency/")
|
||||
async def get_class_instance_wrapped_gen_dependency(
|
||||
value: bool = Depends(class_instance_wrapped_gen_dep),
|
||||
):
|
||||
return value
|
||||
|
||||
|
||||
@app.get("/class-instance-wrapped-async-gen-dependency/")
|
||||
async def get_class_instance_wrapped_async_gen_dependency(
|
||||
value: bool = Depends(class_instance_wrapped_async_gen_dep),
|
||||
):
|
||||
return value
|
||||
|
||||
|
||||
@app.get("/class-instance-async-wrapped-gen-dependency/")
|
||||
async def get_class_instance_async_wrapped_gen_dependency(
|
||||
value: bool = Depends(class_instance_async_wrapped_gen_dep),
|
||||
):
|
||||
return value
|
||||
|
||||
|
||||
@app.get("/class-instance-async-wrapped-gen-async-dependency/")
|
||||
async def get_class_instance_async_wrapped_gen_async_dependency(
|
||||
value: bool = Depends(class_instance_async_wrapped_gen_async_dep),
|
||||
):
|
||||
return value
|
||||
|
||||
|
||||
@app.get("/wrapped-class-dependency/")
|
||||
async def get_wrapped_class_dependency(value: ClassDep = Depends(wrapped_class_dep)):
|
||||
return value.value
|
||||
|
||||
|
||||
@app.get("/wrapped-endpoint/")
|
||||
@noop_wrap
|
||||
def get_wrapped_endpoint():
|
||||
return True
|
||||
|
||||
|
||||
@app.get("/async-wrapped-endpoint/")
|
||||
@noop_wrap
|
||||
async def get_async_wrapped_endpoint():
|
||||
return True
|
||||
|
||||
|
||||
# Async wrapper
|
||||
|
||||
|
||||
@noop_wrap_async
|
||||
def wrapped_dependency_async_wrapper() -> bool:
|
||||
return True
|
||||
|
||||
|
||||
@noop_wrap_async
|
||||
def wrapped_gen_dependency_async_wrapper() -> Generator[bool, None, None]:
|
||||
yield True
|
||||
|
||||
|
||||
@noop_wrap_async
|
||||
async def async_wrapped_dependency_async_wrapper() -> bool:
|
||||
return True
|
||||
|
||||
|
||||
@noop_wrap_async
|
||||
async def async_wrapped_gen_dependency_async_wrapper() -> AsyncGenerator[bool, None]:
|
||||
yield True
|
||||
|
||||
|
||||
@app.get("/wrapped-dependency-async-wrapper/")
|
||||
async def get_wrapped_dependency_async_wrapper(
|
||||
value: bool = Depends(wrapped_dependency_async_wrapper),
|
||||
):
|
||||
return value
|
||||
|
||||
|
||||
@app.get("/wrapped-gen-dependency-async-wrapper/")
|
||||
async def get_wrapped_gen_dependency_async_wrapper(
|
||||
value: bool = Depends(wrapped_gen_dependency_async_wrapper),
|
||||
):
|
||||
return value
|
||||
|
||||
|
||||
@app.get("/async-wrapped-dependency-async-wrapper/")
|
||||
async def get_async_wrapped_dependency_async_wrapper(
|
||||
value: bool = Depends(async_wrapped_dependency_async_wrapper),
|
||||
):
|
||||
return value
|
||||
|
||||
|
||||
@app.get("/async-wrapped-gen-dependency-async-wrapper/")
|
||||
async def get_async_wrapped_gen_dependency_async_wrapper(
|
||||
value: bool = Depends(async_wrapped_gen_dependency_async_wrapper),
|
||||
):
|
||||
return value
|
||||
|
||||
|
||||
@app.get("/wrapped-class-instance-dependency-async-wrapper/")
|
||||
async def get_wrapped_class_instance_dependency_async_wrapper(
|
||||
value: bool = Depends(wrapped_class_instance_dep_async_wrapper),
|
||||
):
|
||||
return value
|
||||
|
||||
|
||||
@app.get("/wrapped-class-instance-async-dependency-async-wrapper/")
|
||||
async def get_wrapped_class_instance_async_dependency_async_wrapper(
|
||||
value: bool = Depends(wrapped_class_instance_async_dep_async_wrapper),
|
||||
):
|
||||
return value
|
||||
|
||||
|
||||
@app.get("/wrapped-class-dependency-async-wrapper/")
|
||||
async def get_wrapped_class_dependency_async_wrapper(
|
||||
value: ClassDep = Depends(wrapped_class_dep_async_wrapper),
|
||||
):
|
||||
return value.value
|
||||
|
||||
|
||||
@app.get("/wrapped-endpoint-async-wrapper/")
|
||||
@noop_wrap_async
|
||||
def get_wrapped_endpoint_async_wrapper():
|
||||
return True
|
||||
|
||||
|
||||
@app.get("/async-wrapped-endpoint-async-wrapper/")
|
||||
@noop_wrap_async
|
||||
async def get_async_wrapped_endpoint_async_wrapper():
|
||||
return True
|
||||
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"route",
|
||||
[
|
||||
"/wrapped-dependency",
|
||||
"/wrapped-gen-dependency",
|
||||
"/async-wrapped-dependency",
|
||||
"/async-wrapped-gen-dependency",
|
||||
"/wrapped-dependency/",
|
||||
"/wrapped-gen-dependency/",
|
||||
"/async-wrapped-dependency/",
|
||||
"/async-wrapped-gen-dependency/",
|
||||
"/wrapped-class-instance-dependency/",
|
||||
"/wrapped-class-instance-async-dependency/",
|
||||
"/wrapped-class-instance-gen-dependency/",
|
||||
"/wrapped-class-instance-async-gen-dependency/",
|
||||
"/class-instance-wrapped-dependency/",
|
||||
"/class-instance-wrapped-async-dependency/",
|
||||
"/class-instance-async-wrapped-dependency/",
|
||||
"/class-instance-async-wrapped-async-dependency/",
|
||||
"/class-instance-wrapped-gen-dependency/",
|
||||
"/class-instance-wrapped-async-gen-dependency/",
|
||||
"/class-instance-async-wrapped-gen-dependency/",
|
||||
"/class-instance-async-wrapped-gen-async-dependency/",
|
||||
"/wrapped-class-dependency/",
|
||||
"/wrapped-endpoint/",
|
||||
"/async-wrapped-endpoint/",
|
||||
"/wrapped-dependency-async-wrapper/",
|
||||
"/wrapped-gen-dependency-async-wrapper/",
|
||||
"/async-wrapped-dependency-async-wrapper/",
|
||||
"/async-wrapped-gen-dependency-async-wrapper/",
|
||||
"/wrapped-class-instance-dependency-async-wrapper/",
|
||||
"/wrapped-class-instance-async-dependency-async-wrapper/",
|
||||
"/wrapped-class-dependency-async-wrapper/",
|
||||
"/wrapped-endpoint-async-wrapper/",
|
||||
"/async-wrapped-endpoint-async-wrapper/",
|
||||
],
|
||||
)
|
||||
def test_class_dependency(route):
|
||||
|
|
|
|||
|
|
@ -0,0 +1,198 @@
|
|||
# Ref: https://github.com/fastapi/fastapi/issues/14454
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, FastAPI, Security
|
||||
from fastapi.security import OAuth2AuthorizationCodeBearer
|
||||
from fastapi.testclient import TestClient
|
||||
from inline_snapshot import snapshot
|
||||
from typing_extensions import Annotated
|
||||
|
||||
oauth2_scheme = OAuth2AuthorizationCodeBearer(
|
||||
authorizationUrl="authorize",
|
||||
tokenUrl="token",
|
||||
auto_error=True,
|
||||
scopes={"read": "Read access", "write": "Write access"},
|
||||
)
|
||||
|
||||
|
||||
async def get_token(token: Annotated[str, Depends(oauth2_scheme)]) -> str:
|
||||
return token
|
||||
|
||||
|
||||
app = FastAPI(dependencies=[Depends(get_token)])
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
return {"message": "Hello World"}
|
||||
|
||||
|
||||
@app.get(
|
||||
"/with-oauth2-scheme",
|
||||
dependencies=[Security(oauth2_scheme, scopes=["read", "write"])],
|
||||
)
|
||||
async def read_with_oauth2_scheme():
|
||||
return {"message": "Admin Access"}
|
||||
|
||||
|
||||
@app.get(
|
||||
"/with-get-token", dependencies=[Security(get_token, scopes=["read", "write"])]
|
||||
)
|
||||
async def read_with_get_token():
|
||||
return {"message": "Admin Access"}
|
||||
|
||||
|
||||
router = APIRouter(dependencies=[Security(oauth2_scheme, scopes=["read"])])
|
||||
|
||||
|
||||
@router.get("/items/")
|
||||
async def read_items(token: Optional[str] = Depends(oauth2_scheme)):
|
||||
return {"token": token}
|
||||
|
||||
|
||||
@router.post("/items/")
|
||||
async def create_item(
|
||||
token: Optional[str] = Security(oauth2_scheme, scopes=["read", "write"]),
|
||||
):
|
||||
return {"token": token}
|
||||
|
||||
|
||||
app.include_router(router)
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
def test_root():
|
||||
response = client.get("/", headers={"Authorization": "Bearer testtoken"})
|
||||
assert response.status_code == 200, response.text
|
||||
assert response.json() == {"message": "Hello World"}
|
||||
|
||||
|
||||
def test_read_with_oauth2_scheme():
|
||||
response = client.get(
|
||||
"/with-oauth2-scheme", headers={"Authorization": "Bearer testtoken"}
|
||||
)
|
||||
assert response.status_code == 200, response.text
|
||||
assert response.json() == {"message": "Admin Access"}
|
||||
|
||||
|
||||
def test_read_with_get_token():
|
||||
response = client.get(
|
||||
"/with-get-token", headers={"Authorization": "Bearer testtoken"}
|
||||
)
|
||||
assert response.status_code == 200, response.text
|
||||
assert response.json() == {"message": "Admin Access"}
|
||||
|
||||
|
||||
def test_read_token():
|
||||
response = client.get("/items/", headers={"Authorization": "Bearer testtoken"})
|
||||
assert response.status_code == 200, response.text
|
||||
assert response.json() == {"token": "testtoken"}
|
||||
|
||||
|
||||
def test_create_token():
|
||||
response = client.post("/items/", headers={"Authorization": "Bearer testtoken"})
|
||||
assert response.status_code == 200, response.text
|
||||
assert response.json() == {"token": "testtoken"}
|
||||
|
||||
|
||||
def test_openapi_schema():
|
||||
response = client.get("/openapi.json")
|
||||
assert response.status_code == 200, response.text
|
||||
assert response.json() == snapshot(
|
||||
{
|
||||
"openapi": "3.1.0",
|
||||
"info": {"title": "FastAPI", "version": "0.1.0"},
|
||||
"paths": {
|
||||
"/": {
|
||||
"get": {
|
||||
"summary": "Root",
|
||||
"operationId": "root__get",
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": {"application/json": {"schema": {}}},
|
||||
}
|
||||
},
|
||||
"security": [{"OAuth2AuthorizationCodeBearer": []}],
|
||||
}
|
||||
},
|
||||
"/with-oauth2-scheme": {
|
||||
"get": {
|
||||
"summary": "Read With Oauth2 Scheme",
|
||||
"operationId": "read_with_oauth2_scheme_with_oauth2_scheme_get",
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": {"application/json": {"schema": {}}},
|
||||
}
|
||||
},
|
||||
"security": [
|
||||
{"OAuth2AuthorizationCodeBearer": ["read", "write"]}
|
||||
],
|
||||
}
|
||||
},
|
||||
"/with-get-token": {
|
||||
"get": {
|
||||
"summary": "Read With Get Token",
|
||||
"operationId": "read_with_get_token_with_get_token_get",
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": {"application/json": {"schema": {}}},
|
||||
}
|
||||
},
|
||||
"security": [
|
||||
{"OAuth2AuthorizationCodeBearer": ["read", "write"]}
|
||||
],
|
||||
}
|
||||
},
|
||||
"/items/": {
|
||||
"get": {
|
||||
"summary": "Read Items",
|
||||
"operationId": "read_items_items__get",
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": {"application/json": {"schema": {}}},
|
||||
}
|
||||
},
|
||||
"security": [
|
||||
{"OAuth2AuthorizationCodeBearer": ["read"]},
|
||||
],
|
||||
},
|
||||
"post": {
|
||||
"summary": "Create Item",
|
||||
"operationId": "create_item_items__post",
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": {"application/json": {"schema": {}}},
|
||||
}
|
||||
},
|
||||
"security": [
|
||||
{"OAuth2AuthorizationCodeBearer": ["read", "write"]},
|
||||
],
|
||||
},
|
||||
},
|
||||
},
|
||||
"components": {
|
||||
"securitySchemes": {
|
||||
"OAuth2AuthorizationCodeBearer": {
|
||||
"type": "oauth2",
|
||||
"flows": {
|
||||
"authorizationCode": {
|
||||
"scopes": {
|
||||
"read": "Read access",
|
||||
"write": "Write access",
|
||||
},
|
||||
"authorizationUrl": "authorize",
|
||||
"tokenUrl": "token",
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
|
|
@ -0,0 +1,79 @@
|
|||
# Ref: https://github.com/fastapi/fastapi/issues/14454
|
||||
|
||||
from fastapi import Depends, FastAPI, Security
|
||||
from fastapi.security import OAuth2AuthorizationCodeBearer
|
||||
from fastapi.testclient import TestClient
|
||||
from inline_snapshot import snapshot
|
||||
from typing_extensions import Annotated
|
||||
|
||||
oauth2_scheme = OAuth2AuthorizationCodeBearer(
|
||||
authorizationUrl="api/oauth/authorize",
|
||||
tokenUrl="/api/oauth/token",
|
||||
scopes={"read": "Read access", "write": "Write access"},
|
||||
)
|
||||
|
||||
|
||||
async def get_token(token: Annotated[str, Depends(oauth2_scheme)]) -> str:
|
||||
return token
|
||||
|
||||
|
||||
app = FastAPI(dependencies=[Depends(get_token)])
|
||||
|
||||
|
||||
@app.get("/admin", dependencies=[Security(get_token, scopes=["read", "write"])])
|
||||
async def read_admin():
|
||||
return {"message": "Admin Access"}
|
||||
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
def test_read_admin():
|
||||
response = client.get("/admin", headers={"Authorization": "Bearer faketoken"})
|
||||
assert response.status_code == 200, response.text
|
||||
assert response.json() == {"message": "Admin Access"}
|
||||
|
||||
|
||||
def test_openapi_schema():
|
||||
response = client.get("/openapi.json")
|
||||
assert response.status_code == 200, response.text
|
||||
assert response.json() == snapshot(
|
||||
{
|
||||
"openapi": "3.1.0",
|
||||
"info": {"title": "FastAPI", "version": "0.1.0"},
|
||||
"paths": {
|
||||
"/admin": {
|
||||
"get": {
|
||||
"summary": "Read Admin",
|
||||
"operationId": "read_admin_admin_get",
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": {"application/json": {"schema": {}}},
|
||||
}
|
||||
},
|
||||
"security": [
|
||||
{"OAuth2AuthorizationCodeBearer": ["read", "write"]}
|
||||
],
|
||||
}
|
||||
}
|
||||
},
|
||||
"components": {
|
||||
"securitySchemes": {
|
||||
"OAuth2AuthorizationCodeBearer": {
|
||||
"type": "oauth2",
|
||||
"flows": {
|
||||
"authorizationCode": {
|
||||
"scopes": {
|
||||
"read": "Read access",
|
||||
"write": "Write access",
|
||||
},
|
||||
"authorizationUrl": "api/oauth/authorize",
|
||||
"tokenUrl": "/api/oauth/token",
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
|
|
@ -0,0 +1,26 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from fastapi import Depends, FastAPI, Request
|
||||
from fastapi.testclient import TestClient
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from .utils import needs_py310
|
||||
|
||||
|
||||
class Dep:
|
||||
def __call__(self, request: Request):
|
||||
return "test"
|
||||
|
||||
|
||||
@needs_py310
|
||||
def test_stringified_annotations():
|
||||
app = FastAPI()
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
@app.get("/test/")
|
||||
def call(test: Annotated[str, Depends(Dep())]):
|
||||
return {"test": test}
|
||||
|
||||
response = client.get("/test")
|
||||
assert response.status_code == 200
|
||||
Loading…
Reference in New Issue