diff --git a/fastapi/security/oauth2.py b/fastapi/security/oauth2.py index 42674b476c..d497ee732c 100644 --- a/fastapi/security/oauth2.py +++ b/fastapi/security/oauth2.py @@ -398,6 +398,13 @@ class OAuth2(SecurityBase): self.scheme_name = scheme_name or self.__class__.__name__ self.auto_error = auto_error + def _set_flows( + self, + *, + flows: OAuthFlowsModel | dict[str, dict[str, Any]], + ) -> None: + cast(OAuth2Model, self.model).flows = cast(OAuthFlowsModel, flows) + def make_not_authenticated_error(self) -> HTTPException: """ The OAuth 2 specification doesn't define the challenge that should be used, @@ -452,7 +459,7 @@ class OAuth2PasswordBearer(OAuth2): [FastAPI docs for Simple OAuth2 with Password and Bearer](https://fastapi.tiangolo.com/tutorial/security/simple-oauth2/). """ ), - ], + ] = "", scheme_name: Annotated[ str | None, Doc( @@ -514,6 +521,21 @@ class OAuth2PasswordBearer(OAuth2): ), ] = None, ): + super().__init__( + flows=OAuthFlowsModel(), + scheme_name=scheme_name, + description=description, + auto_error=auto_error, + ) + self.initialize(tokenUrl=tokenUrl, refreshUrl=refreshUrl, scopes=scopes) + + def initialize( + self, + *, + tokenUrl: str, + refreshUrl: str | None = None, + scopes: dict[str, str] | None = None, + ) -> None: if not scopes: scopes = {} flows = OAuthFlowsModel( @@ -526,12 +548,7 @@ class OAuth2PasswordBearer(OAuth2): }, ) ) - super().__init__( - flows=flows, - scheme_name=scheme_name, - description=description, - auto_error=auto_error, - ) + super()._set_flows(flows=flows) async def __call__(self, request: Request) -> str | None: authorization = request.headers.get("Authorization") @@ -552,7 +569,7 @@ class OAuth2AuthorizationCodeBearer(OAuth2): def __init__( self, - authorizationUrl: str, + authorizationUrl: str = "", tokenUrl: Annotated[ str, Doc( @@ -560,7 +577,7 @@ class OAuth2AuthorizationCodeBearer(OAuth2): The URL to obtain the OAuth2 token. """ ), - ], + ] = "", refreshUrl: Annotated[ str | None, Doc( @@ -619,6 +636,27 @@ class OAuth2AuthorizationCodeBearer(OAuth2): ), ] = True, ): + super().__init__( + flows=OAuthFlowsModel(), + scheme_name=scheme_name, + description=description, + auto_error=auto_error, + ) + self.initialize( + authorizationUrl=authorizationUrl, + tokenUrl=tokenUrl, + refreshUrl=refreshUrl, + scopes=scopes, + ) + + def initialize( + self, + *, + authorizationUrl: str, + tokenUrl: str, + refreshUrl: str | None = None, + scopes: dict[str, str] | None = None, + ) -> None: if not scopes: scopes = {} flows = OAuthFlowsModel( @@ -632,12 +670,7 @@ class OAuth2AuthorizationCodeBearer(OAuth2): }, ) ) - super().__init__( - flows=flows, - scheme_name=scheme_name, - description=description, - auto_error=auto_error, - ) + super()._set_flows(flows=flows) async def __call__(self, request: Request) -> str | None: authorization = request.headers.get("Authorization") diff --git a/tests/test_security_oauth2_lazy_initialization.py b/tests/test_security_oauth2_lazy_initialization.py new file mode 100644 index 0000000000..0c0d8fa72d --- /dev/null +++ b/tests/test_security_oauth2_lazy_initialization.py @@ -0,0 +1,86 @@ +# Ref: https://github.com/fastapi/fastapi/issues/3317 + +from fastapi import APIRouter, FastAPI, Security +from fastapi.security import OAuth2AuthorizationCodeBearer, OAuth2PasswordBearer +from fastapi.testclient import TestClient + +auth_code_scheme = OAuth2AuthorizationCodeBearer() +auth_code_router = APIRouter() + + +@auth_code_router.get("/private-route") +async def private_route( + token: str | None = Security(auth_code_scheme, scopes=["admin"]), +): + return {"token": token} + + +def create_auth_code_app() -> FastAPI: + app = FastAPI() + app.include_router(auth_code_router) + auth_code_scheme.initialize( + authorizationUrl="https://example.com/authorize", + tokenUrl="https://example.com/oauth/token", + scopes={"admin": "Admin access"}, + ) + return app + + +def test_oauth2_authorization_code_bearer_lazy_initialize(): + app = create_auth_code_app() + client = TestClient(app) + + response = client.get( + "/private-route", headers={"Authorization": "Bearer testtoken"} + ) + assert response.status_code == 200, response.text + assert response.json() == {"token": "testtoken"} + + openapi = client.get("/openapi.json") + assert openapi.status_code == 200, openapi.text + authorization_code_flow = openapi.json()["components"]["securitySchemes"][ + "OAuth2AuthorizationCodeBearer" + ]["flows"]["authorizationCode"] + assert ( + authorization_code_flow["authorizationUrl"] == "https://example.com/authorize" + ) + assert authorization_code_flow["tokenUrl"] == "https://example.com/oauth/token" + assert authorization_code_flow["scopes"] == {"admin": "Admin access"} + + +password_scheme = OAuth2PasswordBearer() +password_router = APIRouter() + + +@password_router.get("/password-route") +async def password_route(token: str | None = Security(password_scheme)): + return {"token": token} + + +def create_password_app() -> FastAPI: + app = FastAPI() + app.include_router(password_router) + password_scheme.initialize( + tokenUrl="https://example.com/oauth/token", + scopes={"read": "Read access"}, + ) + return app + + +def test_oauth2_password_bearer_lazy_initialize(): + app = create_password_app() + client = TestClient(app) + + response = client.get( + "/password-route", headers={"Authorization": "Bearer testtoken"} + ) + assert response.status_code == 200, response.text + assert response.json() == {"token": "testtoken"} + + openapi = client.get("/openapi.json") + assert openapi.status_code == 200, openapi.text + password_flow = openapi.json()["components"]["securitySchemes"][ + "OAuth2PasswordBearer" + ]["flows"]["password"] + assert password_flow["tokenUrl"] == "https://example.com/oauth/token" + assert password_flow["scopes"] == {"read": "Read access"}