mirror of https://github.com/tiangolo/fastapi.git
✨ Update parameter names and order
fix mypy types, refactor, lint
This commit is contained in:
parent
addfa89b0f
commit
0e19c24014
|
|
@ -1,19 +1,19 @@
|
||||||
from typing import Any, Callable, Dict, List, Type
|
from typing import Any, Callable, Dict, List, Optional, Type
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
from starlette.applications import Starlette
|
from starlette.applications import Starlette
|
||||||
from starlette.exceptions import ExceptionMiddleware, HTTPException
|
from starlette.exceptions import ExceptionMiddleware, HTTPException
|
||||||
from starlette.middleware.errors import ServerErrorMiddleware
|
from starlette.middleware.errors import ServerErrorMiddleware
|
||||||
from starlette.middleware.lifespan import LifespanMiddleware
|
from starlette.middleware.lifespan import LifespanMiddleware
|
||||||
from starlette.responses import JSONResponse
|
from starlette.requests import Request
|
||||||
|
from starlette.responses import JSONResponse, Response
|
||||||
|
|
||||||
from fastapi import routing
|
from fastapi import routing
|
||||||
from fastapi.openapi.utils import get_openapi
|
|
||||||
from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
|
from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
|
||||||
|
from fastapi.openapi.utils import get_openapi
|
||||||
|
|
||||||
|
|
||||||
async def http_exception(request, exc: HTTPException):
|
async def http_exception(request: Request, exc: HTTPException) -> JSONResponse:
|
||||||
print(exc)
|
|
||||||
return JSONResponse({"detail": exc.detail}, status_code=exc.status_code)
|
return JSONResponse({"detail": exc.detail}, status_code=exc.status_code)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -31,7 +31,7 @@ class FastAPI(Starlette):
|
||||||
**extra: Dict[str, Any],
|
**extra: Dict[str, Any],
|
||||||
) -> None:
|
) -> None:
|
||||||
self._debug = debug
|
self._debug = debug
|
||||||
self.router = routing.APIRouter()
|
self.router: routing.APIRouter = routing.APIRouter()
|
||||||
self.exception_middleware = ExceptionMiddleware(self.router, debug=debug)
|
self.exception_middleware = ExceptionMiddleware(self.router, debug=debug)
|
||||||
self.error_middleware = ServerErrorMiddleware(
|
self.error_middleware = ServerErrorMiddleware(
|
||||||
self.exception_middleware, debug=debug
|
self.exception_middleware, debug=debug
|
||||||
|
|
@ -56,33 +56,41 @@ class FastAPI(Starlette):
|
||||||
|
|
||||||
if self.swagger_ui_url or self.redoc_url:
|
if self.swagger_ui_url or self.redoc_url:
|
||||||
assert self.openapi_url, "The openapi_url is required for the docs"
|
assert self.openapi_url, "The openapi_url is required for the docs"
|
||||||
|
self.openapi_schema: Optional[Dict[str, Any]] = None
|
||||||
self.setup()
|
self.setup()
|
||||||
|
|
||||||
def setup(self):
|
def openapi(self) -> Dict:
|
||||||
|
if not self.openapi_schema:
|
||||||
|
self.openapi_schema = get_openapi(
|
||||||
|
title=self.title,
|
||||||
|
version=self.version,
|
||||||
|
openapi_version=self.openapi_version,
|
||||||
|
description=self.description,
|
||||||
|
routes=self.routes,
|
||||||
|
)
|
||||||
|
return self.openapi_schema
|
||||||
|
|
||||||
|
def setup(self) -> None:
|
||||||
if self.openapi_url:
|
if self.openapi_url:
|
||||||
self.add_route(
|
self.add_route(
|
||||||
self.openapi_url,
|
self.openapi_url,
|
||||||
lambda req: JSONResponse(
|
lambda req: JSONResponse(self.openapi()),
|
||||||
get_openapi(
|
|
||||||
title=self.title,
|
|
||||||
version=self.version,
|
|
||||||
openapi_version=self.openapi_version,
|
|
||||||
description=self.description,
|
|
||||||
routes=self.routes,
|
|
||||||
)
|
|
||||||
),
|
|
||||||
include_in_schema=False,
|
include_in_schema=False,
|
||||||
)
|
)
|
||||||
if self.swagger_ui_url:
|
if self.swagger_ui_url:
|
||||||
self.add_route(
|
self.add_route(
|
||||||
self.swagger_ui_url,
|
self.swagger_ui_url,
|
||||||
lambda r: get_swagger_ui_html(openapi_url=self.openapi_url, title=self.title + " - Swagger UI"),
|
lambda r: get_swagger_ui_html(
|
||||||
|
openapi_url=self.openapi_url, title=self.title + " - Swagger UI"
|
||||||
|
),
|
||||||
include_in_schema=False,
|
include_in_schema=False,
|
||||||
)
|
)
|
||||||
if self.redoc_url:
|
if self.redoc_url:
|
||||||
self.add_route(
|
self.add_route(
|
||||||
self.redoc_url,
|
self.redoc_url,
|
||||||
lambda r: get_redoc_html(openapi_url=self.openapi_url, title=self.title + " - ReDoc"),
|
lambda r: get_redoc_html(
|
||||||
|
openapi_url=self.openapi_url, title=self.title + " - ReDoc"
|
||||||
|
),
|
||||||
include_in_schema=False,
|
include_in_schema=False,
|
||||||
)
|
)
|
||||||
self.add_exception_handler(HTTPException, http_exception)
|
self.add_exception_handler(HTTPException, http_exception)
|
||||||
|
|
@ -91,311 +99,322 @@ class FastAPI(Starlette):
|
||||||
self,
|
self,
|
||||||
path: str,
|
path: str,
|
||||||
endpoint: Callable,
|
endpoint: Callable,
|
||||||
methods: List[str] = None,
|
*,
|
||||||
name: str = None,
|
response_model: Type[BaseModel] = None,
|
||||||
include_in_schema: bool = True,
|
status_code: int = 200,
|
||||||
tags: List[str] = None,
|
tags: List[str] = None,
|
||||||
summary: str = None,
|
summary: str = None,
|
||||||
description: str = None,
|
description: str = None,
|
||||||
operation_id: str = None,
|
|
||||||
deprecated: bool = None,
|
|
||||||
response_type: Type = None,
|
|
||||||
response_description: str = "Successful Response",
|
response_description: str = "Successful Response",
|
||||||
response_code=200,
|
deprecated: bool = None,
|
||||||
response_wrapper=JSONResponse,
|
name: str = None,
|
||||||
|
methods: List[str] = None,
|
||||||
|
operation_id: str = None,
|
||||||
|
include_in_schema: bool = True,
|
||||||
|
content_type: Type[Response] = JSONResponse,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.router.add_api_route(
|
self.router.add_api_route(
|
||||||
path,
|
path,
|
||||||
endpoint=endpoint,
|
endpoint=endpoint,
|
||||||
methods=methods,
|
response_model=response_model,
|
||||||
name=name,
|
status_code=status_code,
|
||||||
include_in_schema=include_in_schema,
|
|
||||||
tags=tags or [],
|
tags=tags or [],
|
||||||
summary=summary,
|
summary=summary,
|
||||||
description=description,
|
description=description,
|
||||||
operation_id=operation_id,
|
|
||||||
deprecated=deprecated,
|
|
||||||
response_type=response_type,
|
|
||||||
response_description=response_description,
|
response_description=response_description,
|
||||||
response_code=response_code,
|
deprecated=deprecated,
|
||||||
response_wrapper=response_wrapper,
|
name=name,
|
||||||
|
methods=methods,
|
||||||
|
operation_id=operation_id,
|
||||||
|
include_in_schema=include_in_schema,
|
||||||
|
content_type=content_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
def api_route(
|
def api_route(
|
||||||
self,
|
self,
|
||||||
path: str,
|
path: str,
|
||||||
methods: List[str] = None,
|
*,
|
||||||
name: str = None,
|
response_model: Type[BaseModel] = None,
|
||||||
include_in_schema: bool = True,
|
status_code: int = 200,
|
||||||
tags: List[str] = None,
|
tags: List[str] = None,
|
||||||
summary: str = None,
|
summary: str = None,
|
||||||
description: str = None,
|
description: str = None,
|
||||||
operation_id: str = None,
|
|
||||||
deprecated: bool = None,
|
|
||||||
response_type: Type = None,
|
|
||||||
response_description: str = "Successful Response",
|
response_description: str = "Successful Response",
|
||||||
response_code=200,
|
deprecated: bool = None,
|
||||||
response_wrapper=JSONResponse,
|
name: str = None,
|
||||||
|
methods: List[str] = None,
|
||||||
|
operation_id: str = None,
|
||||||
|
include_in_schema: bool = True,
|
||||||
|
content_type: Type[Response] = JSONResponse,
|
||||||
) -> Callable:
|
) -> Callable:
|
||||||
def decorator(func: Callable) -> Callable:
|
def decorator(func: Callable) -> Callable:
|
||||||
self.router.add_api_route(
|
self.router.add_api_route(
|
||||||
path,
|
path,
|
||||||
func,
|
func,
|
||||||
methods=methods,
|
response_model=response_model,
|
||||||
name=name,
|
status_code=status_code,
|
||||||
include_in_schema=include_in_schema,
|
|
||||||
tags=tags or [],
|
tags=tags or [],
|
||||||
summary=summary,
|
summary=summary,
|
||||||
description=description,
|
description=description,
|
||||||
operation_id=operation_id,
|
|
||||||
deprecated=deprecated,
|
|
||||||
response_type=response_type,
|
|
||||||
response_description=response_description,
|
response_description=response_description,
|
||||||
response_code=response_code,
|
deprecated=deprecated,
|
||||||
response_wrapper=response_wrapper,
|
name=name,
|
||||||
|
methods=methods,
|
||||||
|
operation_id=operation_id,
|
||||||
|
include_in_schema=include_in_schema,
|
||||||
|
content_type=content_type,
|
||||||
)
|
)
|
||||||
return func
|
return func
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
def include_router(self, router: "APIRouter", *, prefix=""):
|
def include_router(self, router: routing.APIRouter, *, prefix: str = "") -> None:
|
||||||
self.router.include_router(router, prefix=prefix)
|
self.router.include_router(router, prefix=prefix)
|
||||||
|
|
||||||
def get(
|
def get(
|
||||||
self,
|
self,
|
||||||
path: str,
|
path: str,
|
||||||
name: str = None,
|
*,
|
||||||
include_in_schema: bool = True,
|
response_model: Type[BaseModel] = None,
|
||||||
|
status_code: int = 200,
|
||||||
tags: List[str] = None,
|
tags: List[str] = None,
|
||||||
summary: str = None,
|
summary: str = None,
|
||||||
description: str = None,
|
description: str = None,
|
||||||
operation_id: str = None,
|
|
||||||
deprecated: bool = None,
|
|
||||||
response_type: Type = None,
|
|
||||||
response_description: str = "Successful Response",
|
response_description: str = "Successful Response",
|
||||||
response_code=200,
|
deprecated: bool = None,
|
||||||
response_wrapper=JSONResponse,
|
name: str = None,
|
||||||
):
|
operation_id: str = None,
|
||||||
|
include_in_schema: bool = True,
|
||||||
|
content_type: Type[Response] = JSONResponse,
|
||||||
|
) -> Callable:
|
||||||
return self.router.get(
|
return self.router.get(
|
||||||
path=path,
|
path,
|
||||||
name=name,
|
response_model=response_model,
|
||||||
include_in_schema=include_in_schema,
|
status_code=status_code,
|
||||||
tags=tags or [],
|
tags=tags or [],
|
||||||
summary=summary,
|
summary=summary,
|
||||||
description=description,
|
description=description,
|
||||||
operation_id=operation_id,
|
|
||||||
deprecated=deprecated,
|
|
||||||
response_type=response_type,
|
|
||||||
response_description=response_description,
|
response_description=response_description,
|
||||||
response_code=response_code,
|
deprecated=deprecated,
|
||||||
response_wrapper=response_wrapper,
|
name=name,
|
||||||
|
operation_id=operation_id,
|
||||||
|
include_in_schema=include_in_schema,
|
||||||
|
content_type=content_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
def put(
|
def put(
|
||||||
self,
|
self,
|
||||||
path: str,
|
path: str,
|
||||||
name: str = None,
|
*,
|
||||||
include_in_schema: bool = True,
|
response_model: Type[BaseModel] = None,
|
||||||
|
status_code: int = 200,
|
||||||
tags: List[str] = None,
|
tags: List[str] = None,
|
||||||
summary: str = None,
|
summary: str = None,
|
||||||
description: str = None,
|
description: str = None,
|
||||||
operation_id: str = None,
|
|
||||||
deprecated: bool = None,
|
|
||||||
response_type: Type = None,
|
|
||||||
response_description: str = "Successful Response",
|
response_description: str = "Successful Response",
|
||||||
response_code=200,
|
deprecated: bool = None,
|
||||||
response_wrapper=JSONResponse,
|
name: str = None,
|
||||||
):
|
operation_id: str = None,
|
||||||
|
include_in_schema: bool = True,
|
||||||
|
content_type: Type[Response] = JSONResponse,
|
||||||
|
) -> Callable:
|
||||||
return self.router.put(
|
return self.router.put(
|
||||||
path=path,
|
path,
|
||||||
name=name,
|
response_model=response_model,
|
||||||
include_in_schema=include_in_schema,
|
status_code=status_code,
|
||||||
tags=tags or [],
|
tags=tags or [],
|
||||||
summary=summary,
|
summary=summary,
|
||||||
description=description,
|
description=description,
|
||||||
operation_id=operation_id,
|
|
||||||
deprecated=deprecated,
|
|
||||||
response_type=response_type,
|
|
||||||
response_description=response_description,
|
response_description=response_description,
|
||||||
response_code=response_code,
|
deprecated=deprecated,
|
||||||
response_wrapper=response_wrapper,
|
name=name,
|
||||||
|
operation_id=operation_id,
|
||||||
|
include_in_schema=include_in_schema,
|
||||||
|
content_type=content_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
def post(
|
def post(
|
||||||
self,
|
self,
|
||||||
path: str,
|
path: str,
|
||||||
name: str = None,
|
*,
|
||||||
include_in_schema: bool = True,
|
response_model: Type[BaseModel] = None,
|
||||||
|
status_code: int = 200,
|
||||||
tags: List[str] = None,
|
tags: List[str] = None,
|
||||||
summary: str = None,
|
summary: str = None,
|
||||||
description: str = None,
|
description: str = None,
|
||||||
operation_id: str = None,
|
|
||||||
deprecated: bool = None,
|
|
||||||
response_type: Type = None,
|
|
||||||
response_description: str = "Successful Response",
|
response_description: str = "Successful Response",
|
||||||
response_code=200,
|
deprecated: bool = None,
|
||||||
response_wrapper=JSONResponse,
|
name: str = None,
|
||||||
):
|
operation_id: str = None,
|
||||||
|
include_in_schema: bool = True,
|
||||||
|
content_type: Type[Response] = JSONResponse,
|
||||||
|
) -> Callable:
|
||||||
return self.router.post(
|
return self.router.post(
|
||||||
path=path,
|
path,
|
||||||
name=name,
|
response_model=response_model,
|
||||||
include_in_schema=include_in_schema,
|
status_code=status_code,
|
||||||
tags=tags or [],
|
tags=tags or [],
|
||||||
summary=summary,
|
summary=summary,
|
||||||
description=description,
|
description=description,
|
||||||
operation_id=operation_id,
|
|
||||||
deprecated=deprecated,
|
|
||||||
response_type=response_type,
|
|
||||||
response_description=response_description,
|
response_description=response_description,
|
||||||
response_code=response_code,
|
deprecated=deprecated,
|
||||||
response_wrapper=response_wrapper,
|
name=name,
|
||||||
|
operation_id=operation_id,
|
||||||
|
include_in_schema=include_in_schema,
|
||||||
|
content_type=content_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
def delete(
|
def delete(
|
||||||
self,
|
self,
|
||||||
path: str,
|
path: str,
|
||||||
name: str = None,
|
*,
|
||||||
include_in_schema: bool = True,
|
response_model: Type[BaseModel] = None,
|
||||||
|
status_code: int = 200,
|
||||||
tags: List[str] = None,
|
tags: List[str] = None,
|
||||||
summary: str = None,
|
summary: str = None,
|
||||||
description: str = None,
|
description: str = None,
|
||||||
operation_id: str = None,
|
|
||||||
deprecated: bool = None,
|
|
||||||
response_type: Type = None,
|
|
||||||
response_description: str = "Successful Response",
|
response_description: str = "Successful Response",
|
||||||
response_code=200,
|
deprecated: bool = None,
|
||||||
response_wrapper=JSONResponse,
|
name: str = None,
|
||||||
):
|
operation_id: str = None,
|
||||||
|
include_in_schema: bool = True,
|
||||||
|
content_type: Type[Response] = JSONResponse,
|
||||||
|
) -> Callable:
|
||||||
return self.router.delete(
|
return self.router.delete(
|
||||||
path=path,
|
path,
|
||||||
name=name,
|
response_model=response_model,
|
||||||
include_in_schema=include_in_schema,
|
status_code=status_code,
|
||||||
tags=tags or [],
|
tags=tags or [],
|
||||||
summary=summary,
|
summary=summary,
|
||||||
description=description,
|
description=description,
|
||||||
operation_id=operation_id,
|
|
||||||
deprecated=deprecated,
|
|
||||||
response_type=response_type,
|
|
||||||
response_description=response_description,
|
response_description=response_description,
|
||||||
response_code=response_code,
|
deprecated=deprecated,
|
||||||
response_wrapper=response_wrapper,
|
name=name,
|
||||||
|
operation_id=operation_id,
|
||||||
|
include_in_schema=include_in_schema,
|
||||||
|
content_type=content_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
def options(
|
def options(
|
||||||
self,
|
self,
|
||||||
path: str,
|
path: str,
|
||||||
name: str = None,
|
*,
|
||||||
include_in_schema: bool = True,
|
response_model: Type[BaseModel] = None,
|
||||||
|
status_code: int = 200,
|
||||||
tags: List[str] = None,
|
tags: List[str] = None,
|
||||||
summary: str = None,
|
summary: str = None,
|
||||||
description: str = None,
|
description: str = None,
|
||||||
operation_id: str = None,
|
|
||||||
deprecated: bool = None,
|
|
||||||
response_type: Type = None,
|
|
||||||
response_description: str = "Successful Response",
|
response_description: str = "Successful Response",
|
||||||
response_code=200,
|
deprecated: bool = None,
|
||||||
response_wrapper=JSONResponse,
|
name: str = None,
|
||||||
):
|
operation_id: str = None,
|
||||||
|
include_in_schema: bool = True,
|
||||||
|
content_type: Type[Response] = JSONResponse,
|
||||||
|
) -> Callable:
|
||||||
return self.router.options(
|
return self.router.options(
|
||||||
path=path,
|
path,
|
||||||
name=name,
|
response_model=response_model,
|
||||||
include_in_schema=include_in_schema,
|
status_code=status_code,
|
||||||
tags=tags or [],
|
tags=tags or [],
|
||||||
summary=summary,
|
summary=summary,
|
||||||
description=description,
|
description=description,
|
||||||
operation_id=operation_id,
|
|
||||||
deprecated=deprecated,
|
|
||||||
response_type=response_type,
|
|
||||||
response_description=response_description,
|
response_description=response_description,
|
||||||
response_code=response_code,
|
deprecated=deprecated,
|
||||||
response_wrapper=response_wrapper,
|
name=name,
|
||||||
|
operation_id=operation_id,
|
||||||
|
include_in_schema=include_in_schema,
|
||||||
|
content_type=content_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
def head(
|
def head(
|
||||||
self,
|
self,
|
||||||
path: str,
|
path: str,
|
||||||
name: str = None,
|
*,
|
||||||
include_in_schema: bool = True,
|
response_model: Type[BaseModel] = None,
|
||||||
|
status_code: int = 200,
|
||||||
tags: List[str] = None,
|
tags: List[str] = None,
|
||||||
summary: str = None,
|
summary: str = None,
|
||||||
description: str = None,
|
description: str = None,
|
||||||
operation_id: str = None,
|
|
||||||
deprecated: bool = None,
|
|
||||||
response_type: Type = None,
|
|
||||||
response_description: str = "Successful Response",
|
response_description: str = "Successful Response",
|
||||||
response_code=200,
|
deprecated: bool = None,
|
||||||
response_wrapper=JSONResponse,
|
name: str = None,
|
||||||
):
|
operation_id: str = None,
|
||||||
|
include_in_schema: bool = True,
|
||||||
|
content_type: Type[Response] = JSONResponse,
|
||||||
|
) -> Callable:
|
||||||
return self.router.head(
|
return self.router.head(
|
||||||
path=path,
|
path,
|
||||||
name=name,
|
response_model=response_model,
|
||||||
include_in_schema=include_in_schema,
|
status_code=status_code,
|
||||||
tags=tags or [],
|
tags=tags or [],
|
||||||
summary=summary,
|
summary=summary,
|
||||||
description=description,
|
description=description,
|
||||||
operation_id=operation_id,
|
|
||||||
deprecated=deprecated,
|
|
||||||
response_type=response_type,
|
|
||||||
response_description=response_description,
|
response_description=response_description,
|
||||||
response_code=response_code,
|
deprecated=deprecated,
|
||||||
response_wrapper=response_wrapper,
|
name=name,
|
||||||
|
operation_id=operation_id,
|
||||||
|
include_in_schema=include_in_schema,
|
||||||
|
content_type=content_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
def patch(
|
def patch(
|
||||||
self,
|
self,
|
||||||
path: str,
|
path: str,
|
||||||
name: str = None,
|
*,
|
||||||
include_in_schema: bool = True,
|
response_model: Type[BaseModel] = None,
|
||||||
|
status_code: int = 200,
|
||||||
tags: List[str] = None,
|
tags: List[str] = None,
|
||||||
summary: str = None,
|
summary: str = None,
|
||||||
description: str = None,
|
description: str = None,
|
||||||
operation_id: str = None,
|
|
||||||
deprecated: bool = None,
|
|
||||||
response_type: Type = None,
|
|
||||||
response_description: str = "Successful Response",
|
response_description: str = "Successful Response",
|
||||||
response_code=200,
|
deprecated: bool = None,
|
||||||
response_wrapper=JSONResponse,
|
name: str = None,
|
||||||
):
|
operation_id: str = None,
|
||||||
|
include_in_schema: bool = True,
|
||||||
|
content_type: Type[Response] = JSONResponse,
|
||||||
|
) -> Callable:
|
||||||
return self.router.patch(
|
return self.router.patch(
|
||||||
path=path,
|
path,
|
||||||
name=name,
|
response_model=response_model,
|
||||||
include_in_schema=include_in_schema,
|
status_code=status_code,
|
||||||
tags=tags or [],
|
tags=tags or [],
|
||||||
summary=summary,
|
summary=summary,
|
||||||
description=description,
|
description=description,
|
||||||
operation_id=operation_id,
|
|
||||||
deprecated=deprecated,
|
|
||||||
response_type=response_type,
|
|
||||||
response_description=response_description,
|
response_description=response_description,
|
||||||
response_code=response_code,
|
deprecated=deprecated,
|
||||||
response_wrapper=response_wrapper,
|
name=name,
|
||||||
|
operation_id=operation_id,
|
||||||
|
include_in_schema=include_in_schema,
|
||||||
|
content_type=content_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
def trace(
|
def trace(
|
||||||
self,
|
self,
|
||||||
path: str,
|
path: str,
|
||||||
name: str = None,
|
*,
|
||||||
include_in_schema: bool = True,
|
response_model: Type[BaseModel] = None,
|
||||||
|
status_code: int = 200,
|
||||||
tags: List[str] = None,
|
tags: List[str] = None,
|
||||||
summary: str = None,
|
summary: str = None,
|
||||||
description: str = None,
|
description: str = None,
|
||||||
operation_id: str = None,
|
|
||||||
deprecated: bool = None,
|
|
||||||
response_type: Type = None,
|
|
||||||
response_description: str = "Successful Response",
|
response_description: str = "Successful Response",
|
||||||
response_code=200,
|
deprecated: bool = None,
|
||||||
response_wrapper=JSONResponse,
|
name: str = None,
|
||||||
):
|
operation_id: str = None,
|
||||||
|
include_in_schema: bool = True,
|
||||||
|
content_type: Type[Response] = JSONResponse,
|
||||||
|
) -> Callable:
|
||||||
return self.router.trace(
|
return self.router.trace(
|
||||||
path=path,
|
path,
|
||||||
name=name,
|
response_model=response_model,
|
||||||
include_in_schema=include_in_schema,
|
status_code=status_code,
|
||||||
tags=tags or [],
|
tags=tags or [],
|
||||||
summary=summary,
|
summary=summary,
|
||||||
description=description,
|
description=description,
|
||||||
operation_id=operation_id,
|
|
||||||
deprecated=deprecated,
|
|
||||||
response_type=response_type,
|
|
||||||
response_description=response_description,
|
response_description=response_description,
|
||||||
response_code=response_code,
|
deprecated=deprecated,
|
||||||
response_wrapper=response_wrapper,
|
name=name,
|
||||||
|
operation_id=operation_id,
|
||||||
|
include_in_schema=include_in_schema,
|
||||||
|
content_type=content_type,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -1,14 +1,14 @@
|
||||||
from typing import Any, Callable, Dict, List, Sequence, Tuple
|
from typing import Any, Callable, Dict, List, Sequence, Tuple
|
||||||
|
|
||||||
from starlette.concurrency import run_in_threadpool
|
|
||||||
from starlette.requests import Request
|
|
||||||
|
|
||||||
from fastapi.security.base import SecurityBase
|
|
||||||
from pydantic import BaseConfig, Schema
|
from pydantic import BaseConfig, Schema
|
||||||
from pydantic.error_wrappers import ErrorWrapper
|
from pydantic.error_wrappers import ErrorWrapper
|
||||||
from pydantic.errors import MissingError
|
from pydantic.errors import MissingError
|
||||||
from pydantic.fields import Field, Required
|
from pydantic.fields import Field, Required
|
||||||
from pydantic.schema import get_annotation_from_schema
|
from pydantic.schema import get_annotation_from_schema
|
||||||
|
from starlette.concurrency import run_in_threadpool
|
||||||
|
from starlette.requests import Request
|
||||||
|
|
||||||
|
from fastapi.security.base import SecurityBase
|
||||||
|
|
||||||
param_supported_types = (str, int, float, bool)
|
param_supported_types = (str, int, float, bool)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,14 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import inspect
|
import inspect
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import Any, Callable, Dict, List, Tuple
|
from typing import Any, Callable, Dict, List, Mapping, Sequence, Tuple, Type
|
||||||
|
|
||||||
|
from pydantic import BaseConfig, Schema, create_model
|
||||||
|
from pydantic.error_wrappers import ErrorWrapper
|
||||||
|
from pydantic.errors import MissingError
|
||||||
|
from pydantic.fields import Field, Required
|
||||||
|
from pydantic.schema import get_annotation_from_schema
|
||||||
|
from pydantic.utils import lenient_issubclass
|
||||||
from starlette.concurrency import run_in_threadpool
|
from starlette.concurrency import run_in_threadpool
|
||||||
from starlette.requests import Request
|
from starlette.requests import Request
|
||||||
|
|
||||||
|
|
@ -10,17 +16,11 @@ from fastapi import params
|
||||||
from fastapi.dependencies.models import Dependant, SecurityRequirement
|
from fastapi.dependencies.models import Dependant, SecurityRequirement
|
||||||
from fastapi.security.base import SecurityBase
|
from fastapi.security.base import SecurityBase
|
||||||
from fastapi.utils import get_path_param_names
|
from fastapi.utils import get_path_param_names
|
||||||
from pydantic import BaseConfig, Schema, create_model
|
|
||||||
from pydantic.error_wrappers import ErrorWrapper
|
|
||||||
from pydantic.errors import MissingError
|
|
||||||
from pydantic.fields import Field, Required
|
|
||||||
from pydantic.schema import get_annotation_from_schema
|
|
||||||
from pydantic.utils import lenient_issubclass
|
|
||||||
|
|
||||||
param_supported_types = (str, int, float, bool)
|
param_supported_types = (str, int, float, bool)
|
||||||
|
|
||||||
|
|
||||||
def get_sub_dependant(*, param: inspect.Parameter, path: str):
|
def get_sub_dependant(*, param: inspect.Parameter, path: str) -> Dependant:
|
||||||
depends: params.Depends = param.default
|
depends: params.Depends = param.default
|
||||||
if depends.dependency:
|
if depends.dependency:
|
||||||
dependency = depends.dependency
|
dependency = depends.dependency
|
||||||
|
|
@ -36,7 +36,7 @@ def get_sub_dependant(*, param: inspect.Parameter, path: str):
|
||||||
return sub_dependant
|
return sub_dependant
|
||||||
|
|
||||||
|
|
||||||
def get_flat_dependant(dependant: Dependant):
|
def get_flat_dependant(dependant: Dependant) -> Dependant:
|
||||||
flat_dependant = Dependant(
|
flat_dependant = Dependant(
|
||||||
path_params=dependant.path_params.copy(),
|
path_params=dependant.path_params.copy(),
|
||||||
query_params=dependant.query_params.copy(),
|
query_params=dependant.query_params.copy(),
|
||||||
|
|
@ -58,7 +58,7 @@ def get_flat_dependant(dependant: Dependant):
|
||||||
return flat_dependant
|
return flat_dependant
|
||||||
|
|
||||||
|
|
||||||
def get_dependant(*, path: str, call: Callable, name: str = None):
|
def get_dependant(*, path: str, call: Callable, name: str = None) -> Dependant:
|
||||||
path_param_names = get_path_param_names(path)
|
path_param_names = get_path_param_names(path)
|
||||||
endpoint_signature = inspect.signature(call)
|
endpoint_signature = inspect.signature(call)
|
||||||
signature_params = endpoint_signature.parameters
|
signature_params = endpoint_signature.parameters
|
||||||
|
|
@ -73,9 +73,10 @@ def get_dependant(*, path: str, call: Callable, name: str = None):
|
||||||
if (
|
if (
|
||||||
(param.default == param.empty) or isinstance(param.default, params.Path)
|
(param.default == param.empty) or isinstance(param.default, params.Path)
|
||||||
) and (param_name in path_param_names):
|
) and (param_name in path_param_names):
|
||||||
assert lenient_issubclass(
|
assert (
|
||||||
param.annotation, param_supported_types
|
lenient_issubclass(param.annotation, param_supported_types)
|
||||||
) or param.annotation == param.empty, f"Path params must be of type str, int, float or boot: {param}"
|
or param.annotation == param.empty
|
||||||
|
), f"Path params must be of type str, int, float or boot: {param}"
|
||||||
param = signature_params[param_name]
|
param = signature_params[param_name]
|
||||||
add_param_to_fields(
|
add_param_to_fields(
|
||||||
param=param,
|
param=param,
|
||||||
|
|
@ -109,9 +110,9 @@ def add_param_to_fields(
|
||||||
*,
|
*,
|
||||||
param: inspect.Parameter,
|
param: inspect.Parameter,
|
||||||
dependant: Dependant,
|
dependant: Dependant,
|
||||||
default_schema=params.Param,
|
default_schema: Type[Schema] = params.Param,
|
||||||
force_type: params.ParamTypes = None,
|
force_type: params.ParamTypes = None,
|
||||||
):
|
) -> None:
|
||||||
default_value = Required
|
default_value = Required
|
||||||
if not param.default == param.empty:
|
if not param.default == param.empty:
|
||||||
default_value = param.default
|
default_value = param.default
|
||||||
|
|
@ -125,15 +126,19 @@ def add_param_to_fields(
|
||||||
else:
|
else:
|
||||||
schema = default_schema(default_value)
|
schema = default_schema(default_value)
|
||||||
required = default_value == Required
|
required = default_value == Required
|
||||||
annotation = Any
|
annotation: Type = Type[Any]
|
||||||
if not param.annotation == param.empty:
|
if not param.annotation == param.empty:
|
||||||
annotation = param.annotation
|
annotation = param.annotation
|
||||||
annotation = get_annotation_from_schema(annotation, schema)
|
annotation = get_annotation_from_schema(annotation, schema)
|
||||||
|
if not schema.alias and getattr(schema, "alias_underscore_to_hyphen", None):
|
||||||
|
alias = param.name.replace("_", "-")
|
||||||
|
else:
|
||||||
|
alias = schema.alias or param.name
|
||||||
field = Field(
|
field = Field(
|
||||||
name=param.name,
|
name=param.name,
|
||||||
type_=annotation,
|
type_=annotation,
|
||||||
default=None if required else default_value,
|
default=None if required else default_value,
|
||||||
alias=schema.alias or param.name,
|
alias=alias,
|
||||||
required=required,
|
required=required,
|
||||||
model_config=BaseConfig(),
|
model_config=BaseConfig(),
|
||||||
class_validators=[],
|
class_validators=[],
|
||||||
|
|
@ -152,7 +157,7 @@ def add_param_to_fields(
|
||||||
dependant.cookie_params.append(field)
|
dependant.cookie_params.append(field)
|
||||||
|
|
||||||
|
|
||||||
def add_param_to_body_fields(*, param: inspect.Parameter, dependant: Dependant):
|
def add_param_to_body_fields(*, param: inspect.Parameter, dependant: Dependant) -> None:
|
||||||
default_value = Required
|
default_value = Required
|
||||||
if not param.default == param.empty:
|
if not param.default == param.empty:
|
||||||
default_value = param.default
|
default_value = param.default
|
||||||
|
|
@ -176,7 +181,7 @@ def add_param_to_body_fields(*, param: inspect.Parameter, dependant: Dependant):
|
||||||
dependant.body_params.append(field)
|
dependant.body_params.append(field)
|
||||||
|
|
||||||
|
|
||||||
def is_coroutine_callable(call: Callable = None):
|
def is_coroutine_callable(call: Callable = None) -> bool:
|
||||||
if not call:
|
if not call:
|
||||||
return False
|
return False
|
||||||
if inspect.isfunction(call):
|
if inspect.isfunction(call):
|
||||||
|
|
@ -191,7 +196,7 @@ def is_coroutine_callable(call: Callable = None):
|
||||||
|
|
||||||
async def solve_dependencies(
|
async def solve_dependencies(
|
||||||
*, request: Request, dependant: Dependant, body: Dict[str, Any] = None
|
*, request: Request, dependant: Dependant, body: Dict[str, Any] = None
|
||||||
):
|
) -> Tuple[Dict[str, Any], List[ErrorWrapper]]:
|
||||||
values: Dict[str, Any] = {}
|
values: Dict[str, Any] = {}
|
||||||
errors: List[ErrorWrapper] = []
|
errors: List[ErrorWrapper] = []
|
||||||
for sub_dependant in dependant.dependencies:
|
for sub_dependant in dependant.dependencies:
|
||||||
|
|
@ -200,13 +205,13 @@ async def solve_dependencies(
|
||||||
)
|
)
|
||||||
if sub_errors:
|
if sub_errors:
|
||||||
return {}, errors
|
return {}, errors
|
||||||
if sub_dependant.call and is_coroutine_callable(sub_dependant.call):
|
assert sub_dependant.call is not None, "sub_dependant.call must be a function"
|
||||||
|
if is_coroutine_callable(sub_dependant.call):
|
||||||
solved = await sub_dependant.call(**sub_values)
|
solved = await sub_dependant.call(**sub_values)
|
||||||
else:
|
else:
|
||||||
solved = await run_in_threadpool(sub_dependant.call, **sub_values)
|
solved = await run_in_threadpool(sub_dependant.call, **sub_values)
|
||||||
values[
|
assert sub_dependant.name is not None, "Subdependants always have a name"
|
||||||
sub_dependant.name
|
values[sub_dependant.name] = solved
|
||||||
] = solved # type: ignore # Sub-dependants always have a name
|
|
||||||
path_values, path_errors = request_params_to_args(
|
path_values, path_errors = request_params_to_args(
|
||||||
dependant.path_params, request.path_params
|
dependant.path_params, request.path_params
|
||||||
)
|
)
|
||||||
|
|
@ -236,7 +241,7 @@ async def solve_dependencies(
|
||||||
|
|
||||||
|
|
||||||
def request_params_to_args(
|
def request_params_to_args(
|
||||||
required_params: List[Field], received_params: Dict[str, Any]
|
required_params: Sequence[Field], received_params: Mapping[str, Any]
|
||||||
) -> Tuple[Dict[str, Any], List[ErrorWrapper]]:
|
) -> Tuple[Dict[str, Any], List[ErrorWrapper]]:
|
||||||
values = {}
|
values = {}
|
||||||
errors = []
|
errors = []
|
||||||
|
|
@ -250,9 +255,9 @@ def request_params_to_args(
|
||||||
else:
|
else:
|
||||||
values[field.name] = deepcopy(field.default)
|
values[field.name] = deepcopy(field.default)
|
||||||
continue
|
continue
|
||||||
v_, errors_ = field.validate(
|
schema: params.Param = field.schema
|
||||||
value, values, loc=(field.schema.in_.value, field.alias)
|
assert isinstance(schema, params.Param), "Params must be subclasses of Param"
|
||||||
)
|
v_, errors_ = field.validate(value, values, loc=(schema.in_.value, field.alias))
|
||||||
if isinstance(errors_, ErrorWrapper):
|
if isinstance(errors_, ErrorWrapper):
|
||||||
errors.append(errors_)
|
errors.append(errors_)
|
||||||
elif isinstance(errors_, list):
|
elif isinstance(errors_, list):
|
||||||
|
|
@ -294,7 +299,7 @@ async def request_body_to_args(
|
||||||
return values, errors
|
return values, errors
|
||||||
|
|
||||||
|
|
||||||
def get_body_field(*, dependant: Dependant, name: str):
|
def get_body_field(*, dependant: Dependant, name: str) -> Field:
|
||||||
flat_dependant = get_flat_dependant(dependant)
|
flat_dependant = get_flat_dependant(dependant)
|
||||||
if not flat_dependant.body_params:
|
if not flat_dependant.body_params:
|
||||||
return None
|
return None
|
||||||
|
|
@ -308,7 +313,7 @@ def get_body_field(*, dependant: Dependant, name: str):
|
||||||
BodyModel.__fields__[f.name] = f
|
BodyModel.__fields__[f.name] = f
|
||||||
required = any(True for f in flat_dependant.body_params if f.required)
|
required = any(True for f in flat_dependant.body_params if f.required)
|
||||||
if any(isinstance(f.schema, params.File) for f in flat_dependant.body_params):
|
if any(isinstance(f.schema, params.File) for f in flat_dependant.body_params):
|
||||||
BodySchema = params.File
|
BodySchema: Type[params.Body] = params.File
|
||||||
elif any(isinstance(f.schema, params.Form) for f in flat_dependant.body_params):
|
elif any(isinstance(f.schema, params.Form) for f in flat_dependant.body_params):
|
||||||
BodySchema = params.Form
|
BodySchema = params.Form
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
|
|
@ -1,18 +1,18 @@
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from types import GeneratorType
|
from types import GeneratorType
|
||||||
from typing import Set
|
from typing import Any, Set
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from pydantic.json import pydantic_encoder
|
from pydantic.json import pydantic_encoder
|
||||||
|
|
||||||
|
|
||||||
def jsonable_encoder(
|
def jsonable_encoder(
|
||||||
obj,
|
obj: Any,
|
||||||
include: Set[str] = None,
|
include: Set[str] = None,
|
||||||
exclude: Set[str] = set(),
|
exclude: Set[str] = set(),
|
||||||
by_alias: bool = False,
|
by_alias: bool = False,
|
||||||
include_none=True,
|
include_none: bool = True,
|
||||||
):
|
) -> Any:
|
||||||
if isinstance(obj, BaseModel):
|
if isinstance(obj, BaseModel):
|
||||||
return jsonable_encoder(
|
return jsonable_encoder(
|
||||||
obj.dict(include=include, exclude=exclude, by_alias=by_alias),
|
obj.dict(include=include, exclude=exclude, by_alias=by_alias),
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
from starlette.responses import HTMLResponse
|
from starlette.responses import HTMLResponse
|
||||||
|
|
||||||
def get_swagger_ui_html(*, openapi_url: str, title: str):
|
|
||||||
|
def get_swagger_ui_html(*, openapi_url: str, title: str) -> HTMLResponse:
|
||||||
return HTMLResponse(
|
return HTMLResponse(
|
||||||
"""
|
"""
|
||||||
<! doctype html>
|
<! doctype html>
|
||||||
|
|
@ -35,12 +36,11 @@ def get_swagger_ui_html(*, openapi_url: str, title: str):
|
||||||
</script>
|
</script>
|
||||||
</body>
|
</body>
|
||||||
</html>
|
</html>
|
||||||
""",
|
"""
|
||||||
media_type="text/html",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_redoc_html(*, openapi_url: str, title: str):
|
def get_redoc_html(*, openapi_url: str, title: str) -> HTMLResponse:
|
||||||
return HTMLResponse(
|
return HTMLResponse(
|
||||||
"""
|
"""
|
||||||
<!DOCTYPE html>
|
<!DOCTYPE html>
|
||||||
|
|
@ -73,6 +73,5 @@ def get_redoc_html(*, openapi_url: str, title: str):
|
||||||
<script src="https://cdn.jsdelivr.net/npm/redoc@next/bundles/redoc.standalone.js"> </script>
|
<script src="https://cdn.jsdelivr.net/npm/redoc@next/bundles/redoc.standalone.js"> </script>
|
||||||
</body>
|
</body>
|
||||||
</html>
|
</html>
|
||||||
""",
|
"""
|
||||||
media_type="text/html",
|
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -7,13 +7,13 @@ from pydantic.types import UrlStr
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import pydantic.types.EmailStr
|
import pydantic.types.EmailStr
|
||||||
from pydantic.types import EmailStr
|
from pydantic.types import EmailStr # type: ignore
|
||||||
except ImportError:
|
except ImportError:
|
||||||
logging.warning(
|
logging.warning(
|
||||||
"email-validator not installed, email fields will be treated as str"
|
"email-validator not installed, email fields will be treated as str"
|
||||||
)
|
)
|
||||||
|
|
||||||
class EmailStr(str):
|
class EmailStr(str): # type: ignore
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -50,7 +50,7 @@ class Server(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class Reference(BaseModel):
|
class Reference(BaseModel):
|
||||||
ref: str = PSchema(..., alias="$ref")
|
ref: str = PSchema(..., alias="$ref") # type: ignore
|
||||||
|
|
||||||
|
|
||||||
class Discriminator(BaseModel):
|
class Discriminator(BaseModel):
|
||||||
|
|
@ -72,28 +72,28 @@ class ExternalDocumentation(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class SchemaBase(BaseModel):
|
class SchemaBase(BaseModel):
|
||||||
ref: Optional[str] = PSchema(None, alias="$ref")
|
ref: Optional[str] = PSchema(None, alias="$ref") # type: ignore
|
||||||
title: Optional[str] = None
|
title: Optional[str] = None
|
||||||
multipleOf: Optional[float] = None
|
multipleOf: Optional[float] = None
|
||||||
maximum: Optional[float] = None
|
maximum: Optional[float] = None
|
||||||
exclusiveMaximum: Optional[float] = None
|
exclusiveMaximum: Optional[float] = None
|
||||||
minimum: Optional[float] = None
|
minimum: Optional[float] = None
|
||||||
exclusiveMinimum: Optional[float] = None
|
exclusiveMinimum: Optional[float] = None
|
||||||
maxLength: Optional[int] = PSchema(None, gte=0)
|
maxLength: Optional[int] = PSchema(None, gte=0) # type: ignore
|
||||||
minLength: Optional[int] = PSchema(None, gte=0)
|
minLength: Optional[int] = PSchema(None, gte=0) # type: ignore
|
||||||
pattern: Optional[str] = None
|
pattern: Optional[str] = None
|
||||||
maxItems: Optional[int] = PSchema(None, gte=0)
|
maxItems: Optional[int] = PSchema(None, gte=0) # type: ignore
|
||||||
minItems: Optional[int] = PSchema(None, gte=0)
|
minItems: Optional[int] = PSchema(None, gte=0) # type: ignore
|
||||||
uniqueItems: Optional[bool] = None
|
uniqueItems: Optional[bool] = None
|
||||||
maxProperties: Optional[int] = PSchema(None, gte=0)
|
maxProperties: Optional[int] = PSchema(None, gte=0) # type: ignore
|
||||||
minProperties: Optional[int] = PSchema(None, gte=0)
|
minProperties: Optional[int] = PSchema(None, gte=0) # type: ignore
|
||||||
required: Optional[List[str]] = None
|
required: Optional[List[str]] = None
|
||||||
enum: Optional[List[str]] = None
|
enum: Optional[List[str]] = None
|
||||||
type: Optional[str] = None
|
type: Optional[str] = None
|
||||||
allOf: Optional[List[Any]] = None
|
allOf: Optional[List[Any]] = None
|
||||||
oneOf: Optional[List[Any]] = None
|
oneOf: Optional[List[Any]] = None
|
||||||
anyOf: Optional[List[Any]] = None
|
anyOf: Optional[List[Any]] = None
|
||||||
not_: Optional[List[Any]] = PSchema(None, alias="not")
|
not_: Optional[List[Any]] = PSchema(None, alias="not") # type: ignore
|
||||||
items: Optional[Any] = None
|
items: Optional[Any] = None
|
||||||
properties: Optional[Dict[str, Any]] = None
|
properties: Optional[Dict[str, Any]] = None
|
||||||
additionalProperties: Optional[Union[bool, Any]] = None
|
additionalProperties: Optional[Union[bool, Any]] = None
|
||||||
|
|
@ -114,7 +114,7 @@ class Schema(SchemaBase):
|
||||||
allOf: Optional[List[SchemaBase]] = None
|
allOf: Optional[List[SchemaBase]] = None
|
||||||
oneOf: Optional[List[SchemaBase]] = None
|
oneOf: Optional[List[SchemaBase]] = None
|
||||||
anyOf: Optional[List[SchemaBase]] = None
|
anyOf: Optional[List[SchemaBase]] = None
|
||||||
not_: Optional[List[SchemaBase]] = PSchema(None, alias="not")
|
not_: Optional[List[SchemaBase]] = PSchema(None, alias="not") # type: ignore
|
||||||
items: Optional[SchemaBase] = None
|
items: Optional[SchemaBase] = None
|
||||||
properties: Optional[Dict[str, SchemaBase]] = None
|
properties: Optional[Dict[str, SchemaBase]] = None
|
||||||
additionalProperties: Optional[Union[bool, SchemaBase]] = None
|
additionalProperties: Optional[Union[bool, SchemaBase]] = None
|
||||||
|
|
@ -144,7 +144,9 @@ class Encoding(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class MediaType(BaseModel):
|
class MediaType(BaseModel):
|
||||||
schema_: Optional[Union[Schema, Reference]] = PSchema(None, alias="schema")
|
schema_: Optional[Union[Schema, Reference]] = PSchema(
|
||||||
|
None, alias="schema"
|
||||||
|
) # type: ignore
|
||||||
example: Optional[Any] = None
|
example: Optional[Any] = None
|
||||||
examples: Optional[Dict[str, Union[Example, Reference]]] = None
|
examples: Optional[Dict[str, Union[Example, Reference]]] = None
|
||||||
encoding: Optional[Dict[str, Encoding]] = None
|
encoding: Optional[Dict[str, Encoding]] = None
|
||||||
|
|
@ -158,7 +160,9 @@ class ParameterBase(BaseModel):
|
||||||
style: Optional[str] = None
|
style: Optional[str] = None
|
||||||
explode: Optional[bool] = None
|
explode: Optional[bool] = None
|
||||||
allowReserved: Optional[bool] = None
|
allowReserved: Optional[bool] = None
|
||||||
schema_: Optional[Union[Schema, Reference]] = PSchema(None, alias="schema")
|
schema_: Optional[Union[Schema, Reference]] = PSchema(
|
||||||
|
None, alias="schema"
|
||||||
|
) # type: ignore
|
||||||
example: Optional[Any] = None
|
example: Optional[Any] = None
|
||||||
examples: Optional[Dict[str, Union[Example, Reference]]] = None
|
examples: Optional[Dict[str, Union[Example, Reference]]] = None
|
||||||
# Serialization rules for more complex scenarios
|
# Serialization rules for more complex scenarios
|
||||||
|
|
@ -167,7 +171,7 @@ class ParameterBase(BaseModel):
|
||||||
|
|
||||||
class Parameter(ParameterBase):
|
class Parameter(ParameterBase):
|
||||||
name: str
|
name: str
|
||||||
in_: ParameterInType = PSchema(..., alias="in")
|
in_: ParameterInType = PSchema(..., alias="in") # type: ignore
|
||||||
|
|
||||||
|
|
||||||
class Header(ParameterBase):
|
class Header(ParameterBase):
|
||||||
|
|
@ -222,7 +226,7 @@ class Operation(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class PathItem(BaseModel):
|
class PathItem(BaseModel):
|
||||||
ref: Optional[str] = PSchema(None, alias="$ref")
|
ref: Optional[str] = PSchema(None, alias="$ref") # type: ignore
|
||||||
summary: Optional[str] = None
|
summary: Optional[str] = None
|
||||||
description: Optional[str] = None
|
description: Optional[str] = None
|
||||||
get: Optional[Operation] = None
|
get: Optional[Operation] = None
|
||||||
|
|
@ -250,7 +254,7 @@ class SecuritySchemeType(Enum):
|
||||||
|
|
||||||
|
|
||||||
class SecurityBase(BaseModel):
|
class SecurityBase(BaseModel):
|
||||||
type_: SecuritySchemeType = PSchema(..., alias="type")
|
type_: SecuritySchemeType = PSchema(..., alias="type") # type: ignore
|
||||||
description: Optional[str] = None
|
description: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -261,13 +265,13 @@ class APIKeyIn(Enum):
|
||||||
|
|
||||||
|
|
||||||
class APIKey(SecurityBase):
|
class APIKey(SecurityBase):
|
||||||
type_ = PSchema(SecuritySchemeType.apiKey, alias="type")
|
type_ = PSchema(SecuritySchemeType.apiKey, alias="type") # type: ignore
|
||||||
in_: APIKeyIn = PSchema(..., alias="in")
|
in_: APIKeyIn = PSchema(..., alias="in") # type: ignore
|
||||||
name: str
|
name: str
|
||||||
|
|
||||||
|
|
||||||
class HTTPBase(SecurityBase):
|
class HTTPBase(SecurityBase):
|
||||||
type_ = PSchema(SecuritySchemeType.http, alias="type")
|
type_ = PSchema(SecuritySchemeType.http, alias="type") # type: ignore
|
||||||
scheme: str
|
scheme: str
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -306,12 +310,12 @@ class OAuthFlows(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class OAuth2(SecurityBase):
|
class OAuth2(SecurityBase):
|
||||||
type_ = PSchema(SecuritySchemeType.oauth2, alias="type")
|
type_ = PSchema(SecuritySchemeType.oauth2, alias="type") # type: ignore
|
||||||
flows: OAuthFlows
|
flows: OAuthFlows
|
||||||
|
|
||||||
|
|
||||||
class OpenIdConnect(SecurityBase):
|
class OpenIdConnect(SecurityBase):
|
||||||
type_ = PSchema(SecuritySchemeType.openIdConnect, alias="type")
|
type_ = PSchema(SecuritySchemeType.openIdConnect, alias="type") # type: ignore
|
||||||
openIdConnectUrl: str
|
openIdConnectUrl: str
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,23 +1,21 @@
|
||||||
from typing import Any, Dict, Sequence, Type, List
|
from typing import Any, Dict, List, Optional, Sequence, Tuple, Type
|
||||||
|
|
||||||
from pydantic.fields import Field
|
from pydantic.fields import Field
|
||||||
from pydantic.schema import field_schema, get_model_name_map
|
from pydantic.schema import Schema, field_schema, get_model_name_map
|
||||||
from pydantic.utils import lenient_issubclass
|
from pydantic.utils import lenient_issubclass
|
||||||
|
|
||||||
from starlette.responses import HTMLResponse, JSONResponse
|
from starlette.responses import HTMLResponse, JSONResponse
|
||||||
from starlette.routing import BaseRoute
|
from starlette.routing import BaseRoute, Route
|
||||||
from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY
|
from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY
|
||||||
|
|
||||||
from fastapi import routing
|
from fastapi import routing
|
||||||
from fastapi.dependencies.models import Dependant
|
from fastapi.dependencies.models import Dependant
|
||||||
from fastapi.dependencies.utils import get_flat_dependant
|
from fastapi.dependencies.utils import get_flat_dependant
|
||||||
from fastapi.encoders import jsonable_encoder
|
from fastapi.encoders import jsonable_encoder
|
||||||
from fastapi.openapi.constants import REF_PREFIX, METHODS_WITH_BODY
|
from fastapi.openapi.constants import METHODS_WITH_BODY, REF_PREFIX
|
||||||
from fastapi.openapi.models import OpenAPI
|
from fastapi.openapi.models import OpenAPI
|
||||||
from fastapi.params import Body
|
from fastapi.params import Body, Param
|
||||||
from fastapi.utils import get_flat_models_from_routes, get_model_definitions
|
from fastapi.utils import get_flat_models_from_routes, get_model_definitions
|
||||||
|
|
||||||
|
|
||||||
validation_error_definition = {
|
validation_error_definition = {
|
||||||
"title": "ValidationError",
|
"title": "ValidationError",
|
||||||
"type": "object",
|
"type": "object",
|
||||||
|
|
@ -42,7 +40,7 @@ validation_error_response_definition = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def get_openapi_params(dependant: Dependant):
|
def get_openapi_params(dependant: Dependant) -> List[Field]:
|
||||||
flat_dependant = get_flat_dependant(dependant)
|
flat_dependant = get_flat_dependant(dependant)
|
||||||
return (
|
return (
|
||||||
flat_dependant.path_params
|
flat_dependant.path_params
|
||||||
|
|
@ -52,7 +50,7 @@ def get_openapi_params(dependant: Dependant):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_openapi_security_definitions(flat_dependant: Dependant):
|
def get_openapi_security_definitions(flat_dependant: Dependant) -> Tuple[Dict, List]:
|
||||||
security_definitions = {}
|
security_definitions = {}
|
||||||
operation_security = []
|
operation_security = []
|
||||||
for security_requirement in flat_dependant.security_requirements:
|
for security_requirement in flat_dependant.security_requirements:
|
||||||
|
|
@ -61,59 +59,60 @@ def get_openapi_security_definitions(flat_dependant: Dependant):
|
||||||
by_alias=True,
|
by_alias=True,
|
||||||
include_none=False,
|
include_none=False,
|
||||||
)
|
)
|
||||||
security_name = (
|
security_name = security_requirement.security_scheme.scheme_name
|
||||||
security_requirement.security_scheme.scheme_name
|
|
||||||
|
|
||||||
)
|
|
||||||
security_definitions[security_name] = security_definition
|
security_definitions[security_name] = security_definition
|
||||||
operation_security.append({security_name: security_requirement.scopes})
|
operation_security.append({security_name: security_requirement.scopes})
|
||||||
return security_definitions, operation_security
|
return security_definitions, operation_security
|
||||||
|
|
||||||
|
|
||||||
def get_openapi_operation_parameters(all_route_params: List[Field]):
|
def get_openapi_operation_parameters(
|
||||||
|
all_route_params: Sequence[Field]
|
||||||
|
) -> Tuple[Dict[str, Dict], List[Dict[str, Any]]]:
|
||||||
definitions: Dict[str, Dict] = {}
|
definitions: Dict[str, Dict] = {}
|
||||||
parameters = []
|
parameters = []
|
||||||
for param in all_route_params:
|
for param in all_route_params:
|
||||||
|
schema: Param = param.schema
|
||||||
if "ValidationError" not in definitions:
|
if "ValidationError" not in definitions:
|
||||||
definitions["ValidationError"] = validation_error_definition
|
definitions["ValidationError"] = validation_error_definition
|
||||||
definitions["HTTPValidationError"] = validation_error_response_definition
|
definitions["HTTPValidationError"] = validation_error_response_definition
|
||||||
parameter = {
|
parameter = {
|
||||||
"name": param.alias,
|
"name": param.alias,
|
||||||
"in": param.schema.in_.value,
|
"in": schema.in_.value,
|
||||||
"required": param.required,
|
"required": param.required,
|
||||||
"schema": field_schema(param, model_name_map={})[0],
|
"schema": field_schema(param, model_name_map={})[0],
|
||||||
}
|
}
|
||||||
if param.schema.description:
|
if schema.description:
|
||||||
parameter["description"] = param.schema.description
|
parameter["description"] = schema.description
|
||||||
if param.schema.deprecated:
|
if schema.deprecated:
|
||||||
parameter["deprecated"] = param.schema.deprecated
|
parameter["deprecated"] = schema.deprecated
|
||||||
parameters.append(parameter)
|
parameters.append(parameter)
|
||||||
return definitions, parameters
|
return definitions, parameters
|
||||||
|
|
||||||
|
|
||||||
def get_openapi_operation_request_body(
|
def get_openapi_operation_request_body(
|
||||||
*, body_field: Field, model_name_map: Dict[Type, str]
|
*, body_field: Field, model_name_map: Dict[Type, str]
|
||||||
):
|
) -> Optional[Dict]:
|
||||||
if not body_field:
|
if not body_field:
|
||||||
return None
|
return None
|
||||||
assert isinstance(body_field, Field)
|
assert isinstance(body_field, Field)
|
||||||
body_schema, _ = field_schema(
|
body_schema, _ = field_schema(
|
||||||
body_field, model_name_map=model_name_map, ref_prefix=REF_PREFIX
|
body_field, model_name_map=model_name_map, ref_prefix=REF_PREFIX
|
||||||
)
|
)
|
||||||
if isinstance(body_field.schema, Body):
|
schema: Schema = body_field.schema
|
||||||
request_media_type = body_field.schema.media_type
|
if isinstance(schema, Body):
|
||||||
|
request_media_type = schema.media_type
|
||||||
else:
|
else:
|
||||||
# Includes not declared media types (Schema)
|
# Includes not declared media types (Schema)
|
||||||
request_media_type = "application/json"
|
request_media_type = "application/json"
|
||||||
required = body_field.required
|
required = body_field.required
|
||||||
request_body_oai = {}
|
request_body_oai: Dict[str, Any] = {}
|
||||||
if required:
|
if required:
|
||||||
request_body_oai["required"] = required
|
request_body_oai["required"] = required
|
||||||
request_body_oai["content"] = {request_media_type: {"schema": body_schema}}
|
request_body_oai["content"] = {request_media_type: {"schema": body_schema}}
|
||||||
return request_body_oai
|
return request_body_oai
|
||||||
|
|
||||||
|
|
||||||
def generate_operation_id(*, route: routing.APIRoute, method: str):
|
def generate_operation_id(*, route: routing.APIRoute, method: str) -> str:
|
||||||
if route.operation_id:
|
if route.operation_id:
|
||||||
return route.operation_id
|
return route.operation_id
|
||||||
path: str = route.path
|
path: str = route.path
|
||||||
|
|
@ -123,12 +122,13 @@ def generate_operation_id(*, route: routing.APIRoute, method: str):
|
||||||
return operation_id
|
return operation_id
|
||||||
|
|
||||||
|
|
||||||
def generate_operation_summary(*, route: routing.APIRoute, method: str):
|
def generate_operation_summary(*, route: routing.APIRoute, method: str) -> str:
|
||||||
if route.summary:
|
if route.summary:
|
||||||
return route.summary
|
return route.summary
|
||||||
return method.title() + " " + route.name.replace("_", " ").title()
|
return method.title() + " " + route.name.replace("_", " ").title()
|
||||||
|
|
||||||
def get_openapi_operation_metadata(*, route: BaseRoute, method: str):
|
|
||||||
|
def get_openapi_operation_metadata(*, route: routing.APIRoute, method: str) -> Dict:
|
||||||
operation: Dict[str, Any] = {}
|
operation: Dict[str, Any] = {}
|
||||||
if route.tags:
|
if route.tags:
|
||||||
operation["tags"] = route.tags
|
operation["tags"] = route.tags
|
||||||
|
|
@ -141,12 +141,13 @@ def get_openapi_operation_metadata(*, route: BaseRoute, method: str):
|
||||||
return operation
|
return operation
|
||||||
|
|
||||||
|
|
||||||
def get_openapi_path(*, route: BaseRoute, model_name_map: Dict[Type, str]):
|
def get_openapi_path(
|
||||||
if not (route.include_in_schema and isinstance(route, routing.APIRoute)):
|
*, route: routing.APIRoute, model_name_map: Dict[Type, str]
|
||||||
return None
|
) -> Tuple[Dict, Dict, Dict]:
|
||||||
path = {}
|
path = {}
|
||||||
security_schemes: Dict[str, Any] = {}
|
security_schemes: Dict[str, Any] = {}
|
||||||
definitions: Dict[str, Any] = {}
|
definitions: Dict[str, Any] = {}
|
||||||
|
assert route.methods is not None, "Methods must be a list"
|
||||||
for method in route.methods:
|
for method in route.methods:
|
||||||
operation = get_openapi_operation_metadata(route=route, method=method)
|
operation = get_openapi_operation_metadata(route=route, method=method)
|
||||||
parameters: List[Dict] = []
|
parameters: List[Dict] = []
|
||||||
|
|
@ -172,10 +173,9 @@ def get_openapi_path(*, route: BaseRoute, model_name_map: Dict[Type, str]):
|
||||||
)
|
)
|
||||||
if request_body_oai:
|
if request_body_oai:
|
||||||
operation["requestBody"] = request_body_oai
|
operation["requestBody"] = request_body_oai
|
||||||
response_code = str(route.response_code)
|
status_code = str(route.status_code)
|
||||||
response_schema = {"type": "string"}
|
response_schema = {"type": "string"}
|
||||||
if lenient_issubclass(route.response_wrapper, JSONResponse):
|
if lenient_issubclass(route.content_type, JSONResponse):
|
||||||
response_media_type = "application/json"
|
|
||||||
if route.response_field:
|
if route.response_field:
|
||||||
response_schema, _ = field_schema(
|
response_schema, _ = field_schema(
|
||||||
route.response_field,
|
route.response_field,
|
||||||
|
|
@ -184,16 +184,9 @@ def get_openapi_path(*, route: BaseRoute, model_name_map: Dict[Type, str]):
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
response_schema = {}
|
response_schema = {}
|
||||||
elif lenient_issubclass(route.response_wrapper, HTMLResponse):
|
content = {route.content_type.media_type: {"schema": response_schema}}
|
||||||
response_media_type = "text/html"
|
|
||||||
else:
|
|
||||||
response_media_type = "text/plain"
|
|
||||||
content = {response_media_type: {"schema": response_schema}}
|
|
||||||
operation["responses"] = {
|
operation["responses"] = {
|
||||||
response_code: {
|
status_code: {"description": route.response_description, "content": content}
|
||||||
"description": route.response_description,
|
|
||||||
"content": content,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
if all_route_params or route.body_field:
|
if all_route_params or route.body_field:
|
||||||
operation["responses"][str(HTTP_422_UNPROCESSABLE_ENTITY)] = {
|
operation["responses"][str(HTTP_422_UNPROCESSABLE_ENTITY)] = {
|
||||||
|
|
@ -215,7 +208,7 @@ def get_openapi(
|
||||||
openapi_version: str = "3.0.2",
|
openapi_version: str = "3.0.2",
|
||||||
description: str = None,
|
description: str = None,
|
||||||
routes: Sequence[BaseRoute]
|
routes: Sequence[BaseRoute]
|
||||||
):
|
) -> Dict:
|
||||||
info = {"title": title, "version": version}
|
info = {"title": title, "version": version}
|
||||||
if description:
|
if description:
|
||||||
info["description"] = description
|
info["description"] = description
|
||||||
|
|
@ -228,15 +221,18 @@ def get_openapi(
|
||||||
flat_models=flat_models, model_name_map=model_name_map
|
flat_models=flat_models, model_name_map=model_name_map
|
||||||
)
|
)
|
||||||
for route in routes:
|
for route in routes:
|
||||||
result = get_openapi_path(route=route, model_name_map=model_name_map)
|
if isinstance(route, routing.APIRoute):
|
||||||
if result:
|
result = get_openapi_path(route=route, model_name_map=model_name_map)
|
||||||
path, security_schemes, path_definitions = result
|
if result:
|
||||||
if path:
|
path, security_schemes, path_definitions = result
|
||||||
paths.setdefault(route.path, {}).update(path)
|
if path:
|
||||||
if security_schemes:
|
paths.setdefault(route.path, {}).update(path)
|
||||||
components.setdefault("securitySchemes", {}).update(security_schemes)
|
if security_schemes:
|
||||||
if path_definitions:
|
components.setdefault("securitySchemes", {}).update(
|
||||||
definitions.update(path_definitions)
|
security_schemes
|
||||||
|
)
|
||||||
|
if path_definitions:
|
||||||
|
definitions.update(path_definitions)
|
||||||
if definitions:
|
if definitions:
|
||||||
components.setdefault("schemas", {}).update(definitions)
|
components.setdefault("schemas", {}).update(definitions)
|
||||||
if components:
|
if components:
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Sequence, Any, Dict
|
from typing import Any, Callable, Sequence
|
||||||
|
|
||||||
from pydantic import Schema
|
from pydantic import Schema
|
||||||
|
|
||||||
|
|
@ -16,7 +16,7 @@ class Param(Schema):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
default,
|
default: Any,
|
||||||
*,
|
*,
|
||||||
deprecated: bool = None,
|
deprecated: bool = None,
|
||||||
alias: str = None,
|
alias: str = None,
|
||||||
|
|
@ -29,7 +29,7 @@ class Param(Schema):
|
||||||
min_length: int = None,
|
min_length: int = None,
|
||||||
max_length: int = None,
|
max_length: int = None,
|
||||||
regex: str = None,
|
regex: str = None,
|
||||||
**extra: Dict[str, Any],
|
**extra: Any,
|
||||||
):
|
):
|
||||||
self.deprecated = deprecated
|
self.deprecated = deprecated
|
||||||
super().__init__(
|
super().__init__(
|
||||||
|
|
@ -53,7 +53,7 @@ class Path(Param):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
default,
|
default: Any,
|
||||||
*,
|
*,
|
||||||
deprecated: bool = None,
|
deprecated: bool = None,
|
||||||
alias: str = None,
|
alias: str = None,
|
||||||
|
|
@ -66,7 +66,7 @@ class Path(Param):
|
||||||
min_length: int = None,
|
min_length: int = None,
|
||||||
max_length: int = None,
|
max_length: int = None,
|
||||||
regex: str = None,
|
regex: str = None,
|
||||||
**extra: Dict[str, Any],
|
**extra: Any,
|
||||||
):
|
):
|
||||||
self.description = description
|
self.description = description
|
||||||
self.deprecated = deprecated
|
self.deprecated = deprecated
|
||||||
|
|
@ -92,7 +92,7 @@ class Query(Param):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
default,
|
default: Any,
|
||||||
*,
|
*,
|
||||||
deprecated: bool = None,
|
deprecated: bool = None,
|
||||||
alias: str = None,
|
alias: str = None,
|
||||||
|
|
@ -105,7 +105,7 @@ class Query(Param):
|
||||||
min_length: int = None,
|
min_length: int = None,
|
||||||
max_length: int = None,
|
max_length: int = None,
|
||||||
regex: str = None,
|
regex: str = None,
|
||||||
**extra: Dict[str, Any],
|
**extra: Any,
|
||||||
):
|
):
|
||||||
self.description = description
|
self.description = description
|
||||||
self.deprecated = deprecated
|
self.deprecated = deprecated
|
||||||
|
|
@ -130,10 +130,11 @@ class Header(Param):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
default,
|
default: Any,
|
||||||
*,
|
*,
|
||||||
deprecated: bool = None,
|
deprecated: bool = None,
|
||||||
alias: str = None,
|
alias: str = None,
|
||||||
|
alias_underscore_to_hyphen: bool = True,
|
||||||
title: str = None,
|
title: str = None,
|
||||||
description: str = None,
|
description: str = None,
|
||||||
gt: float = None,
|
gt: float = None,
|
||||||
|
|
@ -143,10 +144,11 @@ class Header(Param):
|
||||||
min_length: int = None,
|
min_length: int = None,
|
||||||
max_length: int = None,
|
max_length: int = None,
|
||||||
regex: str = None,
|
regex: str = None,
|
||||||
**extra: Dict[str, Any],
|
**extra: Any,
|
||||||
):
|
):
|
||||||
self.description = description
|
self.description = description
|
||||||
self.deprecated = deprecated
|
self.deprecated = deprecated
|
||||||
|
self.alias_underscore_to_hyphen = alias_underscore_to_hyphen
|
||||||
super().__init__(
|
super().__init__(
|
||||||
default,
|
default,
|
||||||
alias=alias,
|
alias=alias,
|
||||||
|
|
@ -168,7 +170,7 @@ class Cookie(Param):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
default,
|
default: Any,
|
||||||
*,
|
*,
|
||||||
deprecated: bool = None,
|
deprecated: bool = None,
|
||||||
alias: str = None,
|
alias: str = None,
|
||||||
|
|
@ -181,7 +183,7 @@ class Cookie(Param):
|
||||||
min_length: int = None,
|
min_length: int = None,
|
||||||
max_length: int = None,
|
max_length: int = None,
|
||||||
regex: str = None,
|
regex: str = None,
|
||||||
**extra: Dict[str, Any],
|
**extra: Any,
|
||||||
):
|
):
|
||||||
self.description = description
|
self.description = description
|
||||||
self.deprecated = deprecated
|
self.deprecated = deprecated
|
||||||
|
|
@ -204,9 +206,9 @@ class Cookie(Param):
|
||||||
class Body(Schema):
|
class Body(Schema):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
default,
|
default: Any,
|
||||||
*,
|
*,
|
||||||
embed=False,
|
embed: bool = False,
|
||||||
media_type: str = "application/json",
|
media_type: str = "application/json",
|
||||||
alias: str = None,
|
alias: str = None,
|
||||||
title: str = None,
|
title: str = None,
|
||||||
|
|
@ -218,7 +220,7 @@ class Body(Schema):
|
||||||
min_length: int = None,
|
min_length: int = None,
|
||||||
max_length: int = None,
|
max_length: int = None,
|
||||||
regex: str = None,
|
regex: str = None,
|
||||||
**extra: Dict[str, Any],
|
**extra: Any,
|
||||||
):
|
):
|
||||||
self.embed = embed
|
self.embed = embed
|
||||||
self.media_type = media_type
|
self.media_type = media_type
|
||||||
|
|
@ -241,9 +243,9 @@ class Body(Schema):
|
||||||
class Form(Body):
|
class Form(Body):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
default,
|
default: Any,
|
||||||
*,
|
*,
|
||||||
sub_key=False,
|
sub_key: bool = False,
|
||||||
media_type: str = "application/x-www-form-urlencoded",
|
media_type: str = "application/x-www-form-urlencoded",
|
||||||
alias: str = None,
|
alias: str = None,
|
||||||
title: str = None,
|
title: str = None,
|
||||||
|
|
@ -255,7 +257,7 @@ class Form(Body):
|
||||||
min_length: int = None,
|
min_length: int = None,
|
||||||
max_length: int = None,
|
max_length: int = None,
|
||||||
regex: str = None,
|
regex: str = None,
|
||||||
**extra: Dict[str, Any],
|
**extra: Any,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
default,
|
default,
|
||||||
|
|
@ -278,9 +280,9 @@ class Form(Body):
|
||||||
class File(Form):
|
class File(Form):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
default,
|
default: Any,
|
||||||
*,
|
*,
|
||||||
sub_key=False,
|
sub_key: bool = False,
|
||||||
media_type: str = "multipart/form-data",
|
media_type: str = "multipart/form-data",
|
||||||
alias: str = None,
|
alias: str = None,
|
||||||
title: str = None,
|
title: str = None,
|
||||||
|
|
@ -292,7 +294,7 @@ class File(Form):
|
||||||
min_length: int = None,
|
min_length: int = None,
|
||||||
max_length: int = None,
|
max_length: int = None,
|
||||||
regex: str = None,
|
regex: str = None,
|
||||||
**extra: Dict[str, Any],
|
**extra: Any,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
default,
|
default,
|
||||||
|
|
@ -313,11 +315,11 @@ class File(Form):
|
||||||
|
|
||||||
|
|
||||||
class Depends:
|
class Depends:
|
||||||
def __init__(self, dependency=None):
|
def __init__(self, dependency: Callable = None):
|
||||||
self.dependency = dependency
|
self.dependency = dependency
|
||||||
|
|
||||||
|
|
||||||
class Security(Depends):
|
class Security(Depends):
|
||||||
def __init__(self, dependency=None, scopes: Sequence[str] = None):
|
def __init__(self, dependency: Callable = None, scopes: Sequence[str] = None):
|
||||||
self.scopes = scopes or []
|
self.scopes = scopes or []
|
||||||
super().__init__(dependency=dependency)
|
super().__init__(dependency=dependency)
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,11 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import inspect
|
import inspect
|
||||||
from typing import Callable, List, Type
|
from typing import Any, Callable, List, Optional, Type
|
||||||
|
|
||||||
from pydantic import BaseConfig, BaseModel, Schema
|
from pydantic import BaseConfig, BaseModel, Schema
|
||||||
from pydantic.error_wrappers import ErrorWrapper, ValidationError
|
from pydantic.error_wrappers import ErrorWrapper, ValidationError
|
||||||
from pydantic.fields import Field
|
from pydantic.fields import Field
|
||||||
from pydantic.utils import lenient_issubclass
|
from pydantic.utils import lenient_issubclass
|
||||||
|
|
||||||
from starlette import routing
|
from starlette import routing
|
||||||
from starlette.concurrency import run_in_threadpool
|
from starlette.concurrency import run_in_threadpool
|
||||||
from starlette.exceptions import HTTPException
|
from starlette.exceptions import HTTPException
|
||||||
|
|
@ -22,7 +21,7 @@ from fastapi.dependencies.utils import get_body_field, get_dependant, solve_depe
|
||||||
from fastapi.encoders import jsonable_encoder
|
from fastapi.encoders import jsonable_encoder
|
||||||
|
|
||||||
|
|
||||||
def serialize_response(*, field: Field = None, response):
|
def serialize_response(*, field: Field = None, response: Response) -> Any:
|
||||||
if field:
|
if field:
|
||||||
errors = []
|
errors = []
|
||||||
value, errors_ = field.validate(response, {}, loc=("response",))
|
value, errors_ = field.validate(response, {}, loc=("response",))
|
||||||
|
|
@ -40,11 +39,12 @@ def serialize_response(*, field: Field = None, response):
|
||||||
def get_app(
|
def get_app(
|
||||||
dependant: Dependant,
|
dependant: Dependant,
|
||||||
body_field: Field = None,
|
body_field: Field = None,
|
||||||
response_code: str = 200,
|
status_code: int = 200,
|
||||||
response_wrapper: Type[Response] = JSONResponse,
|
content_type: Type[Response] = JSONResponse,
|
||||||
response_field: Type[Field] = None,
|
response_field: Field = None,
|
||||||
):
|
) -> Callable:
|
||||||
is_coroutine = dependant.call and asyncio.iscoroutinefunction(dependant.call)
|
assert dependant.call is not None, "dependant.call must me a function"
|
||||||
|
is_coroutine = asyncio.iscoroutinefunction(dependant.call)
|
||||||
is_body_form = body_field and isinstance(body_field.schema, params.Form)
|
is_body_form = body_field and isinstance(body_field.schema, params.Form)
|
||||||
|
|
||||||
async def app(request: Request) -> Response:
|
async def app(request: Request) -> Response:
|
||||||
|
|
@ -69,6 +69,7 @@ def get_app(
|
||||||
status_code=HTTP_422_UNPROCESSABLE_ENTITY, detail=errors_out.errors()
|
status_code=HTTP_422_UNPROCESSABLE_ENTITY, detail=errors_out.errors()
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
assert dependant.call is not None, "dependant.call must me a function"
|
||||||
if is_coroutine:
|
if is_coroutine:
|
||||||
raw_response = await dependant.call(**values)
|
raw_response = await dependant.call(**values)
|
||||||
else:
|
else:
|
||||||
|
|
@ -76,32 +77,32 @@ def get_app(
|
||||||
if isinstance(raw_response, Response):
|
if isinstance(raw_response, Response):
|
||||||
return raw_response
|
return raw_response
|
||||||
if isinstance(raw_response, BaseModel):
|
if isinstance(raw_response, BaseModel):
|
||||||
return response_wrapper(
|
return content_type(
|
||||||
content=jsonable_encoder(raw_response), status_code=response_code
|
content=jsonable_encoder(raw_response), status_code=status_code
|
||||||
)
|
)
|
||||||
errors = []
|
errors = []
|
||||||
try:
|
try:
|
||||||
return response_wrapper(
|
return content_type(
|
||||||
content=serialize_response(
|
content=serialize_response(
|
||||||
field=response_field, response=raw_response
|
field=response_field, response=raw_response
|
||||||
),
|
),
|
||||||
status_code=response_code,
|
status_code=status_code,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
errors.append(e)
|
errors.append(e)
|
||||||
try:
|
try:
|
||||||
response = dict(raw_response)
|
response = dict(raw_response)
|
||||||
return response_wrapper(
|
return content_type(
|
||||||
content=serialize_response(field=response_field, response=response),
|
content=serialize_response(field=response_field, response=response),
|
||||||
status_code=response_code,
|
status_code=status_code,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
errors.append(e)
|
errors.append(e)
|
||||||
try:
|
try:
|
||||||
response = vars(raw_response)
|
response = vars(raw_response)
|
||||||
return response_wrapper(
|
return content_type(
|
||||||
content=serialize_response(field=response_field, response=response),
|
content=serialize_response(field=response_field, response=response),
|
||||||
status_code=response_code,
|
status_code=status_code,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
errors.append(e)
|
errors.append(e)
|
||||||
|
|
@ -116,43 +117,32 @@ class APIRoute(routing.Route):
|
||||||
path: str,
|
path: str,
|
||||||
endpoint: Callable,
|
endpoint: Callable,
|
||||||
*,
|
*,
|
||||||
methods: List[str] = None,
|
response_model: Type[BaseModel] = None,
|
||||||
name: str = None,
|
status_code: int = 200,
|
||||||
include_in_schema: bool = True,
|
|
||||||
tags: List[str] = None,
|
tags: List[str] = None,
|
||||||
summary: str = None,
|
summary: str = None,
|
||||||
description: str = None,
|
description: str = None,
|
||||||
operation_id: str = None,
|
|
||||||
deprecated: bool = None,
|
|
||||||
response_type: Type = None,
|
|
||||||
response_description: str = "Successful Response",
|
response_description: str = "Successful Response",
|
||||||
response_code=200,
|
deprecated: bool = None,
|
||||||
response_wrapper=JSONResponse,
|
name: str = None,
|
||||||
|
methods: List[str] = None,
|
||||||
|
operation_id: str = None,
|
||||||
|
include_in_schema: bool = True,
|
||||||
|
content_type: Type[Response] = JSONResponse,
|
||||||
) -> None:
|
) -> None:
|
||||||
assert path.startswith("/"), "Routed paths must always start with '/'"
|
assert path.startswith("/"), "Routed paths must always start with '/'"
|
||||||
self.path = path
|
self.path = path
|
||||||
self.endpoint = endpoint
|
self.endpoint = endpoint
|
||||||
self.name = get_name(endpoint) if name is None else name
|
self.name = get_name(endpoint) if name is None else name
|
||||||
self.include_in_schema = include_in_schema
|
self.response_model = response_model
|
||||||
self.tags = tags or []
|
if self.response_model:
|
||||||
self.summary = summary
|
|
||||||
self.description = description or self.endpoint.__doc__
|
|
||||||
self.operation_id = operation_id
|
|
||||||
self.deprecated = deprecated
|
|
||||||
self.body_field: Field = None
|
|
||||||
self.response_description = response_description
|
|
||||||
self.response_code = response_code
|
|
||||||
self.response_wrapper = response_wrapper
|
|
||||||
self.response_field = None
|
|
||||||
if response_type:
|
|
||||||
assert lenient_issubclass(
|
assert lenient_issubclass(
|
||||||
response_wrapper, JSONResponse
|
content_type, JSONResponse
|
||||||
), "To declare a type the response must be a JSON response"
|
), "To declare a type the response must be a JSON response"
|
||||||
self.response_type = response_type
|
|
||||||
response_name = "Response_" + self.name
|
response_name = "Response_" + self.name
|
||||||
self.response_field = Field(
|
self.response_field: Optional[Field] = Field(
|
||||||
name=response_name,
|
name=response_name,
|
||||||
type_=self.response_type,
|
type_=self.response_model,
|
||||||
class_validators=[],
|
class_validators=[],
|
||||||
default=None,
|
default=None,
|
||||||
required=False,
|
required=False,
|
||||||
|
|
@ -160,25 +150,34 @@ class APIRoute(routing.Route):
|
||||||
schema=Schema(None),
|
schema=Schema(None),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.response_type = None
|
self.response_field = None
|
||||||
|
self.status_code = status_code
|
||||||
|
self.tags = tags or []
|
||||||
|
self.summary = summary
|
||||||
|
self.description = description or self.endpoint.__doc__
|
||||||
|
self.response_description = response_description
|
||||||
|
self.deprecated = deprecated
|
||||||
if methods is None:
|
if methods is None:
|
||||||
methods = ["GET"]
|
methods = ["GET"]
|
||||||
self.methods = methods
|
self.methods = methods
|
||||||
|
self.operation_id = operation_id
|
||||||
|
self.include_in_schema = include_in_schema
|
||||||
|
self.content_type = content_type
|
||||||
|
|
||||||
self.path_regex, self.path_format, self.param_convertors = self.compile_path(
|
self.path_regex, self.path_format, self.param_convertors = self.compile_path(
|
||||||
path
|
path
|
||||||
)
|
)
|
||||||
assert inspect.isfunction(endpoint) or inspect.ismethod(
|
assert inspect.isfunction(endpoint) or inspect.ismethod(
|
||||||
endpoint
|
endpoint
|
||||||
), f"An endpoint must be a function or method"
|
), f"An endpoint must be a function or method"
|
||||||
|
|
||||||
self.dependant = get_dependant(path=path, call=self.endpoint)
|
self.dependant = get_dependant(path=path, call=self.endpoint)
|
||||||
self.body_field = get_body_field(dependant=self.dependant, name=self.name)
|
self.body_field = get_body_field(dependant=self.dependant, name=self.name)
|
||||||
self.app = request_response(
|
self.app = request_response(
|
||||||
get_app(
|
get_app(
|
||||||
dependant=self.dependant,
|
dependant=self.dependant,
|
||||||
body_field=self.body_field,
|
body_field=self.body_field,
|
||||||
response_code=self.response_code,
|
status_code=self.status_code,
|
||||||
response_wrapper=self.response_wrapper,
|
content_type=self.content_type,
|
||||||
response_field=self.response_field,
|
response_field=self.response_field,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
@ -189,75 +188,77 @@ class APIRouter(routing.Router):
|
||||||
self,
|
self,
|
||||||
path: str,
|
path: str,
|
||||||
endpoint: Callable,
|
endpoint: Callable,
|
||||||
methods: List[str] = None,
|
*,
|
||||||
name: str = None,
|
response_model: Type[BaseModel] = None,
|
||||||
include_in_schema: bool = True,
|
status_code: int = 200,
|
||||||
tags: List[str] = None,
|
tags: List[str] = None,
|
||||||
summary: str = None,
|
summary: str = None,
|
||||||
description: str = None,
|
description: str = None,
|
||||||
operation_id: str = None,
|
|
||||||
deprecated: bool = None,
|
|
||||||
response_type: Type = None,
|
|
||||||
response_description: str = "Successful Response",
|
response_description: str = "Successful Response",
|
||||||
response_code=200,
|
deprecated: bool = None,
|
||||||
response_wrapper=JSONResponse,
|
name: str = None,
|
||||||
|
methods: List[str] = None,
|
||||||
|
operation_id: str = None,
|
||||||
|
include_in_schema: bool = True,
|
||||||
|
content_type: Type[Response] = JSONResponse,
|
||||||
) -> None:
|
) -> None:
|
||||||
route = APIRoute(
|
route = APIRoute(
|
||||||
path,
|
path,
|
||||||
endpoint=endpoint,
|
endpoint=endpoint,
|
||||||
methods=methods,
|
response_model=response_model,
|
||||||
name=name,
|
status_code=status_code,
|
||||||
include_in_schema=include_in_schema,
|
|
||||||
tags=tags or [],
|
tags=tags or [],
|
||||||
summary=summary,
|
summary=summary,
|
||||||
description=description,
|
description=description,
|
||||||
operation_id=operation_id,
|
|
||||||
deprecated=deprecated,
|
|
||||||
response_type=response_type,
|
|
||||||
response_description=response_description,
|
response_description=response_description,
|
||||||
response_code=response_code,
|
deprecated=deprecated,
|
||||||
response_wrapper=response_wrapper,
|
name=name,
|
||||||
|
methods=methods,
|
||||||
|
operation_id=operation_id,
|
||||||
|
include_in_schema=include_in_schema,
|
||||||
|
content_type=content_type,
|
||||||
)
|
)
|
||||||
self.routes.append(route)
|
self.routes.append(route)
|
||||||
|
|
||||||
def api_route(
|
def api_route(
|
||||||
self,
|
self,
|
||||||
path: str,
|
path: str,
|
||||||
methods: List[str] = None,
|
*,
|
||||||
name: str = None,
|
response_model: Type[BaseModel] = None,
|
||||||
include_in_schema: bool = True,
|
status_code: int = 200,
|
||||||
tags: List[str] = None,
|
tags: List[str] = None,
|
||||||
summary: str = None,
|
summary: str = None,
|
||||||
description: str = None,
|
description: str = None,
|
||||||
operation_id: str = None,
|
|
||||||
deprecated: bool = None,
|
|
||||||
response_type: Type = None,
|
|
||||||
response_description: str = "Successful Response",
|
response_description: str = "Successful Response",
|
||||||
response_code=200,
|
deprecated: bool = None,
|
||||||
response_wrapper=JSONResponse,
|
name: str = None,
|
||||||
|
methods: List[str] = None,
|
||||||
|
operation_id: str = None,
|
||||||
|
include_in_schema: bool = True,
|
||||||
|
content_type: Type[Response] = JSONResponse,
|
||||||
) -> Callable:
|
) -> Callable:
|
||||||
def decorator(func: Callable) -> Callable:
|
def decorator(func: Callable) -> Callable:
|
||||||
self.add_api_route(
|
self.add_api_route(
|
||||||
path,
|
path,
|
||||||
func,
|
func,
|
||||||
methods=methods,
|
response_model=response_model,
|
||||||
name=name,
|
status_code=status_code,
|
||||||
include_in_schema=include_in_schema,
|
|
||||||
tags=tags or [],
|
tags=tags or [],
|
||||||
summary=summary,
|
summary=summary,
|
||||||
description=description,
|
description=description,
|
||||||
operation_id=operation_id,
|
|
||||||
deprecated=deprecated,
|
|
||||||
response_type=response_type,
|
|
||||||
response_description=response_description,
|
response_description=response_description,
|
||||||
response_code=response_code,
|
deprecated=deprecated,
|
||||||
response_wrapper=response_wrapper,
|
name=name,
|
||||||
|
methods=methods,
|
||||||
|
operation_id=operation_id,
|
||||||
|
include_in_schema=include_in_schema,
|
||||||
|
content_type=content_type,
|
||||||
)
|
)
|
||||||
return func
|
return func
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
def include_router(self, router: "APIRouter", *, prefix=""):
|
def include_router(self, router: "APIRouter", *, prefix: str = "") -> None:
|
||||||
if prefix:
|
if prefix:
|
||||||
assert prefix.startswith("/"), "A path prefix must start with '/'"
|
assert prefix.startswith("/"), "A path prefix must start with '/'"
|
||||||
assert not prefix.endswith(
|
assert not prefix.endswith(
|
||||||
|
|
@ -268,18 +269,18 @@ class APIRouter(routing.Router):
|
||||||
self.add_api_route(
|
self.add_api_route(
|
||||||
prefix + route.path,
|
prefix + route.path,
|
||||||
route.endpoint,
|
route.endpoint,
|
||||||
methods=route.methods,
|
response_model=route.response_model,
|
||||||
name=route.name,
|
status_code=route.status_code,
|
||||||
include_in_schema=route.include_in_schema,
|
tags=route.tags or [],
|
||||||
tags=route.tags,
|
|
||||||
summary=route.summary,
|
summary=route.summary,
|
||||||
description=route.description,
|
description=route.description,
|
||||||
operation_id=route.operation_id,
|
|
||||||
deprecated=route.deprecated,
|
|
||||||
response_type=route.response_type,
|
|
||||||
response_description=route.response_description,
|
response_description=route.response_description,
|
||||||
response_code=route.response_code,
|
deprecated=route.deprecated,
|
||||||
response_wrapper=route.response_wrapper,
|
name=route.name,
|
||||||
|
methods=route.methods,
|
||||||
|
operation_id=route.operation_id,
|
||||||
|
include_in_schema=route.include_in_schema,
|
||||||
|
content_type=route.content_type,
|
||||||
)
|
)
|
||||||
elif isinstance(route, routing.Route):
|
elif isinstance(route, routing.Route):
|
||||||
self.add_route(
|
self.add_route(
|
||||||
|
|
@ -293,247 +294,255 @@ class APIRouter(routing.Router):
|
||||||
def get(
|
def get(
|
||||||
self,
|
self,
|
||||||
path: str,
|
path: str,
|
||||||
name: str = None,
|
*,
|
||||||
include_in_schema: bool = True,
|
response_model: Type[BaseModel] = None,
|
||||||
|
status_code: int = 200,
|
||||||
tags: List[str] = None,
|
tags: List[str] = None,
|
||||||
summary: str = None,
|
summary: str = None,
|
||||||
description: str = None,
|
description: str = None,
|
||||||
operation_id: str = None,
|
|
||||||
deprecated: bool = None,
|
|
||||||
response_type: Type = None,
|
|
||||||
response_description: str = "Successful Response",
|
response_description: str = "Successful Response",
|
||||||
response_code=200,
|
deprecated: bool = None,
|
||||||
response_wrapper=JSONResponse,
|
name: str = None,
|
||||||
):
|
operation_id: str = None,
|
||||||
|
include_in_schema: bool = True,
|
||||||
|
content_type: Type[Response] = JSONResponse,
|
||||||
|
) -> Callable:
|
||||||
return self.api_route(
|
return self.api_route(
|
||||||
path=path,
|
path=path,
|
||||||
methods=["GET"],
|
response_model=response_model,
|
||||||
name=name,
|
status_code=status_code,
|
||||||
include_in_schema=include_in_schema,
|
|
||||||
tags=tags or [],
|
tags=tags or [],
|
||||||
summary=summary,
|
summary=summary,
|
||||||
description=description,
|
description=description,
|
||||||
operation_id=operation_id,
|
|
||||||
deprecated=deprecated,
|
|
||||||
response_type=response_type,
|
|
||||||
response_description=response_description,
|
response_description=response_description,
|
||||||
response_code=response_code,
|
deprecated=deprecated,
|
||||||
response_wrapper=response_wrapper,
|
name=name,
|
||||||
|
methods=["GET"],
|
||||||
|
operation_id=operation_id,
|
||||||
|
include_in_schema=include_in_schema,
|
||||||
|
content_type=content_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
def put(
|
def put(
|
||||||
self,
|
self,
|
||||||
path: str,
|
path: str,
|
||||||
name: str = None,
|
*,
|
||||||
include_in_schema: bool = True,
|
response_model: Type[BaseModel] = None,
|
||||||
|
status_code: int = 200,
|
||||||
tags: List[str] = None,
|
tags: List[str] = None,
|
||||||
summary: str = None,
|
summary: str = None,
|
||||||
description: str = None,
|
description: str = None,
|
||||||
operation_id: str = None,
|
|
||||||
deprecated: bool = None,
|
|
||||||
response_type: Type = None,
|
|
||||||
response_description: str = "Successful Response",
|
response_description: str = "Successful Response",
|
||||||
response_code=200,
|
deprecated: bool = None,
|
||||||
response_wrapper=JSONResponse,
|
name: str = None,
|
||||||
):
|
operation_id: str = None,
|
||||||
|
include_in_schema: bool = True,
|
||||||
|
content_type: Type[Response] = JSONResponse,
|
||||||
|
) -> Callable:
|
||||||
return self.api_route(
|
return self.api_route(
|
||||||
path=path,
|
path=path,
|
||||||
methods=["PUT"],
|
response_model=response_model,
|
||||||
name=name,
|
status_code=status_code,
|
||||||
include_in_schema=include_in_schema,
|
|
||||||
tags=tags or [],
|
tags=tags or [],
|
||||||
summary=summary,
|
summary=summary,
|
||||||
description=description,
|
description=description,
|
||||||
operation_id=operation_id,
|
|
||||||
deprecated=deprecated,
|
|
||||||
response_type=response_type,
|
|
||||||
response_description=response_description,
|
response_description=response_description,
|
||||||
response_code=response_code,
|
deprecated=deprecated,
|
||||||
response_wrapper=response_wrapper,
|
name=name,
|
||||||
|
methods=["PUT"],
|
||||||
|
operation_id=operation_id,
|
||||||
|
include_in_schema=include_in_schema,
|
||||||
|
content_type=content_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
def post(
|
def post(
|
||||||
self,
|
self,
|
||||||
path: str,
|
path: str,
|
||||||
name: str = None,
|
*,
|
||||||
include_in_schema: bool = True,
|
response_model: Type[BaseModel] = None,
|
||||||
|
status_code: int = 200,
|
||||||
tags: List[str] = None,
|
tags: List[str] = None,
|
||||||
summary: str = None,
|
summary: str = None,
|
||||||
description: str = None,
|
description: str = None,
|
||||||
operation_id: str = None,
|
|
||||||
deprecated: bool = None,
|
|
||||||
response_type: Type = None,
|
|
||||||
response_description: str = "Successful Response",
|
response_description: str = "Successful Response",
|
||||||
response_code=200,
|
deprecated: bool = None,
|
||||||
response_wrapper=JSONResponse,
|
name: str = None,
|
||||||
):
|
operation_id: str = None,
|
||||||
|
include_in_schema: bool = True,
|
||||||
|
content_type: Type[Response] = JSONResponse,
|
||||||
|
) -> Callable:
|
||||||
return self.api_route(
|
return self.api_route(
|
||||||
path=path,
|
path=path,
|
||||||
methods=["POST"],
|
response_model=response_model,
|
||||||
name=name,
|
status_code=status_code,
|
||||||
include_in_schema=include_in_schema,
|
|
||||||
tags=tags or [],
|
tags=tags or [],
|
||||||
summary=summary,
|
summary=summary,
|
||||||
description=description,
|
description=description,
|
||||||
operation_id=operation_id,
|
|
||||||
deprecated=deprecated,
|
|
||||||
response_type=response_type,
|
|
||||||
response_description=response_description,
|
response_description=response_description,
|
||||||
response_code=response_code,
|
deprecated=deprecated,
|
||||||
response_wrapper=response_wrapper,
|
name=name,
|
||||||
|
methods=["POST"],
|
||||||
|
operation_id=operation_id,
|
||||||
|
include_in_schema=include_in_schema,
|
||||||
|
content_type=content_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
def delete(
|
def delete(
|
||||||
self,
|
self,
|
||||||
path: str,
|
path: str,
|
||||||
name: str = None,
|
*,
|
||||||
include_in_schema: bool = True,
|
response_model: Type[BaseModel] = None,
|
||||||
|
status_code: int = 200,
|
||||||
tags: List[str] = None,
|
tags: List[str] = None,
|
||||||
summary: str = None,
|
summary: str = None,
|
||||||
description: str = None,
|
description: str = None,
|
||||||
operation_id: str = None,
|
|
||||||
deprecated: bool = None,
|
|
||||||
response_type: Type = None,
|
|
||||||
response_description: str = "Successful Response",
|
response_description: str = "Successful Response",
|
||||||
response_code=200,
|
deprecated: bool = None,
|
||||||
response_wrapper=JSONResponse,
|
name: str = None,
|
||||||
):
|
operation_id: str = None,
|
||||||
|
include_in_schema: bool = True,
|
||||||
|
content_type: Type[Response] = JSONResponse,
|
||||||
|
) -> Callable:
|
||||||
return self.api_route(
|
return self.api_route(
|
||||||
path=path,
|
path=path,
|
||||||
methods=["DELETE"],
|
response_model=response_model,
|
||||||
name=name,
|
status_code=status_code,
|
||||||
include_in_schema=include_in_schema,
|
|
||||||
tags=tags or [],
|
tags=tags or [],
|
||||||
summary=summary,
|
summary=summary,
|
||||||
description=description,
|
description=description,
|
||||||
operation_id=operation_id,
|
|
||||||
deprecated=deprecated,
|
|
||||||
response_type=response_type,
|
|
||||||
response_description=response_description,
|
response_description=response_description,
|
||||||
response_code=response_code,
|
deprecated=deprecated,
|
||||||
response_wrapper=response_wrapper,
|
name=name,
|
||||||
|
methods=["DELETE"],
|
||||||
|
operation_id=operation_id,
|
||||||
|
include_in_schema=include_in_schema,
|
||||||
|
content_type=content_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
def options(
|
def options(
|
||||||
self,
|
self,
|
||||||
path: str,
|
path: str,
|
||||||
name: str = None,
|
*,
|
||||||
include_in_schema: bool = True,
|
response_model: Type[BaseModel] = None,
|
||||||
|
status_code: int = 200,
|
||||||
tags: List[str] = None,
|
tags: List[str] = None,
|
||||||
summary: str = None,
|
summary: str = None,
|
||||||
description: str = None,
|
description: str = None,
|
||||||
operation_id: str = None,
|
|
||||||
deprecated: bool = None,
|
|
||||||
response_type: Type = None,
|
|
||||||
response_description: str = "Successful Response",
|
response_description: str = "Successful Response",
|
||||||
response_code=200,
|
deprecated: bool = None,
|
||||||
response_wrapper=JSONResponse,
|
name: str = None,
|
||||||
):
|
operation_id: str = None,
|
||||||
|
include_in_schema: bool = True,
|
||||||
|
content_type: Type[Response] = JSONResponse,
|
||||||
|
) -> Callable:
|
||||||
return self.api_route(
|
return self.api_route(
|
||||||
path=path,
|
path=path,
|
||||||
methods=["OPTIONS"],
|
response_model=response_model,
|
||||||
name=name,
|
status_code=status_code,
|
||||||
include_in_schema=include_in_schema,
|
|
||||||
tags=tags or [],
|
tags=tags or [],
|
||||||
summary=summary,
|
summary=summary,
|
||||||
description=description,
|
description=description,
|
||||||
operation_id=operation_id,
|
|
||||||
deprecated=deprecated,
|
|
||||||
response_type=response_type,
|
|
||||||
response_description=response_description,
|
response_description=response_description,
|
||||||
response_code=response_code,
|
deprecated=deprecated,
|
||||||
response_wrapper=response_wrapper,
|
name=name,
|
||||||
|
methods=["OPTIONS"],
|
||||||
|
operation_id=operation_id,
|
||||||
|
include_in_schema=include_in_schema,
|
||||||
|
content_type=content_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
def head(
|
def head(
|
||||||
self,
|
self,
|
||||||
path: str,
|
path: str,
|
||||||
name: str = None,
|
*,
|
||||||
include_in_schema: bool = True,
|
response_model: Type[BaseModel] = None,
|
||||||
|
status_code: int = 200,
|
||||||
tags: List[str] = None,
|
tags: List[str] = None,
|
||||||
summary: str = None,
|
summary: str = None,
|
||||||
description: str = None,
|
description: str = None,
|
||||||
operation_id: str = None,
|
|
||||||
deprecated: bool = None,
|
|
||||||
response_type: Type = None,
|
|
||||||
response_description: str = "Successful Response",
|
response_description: str = "Successful Response",
|
||||||
response_code=200,
|
deprecated: bool = None,
|
||||||
response_wrapper=JSONResponse,
|
name: str = None,
|
||||||
):
|
operation_id: str = None,
|
||||||
|
include_in_schema: bool = True,
|
||||||
|
content_type: Type[Response] = JSONResponse,
|
||||||
|
) -> Callable:
|
||||||
return self.api_route(
|
return self.api_route(
|
||||||
path=path,
|
path=path,
|
||||||
methods=["HEAD"],
|
response_model=response_model,
|
||||||
name=name,
|
status_code=status_code,
|
||||||
include_in_schema=include_in_schema,
|
|
||||||
tags=tags or [],
|
tags=tags or [],
|
||||||
summary=summary,
|
summary=summary,
|
||||||
description=description,
|
description=description,
|
||||||
operation_id=operation_id,
|
|
||||||
deprecated=deprecated,
|
|
||||||
response_type=response_type,
|
|
||||||
response_description=response_description,
|
response_description=response_description,
|
||||||
response_code=response_code,
|
deprecated=deprecated,
|
||||||
response_wrapper=response_wrapper,
|
name=name,
|
||||||
|
methods=["HEAD"],
|
||||||
|
operation_id=operation_id,
|
||||||
|
include_in_schema=include_in_schema,
|
||||||
|
content_type=content_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
def patch(
|
def patch(
|
||||||
self,
|
self,
|
||||||
path: str,
|
path: str,
|
||||||
name: str = None,
|
*,
|
||||||
include_in_schema: bool = True,
|
response_model: Type[BaseModel] = None,
|
||||||
|
status_code: int = 200,
|
||||||
tags: List[str] = None,
|
tags: List[str] = None,
|
||||||
summary: str = None,
|
summary: str = None,
|
||||||
description: str = None,
|
description: str = None,
|
||||||
operation_id: str = None,
|
|
||||||
deprecated: bool = None,
|
|
||||||
response_type: Type = None,
|
|
||||||
response_description: str = "Successful Response",
|
response_description: str = "Successful Response",
|
||||||
response_code=200,
|
deprecated: bool = None,
|
||||||
response_wrapper=JSONResponse,
|
name: str = None,
|
||||||
):
|
operation_id: str = None,
|
||||||
|
include_in_schema: bool = True,
|
||||||
|
content_type: Type[Response] = JSONResponse,
|
||||||
|
) -> Callable:
|
||||||
return self.api_route(
|
return self.api_route(
|
||||||
path=path,
|
path=path,
|
||||||
methods=["PATCH"],
|
response_model=response_model,
|
||||||
name=name,
|
status_code=status_code,
|
||||||
include_in_schema=include_in_schema,
|
|
||||||
tags=tags or [],
|
tags=tags or [],
|
||||||
summary=summary,
|
summary=summary,
|
||||||
description=description,
|
description=description,
|
||||||
operation_id=operation_id,
|
|
||||||
deprecated=deprecated,
|
|
||||||
response_type=response_type,
|
|
||||||
response_description=response_description,
|
response_description=response_description,
|
||||||
response_code=response_code,
|
deprecated=deprecated,
|
||||||
response_wrapper=response_wrapper,
|
name=name,
|
||||||
|
methods=["PATCH"],
|
||||||
|
operation_id=operation_id,
|
||||||
|
include_in_schema=include_in_schema,
|
||||||
|
content_type=content_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
def trace(
|
def trace(
|
||||||
self,
|
self,
|
||||||
path: str,
|
path: str,
|
||||||
name: str = None,
|
*,
|
||||||
include_in_schema: bool = True,
|
response_model: Type[BaseModel] = None,
|
||||||
|
status_code: int = 200,
|
||||||
tags: List[str] = None,
|
tags: List[str] = None,
|
||||||
summary: str = None,
|
summary: str = None,
|
||||||
description: str = None,
|
description: str = None,
|
||||||
operation_id: str = None,
|
|
||||||
deprecated: bool = None,
|
|
||||||
response_type: Type = None,
|
|
||||||
response_description: str = "Successful Response",
|
response_description: str = "Successful Response",
|
||||||
response_code=200,
|
deprecated: bool = None,
|
||||||
response_wrapper=JSONResponse,
|
name: str = None,
|
||||||
):
|
operation_id: str = None,
|
||||||
|
include_in_schema: bool = True,
|
||||||
|
content_type: Type[Response] = JSONResponse,
|
||||||
|
) -> Callable:
|
||||||
return self.api_route(
|
return self.api_route(
|
||||||
path=path,
|
path=path,
|
||||||
methods=["TRACE"],
|
response_model=response_model,
|
||||||
name=name,
|
status_code=status_code,
|
||||||
include_in_schema=include_in_schema,
|
|
||||||
tags=tags or [],
|
tags=tags or [],
|
||||||
summary=summary,
|
summary=summary,
|
||||||
description=description,
|
description=description,
|
||||||
operation_id=operation_id,
|
|
||||||
deprecated=deprecated,
|
|
||||||
response_type=response_type,
|
|
||||||
response_description=response_description,
|
response_description=response_description,
|
||||||
response_code=response_code,
|
deprecated=deprecated,
|
||||||
response_wrapper=response_wrapper,
|
name=name,
|
||||||
|
methods=["TRACE"],
|
||||||
|
operation_id=operation_id,
|
||||||
|
include_in_schema=include_in_schema,
|
||||||
|
content_type=content_type,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -1,18 +1,19 @@
|
||||||
from starlette.requests import Request
|
from starlette.requests import Request
|
||||||
|
|
||||||
from .base import SecurityBase
|
from fastapi.openapi.models import APIKey, APIKeyIn
|
||||||
from fastapi.openapi.models import APIKeyIn, APIKey
|
from fastapi.security.base import SecurityBase
|
||||||
|
|
||||||
|
|
||||||
class APIKeyBase(SecurityBase):
|
class APIKeyBase(SecurityBase):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
class APIKeyQuery(APIKeyBase):
|
|
||||||
|
|
||||||
|
class APIKeyQuery(APIKeyBase):
|
||||||
def __init__(self, *, name: str, scheme_name: str = None):
|
def __init__(self, *, name: str, scheme_name: str = None):
|
||||||
self.model = APIKey(in_=APIKeyIn.query, name=name)
|
self.model = APIKey(in_=APIKeyIn.query, name=name)
|
||||||
self.scheme_name = scheme_name or self.__class__.__name__
|
self.scheme_name = scheme_name or self.__class__.__name__
|
||||||
|
|
||||||
async def __call__(self, requests: Request):
|
async def __call__(self, requests: Request) -> str:
|
||||||
return requests.query_params.get(self.model.name)
|
return requests.query_params.get(self.model.name)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -21,7 +22,7 @@ class APIKeyHeader(APIKeyBase):
|
||||||
self.model = APIKey(in_=APIKeyIn.header, name=name)
|
self.model = APIKey(in_=APIKeyIn.header, name=name)
|
||||||
self.scheme_name = scheme_name or self.__class__.__name__
|
self.scheme_name = scheme_name or self.__class__.__name__
|
||||||
|
|
||||||
async def __call__(self, requests: Request):
|
async def __call__(self, requests: Request) -> str:
|
||||||
return requests.headers.get(self.model.name)
|
return requests.headers.get(self.model.name)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -30,5 +31,5 @@ class APIKeyCookie(APIKeyBase):
|
||||||
self.model = APIKey(in_=APIKeyIn.cookie, name=name)
|
self.model = APIKey(in_=APIKeyIn.cookie, name=name)
|
||||||
self.scheme_name = scheme_name or self.__class__.__name__
|
self.scheme_name = scheme_name or self.__class__.__name__
|
||||||
|
|
||||||
async def __call__(self, requests: Request):
|
async def __call__(self, requests: Request) -> str:
|
||||||
return requests.cookies.get(self.model.name)
|
return requests.cookies.get(self.model.name)
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from fastapi.openapi.models import SecurityBase as SecurityBaseModel
|
from fastapi.openapi.models import SecurityBase as SecurityBaseModel
|
||||||
|
|
||||||
|
|
||||||
class SecurityBase:
|
class SecurityBase:
|
||||||
pass
|
model: SecurityBaseModel
|
||||||
|
scheme_name: str
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,10 @@
|
||||||
from starlette.requests import Request
|
from starlette.requests import Request
|
||||||
|
|
||||||
from .base import SecurityBase
|
from fastapi.openapi.models import (
|
||||||
from fastapi.openapi.models import HTTPBase as HTTPBaseModel, HTTPBearer as HTTPBearerModel
|
HTTPBase as HTTPBaseModel,
|
||||||
|
HTTPBearer as HTTPBearerModel,
|
||||||
|
)
|
||||||
|
from fastapi.security.base import SecurityBase
|
||||||
|
|
||||||
|
|
||||||
class HTTPBase(SecurityBase):
|
class HTTPBase(SecurityBase):
|
||||||
|
|
@ -9,7 +12,7 @@ class HTTPBase(SecurityBase):
|
||||||
self.model = HTTPBaseModel(scheme=scheme)
|
self.model = HTTPBaseModel(scheme=scheme)
|
||||||
self.scheme_name = scheme_name or self.__class__.__name__
|
self.scheme_name = scheme_name or self.__class__.__name__
|
||||||
|
|
||||||
async def __call__(self, request: Request):
|
async def __call__(self, request: Request) -> str:
|
||||||
return request.headers.get("Authorization")
|
return request.headers.get("Authorization")
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -17,8 +20,8 @@ class HTTPBasic(HTTPBase):
|
||||||
def __init__(self, *, scheme_name: str = None):
|
def __init__(self, *, scheme_name: str = None):
|
||||||
self.model = HTTPBaseModel(scheme="basic")
|
self.model = HTTPBaseModel(scheme="basic")
|
||||||
self.scheme_name = scheme_name or self.__class__.__name__
|
self.scheme_name = scheme_name or self.__class__.__name__
|
||||||
|
|
||||||
async def __call__(self, request: Request):
|
async def __call__(self, request: Request) -> str:
|
||||||
return request.headers.get("Authorization")
|
return request.headers.get("Authorization")
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -26,8 +29,8 @@ class HTTPBearer(HTTPBase):
|
||||||
def __init__(self, *, bearerFormat: str = None, scheme_name: str = None):
|
def __init__(self, *, bearerFormat: str = None, scheme_name: str = None):
|
||||||
self.model = HTTPBearerModel(bearerFormat=bearerFormat)
|
self.model = HTTPBearerModel(bearerFormat=bearerFormat)
|
||||||
self.scheme_name = scheme_name or self.__class__.__name__
|
self.scheme_name = scheme_name or self.__class__.__name__
|
||||||
|
|
||||||
async def __call__(self, request: Request):
|
async def __call__(self, request: Request) -> str:
|
||||||
return request.headers.get("Authorization")
|
return request.headers.get("Authorization")
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -35,6 +38,6 @@ class HTTPDigest(HTTPBase):
|
||||||
def __init__(self, *, scheme_name: str = None):
|
def __init__(self, *, scheme_name: str = None):
|
||||||
self.model = HTTPBaseModel(scheme="digest")
|
self.model = HTTPBaseModel(scheme="digest")
|
||||||
self.scheme_name = scheme_name or self.__class__.__name__
|
self.scheme_name = scheme_name or self.__class__.__name__
|
||||||
|
|
||||||
async def __call__(self, request: Request):
|
async def __call__(self, request: Request) -> str:
|
||||||
return request.headers.get("Authorization")
|
return request.headers.get("Authorization")
|
||||||
|
|
|
||||||
|
|
@ -1,13 +1,15 @@
|
||||||
from starlette.requests import Request
|
from starlette.requests import Request
|
||||||
|
|
||||||
from .base import SecurityBase
|
|
||||||
from fastapi.openapi.models import OAuth2 as OAuth2Model, OAuthFlows as OAuthFlowsModel
|
from fastapi.openapi.models import OAuth2 as OAuth2Model, OAuthFlows as OAuthFlowsModel
|
||||||
|
from fastapi.security.base import SecurityBase
|
||||||
|
|
||||||
|
|
||||||
class OAuth2(SecurityBase):
|
class OAuth2(SecurityBase):
|
||||||
def __init__(self, *, flows: OAuthFlowsModel = OAuthFlowsModel(), scheme_name: str = None):
|
def __init__(
|
||||||
|
self, *, flows: OAuthFlowsModel = OAuthFlowsModel(), scheme_name: str = None
|
||||||
|
):
|
||||||
self.model = OAuth2Model(flows=flows)
|
self.model = OAuth2Model(flows=flows)
|
||||||
self.scheme_name = scheme_name or self.__class__.__name__
|
self.scheme_name = scheme_name or self.__class__.__name__
|
||||||
|
|
||||||
async def __call__(self, request: Request):
|
async def __call__(self, request: Request) -> str:
|
||||||
return request.headers.get("Authorization")
|
return request.headers.get("Authorization")
|
||||||
|
|
|
||||||
|
|
@ -1,13 +1,13 @@
|
||||||
from starlette.requests import Request
|
from starlette.requests import Request
|
||||||
|
|
||||||
from .base import SecurityBase
|
|
||||||
from fastapi.openapi.models import OpenIdConnect as OpenIdConnectModel
|
from fastapi.openapi.models import OpenIdConnect as OpenIdConnectModel
|
||||||
|
from fastapi.security.base import SecurityBase
|
||||||
|
|
||||||
|
|
||||||
class OpenIdConnect(SecurityBase):
|
class OpenIdConnect(SecurityBase):
|
||||||
def __init__(self, *, openIdConnectUrl: str, scheme_name: str = None):
|
def __init__(self, *, openIdConnectUrl: str, scheme_name: str = None):
|
||||||
self.model = OpenIdConnectModel(openIdConnectUrl=openIdConnectUrl)
|
self.model = OpenIdConnectModel(openIdConnectUrl=openIdConnectUrl)
|
||||||
self.scheme_name = scheme_name or self.__class__.__name__
|
self.scheme_name = scheme_name or self.__class__.__name__
|
||||||
|
|
||||||
async def __call__(self, request: Request):
|
async def __call__(self, request: Request) -> str:
|
||||||
return request.headers.get("Authorization")
|
return request.headers.get("Authorization")
|
||||||
|
|
|
||||||
|
|
@ -1,20 +1,24 @@
|
||||||
import re
|
import re
|
||||||
from typing import Dict, Sequence, Set, Type
|
from typing import Any, Dict, List, Sequence, Set, Type
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from pydantic.fields import Field
|
||||||
|
from pydantic.schema import get_flat_models_from_fields, model_process_schema
|
||||||
from starlette.routing import BaseRoute
|
from starlette.routing import BaseRoute
|
||||||
|
|
||||||
from fastapi import routing
|
from fastapi import routing
|
||||||
from fastapi.openapi.constants import REF_PREFIX
|
from fastapi.openapi.constants import REF_PREFIX
|
||||||
from pydantic import BaseModel
|
|
||||||
from pydantic.fields import Field
|
|
||||||
from pydantic.schema import get_flat_models_from_fields, model_process_schema
|
|
||||||
|
|
||||||
|
|
||||||
def get_flat_models_from_routes(routes: Sequence[BaseRoute]):
|
def get_flat_models_from_routes(
|
||||||
body_fields_from_routes = []
|
routes: Sequence[Type[BaseRoute]]
|
||||||
responses_from_routes = []
|
) -> Set[Type[BaseModel]]:
|
||||||
|
body_fields_from_routes: List[Field] = []
|
||||||
|
responses_from_routes: List[Field] = []
|
||||||
for route in routes:
|
for route in routes:
|
||||||
if route.include_in_schema and isinstance(route, routing.APIRoute):
|
if getattr(route, "include_in_schema", None) and isinstance(
|
||||||
|
route, routing.APIRoute
|
||||||
|
):
|
||||||
if route.body_field:
|
if route.body_field:
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
route.body_field, Field
|
route.body_field, Field
|
||||||
|
|
@ -30,7 +34,7 @@ def get_flat_models_from_routes(routes: Sequence[BaseRoute]):
|
||||||
|
|
||||||
def get_model_definitions(
|
def get_model_definitions(
|
||||||
*, flat_models: Set[Type[BaseModel]], model_name_map: Dict[Type[BaseModel], str]
|
*, flat_models: Set[Type[BaseModel]], model_name_map: Dict[Type[BaseModel], str]
|
||||||
):
|
) -> Dict[str, Any]:
|
||||||
definitions: Dict[str, Dict] = {}
|
definitions: Dict[str, Dict] = {}
|
||||||
for model in flat_models:
|
for model in flat_models:
|
||||||
m_schema, m_definitions = model_process_schema(
|
m_schema, m_definitions = model_process_schema(
|
||||||
|
|
@ -42,5 +46,5 @@ def get_model_definitions(
|
||||||
return definitions
|
return definitions
|
||||||
|
|
||||||
|
|
||||||
def get_path_param_names(path: str):
|
def get_path_param_names(path: str) -> Set[str]:
|
||||||
return {item.strip("{}") for item in re.findall("{[^}]*}", path)}
|
return {item.strip("{}") for item in re.findall("{[^}]*}", path)}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue