diff --git a/docs/en/docs/release-notes.md b/docs/en/docs/release-notes.md index 529c80160..ed39da111 100644 --- a/docs/en/docs/release-notes.md +++ b/docs/en/docs/release-notes.md @@ -7,6 +7,32 @@ hide: ## Latest Changes +## 0.123.9 + +### Fixes + +* 🐛 Fix OAuth2 scopes in OpenAPI in extra corner cases, parent dependency with scopes, sub-dependency security scheme without scopes. PR [#14459](https://github.com/fastapi/fastapi/pull/14459) by [@tiangolo](https://github.com/tiangolo). + +## 0.123.8 + +### Fixes + +* 🐛 Fix OpenAPI security scheme OAuth2 scopes declaration, deduplicate security schemes with different scopes. PR [#14455](https://github.com/fastapi/fastapi/pull/14455) by [@tiangolo](https://github.com/tiangolo). + +## 0.123.7 + +### Fixes + +* 🐛 Fix evaluating stringified annotations in Python 3.10. PR [#11355](https://github.com/fastapi/fastapi/pull/11355) by [@chaen](https://github.com/chaen). + +## 0.123.6 + +### Fixes + +* 🐛 Fix support for functools wraps and partial combined, for async and regular functions and classes in path operations and dependencies. PR [#14448](https://github.com/fastapi/fastapi/pull/14448) by [@tiangolo](https://github.com/tiangolo). + +## 0.123.5 + ### Features * ✨ Allow using dependables with `functools.partial()`. PR [#9753](https://github.com/fastapi/fastapi/pull/9753) by [@lieryan](https://github.com/lieryan). diff --git a/fastapi/__init__.py b/fastapi/__init__.py index b1d2dcecc..dc5467b0f 100644 --- a/fastapi/__init__.py +++ b/fastapi/__init__.py @@ -1,6 +1,6 @@ """FastAPI framework, high performance, easy to learn, fast to code, ready for production""" -__version__ = "0.123.4" +__version__ = "0.123.9" from starlette import status as status diff --git a/fastapi/dependencies/models.py b/fastapi/dependencies/models.py index 4944a8f1c..7e1b15a5e 100644 --- a/fastapi/dependencies/models.py +++ b/fastapi/dependencies/models.py @@ -2,7 +2,7 @@ import inspect import sys from dataclasses import dataclass, field from functools import cached_property, partial -from typing import Any, Callable, List, Optional, Sequence, Set, Union +from typing import Any, Callable, List, Optional, Set, Union from fastapi._compat import ModelField from fastapi.security.base import SecurityBase @@ -15,10 +15,17 @@ else: # pragma: no cover from asyncio import iscoroutinefunction -@dataclass -class SecurityRequirement: - security_scheme: SecurityBase - scopes: Optional[Sequence[str]] = None +def _unwrapped_call(call: Optional[Callable[..., Any]]) -> Any: + if call is None: + return call # pragma: no cover + unwrapped = inspect.unwrap(_impartial(call)) + return unwrapped + + +def _impartial(func: Callable[..., Any]) -> Callable[..., Any]: + while isinstance(func, partial): + func = func.func + return func @dataclass @@ -29,7 +36,6 @@ class Dependant: cookie_params: List[ModelField] = field(default_factory=list) body_params: List[ModelField] = field(default_factory=list) dependencies: List["Dependant"] = field(default_factory=list) - security_requirements: List[SecurityRequirement] = field(default_factory=list) name: Optional[str] = None call: Optional[Callable[..., Any]] = None request_param_names: Set[str] = field(default_factory=set) @@ -70,42 +76,108 @@ class Dependant: return True if self.security_scopes_param_names: return True + if self._is_security_scheme: + return True for sub_dep in self.dependencies: if sub_dep._uses_scopes: return True return False @cached_property - def _unwrapped_call(self) -> Any: + def _is_security_scheme(self) -> bool: if self.call is None: - return self.call # pragma: no cover - unwrapped = inspect.unwrap(self.call) - if isinstance(unwrapped, partial): - unwrapped = unwrapped.func + return False # pragma: no cover + unwrapped = _unwrapped_call(self.call) + return isinstance(unwrapped, SecurityBase) + + # Mainly to get the type of SecurityBase, but it's the same self.call + @cached_property + def _security_scheme(self) -> SecurityBase: + unwrapped = _unwrapped_call(self.call) + assert isinstance(unwrapped, SecurityBase) return unwrapped + @cached_property + def _security_dependencies(self) -> List["Dependant"]: + security_deps = [dep for dep in self.dependencies if dep._is_security_scheme] + return security_deps + @cached_property def is_gen_callable(self) -> bool: - if inspect.isgeneratorfunction(self._unwrapped_call): + if self.call is None: + return False # pragma: no cover + if inspect.isgeneratorfunction( + _impartial(self.call) + ) or inspect.isgeneratorfunction(_unwrapped_call(self.call)): return True - dunder_call = getattr(self._unwrapped_call, "__call__", None) # noqa: B004 - return inspect.isgeneratorfunction(dunder_call) + dunder_call = getattr(_impartial(self.call), "__call__", None) # noqa: B004 + if dunder_call is None: + return False # pragma: no cover + if inspect.isgeneratorfunction( + _impartial(dunder_call) + ) or inspect.isgeneratorfunction(_unwrapped_call(dunder_call)): + return True + dunder_unwrapped_call = getattr(_unwrapped_call(self.call), "__call__", None) # noqa: B004 + if dunder_unwrapped_call is None: + return False # pragma: no cover + if inspect.isgeneratorfunction( + _impartial(dunder_unwrapped_call) + ) or inspect.isgeneratorfunction(_unwrapped_call(dunder_unwrapped_call)): + return True + return False @cached_property def is_async_gen_callable(self) -> bool: - if inspect.isasyncgenfunction(self._unwrapped_call): + if self.call is None: + return False # pragma: no cover + if inspect.isasyncgenfunction( + _impartial(self.call) + ) or inspect.isasyncgenfunction(_unwrapped_call(self.call)): return True - dunder_call = getattr(self._unwrapped_call, "__call__", None) # noqa: B004 - return inspect.isasyncgenfunction(dunder_call) + dunder_call = getattr(_impartial(self.call), "__call__", None) # noqa: B004 + if dunder_call is None: + return False # pragma: no cover + if inspect.isasyncgenfunction( + _impartial(dunder_call) + ) or inspect.isasyncgenfunction(_unwrapped_call(dunder_call)): + return True + dunder_unwrapped_call = getattr(_unwrapped_call(self.call), "__call__", None) # noqa: B004 + if dunder_unwrapped_call is None: + return False # pragma: no cover + if inspect.isasyncgenfunction( + _impartial(dunder_unwrapped_call) + ) or inspect.isasyncgenfunction(_unwrapped_call(dunder_unwrapped_call)): + return True + return False @cached_property def is_coroutine_callable(self) -> bool: - if inspect.isroutine(self._unwrapped_call): - return iscoroutinefunction(self._unwrapped_call) - if inspect.isclass(self._unwrapped_call): - return False - dunder_call = getattr(self._unwrapped_call, "__call__", None) # noqa: B004 - return iscoroutinefunction(dunder_call) + if self.call is None: + return False # pragma: no cover + if inspect.isroutine(_impartial(self.call)) and iscoroutinefunction( + _impartial(self.call) + ): + return True + if inspect.isroutine(_unwrapped_call(self.call)) and iscoroutinefunction( + _unwrapped_call(self.call) + ): + return True + dunder_call = getattr(_impartial(self.call), "__call__", None) # noqa: B004 + if dunder_call is None: + return False # pragma: no cover + if iscoroutinefunction(_impartial(dunder_call)) or iscoroutinefunction( + _unwrapped_call(dunder_call) + ): + return True + dunder_unwrapped_call = getattr(_unwrapped_call(self.call), "__call__", None) # noqa: B004 + if dunder_unwrapped_call is None: + return False # pragma: no cover + if iscoroutinefunction( + _impartial(dunder_unwrapped_call) + ) or iscoroutinefunction(_unwrapped_call(dunder_unwrapped_call)): + return True + # if inspect.isclass(self.call): False, covered by default return + return False @cached_property def computed_scope(self) -> Union[str, None]: diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index 4b6d70322..6732a1876 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -1,5 +1,6 @@ import dataclasses import inspect +import sys from contextlib import AsyncExitStack, contextmanager from copy import copy, deepcopy from dataclasses import dataclass @@ -54,10 +55,9 @@ from fastapi.concurrency import ( asynccontextmanager, contextmanager_in_threadpool, ) -from fastapi.dependencies.models import Dependant, SecurityRequirement +from fastapi.dependencies.models import Dependant from fastapi.exceptions import DependencyScopeError from fastapi.logger import logger -from fastapi.security.base import SecurityBase from fastapi.security.oauth2 import SecurityScopes from fastapi.types import DependencyCacheKey from fastapi.utils import create_model_field, get_path_param_names @@ -141,10 +141,14 @@ def get_flat_dependant( *, skip_repeats: bool = False, visited: Optional[List[DependencyCacheKey]] = None, + parent_oauth_scopes: Optional[List[str]] = None, ) -> Dependant: if visited is None: visited = [] visited.append(dependant.cache_key) + use_parent_oauth_scopes = (parent_oauth_scopes or []) + ( + dependant.oauth_scopes or [] + ) flat_dependant = Dependant( path_params=dependant.path_params.copy(), @@ -152,22 +156,37 @@ def get_flat_dependant( header_params=dependant.header_params.copy(), cookie_params=dependant.cookie_params.copy(), body_params=dependant.body_params.copy(), - security_requirements=dependant.security_requirements.copy(), + name=dependant.name, + call=dependant.call, + request_param_name=dependant.request_param_name, + websocket_param_name=dependant.websocket_param_name, + http_connection_param_name=dependant.http_connection_param_name, + response_param_name=dependant.response_param_name, + background_tasks_param_name=dependant.background_tasks_param_name, + security_scopes_param_name=dependant.security_scopes_param_name, + own_oauth_scopes=dependant.own_oauth_scopes, + parent_oauth_scopes=use_parent_oauth_scopes, use_cache=dependant.use_cache, path=dependant.path, + scope=dependant.scope, ) for sub_dependant in dependant.dependencies: if skip_repeats and sub_dependant.cache_key in visited: continue flat_sub = get_flat_dependant( - sub_dependant, skip_repeats=skip_repeats, visited=visited + sub_dependant, + skip_repeats=skip_repeats, + visited=visited, + parent_oauth_scopes=flat_dependant.oauth_scopes, ) + flat_dependant.dependencies.append(flat_sub) flat_dependant.path_params.extend(flat_sub.path_params) flat_dependant.query_params.extend(flat_sub.query_params) flat_dependant.header_params.extend(flat_sub.header_params) flat_dependant.cookie_params.extend(flat_sub.cookie_params) flat_dependant.body_params.extend(flat_sub.body_params) - flat_dependant.security_requirements.extend(flat_sub.security_requirements) + flat_dependant.dependencies.extend(flat_sub.dependencies) + return flat_dependant @@ -191,7 +210,10 @@ def get_flat_params(dependant: Dependant) -> List[ModelField]: def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature: - signature = inspect.signature(call) + if sys.version_info >= (3, 10): + signature = inspect.signature(call, eval_str=True) + else: + signature = inspect.signature(call) unwrapped = inspect.unwrap(call) globalns = getattr(unwrapped, "__globals__", {}) typed_params = [ @@ -217,7 +239,10 @@ def get_typed_annotation(annotation: Any, globalns: Dict[str, Any]) -> Any: def get_typed_return_annotation(call: Callable[..., Any]) -> Any: - signature = inspect.signature(call) + if sys.version_info >= (3, 10): + signature = inspect.signature(call, eval_str=True) + else: + signature = inspect.signature(call) unwrapped = inspect.unwrap(call) annotation = signature.return_annotation @@ -251,11 +276,6 @@ def get_dependant( path_param_names = get_path_param_names(path) endpoint_signature = get_typed_signature(call) signature_params = endpoint_signature.parameters - if isinstance(call, SecurityBase): - security_requirement = SecurityRequirement( - security_scheme=call, scopes=current_scopes - ) - dependant.security_requirements.append(security_requirement) for param_name, param in signature_params.items(): is_path_param = param_name in path_param_names param_details = analyze_param( @@ -548,10 +568,10 @@ async def _solve_generator( *, dependant: Dependant, stack: AsyncExitStack, sub_values: Dict[str, Any] ) -> Any: assert dependant.call - if dependant.is_gen_callable: - cm = contextmanager_in_threadpool(contextmanager(dependant.call)(**sub_values)) - elif dependant.is_async_gen_callable: + if dependant.is_async_gen_callable: cm = asynccontextmanager(dependant.call)(**sub_values) + elif dependant.is_gen_callable: + cm = contextmanager_in_threadpool(contextmanager(dependant.call)(**sub_values)) return await stack.enter_async_context(cm) diff --git a/fastapi/openapi/utils.py b/fastapi/openapi/utils.py index dbc93d289..06c14861a 100644 --- a/fastapi/openapi/utils.py +++ b/fastapi/openapi/utils.py @@ -79,16 +79,25 @@ def get_openapi_security_definitions( flat_dependant: Dependant, ) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]: security_definitions = {} - operation_security = [] - for security_requirement in flat_dependant.security_requirements: + # Use a dict to merge scopes for same security scheme + operation_security_dict: Dict[str, List[str]] = {} + for security_dependency in flat_dependant._security_dependencies: security_definition = jsonable_encoder( - security_requirement.security_scheme.model, + security_dependency._security_scheme.model, by_alias=True, exclude_none=True, ) - security_name = security_requirement.security_scheme.scheme_name + security_name = security_dependency._security_scheme.scheme_name security_definitions[security_name] = security_definition - operation_security.append({security_name: security_requirement.scopes}) + # Merge scopes for the same security scheme + if security_name not in operation_security_dict: + operation_security_dict[security_name] = [] + for scope in security_dependency.oauth_scopes or []: + if scope not in operation_security_dict[security_name]: + operation_security_dict[security_name].append(scope) + operation_security = [ + {name: scopes} for name, scopes in operation_security_dict.items() + ] return security_definitions, operation_security diff --git a/tests/test_dependency_wrapped.py b/tests/test_dependency_wrapped.py index f581ccba4..08356712d 100644 --- a/tests/test_dependency_wrapped.py +++ b/tests/test_dependency_wrapped.py @@ -1,10 +1,18 @@ +import inspect +import sys from functools import wraps from typing import AsyncGenerator, Generator import pytest from fastapi import Depends, FastAPI +from fastapi.concurrency import iterate_in_threadpool, run_in_threadpool from fastapi.testclient import TestClient +if sys.version_info >= (3, 13): # pragma: no cover + from inspect import iscoroutinefunction +else: # pragma: no cover + from asyncio import iscoroutinefunction + def noop_wrap(func): @wraps(func) @@ -14,8 +22,163 @@ def noop_wrap(func): return wrapper +def noop_wrap_async(func): + if inspect.isgeneratorfunction(func): + + @wraps(func) + async def gen_wrapper(*args, **kwargs): + async for item in iterate_in_threadpool(func(*args, **kwargs)): + yield item + + return gen_wrapper + + elif inspect.isasyncgenfunction(func): + + @wraps(func) + async def async_gen_wrapper(*args, **kwargs): + async for item in func(*args, **kwargs): + yield item + + return async_gen_wrapper + + @wraps(func) + async def wrapper(*args, **kwargs): + if inspect.isroutine(func) and iscoroutinefunction(func): + return await func(*args, **kwargs) + if inspect.isclass(func): + return await run_in_threadpool(func, *args, **kwargs) + dunder_call = getattr(func, "__call__", None) # noqa: B004 + if iscoroutinefunction(dunder_call): + return await dunder_call(*args, **kwargs) + return await run_in_threadpool(func, *args, **kwargs) + + return wrapper + + +class ClassInstanceDep: + def __call__(self): + return True + + +class_instance_dep = ClassInstanceDep() +wrapped_class_instance_dep = noop_wrap(class_instance_dep) +wrapped_class_instance_dep_async_wrapper = noop_wrap_async(class_instance_dep) + + +class ClassInstanceGenDep: + def __call__(self): + yield True + + +class_instance_gen_dep = ClassInstanceGenDep() +wrapped_class_instance_gen_dep = noop_wrap(class_instance_gen_dep) + + +class ClassInstanceWrappedDep: + @noop_wrap + def __call__(self): + return True + + +class_instance_wrapped_dep = ClassInstanceWrappedDep() + + +class ClassInstanceWrappedAsyncDep: + @noop_wrap_async + def __call__(self): + return True + + +class_instance_wrapped_async_dep = ClassInstanceWrappedAsyncDep() + + +class ClassInstanceWrappedGenDep: + @noop_wrap + def __call__(self): + yield True + + +class_instance_wrapped_gen_dep = ClassInstanceWrappedGenDep() + + +class ClassInstanceWrappedAsyncGenDep: + @noop_wrap_async + def __call__(self): + yield True + + +class_instance_wrapped_async_gen_dep = ClassInstanceWrappedAsyncGenDep() + + +class ClassDep: + def __init__(self): + self.value = True + + +wrapped_class_dep = noop_wrap(ClassDep) +wrapped_class_dep_async_wrapper = noop_wrap_async(ClassDep) + + +class ClassInstanceAsyncDep: + async def __call__(self): + return True + + +class_instance_async_dep = ClassInstanceAsyncDep() +wrapped_class_instance_async_dep = noop_wrap(class_instance_async_dep) +wrapped_class_instance_async_dep_async_wrapper = noop_wrap_async( + class_instance_async_dep +) + + +class ClassInstanceAsyncGenDep: + async def __call__(self): + yield True + + +class_instance_async_gen_dep = ClassInstanceAsyncGenDep() +wrapped_class_instance_async_gen_dep = noop_wrap(class_instance_async_gen_dep) + + +class ClassInstanceAsyncWrappedDep: + @noop_wrap + async def __call__(self): + return True + + +class_instance_async_wrapped_dep = ClassInstanceAsyncWrappedDep() + + +class ClassInstanceAsyncWrappedAsyncDep: + @noop_wrap_async + async def __call__(self): + return True + + +class_instance_async_wrapped_async_dep = ClassInstanceAsyncWrappedAsyncDep() + + +class ClassInstanceAsyncWrappedGenDep: + @noop_wrap + async def __call__(self): + yield True + + +class_instance_async_wrapped_gen_dep = ClassInstanceAsyncWrappedGenDep() + + +class ClassInstanceAsyncWrappedGenAsyncDep: + @noop_wrap_async + async def __call__(self): + yield True + + +class_instance_async_wrapped_gen_async_dep = ClassInstanceAsyncWrappedGenAsyncDep() + app = FastAPI() +# Sync wrapper + @noop_wrap def wrapped_dependency() -> bool: @@ -59,16 +222,225 @@ async def get_async_wrapped_gen_dependency( return value +@app.get("/wrapped-class-instance-dependency/") +async def get_wrapped_class_instance_dependency( + value: bool = Depends(wrapped_class_instance_dep), +): + return value + + +@app.get("/wrapped-class-instance-async-dependency/") +async def get_wrapped_class_instance_async_dependency( + value: bool = Depends(wrapped_class_instance_async_dep), +): + return value + + +@app.get("/wrapped-class-instance-gen-dependency/") +async def get_wrapped_class_instance_gen_dependency( + value: bool = Depends(wrapped_class_instance_gen_dep), +): + return value + + +@app.get("/wrapped-class-instance-async-gen-dependency/") +async def get_wrapped_class_instance_async_gen_dependency( + value: bool = Depends(wrapped_class_instance_async_gen_dep), +): + return value + + +@app.get("/class-instance-wrapped-dependency/") +async def get_class_instance_wrapped_dependency( + value: bool = Depends(class_instance_wrapped_dep), +): + return value + + +@app.get("/class-instance-wrapped-async-dependency/") +async def get_class_instance_wrapped_async_dependency( + value: bool = Depends(class_instance_wrapped_async_dep), +): + return value + + +@app.get("/class-instance-async-wrapped-dependency/") +async def get_class_instance_async_wrapped_dependency( + value: bool = Depends(class_instance_async_wrapped_dep), +): + return value + + +@app.get("/class-instance-async-wrapped-async-dependency/") +async def get_class_instance_async_wrapped_async_dependency( + value: bool = Depends(class_instance_async_wrapped_async_dep), +): + return value + + +@app.get("/class-instance-wrapped-gen-dependency/") +async def get_class_instance_wrapped_gen_dependency( + value: bool = Depends(class_instance_wrapped_gen_dep), +): + return value + + +@app.get("/class-instance-wrapped-async-gen-dependency/") +async def get_class_instance_wrapped_async_gen_dependency( + value: bool = Depends(class_instance_wrapped_async_gen_dep), +): + return value + + +@app.get("/class-instance-async-wrapped-gen-dependency/") +async def get_class_instance_async_wrapped_gen_dependency( + value: bool = Depends(class_instance_async_wrapped_gen_dep), +): + return value + + +@app.get("/class-instance-async-wrapped-gen-async-dependency/") +async def get_class_instance_async_wrapped_gen_async_dependency( + value: bool = Depends(class_instance_async_wrapped_gen_async_dep), +): + return value + + +@app.get("/wrapped-class-dependency/") +async def get_wrapped_class_dependency(value: ClassDep = Depends(wrapped_class_dep)): + return value.value + + +@app.get("/wrapped-endpoint/") +@noop_wrap +def get_wrapped_endpoint(): + return True + + +@app.get("/async-wrapped-endpoint/") +@noop_wrap +async def get_async_wrapped_endpoint(): + return True + + +# Async wrapper + + +@noop_wrap_async +def wrapped_dependency_async_wrapper() -> bool: + return True + + +@noop_wrap_async +def wrapped_gen_dependency_async_wrapper() -> Generator[bool, None, None]: + yield True + + +@noop_wrap_async +async def async_wrapped_dependency_async_wrapper() -> bool: + return True + + +@noop_wrap_async +async def async_wrapped_gen_dependency_async_wrapper() -> AsyncGenerator[bool, None]: + yield True + + +@app.get("/wrapped-dependency-async-wrapper/") +async def get_wrapped_dependency_async_wrapper( + value: bool = Depends(wrapped_dependency_async_wrapper), +): + return value + + +@app.get("/wrapped-gen-dependency-async-wrapper/") +async def get_wrapped_gen_dependency_async_wrapper( + value: bool = Depends(wrapped_gen_dependency_async_wrapper), +): + return value + + +@app.get("/async-wrapped-dependency-async-wrapper/") +async def get_async_wrapped_dependency_async_wrapper( + value: bool = Depends(async_wrapped_dependency_async_wrapper), +): + return value + + +@app.get("/async-wrapped-gen-dependency-async-wrapper/") +async def get_async_wrapped_gen_dependency_async_wrapper( + value: bool = Depends(async_wrapped_gen_dependency_async_wrapper), +): + return value + + +@app.get("/wrapped-class-instance-dependency-async-wrapper/") +async def get_wrapped_class_instance_dependency_async_wrapper( + value: bool = Depends(wrapped_class_instance_dep_async_wrapper), +): + return value + + +@app.get("/wrapped-class-instance-async-dependency-async-wrapper/") +async def get_wrapped_class_instance_async_dependency_async_wrapper( + value: bool = Depends(wrapped_class_instance_async_dep_async_wrapper), +): + return value + + +@app.get("/wrapped-class-dependency-async-wrapper/") +async def get_wrapped_class_dependency_async_wrapper( + value: ClassDep = Depends(wrapped_class_dep_async_wrapper), +): + return value.value + + +@app.get("/wrapped-endpoint-async-wrapper/") +@noop_wrap_async +def get_wrapped_endpoint_async_wrapper(): + return True + + +@app.get("/async-wrapped-endpoint-async-wrapper/") +@noop_wrap_async +async def get_async_wrapped_endpoint_async_wrapper(): + return True + + client = TestClient(app) @pytest.mark.parametrize( "route", [ - "/wrapped-dependency", - "/wrapped-gen-dependency", - "/async-wrapped-dependency", - "/async-wrapped-gen-dependency", + "/wrapped-dependency/", + "/wrapped-gen-dependency/", + "/async-wrapped-dependency/", + "/async-wrapped-gen-dependency/", + "/wrapped-class-instance-dependency/", + "/wrapped-class-instance-async-dependency/", + "/wrapped-class-instance-gen-dependency/", + "/wrapped-class-instance-async-gen-dependency/", + "/class-instance-wrapped-dependency/", + "/class-instance-wrapped-async-dependency/", + "/class-instance-async-wrapped-dependency/", + "/class-instance-async-wrapped-async-dependency/", + "/class-instance-wrapped-gen-dependency/", + "/class-instance-wrapped-async-gen-dependency/", + "/class-instance-async-wrapped-gen-dependency/", + "/class-instance-async-wrapped-gen-async-dependency/", + "/wrapped-class-dependency/", + "/wrapped-endpoint/", + "/async-wrapped-endpoint/", + "/wrapped-dependency-async-wrapper/", + "/wrapped-gen-dependency-async-wrapper/", + "/async-wrapped-dependency-async-wrapper/", + "/async-wrapped-gen-dependency-async-wrapper/", + "/wrapped-class-instance-dependency-async-wrapper/", + "/wrapped-class-instance-async-dependency-async-wrapper/", + "/wrapped-class-dependency-async-wrapper/", + "/wrapped-endpoint-async-wrapper/", + "/async-wrapped-endpoint-async-wrapper/", ], ) def test_class_dependency(route): diff --git a/tests/test_security_oauth2_authorization_code_bearer_scopes_openapi.py b/tests/test_security_oauth2_authorization_code_bearer_scopes_openapi.py new file mode 100644 index 000000000..d41f1dc1f --- /dev/null +++ b/tests/test_security_oauth2_authorization_code_bearer_scopes_openapi.py @@ -0,0 +1,198 @@ +# Ref: https://github.com/fastapi/fastapi/issues/14454 + +from typing import Optional + +from fastapi import APIRouter, Depends, FastAPI, Security +from fastapi.security import OAuth2AuthorizationCodeBearer +from fastapi.testclient import TestClient +from inline_snapshot import snapshot +from typing_extensions import Annotated + +oauth2_scheme = OAuth2AuthorizationCodeBearer( + authorizationUrl="authorize", + tokenUrl="token", + auto_error=True, + scopes={"read": "Read access", "write": "Write access"}, +) + + +async def get_token(token: Annotated[str, Depends(oauth2_scheme)]) -> str: + return token + + +app = FastAPI(dependencies=[Depends(get_token)]) + + +@app.get("/") +async def root(): + return {"message": "Hello World"} + + +@app.get( + "/with-oauth2-scheme", + dependencies=[Security(oauth2_scheme, scopes=["read", "write"])], +) +async def read_with_oauth2_scheme(): + return {"message": "Admin Access"} + + +@app.get( + "/with-get-token", dependencies=[Security(get_token, scopes=["read", "write"])] +) +async def read_with_get_token(): + return {"message": "Admin Access"} + + +router = APIRouter(dependencies=[Security(oauth2_scheme, scopes=["read"])]) + + +@router.get("/items/") +async def read_items(token: Optional[str] = Depends(oauth2_scheme)): + return {"token": token} + + +@router.post("/items/") +async def create_item( + token: Optional[str] = Security(oauth2_scheme, scopes=["read", "write"]), +): + return {"token": token} + + +app.include_router(router) + +client = TestClient(app) + + +def test_root(): + response = client.get("/", headers={"Authorization": "Bearer testtoken"}) + assert response.status_code == 200, response.text + assert response.json() == {"message": "Hello World"} + + +def test_read_with_oauth2_scheme(): + response = client.get( + "/with-oauth2-scheme", headers={"Authorization": "Bearer testtoken"} + ) + assert response.status_code == 200, response.text + assert response.json() == {"message": "Admin Access"} + + +def test_read_with_get_token(): + response = client.get( + "/with-get-token", headers={"Authorization": "Bearer testtoken"} + ) + assert response.status_code == 200, response.text + assert response.json() == {"message": "Admin Access"} + + +def test_read_token(): + response = client.get("/items/", headers={"Authorization": "Bearer testtoken"}) + assert response.status_code == 200, response.text + assert response.json() == {"token": "testtoken"} + + +def test_create_token(): + response = client.post("/items/", headers={"Authorization": "Bearer testtoken"}) + assert response.status_code == 200, response.text + assert response.json() == {"token": "testtoken"} + + +def test_openapi_schema(): + response = client.get("/openapi.json") + assert response.status_code == 200, response.text + assert response.json() == snapshot( + { + "openapi": "3.1.0", + "info": {"title": "FastAPI", "version": "0.1.0"}, + "paths": { + "/": { + "get": { + "summary": "Root", + "operationId": "root__get", + "responses": { + "200": { + "description": "Successful Response", + "content": {"application/json": {"schema": {}}}, + } + }, + "security": [{"OAuth2AuthorizationCodeBearer": []}], + } + }, + "/with-oauth2-scheme": { + "get": { + "summary": "Read With Oauth2 Scheme", + "operationId": "read_with_oauth2_scheme_with_oauth2_scheme_get", + "responses": { + "200": { + "description": "Successful Response", + "content": {"application/json": {"schema": {}}}, + } + }, + "security": [ + {"OAuth2AuthorizationCodeBearer": ["read", "write"]} + ], + } + }, + "/with-get-token": { + "get": { + "summary": "Read With Get Token", + "operationId": "read_with_get_token_with_get_token_get", + "responses": { + "200": { + "description": "Successful Response", + "content": {"application/json": {"schema": {}}}, + } + }, + "security": [ + {"OAuth2AuthorizationCodeBearer": ["read", "write"]} + ], + } + }, + "/items/": { + "get": { + "summary": "Read Items", + "operationId": "read_items_items__get", + "responses": { + "200": { + "description": "Successful Response", + "content": {"application/json": {"schema": {}}}, + } + }, + "security": [ + {"OAuth2AuthorizationCodeBearer": ["read"]}, + ], + }, + "post": { + "summary": "Create Item", + "operationId": "create_item_items__post", + "responses": { + "200": { + "description": "Successful Response", + "content": {"application/json": {"schema": {}}}, + } + }, + "security": [ + {"OAuth2AuthorizationCodeBearer": ["read", "write"]}, + ], + }, + }, + }, + "components": { + "securitySchemes": { + "OAuth2AuthorizationCodeBearer": { + "type": "oauth2", + "flows": { + "authorizationCode": { + "scopes": { + "read": "Read access", + "write": "Write access", + }, + "authorizationUrl": "authorize", + "tokenUrl": "token", + } + }, + } + } + }, + } + ) diff --git a/tests/test_security_oauth2_authorization_code_bearer_scopes_openapi_simple.py b/tests/test_security_oauth2_authorization_code_bearer_scopes_openapi_simple.py new file mode 100644 index 000000000..ff866d4fc --- /dev/null +++ b/tests/test_security_oauth2_authorization_code_bearer_scopes_openapi_simple.py @@ -0,0 +1,79 @@ +# Ref: https://github.com/fastapi/fastapi/issues/14454 + +from fastapi import Depends, FastAPI, Security +from fastapi.security import OAuth2AuthorizationCodeBearer +from fastapi.testclient import TestClient +from inline_snapshot import snapshot +from typing_extensions import Annotated + +oauth2_scheme = OAuth2AuthorizationCodeBearer( + authorizationUrl="api/oauth/authorize", + tokenUrl="/api/oauth/token", + scopes={"read": "Read access", "write": "Write access"}, +) + + +async def get_token(token: Annotated[str, Depends(oauth2_scheme)]) -> str: + return token + + +app = FastAPI(dependencies=[Depends(get_token)]) + + +@app.get("/admin", dependencies=[Security(get_token, scopes=["read", "write"])]) +async def read_admin(): + return {"message": "Admin Access"} + + +client = TestClient(app) + + +def test_read_admin(): + response = client.get("/admin", headers={"Authorization": "Bearer faketoken"}) + assert response.status_code == 200, response.text + assert response.json() == {"message": "Admin Access"} + + +def test_openapi_schema(): + response = client.get("/openapi.json") + assert response.status_code == 200, response.text + assert response.json() == snapshot( + { + "openapi": "3.1.0", + "info": {"title": "FastAPI", "version": "0.1.0"}, + "paths": { + "/admin": { + "get": { + "summary": "Read Admin", + "operationId": "read_admin_admin_get", + "responses": { + "200": { + "description": "Successful Response", + "content": {"application/json": {"schema": {}}}, + } + }, + "security": [ + {"OAuth2AuthorizationCodeBearer": ["read", "write"]} + ], + } + } + }, + "components": { + "securitySchemes": { + "OAuth2AuthorizationCodeBearer": { + "type": "oauth2", + "flows": { + "authorizationCode": { + "scopes": { + "read": "Read access", + "write": "Write access", + }, + "authorizationUrl": "api/oauth/authorize", + "tokenUrl": "/api/oauth/token", + } + }, + } + } + }, + } + ) diff --git a/tests/test_stringified_annotations_simple.py b/tests/test_stringified_annotations_simple.py new file mode 100644 index 000000000..9bd6d09d6 --- /dev/null +++ b/tests/test_stringified_annotations_simple.py @@ -0,0 +1,26 @@ +from __future__ import annotations + +from fastapi import Depends, FastAPI, Request +from fastapi.testclient import TestClient +from typing_extensions import Annotated + +from .utils import needs_py310 + + +class Dep: + def __call__(self, request: Request): + return "test" + + +@needs_py310 +def test_stringified_annotations(): + app = FastAPI() + + client = TestClient(app) + + @app.get("/test/") + def call(test: Annotated[str, Depends(Dep())]): + return {"test": test} + + response = client.get("/test") + assert response.status_code == 200