diff --git a/fastapi/dependencies/models.py b/fastapi/dependencies/models.py index d6359c0f5..fbb666a7d 100644 --- a/fastapi/dependencies/models.py +++ b/fastapi/dependencies/models.py @@ -38,19 +38,43 @@ class Dependant: response_param_name: Optional[str] = None background_tasks_param_name: Optional[str] = None security_scopes_param_name: Optional[str] = None - security_scopes: Optional[List[str]] = None + own_oauth_scopes: Optional[List[str]] = None + parent_oauth_scopes: Optional[List[str]] = None use_cache: bool = True path: Optional[str] = None scope: Union[Literal["function", "request"], None] = None + @cached_property + def oauth_scopes(self) -> List[str]: + scopes = self.parent_oauth_scopes.copy() if self.parent_oauth_scopes else [] + # This doesn't use a set to preserve order, just in case + for scope in self.own_oauth_scopes or []: + if scope not in scopes: + scopes.append(scope) + return scopes + @cached_property def cache_key(self) -> DependencyCacheKey: + scopes_for_cache = ( + tuple(sorted(set(self.oauth_scopes or []))) if self._uses_scopes else () + ) return ( self.call, - tuple(sorted(set(self.security_scopes or []))), + scopes_for_cache, self.computed_scope or "", ) + @cached_property + def _uses_scopes(self) -> bool: + if self.own_oauth_scopes: + return True + if self.security_scopes_param_name is not None: + return True + for sub_dep in self.dependencies: + if sub_dep._uses_scopes: + return True + return False + @cached_property def is_gen_callable(self) -> bool: if inspect.isgeneratorfunction(self.call): diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index 45353835b..d43fa8a51 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -58,8 +58,7 @@ from fastapi.dependencies.models import Dependant, SecurityRequirement from fastapi.exceptions import DependencyScopeError from fastapi.logger import logger from fastapi.security.base import SecurityBase -from fastapi.security.oauth2 import OAuth2, SecurityScopes -from fastapi.security.open_id_connect_url import OpenIdConnect +from fastapi.security.oauth2 import SecurityScopes from fastapi.types import DependencyCacheKey from fastapi.utils import create_model_field, get_path_param_names from pydantic import BaseModel @@ -126,14 +125,14 @@ def get_parameterless_sub_dependant(*, depends: params.Depends, path: str) -> De assert callable(depends.dependency), ( "A parameter-less dependency must have a callable dependency" ) - use_security_scopes: List[str] = [] + own_oauth_scopes: List[str] = [] if isinstance(depends, params.Security) and depends.scopes: - use_security_scopes.extend(depends.scopes) + own_oauth_scopes.extend(depends.scopes) return get_dependant( path=path, call=depends.dependency, scope=depends.scope, - security_scopes=use_security_scopes, + own_oauth_scopes=own_oauth_scopes, ) @@ -232,7 +231,8 @@ def get_dependant( path: str, call: Callable[..., Any], name: Optional[str] = None, - security_scopes: Optional[List[str]] = None, + own_oauth_scopes: Optional[List[str]] = None, + parent_oauth_scopes: Optional[List[str]] = None, use_cache: bool = True, scope: Union[Literal["function", "request"], None] = None, ) -> Dependant: @@ -240,19 +240,18 @@ def get_dependant( call=call, name=name, path=path, - security_scopes=security_scopes, use_cache=use_cache, scope=scope, + own_oauth_scopes=own_oauth_scopes, + parent_oauth_scopes=parent_oauth_scopes, ) + current_scopes = (parent_oauth_scopes or []) + (own_oauth_scopes or []) path_param_names = get_path_param_names(path) endpoint_signature = get_typed_signature(call) signature_params = endpoint_signature.parameters if isinstance(call, SecurityBase): - use_scopes: List[str] = [] - if isinstance(call, (OAuth2, OpenIdConnect)): - use_scopes = security_scopes or use_scopes security_requirement = SecurityRequirement( - security_scheme=call, scopes=use_scopes + security_scheme=call, scopes=current_scopes ) dependant.security_requirements.append(security_requirement) for param_name, param in signature_params.items(): @@ -275,17 +274,16 @@ def get_dependant( f'The dependency "{dependant.call.__name__}" has a scope of ' '"request", it cannot depend on dependencies with scope "function".' ) - use_security_scopes = security_scopes or [] + sub_own_oauth_scopes: List[str] = [] if isinstance(param_details.depends, params.Security): if param_details.depends.scopes: - use_security_scopes = use_security_scopes + list( - param_details.depends.scopes - ) + sub_own_oauth_scopes = list(param_details.depends.scopes) sub_dependant = get_dependant( path=path, call=param_details.depends.dependency, name=param_name, - security_scopes=use_security_scopes, + own_oauth_scopes=sub_own_oauth_scopes, + parent_oauth_scopes=current_scopes, use_cache=param_details.depends.use_cache, scope=param_details.depends.scope, ) @@ -611,7 +609,7 @@ async def solve_dependencies( path=use_path, call=call, name=sub_dependant.name, - security_scopes=sub_dependant.security_scopes, + parent_oauth_scopes=sub_dependant.oauth_scopes, scope=sub_dependant.scope, ) @@ -693,7 +691,7 @@ async def solve_dependencies( values[dependant.response_param_name] = response if dependant.security_scopes_param_name: values[dependant.security_scopes_param_name] = SecurityScopes( - scopes=dependant.security_scopes + scopes=dependant.oauth_scopes ) return SolvedDependency( values=values, diff --git a/tests/test_security_scopes.py b/tests/test_security_scopes.py new file mode 100644 index 000000000..248fd2bcc --- /dev/null +++ b/tests/test_security_scopes.py @@ -0,0 +1,46 @@ +from typing import Dict + +import pytest +from fastapi import Depends, FastAPI, Security +from fastapi.testclient import TestClient +from typing_extensions import Annotated + + +@pytest.fixture(name="call_counter") +def call_counter_fixture(): + return {"count": 0} + + +@pytest.fixture(name="app") +def app_fixture(call_counter: Dict[str, int]): + def get_db(): + call_counter["count"] += 1 + return f"db_{call_counter['count']}" + + def get_user(db: Annotated[str, Depends(get_db)]): + return "user" + + app = FastAPI() + + @app.get("/") + def endpoint( + db: Annotated[str, Depends(get_db)], + user: Annotated[str, Security(get_user, scopes=["read"])], + ): + return {"db": db} + + return app + + +@pytest.fixture(name="client") +def client_fixture(app: FastAPI): + return TestClient(app) + + +def test_security_scopes_dependency_called_once( + client: TestClient, call_counter: Dict[str, int] +): + response = client.get("/") + + assert response.status_code == 200 + assert call_counter["count"] == 1 diff --git a/tests/test_security_scopes_sub_dependency.py b/tests/test_security_scopes_sub_dependency.py new file mode 100644 index 000000000..9cc668d8e --- /dev/null +++ b/tests/test_security_scopes_sub_dependency.py @@ -0,0 +1,107 @@ +# Ref: https://github.com/fastapi/fastapi/discussions/6024#discussioncomment-8541913 + +from typing import Dict + +import pytest +from fastapi import Depends, FastAPI, Security +from fastapi.security import SecurityScopes +from fastapi.testclient import TestClient +from typing_extensions import Annotated + + +@pytest.fixture(name="call_counts") +def call_counts_fixture(): + return { + "get_db_session": 0, + "get_current_user": 0, + "get_user_me": 0, + "get_user_items": 0, + } + + +@pytest.fixture(name="app") +def app_fixture(call_counts: Dict[str, int]): + def get_db_session(): + call_counts["get_db_session"] += 1 + return f"db_session_{call_counts['get_db_session']}" + + def get_current_user( + security_scopes: SecurityScopes, + db_session: Annotated[str, Depends(get_db_session)], + ): + call_counts["get_current_user"] += 1 + return { + "user": f"user_{call_counts['get_current_user']}", + "scopes": security_scopes.scopes, + "db_session": db_session, + } + + def get_user_me( + current_user: Annotated[dict, Security(get_current_user, scopes=["me"])], + ): + call_counts["get_user_me"] += 1 + return { + "user_me": f"user_me_{call_counts['get_user_me']}", + "current_user": current_user, + } + + def get_user_items( + user_me: Annotated[dict, Depends(get_user_me)], + ): + call_counts["get_user_items"] += 1 + return { + "user_items": f"user_items_{call_counts['get_user_items']}", + "user_me": user_me, + } + + app = FastAPI() + + @app.get("/") + def path_operation( + user_me: Annotated[dict, Depends(get_user_me)], + user_items: Annotated[dict, Security(get_user_items, scopes=["items"])], + ): + return { + "user_me": user_me, + "user_items": user_items, + } + + return app + + +@pytest.fixture(name="client") +def client_fixture(app: FastAPI): + return TestClient(app) + + +def test_security_scopes_sub_dependency_caching( + client: TestClient, call_counts: Dict[str, int] +): + response = client.get("/") + + assert response.status_code == 200 + assert call_counts["get_db_session"] == 1 + assert call_counts["get_current_user"] == 2 + assert call_counts["get_user_me"] == 2 + assert call_counts["get_user_items"] == 1 + assert response.json() == { + "user_me": { + "user_me": "user_me_1", + "current_user": { + "user": "user_1", + "scopes": ["me"], + "db_session": "db_session_1", + }, + }, + "user_items": { + "user_items": "user_items_1", + "user_me": { + "user_me": "user_me_2", + "current_user": { + "user": "user_2", + "scopes": ["items", "me"], + "db_session": "db_session_1", + }, + }, + }, + }