diff --git a/fastapi/dependencies/models.py b/fastapi/dependencies/models.py index 9b545e4e5..af168a177 100644 --- a/fastapi/dependencies/models.py +++ b/fastapi/dependencies/models.py @@ -2,7 +2,7 @@ import inspect import sys from dataclasses import dataclass, field from functools import cached_property, partial -from typing import Any, Callable, List, Optional, Sequence, Union +from typing import Any, Callable, List, Optional, Union from fastapi._compat import ModelField from fastapi.security.base import SecurityBase @@ -28,12 +28,6 @@ def _impartial(func: Callable[..., Any]) -> Callable[..., Any]: return func -@dataclass -class SecurityRequirement: - security_scheme: SecurityBase - scopes: Optional[Sequence[str]] = None - - @dataclass class Dependant: path_params: List[ModelField] = field(default_factory=list) @@ -42,7 +36,6 @@ class Dependant: cookie_params: List[ModelField] = field(default_factory=list) body_params: List[ModelField] = field(default_factory=list) dependencies: List["Dependant"] = field(default_factory=list) - security_requirements: List[SecurityRequirement] = field(default_factory=list) name: Optional[str] = None call: Optional[Callable[..., Any]] = None request_param_name: Optional[str] = None @@ -83,11 +76,32 @@ class Dependant: return True if self.security_scopes_param_name is not None: return True + if self._is_security_scheme: + return True for sub_dep in self.dependencies: if sub_dep._uses_scopes: return True return False + @cached_property + def _is_security_scheme(self) -> bool: + if self.call is None: + return False # pragma: no cover + unwrapped = _unwrapped_call(self.call) + return isinstance(unwrapped, SecurityBase) + + # Mainly to get the type of SecurityBase, but it's the same self.call + @cached_property + def _security_scheme(self) -> SecurityBase: + unwrapped = _unwrapped_call(self.call) + assert isinstance(unwrapped, SecurityBase) + return unwrapped + + @cached_property + def _security_dependencies(self) -> List["Dependant"]: + security_deps = [dep for dep in self.dependencies if dep._is_security_scheme] + return security_deps + @cached_property def is_gen_callable(self) -> bool: if self.call is None: diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index 1ff35f648..23bca6f2a 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -55,10 +55,9 @@ from fastapi.concurrency import ( asynccontextmanager, contextmanager_in_threadpool, ) -from fastapi.dependencies.models import Dependant, SecurityRequirement +from fastapi.dependencies.models import Dependant from fastapi.exceptions import DependencyScopeError from fastapi.logger import logger -from fastapi.security.base import SecurityBase from fastapi.security.oauth2 import SecurityScopes from fastapi.types import DependencyCacheKey from fastapi.utils import create_model_field, get_path_param_names @@ -142,10 +141,14 @@ def get_flat_dependant( *, skip_repeats: bool = False, visited: Optional[List[DependencyCacheKey]] = None, + parent_oauth_scopes: Optional[List[str]] = None, ) -> Dependant: if visited is None: visited = [] visited.append(dependant.cache_key) + use_parent_oauth_scopes = (parent_oauth_scopes or []) + ( + dependant.oauth_scopes or [] + ) flat_dependant = Dependant( path_params=dependant.path_params.copy(), @@ -153,22 +156,37 @@ def get_flat_dependant( header_params=dependant.header_params.copy(), cookie_params=dependant.cookie_params.copy(), body_params=dependant.body_params.copy(), - security_requirements=dependant.security_requirements.copy(), + name=dependant.name, + call=dependant.call, + request_param_name=dependant.request_param_name, + websocket_param_name=dependant.websocket_param_name, + http_connection_param_name=dependant.http_connection_param_name, + response_param_name=dependant.response_param_name, + background_tasks_param_name=dependant.background_tasks_param_name, + security_scopes_param_name=dependant.security_scopes_param_name, + own_oauth_scopes=dependant.own_oauth_scopes, + parent_oauth_scopes=use_parent_oauth_scopes, use_cache=dependant.use_cache, path=dependant.path, + scope=dependant.scope, ) for sub_dependant in dependant.dependencies: if skip_repeats and sub_dependant.cache_key in visited: continue flat_sub = get_flat_dependant( - sub_dependant, skip_repeats=skip_repeats, visited=visited + sub_dependant, + skip_repeats=skip_repeats, + visited=visited, + parent_oauth_scopes=flat_dependant.oauth_scopes, ) + flat_dependant.dependencies.append(flat_sub) flat_dependant.path_params.extend(flat_sub.path_params) flat_dependant.query_params.extend(flat_sub.query_params) flat_dependant.header_params.extend(flat_sub.header_params) flat_dependant.cookie_params.extend(flat_sub.cookie_params) flat_dependant.body_params.extend(flat_sub.body_params) - flat_dependant.security_requirements.extend(flat_sub.security_requirements) + flat_dependant.dependencies.extend(flat_sub.dependencies) + return flat_dependant @@ -258,11 +276,6 @@ def get_dependant( path_param_names = get_path_param_names(path) endpoint_signature = get_typed_signature(call) signature_params = endpoint_signature.parameters - if isinstance(call, SecurityBase): - security_requirement = SecurityRequirement( - security_scheme=call, scopes=current_scopes - ) - dependant.security_requirements.append(security_requirement) for param_name, param in signature_params.items(): is_path_param = param_name in path_param_names param_details = analyze_param( diff --git a/fastapi/openapi/utils.py b/fastapi/openapi/utils.py index e7e6da2f7..06c14861a 100644 --- a/fastapi/openapi/utils.py +++ b/fastapi/openapi/utils.py @@ -81,18 +81,18 @@ def get_openapi_security_definitions( security_definitions = {} # Use a dict to merge scopes for same security scheme operation_security_dict: Dict[str, List[str]] = {} - for security_requirement in flat_dependant.security_requirements: + for security_dependency in flat_dependant._security_dependencies: security_definition = jsonable_encoder( - security_requirement.security_scheme.model, + security_dependency._security_scheme.model, by_alias=True, exclude_none=True, ) - security_name = security_requirement.security_scheme.scheme_name + security_name = security_dependency._security_scheme.scheme_name security_definitions[security_name] = security_definition # Merge scopes for the same security scheme if security_name not in operation_security_dict: operation_security_dict[security_name] = [] - for scope in security_requirement.scopes or []: + for scope in security_dependency.oauth_scopes or []: if scope not in operation_security_dict[security_name]: operation_security_dict[security_name].append(scope) operation_security = [ diff --git a/tests/test_security_oauth2_authorization_code_bearer_scopes_openapi.py b/tests/test_security_oauth2_authorization_code_bearer_scopes_openapi.py index 644df8de6..d41f1dc1f 100644 --- a/tests/test_security_oauth2_authorization_code_bearer_scopes_openapi.py +++ b/tests/test_security_oauth2_authorization_code_bearer_scopes_openapi.py @@ -2,10 +2,11 @@ from typing import Optional -from fastapi import APIRouter, FastAPI, Security +from fastapi import APIRouter, Depends, FastAPI, Security from fastapi.security import OAuth2AuthorizationCodeBearer from fastapi.testclient import TestClient from inline_snapshot import snapshot +from typing_extensions import Annotated oauth2_scheme = OAuth2AuthorizationCodeBearer( authorizationUrl="authorize", @@ -14,7 +15,12 @@ oauth2_scheme = OAuth2AuthorizationCodeBearer( scopes={"read": "Read access", "write": "Write access"}, ) -app = FastAPI(dependencies=[Security(oauth2_scheme)]) + +async def get_token(token: Annotated[str, Depends(oauth2_scheme)]) -> str: + return token + + +app = FastAPI(dependencies=[Depends(get_token)]) @app.get("/") @@ -22,11 +28,26 @@ async def root(): return {"message": "Hello World"} +@app.get( + "/with-oauth2-scheme", + dependencies=[Security(oauth2_scheme, scopes=["read", "write"])], +) +async def read_with_oauth2_scheme(): + return {"message": "Admin Access"} + + +@app.get( + "/with-get-token", dependencies=[Security(get_token, scopes=["read", "write"])] +) +async def read_with_get_token(): + return {"message": "Admin Access"} + + router = APIRouter(dependencies=[Security(oauth2_scheme, scopes=["read"])]) @router.get("/items/") -async def read_items(token: Optional[str] = Security(oauth2_scheme)): +async def read_items(token: Optional[str] = Depends(oauth2_scheme)): return {"token": token} @@ -48,6 +69,22 @@ def test_root(): assert response.json() == {"message": "Hello World"} +def test_read_with_oauth2_scheme(): + response = client.get( + "/with-oauth2-scheme", headers={"Authorization": "Bearer testtoken"} + ) + assert response.status_code == 200, response.text + assert response.json() == {"message": "Admin Access"} + + +def test_read_with_get_token(): + response = client.get( + "/with-get-token", headers={"Authorization": "Bearer testtoken"} + ) + assert response.status_code == 200, response.text + assert response.json() == {"message": "Admin Access"} + + def test_read_token(): response = client.get("/items/", headers={"Authorization": "Bearer testtoken"}) assert response.status_code == 200, response.text @@ -81,6 +118,36 @@ def test_openapi_schema(): "security": [{"OAuth2AuthorizationCodeBearer": []}], } }, + "/with-oauth2-scheme": { + "get": { + "summary": "Read With Oauth2 Scheme", + "operationId": "read_with_oauth2_scheme_with_oauth2_scheme_get", + "responses": { + "200": { + "description": "Successful Response", + "content": {"application/json": {"schema": {}}}, + } + }, + "security": [ + {"OAuth2AuthorizationCodeBearer": ["read", "write"]} + ], + } + }, + "/with-get-token": { + "get": { + "summary": "Read With Get Token", + "operationId": "read_with_get_token_with_get_token_get", + "responses": { + "200": { + "description": "Successful Response", + "content": {"application/json": {"schema": {}}}, + } + }, + "security": [ + {"OAuth2AuthorizationCodeBearer": ["read", "write"]} + ], + } + }, "/items/": { "get": { "summary": "Read Items", diff --git a/tests/test_security_oauth2_authorization_code_bearer_scopes_openapi_simple.py b/tests/test_security_oauth2_authorization_code_bearer_scopes_openapi_simple.py new file mode 100644 index 000000000..ff866d4fc --- /dev/null +++ b/tests/test_security_oauth2_authorization_code_bearer_scopes_openapi_simple.py @@ -0,0 +1,79 @@ +# Ref: https://github.com/fastapi/fastapi/issues/14454 + +from fastapi import Depends, FastAPI, Security +from fastapi.security import OAuth2AuthorizationCodeBearer +from fastapi.testclient import TestClient +from inline_snapshot import snapshot +from typing_extensions import Annotated + +oauth2_scheme = OAuth2AuthorizationCodeBearer( + authorizationUrl="api/oauth/authorize", + tokenUrl="/api/oauth/token", + scopes={"read": "Read access", "write": "Write access"}, +) + + +async def get_token(token: Annotated[str, Depends(oauth2_scheme)]) -> str: + return token + + +app = FastAPI(dependencies=[Depends(get_token)]) + + +@app.get("/admin", dependencies=[Security(get_token, scopes=["read", "write"])]) +async def read_admin(): + return {"message": "Admin Access"} + + +client = TestClient(app) + + +def test_read_admin(): + response = client.get("/admin", headers={"Authorization": "Bearer faketoken"}) + assert response.status_code == 200, response.text + assert response.json() == {"message": "Admin Access"} + + +def test_openapi_schema(): + response = client.get("/openapi.json") + assert response.status_code == 200, response.text + assert response.json() == snapshot( + { + "openapi": "3.1.0", + "info": {"title": "FastAPI", "version": "0.1.0"}, + "paths": { + "/admin": { + "get": { + "summary": "Read Admin", + "operationId": "read_admin_admin_get", + "responses": { + "200": { + "description": "Successful Response", + "content": {"application/json": {"schema": {}}}, + } + }, + "security": [ + {"OAuth2AuthorizationCodeBearer": ["read", "write"]} + ], + } + } + }, + "components": { + "securitySchemes": { + "OAuth2AuthorizationCodeBearer": { + "type": "oauth2", + "flows": { + "authorizationCode": { + "scopes": { + "read": "Read access", + "write": "Write access", + }, + "authorizationUrl": "api/oauth/authorize", + "tokenUrl": "/api/oauth/token", + } + }, + } + } + }, + } + )