🐛 Cache dependencies that don't use scopes and don't have sub-dependencies with scopes (#14419)

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
This commit is contained in:
Sebastián Ramírez 2025-11-30 06:45:49 -08:00 committed by GitHub
parent 63d7a2b997
commit 7fbd30460f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 195 additions and 20 deletions

View File

@ -38,19 +38,43 @@ class Dependant:
response_param_name: Optional[str] = None response_param_name: Optional[str] = None
background_tasks_param_name: Optional[str] = None background_tasks_param_name: Optional[str] = None
security_scopes_param_name: Optional[str] = None security_scopes_param_name: Optional[str] = None
security_scopes: Optional[List[str]] = None own_oauth_scopes: Optional[List[str]] = None
parent_oauth_scopes: Optional[List[str]] = None
use_cache: bool = True use_cache: bool = True
path: Optional[str] = None path: Optional[str] = None
scope: Union[Literal["function", "request"], None] = None scope: Union[Literal["function", "request"], None] = None
@cached_property
def oauth_scopes(self) -> List[str]:
scopes = self.parent_oauth_scopes.copy() if self.parent_oauth_scopes else []
# This doesn't use a set to preserve order, just in case
for scope in self.own_oauth_scopes or []:
if scope not in scopes:
scopes.append(scope)
return scopes
@cached_property @cached_property
def cache_key(self) -> DependencyCacheKey: def cache_key(self) -> DependencyCacheKey:
scopes_for_cache = (
tuple(sorted(set(self.oauth_scopes or []))) if self._uses_scopes else ()
)
return ( return (
self.call, self.call,
tuple(sorted(set(self.security_scopes or []))), scopes_for_cache,
self.computed_scope or "", self.computed_scope or "",
) )
@cached_property
def _uses_scopes(self) -> bool:
if self.own_oauth_scopes:
return True
if self.security_scopes_param_name is not None:
return True
for sub_dep in self.dependencies:
if sub_dep._uses_scopes:
return True
return False
@cached_property @cached_property
def is_gen_callable(self) -> bool: def is_gen_callable(self) -> bool:
if inspect.isgeneratorfunction(self.call): if inspect.isgeneratorfunction(self.call):

View File

@ -58,8 +58,7 @@ from fastapi.dependencies.models import Dependant, SecurityRequirement
from fastapi.exceptions import DependencyScopeError from fastapi.exceptions import DependencyScopeError
from fastapi.logger import logger from fastapi.logger import logger
from fastapi.security.base import SecurityBase from fastapi.security.base import SecurityBase
from fastapi.security.oauth2 import OAuth2, SecurityScopes from fastapi.security.oauth2 import SecurityScopes
from fastapi.security.open_id_connect_url import OpenIdConnect
from fastapi.types import DependencyCacheKey from fastapi.types import DependencyCacheKey
from fastapi.utils import create_model_field, get_path_param_names from fastapi.utils import create_model_field, get_path_param_names
from pydantic import BaseModel from pydantic import BaseModel
@ -126,14 +125,14 @@ def get_parameterless_sub_dependant(*, depends: params.Depends, path: str) -> De
assert callable(depends.dependency), ( assert callable(depends.dependency), (
"A parameter-less dependency must have a callable dependency" "A parameter-less dependency must have a callable dependency"
) )
use_security_scopes: List[str] = [] own_oauth_scopes: List[str] = []
if isinstance(depends, params.Security) and depends.scopes: if isinstance(depends, params.Security) and depends.scopes:
use_security_scopes.extend(depends.scopes) own_oauth_scopes.extend(depends.scopes)
return get_dependant( return get_dependant(
path=path, path=path,
call=depends.dependency, call=depends.dependency,
scope=depends.scope, scope=depends.scope,
security_scopes=use_security_scopes, own_oauth_scopes=own_oauth_scopes,
) )
@ -232,7 +231,8 @@ def get_dependant(
path: str, path: str,
call: Callable[..., Any], call: Callable[..., Any],
name: Optional[str] = None, name: Optional[str] = None,
security_scopes: Optional[List[str]] = None, own_oauth_scopes: Optional[List[str]] = None,
parent_oauth_scopes: Optional[List[str]] = None,
use_cache: bool = True, use_cache: bool = True,
scope: Union[Literal["function", "request"], None] = None, scope: Union[Literal["function", "request"], None] = None,
) -> Dependant: ) -> Dependant:
@ -240,19 +240,18 @@ def get_dependant(
call=call, call=call,
name=name, name=name,
path=path, path=path,
security_scopes=security_scopes,
use_cache=use_cache, use_cache=use_cache,
scope=scope, scope=scope,
own_oauth_scopes=own_oauth_scopes,
parent_oauth_scopes=parent_oauth_scopes,
) )
current_scopes = (parent_oauth_scopes or []) + (own_oauth_scopes or [])
path_param_names = get_path_param_names(path) path_param_names = get_path_param_names(path)
endpoint_signature = get_typed_signature(call) endpoint_signature = get_typed_signature(call)
signature_params = endpoint_signature.parameters signature_params = endpoint_signature.parameters
if isinstance(call, SecurityBase): if isinstance(call, SecurityBase):
use_scopes: List[str] = []
if isinstance(call, (OAuth2, OpenIdConnect)):
use_scopes = security_scopes or use_scopes
security_requirement = SecurityRequirement( security_requirement = SecurityRequirement(
security_scheme=call, scopes=use_scopes security_scheme=call, scopes=current_scopes
) )
dependant.security_requirements.append(security_requirement) dependant.security_requirements.append(security_requirement)
for param_name, param in signature_params.items(): for param_name, param in signature_params.items():
@ -275,17 +274,16 @@ def get_dependant(
f'The dependency "{dependant.call.__name__}" has a scope of ' f'The dependency "{dependant.call.__name__}" has a scope of '
'"request", it cannot depend on dependencies with scope "function".' '"request", it cannot depend on dependencies with scope "function".'
) )
use_security_scopes = security_scopes or [] sub_own_oauth_scopes: List[str] = []
if isinstance(param_details.depends, params.Security): if isinstance(param_details.depends, params.Security):
if param_details.depends.scopes: if param_details.depends.scopes:
use_security_scopes = use_security_scopes + list( sub_own_oauth_scopes = list(param_details.depends.scopes)
param_details.depends.scopes
)
sub_dependant = get_dependant( sub_dependant = get_dependant(
path=path, path=path,
call=param_details.depends.dependency, call=param_details.depends.dependency,
name=param_name, name=param_name,
security_scopes=use_security_scopes, own_oauth_scopes=sub_own_oauth_scopes,
parent_oauth_scopes=current_scopes,
use_cache=param_details.depends.use_cache, use_cache=param_details.depends.use_cache,
scope=param_details.depends.scope, scope=param_details.depends.scope,
) )
@ -611,7 +609,7 @@ async def solve_dependencies(
path=use_path, path=use_path,
call=call, call=call,
name=sub_dependant.name, name=sub_dependant.name,
security_scopes=sub_dependant.security_scopes, parent_oauth_scopes=sub_dependant.oauth_scopes,
scope=sub_dependant.scope, scope=sub_dependant.scope,
) )
@ -693,7 +691,7 @@ async def solve_dependencies(
values[dependant.response_param_name] = response values[dependant.response_param_name] = response
if dependant.security_scopes_param_name: if dependant.security_scopes_param_name:
values[dependant.security_scopes_param_name] = SecurityScopes( values[dependant.security_scopes_param_name] = SecurityScopes(
scopes=dependant.security_scopes scopes=dependant.oauth_scopes
) )
return SolvedDependency( return SolvedDependency(
values=values, values=values,

View File

@ -0,0 +1,46 @@
from typing import Dict
import pytest
from fastapi import Depends, FastAPI, Security
from fastapi.testclient import TestClient
from typing_extensions import Annotated
@pytest.fixture(name="call_counter")
def call_counter_fixture():
return {"count": 0}
@pytest.fixture(name="app")
def app_fixture(call_counter: Dict[str, int]):
def get_db():
call_counter["count"] += 1
return f"db_{call_counter['count']}"
def get_user(db: Annotated[str, Depends(get_db)]):
return "user"
app = FastAPI()
@app.get("/")
def endpoint(
db: Annotated[str, Depends(get_db)],
user: Annotated[str, Security(get_user, scopes=["read"])],
):
return {"db": db}
return app
@pytest.fixture(name="client")
def client_fixture(app: FastAPI):
return TestClient(app)
def test_security_scopes_dependency_called_once(
client: TestClient, call_counter: Dict[str, int]
):
response = client.get("/")
assert response.status_code == 200
assert call_counter["count"] == 1

View File

@ -0,0 +1,107 @@
# Ref: https://github.com/fastapi/fastapi/discussions/6024#discussioncomment-8541913
from typing import Dict
import pytest
from fastapi import Depends, FastAPI, Security
from fastapi.security import SecurityScopes
from fastapi.testclient import TestClient
from typing_extensions import Annotated
@pytest.fixture(name="call_counts")
def call_counts_fixture():
return {
"get_db_session": 0,
"get_current_user": 0,
"get_user_me": 0,
"get_user_items": 0,
}
@pytest.fixture(name="app")
def app_fixture(call_counts: Dict[str, int]):
def get_db_session():
call_counts["get_db_session"] += 1
return f"db_session_{call_counts['get_db_session']}"
def get_current_user(
security_scopes: SecurityScopes,
db_session: Annotated[str, Depends(get_db_session)],
):
call_counts["get_current_user"] += 1
return {
"user": f"user_{call_counts['get_current_user']}",
"scopes": security_scopes.scopes,
"db_session": db_session,
}
def get_user_me(
current_user: Annotated[dict, Security(get_current_user, scopes=["me"])],
):
call_counts["get_user_me"] += 1
return {
"user_me": f"user_me_{call_counts['get_user_me']}",
"current_user": current_user,
}
def get_user_items(
user_me: Annotated[dict, Depends(get_user_me)],
):
call_counts["get_user_items"] += 1
return {
"user_items": f"user_items_{call_counts['get_user_items']}",
"user_me": user_me,
}
app = FastAPI()
@app.get("/")
def path_operation(
user_me: Annotated[dict, Depends(get_user_me)],
user_items: Annotated[dict, Security(get_user_items, scopes=["items"])],
):
return {
"user_me": user_me,
"user_items": user_items,
}
return app
@pytest.fixture(name="client")
def client_fixture(app: FastAPI):
return TestClient(app)
def test_security_scopes_sub_dependency_caching(
client: TestClient, call_counts: Dict[str, int]
):
response = client.get("/")
assert response.status_code == 200
assert call_counts["get_db_session"] == 1
assert call_counts["get_current_user"] == 2
assert call_counts["get_user_me"] == 2
assert call_counts["get_user_items"] == 1
assert response.json() == {
"user_me": {
"user_me": "user_me_1",
"current_user": {
"user": "user_1",
"scopes": ["me"],
"db_session": "db_session_1",
},
},
"user_items": {
"user_items": "user_items_1",
"user_me": {
"user_me": "user_me_2",
"current_user": {
"user": "user_2",
"scopes": ["items", "me"],
"db_session": "db_session_1",
},
},
},
}