mirror of https://github.com/tiangolo/fastapi.git
Merge b64518fe02 into 0127069d47
This commit is contained in:
commit
1e676233aa
|
|
@ -398,6 +398,13 @@ class OAuth2(SecurityBase):
|
|||
self.scheme_name = scheme_name or self.__class__.__name__
|
||||
self.auto_error = auto_error
|
||||
|
||||
def _set_flows(
|
||||
self,
|
||||
*,
|
||||
flows: OAuthFlowsModel | dict[str, dict[str, Any]],
|
||||
) -> None:
|
||||
cast(OAuth2Model, self.model).flows = cast(OAuthFlowsModel, flows)
|
||||
|
||||
def make_not_authenticated_error(self) -> HTTPException:
|
||||
"""
|
||||
The OAuth 2 specification doesn't define the challenge that should be used,
|
||||
|
|
@ -452,7 +459,7 @@ class OAuth2PasswordBearer(OAuth2):
|
|||
[FastAPI docs for Simple OAuth2 with Password and Bearer](https://fastapi.tiangolo.com/tutorial/security/simple-oauth2/).
|
||||
"""
|
||||
),
|
||||
],
|
||||
] = "",
|
||||
scheme_name: Annotated[
|
||||
str | None,
|
||||
Doc(
|
||||
|
|
@ -514,6 +521,21 @@ class OAuth2PasswordBearer(OAuth2):
|
|||
),
|
||||
] = None,
|
||||
):
|
||||
super().__init__(
|
||||
flows=OAuthFlowsModel(),
|
||||
scheme_name=scheme_name,
|
||||
description=description,
|
||||
auto_error=auto_error,
|
||||
)
|
||||
self.initialize(tokenUrl=tokenUrl, refreshUrl=refreshUrl, scopes=scopes)
|
||||
|
||||
def initialize(
|
||||
self,
|
||||
*,
|
||||
tokenUrl: str,
|
||||
refreshUrl: str | None = None,
|
||||
scopes: dict[str, str] | None = None,
|
||||
) -> None:
|
||||
if not scopes:
|
||||
scopes = {}
|
||||
flows = OAuthFlowsModel(
|
||||
|
|
@ -526,12 +548,7 @@ class OAuth2PasswordBearer(OAuth2):
|
|||
},
|
||||
)
|
||||
)
|
||||
super().__init__(
|
||||
flows=flows,
|
||||
scheme_name=scheme_name,
|
||||
description=description,
|
||||
auto_error=auto_error,
|
||||
)
|
||||
super()._set_flows(flows=flows)
|
||||
|
||||
async def __call__(self, request: Request) -> str | None:
|
||||
authorization = request.headers.get("Authorization")
|
||||
|
|
@ -552,7 +569,7 @@ class OAuth2AuthorizationCodeBearer(OAuth2):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
authorizationUrl: str,
|
||||
authorizationUrl: str = "",
|
||||
tokenUrl: Annotated[
|
||||
str,
|
||||
Doc(
|
||||
|
|
@ -560,7 +577,7 @@ class OAuth2AuthorizationCodeBearer(OAuth2):
|
|||
The URL to obtain the OAuth2 token.
|
||||
"""
|
||||
),
|
||||
],
|
||||
] = "",
|
||||
refreshUrl: Annotated[
|
||||
str | None,
|
||||
Doc(
|
||||
|
|
@ -619,6 +636,27 @@ class OAuth2AuthorizationCodeBearer(OAuth2):
|
|||
),
|
||||
] = True,
|
||||
):
|
||||
super().__init__(
|
||||
flows=OAuthFlowsModel(),
|
||||
scheme_name=scheme_name,
|
||||
description=description,
|
||||
auto_error=auto_error,
|
||||
)
|
||||
self.initialize(
|
||||
authorizationUrl=authorizationUrl,
|
||||
tokenUrl=tokenUrl,
|
||||
refreshUrl=refreshUrl,
|
||||
scopes=scopes,
|
||||
)
|
||||
|
||||
def initialize(
|
||||
self,
|
||||
*,
|
||||
authorizationUrl: str,
|
||||
tokenUrl: str,
|
||||
refreshUrl: str | None = None,
|
||||
scopes: dict[str, str] | None = None,
|
||||
) -> None:
|
||||
if not scopes:
|
||||
scopes = {}
|
||||
flows = OAuthFlowsModel(
|
||||
|
|
@ -632,12 +670,7 @@ class OAuth2AuthorizationCodeBearer(OAuth2):
|
|||
},
|
||||
)
|
||||
)
|
||||
super().__init__(
|
||||
flows=flows,
|
||||
scheme_name=scheme_name,
|
||||
description=description,
|
||||
auto_error=auto_error,
|
||||
)
|
||||
super()._set_flows(flows=flows)
|
||||
|
||||
async def __call__(self, request: Request) -> str | None:
|
||||
authorization = request.headers.get("Authorization")
|
||||
|
|
|
|||
|
|
@ -0,0 +1,86 @@
|
|||
# Ref: https://github.com/fastapi/fastapi/issues/3317
|
||||
|
||||
from fastapi import APIRouter, FastAPI, Security
|
||||
from fastapi.security import OAuth2AuthorizationCodeBearer, OAuth2PasswordBearer
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
auth_code_scheme = OAuth2AuthorizationCodeBearer()
|
||||
auth_code_router = APIRouter()
|
||||
|
||||
|
||||
@auth_code_router.get("/private-route")
|
||||
async def private_route(
|
||||
token: str | None = Security(auth_code_scheme, scopes=["admin"]),
|
||||
):
|
||||
return {"token": token}
|
||||
|
||||
|
||||
def create_auth_code_app() -> FastAPI:
|
||||
app = FastAPI()
|
||||
app.include_router(auth_code_router)
|
||||
auth_code_scheme.initialize(
|
||||
authorizationUrl="https://example.com/authorize",
|
||||
tokenUrl="https://example.com/oauth/token",
|
||||
scopes={"admin": "Admin access"},
|
||||
)
|
||||
return app
|
||||
|
||||
|
||||
def test_oauth2_authorization_code_bearer_lazy_initialize():
|
||||
app = create_auth_code_app()
|
||||
client = TestClient(app)
|
||||
|
||||
response = client.get(
|
||||
"/private-route", headers={"Authorization": "Bearer testtoken"}
|
||||
)
|
||||
assert response.status_code == 200, response.text
|
||||
assert response.json() == {"token": "testtoken"}
|
||||
|
||||
openapi = client.get("/openapi.json")
|
||||
assert openapi.status_code == 200, openapi.text
|
||||
authorization_code_flow = openapi.json()["components"]["securitySchemes"][
|
||||
"OAuth2AuthorizationCodeBearer"
|
||||
]["flows"]["authorizationCode"]
|
||||
assert (
|
||||
authorization_code_flow["authorizationUrl"] == "https://example.com/authorize"
|
||||
)
|
||||
assert authorization_code_flow["tokenUrl"] == "https://example.com/oauth/token"
|
||||
assert authorization_code_flow["scopes"] == {"admin": "Admin access"}
|
||||
|
||||
|
||||
password_scheme = OAuth2PasswordBearer()
|
||||
password_router = APIRouter()
|
||||
|
||||
|
||||
@password_router.get("/password-route")
|
||||
async def password_route(token: str | None = Security(password_scheme)):
|
||||
return {"token": token}
|
||||
|
||||
|
||||
def create_password_app() -> FastAPI:
|
||||
app = FastAPI()
|
||||
app.include_router(password_router)
|
||||
password_scheme.initialize(
|
||||
tokenUrl="https://example.com/oauth/token",
|
||||
scopes={"read": "Read access"},
|
||||
)
|
||||
return app
|
||||
|
||||
|
||||
def test_oauth2_password_bearer_lazy_initialize():
|
||||
app = create_password_app()
|
||||
client = TestClient(app)
|
||||
|
||||
response = client.get(
|
||||
"/password-route", headers={"Authorization": "Bearer testtoken"}
|
||||
)
|
||||
assert response.status_code == 200, response.text
|
||||
assert response.json() == {"token": "testtoken"}
|
||||
|
||||
openapi = client.get("/openapi.json")
|
||||
assert openapi.status_code == 200, openapi.text
|
||||
password_flow = openapi.json()["components"]["securitySchemes"][
|
||||
"OAuth2PasswordBearer"
|
||||
]["flows"]["password"]
|
||||
assert password_flow["tokenUrl"] == "https://example.com/oauth/token"
|
||||
assert password_flow["scopes"] == {"read": "Read access"}
|
||||
Loading…
Reference in New Issue