🐛 Cache dependencies that don't use scopes and don't have sub-dependencies with scopes (#14419)

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
This commit is contained in:
Sebastián Ramírez 2025-11-30 06:45:49 -08:00 committed by GitHub
parent 63d7a2b997
commit 7fbd30460f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 195 additions and 20 deletions

View File

@ -38,19 +38,43 @@ class Dependant:
response_param_name: Optional[str] = None
background_tasks_param_name: Optional[str] = None
security_scopes_param_name: Optional[str] = None
security_scopes: Optional[List[str]] = None
own_oauth_scopes: Optional[List[str]] = None
parent_oauth_scopes: Optional[List[str]] = None
use_cache: bool = True
path: Optional[str] = None
scope: Union[Literal["function", "request"], None] = None
@cached_property
def oauth_scopes(self) -> List[str]:
scopes = self.parent_oauth_scopes.copy() if self.parent_oauth_scopes else []
# This doesn't use a set to preserve order, just in case
for scope in self.own_oauth_scopes or []:
if scope not in scopes:
scopes.append(scope)
return scopes
@cached_property
def cache_key(self) -> DependencyCacheKey:
scopes_for_cache = (
tuple(sorted(set(self.oauth_scopes or []))) if self._uses_scopes else ()
)
return (
self.call,
tuple(sorted(set(self.security_scopes or []))),
scopes_for_cache,
self.computed_scope or "",
)
@cached_property
def _uses_scopes(self) -> bool:
if self.own_oauth_scopes:
return True
if self.security_scopes_param_name is not None:
return True
for sub_dep in self.dependencies:
if sub_dep._uses_scopes:
return True
return False
@cached_property
def is_gen_callable(self) -> bool:
if inspect.isgeneratorfunction(self.call):

View File

@ -58,8 +58,7 @@ from fastapi.dependencies.models import Dependant, SecurityRequirement
from fastapi.exceptions import DependencyScopeError
from fastapi.logger import logger
from fastapi.security.base import SecurityBase
from fastapi.security.oauth2 import OAuth2, SecurityScopes
from fastapi.security.open_id_connect_url import OpenIdConnect
from fastapi.security.oauth2 import SecurityScopes
from fastapi.types import DependencyCacheKey
from fastapi.utils import create_model_field, get_path_param_names
from pydantic import BaseModel
@ -126,14 +125,14 @@ def get_parameterless_sub_dependant(*, depends: params.Depends, path: str) -> De
assert callable(depends.dependency), (
"A parameter-less dependency must have a callable dependency"
)
use_security_scopes: List[str] = []
own_oauth_scopes: List[str] = []
if isinstance(depends, params.Security) and depends.scopes:
use_security_scopes.extend(depends.scopes)
own_oauth_scopes.extend(depends.scopes)
return get_dependant(
path=path,
call=depends.dependency,
scope=depends.scope,
security_scopes=use_security_scopes,
own_oauth_scopes=own_oauth_scopes,
)
@ -232,7 +231,8 @@ def get_dependant(
path: str,
call: Callable[..., Any],
name: Optional[str] = None,
security_scopes: Optional[List[str]] = None,
own_oauth_scopes: Optional[List[str]] = None,
parent_oauth_scopes: Optional[List[str]] = None,
use_cache: bool = True,
scope: Union[Literal["function", "request"], None] = None,
) -> Dependant:
@ -240,19 +240,18 @@ def get_dependant(
call=call,
name=name,
path=path,
security_scopes=security_scopes,
use_cache=use_cache,
scope=scope,
own_oauth_scopes=own_oauth_scopes,
parent_oauth_scopes=parent_oauth_scopes,
)
current_scopes = (parent_oauth_scopes or []) + (own_oauth_scopes or [])
path_param_names = get_path_param_names(path)
endpoint_signature = get_typed_signature(call)
signature_params = endpoint_signature.parameters
if isinstance(call, SecurityBase):
use_scopes: List[str] = []
if isinstance(call, (OAuth2, OpenIdConnect)):
use_scopes = security_scopes or use_scopes
security_requirement = SecurityRequirement(
security_scheme=call, scopes=use_scopes
security_scheme=call, scopes=current_scopes
)
dependant.security_requirements.append(security_requirement)
for param_name, param in signature_params.items():
@ -275,17 +274,16 @@ def get_dependant(
f'The dependency "{dependant.call.__name__}" has a scope of '
'"request", it cannot depend on dependencies with scope "function".'
)
use_security_scopes = security_scopes or []
sub_own_oauth_scopes: List[str] = []
if isinstance(param_details.depends, params.Security):
if param_details.depends.scopes:
use_security_scopes = use_security_scopes + list(
param_details.depends.scopes
)
sub_own_oauth_scopes = list(param_details.depends.scopes)
sub_dependant = get_dependant(
path=path,
call=param_details.depends.dependency,
name=param_name,
security_scopes=use_security_scopes,
own_oauth_scopes=sub_own_oauth_scopes,
parent_oauth_scopes=current_scopes,
use_cache=param_details.depends.use_cache,
scope=param_details.depends.scope,
)
@ -611,7 +609,7 @@ async def solve_dependencies(
path=use_path,
call=call,
name=sub_dependant.name,
security_scopes=sub_dependant.security_scopes,
parent_oauth_scopes=sub_dependant.oauth_scopes,
scope=sub_dependant.scope,
)
@ -693,7 +691,7 @@ async def solve_dependencies(
values[dependant.response_param_name] = response
if dependant.security_scopes_param_name:
values[dependant.security_scopes_param_name] = SecurityScopes(
scopes=dependant.security_scopes
scopes=dependant.oauth_scopes
)
return SolvedDependency(
values=values,

View File

@ -0,0 +1,46 @@
from typing import Dict
import pytest
from fastapi import Depends, FastAPI, Security
from fastapi.testclient import TestClient
from typing_extensions import Annotated
@pytest.fixture(name="call_counter")
def call_counter_fixture():
return {"count": 0}
@pytest.fixture(name="app")
def app_fixture(call_counter: Dict[str, int]):
def get_db():
call_counter["count"] += 1
return f"db_{call_counter['count']}"
def get_user(db: Annotated[str, Depends(get_db)]):
return "user"
app = FastAPI()
@app.get("/")
def endpoint(
db: Annotated[str, Depends(get_db)],
user: Annotated[str, Security(get_user, scopes=["read"])],
):
return {"db": db}
return app
@pytest.fixture(name="client")
def client_fixture(app: FastAPI):
return TestClient(app)
def test_security_scopes_dependency_called_once(
client: TestClient, call_counter: Dict[str, int]
):
response = client.get("/")
assert response.status_code == 200
assert call_counter["count"] == 1

View File

@ -0,0 +1,107 @@
# Ref: https://github.com/fastapi/fastapi/discussions/6024#discussioncomment-8541913
from typing import Dict
import pytest
from fastapi import Depends, FastAPI, Security
from fastapi.security import SecurityScopes
from fastapi.testclient import TestClient
from typing_extensions import Annotated
@pytest.fixture(name="call_counts")
def call_counts_fixture():
return {
"get_db_session": 0,
"get_current_user": 0,
"get_user_me": 0,
"get_user_items": 0,
}
@pytest.fixture(name="app")
def app_fixture(call_counts: Dict[str, int]):
def get_db_session():
call_counts["get_db_session"] += 1
return f"db_session_{call_counts['get_db_session']}"
def get_current_user(
security_scopes: SecurityScopes,
db_session: Annotated[str, Depends(get_db_session)],
):
call_counts["get_current_user"] += 1
return {
"user": f"user_{call_counts['get_current_user']}",
"scopes": security_scopes.scopes,
"db_session": db_session,
}
def get_user_me(
current_user: Annotated[dict, Security(get_current_user, scopes=["me"])],
):
call_counts["get_user_me"] += 1
return {
"user_me": f"user_me_{call_counts['get_user_me']}",
"current_user": current_user,
}
def get_user_items(
user_me: Annotated[dict, Depends(get_user_me)],
):
call_counts["get_user_items"] += 1
return {
"user_items": f"user_items_{call_counts['get_user_items']}",
"user_me": user_me,
}
app = FastAPI()
@app.get("/")
def path_operation(
user_me: Annotated[dict, Depends(get_user_me)],
user_items: Annotated[dict, Security(get_user_items, scopes=["items"])],
):
return {
"user_me": user_me,
"user_items": user_items,
}
return app
@pytest.fixture(name="client")
def client_fixture(app: FastAPI):
return TestClient(app)
def test_security_scopes_sub_dependency_caching(
client: TestClient, call_counts: Dict[str, int]
):
response = client.get("/")
assert response.status_code == 200
assert call_counts["get_db_session"] == 1
assert call_counts["get_current_user"] == 2
assert call_counts["get_user_me"] == 2
assert call_counts["get_user_items"] == 1
assert response.json() == {
"user_me": {
"user_me": "user_me_1",
"current_user": {
"user": "user_1",
"scopes": ["me"],
"db_session": "db_session_1",
},
},
"user_items": {
"user_items": "user_items_1",
"user_me": {
"user_me": "user_me_2",
"current_user": {
"user": "user_2",
"scopes": ["items", "me"],
"db_session": "db_session_1",
},
},
},
}