mirror of https://github.com/tiangolo/fastapi.git
🐛 Cache dependencies that don't use scopes and don't have sub-dependencies with scopes (#14419)
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
This commit is contained in:
parent
63d7a2b997
commit
7fbd30460f
|
|
@ -38,19 +38,43 @@ class Dependant:
|
||||||
response_param_name: Optional[str] = None
|
response_param_name: Optional[str] = None
|
||||||
background_tasks_param_name: Optional[str] = None
|
background_tasks_param_name: Optional[str] = None
|
||||||
security_scopes_param_name: Optional[str] = None
|
security_scopes_param_name: Optional[str] = None
|
||||||
security_scopes: Optional[List[str]] = None
|
own_oauth_scopes: Optional[List[str]] = None
|
||||||
|
parent_oauth_scopes: Optional[List[str]] = None
|
||||||
use_cache: bool = True
|
use_cache: bool = True
|
||||||
path: Optional[str] = None
|
path: Optional[str] = None
|
||||||
scope: Union[Literal["function", "request"], None] = None
|
scope: Union[Literal["function", "request"], None] = None
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def oauth_scopes(self) -> List[str]:
|
||||||
|
scopes = self.parent_oauth_scopes.copy() if self.parent_oauth_scopes else []
|
||||||
|
# This doesn't use a set to preserve order, just in case
|
||||||
|
for scope in self.own_oauth_scopes or []:
|
||||||
|
if scope not in scopes:
|
||||||
|
scopes.append(scope)
|
||||||
|
return scopes
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def cache_key(self) -> DependencyCacheKey:
|
def cache_key(self) -> DependencyCacheKey:
|
||||||
|
scopes_for_cache = (
|
||||||
|
tuple(sorted(set(self.oauth_scopes or []))) if self._uses_scopes else ()
|
||||||
|
)
|
||||||
return (
|
return (
|
||||||
self.call,
|
self.call,
|
||||||
tuple(sorted(set(self.security_scopes or []))),
|
scopes_for_cache,
|
||||||
self.computed_scope or "",
|
self.computed_scope or "",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def _uses_scopes(self) -> bool:
|
||||||
|
if self.own_oauth_scopes:
|
||||||
|
return True
|
||||||
|
if self.security_scopes_param_name is not None:
|
||||||
|
return True
|
||||||
|
for sub_dep in self.dependencies:
|
||||||
|
if sub_dep._uses_scopes:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def is_gen_callable(self) -> bool:
|
def is_gen_callable(self) -> bool:
|
||||||
if inspect.isgeneratorfunction(self.call):
|
if inspect.isgeneratorfunction(self.call):
|
||||||
|
|
|
||||||
|
|
@ -58,8 +58,7 @@ from fastapi.dependencies.models import Dependant, SecurityRequirement
|
||||||
from fastapi.exceptions import DependencyScopeError
|
from fastapi.exceptions import DependencyScopeError
|
||||||
from fastapi.logger import logger
|
from fastapi.logger import logger
|
||||||
from fastapi.security.base import SecurityBase
|
from fastapi.security.base import SecurityBase
|
||||||
from fastapi.security.oauth2 import OAuth2, SecurityScopes
|
from fastapi.security.oauth2 import SecurityScopes
|
||||||
from fastapi.security.open_id_connect_url import OpenIdConnect
|
|
||||||
from fastapi.types import DependencyCacheKey
|
from fastapi.types import DependencyCacheKey
|
||||||
from fastapi.utils import create_model_field, get_path_param_names
|
from fastapi.utils import create_model_field, get_path_param_names
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
@ -126,14 +125,14 @@ def get_parameterless_sub_dependant(*, depends: params.Depends, path: str) -> De
|
||||||
assert callable(depends.dependency), (
|
assert callable(depends.dependency), (
|
||||||
"A parameter-less dependency must have a callable dependency"
|
"A parameter-less dependency must have a callable dependency"
|
||||||
)
|
)
|
||||||
use_security_scopes: List[str] = []
|
own_oauth_scopes: List[str] = []
|
||||||
if isinstance(depends, params.Security) and depends.scopes:
|
if isinstance(depends, params.Security) and depends.scopes:
|
||||||
use_security_scopes.extend(depends.scopes)
|
own_oauth_scopes.extend(depends.scopes)
|
||||||
return get_dependant(
|
return get_dependant(
|
||||||
path=path,
|
path=path,
|
||||||
call=depends.dependency,
|
call=depends.dependency,
|
||||||
scope=depends.scope,
|
scope=depends.scope,
|
||||||
security_scopes=use_security_scopes,
|
own_oauth_scopes=own_oauth_scopes,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -232,7 +231,8 @@ def get_dependant(
|
||||||
path: str,
|
path: str,
|
||||||
call: Callable[..., Any],
|
call: Callable[..., Any],
|
||||||
name: Optional[str] = None,
|
name: Optional[str] = None,
|
||||||
security_scopes: Optional[List[str]] = None,
|
own_oauth_scopes: Optional[List[str]] = None,
|
||||||
|
parent_oauth_scopes: Optional[List[str]] = None,
|
||||||
use_cache: bool = True,
|
use_cache: bool = True,
|
||||||
scope: Union[Literal["function", "request"], None] = None,
|
scope: Union[Literal["function", "request"], None] = None,
|
||||||
) -> Dependant:
|
) -> Dependant:
|
||||||
|
|
@ -240,19 +240,18 @@ def get_dependant(
|
||||||
call=call,
|
call=call,
|
||||||
name=name,
|
name=name,
|
||||||
path=path,
|
path=path,
|
||||||
security_scopes=security_scopes,
|
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
scope=scope,
|
scope=scope,
|
||||||
|
own_oauth_scopes=own_oauth_scopes,
|
||||||
|
parent_oauth_scopes=parent_oauth_scopes,
|
||||||
)
|
)
|
||||||
|
current_scopes = (parent_oauth_scopes or []) + (own_oauth_scopes or [])
|
||||||
path_param_names = get_path_param_names(path)
|
path_param_names = get_path_param_names(path)
|
||||||
endpoint_signature = get_typed_signature(call)
|
endpoint_signature = get_typed_signature(call)
|
||||||
signature_params = endpoint_signature.parameters
|
signature_params = endpoint_signature.parameters
|
||||||
if isinstance(call, SecurityBase):
|
if isinstance(call, SecurityBase):
|
||||||
use_scopes: List[str] = []
|
|
||||||
if isinstance(call, (OAuth2, OpenIdConnect)):
|
|
||||||
use_scopes = security_scopes or use_scopes
|
|
||||||
security_requirement = SecurityRequirement(
|
security_requirement = SecurityRequirement(
|
||||||
security_scheme=call, scopes=use_scopes
|
security_scheme=call, scopes=current_scopes
|
||||||
)
|
)
|
||||||
dependant.security_requirements.append(security_requirement)
|
dependant.security_requirements.append(security_requirement)
|
||||||
for param_name, param in signature_params.items():
|
for param_name, param in signature_params.items():
|
||||||
|
|
@ -275,17 +274,16 @@ def get_dependant(
|
||||||
f'The dependency "{dependant.call.__name__}" has a scope of '
|
f'The dependency "{dependant.call.__name__}" has a scope of '
|
||||||
'"request", it cannot depend on dependencies with scope "function".'
|
'"request", it cannot depend on dependencies with scope "function".'
|
||||||
)
|
)
|
||||||
use_security_scopes = security_scopes or []
|
sub_own_oauth_scopes: List[str] = []
|
||||||
if isinstance(param_details.depends, params.Security):
|
if isinstance(param_details.depends, params.Security):
|
||||||
if param_details.depends.scopes:
|
if param_details.depends.scopes:
|
||||||
use_security_scopes = use_security_scopes + list(
|
sub_own_oauth_scopes = list(param_details.depends.scopes)
|
||||||
param_details.depends.scopes
|
|
||||||
)
|
|
||||||
sub_dependant = get_dependant(
|
sub_dependant = get_dependant(
|
||||||
path=path,
|
path=path,
|
||||||
call=param_details.depends.dependency,
|
call=param_details.depends.dependency,
|
||||||
name=param_name,
|
name=param_name,
|
||||||
security_scopes=use_security_scopes,
|
own_oauth_scopes=sub_own_oauth_scopes,
|
||||||
|
parent_oauth_scopes=current_scopes,
|
||||||
use_cache=param_details.depends.use_cache,
|
use_cache=param_details.depends.use_cache,
|
||||||
scope=param_details.depends.scope,
|
scope=param_details.depends.scope,
|
||||||
)
|
)
|
||||||
|
|
@ -611,7 +609,7 @@ async def solve_dependencies(
|
||||||
path=use_path,
|
path=use_path,
|
||||||
call=call,
|
call=call,
|
||||||
name=sub_dependant.name,
|
name=sub_dependant.name,
|
||||||
security_scopes=sub_dependant.security_scopes,
|
parent_oauth_scopes=sub_dependant.oauth_scopes,
|
||||||
scope=sub_dependant.scope,
|
scope=sub_dependant.scope,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -693,7 +691,7 @@ async def solve_dependencies(
|
||||||
values[dependant.response_param_name] = response
|
values[dependant.response_param_name] = response
|
||||||
if dependant.security_scopes_param_name:
|
if dependant.security_scopes_param_name:
|
||||||
values[dependant.security_scopes_param_name] = SecurityScopes(
|
values[dependant.security_scopes_param_name] = SecurityScopes(
|
||||||
scopes=dependant.security_scopes
|
scopes=dependant.oauth_scopes
|
||||||
)
|
)
|
||||||
return SolvedDependency(
|
return SolvedDependency(
|
||||||
values=values,
|
values=values,
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,46 @@
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi import Depends, FastAPI, Security
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(name="call_counter")
|
||||||
|
def call_counter_fixture():
|
||||||
|
return {"count": 0}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(name="app")
|
||||||
|
def app_fixture(call_counter: Dict[str, int]):
|
||||||
|
def get_db():
|
||||||
|
call_counter["count"] += 1
|
||||||
|
return f"db_{call_counter['count']}"
|
||||||
|
|
||||||
|
def get_user(db: Annotated[str, Depends(get_db)]):
|
||||||
|
return "user"
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
@app.get("/")
|
||||||
|
def endpoint(
|
||||||
|
db: Annotated[str, Depends(get_db)],
|
||||||
|
user: Annotated[str, Security(get_user, scopes=["read"])],
|
||||||
|
):
|
||||||
|
return {"db": db}
|
||||||
|
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(name="client")
|
||||||
|
def client_fixture(app: FastAPI):
|
||||||
|
return TestClient(app)
|
||||||
|
|
||||||
|
|
||||||
|
def test_security_scopes_dependency_called_once(
|
||||||
|
client: TestClient, call_counter: Dict[str, int]
|
||||||
|
):
|
||||||
|
response = client.get("/")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert call_counter["count"] == 1
|
||||||
|
|
@ -0,0 +1,107 @@
|
||||||
|
# Ref: https://github.com/fastapi/fastapi/discussions/6024#discussioncomment-8541913
|
||||||
|
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi import Depends, FastAPI, Security
|
||||||
|
from fastapi.security import SecurityScopes
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(name="call_counts")
|
||||||
|
def call_counts_fixture():
|
||||||
|
return {
|
||||||
|
"get_db_session": 0,
|
||||||
|
"get_current_user": 0,
|
||||||
|
"get_user_me": 0,
|
||||||
|
"get_user_items": 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(name="app")
|
||||||
|
def app_fixture(call_counts: Dict[str, int]):
|
||||||
|
def get_db_session():
|
||||||
|
call_counts["get_db_session"] += 1
|
||||||
|
return f"db_session_{call_counts['get_db_session']}"
|
||||||
|
|
||||||
|
def get_current_user(
|
||||||
|
security_scopes: SecurityScopes,
|
||||||
|
db_session: Annotated[str, Depends(get_db_session)],
|
||||||
|
):
|
||||||
|
call_counts["get_current_user"] += 1
|
||||||
|
return {
|
||||||
|
"user": f"user_{call_counts['get_current_user']}",
|
||||||
|
"scopes": security_scopes.scopes,
|
||||||
|
"db_session": db_session,
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_user_me(
|
||||||
|
current_user: Annotated[dict, Security(get_current_user, scopes=["me"])],
|
||||||
|
):
|
||||||
|
call_counts["get_user_me"] += 1
|
||||||
|
return {
|
||||||
|
"user_me": f"user_me_{call_counts['get_user_me']}",
|
||||||
|
"current_user": current_user,
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_user_items(
|
||||||
|
user_me: Annotated[dict, Depends(get_user_me)],
|
||||||
|
):
|
||||||
|
call_counts["get_user_items"] += 1
|
||||||
|
return {
|
||||||
|
"user_items": f"user_items_{call_counts['get_user_items']}",
|
||||||
|
"user_me": user_me,
|
||||||
|
}
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
@app.get("/")
|
||||||
|
def path_operation(
|
||||||
|
user_me: Annotated[dict, Depends(get_user_me)],
|
||||||
|
user_items: Annotated[dict, Security(get_user_items, scopes=["items"])],
|
||||||
|
):
|
||||||
|
return {
|
||||||
|
"user_me": user_me,
|
||||||
|
"user_items": user_items,
|
||||||
|
}
|
||||||
|
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(name="client")
|
||||||
|
def client_fixture(app: FastAPI):
|
||||||
|
return TestClient(app)
|
||||||
|
|
||||||
|
|
||||||
|
def test_security_scopes_sub_dependency_caching(
|
||||||
|
client: TestClient, call_counts: Dict[str, int]
|
||||||
|
):
|
||||||
|
response = client.get("/")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert call_counts["get_db_session"] == 1
|
||||||
|
assert call_counts["get_current_user"] == 2
|
||||||
|
assert call_counts["get_user_me"] == 2
|
||||||
|
assert call_counts["get_user_items"] == 1
|
||||||
|
assert response.json() == {
|
||||||
|
"user_me": {
|
||||||
|
"user_me": "user_me_1",
|
||||||
|
"current_user": {
|
||||||
|
"user": "user_1",
|
||||||
|
"scopes": ["me"],
|
||||||
|
"db_session": "db_session_1",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"user_items": {
|
||||||
|
"user_items": "user_items_1",
|
||||||
|
"user_me": {
|
||||||
|
"user_me": "user_me_2",
|
||||||
|
"current_user": {
|
||||||
|
"user": "user_2",
|
||||||
|
"scopes": ["items", "me"],
|
||||||
|
"db_session": "db_session_1",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue