mirror of https://github.com/tiangolo/fastapi.git
✨ Refactor, update code, several features
This commit is contained in:
parent
b9d912c638
commit
addfa89b0f
|
|
@ -8,7 +8,8 @@ from starlette.responses import JSONResponse
|
||||||
|
|
||||||
|
|
||||||
from fastapi import routing
|
from fastapi import routing
|
||||||
from fastapi.openapi.utils import get_swagger_ui_html, get_openapi, get_redoc_html
|
from fastapi.openapi.utils import get_openapi
|
||||||
|
from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
|
||||||
|
|
||||||
|
|
||||||
async def http_exception(request, exc: HTTPException):
|
async def http_exception(request, exc: HTTPException):
|
||||||
|
|
@ -154,9 +155,11 @@ class FastAPI(Starlette):
|
||||||
response_wrapper=response_wrapper,
|
response_wrapper=response_wrapper,
|
||||||
)
|
)
|
||||||
return func
|
return func
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
def include_router(self, router: "APIRouter", *, prefix=""):
|
||||||
|
self.router.include_router(router, prefix=prefix)
|
||||||
|
|
||||||
def get(
|
def get(
|
||||||
self,
|
self,
|
||||||
path: str,
|
path: str,
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,78 @@
|
||||||
|
from starlette.responses import HTMLResponse
|
||||||
|
|
||||||
|
def get_swagger_ui_html(*, openapi_url: str, title: str):
|
||||||
|
return HTMLResponse(
|
||||||
|
"""
|
||||||
|
<! doctype html>
|
||||||
|
<html>
|
||||||
|
<head>
|
||||||
|
<link type="text/css" rel="stylesheet" href="//unpkg.com/swagger-ui-dist@3/swagger-ui.css">
|
||||||
|
<title>
|
||||||
|
"""
|
||||||
|
+ title
|
||||||
|
+ """
|
||||||
|
</title>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<div id="swagger-ui">
|
||||||
|
</div>
|
||||||
|
<script src="//unpkg.com/swagger-ui-dist@3/swagger-ui-bundle.js"></script>
|
||||||
|
<!-- `SwaggerUIBundle` is now available on the page -->
|
||||||
|
<script>
|
||||||
|
|
||||||
|
const ui = SwaggerUIBundle({
|
||||||
|
url: '"""
|
||||||
|
+ openapi_url
|
||||||
|
+ """',
|
||||||
|
dom_id: '#swagger-ui',
|
||||||
|
presets: [
|
||||||
|
SwaggerUIBundle.presets.apis,
|
||||||
|
SwaggerUIBundle.SwaggerUIStandalonePreset
|
||||||
|
],
|
||||||
|
layout: "BaseLayout"
|
||||||
|
|
||||||
|
})
|
||||||
|
</script>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
|
""",
|
||||||
|
media_type="text/html",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_redoc_html(*, openapi_url: str, title: str):
|
||||||
|
return HTMLResponse(
|
||||||
|
"""
|
||||||
|
<!DOCTYPE html>
|
||||||
|
<html>
|
||||||
|
<head>
|
||||||
|
<title>
|
||||||
|
"""
|
||||||
|
+ title
|
||||||
|
+ """
|
||||||
|
</title>
|
||||||
|
<!-- needed for adaptive design -->
|
||||||
|
<meta charset="utf-8"/>
|
||||||
|
<meta name="viewport" content="width=device-width, initial-scale=1">
|
||||||
|
<link href="https://fonts.googleapis.com/css?family=Montserrat:300,400,700|Roboto:300,400,700" rel="stylesheet">
|
||||||
|
|
||||||
|
<!--
|
||||||
|
ReDoc doesn't change outer page styles
|
||||||
|
-->
|
||||||
|
<style>
|
||||||
|
body {
|
||||||
|
margin: 0;
|
||||||
|
padding: 0;
|
||||||
|
}
|
||||||
|
</style>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<redoc spec-url='"""
|
||||||
|
+ openapi_url
|
||||||
|
+ """'></redoc>
|
||||||
|
<script src="https://cdn.jsdelivr.net/npm/redoc@next/bundles/redoc.standalone.js"> </script>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
|
""",
|
||||||
|
media_type="text/html",
|
||||||
|
)
|
||||||
|
|
@ -1,4 +1,8 @@
|
||||||
from typing import Any, Dict, Sequence, Type
|
from typing import Any, Dict, Sequence, Type, List
|
||||||
|
|
||||||
|
from pydantic.fields import Field
|
||||||
|
from pydantic.schema import field_schema, get_model_name_map
|
||||||
|
from pydantic.utils import lenient_issubclass
|
||||||
|
|
||||||
from starlette.responses import HTMLResponse, JSONResponse
|
from starlette.responses import HTMLResponse, JSONResponse
|
||||||
from starlette.routing import BaseRoute
|
from starlette.routing import BaseRoute
|
||||||
|
|
@ -12,9 +16,7 @@ from fastapi.openapi.constants import REF_PREFIX, METHODS_WITH_BODY
|
||||||
from fastapi.openapi.models import OpenAPI
|
from fastapi.openapi.models import OpenAPI
|
||||||
from fastapi.params import Body
|
from fastapi.params import Body
|
||||||
from fastapi.utils import get_flat_models_from_routes, get_model_definitions
|
from fastapi.utils import get_flat_models_from_routes, get_model_definitions
|
||||||
from pydantic.fields import Field
|
|
||||||
from pydantic.schema import field_schema, get_model_name_map
|
|
||||||
from pydantic.utils import lenient_issubclass
|
|
||||||
|
|
||||||
validation_error_definition = {
|
validation_error_definition = {
|
||||||
"title": "ValidationError",
|
"title": "ValidationError",
|
||||||
|
|
@ -49,91 +51,126 @@ def get_openapi_params(dependant: Dependant):
|
||||||
+ flat_dependant.cookie_params
|
+ flat_dependant.cookie_params
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_openapi_security_definitions(flat_dependant: Dependant):
|
||||||
|
security_definitions = {}
|
||||||
|
operation_security = []
|
||||||
|
for security_requirement in flat_dependant.security_requirements:
|
||||||
|
security_definition = jsonable_encoder(
|
||||||
|
security_requirement.security_scheme.model,
|
||||||
|
by_alias=True,
|
||||||
|
include_none=False,
|
||||||
|
)
|
||||||
|
security_name = (
|
||||||
|
security_requirement.security_scheme.scheme_name
|
||||||
|
|
||||||
|
)
|
||||||
|
security_definitions[security_name] = security_definition
|
||||||
|
operation_security.append({security_name: security_requirement.scopes})
|
||||||
|
return security_definitions, operation_security
|
||||||
|
|
||||||
|
|
||||||
|
def get_openapi_operation_parameters(all_route_params: List[Field]):
|
||||||
|
definitions: Dict[str, Dict] = {}
|
||||||
|
parameters = []
|
||||||
|
for param in all_route_params:
|
||||||
|
if "ValidationError" not in definitions:
|
||||||
|
definitions["ValidationError"] = validation_error_definition
|
||||||
|
definitions["HTTPValidationError"] = validation_error_response_definition
|
||||||
|
parameter = {
|
||||||
|
"name": param.alias,
|
||||||
|
"in": param.schema.in_.value,
|
||||||
|
"required": param.required,
|
||||||
|
"schema": field_schema(param, model_name_map={})[0],
|
||||||
|
}
|
||||||
|
if param.schema.description:
|
||||||
|
parameter["description"] = param.schema.description
|
||||||
|
if param.schema.deprecated:
|
||||||
|
parameter["deprecated"] = param.schema.deprecated
|
||||||
|
parameters.append(parameter)
|
||||||
|
return definitions, parameters
|
||||||
|
|
||||||
|
|
||||||
|
def get_openapi_operation_request_body(
|
||||||
|
*, body_field: Field, model_name_map: Dict[Type, str]
|
||||||
|
):
|
||||||
|
if not body_field:
|
||||||
|
return None
|
||||||
|
assert isinstance(body_field, Field)
|
||||||
|
body_schema, _ = field_schema(
|
||||||
|
body_field, model_name_map=model_name_map, ref_prefix=REF_PREFIX
|
||||||
|
)
|
||||||
|
if isinstance(body_field.schema, Body):
|
||||||
|
request_media_type = body_field.schema.media_type
|
||||||
|
else:
|
||||||
|
# Includes not declared media types (Schema)
|
||||||
|
request_media_type = "application/json"
|
||||||
|
required = body_field.required
|
||||||
|
request_body_oai = {}
|
||||||
|
if required:
|
||||||
|
request_body_oai["required"] = required
|
||||||
|
request_body_oai["content"] = {request_media_type: {"schema": body_schema}}
|
||||||
|
return request_body_oai
|
||||||
|
|
||||||
|
|
||||||
|
def generate_operation_id(*, route: routing.APIRoute, method: str):
|
||||||
|
if route.operation_id:
|
||||||
|
return route.operation_id
|
||||||
|
path: str = route.path
|
||||||
|
operation_id = route.name + path
|
||||||
|
operation_id = operation_id.replace("{", "_").replace("}", "_").replace("/", "_")
|
||||||
|
operation_id = operation_id + "_" + method.lower()
|
||||||
|
return operation_id
|
||||||
|
|
||||||
|
|
||||||
|
def generate_operation_summary(*, route: routing.APIRoute, method: str):
|
||||||
|
if route.summary:
|
||||||
|
return route.summary
|
||||||
|
return method.title() + " " + route.name.replace("_", " ").title()
|
||||||
|
|
||||||
|
def get_openapi_operation_metadata(*, route: BaseRoute, method: str):
|
||||||
|
operation: Dict[str, Any] = {}
|
||||||
|
if route.tags:
|
||||||
|
operation["tags"] = route.tags
|
||||||
|
operation["summary"] = generate_operation_summary(route=route, method=method)
|
||||||
|
if route.description:
|
||||||
|
operation["description"] = route.description
|
||||||
|
operation["operationId"] = generate_operation_id(route=route, method=method)
|
||||||
|
if route.deprecated:
|
||||||
|
operation["deprecated"] = route.deprecated
|
||||||
|
return operation
|
||||||
|
|
||||||
|
|
||||||
def get_openapi_path(*, route: BaseRoute, model_name_map: Dict[Type, str]):
|
def get_openapi_path(*, route: BaseRoute, model_name_map: Dict[Type, str]):
|
||||||
if not (route.include_in_schema and isinstance(route, routing.APIRoute)):
|
if not (route.include_in_schema and isinstance(route, routing.APIRoute)):
|
||||||
return None
|
return None
|
||||||
path = {}
|
path = {}
|
||||||
security_schemes = {}
|
security_schemes: Dict[str, Any] = {}
|
||||||
definitions = {}
|
definitions: Dict[str, Any] = {}
|
||||||
for method in route.methods:
|
for method in route.methods:
|
||||||
operation: Dict[str, Any] = {}
|
operation = get_openapi_operation_metadata(route=route, method=method)
|
||||||
if route.tags:
|
parameters: List[Dict] = []
|
||||||
operation["tags"] = route.tags
|
|
||||||
if route.summary:
|
|
||||||
operation["summary"] = route.summary
|
|
||||||
if route.description:
|
|
||||||
operation["description"] = route.description
|
|
||||||
if route.operation_id:
|
|
||||||
operation["operationId"] = route.operation_id
|
|
||||||
else:
|
|
||||||
operation["operationId"] = route.name
|
|
||||||
if route.deprecated:
|
|
||||||
operation["deprecated"] = route.deprecated
|
|
||||||
parameters = []
|
|
||||||
flat_dependant = get_flat_dependant(route.dependant)
|
flat_dependant = get_flat_dependant(route.dependant)
|
||||||
security_definitions = {}
|
security_definitions, operation_security = get_openapi_security_definitions(
|
||||||
for security_requirement in flat_dependant.security_requirements:
|
flat_dependant=flat_dependant
|
||||||
security_definition = jsonable_encoder(
|
)
|
||||||
security_requirement.security_scheme,
|
if operation_security:
|
||||||
exclude={"scheme_name"},
|
operation.setdefault("security", []).extend(operation_security)
|
||||||
by_alias=True,
|
|
||||||
include_none=False,
|
|
||||||
)
|
|
||||||
security_name = (
|
|
||||||
getattr(
|
|
||||||
security_requirement.security_scheme, "scheme_name", None
|
|
||||||
)
|
|
||||||
or security_requirement.security_scheme.__class__.__name__
|
|
||||||
)
|
|
||||||
security_definitions[security_name] = security_definition
|
|
||||||
operation.setdefault("security", []).append(
|
|
||||||
{security_name: security_requirement.scopes}
|
|
||||||
)
|
|
||||||
if security_definitions:
|
if security_definitions:
|
||||||
security_schemes.update(
|
security_schemes.update(security_definitions)
|
||||||
security_definitions
|
|
||||||
)
|
|
||||||
all_route_params = get_openapi_params(route.dependant)
|
all_route_params = get_openapi_params(route.dependant)
|
||||||
for param in all_route_params:
|
validation_definitions, operation_parameters = get_openapi_operation_parameters(
|
||||||
if "ValidationError" not in definitions:
|
all_route_params=all_route_params
|
||||||
definitions["ValidationError"] = validation_error_definition
|
)
|
||||||
definitions[
|
definitions.update(validation_definitions)
|
||||||
"HTTPValidationError"
|
parameters.extend(operation_parameters)
|
||||||
] = validation_error_response_definition
|
|
||||||
parameter = {
|
|
||||||
"name": param.alias,
|
|
||||||
"in": param.schema.in_.value,
|
|
||||||
"required": param.required,
|
|
||||||
"schema": field_schema(param, model_name_map={})[0],
|
|
||||||
}
|
|
||||||
if param.schema.description:
|
|
||||||
parameter["description"] = param.schema.description
|
|
||||||
if param.schema.deprecated:
|
|
||||||
parameter["deprecated"] = param.schema.deprecated
|
|
||||||
parameters.append(parameter)
|
|
||||||
if parameters:
|
if parameters:
|
||||||
operation["parameters"] = parameters
|
operation["parameters"] = parameters
|
||||||
if method in METHODS_WITH_BODY:
|
if method in METHODS_WITH_BODY:
|
||||||
body_field = route.body_field
|
request_body_oai = get_openapi_operation_request_body(
|
||||||
if body_field:
|
body_field=route.body_field, model_name_map=model_name_map
|
||||||
assert isinstance(body_field, Field)
|
)
|
||||||
body_schema, _ = field_schema(
|
if request_body_oai:
|
||||||
body_field,
|
|
||||||
model_name_map=model_name_map,
|
|
||||||
ref_prefix=REF_PREFIX,
|
|
||||||
)
|
|
||||||
if isinstance(body_field.schema, Body):
|
|
||||||
request_media_type = body_field.schema.media_type
|
|
||||||
else:
|
|
||||||
# Includes not declared media types (Schema)
|
|
||||||
request_media_type = "application/json"
|
|
||||||
required = body_field.required
|
|
||||||
request_body_oai = {}
|
|
||||||
if required:
|
|
||||||
request_body_oai["required"] = required
|
|
||||||
request_body_oai["content"] = {
|
|
||||||
request_media_type: {"schema": body_schema}
|
|
||||||
}
|
|
||||||
operation["requestBody"] = request_body_oai
|
operation["requestBody"] = request_body_oai
|
||||||
response_code = str(route.response_code)
|
response_code = str(route.response_code)
|
||||||
response_schema = {"type": "string"}
|
response_schema = {"type": "string"}
|
||||||
|
|
@ -206,75 +243,3 @@ def get_openapi(
|
||||||
output["components"] = components
|
output["components"] = components
|
||||||
output["paths"] = paths
|
output["paths"] = paths
|
||||||
return jsonable_encoder(OpenAPI(**output), by_alias=True, include_none=False)
|
return jsonable_encoder(OpenAPI(**output), by_alias=True, include_none=False)
|
||||||
|
|
||||||
|
|
||||||
def get_swagger_ui_html(*, openapi_url: str, title: str):
|
|
||||||
return HTMLResponse(
|
|
||||||
"""
|
|
||||||
<! doctype html>
|
|
||||||
<html>
|
|
||||||
<head>
|
|
||||||
<link type="text/css" rel="stylesheet" href="//unpkg.com/swagger-ui-dist@3/swagger-ui.css">
|
|
||||||
<title>
|
|
||||||
""" + title + """
|
|
||||||
</title>
|
|
||||||
</head>
|
|
||||||
<body>
|
|
||||||
<div id="swagger-ui">
|
|
||||||
</div>
|
|
||||||
<script src="//unpkg.com/swagger-ui-dist@3/swagger-ui-bundle.js"></script>
|
|
||||||
<!-- `SwaggerUIBundle` is now available on the page -->
|
|
||||||
<script>
|
|
||||||
|
|
||||||
const ui = SwaggerUIBundle({
|
|
||||||
url: '"""
|
|
||||||
+ openapi_url
|
|
||||||
+ """',
|
|
||||||
dom_id: '#swagger-ui',
|
|
||||||
presets: [
|
|
||||||
SwaggerUIBundle.presets.apis,
|
|
||||||
SwaggerUIBundle.SwaggerUIStandalonePreset
|
|
||||||
],
|
|
||||||
layout: "BaseLayout"
|
|
||||||
|
|
||||||
})
|
|
||||||
</script>
|
|
||||||
</body>
|
|
||||||
</html>
|
|
||||||
""",
|
|
||||||
media_type="text/html",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_redoc_html(*, openapi_url: str, title: str):
|
|
||||||
return HTMLResponse(
|
|
||||||
"""
|
|
||||||
<!DOCTYPE html>
|
|
||||||
<html>
|
|
||||||
<head>
|
|
||||||
<title>
|
|
||||||
""" + title + """
|
|
||||||
</title>
|
|
||||||
<!-- needed for adaptive design -->
|
|
||||||
<meta charset="utf-8"/>
|
|
||||||
<meta name="viewport" content="width=device-width, initial-scale=1">
|
|
||||||
<link href="https://fonts.googleapis.com/css?family=Montserrat:300,400,700|Roboto:300,400,700" rel="stylesheet">
|
|
||||||
|
|
||||||
<!--
|
|
||||||
ReDoc doesn't change outer page styles
|
|
||||||
-->
|
|
||||||
<style>
|
|
||||||
body {
|
|
||||||
margin: 0;
|
|
||||||
padding: 0;
|
|
||||||
}
|
|
||||||
</style>
|
|
||||||
</head>
|
|
||||||
<body>
|
|
||||||
<redoc spec-url='""" + openapi_url + """'></redoc>
|
|
||||||
<script src="https://cdn.jsdelivr.net/npm/redoc@next/bundles/redoc.standalone.js"> </script>
|
|
||||||
</body>
|
|
||||||
</html>
|
|
||||||
""",
|
|
||||||
media_type="text/html",
|
|
||||||
)
|
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,11 @@ import asyncio
|
||||||
import inspect
|
import inspect
|
||||||
from typing import Callable, List, Type
|
from typing import Callable, List, Type
|
||||||
|
|
||||||
|
from pydantic import BaseConfig, BaseModel, Schema
|
||||||
|
from pydantic.error_wrappers import ErrorWrapper, ValidationError
|
||||||
|
from pydantic.fields import Field
|
||||||
|
from pydantic.utils import lenient_issubclass
|
||||||
|
|
||||||
from starlette import routing
|
from starlette import routing
|
||||||
from starlette.concurrency import run_in_threadpool
|
from starlette.concurrency import run_in_threadpool
|
||||||
from starlette.exceptions import HTTPException
|
from starlette.exceptions import HTTPException
|
||||||
|
|
@ -15,10 +20,6 @@ from fastapi import params
|
||||||
from fastapi.dependencies.models import Dependant
|
from fastapi.dependencies.models import Dependant
|
||||||
from fastapi.dependencies.utils import get_body_field, get_dependant, solve_dependencies
|
from fastapi.dependencies.utils import get_body_field, get_dependant, solve_dependencies
|
||||||
from fastapi.encoders import jsonable_encoder
|
from fastapi.encoders import jsonable_encoder
|
||||||
from pydantic import BaseConfig, BaseModel, Schema
|
|
||||||
from pydantic.error_wrappers import ErrorWrapper, ValidationError
|
|
||||||
from pydantic.fields import Field
|
|
||||||
from pydantic.utils import lenient_issubclass
|
|
||||||
|
|
||||||
|
|
||||||
def serialize_response(*, field: Field = None, response):
|
def serialize_response(*, field: Field = None, response):
|
||||||
|
|
@ -44,11 +45,12 @@ def get_app(
|
||||||
response_field: Type[Field] = None,
|
response_field: Type[Field] = None,
|
||||||
):
|
):
|
||||||
is_coroutine = dependant.call and asyncio.iscoroutinefunction(dependant.call)
|
is_coroutine = dependant.call and asyncio.iscoroutinefunction(dependant.call)
|
||||||
|
is_body_form = body_field and isinstance(body_field.schema, params.Form)
|
||||||
|
|
||||||
async def app(request: Request) -> Response:
|
async def app(request: Request) -> Response:
|
||||||
body = None
|
body = None
|
||||||
if body_field:
|
if body_field:
|
||||||
if isinstance(body_field.schema, params.Form):
|
if is_body_form:
|
||||||
raw_body = await request.form()
|
raw_body = await request.form()
|
||||||
body = {}
|
body = {}
|
||||||
for field, value in raw_body.items():
|
for field, value in raw_body.items():
|
||||||
|
|
@ -127,12 +129,7 @@ class APIRoute(routing.Route):
|
||||||
response_code=200,
|
response_code=200,
|
||||||
response_wrapper=JSONResponse,
|
response_wrapper=JSONResponse,
|
||||||
) -> None:
|
) -> None:
|
||||||
# TODO define how to read and provide security params, and how to have them globally too
|
assert path.startswith("/"), "Routed paths must always start with '/'"
|
||||||
# TODO implement dependencies and injection
|
|
||||||
# TODO refactor code structure
|
|
||||||
# TODO create testing
|
|
||||||
# TODO testing coverage
|
|
||||||
assert path.startswith("/"), "Routed paths must always start '/'"
|
|
||||||
self.path = path
|
self.path = path
|
||||||
self.endpoint = endpoint
|
self.endpoint = endpoint
|
||||||
self.name = get_name(endpoint) if name is None else name
|
self.name = get_name(endpoint) if name is None else name
|
||||||
|
|
@ -260,6 +257,39 @@ class APIRouter(routing.Router):
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
def include_router(self, router: "APIRouter", *, prefix=""):
|
||||||
|
if prefix:
|
||||||
|
assert prefix.startswith("/"), "A path prefix must start with '/'"
|
||||||
|
assert not prefix.endswith(
|
||||||
|
"/"
|
||||||
|
), "A path prefix must not end with '/', as the routes will start with '/'"
|
||||||
|
for route in router.routes:
|
||||||
|
if isinstance(route, APIRoute):
|
||||||
|
self.add_api_route(
|
||||||
|
prefix + route.path,
|
||||||
|
route.endpoint,
|
||||||
|
methods=route.methods,
|
||||||
|
name=route.name,
|
||||||
|
include_in_schema=route.include_in_schema,
|
||||||
|
tags=route.tags,
|
||||||
|
summary=route.summary,
|
||||||
|
description=route.description,
|
||||||
|
operation_id=route.operation_id,
|
||||||
|
deprecated=route.deprecated,
|
||||||
|
response_type=route.response_type,
|
||||||
|
response_description=route.response_description,
|
||||||
|
response_code=route.response_code,
|
||||||
|
response_wrapper=route.response_wrapper,
|
||||||
|
)
|
||||||
|
elif isinstance(route, routing.Route):
|
||||||
|
self.add_route(
|
||||||
|
prefix + route.path,
|
||||||
|
route.endpoint,
|
||||||
|
methods=route.methods,
|
||||||
|
name=route.name,
|
||||||
|
include_in_schema=route.include_in_schema,
|
||||||
|
)
|
||||||
|
|
||||||
def get(
|
def get(
|
||||||
self,
|
self,
|
||||||
path: str,
|
path: str,
|
||||||
|
|
|
||||||
|
|
@ -1,39 +1,34 @@
|
||||||
from enum import Enum
|
|
||||||
|
|
||||||
from pydantic import Schema
|
|
||||||
|
|
||||||
from starlette.requests import Request
|
from starlette.requests import Request
|
||||||
|
|
||||||
from .base import SecurityBase, Types
|
from .base import SecurityBase
|
||||||
|
from fastapi.openapi.models import APIKeyIn, APIKey
|
||||||
class APIKeyIn(Enum):
|
|
||||||
query = "query"
|
|
||||||
header = "header"
|
|
||||||
cookie = "cookie"
|
|
||||||
|
|
||||||
|
|
||||||
class APIKeyBase(SecurityBase):
|
class APIKeyBase(SecurityBase):
|
||||||
type_ = Schema(Types.apiKey, alias="type")
|
pass
|
||||||
in_: str = Schema(..., alias="in")
|
|
||||||
name: str
|
|
||||||
|
|
||||||
|
|
||||||
class APIKeyQuery(APIKeyBase):
|
class APIKeyQuery(APIKeyBase):
|
||||||
in_ = Schema(APIKeyIn.query, alias="in")
|
|
||||||
|
def __init__(self, *, name: str, scheme_name: str = None):
|
||||||
|
self.model = APIKey(in_=APIKeyIn.query, name=name)
|
||||||
|
self.scheme_name = scheme_name or self.__class__.__name__
|
||||||
|
|
||||||
async def __call__(self, requests: Request):
|
async def __call__(self, requests: Request):
|
||||||
return requests.query_params.get(self.name)
|
return requests.query_params.get(self.model.name)
|
||||||
|
|
||||||
|
|
||||||
class APIKeyHeader(APIKeyBase):
|
class APIKeyHeader(APIKeyBase):
|
||||||
in_ = Schema(APIKeyIn.header, alias="in")
|
def __init__(self, *, name: str, scheme_name: str = None):
|
||||||
|
self.model = APIKey(in_=APIKeyIn.header, name=name)
|
||||||
|
self.scheme_name = scheme_name or self.__class__.__name__
|
||||||
|
|
||||||
async def __call__(self, requests: Request):
|
async def __call__(self, requests: Request):
|
||||||
return requests.headers.get(self.name)
|
return requests.headers.get(self.model.name)
|
||||||
|
|
||||||
|
|
||||||
class APIKeyCookie(APIKeyBase):
|
class APIKeyCookie(APIKeyBase):
|
||||||
in_ = Schema(APIKeyIn.cookie, alias="in")
|
def __init__(self, *, name: str, scheme_name: str = None):
|
||||||
|
self.model = APIKey(in_=APIKeyIn.cookie, name=name)
|
||||||
|
self.scheme_name = scheme_name or self.__class__.__name__
|
||||||
|
|
||||||
async def __call__(self, requests: Request):
|
async def __call__(self, requests: Request):
|
||||||
return requests.cookies.get(self.name)
|
return requests.cookies.get(self.model.name)
|
||||||
|
|
|
||||||
|
|
@ -1,16 +1,6 @@
|
||||||
from enum import Enum
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from pydantic import BaseModel, Schema
|
from fastapi.openapi.models import SecurityBase as SecurityBaseModel
|
||||||
|
|
||||||
|
class SecurityBase:
|
||||||
class Types(Enum):
|
pass
|
||||||
apiKey = "apiKey"
|
|
||||||
http = "http"
|
|
||||||
oauth2 = "oauth2"
|
|
||||||
openIdConnect = "openIdConnect"
|
|
||||||
|
|
||||||
|
|
||||||
class SecurityBase(BaseModel):
|
|
||||||
scheme_name: str = None
|
|
||||||
type_: Types = Schema(..., alias="type")
|
|
||||||
description: str = None
|
|
||||||
|
|
|
||||||
|
|
@ -1,26 +1,40 @@
|
||||||
from pydantic import Schema
|
|
||||||
|
|
||||||
from starlette.requests import Request
|
from starlette.requests import Request
|
||||||
|
|
||||||
from .base import SecurityBase, Types
|
from .base import SecurityBase
|
||||||
|
from fastapi.openapi.models import HTTPBase as HTTPBaseModel, HTTPBearer as HTTPBearerModel
|
||||||
|
|
||||||
|
|
||||||
class HTTPBase(SecurityBase):
|
class HTTPBase(SecurityBase):
|
||||||
type_ = Schema(Types.http, alias="type")
|
def __init__(self, *, scheme: str, scheme_name: str = None):
|
||||||
scheme: str
|
self.model = HTTPBaseModel(scheme=scheme)
|
||||||
|
self.scheme_name = scheme_name or self.__class__.__name__
|
||||||
|
|
||||||
async def __call__(self, request: Request):
|
async def __call__(self, request: Request):
|
||||||
return request.headers.get("Authorization")
|
return request.headers.get("Authorization")
|
||||||
|
|
||||||
|
|
||||||
class HTTPBasic(HTTPBase):
|
class HTTPBasic(HTTPBase):
|
||||||
scheme = "basic"
|
def __init__(self, *, scheme_name: str = None):
|
||||||
|
self.model = HTTPBaseModel(scheme="basic")
|
||||||
|
self.scheme_name = scheme_name or self.__class__.__name__
|
||||||
|
|
||||||
|
async def __call__(self, request: Request):
|
||||||
|
return request.headers.get("Authorization")
|
||||||
|
|
||||||
|
|
||||||
class HTTPBearer(HTTPBase):
|
class HTTPBearer(HTTPBase):
|
||||||
scheme = "bearer"
|
def __init__(self, *, bearerFormat: str = None, scheme_name: str = None):
|
||||||
bearerFormat: str = None
|
self.model = HTTPBearerModel(bearerFormat=bearerFormat)
|
||||||
|
self.scheme_name = scheme_name or self.__class__.__name__
|
||||||
|
|
||||||
|
async def __call__(self, request: Request):
|
||||||
|
return request.headers.get("Authorization")
|
||||||
|
|
||||||
|
|
||||||
class HTTPDigest(HTTPBase):
|
class HTTPDigest(HTTPBase):
|
||||||
scheme = "digest"
|
def __init__(self, *, scheme_name: str = None):
|
||||||
|
self.model = HTTPBaseModel(scheme="digest")
|
||||||
|
self.scheme_name = scheme_name or self.__class__.__name__
|
||||||
|
|
||||||
|
async def __call__(self, request: Request):
|
||||||
|
return request.headers.get("Authorization")
|
||||||
|
|
|
||||||
|
|
@ -1,43 +1,13 @@
|
||||||
from typing import Dict
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Schema
|
|
||||||
|
|
||||||
from starlette.requests import Request
|
from starlette.requests import Request
|
||||||
|
|
||||||
from .base import SecurityBase, Types
|
from .base import SecurityBase
|
||||||
|
from fastapi.openapi.models import OAuth2 as OAuth2Model, OAuthFlows as OAuthFlowsModel
|
||||||
class OAuthFlow(BaseModel):
|
|
||||||
refreshUrl: str = None
|
|
||||||
scopes: Dict[str, str] = {}
|
|
||||||
|
|
||||||
|
|
||||||
class OAuthFlowImplicit(OAuthFlow):
|
|
||||||
authorizationUrl: str
|
|
||||||
|
|
||||||
|
|
||||||
class OAuthFlowPassword(OAuthFlow):
|
|
||||||
tokenUrl: str
|
|
||||||
|
|
||||||
|
|
||||||
class OAuthFlowClientCredentials(OAuthFlow):
|
|
||||||
tokenUrl: str
|
|
||||||
|
|
||||||
|
|
||||||
class OAuthFlowAuthorizationCode(OAuthFlow):
|
|
||||||
authorizationUrl: str
|
|
||||||
tokenUrl: str
|
|
||||||
|
|
||||||
|
|
||||||
class OAuthFlows(BaseModel):
|
|
||||||
implicit: OAuthFlowImplicit = None
|
|
||||||
password: OAuthFlowPassword = None
|
|
||||||
clientCredentials: OAuthFlowClientCredentials = None
|
|
||||||
authorizationCode: OAuthFlowAuthorizationCode = None
|
|
||||||
|
|
||||||
|
|
||||||
class OAuth2(SecurityBase):
|
class OAuth2(SecurityBase):
|
||||||
type_ = Schema(Types.oauth2, alias="type")
|
def __init__(self, *, flows: OAuthFlowsModel = OAuthFlowsModel(), scheme_name: str = None):
|
||||||
flows: OAuthFlows
|
self.model = OAuth2Model(flows=flows)
|
||||||
|
self.scheme_name = scheme_name or self.__class__.__name__
|
||||||
|
|
||||||
async def __call__(self, request: Request):
|
async def __call__(self, request: Request):
|
||||||
return request.headers.get("Authorization")
|
return request.headers.get("Authorization")
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,13 @@
|
||||||
from starlette.requests import Request
|
from starlette.requests import Request
|
||||||
|
|
||||||
from .base import SecurityBase, Types
|
from .base import SecurityBase
|
||||||
|
from fastapi.openapi.models import OpenIdConnect as OpenIdConnectModel
|
||||||
|
|
||||||
|
|
||||||
class OpenIdConnect(SecurityBase):
|
class OpenIdConnect(SecurityBase):
|
||||||
type_ = Types.openIdConnect
|
def __init__(self, *, openIdConnectUrl: str, scheme_name: str = None):
|
||||||
openIdConnectUrl: str
|
self.model = OpenIdConnectModel(openIdConnectUrl=openIdConnectUrl)
|
||||||
|
self.scheme_name = scheme_name or self.__class__.__name__
|
||||||
|
|
||||||
async def __call__(self, request: Request):
|
async def __call__(self, request: Request):
|
||||||
return request.headers.get("Authorization")
|
return request.headers.get("Authorization")
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue