fastapi/docs_src/security/tutorial008_an_py39.py

139 lines
4.7 KiB
Python

from typing import Annotated, Any, Dict, Optional
import httpx
from cachetools import TTLCache
from fastapi import Depends, FastAPI, HTTPException, Security
from fastapi.security import (
HTTPAuthorizationCredentials,
HTTPBearer,
OpenIdConnect,
SecurityScopes,
)
from jose import JWTError, jwt
from pydantic import Field
from pydantic_settings import BaseSettings
from starlette.requests import Request
from starlette.status import HTTP_400_BAD_REQUEST, HTTP_403_FORBIDDEN
class AccessTokenCredentials(HTTPAuthorizationCredentials):
token: Dict[str, Any]
class AccessTokenValidator(HTTPBearer):
"""Generic HTTPBearer Validator that validates JWT tokens given the JWKS provided at jwks_url."""
def __init__(
self,
*,
jwks_url: str,
audience: str,
issuer: str,
expire_seconds: int = 3600,
roles_claim: str = "groups",
scheme_name: Optional[str] = None,
description: Optional[str] = None,
):
super().__init__(scheme_name=scheme_name, description=description)
self.uri = jwks_url
self.audience = audience
self.issuer = issuer
self.roles_claim = roles_claim
self.keyset_cache: TTLCache[str, str] = TTLCache(16, expire_seconds)
async def get_jwt_keyset(self) -> str:
"""Retrieves keyset when expired/not cached yet."""
result: Optional[str] = self.keyset_cache.get(self.uri)
if result is None:
async with httpx.AsyncClient() as client:
response = await client.get(self.uri)
result = self.keyset_cache[self.uri] = response.text
return result
async def __call__(
self, request: Request, security_scopes: SecurityScopes
) -> AccessTokenCredentials: # type: ignore
"""Validates the JWT Access Token. If security_scopes are given, they are validated against the roles_claim in the Access Token."""
# 1. Unpack bearer token
unverified_token = await super().__call__(request)
if not unverified_token:
raise HTTPException(HTTP_400_BAD_REQUEST, "Invalid Access Token")
access_token = unverified_token.credentials
try:
# 2. Get keyset from authorization server so that we can validate the JWT Access Token
keyset = await self.get_jwt_keyset()
# 3. Perform validation
verified_token = jwt.decode(
token=access_token,
key=keyset,
audience=self.audience,
issuer=self.issuer,
)
except JWTError:
raise HTTPException(
status_code=HTTP_400_BAD_REQUEST,
detail="Unsupported authorization code",
) from None
# 4. if security scopes are present, validate them
if security_scopes and security_scopes.scopes:
# 4.1 the roles_claim must be present in the access token
scopes = verified_token.get(self.roles_claim)
if scopes is None:
raise HTTPException(
status_code=HTTP_400_BAD_REQUEST, detail="Unsupported Access Token"
)
# 4.2 all required roles in the roles_claim must be present
if not set(security_scopes.scopes).issubset(set(scopes)):
raise HTTPException(
status_code=HTTP_403_FORBIDDEN, detail="Not Authorized"
)
return AccessTokenCredentials(
scheme=self.scheme_name, credentials=access_token, token=verified_token
)
class Settings(BaseSettings):
"""Settings wil be read from an .env file"""
issuer: str = Field(default=...)
audience: str = Field(default=...)
client_id: str = Field(default=...)
class Config:
env_file = ".env"
settings = Settings()
# Standard OIDC URLs
oidc_url = f"{settings.issuer}/.well-known/openid-configuration"
jwks_url = f"{settings.issuer}/v1/keys"
openid_connect = OpenIdConnect(openIdConnectUrl=oidc_url)
swagger_ui_init_oauth = {
"clientId": settings.client_id,
"scopes": ["openid"], # fill in additional scopes when necessary
"appName": "Test Application",
"usePkceWithAuthorizationCodeGrant": True,
}
# The openid_connect security scheme is given as a dependency so that you can authenticate using the swagger UI
app = FastAPI(
swagger_ui_init_oauth=swagger_ui_init_oauth, dependencies=[Depends(openid_connect)]
)
# the tokenvalidator is used for all endpoints that need to be authorized
oauth2 = AccessTokenValidator(
jwks_url=jwks_url, audience=settings.audience, issuer=settings.issuer
)
@app.get("/hello")
async def hello(
token: Annotated[AccessTokenCredentials, Security(oauth2, scopes=["Foo"])],
) -> str:
return "Hi!"