mirror of https://github.com/tiangolo/fastapi.git
♻️ Refactor logic to get security scheme dependencies
This commit is contained in:
parent
86ed53f3ef
commit
be1bca1c1f
|
|
@ -2,7 +2,7 @@ import inspect
|
|||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
from functools import cached_property, partial
|
||||
from typing import Any, Callable, List, Optional, Sequence, Union
|
||||
from typing import Any, Callable, List, Optional, Union
|
||||
|
||||
from fastapi._compat import ModelField
|
||||
from fastapi.security.base import SecurityBase
|
||||
|
|
@ -28,12 +28,6 @@ def _impartial(func: Callable[..., Any]) -> Callable[..., Any]:
|
|||
return func
|
||||
|
||||
|
||||
@dataclass
|
||||
class SecurityRequirement:
|
||||
security_scheme: SecurityBase
|
||||
scopes: Optional[Sequence[str]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class Dependant:
|
||||
path_params: List[ModelField] = field(default_factory=list)
|
||||
|
|
@ -42,7 +36,6 @@ class Dependant:
|
|||
cookie_params: List[ModelField] = field(default_factory=list)
|
||||
body_params: List[ModelField] = field(default_factory=list)
|
||||
dependencies: List["Dependant"] = field(default_factory=list)
|
||||
security_requirements: List[SecurityRequirement] = field(default_factory=list)
|
||||
name: Optional[str] = None
|
||||
call: Optional[Callable[..., Any]] = None
|
||||
request_param_name: Optional[str] = None
|
||||
|
|
@ -83,11 +76,32 @@ class Dependant:
|
|||
return True
|
||||
if self.security_scopes_param_name is not None:
|
||||
return True
|
||||
if self._is_security_scheme:
|
||||
return True
|
||||
for sub_dep in self.dependencies:
|
||||
if sub_dep._uses_scopes:
|
||||
return True
|
||||
return False
|
||||
|
||||
@cached_property
|
||||
def _is_security_scheme(self) -> bool:
|
||||
if self.call is None:
|
||||
return False # pragma: no cover
|
||||
unwrapped = _unwrapped_call(self.call)
|
||||
return isinstance(unwrapped, SecurityBase)
|
||||
|
||||
# Mainly to get the type of SecurityBase, but it's the same self.call
|
||||
@cached_property
|
||||
def _security_scheme(self) -> SecurityBase:
|
||||
unwrapped = _unwrapped_call(self.call)
|
||||
assert isinstance(unwrapped, SecurityBase)
|
||||
return unwrapped
|
||||
|
||||
@cached_property
|
||||
def _security_dependencies(self) -> List["Dependant"]:
|
||||
security_deps = [dep for dep in self.dependencies if dep._is_security_scheme]
|
||||
return security_deps
|
||||
|
||||
@cached_property
|
||||
def is_gen_callable(self) -> bool:
|
||||
if self.call is None:
|
||||
|
|
|
|||
|
|
@ -55,10 +55,9 @@ from fastapi.concurrency import (
|
|||
asynccontextmanager,
|
||||
contextmanager_in_threadpool,
|
||||
)
|
||||
from fastapi.dependencies.models import Dependant, SecurityRequirement
|
||||
from fastapi.dependencies.models import Dependant
|
||||
from fastapi.exceptions import DependencyScopeError
|
||||
from fastapi.logger import logger
|
||||
from fastapi.security.base import SecurityBase
|
||||
from fastapi.security.oauth2 import SecurityScopes
|
||||
from fastapi.types import DependencyCacheKey
|
||||
from fastapi.utils import create_model_field, get_path_param_names
|
||||
|
|
@ -142,10 +141,14 @@ def get_flat_dependant(
|
|||
*,
|
||||
skip_repeats: bool = False,
|
||||
visited: Optional[List[DependencyCacheKey]] = None,
|
||||
parent_oauth_scopes: Optional[List[str]] = None,
|
||||
) -> Dependant:
|
||||
if visited is None:
|
||||
visited = []
|
||||
visited.append(dependant.cache_key)
|
||||
use_parent_oauth_scopes = (parent_oauth_scopes or []) + (
|
||||
dependant.oauth_scopes or []
|
||||
)
|
||||
|
||||
flat_dependant = Dependant(
|
||||
path_params=dependant.path_params.copy(),
|
||||
|
|
@ -153,22 +156,37 @@ def get_flat_dependant(
|
|||
header_params=dependant.header_params.copy(),
|
||||
cookie_params=dependant.cookie_params.copy(),
|
||||
body_params=dependant.body_params.copy(),
|
||||
security_requirements=dependant.security_requirements.copy(),
|
||||
name=dependant.name,
|
||||
call=dependant.call,
|
||||
request_param_name=dependant.request_param_name,
|
||||
websocket_param_name=dependant.websocket_param_name,
|
||||
http_connection_param_name=dependant.http_connection_param_name,
|
||||
response_param_name=dependant.response_param_name,
|
||||
background_tasks_param_name=dependant.background_tasks_param_name,
|
||||
security_scopes_param_name=dependant.security_scopes_param_name,
|
||||
own_oauth_scopes=dependant.own_oauth_scopes,
|
||||
parent_oauth_scopes=use_parent_oauth_scopes,
|
||||
use_cache=dependant.use_cache,
|
||||
path=dependant.path,
|
||||
scope=dependant.scope,
|
||||
)
|
||||
for sub_dependant in dependant.dependencies:
|
||||
if skip_repeats and sub_dependant.cache_key in visited:
|
||||
continue
|
||||
flat_sub = get_flat_dependant(
|
||||
sub_dependant, skip_repeats=skip_repeats, visited=visited
|
||||
sub_dependant,
|
||||
skip_repeats=skip_repeats,
|
||||
visited=visited,
|
||||
parent_oauth_scopes=flat_dependant.oauth_scopes,
|
||||
)
|
||||
flat_dependant.dependencies.append(flat_sub)
|
||||
flat_dependant.path_params.extend(flat_sub.path_params)
|
||||
flat_dependant.query_params.extend(flat_sub.query_params)
|
||||
flat_dependant.header_params.extend(flat_sub.header_params)
|
||||
flat_dependant.cookie_params.extend(flat_sub.cookie_params)
|
||||
flat_dependant.body_params.extend(flat_sub.body_params)
|
||||
flat_dependant.security_requirements.extend(flat_sub.security_requirements)
|
||||
flat_dependant.dependencies.extend(flat_sub.dependencies)
|
||||
|
||||
return flat_dependant
|
||||
|
||||
|
||||
|
|
@ -258,11 +276,6 @@ def get_dependant(
|
|||
path_param_names = get_path_param_names(path)
|
||||
endpoint_signature = get_typed_signature(call)
|
||||
signature_params = endpoint_signature.parameters
|
||||
if isinstance(call, SecurityBase):
|
||||
security_requirement = SecurityRequirement(
|
||||
security_scheme=call, scopes=current_scopes
|
||||
)
|
||||
dependant.security_requirements.append(security_requirement)
|
||||
for param_name, param in signature_params.items():
|
||||
is_path_param = param_name in path_param_names
|
||||
param_details = analyze_param(
|
||||
|
|
|
|||
|
|
@ -81,18 +81,18 @@ def get_openapi_security_definitions(
|
|||
security_definitions = {}
|
||||
# Use a dict to merge scopes for same security scheme
|
||||
operation_security_dict: Dict[str, List[str]] = {}
|
||||
for security_requirement in flat_dependant.security_requirements:
|
||||
for security_dependency in flat_dependant._security_dependencies:
|
||||
security_definition = jsonable_encoder(
|
||||
security_requirement.security_scheme.model,
|
||||
security_dependency._security_scheme.model,
|
||||
by_alias=True,
|
||||
exclude_none=True,
|
||||
)
|
||||
security_name = security_requirement.security_scheme.scheme_name
|
||||
security_name = security_dependency._security_scheme.scheme_name
|
||||
security_definitions[security_name] = security_definition
|
||||
# Merge scopes for the same security scheme
|
||||
if security_name not in operation_security_dict:
|
||||
operation_security_dict[security_name] = []
|
||||
for scope in security_requirement.scopes or []:
|
||||
for scope in security_dependency.oauth_scopes or []:
|
||||
if scope not in operation_security_dict[security_name]:
|
||||
operation_security_dict[security_name].append(scope)
|
||||
operation_security = [
|
||||
|
|
|
|||
Loading…
Reference in New Issue