mirror of https://github.com/tiangolo/fastapi.git
✨ Refactor, update code, several features
This commit is contained in:
parent
b9d912c638
commit
addfa89b0f
|
|
@ -8,7 +8,8 @@ from starlette.responses import JSONResponse
|
|||
|
||||
|
||||
from fastapi import routing
|
||||
from fastapi.openapi.utils import get_swagger_ui_html, get_openapi, get_redoc_html
|
||||
from fastapi.openapi.utils import get_openapi
|
||||
from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
|
||||
|
||||
|
||||
async def http_exception(request, exc: HTTPException):
|
||||
|
|
@ -154,9 +155,11 @@ class FastAPI(Starlette):
|
|||
response_wrapper=response_wrapper,
|
||||
)
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def include_router(self, router: "APIRouter", *, prefix=""):
|
||||
self.router.include_router(router, prefix=prefix)
|
||||
|
||||
def get(
|
||||
self,
|
||||
path: str,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,78 @@
|
|||
from starlette.responses import HTMLResponse
|
||||
|
||||
def get_swagger_ui_html(*, openapi_url: str, title: str):
|
||||
return HTMLResponse(
|
||||
"""
|
||||
<! doctype html>
|
||||
<html>
|
||||
<head>
|
||||
<link type="text/css" rel="stylesheet" href="//unpkg.com/swagger-ui-dist@3/swagger-ui.css">
|
||||
<title>
|
||||
"""
|
||||
+ title
|
||||
+ """
|
||||
</title>
|
||||
</head>
|
||||
<body>
|
||||
<div id="swagger-ui">
|
||||
</div>
|
||||
<script src="//unpkg.com/swagger-ui-dist@3/swagger-ui-bundle.js"></script>
|
||||
<!-- `SwaggerUIBundle` is now available on the page -->
|
||||
<script>
|
||||
|
||||
const ui = SwaggerUIBundle({
|
||||
url: '"""
|
||||
+ openapi_url
|
||||
+ """',
|
||||
dom_id: '#swagger-ui',
|
||||
presets: [
|
||||
SwaggerUIBundle.presets.apis,
|
||||
SwaggerUIBundle.SwaggerUIStandalonePreset
|
||||
],
|
||||
layout: "BaseLayout"
|
||||
|
||||
})
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
""",
|
||||
media_type="text/html",
|
||||
)
|
||||
|
||||
|
||||
def get_redoc_html(*, openapi_url: str, title: str):
|
||||
return HTMLResponse(
|
||||
"""
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>
|
||||
"""
|
||||
+ title
|
||||
+ """
|
||||
</title>
|
||||
<!-- needed for adaptive design -->
|
||||
<meta charset="utf-8"/>
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1">
|
||||
<link href="https://fonts.googleapis.com/css?family=Montserrat:300,400,700|Roboto:300,400,700" rel="stylesheet">
|
||||
|
||||
<!--
|
||||
ReDoc doesn't change outer page styles
|
||||
-->
|
||||
<style>
|
||||
body {
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<redoc spec-url='"""
|
||||
+ openapi_url
|
||||
+ """'></redoc>
|
||||
<script src="https://cdn.jsdelivr.net/npm/redoc@next/bundles/redoc.standalone.js"> </script>
|
||||
</body>
|
||||
</html>
|
||||
""",
|
||||
media_type="text/html",
|
||||
)
|
||||
|
|
@ -1,4 +1,8 @@
|
|||
from typing import Any, Dict, Sequence, Type
|
||||
from typing import Any, Dict, Sequence, Type, List
|
||||
|
||||
from pydantic.fields import Field
|
||||
from pydantic.schema import field_schema, get_model_name_map
|
||||
from pydantic.utils import lenient_issubclass
|
||||
|
||||
from starlette.responses import HTMLResponse, JSONResponse
|
||||
from starlette.routing import BaseRoute
|
||||
|
|
@ -12,9 +16,7 @@ from fastapi.openapi.constants import REF_PREFIX, METHODS_WITH_BODY
|
|||
from fastapi.openapi.models import OpenAPI
|
||||
from fastapi.params import Body
|
||||
from fastapi.utils import get_flat_models_from_routes, get_model_definitions
|
||||
from pydantic.fields import Field
|
||||
from pydantic.schema import field_schema, get_model_name_map
|
||||
from pydantic.utils import lenient_issubclass
|
||||
|
||||
|
||||
validation_error_definition = {
|
||||
"title": "ValidationError",
|
||||
|
|
@ -49,57 +51,32 @@ def get_openapi_params(dependant: Dependant):
|
|||
+ flat_dependant.cookie_params
|
||||
)
|
||||
|
||||
def get_openapi_path(*, route: BaseRoute, model_name_map: Dict[Type, str]):
|
||||
if not (route.include_in_schema and isinstance(route, routing.APIRoute)):
|
||||
return None
|
||||
path = {}
|
||||
security_schemes = {}
|
||||
definitions = {}
|
||||
for method in route.methods:
|
||||
operation: Dict[str, Any] = {}
|
||||
if route.tags:
|
||||
operation["tags"] = route.tags
|
||||
if route.summary:
|
||||
operation["summary"] = route.summary
|
||||
if route.description:
|
||||
operation["description"] = route.description
|
||||
if route.operation_id:
|
||||
operation["operationId"] = route.operation_id
|
||||
else:
|
||||
operation["operationId"] = route.name
|
||||
if route.deprecated:
|
||||
operation["deprecated"] = route.deprecated
|
||||
parameters = []
|
||||
flat_dependant = get_flat_dependant(route.dependant)
|
||||
|
||||
def get_openapi_security_definitions(flat_dependant: Dependant):
|
||||
security_definitions = {}
|
||||
operation_security = []
|
||||
for security_requirement in flat_dependant.security_requirements:
|
||||
security_definition = jsonable_encoder(
|
||||
security_requirement.security_scheme,
|
||||
exclude={"scheme_name"},
|
||||
security_requirement.security_scheme.model,
|
||||
by_alias=True,
|
||||
include_none=False,
|
||||
)
|
||||
security_name = (
|
||||
getattr(
|
||||
security_requirement.security_scheme, "scheme_name", None
|
||||
)
|
||||
or security_requirement.security_scheme.__class__.__name__
|
||||
security_requirement.security_scheme.scheme_name
|
||||
|
||||
)
|
||||
security_definitions[security_name] = security_definition
|
||||
operation.setdefault("security", []).append(
|
||||
{security_name: security_requirement.scopes}
|
||||
)
|
||||
if security_definitions:
|
||||
security_schemes.update(
|
||||
security_definitions
|
||||
)
|
||||
all_route_params = get_openapi_params(route.dependant)
|
||||
operation_security.append({security_name: security_requirement.scopes})
|
||||
return security_definitions, operation_security
|
||||
|
||||
|
||||
def get_openapi_operation_parameters(all_route_params: List[Field]):
|
||||
definitions: Dict[str, Dict] = {}
|
||||
parameters = []
|
||||
for param in all_route_params:
|
||||
if "ValidationError" not in definitions:
|
||||
definitions["ValidationError"] = validation_error_definition
|
||||
definitions[
|
||||
"HTTPValidationError"
|
||||
] = validation_error_response_definition
|
||||
definitions["HTTPValidationError"] = validation_error_response_definition
|
||||
parameter = {
|
||||
"name": param.alias,
|
||||
"in": param.schema.in_.value,
|
||||
|
|
@ -111,16 +88,17 @@ def get_openapi_path(*, route: BaseRoute, model_name_map: Dict[Type, str]):
|
|||
if param.schema.deprecated:
|
||||
parameter["deprecated"] = param.schema.deprecated
|
||||
parameters.append(parameter)
|
||||
if parameters:
|
||||
operation["parameters"] = parameters
|
||||
if method in METHODS_WITH_BODY:
|
||||
body_field = route.body_field
|
||||
if body_field:
|
||||
return definitions, parameters
|
||||
|
||||
|
||||
def get_openapi_operation_request_body(
|
||||
*, body_field: Field, model_name_map: Dict[Type, str]
|
||||
):
|
||||
if not body_field:
|
||||
return None
|
||||
assert isinstance(body_field, Field)
|
||||
body_schema, _ = field_schema(
|
||||
body_field,
|
||||
model_name_map=model_name_map,
|
||||
ref_prefix=REF_PREFIX,
|
||||
body_field, model_name_map=model_name_map, ref_prefix=REF_PREFIX
|
||||
)
|
||||
if isinstance(body_field.schema, Body):
|
||||
request_media_type = body_field.schema.media_type
|
||||
|
|
@ -131,9 +109,68 @@ def get_openapi_path(*, route: BaseRoute, model_name_map: Dict[Type, str]):
|
|||
request_body_oai = {}
|
||||
if required:
|
||||
request_body_oai["required"] = required
|
||||
request_body_oai["content"] = {
|
||||
request_media_type: {"schema": body_schema}
|
||||
}
|
||||
request_body_oai["content"] = {request_media_type: {"schema": body_schema}}
|
||||
return request_body_oai
|
||||
|
||||
|
||||
def generate_operation_id(*, route: routing.APIRoute, method: str):
|
||||
if route.operation_id:
|
||||
return route.operation_id
|
||||
path: str = route.path
|
||||
operation_id = route.name + path
|
||||
operation_id = operation_id.replace("{", "_").replace("}", "_").replace("/", "_")
|
||||
operation_id = operation_id + "_" + method.lower()
|
||||
return operation_id
|
||||
|
||||
|
||||
def generate_operation_summary(*, route: routing.APIRoute, method: str):
|
||||
if route.summary:
|
||||
return route.summary
|
||||
return method.title() + " " + route.name.replace("_", " ").title()
|
||||
|
||||
def get_openapi_operation_metadata(*, route: BaseRoute, method: str):
|
||||
operation: Dict[str, Any] = {}
|
||||
if route.tags:
|
||||
operation["tags"] = route.tags
|
||||
operation["summary"] = generate_operation_summary(route=route, method=method)
|
||||
if route.description:
|
||||
operation["description"] = route.description
|
||||
operation["operationId"] = generate_operation_id(route=route, method=method)
|
||||
if route.deprecated:
|
||||
operation["deprecated"] = route.deprecated
|
||||
return operation
|
||||
|
||||
|
||||
def get_openapi_path(*, route: BaseRoute, model_name_map: Dict[Type, str]):
|
||||
if not (route.include_in_schema and isinstance(route, routing.APIRoute)):
|
||||
return None
|
||||
path = {}
|
||||
security_schemes: Dict[str, Any] = {}
|
||||
definitions: Dict[str, Any] = {}
|
||||
for method in route.methods:
|
||||
operation = get_openapi_operation_metadata(route=route, method=method)
|
||||
parameters: List[Dict] = []
|
||||
flat_dependant = get_flat_dependant(route.dependant)
|
||||
security_definitions, operation_security = get_openapi_security_definitions(
|
||||
flat_dependant=flat_dependant
|
||||
)
|
||||
if operation_security:
|
||||
operation.setdefault("security", []).extend(operation_security)
|
||||
if security_definitions:
|
||||
security_schemes.update(security_definitions)
|
||||
all_route_params = get_openapi_params(route.dependant)
|
||||
validation_definitions, operation_parameters = get_openapi_operation_parameters(
|
||||
all_route_params=all_route_params
|
||||
)
|
||||
definitions.update(validation_definitions)
|
||||
parameters.extend(operation_parameters)
|
||||
if parameters:
|
||||
operation["parameters"] = parameters
|
||||
if method in METHODS_WITH_BODY:
|
||||
request_body_oai = get_openapi_operation_request_body(
|
||||
body_field=route.body_field, model_name_map=model_name_map
|
||||
)
|
||||
if request_body_oai:
|
||||
operation["requestBody"] = request_body_oai
|
||||
response_code = str(route.response_code)
|
||||
response_schema = {"type": "string"}
|
||||
|
|
@ -206,75 +243,3 @@ def get_openapi(
|
|||
output["components"] = components
|
||||
output["paths"] = paths
|
||||
return jsonable_encoder(OpenAPI(**output), by_alias=True, include_none=False)
|
||||
|
||||
|
||||
def get_swagger_ui_html(*, openapi_url: str, title: str):
|
||||
return HTMLResponse(
|
||||
"""
|
||||
<! doctype html>
|
||||
<html>
|
||||
<head>
|
||||
<link type="text/css" rel="stylesheet" href="//unpkg.com/swagger-ui-dist@3/swagger-ui.css">
|
||||
<title>
|
||||
""" + title + """
|
||||
</title>
|
||||
</head>
|
||||
<body>
|
||||
<div id="swagger-ui">
|
||||
</div>
|
||||
<script src="//unpkg.com/swagger-ui-dist@3/swagger-ui-bundle.js"></script>
|
||||
<!-- `SwaggerUIBundle` is now available on the page -->
|
||||
<script>
|
||||
|
||||
const ui = SwaggerUIBundle({
|
||||
url: '"""
|
||||
+ openapi_url
|
||||
+ """',
|
||||
dom_id: '#swagger-ui',
|
||||
presets: [
|
||||
SwaggerUIBundle.presets.apis,
|
||||
SwaggerUIBundle.SwaggerUIStandalonePreset
|
||||
],
|
||||
layout: "BaseLayout"
|
||||
|
||||
})
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
""",
|
||||
media_type="text/html",
|
||||
)
|
||||
|
||||
|
||||
def get_redoc_html(*, openapi_url: str, title: str):
|
||||
return HTMLResponse(
|
||||
"""
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>
|
||||
""" + title + """
|
||||
</title>
|
||||
<!-- needed for adaptive design -->
|
||||
<meta charset="utf-8"/>
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1">
|
||||
<link href="https://fonts.googleapis.com/css?family=Montserrat:300,400,700|Roboto:300,400,700" rel="stylesheet">
|
||||
|
||||
<!--
|
||||
ReDoc doesn't change outer page styles
|
||||
-->
|
||||
<style>
|
||||
body {
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<redoc spec-url='""" + openapi_url + """'></redoc>
|
||||
<script src="https://cdn.jsdelivr.net/npm/redoc@next/bundles/redoc.standalone.js"> </script>
|
||||
</body>
|
||||
</html>
|
||||
""",
|
||||
media_type="text/html",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -2,6 +2,11 @@ import asyncio
|
|||
import inspect
|
||||
from typing import Callable, List, Type
|
||||
|
||||
from pydantic import BaseConfig, BaseModel, Schema
|
||||
from pydantic.error_wrappers import ErrorWrapper, ValidationError
|
||||
from pydantic.fields import Field
|
||||
from pydantic.utils import lenient_issubclass
|
||||
|
||||
from starlette import routing
|
||||
from starlette.concurrency import run_in_threadpool
|
||||
from starlette.exceptions import HTTPException
|
||||
|
|
@ -15,10 +20,6 @@ from fastapi import params
|
|||
from fastapi.dependencies.models import Dependant
|
||||
from fastapi.dependencies.utils import get_body_field, get_dependant, solve_dependencies
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from pydantic import BaseConfig, BaseModel, Schema
|
||||
from pydantic.error_wrappers import ErrorWrapper, ValidationError
|
||||
from pydantic.fields import Field
|
||||
from pydantic.utils import lenient_issubclass
|
||||
|
||||
|
||||
def serialize_response(*, field: Field = None, response):
|
||||
|
|
@ -44,11 +45,12 @@ def get_app(
|
|||
response_field: Type[Field] = None,
|
||||
):
|
||||
is_coroutine = dependant.call and asyncio.iscoroutinefunction(dependant.call)
|
||||
is_body_form = body_field and isinstance(body_field.schema, params.Form)
|
||||
|
||||
async def app(request: Request) -> Response:
|
||||
body = None
|
||||
if body_field:
|
||||
if isinstance(body_field.schema, params.Form):
|
||||
if is_body_form:
|
||||
raw_body = await request.form()
|
||||
body = {}
|
||||
for field, value in raw_body.items():
|
||||
|
|
@ -127,12 +129,7 @@ class APIRoute(routing.Route):
|
|||
response_code=200,
|
||||
response_wrapper=JSONResponse,
|
||||
) -> None:
|
||||
# TODO define how to read and provide security params, and how to have them globally too
|
||||
# TODO implement dependencies and injection
|
||||
# TODO refactor code structure
|
||||
# TODO create testing
|
||||
# TODO testing coverage
|
||||
assert path.startswith("/"), "Routed paths must always start '/'"
|
||||
assert path.startswith("/"), "Routed paths must always start with '/'"
|
||||
self.path = path
|
||||
self.endpoint = endpoint
|
||||
self.name = get_name(endpoint) if name is None else name
|
||||
|
|
@ -260,6 +257,39 @@ class APIRouter(routing.Router):
|
|||
|
||||
return decorator
|
||||
|
||||
def include_router(self, router: "APIRouter", *, prefix=""):
|
||||
if prefix:
|
||||
assert prefix.startswith("/"), "A path prefix must start with '/'"
|
||||
assert not prefix.endswith(
|
||||
"/"
|
||||
), "A path prefix must not end with '/', as the routes will start with '/'"
|
||||
for route in router.routes:
|
||||
if isinstance(route, APIRoute):
|
||||
self.add_api_route(
|
||||
prefix + route.path,
|
||||
route.endpoint,
|
||||
methods=route.methods,
|
||||
name=route.name,
|
||||
include_in_schema=route.include_in_schema,
|
||||
tags=route.tags,
|
||||
summary=route.summary,
|
||||
description=route.description,
|
||||
operation_id=route.operation_id,
|
||||
deprecated=route.deprecated,
|
||||
response_type=route.response_type,
|
||||
response_description=route.response_description,
|
||||
response_code=route.response_code,
|
||||
response_wrapper=route.response_wrapper,
|
||||
)
|
||||
elif isinstance(route, routing.Route):
|
||||
self.add_route(
|
||||
prefix + route.path,
|
||||
route.endpoint,
|
||||
methods=route.methods,
|
||||
name=route.name,
|
||||
include_in_schema=route.include_in_schema,
|
||||
)
|
||||
|
||||
def get(
|
||||
self,
|
||||
path: str,
|
||||
|
|
|
|||
|
|
@ -1,39 +1,34 @@
|
|||
from enum import Enum
|
||||
|
||||
from pydantic import Schema
|
||||
|
||||
from starlette.requests import Request
|
||||
|
||||
from .base import SecurityBase, Types
|
||||
|
||||
class APIKeyIn(Enum):
|
||||
query = "query"
|
||||
header = "header"
|
||||
cookie = "cookie"
|
||||
|
||||
from .base import SecurityBase
|
||||
from fastapi.openapi.models import APIKeyIn, APIKey
|
||||
|
||||
class APIKeyBase(SecurityBase):
|
||||
type_ = Schema(Types.apiKey, alias="type")
|
||||
in_: str = Schema(..., alias="in")
|
||||
name: str
|
||||
|
||||
pass
|
||||
|
||||
class APIKeyQuery(APIKeyBase):
|
||||
in_ = Schema(APIKeyIn.query, alias="in")
|
||||
|
||||
def __init__(self, *, name: str, scheme_name: str = None):
|
||||
self.model = APIKey(in_=APIKeyIn.query, name=name)
|
||||
self.scheme_name = scheme_name or self.__class__.__name__
|
||||
|
||||
async def __call__(self, requests: Request):
|
||||
return requests.query_params.get(self.name)
|
||||
return requests.query_params.get(self.model.name)
|
||||
|
||||
|
||||
class APIKeyHeader(APIKeyBase):
|
||||
in_ = Schema(APIKeyIn.header, alias="in")
|
||||
def __init__(self, *, name: str, scheme_name: str = None):
|
||||
self.model = APIKey(in_=APIKeyIn.header, name=name)
|
||||
self.scheme_name = scheme_name or self.__class__.__name__
|
||||
|
||||
async def __call__(self, requests: Request):
|
||||
return requests.headers.get(self.name)
|
||||
return requests.headers.get(self.model.name)
|
||||
|
||||
|
||||
class APIKeyCookie(APIKeyBase):
|
||||
in_ = Schema(APIKeyIn.cookie, alias="in")
|
||||
def __init__(self, *, name: str, scheme_name: str = None):
|
||||
self.model = APIKey(in_=APIKeyIn.cookie, name=name)
|
||||
self.scheme_name = scheme_name or self.__class__.__name__
|
||||
|
||||
async def __call__(self, requests: Request):
|
||||
return requests.cookies.get(self.name)
|
||||
return requests.cookies.get(self.model.name)
|
||||
|
|
|
|||
|
|
@ -1,16 +1,6 @@
|
|||
from enum import Enum
|
||||
from pydantic import BaseModel
|
||||
|
||||
from pydantic import BaseModel, Schema
|
||||
from fastapi.openapi.models import SecurityBase as SecurityBaseModel
|
||||
|
||||
|
||||
class Types(Enum):
|
||||
apiKey = "apiKey"
|
||||
http = "http"
|
||||
oauth2 = "oauth2"
|
||||
openIdConnect = "openIdConnect"
|
||||
|
||||
|
||||
class SecurityBase(BaseModel):
|
||||
scheme_name: str = None
|
||||
type_: Types = Schema(..., alias="type")
|
||||
description: str = None
|
||||
class SecurityBase:
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -1,26 +1,40 @@
|
|||
from pydantic import Schema
|
||||
|
||||
from starlette.requests import Request
|
||||
|
||||
from .base import SecurityBase, Types
|
||||
from .base import SecurityBase
|
||||
from fastapi.openapi.models import HTTPBase as HTTPBaseModel, HTTPBearer as HTTPBearerModel
|
||||
|
||||
|
||||
class HTTPBase(SecurityBase):
|
||||
type_ = Schema(Types.http, alias="type")
|
||||
scheme: str
|
||||
def __init__(self, *, scheme: str, scheme_name: str = None):
|
||||
self.model = HTTPBaseModel(scheme=scheme)
|
||||
self.scheme_name = scheme_name or self.__class__.__name__
|
||||
|
||||
async def __call__(self, request: Request):
|
||||
return request.headers.get("Authorization")
|
||||
|
||||
|
||||
class HTTPBasic(HTTPBase):
|
||||
scheme = "basic"
|
||||
def __init__(self, *, scheme_name: str = None):
|
||||
self.model = HTTPBaseModel(scheme="basic")
|
||||
self.scheme_name = scheme_name or self.__class__.__name__
|
||||
|
||||
async def __call__(self, request: Request):
|
||||
return request.headers.get("Authorization")
|
||||
|
||||
|
||||
class HTTPBearer(HTTPBase):
|
||||
scheme = "bearer"
|
||||
bearerFormat: str = None
|
||||
def __init__(self, *, bearerFormat: str = None, scheme_name: str = None):
|
||||
self.model = HTTPBearerModel(bearerFormat=bearerFormat)
|
||||
self.scheme_name = scheme_name or self.__class__.__name__
|
||||
|
||||
async def __call__(self, request: Request):
|
||||
return request.headers.get("Authorization")
|
||||
|
||||
|
||||
class HTTPDigest(HTTPBase):
|
||||
scheme = "digest"
|
||||
def __init__(self, *, scheme_name: str = None):
|
||||
self.model = HTTPBaseModel(scheme="digest")
|
||||
self.scheme_name = scheme_name or self.__class__.__name__
|
||||
|
||||
async def __call__(self, request: Request):
|
||||
return request.headers.get("Authorization")
|
||||
|
|
|
|||
|
|
@ -1,43 +1,13 @@
|
|||
from typing import Dict
|
||||
|
||||
from pydantic import BaseModel, Schema
|
||||
|
||||
from starlette.requests import Request
|
||||
|
||||
from .base import SecurityBase, Types
|
||||
|
||||
class OAuthFlow(BaseModel):
|
||||
refreshUrl: str = None
|
||||
scopes: Dict[str, str] = {}
|
||||
|
||||
|
||||
class OAuthFlowImplicit(OAuthFlow):
|
||||
authorizationUrl: str
|
||||
|
||||
|
||||
class OAuthFlowPassword(OAuthFlow):
|
||||
tokenUrl: str
|
||||
|
||||
|
||||
class OAuthFlowClientCredentials(OAuthFlow):
|
||||
tokenUrl: str
|
||||
|
||||
|
||||
class OAuthFlowAuthorizationCode(OAuthFlow):
|
||||
authorizationUrl: str
|
||||
tokenUrl: str
|
||||
|
||||
|
||||
class OAuthFlows(BaseModel):
|
||||
implicit: OAuthFlowImplicit = None
|
||||
password: OAuthFlowPassword = None
|
||||
clientCredentials: OAuthFlowClientCredentials = None
|
||||
authorizationCode: OAuthFlowAuthorizationCode = None
|
||||
from .base import SecurityBase
|
||||
from fastapi.openapi.models import OAuth2 as OAuth2Model, OAuthFlows as OAuthFlowsModel
|
||||
|
||||
|
||||
class OAuth2(SecurityBase):
|
||||
type_ = Schema(Types.oauth2, alias="type")
|
||||
flows: OAuthFlows
|
||||
def __init__(self, *, flows: OAuthFlowsModel = OAuthFlowsModel(), scheme_name: str = None):
|
||||
self.model = OAuth2Model(flows=flows)
|
||||
self.scheme_name = scheme_name or self.__class__.__name__
|
||||
|
||||
async def __call__(self, request: Request):
|
||||
return request.headers.get("Authorization")
|
||||
|
|
|
|||
|
|
@ -1,11 +1,13 @@
|
|||
from starlette.requests import Request
|
||||
|
||||
from .base import SecurityBase, Types
|
||||
from .base import SecurityBase
|
||||
from fastapi.openapi.models import OpenIdConnect as OpenIdConnectModel
|
||||
|
||||
|
||||
class OpenIdConnect(SecurityBase):
|
||||
type_ = Types.openIdConnect
|
||||
openIdConnectUrl: str
|
||||
def __init__(self, *, openIdConnectUrl: str, scheme_name: str = None):
|
||||
self.model = OpenIdConnectModel(openIdConnectUrl=openIdConnectUrl)
|
||||
self.scheme_name = scheme_name or self.__class__.__name__
|
||||
|
||||
async def __call__(self, request: Request):
|
||||
return request.headers.get("Authorization")
|
||||
|
|
|
|||
Loading…
Reference in New Issue