mirror of https://github.com/tiangolo/fastapi.git
♻️ Refactor, fix and update code
This commit is contained in:
parent
406c092a3b
commit
b9d912c638
|
|
@ -1,3 +1,3 @@
|
||||||
"""Fast API framework, fast high performance, fast to learn, fast to code"""
|
"""Fast API framework, fast high performance, fast to learn, fast to code"""
|
||||||
|
|
||||||
__version__ = '0.1'
|
__version__ = "0.1"
|
||||||
|
|
|
||||||
|
|
@ -1,61 +1,19 @@
|
||||||
import typing
|
from typing import Any, Callable, Dict, List, Type
|
||||||
import inspect
|
|
||||||
|
|
||||||
from starlette.applications import Starlette
|
from starlette.applications import Starlette
|
||||||
from starlette.middleware.lifespan import LifespanMiddleware
|
from starlette.exceptions import ExceptionMiddleware, HTTPException
|
||||||
from starlette.middleware.errors import ServerErrorMiddleware
|
from starlette.middleware.errors import ServerErrorMiddleware
|
||||||
from starlette.exceptions import ExceptionMiddleware
|
from starlette.middleware.lifespan import LifespanMiddleware
|
||||||
from starlette.responses import JSONResponse, HTMLResponse, PlainTextResponse
|
from starlette.responses import JSONResponse
|
||||||
from starlette.requests import Request
|
|
||||||
from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY
|
|
||||||
|
|
||||||
from pydantic import BaseModel, BaseConfig, Schema
|
|
||||||
from pydantic.utils import lenient_issubclass
|
|
||||||
from pydantic.fields import Field
|
|
||||||
from pydantic.schema import (
|
|
||||||
field_schema,
|
|
||||||
get_flat_models_from_models,
|
|
||||||
get_flat_models_from_fields,
|
|
||||||
get_model_name_map,
|
|
||||||
schema,
|
|
||||||
model_process_schema,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .routing import APIRouter, APIRoute, get_openapi_params, get_flat_dependant
|
|
||||||
from .pydantic_utils import jsonable_encoder
|
|
||||||
|
|
||||||
|
|
||||||
def docs(openapi_url):
|
from fastapi import routing
|
||||||
return HTMLResponse(
|
from fastapi.openapi.utils import get_swagger_ui_html, get_openapi, get_redoc_html
|
||||||
"""
|
|
||||||
<! doctype html>
|
|
||||||
<html>
|
|
||||||
<head>
|
|
||||||
<link type="text/css" rel="stylesheet" href="//unpkg.com/swagger-ui-dist@3/swagger-ui.css">
|
|
||||||
</head>
|
|
||||||
<body>
|
|
||||||
<div id="swagger-ui">
|
|
||||||
</div>
|
|
||||||
<script src="//unpkg.com/swagger-ui-dist@3/swagger-ui-bundle.js"></script>
|
|
||||||
<!-- `SwaggerUIBundle` is now available on the page -->
|
|
||||||
<script>
|
|
||||||
|
|
||||||
const ui = SwaggerUIBundle({
|
|
||||||
url: '""" + openapi_url + """',
|
|
||||||
dom_id: '#swagger-ui',
|
|
||||||
presets: [
|
|
||||||
SwaggerUIBundle.presets.apis,
|
|
||||||
SwaggerUIBundle.SwaggerUIStandalonePreset
|
|
||||||
],
|
|
||||||
layout: "BaseLayout"
|
|
||||||
|
|
||||||
})
|
async def http_exception(request, exc: HTTPException):
|
||||||
</script>
|
print(exc)
|
||||||
</body>
|
return JSONResponse({"detail": exc.detail}, status_code=exc.status_code)
|
||||||
</html>
|
|
||||||
""",
|
|
||||||
media_type="text/html",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class FastAPI(Starlette):
|
class FastAPI(Starlette):
|
||||||
|
|
@ -67,24 +25,26 @@ class FastAPI(Starlette):
|
||||||
description: str = "",
|
description: str = "",
|
||||||
version: str = "0.1.0",
|
version: str = "0.1.0",
|
||||||
openapi_url: str = "/openapi.json",
|
openapi_url: str = "/openapi.json",
|
||||||
docs_url: str = "/docs",
|
swagger_ui_url: str = "/docs",
|
||||||
**extra: typing.Dict[str, typing.Any],
|
redoc_url: str = "/redoc",
|
||||||
|
**extra: Dict[str, Any],
|
||||||
) -> None:
|
) -> None:
|
||||||
self._debug = debug
|
self._debug = debug
|
||||||
self.router = APIRouter()
|
self.router = routing.APIRouter()
|
||||||
self.exception_middleware = ExceptionMiddleware(self.router, debug=debug)
|
self.exception_middleware = ExceptionMiddleware(self.router, debug=debug)
|
||||||
self.error_middleware = ServerErrorMiddleware(
|
self.error_middleware = ServerErrorMiddleware(
|
||||||
self.exception_middleware, debug=debug
|
self.exception_middleware, debug=debug
|
||||||
)
|
)
|
||||||
self.lifespan_middleware = LifespanMiddleware(self.error_middleware)
|
self.lifespan_middleware = LifespanMiddleware(self.error_middleware)
|
||||||
self.schema_generator = None # type: typing.Optional[BaseSchemaGenerator]
|
self.schema_generator = None
|
||||||
self.template_env = self.load_template_env(template_directory)
|
self.template_env = self.load_template_env(template_directory)
|
||||||
|
|
||||||
self.title = title
|
self.title = title
|
||||||
self.description = description
|
self.description = description
|
||||||
self.version = version
|
self.version = version
|
||||||
self.openapi_url = openapi_url
|
self.openapi_url = openapi_url
|
||||||
self.docs_url = docs_url
|
self.swagger_ui_url = swagger_ui_url
|
||||||
|
self.redoc_url = redoc_url
|
||||||
self.extra = extra
|
self.extra = extra
|
||||||
|
|
||||||
self.openapi_version = "3.0.2"
|
self.openapi_version = "3.0.2"
|
||||||
|
|
@ -93,29 +53,52 @@ class FastAPI(Starlette):
|
||||||
assert self.title, "A title must be provided for OpenAPI, e.g.: 'My API'"
|
assert self.title, "A title must be provided for OpenAPI, e.g.: 'My API'"
|
||||||
assert self.version, "A version must be provided for OpenAPI, e.g.: '2.1.0'"
|
assert self.version, "A version must be provided for OpenAPI, e.g.: '2.1.0'"
|
||||||
|
|
||||||
if self.docs_url:
|
if self.swagger_ui_url or self.redoc_url:
|
||||||
assert self.openapi_url, "The openapi_url is required for the docs"
|
assert self.openapi_url, "The openapi_url is required for the docs"
|
||||||
|
self.setup()
|
||||||
|
|
||||||
|
def setup(self):
|
||||||
|
if self.openapi_url:
|
||||||
self.add_route(
|
self.add_route(
|
||||||
self.openapi_url,
|
self.openapi_url,
|
||||||
lambda req: JSONResponse(self.openapi()),
|
lambda req: JSONResponse(
|
||||||
|
get_openapi(
|
||||||
|
title=self.title,
|
||||||
|
version=self.version,
|
||||||
|
openapi_version=self.openapi_version,
|
||||||
|
description=self.description,
|
||||||
|
routes=self.routes,
|
||||||
|
)
|
||||||
|
),
|
||||||
include_in_schema=False,
|
include_in_schema=False,
|
||||||
)
|
)
|
||||||
self.add_route(self.docs_url, lambda r: docs(self.openapi_url), include_in_schema=False)
|
if self.swagger_ui_url:
|
||||||
|
self.add_route(
|
||||||
|
self.swagger_ui_url,
|
||||||
|
lambda r: get_swagger_ui_html(openapi_url=self.openapi_url, title=self.title + " - Swagger UI"),
|
||||||
|
include_in_schema=False,
|
||||||
|
)
|
||||||
|
if self.redoc_url:
|
||||||
|
self.add_route(
|
||||||
|
self.redoc_url,
|
||||||
|
lambda r: get_redoc_html(openapi_url=self.openapi_url, title=self.title + " - ReDoc"),
|
||||||
|
include_in_schema=False,
|
||||||
|
)
|
||||||
|
self.add_exception_handler(HTTPException, http_exception)
|
||||||
|
|
||||||
def add_api_route(
|
def add_api_route(
|
||||||
self,
|
self,
|
||||||
path: str,
|
path: str,
|
||||||
endpoint: typing.Callable,
|
endpoint: Callable,
|
||||||
methods: typing.List[str] = None,
|
methods: List[str] = None,
|
||||||
name: str = None,
|
name: str = None,
|
||||||
include_in_schema: bool = True,
|
include_in_schema: bool = True,
|
||||||
tags: typing.List[str] = [],
|
tags: List[str] = None,
|
||||||
summary: str = None,
|
summary: str = None,
|
||||||
description: str = None,
|
description: str = None,
|
||||||
operation_id: str = None,
|
operation_id: str = None,
|
||||||
deprecated: bool = None,
|
deprecated: bool = None,
|
||||||
response_type: typing.Type = None,
|
response_type: Type = None,
|
||||||
response_description: str = "Successful Response",
|
response_description: str = "Successful Response",
|
||||||
response_code=200,
|
response_code=200,
|
||||||
response_wrapper=JSONResponse,
|
response_wrapper=JSONResponse,
|
||||||
|
|
@ -126,7 +109,7 @@ class FastAPI(Starlette):
|
||||||
methods=methods,
|
methods=methods,
|
||||||
name=name,
|
name=name,
|
||||||
include_in_schema=include_in_schema,
|
include_in_schema=include_in_schema,
|
||||||
tags=tags,
|
tags=tags or [],
|
||||||
summary=summary,
|
summary=summary,
|
||||||
description=description,
|
description=description,
|
||||||
operation_id=operation_id,
|
operation_id=operation_id,
|
||||||
|
|
@ -140,27 +123,27 @@ class FastAPI(Starlette):
|
||||||
def api_route(
|
def api_route(
|
||||||
self,
|
self,
|
||||||
path: str,
|
path: str,
|
||||||
methods: typing.List[str] = None,
|
methods: List[str] = None,
|
||||||
name: str = None,
|
name: str = None,
|
||||||
include_in_schema: bool = True,
|
include_in_schema: bool = True,
|
||||||
tags: typing.List[str] = [],
|
tags: List[str] = None,
|
||||||
summary: str = None,
|
summary: str = None,
|
||||||
description: str = None,
|
description: str = None,
|
||||||
operation_id: str = None,
|
operation_id: str = None,
|
||||||
deprecated: bool = None,
|
deprecated: bool = None,
|
||||||
response_type: typing.Type = None,
|
response_type: Type = None,
|
||||||
response_description: str = "Successful Response",
|
response_description: str = "Successful Response",
|
||||||
response_code=200,
|
response_code=200,
|
||||||
response_wrapper=JSONResponse,
|
response_wrapper=JSONResponse,
|
||||||
) -> typing.Callable:
|
) -> Callable:
|
||||||
def decorator(func: typing.Callable) -> typing.Callable:
|
def decorator(func: Callable) -> Callable:
|
||||||
self.router.add_api_route(
|
self.router.add_api_route(
|
||||||
path,
|
path,
|
||||||
func,
|
func,
|
||||||
methods=methods,
|
methods=methods,
|
||||||
name=name,
|
name=name,
|
||||||
include_in_schema=include_in_schema,
|
include_in_schema=include_in_schema,
|
||||||
tags=tags,
|
tags=tags or [],
|
||||||
summary=summary,
|
summary=summary,
|
||||||
description=description,
|
description=description,
|
||||||
operation_id=operation_id,
|
operation_id=operation_id,
|
||||||
|
|
@ -179,12 +162,12 @@ class FastAPI(Starlette):
|
||||||
path: str,
|
path: str,
|
||||||
name: str = None,
|
name: str = None,
|
||||||
include_in_schema: bool = True,
|
include_in_schema: bool = True,
|
||||||
tags: typing.List[str] = [],
|
tags: List[str] = None,
|
||||||
summary: str = None,
|
summary: str = None,
|
||||||
description: str = None,
|
description: str = None,
|
||||||
operation_id: str = None,
|
operation_id: str = None,
|
||||||
deprecated: bool = None,
|
deprecated: bool = None,
|
||||||
response_type: typing.Type = None,
|
response_type: Type = None,
|
||||||
response_description: str = "Successful Response",
|
response_description: str = "Successful Response",
|
||||||
response_code=200,
|
response_code=200,
|
||||||
response_wrapper=JSONResponse,
|
response_wrapper=JSONResponse,
|
||||||
|
|
@ -193,7 +176,7 @@ class FastAPI(Starlette):
|
||||||
path=path,
|
path=path,
|
||||||
name=name,
|
name=name,
|
||||||
include_in_schema=include_in_schema,
|
include_in_schema=include_in_schema,
|
||||||
tags=tags,
|
tags=tags or [],
|
||||||
summary=summary,
|
summary=summary,
|
||||||
description=description,
|
description=description,
|
||||||
operation_id=operation_id,
|
operation_id=operation_id,
|
||||||
|
|
@ -209,12 +192,12 @@ class FastAPI(Starlette):
|
||||||
path: str,
|
path: str,
|
||||||
name: str = None,
|
name: str = None,
|
||||||
include_in_schema: bool = True,
|
include_in_schema: bool = True,
|
||||||
tags: typing.List[str] = [],
|
tags: List[str] = None,
|
||||||
summary: str = None,
|
summary: str = None,
|
||||||
description: str = None,
|
description: str = None,
|
||||||
operation_id: str = None,
|
operation_id: str = None,
|
||||||
deprecated: bool = None,
|
deprecated: bool = None,
|
||||||
response_type: typing.Type = None,
|
response_type: Type = None,
|
||||||
response_description: str = "Successful Response",
|
response_description: str = "Successful Response",
|
||||||
response_code=200,
|
response_code=200,
|
||||||
response_wrapper=JSONResponse,
|
response_wrapper=JSONResponse,
|
||||||
|
|
@ -223,7 +206,7 @@ class FastAPI(Starlette):
|
||||||
path=path,
|
path=path,
|
||||||
name=name,
|
name=name,
|
||||||
include_in_schema=include_in_schema,
|
include_in_schema=include_in_schema,
|
||||||
tags=tags,
|
tags=tags or [],
|
||||||
summary=summary,
|
summary=summary,
|
||||||
description=description,
|
description=description,
|
||||||
operation_id=operation_id,
|
operation_id=operation_id,
|
||||||
|
|
@ -239,12 +222,12 @@ class FastAPI(Starlette):
|
||||||
path: str,
|
path: str,
|
||||||
name: str = None,
|
name: str = None,
|
||||||
include_in_schema: bool = True,
|
include_in_schema: bool = True,
|
||||||
tags: typing.List[str] = [],
|
tags: List[str] = None,
|
||||||
summary: str = None,
|
summary: str = None,
|
||||||
description: str = None,
|
description: str = None,
|
||||||
operation_id: str = None,
|
operation_id: str = None,
|
||||||
deprecated: bool = None,
|
deprecated: bool = None,
|
||||||
response_type: typing.Type = None,
|
response_type: Type = None,
|
||||||
response_description: str = "Successful Response",
|
response_description: str = "Successful Response",
|
||||||
response_code=200,
|
response_code=200,
|
||||||
response_wrapper=JSONResponse,
|
response_wrapper=JSONResponse,
|
||||||
|
|
@ -253,7 +236,7 @@ class FastAPI(Starlette):
|
||||||
path=path,
|
path=path,
|
||||||
name=name,
|
name=name,
|
||||||
include_in_schema=include_in_schema,
|
include_in_schema=include_in_schema,
|
||||||
tags=tags,
|
tags=tags or [],
|
||||||
summary=summary,
|
summary=summary,
|
||||||
description=description,
|
description=description,
|
||||||
operation_id=operation_id,
|
operation_id=operation_id,
|
||||||
|
|
@ -269,12 +252,12 @@ class FastAPI(Starlette):
|
||||||
path: str,
|
path: str,
|
||||||
name: str = None,
|
name: str = None,
|
||||||
include_in_schema: bool = True,
|
include_in_schema: bool = True,
|
||||||
tags: typing.List[str] = [],
|
tags: List[str] = None,
|
||||||
summary: str = None,
|
summary: str = None,
|
||||||
description: str = None,
|
description: str = None,
|
||||||
operation_id: str = None,
|
operation_id: str = None,
|
||||||
deprecated: bool = None,
|
deprecated: bool = None,
|
||||||
response_type: typing.Type = None,
|
response_type: Type = None,
|
||||||
response_description: str = "Successful Response",
|
response_description: str = "Successful Response",
|
||||||
response_code=200,
|
response_code=200,
|
||||||
response_wrapper=JSONResponse,
|
response_wrapper=JSONResponse,
|
||||||
|
|
@ -283,7 +266,7 @@ class FastAPI(Starlette):
|
||||||
path=path,
|
path=path,
|
||||||
name=name,
|
name=name,
|
||||||
include_in_schema=include_in_schema,
|
include_in_schema=include_in_schema,
|
||||||
tags=tags,
|
tags=tags or [],
|
||||||
summary=summary,
|
summary=summary,
|
||||||
description=description,
|
description=description,
|
||||||
operation_id=operation_id,
|
operation_id=operation_id,
|
||||||
|
|
@ -299,12 +282,12 @@ class FastAPI(Starlette):
|
||||||
path: str,
|
path: str,
|
||||||
name: str = None,
|
name: str = None,
|
||||||
include_in_schema: bool = True,
|
include_in_schema: bool = True,
|
||||||
tags: typing.List[str] = [],
|
tags: List[str] = None,
|
||||||
summary: str = None,
|
summary: str = None,
|
||||||
description: str = None,
|
description: str = None,
|
||||||
operation_id: str = None,
|
operation_id: str = None,
|
||||||
deprecated: bool = None,
|
deprecated: bool = None,
|
||||||
response_type: typing.Type = None,
|
response_type: Type = None,
|
||||||
response_description: str = "Successful Response",
|
response_description: str = "Successful Response",
|
||||||
response_code=200,
|
response_code=200,
|
||||||
response_wrapper=JSONResponse,
|
response_wrapper=JSONResponse,
|
||||||
|
|
@ -313,7 +296,7 @@ class FastAPI(Starlette):
|
||||||
path=path,
|
path=path,
|
||||||
name=name,
|
name=name,
|
||||||
include_in_schema=include_in_schema,
|
include_in_schema=include_in_schema,
|
||||||
tags=tags,
|
tags=tags or [],
|
||||||
summary=summary,
|
summary=summary,
|
||||||
description=description,
|
description=description,
|
||||||
operation_id=operation_id,
|
operation_id=operation_id,
|
||||||
|
|
@ -329,12 +312,12 @@ class FastAPI(Starlette):
|
||||||
path: str,
|
path: str,
|
||||||
name: str = None,
|
name: str = None,
|
||||||
include_in_schema: bool = True,
|
include_in_schema: bool = True,
|
||||||
tags: typing.List[str] = [],
|
tags: List[str] = None,
|
||||||
summary: str = None,
|
summary: str = None,
|
||||||
description: str = None,
|
description: str = None,
|
||||||
operation_id: str = None,
|
operation_id: str = None,
|
||||||
deprecated: bool = None,
|
deprecated: bool = None,
|
||||||
response_type: typing.Type = None,
|
response_type: Type = None,
|
||||||
response_description: str = "Successful Response",
|
response_description: str = "Successful Response",
|
||||||
response_code=200,
|
response_code=200,
|
||||||
response_wrapper=JSONResponse,
|
response_wrapper=JSONResponse,
|
||||||
|
|
@ -343,7 +326,7 @@ class FastAPI(Starlette):
|
||||||
path=path,
|
path=path,
|
||||||
name=name,
|
name=name,
|
||||||
include_in_schema=include_in_schema,
|
include_in_schema=include_in_schema,
|
||||||
tags=tags,
|
tags=tags or [],
|
||||||
summary=summary,
|
summary=summary,
|
||||||
description=description,
|
description=description,
|
||||||
operation_id=operation_id,
|
operation_id=operation_id,
|
||||||
|
|
@ -359,12 +342,12 @@ class FastAPI(Starlette):
|
||||||
path: str,
|
path: str,
|
||||||
name: str = None,
|
name: str = None,
|
||||||
include_in_schema: bool = True,
|
include_in_schema: bool = True,
|
||||||
tags: typing.List[str] = [],
|
tags: List[str] = None,
|
||||||
summary: str = None,
|
summary: str = None,
|
||||||
description: str = None,
|
description: str = None,
|
||||||
operation_id: str = None,
|
operation_id: str = None,
|
||||||
deprecated: bool = None,
|
deprecated: bool = None,
|
||||||
response_type: typing.Type = None,
|
response_type: Type = None,
|
||||||
response_description: str = "Successful Response",
|
response_description: str = "Successful Response",
|
||||||
response_code=200,
|
response_code=200,
|
||||||
response_wrapper=JSONResponse,
|
response_wrapper=JSONResponse,
|
||||||
|
|
@ -373,7 +356,7 @@ class FastAPI(Starlette):
|
||||||
path=path,
|
path=path,
|
||||||
name=name,
|
name=name,
|
||||||
include_in_schema=include_in_schema,
|
include_in_schema=include_in_schema,
|
||||||
tags=tags,
|
tags=tags or [],
|
||||||
summary=summary,
|
summary=summary,
|
||||||
description=description,
|
description=description,
|
||||||
operation_id=operation_id,
|
operation_id=operation_id,
|
||||||
|
|
@ -389,12 +372,12 @@ class FastAPI(Starlette):
|
||||||
path: str,
|
path: str,
|
||||||
name: str = None,
|
name: str = None,
|
||||||
include_in_schema: bool = True,
|
include_in_schema: bool = True,
|
||||||
tags: typing.List[str] = [],
|
tags: List[str] = None,
|
||||||
summary: str = None,
|
summary: str = None,
|
||||||
description: str = None,
|
description: str = None,
|
||||||
operation_id: str = None,
|
operation_id: str = None,
|
||||||
deprecated: bool = None,
|
deprecated: bool = None,
|
||||||
response_type: typing.Type = None,
|
response_type: Type = None,
|
||||||
response_description: str = "Successful Response",
|
response_description: str = "Successful Response",
|
||||||
response_code=200,
|
response_code=200,
|
||||||
response_wrapper=JSONResponse,
|
response_wrapper=JSONResponse,
|
||||||
|
|
@ -403,7 +386,7 @@ class FastAPI(Starlette):
|
||||||
path=path,
|
path=path,
|
||||||
name=name,
|
name=name,
|
||||||
include_in_schema=include_in_schema,
|
include_in_schema=include_in_schema,
|
||||||
tags=tags,
|
tags=tags or [],
|
||||||
summary=summary,
|
summary=summary,
|
||||||
description=description,
|
description=description,
|
||||||
operation_id=operation_id,
|
operation_id=operation_id,
|
||||||
|
|
@ -413,169 +396,3 @@ class FastAPI(Starlette):
|
||||||
response_code=response_code,
|
response_code=response_code,
|
||||||
response_wrapper=response_wrapper,
|
response_wrapper=response_wrapper,
|
||||||
)
|
)
|
||||||
|
|
||||||
def openapi(self):
|
|
||||||
info = {"title": self.title, "version": self.version}
|
|
||||||
if self.description:
|
|
||||||
info["description"] = self.description
|
|
||||||
output = {"openapi": self.openapi_version, "info": info}
|
|
||||||
components = {}
|
|
||||||
paths = {}
|
|
||||||
methods_with_body = set(("POST", "PUT"))
|
|
||||||
body_fields_from_routes = []
|
|
||||||
responses_from_routes = []
|
|
||||||
ref_prefix = "#/components/schemas/"
|
|
||||||
for route in self.routes:
|
|
||||||
route: APIRoute
|
|
||||||
if route.include_in_schema and isinstance(route, APIRoute):
|
|
||||||
if route.request_body:
|
|
||||||
assert isinstance(
|
|
||||||
route.request_body, Field
|
|
||||||
), "A request body must be a Pydantic BaseModel or Field"
|
|
||||||
body_fields_from_routes.append(route.request_body)
|
|
||||||
if route.response_field:
|
|
||||||
responses_from_routes.append(route.response_field)
|
|
||||||
flat_models = get_flat_models_from_fields(
|
|
||||||
body_fields_from_routes + responses_from_routes
|
|
||||||
)
|
|
||||||
model_name_map = get_model_name_map(flat_models)
|
|
||||||
definitions = {}
|
|
||||||
for model in flat_models:
|
|
||||||
m_schema, m_definitions = model_process_schema(
|
|
||||||
model, model_name_map=model_name_map, ref_prefix=ref_prefix
|
|
||||||
)
|
|
||||||
definitions.update(m_definitions)
|
|
||||||
model_name = model_name_map[model]
|
|
||||||
definitions[model_name] = m_schema
|
|
||||||
validation_error_definition = {
|
|
||||||
"title": "ValidationError",
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"loc": {
|
|
||||||
"title": "Location",
|
|
||||||
"type": "array",
|
|
||||||
"items": {"type": "string"},
|
|
||||||
},
|
|
||||||
"msg": {"title": "Message", "type": "string"},
|
|
||||||
"type": {"title": "Error Type", "type": "string"},
|
|
||||||
},
|
|
||||||
"required": ["loc", "msg", "type"],
|
|
||||||
}
|
|
||||||
validation_error_response_definition = {
|
|
||||||
"title": "HTTPValidationError",
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"detail": {
|
|
||||||
"title": "Detail",
|
|
||||||
"type": "array",
|
|
||||||
"items": {"$ref": ref_prefix + "ValidationError"},
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for route in self.routes:
|
|
||||||
route: APIRoute
|
|
||||||
if route.include_in_schema and isinstance(route, APIRoute):
|
|
||||||
path = paths.get(route.path, {})
|
|
||||||
for method in route.methods:
|
|
||||||
operation = {}
|
|
||||||
if route.tags:
|
|
||||||
operation["tags"] = route.tags
|
|
||||||
if route.summary:
|
|
||||||
operation["summary"] = route.summary
|
|
||||||
if route.description:
|
|
||||||
operation["description"] = route.description
|
|
||||||
if route.operation_id:
|
|
||||||
operation["operationId"] = route.operation_id
|
|
||||||
else:
|
|
||||||
operation["operationId"] = route.name
|
|
||||||
if route.deprecated:
|
|
||||||
operation["deprecated"] = route.deprecated
|
|
||||||
parameters = []
|
|
||||||
flat_dependant = get_flat_dependant(route.dependant)
|
|
||||||
security_definitions = {}
|
|
||||||
for security_scheme in flat_dependant.security_schemes:
|
|
||||||
security_definition = jsonable_encoder(security_scheme, exclude=("scheme_name",), by_alias=True, include_none=False)
|
|
||||||
security_name = getattr(security_scheme, "scheme_name", None) or security_scheme.__class__.__name__
|
|
||||||
security_definitions[security_name] = security_definition
|
|
||||||
if security_definitions:
|
|
||||||
components.setdefault("securitySchemes", {}).update(security_definitions)
|
|
||||||
operation["security"] = [{name: []} for name in security_definitions]
|
|
||||||
all_route_params = get_openapi_params(route.dependant)
|
|
||||||
for param in all_route_params:
|
|
||||||
if "ValidationError" not in definitions:
|
|
||||||
definitions["ValidationError"] = validation_error_definition
|
|
||||||
definitions[
|
|
||||||
"HTTPValidationError"
|
|
||||||
] = validation_error_response_definition
|
|
||||||
parameter = {
|
|
||||||
"name": param.alias,
|
|
||||||
"in": param.schema.in_.value,
|
|
||||||
"required": param.required,
|
|
||||||
"schema": field_schema(param, model_name_map={})[0],
|
|
||||||
}
|
|
||||||
if param.schema.description:
|
|
||||||
parameter["description"] = param.schema.description
|
|
||||||
if param.schema.deprecated:
|
|
||||||
parameter["deprecated"] = param.schema.deprecated
|
|
||||||
parameters.append(parameter)
|
|
||||||
if parameters:
|
|
||||||
operation["parameters"] = parameters
|
|
||||||
if method in methods_with_body:
|
|
||||||
request_body = getattr(route, "request_body", None)
|
|
||||||
if request_body:
|
|
||||||
assert isinstance(request_body, Field)
|
|
||||||
body_schema, _ = field_schema(
|
|
||||||
request_body,
|
|
||||||
model_name_map=model_name_map,
|
|
||||||
ref_prefix=ref_prefix,
|
|
||||||
)
|
|
||||||
required = request_body.required
|
|
||||||
request_body_oai = {}
|
|
||||||
if required:
|
|
||||||
request_body_oai["required"] = required
|
|
||||||
request_body_oai["content"] = {
|
|
||||||
"application/json": {"schema": body_schema}
|
|
||||||
}
|
|
||||||
operation["requestBody"] = request_body_oai
|
|
||||||
response_code = str(route.response_code)
|
|
||||||
response_schema = {"type": "string"}
|
|
||||||
if lenient_issubclass(route.response_wrapper, JSONResponse):
|
|
||||||
response_media_type = "application/json"
|
|
||||||
if route.response_field:
|
|
||||||
response_schema, _ = field_schema(
|
|
||||||
route.response_field,
|
|
||||||
model_name_map=model_name_map,
|
|
||||||
ref_prefix=ref_prefix,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
response_schema = {}
|
|
||||||
elif lenient_issubclass(route.response_wrapper, HTMLResponse):
|
|
||||||
response_media_type = "text/html"
|
|
||||||
else:
|
|
||||||
response_media_type = "text/plain"
|
|
||||||
content = {response_media_type: {"schema": response_schema}}
|
|
||||||
operation["responses"] = {
|
|
||||||
response_code: {
|
|
||||||
"description": route.response_description,
|
|
||||||
"content": content,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if all_route_params or getattr(route, "request_body", None):
|
|
||||||
operation["responses"][str(HTTP_422_UNPROCESSABLE_ENTITY)] = {
|
|
||||||
"description": "Validation Error",
|
|
||||||
"content": {
|
|
||||||
"application/json": {
|
|
||||||
"schema": {
|
|
||||||
"$ref": ref_prefix + "HTTPValidationError"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
path[method.lower()] = operation
|
|
||||||
paths[route.path] = path
|
|
||||||
if definitions:
|
|
||||||
components.setdefault("schemas", {}).update(definitions)
|
|
||||||
if components:
|
|
||||||
output["components"] = components
|
|
||||||
output["paths"] = paths
|
|
||||||
return output
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,46 @@
|
||||||
|
from typing import Any, Callable, Dict, List, Sequence, Tuple
|
||||||
|
|
||||||
|
from starlette.concurrency import run_in_threadpool
|
||||||
|
from starlette.requests import Request
|
||||||
|
|
||||||
|
from fastapi.security.base import SecurityBase
|
||||||
|
from pydantic import BaseConfig, Schema
|
||||||
|
from pydantic.error_wrappers import ErrorWrapper
|
||||||
|
from pydantic.errors import MissingError
|
||||||
|
from pydantic.fields import Field, Required
|
||||||
|
from pydantic.schema import get_annotation_from_schema
|
||||||
|
|
||||||
|
param_supported_types = (str, int, float, bool)
|
||||||
|
|
||||||
|
|
||||||
|
class SecurityRequirement:
|
||||||
|
def __init__(self, security_scheme: SecurityBase, scopes: Sequence[str] = None):
|
||||||
|
self.security_scheme = security_scheme
|
||||||
|
self.scopes = scopes
|
||||||
|
|
||||||
|
|
||||||
|
class Dependant:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
path_params: List[Field] = None,
|
||||||
|
query_params: List[Field] = None,
|
||||||
|
header_params: List[Field] = None,
|
||||||
|
cookie_params: List[Field] = None,
|
||||||
|
body_params: List[Field] = None,
|
||||||
|
dependencies: List["Dependant"] = None,
|
||||||
|
security_schemes: List[SecurityRequirement] = None,
|
||||||
|
name: str = None,
|
||||||
|
call: Callable = None,
|
||||||
|
request_param_name: str = None,
|
||||||
|
) -> None:
|
||||||
|
self.path_params = path_params or []
|
||||||
|
self.query_params = query_params or []
|
||||||
|
self.header_params = header_params or []
|
||||||
|
self.cookie_params = cookie_params or []
|
||||||
|
self.body_params = body_params or []
|
||||||
|
self.dependencies = dependencies or []
|
||||||
|
self.security_requirements = security_schemes or []
|
||||||
|
self.request_param_name = request_param_name
|
||||||
|
self.name = name
|
||||||
|
self.call = call
|
||||||
|
|
@ -0,0 +1,327 @@
|
||||||
|
import asyncio
|
||||||
|
import inspect
|
||||||
|
from copy import deepcopy
|
||||||
|
from typing import Any, Callable, Dict, List, Tuple
|
||||||
|
|
||||||
|
from starlette.concurrency import run_in_threadpool
|
||||||
|
from starlette.requests import Request
|
||||||
|
|
||||||
|
from fastapi import params
|
||||||
|
from fastapi.dependencies.models import Dependant, SecurityRequirement
|
||||||
|
from fastapi.security.base import SecurityBase
|
||||||
|
from fastapi.utils import get_path_param_names
|
||||||
|
from pydantic import BaseConfig, Schema, create_model
|
||||||
|
from pydantic.error_wrappers import ErrorWrapper
|
||||||
|
from pydantic.errors import MissingError
|
||||||
|
from pydantic.fields import Field, Required
|
||||||
|
from pydantic.schema import get_annotation_from_schema
|
||||||
|
from pydantic.utils import lenient_issubclass
|
||||||
|
|
||||||
|
param_supported_types = (str, int, float, bool)
|
||||||
|
|
||||||
|
|
||||||
|
def get_sub_dependant(*, param: inspect.Parameter, path: str):
|
||||||
|
depends: params.Depends = param.default
|
||||||
|
if depends.dependency:
|
||||||
|
dependency = depends.dependency
|
||||||
|
else:
|
||||||
|
dependency = param.annotation
|
||||||
|
assert callable(dependency)
|
||||||
|
sub_dependant = get_dependant(path=path, call=dependency, name=param.name)
|
||||||
|
if isinstance(depends, params.Security) and isinstance(dependency, SecurityBase):
|
||||||
|
security_requirement = SecurityRequirement(
|
||||||
|
security_scheme=dependency, scopes=depends.scopes
|
||||||
|
)
|
||||||
|
sub_dependant.security_requirements.append(security_requirement)
|
||||||
|
return sub_dependant
|
||||||
|
|
||||||
|
|
||||||
|
def get_flat_dependant(dependant: Dependant):
|
||||||
|
flat_dependant = Dependant(
|
||||||
|
path_params=dependant.path_params.copy(),
|
||||||
|
query_params=dependant.query_params.copy(),
|
||||||
|
header_params=dependant.header_params.copy(),
|
||||||
|
cookie_params=dependant.cookie_params.copy(),
|
||||||
|
body_params=dependant.body_params.copy(),
|
||||||
|
security_schemes=dependant.security_requirements.copy(),
|
||||||
|
)
|
||||||
|
for sub_dependant in dependant.dependencies:
|
||||||
|
if sub_dependant is dependant:
|
||||||
|
raise ValueError("recursion", dependant.dependencies)
|
||||||
|
flat_sub = get_flat_dependant(sub_dependant)
|
||||||
|
flat_dependant.path_params.extend(flat_sub.path_params)
|
||||||
|
flat_dependant.query_params.extend(flat_sub.query_params)
|
||||||
|
flat_dependant.header_params.extend(flat_sub.header_params)
|
||||||
|
flat_dependant.cookie_params.extend(flat_sub.cookie_params)
|
||||||
|
flat_dependant.body_params.extend(flat_sub.body_params)
|
||||||
|
flat_dependant.security_requirements.extend(flat_sub.security_requirements)
|
||||||
|
return flat_dependant
|
||||||
|
|
||||||
|
|
||||||
|
def get_dependant(*, path: str, call: Callable, name: str = None):
|
||||||
|
path_param_names = get_path_param_names(path)
|
||||||
|
endpoint_signature = inspect.signature(call)
|
||||||
|
signature_params = endpoint_signature.parameters
|
||||||
|
dependant = Dependant(call=call, name=name)
|
||||||
|
for param_name in signature_params:
|
||||||
|
param = signature_params[param_name]
|
||||||
|
if isinstance(param.default, params.Depends):
|
||||||
|
sub_dependant = get_sub_dependant(param=param, path=path)
|
||||||
|
dependant.dependencies.append(sub_dependant)
|
||||||
|
for param_name in signature_params:
|
||||||
|
param = signature_params[param_name]
|
||||||
|
if (
|
||||||
|
(param.default == param.empty) or isinstance(param.default, params.Path)
|
||||||
|
) and (param_name in path_param_names):
|
||||||
|
assert lenient_issubclass(
|
||||||
|
param.annotation, param_supported_types
|
||||||
|
) or param.annotation == param.empty, f"Path params must be of type str, int, float or boot: {param}"
|
||||||
|
param = signature_params[param_name]
|
||||||
|
add_param_to_fields(
|
||||||
|
param=param,
|
||||||
|
dependant=dependant,
|
||||||
|
default_schema=params.Path,
|
||||||
|
force_type=params.ParamTypes.path,
|
||||||
|
)
|
||||||
|
elif (param.default == param.empty or param.default is None) and (
|
||||||
|
param.annotation == param.empty
|
||||||
|
or lenient_issubclass(param.annotation, param_supported_types)
|
||||||
|
):
|
||||||
|
add_param_to_fields(
|
||||||
|
param=param, dependant=dependant, default_schema=params.Query
|
||||||
|
)
|
||||||
|
elif isinstance(param.default, params.Param):
|
||||||
|
if param.annotation != param.empty:
|
||||||
|
assert lenient_issubclass(
|
||||||
|
param.annotation, param_supported_types
|
||||||
|
), f"Parameters for Path, Query, Header and Cookies must be of type str, int, float or bool: {param}"
|
||||||
|
add_param_to_fields(
|
||||||
|
param=param, dependant=dependant, default_schema=params.Query
|
||||||
|
)
|
||||||
|
elif lenient_issubclass(param.annotation, Request):
|
||||||
|
dependant.request_param_name = param_name
|
||||||
|
elif not isinstance(param.default, params.Depends):
|
||||||
|
add_param_to_body_fields(param=param, dependant=dependant)
|
||||||
|
return dependant
|
||||||
|
|
||||||
|
|
||||||
|
def add_param_to_fields(
|
||||||
|
*,
|
||||||
|
param: inspect.Parameter,
|
||||||
|
dependant: Dependant,
|
||||||
|
default_schema=params.Param,
|
||||||
|
force_type: params.ParamTypes = None,
|
||||||
|
):
|
||||||
|
default_value = Required
|
||||||
|
if not param.default == param.empty:
|
||||||
|
default_value = param.default
|
||||||
|
if isinstance(default_value, params.Param):
|
||||||
|
schema = default_value
|
||||||
|
default_value = schema.default
|
||||||
|
if schema.in_ is None:
|
||||||
|
schema.in_ = default_schema.in_
|
||||||
|
if force_type:
|
||||||
|
schema.in_ = force_type
|
||||||
|
else:
|
||||||
|
schema = default_schema(default_value)
|
||||||
|
required = default_value == Required
|
||||||
|
annotation = Any
|
||||||
|
if not param.annotation == param.empty:
|
||||||
|
annotation = param.annotation
|
||||||
|
annotation = get_annotation_from_schema(annotation, schema)
|
||||||
|
field = Field(
|
||||||
|
name=param.name,
|
||||||
|
type_=annotation,
|
||||||
|
default=None if required else default_value,
|
||||||
|
alias=schema.alias or param.name,
|
||||||
|
required=required,
|
||||||
|
model_config=BaseConfig(),
|
||||||
|
class_validators=[],
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
if schema.in_ == params.ParamTypes.path:
|
||||||
|
dependant.path_params.append(field)
|
||||||
|
elif schema.in_ == params.ParamTypes.query:
|
||||||
|
dependant.query_params.append(field)
|
||||||
|
elif schema.in_ == params.ParamTypes.header:
|
||||||
|
dependant.header_params.append(field)
|
||||||
|
else:
|
||||||
|
assert (
|
||||||
|
schema.in_ == params.ParamTypes.cookie
|
||||||
|
), f"non-body parameters must be in path, query, header or cookie: {param.name}"
|
||||||
|
dependant.cookie_params.append(field)
|
||||||
|
|
||||||
|
|
||||||
|
def add_param_to_body_fields(*, param: inspect.Parameter, dependant: Dependant):
|
||||||
|
default_value = Required
|
||||||
|
if not param.default == param.empty:
|
||||||
|
default_value = param.default
|
||||||
|
if isinstance(default_value, Schema):
|
||||||
|
schema = default_value
|
||||||
|
default_value = schema.default
|
||||||
|
else:
|
||||||
|
schema = Schema(default_value)
|
||||||
|
required = default_value == Required
|
||||||
|
annotation = get_annotation_from_schema(param.annotation, schema)
|
||||||
|
field = Field(
|
||||||
|
name=param.name,
|
||||||
|
type_=annotation,
|
||||||
|
default=None if required else default_value,
|
||||||
|
alias=schema.alias or param.name,
|
||||||
|
required=required,
|
||||||
|
model_config=BaseConfig,
|
||||||
|
class_validators=[],
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
dependant.body_params.append(field)
|
||||||
|
|
||||||
|
|
||||||
|
def is_coroutine_callable(call: Callable = None):
|
||||||
|
if not call:
|
||||||
|
return False
|
||||||
|
if inspect.isfunction(call):
|
||||||
|
return asyncio.iscoroutinefunction(call)
|
||||||
|
if inspect.isclass(call):
|
||||||
|
return False
|
||||||
|
call = getattr(call, "__call__", None)
|
||||||
|
if not call:
|
||||||
|
return False
|
||||||
|
return asyncio.iscoroutinefunction(call)
|
||||||
|
|
||||||
|
|
||||||
|
async def solve_dependencies(
|
||||||
|
*, request: Request, dependant: Dependant, body: Dict[str, Any] = None
|
||||||
|
):
|
||||||
|
values: Dict[str, Any] = {}
|
||||||
|
errors: List[ErrorWrapper] = []
|
||||||
|
for sub_dependant in dependant.dependencies:
|
||||||
|
sub_values, sub_errors = await solve_dependencies(
|
||||||
|
request=request, dependant=sub_dependant, body=body
|
||||||
|
)
|
||||||
|
if sub_errors:
|
||||||
|
return {}, errors
|
||||||
|
if sub_dependant.call and is_coroutine_callable(sub_dependant.call):
|
||||||
|
solved = await sub_dependant.call(**sub_values)
|
||||||
|
else:
|
||||||
|
solved = await run_in_threadpool(sub_dependant.call, **sub_values)
|
||||||
|
values[
|
||||||
|
sub_dependant.name
|
||||||
|
] = solved # type: ignore # Sub-dependants always have a name
|
||||||
|
path_values, path_errors = request_params_to_args(
|
||||||
|
dependant.path_params, request.path_params
|
||||||
|
)
|
||||||
|
query_values, query_errors = request_params_to_args(
|
||||||
|
dependant.query_params, request.query_params
|
||||||
|
)
|
||||||
|
header_values, header_errors = request_params_to_args(
|
||||||
|
dependant.header_params, request.headers
|
||||||
|
)
|
||||||
|
cookie_values, cookie_errors = request_params_to_args(
|
||||||
|
dependant.cookie_params, request.cookies
|
||||||
|
)
|
||||||
|
values.update(path_values)
|
||||||
|
values.update(query_values)
|
||||||
|
values.update(header_values)
|
||||||
|
values.update(cookie_values)
|
||||||
|
errors = path_errors + query_errors + header_errors + cookie_errors
|
||||||
|
if dependant.body_params:
|
||||||
|
body_values, body_errors = await request_body_to_args( # type: ignore # body_params checked above
|
||||||
|
dependant.body_params, body
|
||||||
|
)
|
||||||
|
values.update(body_values)
|
||||||
|
errors.extend(body_errors)
|
||||||
|
if dependant.request_param_name:
|
||||||
|
values[dependant.request_param_name] = request
|
||||||
|
return values, errors
|
||||||
|
|
||||||
|
|
||||||
|
def request_params_to_args(
|
||||||
|
required_params: List[Field], received_params: Dict[str, Any]
|
||||||
|
) -> Tuple[Dict[str, Any], List[ErrorWrapper]]:
|
||||||
|
values = {}
|
||||||
|
errors = []
|
||||||
|
for field in required_params:
|
||||||
|
value = received_params.get(field.alias)
|
||||||
|
if value is None:
|
||||||
|
if field.required:
|
||||||
|
errors.append(
|
||||||
|
ErrorWrapper(MissingError(), loc=field.alias, config=BaseConfig)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
values[field.name] = deepcopy(field.default)
|
||||||
|
continue
|
||||||
|
v_, errors_ = field.validate(
|
||||||
|
value, values, loc=(field.schema.in_.value, field.alias)
|
||||||
|
)
|
||||||
|
if isinstance(errors_, ErrorWrapper):
|
||||||
|
errors.append(errors_)
|
||||||
|
elif isinstance(errors_, list):
|
||||||
|
errors.extend(errors_)
|
||||||
|
else:
|
||||||
|
values[field.name] = v_
|
||||||
|
return values, errors
|
||||||
|
|
||||||
|
|
||||||
|
async def request_body_to_args(
|
||||||
|
required_params: List[Field], received_body: Dict[str, Any]
|
||||||
|
) -> Tuple[Dict[str, Any], List[ErrorWrapper]]:
|
||||||
|
values = {}
|
||||||
|
errors = []
|
||||||
|
if required_params:
|
||||||
|
field = required_params[0]
|
||||||
|
embed = getattr(field.schema, "embed", None)
|
||||||
|
if len(required_params) == 1 and not embed:
|
||||||
|
received_body = {field.alias: received_body}
|
||||||
|
for field in required_params:
|
||||||
|
value = received_body.get(field.alias)
|
||||||
|
if value is None:
|
||||||
|
if field.required:
|
||||||
|
errors.append(
|
||||||
|
ErrorWrapper(
|
||||||
|
MissingError(), loc=("body", field.alias), config=BaseConfig
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
values[field.name] = deepcopy(field.default)
|
||||||
|
continue
|
||||||
|
v_, errors_ = field.validate(value, values, loc=("body", field.alias))
|
||||||
|
if isinstance(errors_, ErrorWrapper):
|
||||||
|
errors.append(errors_)
|
||||||
|
elif isinstance(errors_, list):
|
||||||
|
errors.extend(errors_)
|
||||||
|
else:
|
||||||
|
values[field.name] = v_
|
||||||
|
return values, errors
|
||||||
|
|
||||||
|
|
||||||
|
def get_body_field(*, dependant: Dependant, name: str):
|
||||||
|
flat_dependant = get_flat_dependant(dependant)
|
||||||
|
if not flat_dependant.body_params:
|
||||||
|
return None
|
||||||
|
first_param = flat_dependant.body_params[0]
|
||||||
|
embed = getattr(first_param.schema, "embed", None)
|
||||||
|
if len(flat_dependant.body_params) == 1 and not embed:
|
||||||
|
return first_param
|
||||||
|
model_name = "Body_" + name
|
||||||
|
BodyModel = create_model(model_name)
|
||||||
|
for f in flat_dependant.body_params:
|
||||||
|
BodyModel.__fields__[f.name] = f
|
||||||
|
required = any(True for f in flat_dependant.body_params if f.required)
|
||||||
|
if any(isinstance(f.schema, params.File) for f in flat_dependant.body_params):
|
||||||
|
BodySchema = params.File
|
||||||
|
elif any(isinstance(f.schema, params.Form) for f in flat_dependant.body_params):
|
||||||
|
BodySchema = params.Form
|
||||||
|
else:
|
||||||
|
BodySchema = params.Body
|
||||||
|
|
||||||
|
field = Field(
|
||||||
|
name="body",
|
||||||
|
type_=BodyModel,
|
||||||
|
default=None,
|
||||||
|
required=required,
|
||||||
|
model_config=BaseConfig,
|
||||||
|
class_validators=[],
|
||||||
|
alias="body",
|
||||||
|
schema=BodySchema(None),
|
||||||
|
)
|
||||||
|
return field
|
||||||
|
|
@ -1,33 +1,44 @@
|
||||||
|
from enum import Enum
|
||||||
from types import GeneratorType
|
from types import GeneratorType
|
||||||
from typing import Set
|
from typing import Set
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from enum import Enum
|
|
||||||
from pydantic.json import pydantic_encoder
|
from pydantic.json import pydantic_encoder
|
||||||
|
|
||||||
|
|
||||||
def jsonable_encoder(
|
def jsonable_encoder(
|
||||||
obj, include: Set[str] = None, exclude: Set[str] = set(), by_alias: bool = False, include_none=True,
|
obj,
|
||||||
|
include: Set[str] = None,
|
||||||
|
exclude: Set[str] = set(),
|
||||||
|
by_alias: bool = False,
|
||||||
|
include_none=True,
|
||||||
):
|
):
|
||||||
if isinstance(obj, BaseModel):
|
if isinstance(obj, BaseModel):
|
||||||
return jsonable_encoder(
|
return jsonable_encoder(
|
||||||
obj.dict(include=include, exclude=exclude, by_alias=by_alias), include_none=include_none
|
obj.dict(include=include, exclude=exclude, by_alias=by_alias),
|
||||||
|
include_none=include_none,
|
||||||
)
|
)
|
||||||
elif isinstance(obj, Enum):
|
if isinstance(obj, Enum):
|
||||||
return obj.value
|
return obj.value
|
||||||
if isinstance(obj, (str, int, float, type(None))):
|
if isinstance(obj, (str, int, float, type(None))):
|
||||||
return obj
|
return obj
|
||||||
if isinstance(obj, dict):
|
if isinstance(obj, dict):
|
||||||
return {
|
return {
|
||||||
jsonable_encoder(
|
jsonable_encoder(
|
||||||
key, by_alias=by_alias, include_none=include_none,
|
key, by_alias=by_alias, include_none=include_none
|
||||||
): jsonable_encoder(
|
): jsonable_encoder(value, by_alias=by_alias, include_none=include_none)
|
||||||
value, by_alias=by_alias, include_none=include_none,
|
for key, value in obj.items()
|
||||||
)
|
if value is not None or include_none
|
||||||
for key, value in obj.items() if value is not None or include_none
|
|
||||||
}
|
}
|
||||||
if isinstance(obj, (list, set, frozenset, GeneratorType, tuple)):
|
if isinstance(obj, (list, set, frozenset, GeneratorType, tuple)):
|
||||||
return [
|
return [
|
||||||
jsonable_encoder(item, include=include, exclude=exclude, by_alias=by_alias, include_none=include_none)
|
jsonable_encoder(
|
||||||
|
item,
|
||||||
|
include=include,
|
||||||
|
exclude=exclude,
|
||||||
|
by_alias=by_alias,
|
||||||
|
include_none=include_none,
|
||||||
|
)
|
||||||
for item in obj
|
for item in obj
|
||||||
]
|
]
|
||||||
return pydantic_encoder(obj)
|
return pydantic_encoder(obj)
|
||||||
|
|
@ -0,0 +1,2 @@
|
||||||
|
METHODS_WITH_BODY = set(("POST", "PUT"))
|
||||||
|
REF_PREFIX = "#/components/schemas/"
|
||||||
|
|
@ -0,0 +1,347 @@
|
||||||
|
import logging
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Schema as PSchema
|
||||||
|
from pydantic.types import UrlStr
|
||||||
|
|
||||||
|
try:
|
||||||
|
import pydantic.types.EmailStr
|
||||||
|
from pydantic.types import EmailStr
|
||||||
|
except ImportError:
|
||||||
|
logging.warning(
|
||||||
|
"email-validator not installed, email fields will be treated as str"
|
||||||
|
)
|
||||||
|
|
||||||
|
class EmailStr(str):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class Contact(BaseModel):
|
||||||
|
name: Optional[str] = None
|
||||||
|
url: Optional[UrlStr] = None
|
||||||
|
email: Optional[EmailStr] = None
|
||||||
|
|
||||||
|
|
||||||
|
class License(BaseModel):
|
||||||
|
name: str
|
||||||
|
url: Optional[UrlStr] = None
|
||||||
|
|
||||||
|
|
||||||
|
class Info(BaseModel):
|
||||||
|
title: str
|
||||||
|
description: Optional[str] = None
|
||||||
|
termsOfService: Optional[str] = None
|
||||||
|
contact: Optional[Contact] = None
|
||||||
|
license: Optional[License] = None
|
||||||
|
version: str
|
||||||
|
|
||||||
|
|
||||||
|
class ServerVariable(BaseModel):
|
||||||
|
enum: Optional[List[str]] = None
|
||||||
|
default: str
|
||||||
|
description: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class Server(BaseModel):
|
||||||
|
url: UrlStr
|
||||||
|
description: Optional[str] = None
|
||||||
|
variables: Optional[Dict[str, ServerVariable]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class Reference(BaseModel):
|
||||||
|
ref: str = PSchema(..., alias="$ref")
|
||||||
|
|
||||||
|
|
||||||
|
class Discriminator(BaseModel):
|
||||||
|
propertyName: str
|
||||||
|
mapping: Optional[Dict[str, str]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class XML(BaseModel):
|
||||||
|
name: Optional[str] = None
|
||||||
|
namespace: Optional[str] = None
|
||||||
|
prefix: Optional[str] = None
|
||||||
|
attribute: Optional[bool] = None
|
||||||
|
wrapped: Optional[bool] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ExternalDocumentation(BaseModel):
|
||||||
|
description: Optional[str] = None
|
||||||
|
url: UrlStr
|
||||||
|
|
||||||
|
|
||||||
|
class SchemaBase(BaseModel):
|
||||||
|
ref: Optional[str] = PSchema(None, alias="$ref")
|
||||||
|
title: Optional[str] = None
|
||||||
|
multipleOf: Optional[float] = None
|
||||||
|
maximum: Optional[float] = None
|
||||||
|
exclusiveMaximum: Optional[float] = None
|
||||||
|
minimum: Optional[float] = None
|
||||||
|
exclusiveMinimum: Optional[float] = None
|
||||||
|
maxLength: Optional[int] = PSchema(None, gte=0)
|
||||||
|
minLength: Optional[int] = PSchema(None, gte=0)
|
||||||
|
pattern: Optional[str] = None
|
||||||
|
maxItems: Optional[int] = PSchema(None, gte=0)
|
||||||
|
minItems: Optional[int] = PSchema(None, gte=0)
|
||||||
|
uniqueItems: Optional[bool] = None
|
||||||
|
maxProperties: Optional[int] = PSchema(None, gte=0)
|
||||||
|
minProperties: Optional[int] = PSchema(None, gte=0)
|
||||||
|
required: Optional[List[str]] = None
|
||||||
|
enum: Optional[List[str]] = None
|
||||||
|
type: Optional[str] = None
|
||||||
|
allOf: Optional[List[Any]] = None
|
||||||
|
oneOf: Optional[List[Any]] = None
|
||||||
|
anyOf: Optional[List[Any]] = None
|
||||||
|
not_: Optional[List[Any]] = PSchema(None, alias="not")
|
||||||
|
items: Optional[Any] = None
|
||||||
|
properties: Optional[Dict[str, Any]] = None
|
||||||
|
additionalProperties: Optional[Union[bool, Any]] = None
|
||||||
|
description: Optional[str] = None
|
||||||
|
format: Optional[str] = None
|
||||||
|
default: Optional[Any] = None
|
||||||
|
nullable: Optional[bool] = None
|
||||||
|
discriminator: Optional[Discriminator] = None
|
||||||
|
readOnly: Optional[bool] = None
|
||||||
|
writeOnly: Optional[bool] = None
|
||||||
|
xml: Optional[XML] = None
|
||||||
|
externalDocs: Optional[ExternalDocumentation] = None
|
||||||
|
example: Optional[Any] = None
|
||||||
|
deprecated: Optional[bool] = None
|
||||||
|
|
||||||
|
|
||||||
|
class Schema(SchemaBase):
|
||||||
|
allOf: Optional[List[SchemaBase]] = None
|
||||||
|
oneOf: Optional[List[SchemaBase]] = None
|
||||||
|
anyOf: Optional[List[SchemaBase]] = None
|
||||||
|
not_: Optional[List[SchemaBase]] = PSchema(None, alias="not")
|
||||||
|
items: Optional[SchemaBase] = None
|
||||||
|
properties: Optional[Dict[str, SchemaBase]] = None
|
||||||
|
additionalProperties: Optional[Union[bool, SchemaBase]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class Example(BaseModel):
|
||||||
|
summary: Optional[str] = None
|
||||||
|
description: Optional[str] = None
|
||||||
|
value: Optional[Any] = None
|
||||||
|
externalValue: Optional[UrlStr] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ParameterInType(Enum):
|
||||||
|
query = "query"
|
||||||
|
header = "header"
|
||||||
|
path = "path"
|
||||||
|
cookie = "cookie"
|
||||||
|
|
||||||
|
|
||||||
|
class Encoding(BaseModel):
|
||||||
|
contentType: Optional[str] = None
|
||||||
|
# Workaround OpenAPI recursive reference, using Any
|
||||||
|
headers: Optional[Dict[str, Union[Any, Reference]]] = None
|
||||||
|
style: Optional[str] = None
|
||||||
|
explode: Optional[bool] = None
|
||||||
|
allowReserved: Optional[bool] = None
|
||||||
|
|
||||||
|
|
||||||
|
class MediaType(BaseModel):
|
||||||
|
schema_: Optional[Union[Schema, Reference]] = PSchema(None, alias="schema")
|
||||||
|
example: Optional[Any] = None
|
||||||
|
examples: Optional[Dict[str, Union[Example, Reference]]] = None
|
||||||
|
encoding: Optional[Dict[str, Encoding]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ParameterBase(BaseModel):
|
||||||
|
description: Optional[str] = None
|
||||||
|
required: Optional[bool] = None
|
||||||
|
deprecated: Optional[bool] = None
|
||||||
|
# Serialization rules for simple scenarios
|
||||||
|
style: Optional[str] = None
|
||||||
|
explode: Optional[bool] = None
|
||||||
|
allowReserved: Optional[bool] = None
|
||||||
|
schema_: Optional[Union[Schema, Reference]] = PSchema(None, alias="schema")
|
||||||
|
example: Optional[Any] = None
|
||||||
|
examples: Optional[Dict[str, Union[Example, Reference]]] = None
|
||||||
|
# Serialization rules for more complex scenarios
|
||||||
|
content: Optional[Dict[str, MediaType]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class Parameter(ParameterBase):
|
||||||
|
name: str
|
||||||
|
in_: ParameterInType = PSchema(..., alias="in")
|
||||||
|
|
||||||
|
|
||||||
|
class Header(ParameterBase):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# Workaround OpenAPI recursive reference
|
||||||
|
class EncodingWithHeaders(Encoding):
|
||||||
|
headers: Optional[Dict[str, Union[Header, Reference]]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class RequestBody(BaseModel):
|
||||||
|
description: Optional[str] = None
|
||||||
|
content: Dict[str, MediaType]
|
||||||
|
required: Optional[bool] = None
|
||||||
|
|
||||||
|
|
||||||
|
class Link(BaseModel):
|
||||||
|
operationRef: Optional[str] = None
|
||||||
|
operationId: Optional[str] = None
|
||||||
|
parameters: Optional[Dict[str, Union[Any, str]]] = None
|
||||||
|
requestBody: Optional[Union[Any, str]] = None
|
||||||
|
description: Optional[str] = None
|
||||||
|
server: Optional[Server] = None
|
||||||
|
|
||||||
|
|
||||||
|
class Response(BaseModel):
|
||||||
|
description: str
|
||||||
|
headers: Optional[Dict[str, Union[Header, Reference]]] = None
|
||||||
|
content: Optional[Dict[str, MediaType]] = None
|
||||||
|
links: Optional[Dict[str, Union[Link, Reference]]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class Responses(BaseModel):
|
||||||
|
default: Response
|
||||||
|
|
||||||
|
|
||||||
|
class Operation(BaseModel):
|
||||||
|
tags: Optional[List[str]] = None
|
||||||
|
summary: Optional[str] = None
|
||||||
|
description: Optional[str] = None
|
||||||
|
externalDocs: Optional[ExternalDocumentation] = None
|
||||||
|
operationId: Optional[str] = None
|
||||||
|
parameters: Optional[List[Union[Parameter, Reference]]] = None
|
||||||
|
requestBody: Optional[Union[RequestBody, Reference]] = None
|
||||||
|
responses: Union[Responses, Dict[Union[str], Response]]
|
||||||
|
# Workaround OpenAPI recursive reference
|
||||||
|
callbacks: Optional[Dict[str, Union[Dict[str, Any], Reference]]] = None
|
||||||
|
deprecated: Optional[bool] = None
|
||||||
|
security: Optional[List[Dict[str, List[str]]]] = None
|
||||||
|
servers: Optional[List[Server]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class PathItem(BaseModel):
|
||||||
|
ref: Optional[str] = PSchema(None, alias="$ref")
|
||||||
|
summary: Optional[str] = None
|
||||||
|
description: Optional[str] = None
|
||||||
|
get: Optional[Operation] = None
|
||||||
|
put: Optional[Operation] = None
|
||||||
|
post: Optional[Operation] = None
|
||||||
|
delete: Optional[Operation] = None
|
||||||
|
options: Optional[Operation] = None
|
||||||
|
head: Optional[Operation] = None
|
||||||
|
patch: Optional[Operation] = None
|
||||||
|
trace: Optional[Operation] = None
|
||||||
|
servers: Optional[List[Server]] = None
|
||||||
|
parameters: Optional[List[Union[Parameter, Reference]]] = None
|
||||||
|
|
||||||
|
|
||||||
|
# Workaround OpenAPI recursive reference
|
||||||
|
class OperationWithCallbacks(BaseModel):
|
||||||
|
callbacks: Optional[Dict[str, Union[Dict[str, PathItem], Reference]]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class SecuritySchemeType(Enum):
|
||||||
|
apiKey = "apiKey"
|
||||||
|
http = "http"
|
||||||
|
oauth2 = "oauth2"
|
||||||
|
openIdConnect = "openIdConnect"
|
||||||
|
|
||||||
|
|
||||||
|
class SecurityBase(BaseModel):
|
||||||
|
type_: SecuritySchemeType = PSchema(..., alias="type")
|
||||||
|
description: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class APIKeyIn(Enum):
|
||||||
|
query = "query"
|
||||||
|
header = "header"
|
||||||
|
cookie = "cookie"
|
||||||
|
|
||||||
|
|
||||||
|
class APIKey(SecurityBase):
|
||||||
|
type_ = PSchema(SecuritySchemeType.apiKey, alias="type")
|
||||||
|
in_: APIKeyIn = PSchema(..., alias="in")
|
||||||
|
name: str
|
||||||
|
|
||||||
|
|
||||||
|
class HTTPBase(SecurityBase):
|
||||||
|
type_ = PSchema(SecuritySchemeType.http, alias="type")
|
||||||
|
scheme: str
|
||||||
|
|
||||||
|
|
||||||
|
class HTTPBearer(HTTPBase):
|
||||||
|
scheme = "bearer"
|
||||||
|
bearerFormat: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthFlow(BaseModel):
|
||||||
|
refreshUrl: Optional[str] = None
|
||||||
|
scopes: Dict[str, str] = {}
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthFlowImplicit(OAuthFlow):
|
||||||
|
authorizationUrl: str
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthFlowPassword(OAuthFlow):
|
||||||
|
tokenUrl: str
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthFlowClientCredentials(OAuthFlow):
|
||||||
|
tokenUrl: str
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthFlowAuthorizationCode(OAuthFlow):
|
||||||
|
authorizationUrl: str
|
||||||
|
tokenUrl: str
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthFlows(BaseModel):
|
||||||
|
implicit: Optional[OAuthFlowImplicit] = None
|
||||||
|
password: Optional[OAuthFlowPassword] = None
|
||||||
|
clientCredentials: Optional[OAuthFlowClientCredentials] = None
|
||||||
|
authorizationCode: Optional[OAuthFlowAuthorizationCode] = None
|
||||||
|
|
||||||
|
|
||||||
|
class OAuth2(SecurityBase):
|
||||||
|
type_ = PSchema(SecuritySchemeType.oauth2, alias="type")
|
||||||
|
flows: OAuthFlows
|
||||||
|
|
||||||
|
|
||||||
|
class OpenIdConnect(SecurityBase):
|
||||||
|
type_ = PSchema(SecuritySchemeType.openIdConnect, alias="type")
|
||||||
|
openIdConnectUrl: str
|
||||||
|
|
||||||
|
|
||||||
|
SecurityScheme = Union[APIKey, HTTPBase, HTTPBearer, OAuth2, OpenIdConnect]
|
||||||
|
|
||||||
|
|
||||||
|
class Components(BaseModel):
|
||||||
|
schemas: Optional[Dict[str, Union[Schema, Reference]]] = None
|
||||||
|
responses: Optional[Dict[str, Union[Response, Reference]]] = None
|
||||||
|
parameters: Optional[Dict[str, Union[Parameter, Reference]]] = None
|
||||||
|
examples: Optional[Dict[str, Union[Example, Reference]]] = None
|
||||||
|
requestBodies: Optional[Dict[str, Union[RequestBody, Reference]]] = None
|
||||||
|
headers: Optional[Dict[str, Union[Header, Reference]]] = None
|
||||||
|
securitySchemes: Optional[Dict[str, Union[SecurityScheme, Reference]]] = None
|
||||||
|
links: Optional[Dict[str, Union[Link, Reference]]] = None
|
||||||
|
callbacks: Optional[Dict[str, Union[Dict[str, PathItem], Reference]]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class Tag(BaseModel):
|
||||||
|
name: str
|
||||||
|
description: Optional[str] = None
|
||||||
|
externalDocs: Optional[ExternalDocumentation] = None
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAPI(BaseModel):
|
||||||
|
openapi: str
|
||||||
|
info: Info
|
||||||
|
servers: Optional[List[Server]] = None
|
||||||
|
paths: Dict[str, PathItem]
|
||||||
|
components: Optional[Components] = None
|
||||||
|
security: Optional[List[Dict[str, List[str]]]] = None
|
||||||
|
tags: Optional[List[Tag]] = None
|
||||||
|
externalDocs: Optional[ExternalDocumentation] = None
|
||||||
|
|
@ -0,0 +1,280 @@
|
||||||
|
from typing import Any, Dict, Sequence, Type
|
||||||
|
|
||||||
|
from starlette.responses import HTMLResponse, JSONResponse
|
||||||
|
from starlette.routing import BaseRoute
|
||||||
|
from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY
|
||||||
|
|
||||||
|
from fastapi import routing
|
||||||
|
from fastapi.dependencies.models import Dependant
|
||||||
|
from fastapi.dependencies.utils import get_flat_dependant
|
||||||
|
from fastapi.encoders import jsonable_encoder
|
||||||
|
from fastapi.openapi.constants import REF_PREFIX, METHODS_WITH_BODY
|
||||||
|
from fastapi.openapi.models import OpenAPI
|
||||||
|
from fastapi.params import Body
|
||||||
|
from fastapi.utils import get_flat_models_from_routes, get_model_definitions
|
||||||
|
from pydantic.fields import Field
|
||||||
|
from pydantic.schema import field_schema, get_model_name_map
|
||||||
|
from pydantic.utils import lenient_issubclass
|
||||||
|
|
||||||
|
validation_error_definition = {
|
||||||
|
"title": "ValidationError",
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"loc": {"title": "Location", "type": "array", "items": {"type": "string"}},
|
||||||
|
"msg": {"title": "Message", "type": "string"},
|
||||||
|
"type": {"title": "Error Type", "type": "string"},
|
||||||
|
},
|
||||||
|
"required": ["loc", "msg", "type"],
|
||||||
|
}
|
||||||
|
|
||||||
|
validation_error_response_definition = {
|
||||||
|
"title": "HTTPValidationError",
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"detail": {
|
||||||
|
"title": "Detail",
|
||||||
|
"type": "array",
|
||||||
|
"items": {"$ref": REF_PREFIX + "ValidationError"},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_openapi_params(dependant: Dependant):
|
||||||
|
flat_dependant = get_flat_dependant(dependant)
|
||||||
|
return (
|
||||||
|
flat_dependant.path_params
|
||||||
|
+ flat_dependant.query_params
|
||||||
|
+ flat_dependant.header_params
|
||||||
|
+ flat_dependant.cookie_params
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_openapi_path(*, route: BaseRoute, model_name_map: Dict[Type, str]):
|
||||||
|
if not (route.include_in_schema and isinstance(route, routing.APIRoute)):
|
||||||
|
return None
|
||||||
|
path = {}
|
||||||
|
security_schemes = {}
|
||||||
|
definitions = {}
|
||||||
|
for method in route.methods:
|
||||||
|
operation: Dict[str, Any] = {}
|
||||||
|
if route.tags:
|
||||||
|
operation["tags"] = route.tags
|
||||||
|
if route.summary:
|
||||||
|
operation["summary"] = route.summary
|
||||||
|
if route.description:
|
||||||
|
operation["description"] = route.description
|
||||||
|
if route.operation_id:
|
||||||
|
operation["operationId"] = route.operation_id
|
||||||
|
else:
|
||||||
|
operation["operationId"] = route.name
|
||||||
|
if route.deprecated:
|
||||||
|
operation["deprecated"] = route.deprecated
|
||||||
|
parameters = []
|
||||||
|
flat_dependant = get_flat_dependant(route.dependant)
|
||||||
|
security_definitions = {}
|
||||||
|
for security_requirement in flat_dependant.security_requirements:
|
||||||
|
security_definition = jsonable_encoder(
|
||||||
|
security_requirement.security_scheme,
|
||||||
|
exclude={"scheme_name"},
|
||||||
|
by_alias=True,
|
||||||
|
include_none=False,
|
||||||
|
)
|
||||||
|
security_name = (
|
||||||
|
getattr(
|
||||||
|
security_requirement.security_scheme, "scheme_name", None
|
||||||
|
)
|
||||||
|
or security_requirement.security_scheme.__class__.__name__
|
||||||
|
)
|
||||||
|
security_definitions[security_name] = security_definition
|
||||||
|
operation.setdefault("security", []).append(
|
||||||
|
{security_name: security_requirement.scopes}
|
||||||
|
)
|
||||||
|
if security_definitions:
|
||||||
|
security_schemes.update(
|
||||||
|
security_definitions
|
||||||
|
)
|
||||||
|
all_route_params = get_openapi_params(route.dependant)
|
||||||
|
for param in all_route_params:
|
||||||
|
if "ValidationError" not in definitions:
|
||||||
|
definitions["ValidationError"] = validation_error_definition
|
||||||
|
definitions[
|
||||||
|
"HTTPValidationError"
|
||||||
|
] = validation_error_response_definition
|
||||||
|
parameter = {
|
||||||
|
"name": param.alias,
|
||||||
|
"in": param.schema.in_.value,
|
||||||
|
"required": param.required,
|
||||||
|
"schema": field_schema(param, model_name_map={})[0],
|
||||||
|
}
|
||||||
|
if param.schema.description:
|
||||||
|
parameter["description"] = param.schema.description
|
||||||
|
if param.schema.deprecated:
|
||||||
|
parameter["deprecated"] = param.schema.deprecated
|
||||||
|
parameters.append(parameter)
|
||||||
|
if parameters:
|
||||||
|
operation["parameters"] = parameters
|
||||||
|
if method in METHODS_WITH_BODY:
|
||||||
|
body_field = route.body_field
|
||||||
|
if body_field:
|
||||||
|
assert isinstance(body_field, Field)
|
||||||
|
body_schema, _ = field_schema(
|
||||||
|
body_field,
|
||||||
|
model_name_map=model_name_map,
|
||||||
|
ref_prefix=REF_PREFIX,
|
||||||
|
)
|
||||||
|
if isinstance(body_field.schema, Body):
|
||||||
|
request_media_type = body_field.schema.media_type
|
||||||
|
else:
|
||||||
|
# Includes not declared media types (Schema)
|
||||||
|
request_media_type = "application/json"
|
||||||
|
required = body_field.required
|
||||||
|
request_body_oai = {}
|
||||||
|
if required:
|
||||||
|
request_body_oai["required"] = required
|
||||||
|
request_body_oai["content"] = {
|
||||||
|
request_media_type: {"schema": body_schema}
|
||||||
|
}
|
||||||
|
operation["requestBody"] = request_body_oai
|
||||||
|
response_code = str(route.response_code)
|
||||||
|
response_schema = {"type": "string"}
|
||||||
|
if lenient_issubclass(route.response_wrapper, JSONResponse):
|
||||||
|
response_media_type = "application/json"
|
||||||
|
if route.response_field:
|
||||||
|
response_schema, _ = field_schema(
|
||||||
|
route.response_field,
|
||||||
|
model_name_map=model_name_map,
|
||||||
|
ref_prefix=REF_PREFIX,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
response_schema = {}
|
||||||
|
elif lenient_issubclass(route.response_wrapper, HTMLResponse):
|
||||||
|
response_media_type = "text/html"
|
||||||
|
else:
|
||||||
|
response_media_type = "text/plain"
|
||||||
|
content = {response_media_type: {"schema": response_schema}}
|
||||||
|
operation["responses"] = {
|
||||||
|
response_code: {
|
||||||
|
"description": route.response_description,
|
||||||
|
"content": content,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if all_route_params or route.body_field:
|
||||||
|
operation["responses"][str(HTTP_422_UNPROCESSABLE_ENTITY)] = {
|
||||||
|
"description": "Validation Error",
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": {"$ref": REF_PREFIX + "HTTPValidationError"}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
path[method.lower()] = operation
|
||||||
|
return path, security_schemes, definitions
|
||||||
|
|
||||||
|
|
||||||
|
def get_openapi(
|
||||||
|
*,
|
||||||
|
title: str,
|
||||||
|
version: str,
|
||||||
|
openapi_version: str = "3.0.2",
|
||||||
|
description: str = None,
|
||||||
|
routes: Sequence[BaseRoute]
|
||||||
|
):
|
||||||
|
info = {"title": title, "version": version}
|
||||||
|
if description:
|
||||||
|
info["description"] = description
|
||||||
|
output = {"openapi": openapi_version, "info": info}
|
||||||
|
components: Dict[str, Dict] = {}
|
||||||
|
paths: Dict[str, Dict] = {}
|
||||||
|
flat_models = get_flat_models_from_routes(routes)
|
||||||
|
model_name_map = get_model_name_map(flat_models)
|
||||||
|
definitions = get_model_definitions(
|
||||||
|
flat_models=flat_models, model_name_map=model_name_map
|
||||||
|
)
|
||||||
|
for route in routes:
|
||||||
|
result = get_openapi_path(route=route, model_name_map=model_name_map)
|
||||||
|
if result:
|
||||||
|
path, security_schemes, path_definitions = result
|
||||||
|
if path:
|
||||||
|
paths.setdefault(route.path, {}).update(path)
|
||||||
|
if security_schemes:
|
||||||
|
components.setdefault("securitySchemes", {}).update(security_schemes)
|
||||||
|
if path_definitions:
|
||||||
|
definitions.update(path_definitions)
|
||||||
|
if definitions:
|
||||||
|
components.setdefault("schemas", {}).update(definitions)
|
||||||
|
if components:
|
||||||
|
output["components"] = components
|
||||||
|
output["paths"] = paths
|
||||||
|
return jsonable_encoder(OpenAPI(**output), by_alias=True, include_none=False)
|
||||||
|
|
||||||
|
|
||||||
|
def get_swagger_ui_html(*, openapi_url: str, title: str):
|
||||||
|
return HTMLResponse(
|
||||||
|
"""
|
||||||
|
<! doctype html>
|
||||||
|
<html>
|
||||||
|
<head>
|
||||||
|
<link type="text/css" rel="stylesheet" href="//unpkg.com/swagger-ui-dist@3/swagger-ui.css">
|
||||||
|
<title>
|
||||||
|
""" + title + """
|
||||||
|
</title>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<div id="swagger-ui">
|
||||||
|
</div>
|
||||||
|
<script src="//unpkg.com/swagger-ui-dist@3/swagger-ui-bundle.js"></script>
|
||||||
|
<!-- `SwaggerUIBundle` is now available on the page -->
|
||||||
|
<script>
|
||||||
|
|
||||||
|
const ui = SwaggerUIBundle({
|
||||||
|
url: '"""
|
||||||
|
+ openapi_url
|
||||||
|
+ """',
|
||||||
|
dom_id: '#swagger-ui',
|
||||||
|
presets: [
|
||||||
|
SwaggerUIBundle.presets.apis,
|
||||||
|
SwaggerUIBundle.SwaggerUIStandalonePreset
|
||||||
|
],
|
||||||
|
layout: "BaseLayout"
|
||||||
|
|
||||||
|
})
|
||||||
|
</script>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
|
""",
|
||||||
|
media_type="text/html",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_redoc_html(*, openapi_url: str, title: str):
|
||||||
|
return HTMLResponse(
|
||||||
|
"""
|
||||||
|
<!DOCTYPE html>
|
||||||
|
<html>
|
||||||
|
<head>
|
||||||
|
<title>
|
||||||
|
""" + title + """
|
||||||
|
</title>
|
||||||
|
<!-- needed for adaptive design -->
|
||||||
|
<meta charset="utf-8"/>
|
||||||
|
<meta name="viewport" content="width=device-width, initial-scale=1">
|
||||||
|
<link href="https://fonts.googleapis.com/css?family=Montserrat:300,400,700|Roboto:300,400,700" rel="stylesheet">
|
||||||
|
|
||||||
|
<!--
|
||||||
|
ReDoc doesn't change outer page styles
|
||||||
|
-->
|
||||||
|
<style>
|
||||||
|
body {
|
||||||
|
margin: 0;
|
||||||
|
padding: 0;
|
||||||
|
}
|
||||||
|
</style>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<redoc spec-url='""" + openapi_url + """'></redoc>
|
||||||
|
<script src="https://cdn.jsdelivr.net/npm/redoc@next/bundles/redoc.standalone.js"> </script>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
|
""",
|
||||||
|
media_type="text/html",
|
||||||
|
)
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
from typing import Sequence
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
from typing import Sequence, Any, Dict
|
||||||
|
|
||||||
from pydantic import Schema
|
from pydantic import Schema
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -12,6 +13,7 @@ class ParamTypes(Enum):
|
||||||
|
|
||||||
class Param(Schema):
|
class Param(Schema):
|
||||||
in_: ParamTypes
|
in_: ParamTypes
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
default,
|
default,
|
||||||
|
|
@ -27,7 +29,7 @@ class Param(Schema):
|
||||||
min_length: int = None,
|
min_length: int = None,
|
||||||
max_length: int = None,
|
max_length: int = None,
|
||||||
regex: str = None,
|
regex: str = None,
|
||||||
**extra: object,
|
**extra: Dict[str, Any],
|
||||||
):
|
):
|
||||||
self.deprecated = deprecated
|
self.deprecated = deprecated
|
||||||
super().__init__(
|
super().__init__(
|
||||||
|
|
@ -64,7 +66,7 @@ class Path(Param):
|
||||||
min_length: int = None,
|
min_length: int = None,
|
||||||
max_length: int = None,
|
max_length: int = None,
|
||||||
regex: str = None,
|
regex: str = None,
|
||||||
**extra: object,
|
**extra: Dict[str, Any],
|
||||||
):
|
):
|
||||||
self.description = description
|
self.description = description
|
||||||
self.deprecated = deprecated
|
self.deprecated = deprecated
|
||||||
|
|
@ -103,7 +105,7 @@ class Query(Param):
|
||||||
min_length: int = None,
|
min_length: int = None,
|
||||||
max_length: int = None,
|
max_length: int = None,
|
||||||
regex: str = None,
|
regex: str = None,
|
||||||
**extra: object,
|
**extra: Dict[str, Any],
|
||||||
):
|
):
|
||||||
self.description = description
|
self.description = description
|
||||||
self.deprecated = deprecated
|
self.deprecated = deprecated
|
||||||
|
|
@ -141,7 +143,7 @@ class Header(Param):
|
||||||
min_length: int = None,
|
min_length: int = None,
|
||||||
max_length: int = None,
|
max_length: int = None,
|
||||||
regex: str = None,
|
regex: str = None,
|
||||||
**extra: object,
|
**extra: Dict[str, Any],
|
||||||
):
|
):
|
||||||
self.description = description
|
self.description = description
|
||||||
self.deprecated = deprecated
|
self.deprecated = deprecated
|
||||||
|
|
@ -179,7 +181,7 @@ class Cookie(Param):
|
||||||
min_length: int = None,
|
min_length: int = None,
|
||||||
max_length: int = None,
|
max_length: int = None,
|
||||||
regex: str = None,
|
regex: str = None,
|
||||||
**extra: object,
|
**extra: Dict[str, Any],
|
||||||
):
|
):
|
||||||
self.description = description
|
self.description = description
|
||||||
self.deprecated = deprecated
|
self.deprecated = deprecated
|
||||||
|
|
@ -204,7 +206,8 @@ class Body(Schema):
|
||||||
self,
|
self,
|
||||||
default,
|
default,
|
||||||
*,
|
*,
|
||||||
sub_key=False,
|
embed=False,
|
||||||
|
media_type: str = "application/json",
|
||||||
alias: str = None,
|
alias: str = None,
|
||||||
title: str = None,
|
title: str = None,
|
||||||
description: str = None,
|
description: str = None,
|
||||||
|
|
@ -215,9 +218,10 @@ class Body(Schema):
|
||||||
min_length: int = None,
|
min_length: int = None,
|
||||||
max_length: int = None,
|
max_length: int = None,
|
||||||
regex: str = None,
|
regex: str = None,
|
||||||
**extra: object,
|
**extra: Dict[str, Any],
|
||||||
):
|
):
|
||||||
self.sub_key = sub_key
|
self.embed = embed
|
||||||
|
self.media_type = media_type
|
||||||
super().__init__(
|
super().__init__(
|
||||||
default,
|
default,
|
||||||
alias=alias,
|
alias=alias,
|
||||||
|
|
@ -234,13 +238,86 @@ class Body(Schema):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Form(Body):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
default,
|
||||||
|
*,
|
||||||
|
sub_key=False,
|
||||||
|
media_type: str = "application/x-www-form-urlencoded",
|
||||||
|
alias: str = None,
|
||||||
|
title: str = None,
|
||||||
|
description: str = None,
|
||||||
|
gt: float = None,
|
||||||
|
ge: float = None,
|
||||||
|
lt: float = None,
|
||||||
|
le: float = None,
|
||||||
|
min_length: int = None,
|
||||||
|
max_length: int = None,
|
||||||
|
regex: str = None,
|
||||||
|
**extra: Dict[str, Any],
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
default,
|
||||||
|
embed=sub_key,
|
||||||
|
media_type=media_type,
|
||||||
|
alias=alias,
|
||||||
|
title=title,
|
||||||
|
description=description,
|
||||||
|
gt=gt,
|
||||||
|
ge=ge,
|
||||||
|
lt=lt,
|
||||||
|
le=le,
|
||||||
|
min_length=min_length,
|
||||||
|
max_length=max_length,
|
||||||
|
regex=regex,
|
||||||
|
**extra,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class File(Form):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
default,
|
||||||
|
*,
|
||||||
|
sub_key=False,
|
||||||
|
media_type: str = "multipart/form-data",
|
||||||
|
alias: str = None,
|
||||||
|
title: str = None,
|
||||||
|
description: str = None,
|
||||||
|
gt: float = None,
|
||||||
|
ge: float = None,
|
||||||
|
lt: float = None,
|
||||||
|
le: float = None,
|
||||||
|
min_length: int = None,
|
||||||
|
max_length: int = None,
|
||||||
|
regex: str = None,
|
||||||
|
**extra: Dict[str, Any],
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
default,
|
||||||
|
embed=sub_key,
|
||||||
|
media_type=media_type,
|
||||||
|
alias=alias,
|
||||||
|
title=title,
|
||||||
|
description=description,
|
||||||
|
gt=gt,
|
||||||
|
ge=ge,
|
||||||
|
lt=lt,
|
||||||
|
le=le,
|
||||||
|
min_length=min_length,
|
||||||
|
max_length=max_length,
|
||||||
|
regex=regex,
|
||||||
|
**extra,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Depends:
|
class Depends:
|
||||||
def __init__(self, dependency = None):
|
def __init__(self, dependency=None):
|
||||||
self.dependency = dependency
|
self.dependency = dependency
|
||||||
|
|
||||||
|
|
||||||
class Security:
|
class Security(Depends):
|
||||||
def __init__(self, security_scheme = None, scopes: Sequence[str] = None):
|
def __init__(self, dependency=None, scopes: Sequence[str] = None):
|
||||||
self.security_scheme = security_scheme
|
self.scopes = scopes or []
|
||||||
self.scopes = scopes
|
super().__init__(dependency=dependency)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,341 +1,66 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import inspect
|
import inspect
|
||||||
import re
|
from typing import Callable, List, Type
|
||||||
import typing
|
|
||||||
from copy import deepcopy
|
|
||||||
|
|
||||||
from starlette import routing
|
from starlette import routing
|
||||||
from starlette.routing import get_name, request_response
|
|
||||||
from starlette.requests import Request
|
|
||||||
from starlette.responses import Response, JSONResponse
|
|
||||||
from starlette.concurrency import run_in_threadpool
|
from starlette.concurrency import run_in_threadpool
|
||||||
from starlette.exceptions import HTTPException
|
from starlette.exceptions import HTTPException
|
||||||
|
from starlette.formparsers import UploadFile
|
||||||
|
from starlette.requests import Request
|
||||||
|
from starlette.responses import JSONResponse, Response
|
||||||
|
from starlette.routing import get_name, request_response
|
||||||
from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY
|
from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY
|
||||||
|
|
||||||
|
|
||||||
from pydantic.fields import Field, Required
|
|
||||||
from pydantic.schema import get_annotation_from_schema
|
|
||||||
from pydantic import BaseConfig, BaseModel, create_model, Schema
|
|
||||||
from pydantic.error_wrappers import ErrorWrapper, ValidationError
|
|
||||||
from pydantic.errors import MissingError
|
|
||||||
from pydantic.utils import lenient_issubclass
|
|
||||||
from .pydantic_utils import jsonable_encoder
|
|
||||||
|
|
||||||
from fastapi import params
|
from fastapi import params
|
||||||
from fastapi.security.base import SecurityBase
|
from fastapi.dependencies.models import Dependant
|
||||||
|
from fastapi.dependencies.utils import get_body_field, get_dependant, solve_dependencies
|
||||||
|
from fastapi.encoders import jsonable_encoder
|
||||||
|
from pydantic import BaseConfig, BaseModel, Schema
|
||||||
|
from pydantic.error_wrappers import ErrorWrapper, ValidationError
|
||||||
|
from pydantic.fields import Field
|
||||||
|
from pydantic.utils import lenient_issubclass
|
||||||
|
|
||||||
|
|
||||||
param_supported_types = (str, int, float, bool)
|
def serialize_response(*, field: Field = None, response):
|
||||||
|
if field:
|
||||||
|
|
||||||
class Dependant:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
path_params: typing.List[Field] = None,
|
|
||||||
query_params: typing.List[Field] = None,
|
|
||||||
header_params: typing.List[Field] = None,
|
|
||||||
cookie_params: typing.List[Field] = None,
|
|
||||||
body_params: typing.List[Field] = None,
|
|
||||||
dependencies: typing.List["Dependant"] = None,
|
|
||||||
security_schemes: typing.List[Field] = None,
|
|
||||||
name: str = None,
|
|
||||||
call: typing.Callable = None,
|
|
||||||
request_param_name: str = None,
|
|
||||||
) -> None:
|
|
||||||
self.path_params: typing.List[Field] = path_params or []
|
|
||||||
self.query_params: typing.List[Field] = query_params or []
|
|
||||||
self.header_params: typing.List[Field] = header_params or []
|
|
||||||
self.cookie_params: typing.List[Field] = cookie_params or []
|
|
||||||
self.body_params: typing.List[Field] = body_params or []
|
|
||||||
self.dependencies: typing.List[Dependant] = dependencies or []
|
|
||||||
self.security_schemes: typing.List[Field] = security_schemes or []
|
|
||||||
self.request_param_name = request_param_name
|
|
||||||
self.name = name
|
|
||||||
self.call: typing.Callable = call
|
|
||||||
|
|
||||||
|
|
||||||
def request_params_to_args(
|
|
||||||
required_params: typing.List[Field], received_params: typing.Dict[str, typing.Any]
|
|
||||||
) -> typing.Tuple[typing.Dict[str, typing.Any], typing.List[ErrorWrapper]]:
|
|
||||||
values = {}
|
|
||||||
errors = []
|
errors = []
|
||||||
for field in required_params:
|
value, errors_ = field.validate(response, {}, loc=("response",))
|
||||||
value = received_params.get(field.alias)
|
|
||||||
if value is None:
|
|
||||||
if field.required:
|
|
||||||
errors.append(
|
|
||||||
ErrorWrapper(MissingError(), loc=field.alias, config=BaseConfig)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
values[field.name] = deepcopy(field.default)
|
|
||||||
continue
|
|
||||||
v_, errors_ = field.validate(
|
|
||||||
value, values, loc=(field.schema.in_.value, field.alias)
|
|
||||||
)
|
|
||||||
if isinstance(errors_, ErrorWrapper):
|
if isinstance(errors_, ErrorWrapper):
|
||||||
errors_: ErrorWrapper
|
|
||||||
errors.append(errors_)
|
errors.append(errors_)
|
||||||
elif isinstance(errors_, list):
|
elif isinstance(errors_, list):
|
||||||
errors.extend(errors_)
|
errors.extend(errors_)
|
||||||
|
if errors:
|
||||||
|
raise ValidationError(errors)
|
||||||
|
return jsonable_encoder(value)
|
||||||
else:
|
else:
|
||||||
values[field.name] = v_
|
return jsonable_encoder(response)
|
||||||
return values, errors
|
|
||||||
|
|
||||||
|
|
||||||
def request_body_to_args(
|
def get_app(
|
||||||
required_params: typing.List[Field], received_body: typing.Dict[str, typing.Any]
|
|
||||||
) -> typing.Tuple[typing.Dict[str, typing.Any], typing.List[ErrorWrapper]]:
|
|
||||||
values = {}
|
|
||||||
errors = []
|
|
||||||
if required_params:
|
|
||||||
field = required_params[0]
|
|
||||||
sub_key = getattr(field.schema, "sub_key", None)
|
|
||||||
if len(required_params) == 1 and not sub_key:
|
|
||||||
received_body = {field.alias: received_body}
|
|
||||||
for field in required_params:
|
|
||||||
value = received_body.get(field.alias)
|
|
||||||
if value is None:
|
|
||||||
if field.required:
|
|
||||||
errors.append(
|
|
||||||
ErrorWrapper(
|
|
||||||
MissingError(), loc=("body", field.alias), config=BaseConfig
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
values[field.name] = deepcopy(field.default)
|
|
||||||
continue
|
|
||||||
|
|
||||||
v_, errors_ = field.validate(value, values, loc=("body", field.alias))
|
|
||||||
if isinstance(errors_, ErrorWrapper):
|
|
||||||
errors_: ErrorWrapper
|
|
||||||
errors.append(errors_)
|
|
||||||
elif isinstance(errors_, list):
|
|
||||||
errors.extend(errors_)
|
|
||||||
else:
|
|
||||||
values[field.name] = v_
|
|
||||||
return values, errors
|
|
||||||
|
|
||||||
|
|
||||||
def add_param_to_fields(
|
|
||||||
*,
|
|
||||||
param: inspect.Parameter,
|
|
||||||
dependant: Dependant,
|
dependant: Dependant,
|
||||||
default_schema=params.Param,
|
body_field: Field = None,
|
||||||
force_type: params.ParamTypes = None,
|
response_code: str = 200,
|
||||||
|
response_wrapper: Type[Response] = JSONResponse,
|
||||||
|
response_field: Type[Field] = None,
|
||||||
):
|
):
|
||||||
default_value = Required
|
is_coroutine = dependant.call and asyncio.iscoroutinefunction(dependant.call)
|
||||||
if not param.default == param.empty:
|
|
||||||
default_value = param.default
|
|
||||||
if isinstance(default_value, params.Param):
|
|
||||||
schema = default_value
|
|
||||||
default_value = schema.default
|
|
||||||
if schema.in_ is None:
|
|
||||||
schema.in_ = default_schema.in_
|
|
||||||
if force_type:
|
|
||||||
schema.in_ = force_type
|
|
||||||
else:
|
|
||||||
schema = default_schema(default_value)
|
|
||||||
required = default_value == Required
|
|
||||||
annotation = typing.Any
|
|
||||||
if not param.annotation == param.empty:
|
|
||||||
annotation = param.annotation
|
|
||||||
annotation = get_annotation_from_schema(annotation, schema)
|
|
||||||
Config = BaseConfig
|
|
||||||
field = Field(
|
|
||||||
name=param.name,
|
|
||||||
type_=annotation,
|
|
||||||
default=None if required else default_value,
|
|
||||||
alias=schema.alias or param.name,
|
|
||||||
required=required,
|
|
||||||
model_config=Config,
|
|
||||||
class_validators=[],
|
|
||||||
schema=schema,
|
|
||||||
)
|
|
||||||
if schema.in_ == params.ParamTypes.path:
|
|
||||||
dependant.path_params.append(field)
|
|
||||||
elif schema.in_ == params.ParamTypes.query:
|
|
||||||
dependant.query_params.append(field)
|
|
||||||
elif schema.in_ == params.ParamTypes.header:
|
|
||||||
dependant.header_params.append(field)
|
|
||||||
else:
|
|
||||||
assert (
|
|
||||||
schema.in_ == params.ParamTypes.cookie
|
|
||||||
), f"non-body parameters must be in path, query, header or cookie: {param.name}"
|
|
||||||
dependant.cookie_params.append(field)
|
|
||||||
|
|
||||||
|
|
||||||
def add_param_to_body_fields(*, param: inspect.Parameter, dependant: Dependant):
|
|
||||||
default_value = Required
|
|
||||||
if not param.default == param.empty:
|
|
||||||
default_value = param.default
|
|
||||||
if isinstance(default_value, Schema):
|
|
||||||
schema = default_value
|
|
||||||
default_value = schema.default
|
|
||||||
else:
|
|
||||||
schema = Schema(default_value)
|
|
||||||
required = default_value == Required
|
|
||||||
annotation = get_annotation_from_schema(param.annotation, schema)
|
|
||||||
field = Field(
|
|
||||||
name=param.name,
|
|
||||||
type_=annotation,
|
|
||||||
default=None if required else default_value,
|
|
||||||
alias=schema.alias or param.name,
|
|
||||||
required=required,
|
|
||||||
model_config=BaseConfig,
|
|
||||||
class_validators=[],
|
|
||||||
schema=schema,
|
|
||||||
)
|
|
||||||
dependant.body_params.append(field)
|
|
||||||
|
|
||||||
|
|
||||||
def get_sub_dependant(
|
|
||||||
*, param: inspect.Parameter, path: str
|
|
||||||
):
|
|
||||||
depends: params.Depends = param.default
|
|
||||||
if depends.dependency:
|
|
||||||
dependency = depends.dependency
|
|
||||||
else:
|
|
||||||
dependency = param.annotation
|
|
||||||
assert callable(dependency)
|
|
||||||
sub_dependant = get_dependant(path=path, call=dependency, name=param.name)
|
|
||||||
if isinstance(dependency, SecurityBase):
|
|
||||||
sub_dependant.security_schemes.append(dependency)
|
|
||||||
return sub_dependant
|
|
||||||
|
|
||||||
|
|
||||||
def get_flat_dependant(dependant: Dependant):
|
|
||||||
flat_dependant = Dependant(
|
|
||||||
path_params=dependant.path_params.copy(),
|
|
||||||
query_params=dependant.query_params.copy(),
|
|
||||||
header_params=dependant.header_params.copy(),
|
|
||||||
cookie_params=dependant.cookie_params.copy(),
|
|
||||||
body_params=dependant.body_params.copy(),
|
|
||||||
security_schemes=dependant.security_schemes.copy(),
|
|
||||||
)
|
|
||||||
for sub_dependant in dependant.dependencies:
|
|
||||||
if sub_dependant is dependant:
|
|
||||||
raise ValueError("recursion", dependant.dependencies)
|
|
||||||
flat_sub = get_flat_dependant(sub_dependant)
|
|
||||||
flat_dependant.path_params.extend(flat_sub.path_params)
|
|
||||||
flat_dependant.query_params.extend(flat_sub.query_params)
|
|
||||||
flat_dependant.header_params.extend(flat_sub.header_params)
|
|
||||||
flat_dependant.cookie_params.extend(flat_sub.cookie_params)
|
|
||||||
flat_dependant.body_params.extend(flat_sub.body_params)
|
|
||||||
flat_dependant.security_schemes.extend(flat_sub.security_schemes)
|
|
||||||
return flat_dependant
|
|
||||||
|
|
||||||
|
|
||||||
def get_path_param_names(path: str):
|
|
||||||
return {item.strip("{}") for item in re.findall("{[^}]*}", path)}
|
|
||||||
|
|
||||||
|
|
||||||
def get_dependant(*, path: str, call: typing.Callable, name: str = None):
|
|
||||||
path_param_names = get_path_param_names(path)
|
|
||||||
endpoint_signature = inspect.signature(call)
|
|
||||||
signature_params = endpoint_signature.parameters
|
|
||||||
dependant = Dependant(call=call, name=name)
|
|
||||||
for param_name in signature_params:
|
|
||||||
param = signature_params[param_name]
|
|
||||||
if isinstance(param.default, params.Depends):
|
|
||||||
sub_dependant = get_sub_dependant(param=param, path=path)
|
|
||||||
dependant.dependencies.append(sub_dependant)
|
|
||||||
for param_name in signature_params:
|
|
||||||
param = signature_params[param_name]
|
|
||||||
if (
|
|
||||||
(param.default == param.empty) or isinstance(param.default, params.Path)
|
|
||||||
) and (param_name in path_param_names):
|
|
||||||
assert lenient_issubclass(
|
|
||||||
param.annotation, param_supported_types
|
|
||||||
), f"Path params must be of type str, int, float or boot: {param}"
|
|
||||||
param = signature_params[param_name]
|
|
||||||
add_param_to_fields(
|
|
||||||
param=param,
|
|
||||||
dependant=dependant,
|
|
||||||
default_schema=params.Path,
|
|
||||||
force_type=params.ParamTypes.path,
|
|
||||||
)
|
|
||||||
elif (param.default == param.empty or param.default is None) and (
|
|
||||||
param.annotation == param.empty
|
|
||||||
or lenient_issubclass(param.annotation, param_supported_types)
|
|
||||||
):
|
|
||||||
add_param_to_fields(
|
|
||||||
param=param, dependant=dependant, default_schema=params.Query
|
|
||||||
)
|
|
||||||
elif isinstance(param.default, params.Param):
|
|
||||||
if param.annotation != param.empty:
|
|
||||||
assert lenient_issubclass(
|
|
||||||
param.annotation, param_supported_types
|
|
||||||
), f"Parameters for Path, Query, Header and Cookies must be of type str, int, float or bool: {param}"
|
|
||||||
add_param_to_fields(
|
|
||||||
param=param, dependant=dependant, default_schema=params.Query
|
|
||||||
)
|
|
||||||
elif lenient_issubclass(param.annotation, Request):
|
|
||||||
dependant.request_param_name = param_name
|
|
||||||
elif not isinstance(param.default, params.Depends):
|
|
||||||
add_param_to_body_fields(param=param, dependant=dependant)
|
|
||||||
return dependant
|
|
||||||
|
|
||||||
|
|
||||||
def is_coroutine_callable(call: typing.Callable):
|
|
||||||
if inspect.isfunction(call):
|
|
||||||
return asyncio.iscoroutinefunction(call)
|
|
||||||
elif inspect.isclass(call):
|
|
||||||
return False
|
|
||||||
else:
|
|
||||||
call = getattr(call, "__call__", None)
|
|
||||||
if not call:
|
|
||||||
return False
|
|
||||||
else:
|
|
||||||
return asyncio.iscoroutinefunction(call)
|
|
||||||
|
|
||||||
|
|
||||||
async def solve_dependencies(*, request: Request, dependant: Dependant):
|
|
||||||
values = {}
|
|
||||||
errors = []
|
|
||||||
for sub_dependant in dependant.dependencies:
|
|
||||||
sub_values, sub_errors = await solve_dependencies(
|
|
||||||
request=request, dependant=sub_dependant
|
|
||||||
)
|
|
||||||
if sub_errors:
|
|
||||||
return {}, errors
|
|
||||||
if is_coroutine_callable(sub_dependant.call):
|
|
||||||
solved = await sub_dependant.call(**sub_values)
|
|
||||||
else:
|
|
||||||
solved = await run_in_threadpool(sub_dependant.call, **sub_values)
|
|
||||||
values[sub_dependant.name] = solved
|
|
||||||
path_values, path_errors = request_params_to_args(
|
|
||||||
dependant.path_params, request.path_params
|
|
||||||
)
|
|
||||||
query_values, query_errors = request_params_to_args(
|
|
||||||
dependant.query_params, request.query_params
|
|
||||||
)
|
|
||||||
header_values, header_errors = request_params_to_args(
|
|
||||||
dependant.header_params, request.headers
|
|
||||||
)
|
|
||||||
cookie_values, cookie_errors = request_params_to_args(
|
|
||||||
dependant.cookie_params, request.cookies
|
|
||||||
)
|
|
||||||
values.update(path_values)
|
|
||||||
values.update(query_values)
|
|
||||||
values.update(header_values)
|
|
||||||
values.update(cookie_values)
|
|
||||||
errors = path_errors + query_errors + header_errors + cookie_errors
|
|
||||||
if dependant.body_params:
|
|
||||||
body = await request.json()
|
|
||||||
body_values, body_errors = request_body_to_args(dependant.body_params, body)
|
|
||||||
values.update(body_values)
|
|
||||||
errors.extend(body_errors)
|
|
||||||
if dependant.request_param_name:
|
|
||||||
values[dependant.request_param_name] = request
|
|
||||||
return values, errors
|
|
||||||
|
|
||||||
|
|
||||||
def get_app(dependant: Dependant):
|
|
||||||
is_coroutine = asyncio.iscoroutinefunction(dependant.call)
|
|
||||||
|
|
||||||
async def app(request: Request) -> Response:
|
async def app(request: Request) -> Response:
|
||||||
values, errors = await solve_dependencies(request=request, dependant=dependant)
|
body = None
|
||||||
|
if body_field:
|
||||||
|
if isinstance(body_field.schema, params.Form):
|
||||||
|
raw_body = await request.form()
|
||||||
|
body = {}
|
||||||
|
for field, value in raw_body.items():
|
||||||
|
if isinstance(value, UploadFile):
|
||||||
|
body[field] = await value.read()
|
||||||
|
else:
|
||||||
|
body[field] = value
|
||||||
|
else:
|
||||||
|
body = await request.json()
|
||||||
|
values, errors = await solve_dependencies(
|
||||||
|
request=request, dependant=dependant, body=body
|
||||||
|
)
|
||||||
if errors:
|
if errors:
|
||||||
errors_out = ValidationError(errors)
|
errors_out = ValidationError(errors)
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
@ -348,36 +73,56 @@ def get_app(dependant: Dependant):
|
||||||
raw_response = await run_in_threadpool(dependant.call, **values)
|
raw_response = await run_in_threadpool(dependant.call, **values)
|
||||||
if isinstance(raw_response, Response):
|
if isinstance(raw_response, Response):
|
||||||
return raw_response
|
return raw_response
|
||||||
else:
|
if isinstance(raw_response, BaseModel):
|
||||||
return JSONResponse(content=jsonable_encoder(raw_response))
|
return response_wrapper(
|
||||||
return app
|
content=jsonable_encoder(raw_response), status_code=response_code
|
||||||
|
|
||||||
|
|
||||||
def get_openapi_params(dependant: Dependant):
|
|
||||||
flat_dependant = get_flat_dependant(dependant)
|
|
||||||
return (
|
|
||||||
flat_dependant.path_params
|
|
||||||
+ flat_dependant.query_params
|
|
||||||
+ flat_dependant.header_params
|
|
||||||
+ flat_dependant.cookie_params
|
|
||||||
)
|
)
|
||||||
|
errors = []
|
||||||
|
try:
|
||||||
|
return response_wrapper(
|
||||||
|
content=serialize_response(
|
||||||
|
field=response_field, response=raw_response
|
||||||
|
),
|
||||||
|
status_code=response_code,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
errors.append(e)
|
||||||
|
try:
|
||||||
|
response = dict(raw_response)
|
||||||
|
return response_wrapper(
|
||||||
|
content=serialize_response(field=response_field, response=response),
|
||||||
|
status_code=response_code,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
errors.append(e)
|
||||||
|
try:
|
||||||
|
response = vars(raw_response)
|
||||||
|
return response_wrapper(
|
||||||
|
content=serialize_response(field=response_field, response=response),
|
||||||
|
status_code=response_code,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
errors.append(e)
|
||||||
|
raise ValueError(errors)
|
||||||
|
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
class APIRoute(routing.Route):
|
class APIRoute(routing.Route):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
path: str,
|
path: str,
|
||||||
endpoint: typing.Callable,
|
endpoint: Callable,
|
||||||
*,
|
*,
|
||||||
methods: typing.List[str] = None,
|
methods: List[str] = None,
|
||||||
name: str = None,
|
name: str = None,
|
||||||
include_in_schema: bool = True,
|
include_in_schema: bool = True,
|
||||||
tags: typing.List[str] = [],
|
tags: List[str] = None,
|
||||||
summary: str = None,
|
summary: str = None,
|
||||||
description: str = None,
|
description: str = None,
|
||||||
operation_id: str = None,
|
operation_id: str = None,
|
||||||
deprecated: bool = None,
|
deprecated: bool = None,
|
||||||
response_type: typing.Type = None,
|
response_type: Type = None,
|
||||||
response_description: str = "Successful Response",
|
response_description: str = "Successful Response",
|
||||||
response_code=200,
|
response_code=200,
|
||||||
response_wrapper=JSONResponse,
|
response_wrapper=JSONResponse,
|
||||||
|
|
@ -392,12 +137,12 @@ class APIRoute(routing.Route):
|
||||||
self.endpoint = endpoint
|
self.endpoint = endpoint
|
||||||
self.name = get_name(endpoint) if name is None else name
|
self.name = get_name(endpoint) if name is None else name
|
||||||
self.include_in_schema = include_in_schema
|
self.include_in_schema = include_in_schema
|
||||||
self.tags = tags
|
self.tags = tags or []
|
||||||
self.summary = summary
|
self.summary = summary
|
||||||
self.description = description
|
self.description = description or self.endpoint.__doc__
|
||||||
self.operation_id = operation_id
|
self.operation_id = operation_id
|
||||||
self.deprecated = deprecated
|
self.deprecated = deprecated
|
||||||
self.request_body: typing.Union[BaseModel, Field, None] = None
|
self.body_field: Field = None
|
||||||
self.response_description = response_description
|
self.response_description = response_description
|
||||||
self.response_code = response_code
|
self.response_code = response_code
|
||||||
self.response_wrapper = response_wrapper
|
self.response_wrapper = response_wrapper
|
||||||
|
|
@ -430,53 +175,32 @@ class APIRoute(routing.Route):
|
||||||
), f"An endpoint must be a function or method"
|
), f"An endpoint must be a function or method"
|
||||||
|
|
||||||
self.dependant = get_dependant(path=path, call=self.endpoint)
|
self.dependant = get_dependant(path=path, call=self.endpoint)
|
||||||
# flat_dependant = get_flat_dependant(self.dependant)
|
self.body_field = get_body_field(dependant=self.dependant, name=self.name)
|
||||||
# path_param_names = get_path_param_names(path)
|
self.app = request_response(
|
||||||
# for path_param in path_param_names:
|
get_app(
|
||||||
# assert path_param in {
|
dependant=self.dependant,
|
||||||
# f.alias for f in flat_dependant.path_params
|
body_field=self.body_field,
|
||||||
# }, f"Path parameter must be defined as a function parameter or be defined by a dependency: {path_param}"
|
response_code=self.response_code,
|
||||||
|
response_wrapper=self.response_wrapper,
|
||||||
if self.dependant.body_params:
|
response_field=self.response_field,
|
||||||
first_param = self.dependant.body_params[0]
|
)
|
||||||
sub_key = getattr(first_param.schema, "sub_key", None)
|
|
||||||
if len(self.dependant.body_params) == 1 and not sub_key:
|
|
||||||
self.request_body = first_param
|
|
||||||
else:
|
|
||||||
model_name = "Body_" + self.name
|
|
||||||
BodyModel = create_model(model_name)
|
|
||||||
for f in self.dependant.body_params:
|
|
||||||
BodyModel.__fields__[f.name] = f
|
|
||||||
required = any(True for f in self.dependant.body_params if f.required)
|
|
||||||
field = Field(
|
|
||||||
name="body",
|
|
||||||
type_=BodyModel,
|
|
||||||
default=None,
|
|
||||||
required=required,
|
|
||||||
model_config=BaseConfig,
|
|
||||||
class_validators=[],
|
|
||||||
alias="body",
|
|
||||||
schema=Schema(None),
|
|
||||||
)
|
)
|
||||||
self.request_body = field
|
|
||||||
|
|
||||||
self.app = request_response(get_app(dependant=self.dependant))
|
|
||||||
|
|
||||||
|
|
||||||
class APIRouter(routing.Router):
|
class APIRouter(routing.Router):
|
||||||
def add_api_route(
|
def add_api_route(
|
||||||
self,
|
self,
|
||||||
path: str,
|
path: str,
|
||||||
endpoint: typing.Callable,
|
endpoint: Callable,
|
||||||
methods: typing.List[str] = None,
|
methods: List[str] = None,
|
||||||
name: str = None,
|
name: str = None,
|
||||||
include_in_schema: bool = True,
|
include_in_schema: bool = True,
|
||||||
tags: typing.List[str] = [],
|
tags: List[str] = None,
|
||||||
summary: str = None,
|
summary: str = None,
|
||||||
description: str = None,
|
description: str = None,
|
||||||
operation_id: str = None,
|
operation_id: str = None,
|
||||||
deprecated: bool = None,
|
deprecated: bool = None,
|
||||||
response_type: typing.Type = None,
|
response_type: Type = None,
|
||||||
response_description: str = "Successful Response",
|
response_description: str = "Successful Response",
|
||||||
response_code=200,
|
response_code=200,
|
||||||
response_wrapper=JSONResponse,
|
response_wrapper=JSONResponse,
|
||||||
|
|
@ -487,7 +211,7 @@ class APIRouter(routing.Router):
|
||||||
methods=methods,
|
methods=methods,
|
||||||
name=name,
|
name=name,
|
||||||
include_in_schema=include_in_schema,
|
include_in_schema=include_in_schema,
|
||||||
tags=tags,
|
tags=tags or [],
|
||||||
summary=summary,
|
summary=summary,
|
||||||
description=description,
|
description=description,
|
||||||
operation_id=operation_id,
|
operation_id=operation_id,
|
||||||
|
|
@ -502,27 +226,27 @@ class APIRouter(routing.Router):
|
||||||
def api_route(
|
def api_route(
|
||||||
self,
|
self,
|
||||||
path: str,
|
path: str,
|
||||||
methods: typing.List[str] = None,
|
methods: List[str] = None,
|
||||||
name: str = None,
|
name: str = None,
|
||||||
include_in_schema: bool = True,
|
include_in_schema: bool = True,
|
||||||
tags: typing.List[str] = [],
|
tags: List[str] = None,
|
||||||
summary: str = None,
|
summary: str = None,
|
||||||
description: str = None,
|
description: str = None,
|
||||||
operation_id: str = None,
|
operation_id: str = None,
|
||||||
deprecated: bool = None,
|
deprecated: bool = None,
|
||||||
response_type: typing.Type = None,
|
response_type: Type = None,
|
||||||
response_description: str = "Successful Response",
|
response_description: str = "Successful Response",
|
||||||
response_code=200,
|
response_code=200,
|
||||||
response_wrapper=JSONResponse,
|
response_wrapper=JSONResponse,
|
||||||
) -> typing.Callable:
|
) -> Callable:
|
||||||
def decorator(func: typing.Callable) -> typing.Callable:
|
def decorator(func: Callable) -> Callable:
|
||||||
self.add_api_route(
|
self.add_api_route(
|
||||||
path,
|
path,
|
||||||
func,
|
func,
|
||||||
methods=methods,
|
methods=methods,
|
||||||
name=name,
|
name=name,
|
||||||
include_in_schema=include_in_schema,
|
include_in_schema=include_in_schema,
|
||||||
tags=tags,
|
tags=tags or [],
|
||||||
summary=summary,
|
summary=summary,
|
||||||
description=description,
|
description=description,
|
||||||
operation_id=operation_id,
|
operation_id=operation_id,
|
||||||
|
|
@ -541,12 +265,12 @@ class APIRouter(routing.Router):
|
||||||
path: str,
|
path: str,
|
||||||
name: str = None,
|
name: str = None,
|
||||||
include_in_schema: bool = True,
|
include_in_schema: bool = True,
|
||||||
tags: typing.List[str] = [],
|
tags: List[str] = None,
|
||||||
summary: str = None,
|
summary: str = None,
|
||||||
description: str = None,
|
description: str = None,
|
||||||
operation_id: str = None,
|
operation_id: str = None,
|
||||||
deprecated: bool = None,
|
deprecated: bool = None,
|
||||||
response_type: typing.Type = None,
|
response_type: Type = None,
|
||||||
response_description: str = "Successful Response",
|
response_description: str = "Successful Response",
|
||||||
response_code=200,
|
response_code=200,
|
||||||
response_wrapper=JSONResponse,
|
response_wrapper=JSONResponse,
|
||||||
|
|
@ -556,7 +280,7 @@ class APIRouter(routing.Router):
|
||||||
methods=["GET"],
|
methods=["GET"],
|
||||||
name=name,
|
name=name,
|
||||||
include_in_schema=include_in_schema,
|
include_in_schema=include_in_schema,
|
||||||
tags=tags,
|
tags=tags or [],
|
||||||
summary=summary,
|
summary=summary,
|
||||||
description=description,
|
description=description,
|
||||||
operation_id=operation_id,
|
operation_id=operation_id,
|
||||||
|
|
@ -572,12 +296,12 @@ class APIRouter(routing.Router):
|
||||||
path: str,
|
path: str,
|
||||||
name: str = None,
|
name: str = None,
|
||||||
include_in_schema: bool = True,
|
include_in_schema: bool = True,
|
||||||
tags: typing.List[str] = [],
|
tags: List[str] = None,
|
||||||
summary: str = None,
|
summary: str = None,
|
||||||
description: str = None,
|
description: str = None,
|
||||||
operation_id: str = None,
|
operation_id: str = None,
|
||||||
deprecated: bool = None,
|
deprecated: bool = None,
|
||||||
response_type: typing.Type = None,
|
response_type: Type = None,
|
||||||
response_description: str = "Successful Response",
|
response_description: str = "Successful Response",
|
||||||
response_code=200,
|
response_code=200,
|
||||||
response_wrapper=JSONResponse,
|
response_wrapper=JSONResponse,
|
||||||
|
|
@ -587,7 +311,7 @@ class APIRouter(routing.Router):
|
||||||
methods=["PUT"],
|
methods=["PUT"],
|
||||||
name=name,
|
name=name,
|
||||||
include_in_schema=include_in_schema,
|
include_in_schema=include_in_schema,
|
||||||
tags=tags,
|
tags=tags or [],
|
||||||
summary=summary,
|
summary=summary,
|
||||||
description=description,
|
description=description,
|
||||||
operation_id=operation_id,
|
operation_id=operation_id,
|
||||||
|
|
@ -603,12 +327,12 @@ class APIRouter(routing.Router):
|
||||||
path: str,
|
path: str,
|
||||||
name: str = None,
|
name: str = None,
|
||||||
include_in_schema: bool = True,
|
include_in_schema: bool = True,
|
||||||
tags: typing.List[str] = [],
|
tags: List[str] = None,
|
||||||
summary: str = None,
|
summary: str = None,
|
||||||
description: str = None,
|
description: str = None,
|
||||||
operation_id: str = None,
|
operation_id: str = None,
|
||||||
deprecated: bool = None,
|
deprecated: bool = None,
|
||||||
response_type: typing.Type = None,
|
response_type: Type = None,
|
||||||
response_description: str = "Successful Response",
|
response_description: str = "Successful Response",
|
||||||
response_code=200,
|
response_code=200,
|
||||||
response_wrapper=JSONResponse,
|
response_wrapper=JSONResponse,
|
||||||
|
|
@ -618,7 +342,7 @@ class APIRouter(routing.Router):
|
||||||
methods=["POST"],
|
methods=["POST"],
|
||||||
name=name,
|
name=name,
|
||||||
include_in_schema=include_in_schema,
|
include_in_schema=include_in_schema,
|
||||||
tags=tags,
|
tags=tags or [],
|
||||||
summary=summary,
|
summary=summary,
|
||||||
description=description,
|
description=description,
|
||||||
operation_id=operation_id,
|
operation_id=operation_id,
|
||||||
|
|
@ -634,12 +358,12 @@ class APIRouter(routing.Router):
|
||||||
path: str,
|
path: str,
|
||||||
name: str = None,
|
name: str = None,
|
||||||
include_in_schema: bool = True,
|
include_in_schema: bool = True,
|
||||||
tags: typing.List[str] = [],
|
tags: List[str] = None,
|
||||||
summary: str = None,
|
summary: str = None,
|
||||||
description: str = None,
|
description: str = None,
|
||||||
operation_id: str = None,
|
operation_id: str = None,
|
||||||
deprecated: bool = None,
|
deprecated: bool = None,
|
||||||
response_type: typing.Type = None,
|
response_type: Type = None,
|
||||||
response_description: str = "Successful Response",
|
response_description: str = "Successful Response",
|
||||||
response_code=200,
|
response_code=200,
|
||||||
response_wrapper=JSONResponse,
|
response_wrapper=JSONResponse,
|
||||||
|
|
@ -649,7 +373,7 @@ class APIRouter(routing.Router):
|
||||||
methods=["DELETE"],
|
methods=["DELETE"],
|
||||||
name=name,
|
name=name,
|
||||||
include_in_schema=include_in_schema,
|
include_in_schema=include_in_schema,
|
||||||
tags=tags,
|
tags=tags or [],
|
||||||
summary=summary,
|
summary=summary,
|
||||||
description=description,
|
description=description,
|
||||||
operation_id=operation_id,
|
operation_id=operation_id,
|
||||||
|
|
@ -665,12 +389,12 @@ class APIRouter(routing.Router):
|
||||||
path: str,
|
path: str,
|
||||||
name: str = None,
|
name: str = None,
|
||||||
include_in_schema: bool = True,
|
include_in_schema: bool = True,
|
||||||
tags: typing.List[str] = [],
|
tags: List[str] = None,
|
||||||
summary: str = None,
|
summary: str = None,
|
||||||
description: str = None,
|
description: str = None,
|
||||||
operation_id: str = None,
|
operation_id: str = None,
|
||||||
deprecated: bool = None,
|
deprecated: bool = None,
|
||||||
response_type: typing.Type = None,
|
response_type: Type = None,
|
||||||
response_description: str = "Successful Response",
|
response_description: str = "Successful Response",
|
||||||
response_code=200,
|
response_code=200,
|
||||||
response_wrapper=JSONResponse,
|
response_wrapper=JSONResponse,
|
||||||
|
|
@ -680,7 +404,7 @@ class APIRouter(routing.Router):
|
||||||
methods=["OPTIONS"],
|
methods=["OPTIONS"],
|
||||||
name=name,
|
name=name,
|
||||||
include_in_schema=include_in_schema,
|
include_in_schema=include_in_schema,
|
||||||
tags=tags,
|
tags=tags or [],
|
||||||
summary=summary,
|
summary=summary,
|
||||||
description=description,
|
description=description,
|
||||||
operation_id=operation_id,
|
operation_id=operation_id,
|
||||||
|
|
@ -696,12 +420,12 @@ class APIRouter(routing.Router):
|
||||||
path: str,
|
path: str,
|
||||||
name: str = None,
|
name: str = None,
|
||||||
include_in_schema: bool = True,
|
include_in_schema: bool = True,
|
||||||
tags: typing.List[str] = [],
|
tags: List[str] = None,
|
||||||
summary: str = None,
|
summary: str = None,
|
||||||
description: str = None,
|
description: str = None,
|
||||||
operation_id: str = None,
|
operation_id: str = None,
|
||||||
deprecated: bool = None,
|
deprecated: bool = None,
|
||||||
response_type: typing.Type = None,
|
response_type: Type = None,
|
||||||
response_description: str = "Successful Response",
|
response_description: str = "Successful Response",
|
||||||
response_code=200,
|
response_code=200,
|
||||||
response_wrapper=JSONResponse,
|
response_wrapper=JSONResponse,
|
||||||
|
|
@ -711,7 +435,7 @@ class APIRouter(routing.Router):
|
||||||
methods=["HEAD"],
|
methods=["HEAD"],
|
||||||
name=name,
|
name=name,
|
||||||
include_in_schema=include_in_schema,
|
include_in_schema=include_in_schema,
|
||||||
tags=tags,
|
tags=tags or [],
|
||||||
summary=summary,
|
summary=summary,
|
||||||
description=description,
|
description=description,
|
||||||
operation_id=operation_id,
|
operation_id=operation_id,
|
||||||
|
|
@ -727,12 +451,12 @@ class APIRouter(routing.Router):
|
||||||
path: str,
|
path: str,
|
||||||
name: str = None,
|
name: str = None,
|
||||||
include_in_schema: bool = True,
|
include_in_schema: bool = True,
|
||||||
tags: typing.List[str] = [],
|
tags: List[str] = None,
|
||||||
summary: str = None,
|
summary: str = None,
|
||||||
description: str = None,
|
description: str = None,
|
||||||
operation_id: str = None,
|
operation_id: str = None,
|
||||||
deprecated: bool = None,
|
deprecated: bool = None,
|
||||||
response_type: typing.Type = None,
|
response_type: Type = None,
|
||||||
response_description: str = "Successful Response",
|
response_description: str = "Successful Response",
|
||||||
response_code=200,
|
response_code=200,
|
||||||
response_wrapper=JSONResponse,
|
response_wrapper=JSONResponse,
|
||||||
|
|
@ -742,7 +466,7 @@ class APIRouter(routing.Router):
|
||||||
methods=["PATCH"],
|
methods=["PATCH"],
|
||||||
name=name,
|
name=name,
|
||||||
include_in_schema=include_in_schema,
|
include_in_schema=include_in_schema,
|
||||||
tags=tags,
|
tags=tags or [],
|
||||||
summary=summary,
|
summary=summary,
|
||||||
description=description,
|
description=description,
|
||||||
operation_id=operation_id,
|
operation_id=operation_id,
|
||||||
|
|
@ -758,12 +482,12 @@ class APIRouter(routing.Router):
|
||||||
path: str,
|
path: str,
|
||||||
name: str = None,
|
name: str = None,
|
||||||
include_in_schema: bool = True,
|
include_in_schema: bool = True,
|
||||||
tags: typing.List[str] = [],
|
tags: List[str] = None,
|
||||||
summary: str = None,
|
summary: str = None,
|
||||||
description: str = None,
|
description: str = None,
|
||||||
operation_id: str = None,
|
operation_id: str = None,
|
||||||
deprecated: bool = None,
|
deprecated: bool = None,
|
||||||
response_type: typing.Type = None,
|
response_type: Type = None,
|
||||||
response_description: str = "Successful Response",
|
response_description: str = "Successful Response",
|
||||||
response_code=200,
|
response_code=200,
|
||||||
response_wrapper=JSONResponse,
|
response_wrapper=JSONResponse,
|
||||||
|
|
@ -773,7 +497,7 @@ class APIRouter(routing.Router):
|
||||||
methods=["TRACE"],
|
methods=["TRACE"],
|
||||||
name=name,
|
name=name,
|
||||||
include_in_schema=include_in_schema,
|
include_in_schema=include_in_schema,
|
||||||
tags=tags,
|
tags=tags or [],
|
||||||
summary=summary,
|
summary=summary,
|
||||||
description=description,
|
description=description,
|
||||||
operation_id=operation_id,
|
operation_id=operation_id,
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,11 @@
|
||||||
from starlette.requests import Request
|
from enum import Enum
|
||||||
|
|
||||||
from pydantic import Schema
|
from pydantic import Schema
|
||||||
from enum import Enum
|
|
||||||
|
from starlette.requests import Request
|
||||||
|
|
||||||
from .base import SecurityBase, Types
|
from .base import SecurityBase, Types
|
||||||
|
|
||||||
__all__ = ["APIKeyIn", "APIKeyBase", "APIKeyQuery", "APIKeyHeader", "APIKeyCookie"]
|
|
||||||
|
|
||||||
|
|
||||||
class APIKeyIn(Enum):
|
class APIKeyIn(Enum):
|
||||||
query = "query"
|
query = "query"
|
||||||
header = "header"
|
header = "header"
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,6 @@
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pydantic import BaseModel, Schema
|
|
||||||
|
|
||||||
__all__ = ["Types", "SecurityBase"]
|
from pydantic import BaseModel, Schema
|
||||||
|
|
||||||
|
|
||||||
class Types(Enum):
|
class Types(Enum):
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,8 @@
|
||||||
|
from pydantic import Schema
|
||||||
|
|
||||||
from starlette.requests import Request
|
from starlette.requests import Request
|
||||||
from pydantic import Schema
|
|
||||||
from .base import SecurityBase, Types
|
|
||||||
|
|
||||||
__all__ = ["HTTPBase", "HTTPBasic", "HTTPBearer", "HTTPDigest"]
|
from .base import SecurityBase, Types
|
||||||
|
|
||||||
|
|
||||||
class HTTPBase(SecurityBase):
|
class HTTPBase(SecurityBase):
|
||||||
|
|
|
||||||
|
|
@ -3,11 +3,9 @@ from typing import Dict
|
||||||
from pydantic import BaseModel, Schema
|
from pydantic import BaseModel, Schema
|
||||||
|
|
||||||
from starlette.requests import Request
|
from starlette.requests import Request
|
||||||
|
|
||||||
from .base import SecurityBase, Types
|
from .base import SecurityBase, Types
|
||||||
|
|
||||||
# __all__ = ["HTTPBase", "HTTPBasic", "HTTPBearer", "HTTPDigest"]
|
|
||||||
|
|
||||||
|
|
||||||
class OAuthFlow(BaseModel):
|
class OAuthFlow(BaseModel):
|
||||||
refreshUrl: str = None
|
refreshUrl: str = None
|
||||||
scopes: Dict[str, str] = {}
|
scopes: Dict[str, str] = {}
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@ from starlette.requests import Request
|
||||||
|
|
||||||
from .base import SecurityBase, Types
|
from .base import SecurityBase, Types
|
||||||
|
|
||||||
|
|
||||||
class OpenIdConnect(SecurityBase):
|
class OpenIdConnect(SecurityBase):
|
||||||
type_ = Types.openIdConnect
|
type_ = Types.openIdConnect
|
||||||
openIdConnectUrl: str
|
openIdConnectUrl: str
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,46 @@
|
||||||
|
import re
|
||||||
|
from typing import Dict, Sequence, Set, Type
|
||||||
|
|
||||||
|
from starlette.routing import BaseRoute
|
||||||
|
|
||||||
|
from fastapi import routing
|
||||||
|
from fastapi.openapi.constants import REF_PREFIX
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from pydantic.fields import Field
|
||||||
|
from pydantic.schema import get_flat_models_from_fields, model_process_schema
|
||||||
|
|
||||||
|
|
||||||
|
def get_flat_models_from_routes(routes: Sequence[BaseRoute]):
|
||||||
|
body_fields_from_routes = []
|
||||||
|
responses_from_routes = []
|
||||||
|
for route in routes:
|
||||||
|
if route.include_in_schema and isinstance(route, routing.APIRoute):
|
||||||
|
if route.body_field:
|
||||||
|
assert isinstance(
|
||||||
|
route.body_field, Field
|
||||||
|
), "A request body must be a Pydantic Field"
|
||||||
|
body_fields_from_routes.append(route.body_field)
|
||||||
|
if route.response_field:
|
||||||
|
responses_from_routes.append(route.response_field)
|
||||||
|
flat_models = get_flat_models_from_fields(
|
||||||
|
body_fields_from_routes + responses_from_routes
|
||||||
|
)
|
||||||
|
return flat_models
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_definitions(
|
||||||
|
*, flat_models: Set[Type[BaseModel]], model_name_map: Dict[Type[BaseModel], str]
|
||||||
|
):
|
||||||
|
definitions: Dict[str, Dict] = {}
|
||||||
|
for model in flat_models:
|
||||||
|
m_schema, m_definitions = model_process_schema(
|
||||||
|
model, model_name_map=model_name_map, ref_prefix=REF_PREFIX
|
||||||
|
)
|
||||||
|
definitions.update(m_definitions)
|
||||||
|
model_name = model_name_map[model]
|
||||||
|
definitions[model_name] = m_schema
|
||||||
|
return definitions
|
||||||
|
|
||||||
|
|
||||||
|
def get_path_param_names(path: str):
|
||||||
|
return {item.strip("{}") for item in re.findall("{[^}]*}", path)}
|
||||||
Loading…
Reference in New Issue