diff --git a/fastapi/dependencies/models.py b/fastapi/dependencies/models.py index 9b545e4e5..af168a177 100644 --- a/fastapi/dependencies/models.py +++ b/fastapi/dependencies/models.py @@ -2,7 +2,7 @@ import inspect import sys from dataclasses import dataclass, field from functools import cached_property, partial -from typing import Any, Callable, List, Optional, Sequence, Union +from typing import Any, Callable, List, Optional, Union from fastapi._compat import ModelField from fastapi.security.base import SecurityBase @@ -28,12 +28,6 @@ def _impartial(func: Callable[..., Any]) -> Callable[..., Any]: return func -@dataclass -class SecurityRequirement: - security_scheme: SecurityBase - scopes: Optional[Sequence[str]] = None - - @dataclass class Dependant: path_params: List[ModelField] = field(default_factory=list) @@ -42,7 +36,6 @@ class Dependant: cookie_params: List[ModelField] = field(default_factory=list) body_params: List[ModelField] = field(default_factory=list) dependencies: List["Dependant"] = field(default_factory=list) - security_requirements: List[SecurityRequirement] = field(default_factory=list) name: Optional[str] = None call: Optional[Callable[..., Any]] = None request_param_name: Optional[str] = None @@ -83,11 +76,32 @@ class Dependant: return True if self.security_scopes_param_name is not None: return True + if self._is_security_scheme: + return True for sub_dep in self.dependencies: if sub_dep._uses_scopes: return True return False + @cached_property + def _is_security_scheme(self) -> bool: + if self.call is None: + return False # pragma: no cover + unwrapped = _unwrapped_call(self.call) + return isinstance(unwrapped, SecurityBase) + + # Mainly to get the type of SecurityBase, but it's the same self.call + @cached_property + def _security_scheme(self) -> SecurityBase: + unwrapped = _unwrapped_call(self.call) + assert isinstance(unwrapped, SecurityBase) + return unwrapped + + @cached_property + def _security_dependencies(self) -> List["Dependant"]: + security_deps = [dep for dep in self.dependencies if dep._is_security_scheme] + return security_deps + @cached_property def is_gen_callable(self) -> bool: if self.call is None: diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index 1ff35f648..23bca6f2a 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -55,10 +55,9 @@ from fastapi.concurrency import ( asynccontextmanager, contextmanager_in_threadpool, ) -from fastapi.dependencies.models import Dependant, SecurityRequirement +from fastapi.dependencies.models import Dependant from fastapi.exceptions import DependencyScopeError from fastapi.logger import logger -from fastapi.security.base import SecurityBase from fastapi.security.oauth2 import SecurityScopes from fastapi.types import DependencyCacheKey from fastapi.utils import create_model_field, get_path_param_names @@ -142,10 +141,14 @@ def get_flat_dependant( *, skip_repeats: bool = False, visited: Optional[List[DependencyCacheKey]] = None, + parent_oauth_scopes: Optional[List[str]] = None, ) -> Dependant: if visited is None: visited = [] visited.append(dependant.cache_key) + use_parent_oauth_scopes = (parent_oauth_scopes or []) + ( + dependant.oauth_scopes or [] + ) flat_dependant = Dependant( path_params=dependant.path_params.copy(), @@ -153,22 +156,37 @@ def get_flat_dependant( header_params=dependant.header_params.copy(), cookie_params=dependant.cookie_params.copy(), body_params=dependant.body_params.copy(), - security_requirements=dependant.security_requirements.copy(), + name=dependant.name, + call=dependant.call, + request_param_name=dependant.request_param_name, + websocket_param_name=dependant.websocket_param_name, + http_connection_param_name=dependant.http_connection_param_name, + response_param_name=dependant.response_param_name, + background_tasks_param_name=dependant.background_tasks_param_name, + security_scopes_param_name=dependant.security_scopes_param_name, + own_oauth_scopes=dependant.own_oauth_scopes, + parent_oauth_scopes=use_parent_oauth_scopes, use_cache=dependant.use_cache, path=dependant.path, + scope=dependant.scope, ) for sub_dependant in dependant.dependencies: if skip_repeats and sub_dependant.cache_key in visited: continue flat_sub = get_flat_dependant( - sub_dependant, skip_repeats=skip_repeats, visited=visited + sub_dependant, + skip_repeats=skip_repeats, + visited=visited, + parent_oauth_scopes=flat_dependant.oauth_scopes, ) + flat_dependant.dependencies.append(flat_sub) flat_dependant.path_params.extend(flat_sub.path_params) flat_dependant.query_params.extend(flat_sub.query_params) flat_dependant.header_params.extend(flat_sub.header_params) flat_dependant.cookie_params.extend(flat_sub.cookie_params) flat_dependant.body_params.extend(flat_sub.body_params) - flat_dependant.security_requirements.extend(flat_sub.security_requirements) + flat_dependant.dependencies.extend(flat_sub.dependencies) + return flat_dependant @@ -258,11 +276,6 @@ def get_dependant( path_param_names = get_path_param_names(path) endpoint_signature = get_typed_signature(call) signature_params = endpoint_signature.parameters - if isinstance(call, SecurityBase): - security_requirement = SecurityRequirement( - security_scheme=call, scopes=current_scopes - ) - dependant.security_requirements.append(security_requirement) for param_name, param in signature_params.items(): is_path_param = param_name in path_param_names param_details = analyze_param( diff --git a/fastapi/openapi/utils.py b/fastapi/openapi/utils.py index e7e6da2f7..06c14861a 100644 --- a/fastapi/openapi/utils.py +++ b/fastapi/openapi/utils.py @@ -81,18 +81,18 @@ def get_openapi_security_definitions( security_definitions = {} # Use a dict to merge scopes for same security scheme operation_security_dict: Dict[str, List[str]] = {} - for security_requirement in flat_dependant.security_requirements: + for security_dependency in flat_dependant._security_dependencies: security_definition = jsonable_encoder( - security_requirement.security_scheme.model, + security_dependency._security_scheme.model, by_alias=True, exclude_none=True, ) - security_name = security_requirement.security_scheme.scheme_name + security_name = security_dependency._security_scheme.scheme_name security_definitions[security_name] = security_definition # Merge scopes for the same security scheme if security_name not in operation_security_dict: operation_security_dict[security_name] = [] - for scope in security_requirement.scopes or []: + for scope in security_dependency.oauth_scopes or []: if scope not in operation_security_dict[security_name]: operation_security_dict[security_name].append(scope) operation_security = [