Merge branch 'master' into fix-duplicate-special-dependency-handling

This commit is contained in:
Motov Yurii 2025-12-05 07:46:01 +01:00 committed by GitHub
commit fff0a93ecd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 850 additions and 48 deletions

View File

@ -7,6 +7,32 @@ hide:
## Latest Changes
## 0.123.9
### Fixes
* 🐛 Fix OAuth2 scopes in OpenAPI in extra corner cases, parent dependency with scopes, sub-dependency security scheme without scopes. PR [#14459](https://github.com/fastapi/fastapi/pull/14459) by [@tiangolo](https://github.com/tiangolo).
## 0.123.8
### Fixes
* 🐛 Fix OpenAPI security scheme OAuth2 scopes declaration, deduplicate security schemes with different scopes. PR [#14455](https://github.com/fastapi/fastapi/pull/14455) by [@tiangolo](https://github.com/tiangolo).
## 0.123.7
### Fixes
* 🐛 Fix evaluating stringified annotations in Python 3.10. PR [#11355](https://github.com/fastapi/fastapi/pull/11355) by [@chaen](https://github.com/chaen).
## 0.123.6
### Fixes
* 🐛 Fix support for functools wraps and partial combined, for async and regular functions and classes in path operations and dependencies. PR [#14448](https://github.com/fastapi/fastapi/pull/14448) by [@tiangolo](https://github.com/tiangolo).
## 0.123.5
### Features
* ✨ Allow using dependables with `functools.partial()`. PR [#9753](https://github.com/fastapi/fastapi/pull/9753) by [@lieryan](https://github.com/lieryan).

View File

@ -1,6 +1,6 @@
"""FastAPI framework, high performance, easy to learn, fast to code, ready for production"""
__version__ = "0.123.4"
__version__ = "0.123.9"
from starlette import status as status

View File

@ -2,7 +2,7 @@ import inspect
import sys
from dataclasses import dataclass, field
from functools import cached_property, partial
from typing import Any, Callable, List, Optional, Sequence, Set, Union
from typing import Any, Callable, List, Optional, Set, Union
from fastapi._compat import ModelField
from fastapi.security.base import SecurityBase
@ -15,10 +15,17 @@ else: # pragma: no cover
from asyncio import iscoroutinefunction
@dataclass
class SecurityRequirement:
security_scheme: SecurityBase
scopes: Optional[Sequence[str]] = None
def _unwrapped_call(call: Optional[Callable[..., Any]]) -> Any:
if call is None:
return call # pragma: no cover
unwrapped = inspect.unwrap(_impartial(call))
return unwrapped
def _impartial(func: Callable[..., Any]) -> Callable[..., Any]:
while isinstance(func, partial):
func = func.func
return func
@dataclass
@ -29,7 +36,6 @@ class Dependant:
cookie_params: List[ModelField] = field(default_factory=list)
body_params: List[ModelField] = field(default_factory=list)
dependencies: List["Dependant"] = field(default_factory=list)
security_requirements: List[SecurityRequirement] = field(default_factory=list)
name: Optional[str] = None
call: Optional[Callable[..., Any]] = None
request_param_names: Set[str] = field(default_factory=set)
@ -70,42 +76,108 @@ class Dependant:
return True
if self.security_scopes_param_names:
return True
if self._is_security_scheme:
return True
for sub_dep in self.dependencies:
if sub_dep._uses_scopes:
return True
return False
@cached_property
def _unwrapped_call(self) -> Any:
def _is_security_scheme(self) -> bool:
if self.call is None:
return self.call # pragma: no cover
unwrapped = inspect.unwrap(self.call)
if isinstance(unwrapped, partial):
unwrapped = unwrapped.func
return False # pragma: no cover
unwrapped = _unwrapped_call(self.call)
return isinstance(unwrapped, SecurityBase)
# Mainly to get the type of SecurityBase, but it's the same self.call
@cached_property
def _security_scheme(self) -> SecurityBase:
unwrapped = _unwrapped_call(self.call)
assert isinstance(unwrapped, SecurityBase)
return unwrapped
@cached_property
def _security_dependencies(self) -> List["Dependant"]:
security_deps = [dep for dep in self.dependencies if dep._is_security_scheme]
return security_deps
@cached_property
def is_gen_callable(self) -> bool:
if inspect.isgeneratorfunction(self._unwrapped_call):
if self.call is None:
return False # pragma: no cover
if inspect.isgeneratorfunction(
_impartial(self.call)
) or inspect.isgeneratorfunction(_unwrapped_call(self.call)):
return True
dunder_call = getattr(self._unwrapped_call, "__call__", None) # noqa: B004
return inspect.isgeneratorfunction(dunder_call)
dunder_call = getattr(_impartial(self.call), "__call__", None) # noqa: B004
if dunder_call is None:
return False # pragma: no cover
if inspect.isgeneratorfunction(
_impartial(dunder_call)
) or inspect.isgeneratorfunction(_unwrapped_call(dunder_call)):
return True
dunder_unwrapped_call = getattr(_unwrapped_call(self.call), "__call__", None) # noqa: B004
if dunder_unwrapped_call is None:
return False # pragma: no cover
if inspect.isgeneratorfunction(
_impartial(dunder_unwrapped_call)
) or inspect.isgeneratorfunction(_unwrapped_call(dunder_unwrapped_call)):
return True
return False
@cached_property
def is_async_gen_callable(self) -> bool:
if inspect.isasyncgenfunction(self._unwrapped_call):
if self.call is None:
return False # pragma: no cover
if inspect.isasyncgenfunction(
_impartial(self.call)
) or inspect.isasyncgenfunction(_unwrapped_call(self.call)):
return True
dunder_call = getattr(self._unwrapped_call, "__call__", None) # noqa: B004
return inspect.isasyncgenfunction(dunder_call)
dunder_call = getattr(_impartial(self.call), "__call__", None) # noqa: B004
if dunder_call is None:
return False # pragma: no cover
if inspect.isasyncgenfunction(
_impartial(dunder_call)
) or inspect.isasyncgenfunction(_unwrapped_call(dunder_call)):
return True
dunder_unwrapped_call = getattr(_unwrapped_call(self.call), "__call__", None) # noqa: B004
if dunder_unwrapped_call is None:
return False # pragma: no cover
if inspect.isasyncgenfunction(
_impartial(dunder_unwrapped_call)
) or inspect.isasyncgenfunction(_unwrapped_call(dunder_unwrapped_call)):
return True
return False
@cached_property
def is_coroutine_callable(self) -> bool:
if inspect.isroutine(self._unwrapped_call):
return iscoroutinefunction(self._unwrapped_call)
if inspect.isclass(self._unwrapped_call):
return False
dunder_call = getattr(self._unwrapped_call, "__call__", None) # noqa: B004
return iscoroutinefunction(dunder_call)
if self.call is None:
return False # pragma: no cover
if inspect.isroutine(_impartial(self.call)) and iscoroutinefunction(
_impartial(self.call)
):
return True
if inspect.isroutine(_unwrapped_call(self.call)) and iscoroutinefunction(
_unwrapped_call(self.call)
):
return True
dunder_call = getattr(_impartial(self.call), "__call__", None) # noqa: B004
if dunder_call is None:
return False # pragma: no cover
if iscoroutinefunction(_impartial(dunder_call)) or iscoroutinefunction(
_unwrapped_call(dunder_call)
):
return True
dunder_unwrapped_call = getattr(_unwrapped_call(self.call), "__call__", None) # noqa: B004
if dunder_unwrapped_call is None:
return False # pragma: no cover
if iscoroutinefunction(
_impartial(dunder_unwrapped_call)
) or iscoroutinefunction(_unwrapped_call(dunder_unwrapped_call)):
return True
# if inspect.isclass(self.call): False, covered by default return
return False
@cached_property
def computed_scope(self) -> Union[str, None]:

View File

@ -1,5 +1,6 @@
import dataclasses
import inspect
import sys
from contextlib import AsyncExitStack, contextmanager
from copy import copy, deepcopy
from dataclasses import dataclass
@ -54,10 +55,9 @@ from fastapi.concurrency import (
asynccontextmanager,
contextmanager_in_threadpool,
)
from fastapi.dependencies.models import Dependant, SecurityRequirement
from fastapi.dependencies.models import Dependant
from fastapi.exceptions import DependencyScopeError
from fastapi.logger import logger
from fastapi.security.base import SecurityBase
from fastapi.security.oauth2 import SecurityScopes
from fastapi.types import DependencyCacheKey
from fastapi.utils import create_model_field, get_path_param_names
@ -141,10 +141,14 @@ def get_flat_dependant(
*,
skip_repeats: bool = False,
visited: Optional[List[DependencyCacheKey]] = None,
parent_oauth_scopes: Optional[List[str]] = None,
) -> Dependant:
if visited is None:
visited = []
visited.append(dependant.cache_key)
use_parent_oauth_scopes = (parent_oauth_scopes or []) + (
dependant.oauth_scopes or []
)
flat_dependant = Dependant(
path_params=dependant.path_params.copy(),
@ -152,22 +156,37 @@ def get_flat_dependant(
header_params=dependant.header_params.copy(),
cookie_params=dependant.cookie_params.copy(),
body_params=dependant.body_params.copy(),
security_requirements=dependant.security_requirements.copy(),
name=dependant.name,
call=dependant.call,
request_param_name=dependant.request_param_name,
websocket_param_name=dependant.websocket_param_name,
http_connection_param_name=dependant.http_connection_param_name,
response_param_name=dependant.response_param_name,
background_tasks_param_name=dependant.background_tasks_param_name,
security_scopes_param_name=dependant.security_scopes_param_name,
own_oauth_scopes=dependant.own_oauth_scopes,
parent_oauth_scopes=use_parent_oauth_scopes,
use_cache=dependant.use_cache,
path=dependant.path,
scope=dependant.scope,
)
for sub_dependant in dependant.dependencies:
if skip_repeats and sub_dependant.cache_key in visited:
continue
flat_sub = get_flat_dependant(
sub_dependant, skip_repeats=skip_repeats, visited=visited
sub_dependant,
skip_repeats=skip_repeats,
visited=visited,
parent_oauth_scopes=flat_dependant.oauth_scopes,
)
flat_dependant.dependencies.append(flat_sub)
flat_dependant.path_params.extend(flat_sub.path_params)
flat_dependant.query_params.extend(flat_sub.query_params)
flat_dependant.header_params.extend(flat_sub.header_params)
flat_dependant.cookie_params.extend(flat_sub.cookie_params)
flat_dependant.body_params.extend(flat_sub.body_params)
flat_dependant.security_requirements.extend(flat_sub.security_requirements)
flat_dependant.dependencies.extend(flat_sub.dependencies)
return flat_dependant
@ -191,7 +210,10 @@ def get_flat_params(dependant: Dependant) -> List[ModelField]:
def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
signature = inspect.signature(call)
if sys.version_info >= (3, 10):
signature = inspect.signature(call, eval_str=True)
else:
signature = inspect.signature(call)
unwrapped = inspect.unwrap(call)
globalns = getattr(unwrapped, "__globals__", {})
typed_params = [
@ -217,7 +239,10 @@ def get_typed_annotation(annotation: Any, globalns: Dict[str, Any]) -> Any:
def get_typed_return_annotation(call: Callable[..., Any]) -> Any:
signature = inspect.signature(call)
if sys.version_info >= (3, 10):
signature = inspect.signature(call, eval_str=True)
else:
signature = inspect.signature(call)
unwrapped = inspect.unwrap(call)
annotation = signature.return_annotation
@ -251,11 +276,6 @@ def get_dependant(
path_param_names = get_path_param_names(path)
endpoint_signature = get_typed_signature(call)
signature_params = endpoint_signature.parameters
if isinstance(call, SecurityBase):
security_requirement = SecurityRequirement(
security_scheme=call, scopes=current_scopes
)
dependant.security_requirements.append(security_requirement)
for param_name, param in signature_params.items():
is_path_param = param_name in path_param_names
param_details = analyze_param(
@ -548,10 +568,10 @@ async def _solve_generator(
*, dependant: Dependant, stack: AsyncExitStack, sub_values: Dict[str, Any]
) -> Any:
assert dependant.call
if dependant.is_gen_callable:
cm = contextmanager_in_threadpool(contextmanager(dependant.call)(**sub_values))
elif dependant.is_async_gen_callable:
if dependant.is_async_gen_callable:
cm = asynccontextmanager(dependant.call)(**sub_values)
elif dependant.is_gen_callable:
cm = contextmanager_in_threadpool(contextmanager(dependant.call)(**sub_values))
return await stack.enter_async_context(cm)

View File

@ -79,16 +79,25 @@ def get_openapi_security_definitions(
flat_dependant: Dependant,
) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]:
security_definitions = {}
operation_security = []
for security_requirement in flat_dependant.security_requirements:
# Use a dict to merge scopes for same security scheme
operation_security_dict: Dict[str, List[str]] = {}
for security_dependency in flat_dependant._security_dependencies:
security_definition = jsonable_encoder(
security_requirement.security_scheme.model,
security_dependency._security_scheme.model,
by_alias=True,
exclude_none=True,
)
security_name = security_requirement.security_scheme.scheme_name
security_name = security_dependency._security_scheme.scheme_name
security_definitions[security_name] = security_definition
operation_security.append({security_name: security_requirement.scopes})
# Merge scopes for the same security scheme
if security_name not in operation_security_dict:
operation_security_dict[security_name] = []
for scope in security_dependency.oauth_scopes or []:
if scope not in operation_security_dict[security_name]:
operation_security_dict[security_name].append(scope)
operation_security = [
{name: scopes} for name, scopes in operation_security_dict.items()
]
return security_definitions, operation_security

View File

@ -1,10 +1,18 @@
import inspect
import sys
from functools import wraps
from typing import AsyncGenerator, Generator
import pytest
from fastapi import Depends, FastAPI
from fastapi.concurrency import iterate_in_threadpool, run_in_threadpool
from fastapi.testclient import TestClient
if sys.version_info >= (3, 13): # pragma: no cover
from inspect import iscoroutinefunction
else: # pragma: no cover
from asyncio import iscoroutinefunction
def noop_wrap(func):
@wraps(func)
@ -14,8 +22,163 @@ def noop_wrap(func):
return wrapper
def noop_wrap_async(func):
if inspect.isgeneratorfunction(func):
@wraps(func)
async def gen_wrapper(*args, **kwargs):
async for item in iterate_in_threadpool(func(*args, **kwargs)):
yield item
return gen_wrapper
elif inspect.isasyncgenfunction(func):
@wraps(func)
async def async_gen_wrapper(*args, **kwargs):
async for item in func(*args, **kwargs):
yield item
return async_gen_wrapper
@wraps(func)
async def wrapper(*args, **kwargs):
if inspect.isroutine(func) and iscoroutinefunction(func):
return await func(*args, **kwargs)
if inspect.isclass(func):
return await run_in_threadpool(func, *args, **kwargs)
dunder_call = getattr(func, "__call__", None) # noqa: B004
if iscoroutinefunction(dunder_call):
return await dunder_call(*args, **kwargs)
return await run_in_threadpool(func, *args, **kwargs)
return wrapper
class ClassInstanceDep:
def __call__(self):
return True
class_instance_dep = ClassInstanceDep()
wrapped_class_instance_dep = noop_wrap(class_instance_dep)
wrapped_class_instance_dep_async_wrapper = noop_wrap_async(class_instance_dep)
class ClassInstanceGenDep:
def __call__(self):
yield True
class_instance_gen_dep = ClassInstanceGenDep()
wrapped_class_instance_gen_dep = noop_wrap(class_instance_gen_dep)
class ClassInstanceWrappedDep:
@noop_wrap
def __call__(self):
return True
class_instance_wrapped_dep = ClassInstanceWrappedDep()
class ClassInstanceWrappedAsyncDep:
@noop_wrap_async
def __call__(self):
return True
class_instance_wrapped_async_dep = ClassInstanceWrappedAsyncDep()
class ClassInstanceWrappedGenDep:
@noop_wrap
def __call__(self):
yield True
class_instance_wrapped_gen_dep = ClassInstanceWrappedGenDep()
class ClassInstanceWrappedAsyncGenDep:
@noop_wrap_async
def __call__(self):
yield True
class_instance_wrapped_async_gen_dep = ClassInstanceWrappedAsyncGenDep()
class ClassDep:
def __init__(self):
self.value = True
wrapped_class_dep = noop_wrap(ClassDep)
wrapped_class_dep_async_wrapper = noop_wrap_async(ClassDep)
class ClassInstanceAsyncDep:
async def __call__(self):
return True
class_instance_async_dep = ClassInstanceAsyncDep()
wrapped_class_instance_async_dep = noop_wrap(class_instance_async_dep)
wrapped_class_instance_async_dep_async_wrapper = noop_wrap_async(
class_instance_async_dep
)
class ClassInstanceAsyncGenDep:
async def __call__(self):
yield True
class_instance_async_gen_dep = ClassInstanceAsyncGenDep()
wrapped_class_instance_async_gen_dep = noop_wrap(class_instance_async_gen_dep)
class ClassInstanceAsyncWrappedDep:
@noop_wrap
async def __call__(self):
return True
class_instance_async_wrapped_dep = ClassInstanceAsyncWrappedDep()
class ClassInstanceAsyncWrappedAsyncDep:
@noop_wrap_async
async def __call__(self):
return True
class_instance_async_wrapped_async_dep = ClassInstanceAsyncWrappedAsyncDep()
class ClassInstanceAsyncWrappedGenDep:
@noop_wrap
async def __call__(self):
yield True
class_instance_async_wrapped_gen_dep = ClassInstanceAsyncWrappedGenDep()
class ClassInstanceAsyncWrappedGenAsyncDep:
@noop_wrap_async
async def __call__(self):
yield True
class_instance_async_wrapped_gen_async_dep = ClassInstanceAsyncWrappedGenAsyncDep()
app = FastAPI()
# Sync wrapper
@noop_wrap
def wrapped_dependency() -> bool:
@ -59,16 +222,225 @@ async def get_async_wrapped_gen_dependency(
return value
@app.get("/wrapped-class-instance-dependency/")
async def get_wrapped_class_instance_dependency(
value: bool = Depends(wrapped_class_instance_dep),
):
return value
@app.get("/wrapped-class-instance-async-dependency/")
async def get_wrapped_class_instance_async_dependency(
value: bool = Depends(wrapped_class_instance_async_dep),
):
return value
@app.get("/wrapped-class-instance-gen-dependency/")
async def get_wrapped_class_instance_gen_dependency(
value: bool = Depends(wrapped_class_instance_gen_dep),
):
return value
@app.get("/wrapped-class-instance-async-gen-dependency/")
async def get_wrapped_class_instance_async_gen_dependency(
value: bool = Depends(wrapped_class_instance_async_gen_dep),
):
return value
@app.get("/class-instance-wrapped-dependency/")
async def get_class_instance_wrapped_dependency(
value: bool = Depends(class_instance_wrapped_dep),
):
return value
@app.get("/class-instance-wrapped-async-dependency/")
async def get_class_instance_wrapped_async_dependency(
value: bool = Depends(class_instance_wrapped_async_dep),
):
return value
@app.get("/class-instance-async-wrapped-dependency/")
async def get_class_instance_async_wrapped_dependency(
value: bool = Depends(class_instance_async_wrapped_dep),
):
return value
@app.get("/class-instance-async-wrapped-async-dependency/")
async def get_class_instance_async_wrapped_async_dependency(
value: bool = Depends(class_instance_async_wrapped_async_dep),
):
return value
@app.get("/class-instance-wrapped-gen-dependency/")
async def get_class_instance_wrapped_gen_dependency(
value: bool = Depends(class_instance_wrapped_gen_dep),
):
return value
@app.get("/class-instance-wrapped-async-gen-dependency/")
async def get_class_instance_wrapped_async_gen_dependency(
value: bool = Depends(class_instance_wrapped_async_gen_dep),
):
return value
@app.get("/class-instance-async-wrapped-gen-dependency/")
async def get_class_instance_async_wrapped_gen_dependency(
value: bool = Depends(class_instance_async_wrapped_gen_dep),
):
return value
@app.get("/class-instance-async-wrapped-gen-async-dependency/")
async def get_class_instance_async_wrapped_gen_async_dependency(
value: bool = Depends(class_instance_async_wrapped_gen_async_dep),
):
return value
@app.get("/wrapped-class-dependency/")
async def get_wrapped_class_dependency(value: ClassDep = Depends(wrapped_class_dep)):
return value.value
@app.get("/wrapped-endpoint/")
@noop_wrap
def get_wrapped_endpoint():
return True
@app.get("/async-wrapped-endpoint/")
@noop_wrap
async def get_async_wrapped_endpoint():
return True
# Async wrapper
@noop_wrap_async
def wrapped_dependency_async_wrapper() -> bool:
return True
@noop_wrap_async
def wrapped_gen_dependency_async_wrapper() -> Generator[bool, None, None]:
yield True
@noop_wrap_async
async def async_wrapped_dependency_async_wrapper() -> bool:
return True
@noop_wrap_async
async def async_wrapped_gen_dependency_async_wrapper() -> AsyncGenerator[bool, None]:
yield True
@app.get("/wrapped-dependency-async-wrapper/")
async def get_wrapped_dependency_async_wrapper(
value: bool = Depends(wrapped_dependency_async_wrapper),
):
return value
@app.get("/wrapped-gen-dependency-async-wrapper/")
async def get_wrapped_gen_dependency_async_wrapper(
value: bool = Depends(wrapped_gen_dependency_async_wrapper),
):
return value
@app.get("/async-wrapped-dependency-async-wrapper/")
async def get_async_wrapped_dependency_async_wrapper(
value: bool = Depends(async_wrapped_dependency_async_wrapper),
):
return value
@app.get("/async-wrapped-gen-dependency-async-wrapper/")
async def get_async_wrapped_gen_dependency_async_wrapper(
value: bool = Depends(async_wrapped_gen_dependency_async_wrapper),
):
return value
@app.get("/wrapped-class-instance-dependency-async-wrapper/")
async def get_wrapped_class_instance_dependency_async_wrapper(
value: bool = Depends(wrapped_class_instance_dep_async_wrapper),
):
return value
@app.get("/wrapped-class-instance-async-dependency-async-wrapper/")
async def get_wrapped_class_instance_async_dependency_async_wrapper(
value: bool = Depends(wrapped_class_instance_async_dep_async_wrapper),
):
return value
@app.get("/wrapped-class-dependency-async-wrapper/")
async def get_wrapped_class_dependency_async_wrapper(
value: ClassDep = Depends(wrapped_class_dep_async_wrapper),
):
return value.value
@app.get("/wrapped-endpoint-async-wrapper/")
@noop_wrap_async
def get_wrapped_endpoint_async_wrapper():
return True
@app.get("/async-wrapped-endpoint-async-wrapper/")
@noop_wrap_async
async def get_async_wrapped_endpoint_async_wrapper():
return True
client = TestClient(app)
@pytest.mark.parametrize(
"route",
[
"/wrapped-dependency",
"/wrapped-gen-dependency",
"/async-wrapped-dependency",
"/async-wrapped-gen-dependency",
"/wrapped-dependency/",
"/wrapped-gen-dependency/",
"/async-wrapped-dependency/",
"/async-wrapped-gen-dependency/",
"/wrapped-class-instance-dependency/",
"/wrapped-class-instance-async-dependency/",
"/wrapped-class-instance-gen-dependency/",
"/wrapped-class-instance-async-gen-dependency/",
"/class-instance-wrapped-dependency/",
"/class-instance-wrapped-async-dependency/",
"/class-instance-async-wrapped-dependency/",
"/class-instance-async-wrapped-async-dependency/",
"/class-instance-wrapped-gen-dependency/",
"/class-instance-wrapped-async-gen-dependency/",
"/class-instance-async-wrapped-gen-dependency/",
"/class-instance-async-wrapped-gen-async-dependency/",
"/wrapped-class-dependency/",
"/wrapped-endpoint/",
"/async-wrapped-endpoint/",
"/wrapped-dependency-async-wrapper/",
"/wrapped-gen-dependency-async-wrapper/",
"/async-wrapped-dependency-async-wrapper/",
"/async-wrapped-gen-dependency-async-wrapper/",
"/wrapped-class-instance-dependency-async-wrapper/",
"/wrapped-class-instance-async-dependency-async-wrapper/",
"/wrapped-class-dependency-async-wrapper/",
"/wrapped-endpoint-async-wrapper/",
"/async-wrapped-endpoint-async-wrapper/",
],
)
def test_class_dependency(route):

View File

@ -0,0 +1,198 @@
# Ref: https://github.com/fastapi/fastapi/issues/14454
from typing import Optional
from fastapi import APIRouter, Depends, FastAPI, Security
from fastapi.security import OAuth2AuthorizationCodeBearer
from fastapi.testclient import TestClient
from inline_snapshot import snapshot
from typing_extensions import Annotated
oauth2_scheme = OAuth2AuthorizationCodeBearer(
authorizationUrl="authorize",
tokenUrl="token",
auto_error=True,
scopes={"read": "Read access", "write": "Write access"},
)
async def get_token(token: Annotated[str, Depends(oauth2_scheme)]) -> str:
return token
app = FastAPI(dependencies=[Depends(get_token)])
@app.get("/")
async def root():
return {"message": "Hello World"}
@app.get(
"/with-oauth2-scheme",
dependencies=[Security(oauth2_scheme, scopes=["read", "write"])],
)
async def read_with_oauth2_scheme():
return {"message": "Admin Access"}
@app.get(
"/with-get-token", dependencies=[Security(get_token, scopes=["read", "write"])]
)
async def read_with_get_token():
return {"message": "Admin Access"}
router = APIRouter(dependencies=[Security(oauth2_scheme, scopes=["read"])])
@router.get("/items/")
async def read_items(token: Optional[str] = Depends(oauth2_scheme)):
return {"token": token}
@router.post("/items/")
async def create_item(
token: Optional[str] = Security(oauth2_scheme, scopes=["read", "write"]),
):
return {"token": token}
app.include_router(router)
client = TestClient(app)
def test_root():
response = client.get("/", headers={"Authorization": "Bearer testtoken"})
assert response.status_code == 200, response.text
assert response.json() == {"message": "Hello World"}
def test_read_with_oauth2_scheme():
response = client.get(
"/with-oauth2-scheme", headers={"Authorization": "Bearer testtoken"}
)
assert response.status_code == 200, response.text
assert response.json() == {"message": "Admin Access"}
def test_read_with_get_token():
response = client.get(
"/with-get-token", headers={"Authorization": "Bearer testtoken"}
)
assert response.status_code == 200, response.text
assert response.json() == {"message": "Admin Access"}
def test_read_token():
response = client.get("/items/", headers={"Authorization": "Bearer testtoken"})
assert response.status_code == 200, response.text
assert response.json() == {"token": "testtoken"}
def test_create_token():
response = client.post("/items/", headers={"Authorization": "Bearer testtoken"})
assert response.status_code == 200, response.text
assert response.json() == {"token": "testtoken"}
def test_openapi_schema():
response = client.get("/openapi.json")
assert response.status_code == 200, response.text
assert response.json() == snapshot(
{
"openapi": "3.1.0",
"info": {"title": "FastAPI", "version": "0.1.0"},
"paths": {
"/": {
"get": {
"summary": "Root",
"operationId": "root__get",
"responses": {
"200": {
"description": "Successful Response",
"content": {"application/json": {"schema": {}}},
}
},
"security": [{"OAuth2AuthorizationCodeBearer": []}],
}
},
"/with-oauth2-scheme": {
"get": {
"summary": "Read With Oauth2 Scheme",
"operationId": "read_with_oauth2_scheme_with_oauth2_scheme_get",
"responses": {
"200": {
"description": "Successful Response",
"content": {"application/json": {"schema": {}}},
}
},
"security": [
{"OAuth2AuthorizationCodeBearer": ["read", "write"]}
],
}
},
"/with-get-token": {
"get": {
"summary": "Read With Get Token",
"operationId": "read_with_get_token_with_get_token_get",
"responses": {
"200": {
"description": "Successful Response",
"content": {"application/json": {"schema": {}}},
}
},
"security": [
{"OAuth2AuthorizationCodeBearer": ["read", "write"]}
],
}
},
"/items/": {
"get": {
"summary": "Read Items",
"operationId": "read_items_items__get",
"responses": {
"200": {
"description": "Successful Response",
"content": {"application/json": {"schema": {}}},
}
},
"security": [
{"OAuth2AuthorizationCodeBearer": ["read"]},
],
},
"post": {
"summary": "Create Item",
"operationId": "create_item_items__post",
"responses": {
"200": {
"description": "Successful Response",
"content": {"application/json": {"schema": {}}},
}
},
"security": [
{"OAuth2AuthorizationCodeBearer": ["read", "write"]},
],
},
},
},
"components": {
"securitySchemes": {
"OAuth2AuthorizationCodeBearer": {
"type": "oauth2",
"flows": {
"authorizationCode": {
"scopes": {
"read": "Read access",
"write": "Write access",
},
"authorizationUrl": "authorize",
"tokenUrl": "token",
}
},
}
}
},
}
)

View File

@ -0,0 +1,79 @@
# Ref: https://github.com/fastapi/fastapi/issues/14454
from fastapi import Depends, FastAPI, Security
from fastapi.security import OAuth2AuthorizationCodeBearer
from fastapi.testclient import TestClient
from inline_snapshot import snapshot
from typing_extensions import Annotated
oauth2_scheme = OAuth2AuthorizationCodeBearer(
authorizationUrl="api/oauth/authorize",
tokenUrl="/api/oauth/token",
scopes={"read": "Read access", "write": "Write access"},
)
async def get_token(token: Annotated[str, Depends(oauth2_scheme)]) -> str:
return token
app = FastAPI(dependencies=[Depends(get_token)])
@app.get("/admin", dependencies=[Security(get_token, scopes=["read", "write"])])
async def read_admin():
return {"message": "Admin Access"}
client = TestClient(app)
def test_read_admin():
response = client.get("/admin", headers={"Authorization": "Bearer faketoken"})
assert response.status_code == 200, response.text
assert response.json() == {"message": "Admin Access"}
def test_openapi_schema():
response = client.get("/openapi.json")
assert response.status_code == 200, response.text
assert response.json() == snapshot(
{
"openapi": "3.1.0",
"info": {"title": "FastAPI", "version": "0.1.0"},
"paths": {
"/admin": {
"get": {
"summary": "Read Admin",
"operationId": "read_admin_admin_get",
"responses": {
"200": {
"description": "Successful Response",
"content": {"application/json": {"schema": {}}},
}
},
"security": [
{"OAuth2AuthorizationCodeBearer": ["read", "write"]}
],
}
}
},
"components": {
"securitySchemes": {
"OAuth2AuthorizationCodeBearer": {
"type": "oauth2",
"flows": {
"authorizationCode": {
"scopes": {
"read": "Read access",
"write": "Write access",
},
"authorizationUrl": "api/oauth/authorize",
"tokenUrl": "/api/oauth/token",
}
},
}
}
},
}
)

View File

@ -0,0 +1,26 @@
from __future__ import annotations
from fastapi import Depends, FastAPI, Request
from fastapi.testclient import TestClient
from typing_extensions import Annotated
from .utils import needs_py310
class Dep:
def __call__(self, request: Request):
return "test"
@needs_py310
def test_stringified_annotations():
app = FastAPI()
client = TestClient(app)
@app.get("/test/")
def call(test: Annotated[str, Depends(Dep())]):
return {"test": test}
response = client.get("/test")
assert response.status_code == 200