mirror of https://github.com/tiangolo/fastapi.git
✨ Improve type annotations, add support for mypy --strict, internally and for external packages (#2547)
This commit is contained in:
parent
4fdcdf341c
commit
fdb6c9ccc5
|
|
@ -26,7 +26,7 @@ invoices_callback_router = APIRouter()
|
|||
|
||||
|
||||
@invoices_callback_router.post(
|
||||
"{$callback_url}/invoices/{$request.body.id}", response_model=InvoiceEventReceived,
|
||||
"{$callback_url}/invoices/{$request.body.id}", response_model=InvoiceEventReceived
|
||||
)
|
||||
def invoice_notification(body: InvoiceEvent):
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -2,24 +2,23 @@
|
|||
|
||||
__version__ = "0.62.0"
|
||||
|
||||
from starlette import status
|
||||
from starlette import status as status
|
||||
|
||||
from .applications import FastAPI
|
||||
from .background import BackgroundTasks
|
||||
from .datastructures import UploadFile
|
||||
from .exceptions import HTTPException
|
||||
from .param_functions import (
|
||||
Body,
|
||||
Cookie,
|
||||
Depends,
|
||||
File,
|
||||
Form,
|
||||
Header,
|
||||
Path,
|
||||
Query,
|
||||
Security,
|
||||
)
|
||||
from .requests import Request
|
||||
from .responses import Response
|
||||
from .routing import APIRouter
|
||||
from .websockets import WebSocket, WebSocketDisconnect
|
||||
from .applications import FastAPI as FastAPI
|
||||
from .background import BackgroundTasks as BackgroundTasks
|
||||
from .datastructures import UploadFile as UploadFile
|
||||
from .exceptions import HTTPException as HTTPException
|
||||
from .param_functions import Body as Body
|
||||
from .param_functions import Cookie as Cookie
|
||||
from .param_functions import Depends as Depends
|
||||
from .param_functions import File as File
|
||||
from .param_functions import Form as Form
|
||||
from .param_functions import Header as Header
|
||||
from .param_functions import Path as Path
|
||||
from .param_functions import Query as Query
|
||||
from .param_functions import Security as Security
|
||||
from .requests import Request as Request
|
||||
from .responses import Response as Response
|
||||
from .routing import APIRouter as APIRouter
|
||||
from .websockets import WebSocket as WebSocket
|
||||
from .websockets import WebSocketDisconnect as WebSocketDisconnect
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Any, Callable, Dict, List, Optional, Sequence, Type, Union
|
||||
from typing import Any, Callable, Coroutine, Dict, List, Optional, Sequence, Type, Union
|
||||
|
||||
from fastapi import routing
|
||||
from fastapi.concurrency import AsyncExitStack
|
||||
|
|
@ -17,6 +17,7 @@ from fastapi.openapi.docs import (
|
|||
)
|
||||
from fastapi.openapi.utils import get_openapi
|
||||
from fastapi.params import Depends
|
||||
from fastapi.types import DecoratedCallable
|
||||
from starlette.applications import Starlette
|
||||
from starlette.datastructures import State
|
||||
from starlette.exceptions import HTTPException
|
||||
|
|
@ -24,7 +25,7 @@ from starlette.middleware import Middleware
|
|||
from starlette.requests import Request
|
||||
from starlette.responses import HTMLResponse, JSONResponse, Response
|
||||
from starlette.routing import BaseRoute
|
||||
from starlette.types import Receive, Scope, Send
|
||||
from starlette.types import ASGIApp, Receive, Scope, Send
|
||||
|
||||
|
||||
class FastAPI(Starlette):
|
||||
|
|
@ -44,24 +45,27 @@ class FastAPI(Starlette):
|
|||
docs_url: Optional[str] = "/docs",
|
||||
redoc_url: Optional[str] = "/redoc",
|
||||
swagger_ui_oauth2_redirect_url: Optional[str] = "/docs/oauth2-redirect",
|
||||
swagger_ui_init_oauth: Optional[dict] = None,
|
||||
swagger_ui_init_oauth: Optional[Dict[str, Any]] = None,
|
||||
middleware: Optional[Sequence[Middleware]] = None,
|
||||
exception_handlers: Optional[
|
||||
Dict[Union[int, Type[Exception]], Callable]
|
||||
Dict[
|
||||
Union[int, Type[Exception]],
|
||||
Callable[[Request, Any], Coroutine[Any, Any, Response]],
|
||||
]
|
||||
] = None,
|
||||
on_startup: Optional[Sequence[Callable]] = None,
|
||||
on_shutdown: Optional[Sequence[Callable]] = None,
|
||||
on_startup: Optional[Sequence[Callable[[], Any]]] = None,
|
||||
on_shutdown: Optional[Sequence[Callable[[], Any]]] = None,
|
||||
openapi_prefix: str = "",
|
||||
root_path: str = "",
|
||||
root_path_in_servers: bool = True,
|
||||
responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None,
|
||||
callbacks: Optional[List[routing.APIRoute]] = None,
|
||||
deprecated: bool = None,
|
||||
callbacks: Optional[List[BaseRoute]] = None,
|
||||
deprecated: Optional[bool] = None,
|
||||
include_in_schema: bool = True,
|
||||
**extra: Any,
|
||||
) -> None:
|
||||
self._debug = debug
|
||||
self.state = State()
|
||||
self._debug: bool = debug
|
||||
self.state: State = State()
|
||||
self.router: routing.APIRouter = routing.APIRouter(
|
||||
routes=routes,
|
||||
dependency_overrides_provider=self,
|
||||
|
|
@ -74,7 +78,10 @@ class FastAPI(Starlette):
|
|||
include_in_schema=include_in_schema,
|
||||
responses=responses,
|
||||
)
|
||||
self.exception_handlers = (
|
||||
self.exception_handlers: Dict[
|
||||
Union[int, Type[Exception]],
|
||||
Callable[[Request, Any], Coroutine[Any, Any, Response]],
|
||||
] = (
|
||||
{} if exception_handlers is None else dict(exception_handlers)
|
||||
)
|
||||
self.exception_handlers.setdefault(HTTPException, http_exception_handler)
|
||||
|
|
@ -82,8 +89,10 @@ class FastAPI(Starlette):
|
|||
RequestValidationError, request_validation_exception_handler
|
||||
)
|
||||
|
||||
self.user_middleware = [] if middleware is None else list(middleware)
|
||||
self.middleware_stack = self.build_middleware_stack()
|
||||
self.user_middleware: List[Middleware] = (
|
||||
[] if middleware is None else list(middleware)
|
||||
)
|
||||
self.middleware_stack: ASGIApp = self.build_middleware_stack()
|
||||
|
||||
self.title = title
|
||||
self.description = description
|
||||
|
|
@ -106,7 +115,7 @@ class FastAPI(Starlette):
|
|||
self.swagger_ui_oauth2_redirect_url = swagger_ui_oauth2_redirect_url
|
||||
self.swagger_ui_init_oauth = swagger_ui_init_oauth
|
||||
self.extra = extra
|
||||
self.dependency_overrides: Dict[Callable, Callable] = {}
|
||||
self.dependency_overrides: Dict[Callable[..., Any], Callable[..., Any]] = {}
|
||||
|
||||
self.openapi_version = "3.0.2"
|
||||
|
||||
|
|
@ -116,7 +125,7 @@ class FastAPI(Starlette):
|
|||
self.openapi_schema: Optional[Dict[str, Any]] = None
|
||||
self.setup()
|
||||
|
||||
def openapi(self) -> Dict:
|
||||
def openapi(self) -> Dict[str, Any]:
|
||||
if not self.openapi_schema:
|
||||
self.openapi_schema = get_openapi(
|
||||
title=self.title,
|
||||
|
|
@ -194,7 +203,7 @@ class FastAPI(Starlette):
|
|||
def add_api_route(
|
||||
self,
|
||||
path: str,
|
||||
endpoint: Callable,
|
||||
endpoint: Callable[..., Coroutine[Any, Any, Response]],
|
||||
*,
|
||||
response_model: Optional[Type[Any]] = None,
|
||||
status_code: int = 200,
|
||||
|
|
@ -268,8 +277,8 @@ class FastAPI(Starlette):
|
|||
include_in_schema: bool = True,
|
||||
response_class: Type[Response] = Default(JSONResponse),
|
||||
name: Optional[str] = None,
|
||||
) -> Callable:
|
||||
def decorator(func: Callable) -> Callable:
|
||||
) -> Callable[[DecoratedCallable], DecoratedCallable]:
|
||||
def decorator(func: DecoratedCallable) -> DecoratedCallable:
|
||||
self.router.add_api_route(
|
||||
path,
|
||||
func,
|
||||
|
|
@ -299,12 +308,14 @@ class FastAPI(Starlette):
|
|||
return decorator
|
||||
|
||||
def add_api_websocket_route(
|
||||
self, path: str, endpoint: Callable, name: Optional[str] = None
|
||||
self, path: str, endpoint: Callable[..., Any], name: Optional[str] = None
|
||||
) -> None:
|
||||
self.router.add_api_websocket_route(path, endpoint, name=name)
|
||||
|
||||
def websocket(self, path: str, name: Optional[str] = None) -> Callable:
|
||||
def decorator(func: Callable) -> Callable:
|
||||
def websocket(
|
||||
self, path: str, name: Optional[str] = None
|
||||
) -> Callable[[DecoratedCallable], DecoratedCallable]:
|
||||
def decorator(func: DecoratedCallable) -> DecoratedCallable:
|
||||
self.add_api_websocket_route(path, func, name=name)
|
||||
return func
|
||||
|
||||
|
|
@ -318,10 +329,10 @@ class FastAPI(Starlette):
|
|||
tags: Optional[List[str]] = None,
|
||||
dependencies: Optional[Sequence[Depends]] = None,
|
||||
responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None,
|
||||
deprecated: bool = None,
|
||||
deprecated: Optional[bool] = None,
|
||||
include_in_schema: bool = True,
|
||||
default_response_class: Type[Response] = Default(JSONResponse),
|
||||
callbacks: Optional[List[routing.APIRoute]] = None,
|
||||
callbacks: Optional[List[BaseRoute]] = None,
|
||||
) -> None:
|
||||
self.router.include_router(
|
||||
router,
|
||||
|
|
@ -358,8 +369,8 @@ class FastAPI(Starlette):
|
|||
include_in_schema: bool = True,
|
||||
response_class: Type[Response] = Default(JSONResponse),
|
||||
name: Optional[str] = None,
|
||||
callbacks: Optional[List[routing.APIRoute]] = None,
|
||||
) -> Callable:
|
||||
callbacks: Optional[List[BaseRoute]] = None,
|
||||
) -> Callable[[DecoratedCallable], DecoratedCallable]:
|
||||
return self.router.get(
|
||||
path,
|
||||
response_model=response_model,
|
||||
|
|
@ -407,8 +418,8 @@ class FastAPI(Starlette):
|
|||
include_in_schema: bool = True,
|
||||
response_class: Type[Response] = Default(JSONResponse),
|
||||
name: Optional[str] = None,
|
||||
callbacks: Optional[List[routing.APIRoute]] = None,
|
||||
) -> Callable:
|
||||
callbacks: Optional[List[BaseRoute]] = None,
|
||||
) -> Callable[[DecoratedCallable], DecoratedCallable]:
|
||||
return self.router.put(
|
||||
path,
|
||||
response_model=response_model,
|
||||
|
|
@ -456,8 +467,8 @@ class FastAPI(Starlette):
|
|||
include_in_schema: bool = True,
|
||||
response_class: Type[Response] = Default(JSONResponse),
|
||||
name: Optional[str] = None,
|
||||
callbacks: Optional[List[routing.APIRoute]] = None,
|
||||
) -> Callable:
|
||||
callbacks: Optional[List[BaseRoute]] = None,
|
||||
) -> Callable[[DecoratedCallable], DecoratedCallable]:
|
||||
return self.router.post(
|
||||
path,
|
||||
response_model=response_model,
|
||||
|
|
@ -505,8 +516,8 @@ class FastAPI(Starlette):
|
|||
include_in_schema: bool = True,
|
||||
response_class: Type[Response] = Default(JSONResponse),
|
||||
name: Optional[str] = None,
|
||||
callbacks: Optional[List[routing.APIRoute]] = None,
|
||||
) -> Callable:
|
||||
callbacks: Optional[List[BaseRoute]] = None,
|
||||
) -> Callable[[DecoratedCallable], DecoratedCallable]:
|
||||
return self.router.delete(
|
||||
path,
|
||||
response_model=response_model,
|
||||
|
|
@ -554,8 +565,8 @@ class FastAPI(Starlette):
|
|||
include_in_schema: bool = True,
|
||||
response_class: Type[Response] = Default(JSONResponse),
|
||||
name: Optional[str] = None,
|
||||
callbacks: Optional[List[routing.APIRoute]] = None,
|
||||
) -> Callable:
|
||||
callbacks: Optional[List[BaseRoute]] = None,
|
||||
) -> Callable[[DecoratedCallable], DecoratedCallable]:
|
||||
return self.router.options(
|
||||
path,
|
||||
response_model=response_model,
|
||||
|
|
@ -603,8 +614,8 @@ class FastAPI(Starlette):
|
|||
include_in_schema: bool = True,
|
||||
response_class: Type[Response] = Default(JSONResponse),
|
||||
name: Optional[str] = None,
|
||||
callbacks: Optional[List[routing.APIRoute]] = None,
|
||||
) -> Callable:
|
||||
callbacks: Optional[List[BaseRoute]] = None,
|
||||
) -> Callable[[DecoratedCallable], DecoratedCallable]:
|
||||
return self.router.head(
|
||||
path,
|
||||
response_model=response_model,
|
||||
|
|
@ -652,8 +663,8 @@ class FastAPI(Starlette):
|
|||
include_in_schema: bool = True,
|
||||
response_class: Type[Response] = Default(JSONResponse),
|
||||
name: Optional[str] = None,
|
||||
callbacks: Optional[List[routing.APIRoute]] = None,
|
||||
) -> Callable:
|
||||
callbacks: Optional[List[BaseRoute]] = None,
|
||||
) -> Callable[[DecoratedCallable], DecoratedCallable]:
|
||||
return self.router.patch(
|
||||
path,
|
||||
response_model=response_model,
|
||||
|
|
@ -701,8 +712,8 @@ class FastAPI(Starlette):
|
|||
include_in_schema: bool = True,
|
||||
response_class: Type[Response] = Default(JSONResponse),
|
||||
name: Optional[str] = None,
|
||||
callbacks: Optional[List[routing.APIRoute]] = None,
|
||||
) -> Callable:
|
||||
callbacks: Optional[List[BaseRoute]] = None,
|
||||
) -> Callable[[DecoratedCallable], DecoratedCallable]:
|
||||
return self.router.trace(
|
||||
path,
|
||||
response_model=response_model,
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
from starlette.background import BackgroundTasks # noqa
|
||||
from starlette.background import BackgroundTasks as BackgroundTasks # noqa
|
||||
|
|
|
|||
|
|
@ -1,8 +1,10 @@
|
|||
from typing import Any, Callable
|
||||
|
||||
from starlette.concurrency import iterate_in_threadpool # noqa
|
||||
from starlette.concurrency import run_in_threadpool # noqa
|
||||
from starlette.concurrency import run_until_first_complete # noqa
|
||||
from starlette.concurrency import iterate_in_threadpool as iterate_in_threadpool # noqa
|
||||
from starlette.concurrency import run_in_threadpool as run_in_threadpool # noqa
|
||||
from starlette.concurrency import ( # noqa
|
||||
run_until_first_complete as run_until_first_complete,
|
||||
)
|
||||
|
||||
asynccontextmanager_error_message = """
|
||||
FastAPI's contextmanager_in_threadpool require Python 3.7 or above,
|
||||
|
|
@ -11,7 +13,7 @@ or the backport for Python 3.6, installed with:
|
|||
"""
|
||||
|
||||
|
||||
def _fake_asynccontextmanager(func: Callable) -> Callable:
|
||||
def _fake_asynccontextmanager(func: Callable[..., Any]) -> Callable[..., Any]:
|
||||
def raiser(*args: Any, **kwargs: Any) -> Any:
|
||||
raise RuntimeError(asynccontextmanager_error_message)
|
||||
|
||||
|
|
@ -19,23 +21,25 @@ def _fake_asynccontextmanager(func: Callable) -> Callable:
|
|||
|
||||
|
||||
try:
|
||||
from contextlib import asynccontextmanager # type: ignore
|
||||
from contextlib import asynccontextmanager as asynccontextmanager # type: ignore
|
||||
except ImportError:
|
||||
try:
|
||||
from async_generator import asynccontextmanager # type: ignore
|
||||
from async_generator import ( # type: ignore # isort: skip
|
||||
asynccontextmanager as asynccontextmanager,
|
||||
)
|
||||
except ImportError: # pragma: no cover
|
||||
asynccontextmanager = _fake_asynccontextmanager
|
||||
|
||||
try:
|
||||
from contextlib import AsyncExitStack # type: ignore
|
||||
from contextlib import AsyncExitStack as AsyncExitStack # type: ignore
|
||||
except ImportError:
|
||||
try:
|
||||
from async_exit_stack import AsyncExitStack # type: ignore
|
||||
from async_exit_stack import AsyncExitStack as AsyncExitStack # type: ignore
|
||||
except ImportError: # pragma: no cover
|
||||
AsyncExitStack = None # type: ignore
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
@asynccontextmanager # type: ignore
|
||||
async def contextmanager_in_threadpool(cm: Any) -> Any:
|
||||
try:
|
||||
yield await run_in_threadpool(cm.__enter__)
|
||||
|
|
|
|||
|
|
@ -1,11 +1,12 @@
|
|||
from typing import Any, Callable, Iterable, Type, TypeVar
|
||||
|
||||
from starlette.datastructures import State as State # noqa: F401
|
||||
from starlette.datastructures import UploadFile as StarletteUploadFile
|
||||
|
||||
|
||||
class UploadFile(StarletteUploadFile):
|
||||
@classmethod
|
||||
def __get_validators__(cls: Type["UploadFile"]) -> Iterable[Callable]:
|
||||
def __get_validators__(cls: Type["UploadFile"]) -> Iterable[Callable[..., Any]]:
|
||||
yield cls.validate
|
||||
|
||||
@classmethod
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Callable, List, Optional, Sequence
|
||||
from typing import Any, Callable, List, Optional, Sequence
|
||||
|
||||
from fastapi.security.base import SecurityBase
|
||||
from pydantic.fields import ModelField
|
||||
|
|
@ -24,7 +24,7 @@ class Dependant:
|
|||
dependencies: Optional[List["Dependant"]] = None,
|
||||
security_schemes: Optional[List[SecurityRequirement]] = None,
|
||||
name: Optional[str] = None,
|
||||
call: Optional[Callable] = None,
|
||||
call: Optional[Callable[..., Any]] = None,
|
||||
request_param_name: Optional[str] = None,
|
||||
websocket_param_name: Optional[str] = None,
|
||||
http_connection_param_name: Optional[str] = None,
|
||||
|
|
|
|||
|
|
@ -90,12 +90,12 @@ def check_file_field(field: ModelField) -> None:
|
|||
if isinstance(field_info, params.Form):
|
||||
try:
|
||||
# __version__ is available in both multiparts, and can be mocked
|
||||
from multipart import __version__
|
||||
from multipart import __version__ # type: ignore
|
||||
|
||||
assert __version__
|
||||
try:
|
||||
# parse_options_header is only available in the right multipart
|
||||
from multipart.multipart import parse_options_header
|
||||
from multipart.multipart import parse_options_header # type: ignore
|
||||
|
||||
assert parse_options_header
|
||||
except ImportError:
|
||||
|
|
@ -133,7 +133,7 @@ def get_parameterless_sub_dependant(*, depends: params.Depends, path: str) -> De
|
|||
def get_sub_dependant(
|
||||
*,
|
||||
depends: params.Depends,
|
||||
dependency: Callable,
|
||||
dependency: Callable[..., Any],
|
||||
path: str,
|
||||
name: Optional[str] = None,
|
||||
security_scopes: Optional[List[str]] = None,
|
||||
|
|
@ -163,7 +163,7 @@ def get_sub_dependant(
|
|||
return sub_dependant
|
||||
|
||||
|
||||
CacheKey = Tuple[Optional[Callable], Tuple[str, ...]]
|
||||
CacheKey = Tuple[Optional[Callable[..., Any]], Tuple[str, ...]]
|
||||
|
||||
|
||||
def get_flat_dependant(
|
||||
|
|
@ -240,7 +240,7 @@ def is_scalar_sequence_field(field: ModelField) -> bool:
|
|||
return False
|
||||
|
||||
|
||||
def get_typed_signature(call: Callable) -> inspect.Signature:
|
||||
def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
|
||||
signature = inspect.signature(call)
|
||||
globalns = getattr(call, "__globals__", {})
|
||||
typed_params = [
|
||||
|
|
@ -259,9 +259,7 @@ def get_typed_signature(call: Callable) -> inspect.Signature:
|
|||
def get_typed_annotation(param: inspect.Parameter, globalns: Dict[str, Any]) -> Any:
|
||||
annotation = param.annotation
|
||||
if isinstance(annotation, str):
|
||||
# Temporary ignore type
|
||||
# Ref: https://github.com/samuelcolvin/pydantic/issues/1738
|
||||
annotation = ForwardRef(annotation) # type: ignore
|
||||
annotation = ForwardRef(annotation)
|
||||
annotation = evaluate_forwardref(annotation, globalns, globalns)
|
||||
return annotation
|
||||
|
||||
|
|
@ -281,7 +279,7 @@ def check_dependency_contextmanagers() -> None:
|
|||
def get_dependant(
|
||||
*,
|
||||
path: str,
|
||||
call: Callable,
|
||||
call: Callable[..., Any],
|
||||
name: Optional[str] = None,
|
||||
security_scopes: Optional[List[str]] = None,
|
||||
use_cache: bool = True,
|
||||
|
|
@ -423,7 +421,7 @@ def add_param_to_fields(*, field: ModelField, dependant: Dependant) -> None:
|
|||
dependant.cookie_params.append(field)
|
||||
|
||||
|
||||
def is_coroutine_callable(call: Callable) -> bool:
|
||||
def is_coroutine_callable(call: Callable[..., Any]) -> bool:
|
||||
if inspect.isroutine(call):
|
||||
return inspect.iscoroutinefunction(call)
|
||||
if inspect.isclass(call):
|
||||
|
|
@ -432,14 +430,14 @@ def is_coroutine_callable(call: Callable) -> bool:
|
|||
return inspect.iscoroutinefunction(call)
|
||||
|
||||
|
||||
def is_async_gen_callable(call: Callable) -> bool:
|
||||
def is_async_gen_callable(call: Callable[..., Any]) -> bool:
|
||||
if inspect.isasyncgenfunction(call):
|
||||
return True
|
||||
call = getattr(call, "__call__", None)
|
||||
return inspect.isasyncgenfunction(call)
|
||||
|
||||
|
||||
def is_gen_callable(call: Callable) -> bool:
|
||||
def is_gen_callable(call: Callable[..., Any]) -> bool:
|
||||
if inspect.isgeneratorfunction(call):
|
||||
return True
|
||||
call = getattr(call, "__call__", None)
|
||||
|
|
@ -447,7 +445,7 @@ def is_gen_callable(call: Callable) -> bool:
|
|||
|
||||
|
||||
async def solve_generator(
|
||||
*, call: Callable, stack: AsyncExitStack, sub_values: Dict[str, Any]
|
||||
*, call: Callable[..., Any], stack: AsyncExitStack, sub_values: Dict[str, Any]
|
||||
) -> Any:
|
||||
if is_gen_callable(call):
|
||||
cm = contextmanager_in_threadpool(contextmanager(call)(**sub_values))
|
||||
|
|
@ -472,29 +470,29 @@ async def solve_dependencies(
|
|||
background_tasks: Optional[BackgroundTasks] = None,
|
||||
response: Optional[Response] = None,
|
||||
dependency_overrides_provider: Optional[Any] = None,
|
||||
dependency_cache: Optional[Dict[Tuple[Callable, Tuple[str]], Any]] = None,
|
||||
dependency_cache: Optional[Dict[Tuple[Callable[..., Any], Tuple[str]], Any]] = None,
|
||||
) -> Tuple[
|
||||
Dict[str, Any],
|
||||
List[ErrorWrapper],
|
||||
Optional[BackgroundTasks],
|
||||
Response,
|
||||
Dict[Tuple[Callable, Tuple[str]], Any],
|
||||
Dict[Tuple[Callable[..., Any], Tuple[str]], Any],
|
||||
]:
|
||||
values: Dict[str, Any] = {}
|
||||
errors: List[ErrorWrapper] = []
|
||||
response = response or Response(
|
||||
content=None,
|
||||
status_code=None, # type: ignore
|
||||
headers=None,
|
||||
media_type=None,
|
||||
background=None,
|
||||
headers=None, # type: ignore # in Starlette
|
||||
media_type=None, # type: ignore # in Starlette
|
||||
background=None, # type: ignore # in Starlette
|
||||
)
|
||||
dependency_cache = dependency_cache or {}
|
||||
sub_dependant: Dependant
|
||||
for sub_dependant in dependant.dependencies:
|
||||
sub_dependant.call = cast(Callable, sub_dependant.call)
|
||||
sub_dependant.call = cast(Callable[..., Any], sub_dependant.call)
|
||||
sub_dependant.cache_key = cast(
|
||||
Tuple[Callable, Tuple[str]], sub_dependant.cache_key
|
||||
Tuple[Callable[..., Any], Tuple[str]], sub_dependant.cache_key
|
||||
)
|
||||
call = sub_dependant.call
|
||||
use_sub_dependant = sub_dependant
|
||||
|
|
|
|||
|
|
@ -12,9 +12,11 @@ DictIntStrAny = Dict[Union[int, str], Any]
|
|||
|
||||
|
||||
def generate_encoders_by_class_tuples(
|
||||
type_encoder_map: Dict[Any, Callable]
|
||||
) -> Dict[Callable, Tuple]:
|
||||
encoders_by_class_tuples: Dict[Callable, Tuple] = defaultdict(tuple)
|
||||
type_encoder_map: Dict[Any, Callable[[Any], Any]]
|
||||
) -> Dict[Callable[[Any], Any], Tuple[Any, ...]]:
|
||||
encoders_by_class_tuples: Dict[Callable[[Any], Any], Tuple[Any, ...]] = defaultdict(
|
||||
tuple
|
||||
)
|
||||
for type_, encoder in type_encoder_map.items():
|
||||
encoders_by_class_tuples[encoder] += (type_,)
|
||||
return encoders_by_class_tuples
|
||||
|
|
@ -31,7 +33,7 @@ def jsonable_encoder(
|
|||
exclude_unset: bool = False,
|
||||
exclude_defaults: bool = False,
|
||||
exclude_none: bool = False,
|
||||
custom_encoder: dict = {},
|
||||
custom_encoder: Dict[Any, Callable[[Any], Any]] = {},
|
||||
sqlalchemy_safe: bool = True,
|
||||
) -> Any:
|
||||
if include is not None and not isinstance(include, set):
|
||||
|
|
@ -43,8 +45,8 @@ def jsonable_encoder(
|
|||
if custom_encoder:
|
||||
encoder.update(custom_encoder)
|
||||
obj_dict = obj.dict(
|
||||
include=include,
|
||||
exclude=exclude,
|
||||
include=include, # type: ignore # in Pydantic
|
||||
exclude=exclude, # type: ignore # in Pydantic
|
||||
by_alias=by_alias,
|
||||
exclude_unset=exclude_unset,
|
||||
exclude_none=exclude_none,
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
from starlette.middleware import Middleware
|
||||
from starlette.middleware import Middleware as Middleware
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
from starlette.middleware.cors import CORSMiddleware # noqa
|
||||
from starlette.middleware.cors import CORSMiddleware as CORSMiddleware # noqa
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
from starlette.middleware.gzip import GZipMiddleware # noqa
|
||||
from starlette.middleware.gzip import GZipMiddleware as GZipMiddleware # noqa
|
||||
|
|
|
|||
|
|
@ -1 +1,3 @@
|
|||
from starlette.middleware.httpsredirect import HTTPSRedirectMiddleware # noqa
|
||||
from starlette.middleware.httpsredirect import ( # noqa
|
||||
HTTPSRedirectMiddleware as HTTPSRedirectMiddleware,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1 +1,3 @@
|
|||
from starlette.middleware.trustedhost import TrustedHostMiddleware # noqa
|
||||
from starlette.middleware.trustedhost import ( # noqa
|
||||
TrustedHostMiddleware as TrustedHostMiddleware,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
from starlette.middleware.wsgi import WSGIMiddleware # noqa
|
||||
from starlette.middleware.wsgi import WSGIMiddleware as WSGIMiddleware # noqa
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import json
|
||||
from typing import Optional
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from starlette.responses import HTMLResponse
|
||||
|
|
@ -13,7 +13,7 @@ def get_swagger_ui_html(
|
|||
swagger_css_url: str = "https://cdn.jsdelivr.net/npm/swagger-ui-dist@3/swagger-ui.css",
|
||||
swagger_favicon_url: str = "https://fastapi.tiangolo.com/img/favicon.png",
|
||||
oauth2_redirect_url: Optional[str] = None,
|
||||
init_oauth: Optional[dict] = None,
|
||||
init_oauth: Optional[Dict[str, Any]] = None,
|
||||
) -> HTMLResponse:
|
||||
|
||||
html = f"""
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from fastapi.logger import logger
|
|||
from pydantic import AnyUrl, BaseModel, Field
|
||||
|
||||
try:
|
||||
import email_validator
|
||||
import email_validator # type: ignore
|
||||
|
||||
assert email_validator # make autoflake ignore the unused import
|
||||
from pydantic import EmailStr
|
||||
|
|
@ -13,7 +13,7 @@ except ImportError: # pragma: no cover
|
|||
|
||||
class EmailStr(str): # type: ignore
|
||||
@classmethod
|
||||
def __get_validators__(cls) -> Iterable[Callable]:
|
||||
def __get_validators__(cls) -> Iterable[Callable[..., Any]]:
|
||||
yield cls.validate
|
||||
|
||||
@classmethod
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ from fastapi.openapi.constants import (
|
|||
)
|
||||
from fastapi.openapi.models import OpenAPI
|
||||
from fastapi.params import Body, Param
|
||||
from fastapi.responses import Response
|
||||
from fastapi.utils import (
|
||||
deep_dict_update,
|
||||
generate_operation_id_for_path,
|
||||
|
|
@ -64,7 +65,9 @@ status_code_ranges: Dict[str, str] = {
|
|||
}
|
||||
|
||||
|
||||
def get_openapi_security_definitions(flat_dependant: Dependant) -> Tuple[Dict, List]:
|
||||
def get_openapi_security_definitions(
|
||||
flat_dependant: Dependant,
|
||||
) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]:
|
||||
security_definitions = {}
|
||||
operation_security = []
|
||||
for security_requirement in flat_dependant.security_requirements:
|
||||
|
|
@ -88,13 +91,12 @@ def get_openapi_operation_parameters(
|
|||
for param in all_route_params:
|
||||
field_info = param.field_info
|
||||
field_info = cast(Param, field_info)
|
||||
# ignore mypy error until enum schemas are released
|
||||
parameter = {
|
||||
"name": param.alias,
|
||||
"in": field_info.in_.value,
|
||||
"required": param.required,
|
||||
"schema": field_schema(
|
||||
param, model_name_map=model_name_map, ref_prefix=REF_PREFIX # type: ignore
|
||||
param, model_name_map=model_name_map, ref_prefix=REF_PREFIX
|
||||
)[0],
|
||||
}
|
||||
if field_info.description:
|
||||
|
|
@ -109,13 +111,12 @@ def get_openapi_operation_request_body(
|
|||
*,
|
||||
body_field: Optional[ModelField],
|
||||
model_name_map: Dict[Union[Type[BaseModel], Type[Enum]], str],
|
||||
) -> Optional[Dict]:
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
if not body_field:
|
||||
return None
|
||||
assert isinstance(body_field, ModelField)
|
||||
# ignore mypy error until enum schemas are released
|
||||
body_schema, _, _ = field_schema(
|
||||
body_field, model_name_map=model_name_map, ref_prefix=REF_PREFIX # type: ignore
|
||||
body_field, model_name_map=model_name_map, ref_prefix=REF_PREFIX
|
||||
)
|
||||
field_info = cast(Body, body_field.field_info)
|
||||
request_media_type = field_info.media_type
|
||||
|
|
@ -140,7 +141,9 @@ def generate_operation_summary(*, route: routing.APIRoute, method: str) -> str:
|
|||
return route.name.replace("_", " ").title()
|
||||
|
||||
|
||||
def get_openapi_operation_metadata(*, route: routing.APIRoute, method: str) -> Dict:
|
||||
def get_openapi_operation_metadata(
|
||||
*, route: routing.APIRoute, method: str
|
||||
) -> Dict[str, Any]:
|
||||
operation: Dict[str, Any] = {}
|
||||
if route.tags:
|
||||
operation["tags"] = route.tags
|
||||
|
|
@ -154,14 +157,14 @@ def get_openapi_operation_metadata(*, route: routing.APIRoute, method: str) -> D
|
|||
|
||||
|
||||
def get_openapi_path(
|
||||
*, route: routing.APIRoute, model_name_map: Dict[Type, str]
|
||||
) -> Tuple[Dict, Dict, Dict]:
|
||||
*, route: routing.APIRoute, model_name_map: Dict[type, str]
|
||||
) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
|
||||
path = {}
|
||||
security_schemes: Dict[str, Any] = {}
|
||||
definitions: Dict[str, Any] = {}
|
||||
assert route.methods is not None, "Methods must be a list"
|
||||
if isinstance(route.response_class, DefaultPlaceholder):
|
||||
current_response_class: Type[routing.Response] = route.response_class.value
|
||||
current_response_class: Type[Response] = route.response_class.value
|
||||
else:
|
||||
current_response_class = route.response_class
|
||||
assert current_response_class, "A response class is needed to generate OpenAPI"
|
||||
|
|
@ -169,7 +172,7 @@ def get_openapi_path(
|
|||
if route.include_in_schema:
|
||||
for method in route.methods:
|
||||
operation = get_openapi_operation_metadata(route=route, method=method)
|
||||
parameters: List[Dict] = []
|
||||
parameters: List[Dict[str, Any]] = []
|
||||
flat_dependant = get_flat_dependant(route.dependant, skip_repeats=True)
|
||||
security_definitions, operation_security = get_openapi_security_definitions(
|
||||
flat_dependant=flat_dependant
|
||||
|
|
@ -196,10 +199,15 @@ def get_openapi_path(
|
|||
if route.callbacks:
|
||||
callbacks = {}
|
||||
for callback in route.callbacks:
|
||||
cb_path, cb_security_schemes, cb_definitions, = get_openapi_path(
|
||||
route=callback, model_name_map=model_name_map
|
||||
)
|
||||
callbacks[callback.name] = {callback.path: cb_path}
|
||||
if isinstance(callback, routing.APIRoute):
|
||||
(
|
||||
cb_path,
|
||||
cb_security_schemes,
|
||||
cb_definitions,
|
||||
) = get_openapi_path(
|
||||
route=callback, model_name_map=model_name_map
|
||||
)
|
||||
callbacks[callback.name] = {callback.path: cb_path}
|
||||
operation["callbacks"] = callbacks
|
||||
status_code = str(route.status_code)
|
||||
operation.setdefault("responses", {}).setdefault(status_code, {})[
|
||||
|
|
@ -332,21 +340,19 @@ def get_openapi(
|
|||
routes: Sequence[BaseRoute],
|
||||
tags: Optional[List[Dict[str, Any]]] = None,
|
||||
servers: Optional[List[Dict[str, Union[str, Any]]]] = None,
|
||||
) -> Dict:
|
||||
) -> Dict[str, Any]:
|
||||
info = {"title": title, "version": version}
|
||||
if description:
|
||||
info["description"] = description
|
||||
output: Dict[str, Any] = {"openapi": openapi_version, "info": info}
|
||||
if servers:
|
||||
output["servers"] = servers
|
||||
components: Dict[str, Dict] = {}
|
||||
paths: Dict[str, Dict] = {}
|
||||
components: Dict[str, Dict[str, Any]] = {}
|
||||
paths: Dict[str, Dict[str, Any]] = {}
|
||||
flat_models = get_flat_models_from_routes(routes)
|
||||
# ignore mypy error until enum schemas are released
|
||||
model_name_map = get_model_name_map(flat_models) # type: ignore
|
||||
# ignore mypy error until enum schemas are released
|
||||
model_name_map = get_model_name_map(flat_models)
|
||||
definitions = get_model_definitions(
|
||||
flat_models=flat_models, model_name_map=model_name_map # type: ignore
|
||||
flat_models=flat_models, model_name_map=model_name_map
|
||||
)
|
||||
for route in routes:
|
||||
if isinstance(route, routing.APIRoute):
|
||||
|
|
@ -368,4 +374,4 @@ def get_openapi(
|
|||
output["paths"] = paths
|
||||
if tags:
|
||||
output["tags"] = tags
|
||||
return jsonable_encoder(OpenAPI(**output), by_alias=True, exclude_none=True)
|
||||
return jsonable_encoder(OpenAPI(**output), by_alias=True, exclude_none=True) # type: ignore
|
||||
|
|
|
|||
|
|
@ -239,13 +239,13 @@ def File( # noqa: N802
|
|||
|
||||
|
||||
def Depends( # noqa: N802
|
||||
dependency: Optional[Callable] = None, *, use_cache: bool = True
|
||||
dependency: Optional[Callable[..., Any]] = None, *, use_cache: bool = True
|
||||
) -> Any:
|
||||
return params.Depends(dependency=dependency, use_cache=use_cache)
|
||||
|
||||
|
||||
def Security( # noqa: N802
|
||||
dependency: Optional[Callable] = None,
|
||||
dependency: Optional[Callable[..., Any]] = None,
|
||||
*,
|
||||
scopes: Optional[Sequence[str]] = None,
|
||||
use_cache: bool = True,
|
||||
|
|
|
|||
|
|
@ -315,7 +315,7 @@ class File(Form):
|
|||
|
||||
class Depends:
|
||||
def __init__(
|
||||
self, dependency: Optional[Callable] = None, *, use_cache: bool = True
|
||||
self, dependency: Optional[Callable[..., Any]] = None, *, use_cache: bool = True
|
||||
):
|
||||
self.dependency = dependency
|
||||
self.use_cache = use_cache
|
||||
|
|
@ -329,7 +329,7 @@ class Depends:
|
|||
class Security(Depends):
|
||||
def __init__(
|
||||
self,
|
||||
dependency: Optional[Callable] = None,
|
||||
dependency: Optional[Callable[..., Any]] = None,
|
||||
*,
|
||||
scopes: Optional[Sequence[str]] = None,
|
||||
use_cache: bool = True,
|
||||
|
|
|
|||
|
|
@ -1,13 +1,13 @@
|
|||
from typing import Any
|
||||
|
||||
from starlette.responses import FileResponse # noqa
|
||||
from starlette.responses import HTMLResponse # noqa
|
||||
from starlette.responses import JSONResponse # noqa
|
||||
from starlette.responses import PlainTextResponse # noqa
|
||||
from starlette.responses import RedirectResponse # noqa
|
||||
from starlette.responses import Response # noqa
|
||||
from starlette.responses import StreamingResponse # noqa
|
||||
from starlette.responses import UJSONResponse # noqa
|
||||
from starlette.responses import FileResponse as FileResponse # noqa
|
||||
from starlette.responses import HTMLResponse as HTMLResponse # noqa
|
||||
from starlette.responses import JSONResponse as JSONResponse # noqa
|
||||
from starlette.responses import PlainTextResponse as PlainTextResponse # noqa
|
||||
from starlette.responses import RedirectResponse as RedirectResponse # noqa
|
||||
from starlette.responses import Response as Response # noqa
|
||||
from starlette.responses import StreamingResponse as StreamingResponse # noqa
|
||||
from starlette.responses import UJSONResponse as UJSONResponse # noqa
|
||||
|
||||
try:
|
||||
import orjson
|
||||
|
|
|
|||
|
|
@ -2,7 +2,18 @@ import asyncio
|
|||
import enum
|
||||
import inspect
|
||||
import json
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Type, Union
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Coroutine,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
||||
from fastapi import params
|
||||
from fastapi.datastructures import Default, DefaultPlaceholder
|
||||
|
|
@ -16,6 +27,7 @@ from fastapi.dependencies.utils import (
|
|||
from fastapi.encoders import DictIntStrAny, SetIntStr, jsonable_encoder
|
||||
from fastapi.exceptions import RequestValidationError, WebSocketRequestValidationError
|
||||
from fastapi.openapi.constants import STATUS_CODES_WITH_NO_BODY
|
||||
from fastapi.types import DecoratedCallable
|
||||
from fastapi.utils import (
|
||||
create_cloned_field,
|
||||
create_response_field,
|
||||
|
|
@ -30,7 +42,8 @@ from starlette.concurrency import run_in_threadpool
|
|||
from starlette.exceptions import HTTPException
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse, Response
|
||||
from starlette.routing import Mount # noqa
|
||||
from starlette.routing import BaseRoute
|
||||
from starlette.routing import Mount as Mount # noqa
|
||||
from starlette.routing import (
|
||||
compile_path,
|
||||
get_name,
|
||||
|
|
@ -150,7 +163,7 @@ def get_request_handler(
|
|||
response_model_exclude_defaults: bool = False,
|
||||
response_model_exclude_none: bool = False,
|
||||
dependency_overrides_provider: Optional[Any] = None,
|
||||
) -> Callable:
|
||||
) -> Callable[[Request], Coroutine[Any, Any, Response]]:
|
||||
assert dependant.call is not None, "dependant.call must be a function"
|
||||
is_coroutine = asyncio.iscoroutinefunction(dependant.call)
|
||||
is_body_form = body_field and isinstance(body_field.field_info, params.Form)
|
||||
|
|
@ -207,7 +220,7 @@ def get_request_handler(
|
|||
response = actual_response_class(
|
||||
content=response_data,
|
||||
status_code=status_code,
|
||||
background=background_tasks,
|
||||
background=background_tasks, # type: ignore # in Starlette
|
||||
)
|
||||
response.headers.raw.extend(sub_response.headers.raw)
|
||||
if sub_response.status_code:
|
||||
|
|
@ -219,7 +232,7 @@ def get_request_handler(
|
|||
|
||||
def get_websocket_app(
|
||||
dependant: Dependant, dependency_overrides_provider: Optional[Any] = None
|
||||
) -> Callable:
|
||||
) -> Callable[[WebSocket], Coroutine[Any, Any, Any]]:
|
||||
async def app(websocket: WebSocket) -> None:
|
||||
solved_result = await solve_dependencies(
|
||||
request=websocket,
|
||||
|
|
@ -240,7 +253,7 @@ class APIWebSocketRoute(routing.WebSocketRoute):
|
|||
def __init__(
|
||||
self,
|
||||
path: str,
|
||||
endpoint: Callable,
|
||||
endpoint: Callable[..., Any],
|
||||
*,
|
||||
name: Optional[str] = None,
|
||||
dependency_overrides_provider: Optional[Any] = None,
|
||||
|
|
@ -262,7 +275,7 @@ class APIRoute(routing.Route):
|
|||
def __init__(
|
||||
self,
|
||||
path: str,
|
||||
endpoint: Callable,
|
||||
endpoint: Callable[..., Any],
|
||||
*,
|
||||
response_model: Optional[Type[Any]] = None,
|
||||
status_code: int = 200,
|
||||
|
|
@ -287,7 +300,7 @@ class APIRoute(routing.Route):
|
|||
JSONResponse
|
||||
),
|
||||
dependency_overrides_provider: Optional[Any] = None,
|
||||
callbacks: Optional[List["APIRoute"]] = None,
|
||||
callbacks: Optional[List[BaseRoute]] = None,
|
||||
) -> None:
|
||||
# normalise enums e.g. http.HTTPStatus
|
||||
if isinstance(status_code, enum.IntEnum):
|
||||
|
|
@ -298,7 +311,7 @@ class APIRoute(routing.Route):
|
|||
self.path_regex, self.path_format, self.param_convertors = compile_path(path)
|
||||
if methods is None:
|
||||
methods = ["GET"]
|
||||
self.methods = set([method.upper() for method in methods])
|
||||
self.methods: Set[str] = set([method.upper() for method in methods])
|
||||
self.unique_id = generate_operation_id_for_path(
|
||||
name=self.name, path=self.path_format, method=list(methods)[0]
|
||||
)
|
||||
|
|
@ -375,7 +388,7 @@ class APIRoute(routing.Route):
|
|||
self.callbacks = callbacks
|
||||
self.app = request_response(self.get_route_handler())
|
||||
|
||||
def get_route_handler(self) -> Callable:
|
||||
def get_route_handler(self) -> Callable[[Request], Coroutine[Any, Any, Response]]:
|
||||
return get_request_handler(
|
||||
dependant=self.dependant,
|
||||
body_field=self.body_field,
|
||||
|
|
@ -401,23 +414,23 @@ class APIRouter(routing.Router):
|
|||
dependencies: Optional[Sequence[params.Depends]] = None,
|
||||
default_response_class: Type[Response] = Default(JSONResponse),
|
||||
responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None,
|
||||
callbacks: Optional[List[APIRoute]] = None,
|
||||
callbacks: Optional[List[BaseRoute]] = None,
|
||||
routes: Optional[List[routing.BaseRoute]] = None,
|
||||
redirect_slashes: bool = True,
|
||||
default: Optional[ASGIApp] = None,
|
||||
dependency_overrides_provider: Optional[Any] = None,
|
||||
route_class: Type[APIRoute] = APIRoute,
|
||||
on_startup: Optional[Sequence[Callable]] = None,
|
||||
on_shutdown: Optional[Sequence[Callable]] = None,
|
||||
deprecated: bool = None,
|
||||
on_startup: Optional[Sequence[Callable[[], Any]]] = None,
|
||||
on_shutdown: Optional[Sequence[Callable[[], Any]]] = None,
|
||||
deprecated: Optional[bool] = None,
|
||||
include_in_schema: bool = True,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
routes=routes,
|
||||
routes=routes, # type: ignore # in Starlette
|
||||
redirect_slashes=redirect_slashes,
|
||||
default=default,
|
||||
on_startup=on_startup,
|
||||
on_shutdown=on_shutdown,
|
||||
default=default, # type: ignore # in Starlette
|
||||
on_startup=on_startup, # type: ignore # in Starlette
|
||||
on_shutdown=on_shutdown, # type: ignore # in Starlette
|
||||
)
|
||||
if prefix:
|
||||
assert prefix.startswith("/"), "A path prefix must start with '/'"
|
||||
|
|
@ -438,7 +451,7 @@ class APIRouter(routing.Router):
|
|||
def add_api_route(
|
||||
self,
|
||||
path: str,
|
||||
endpoint: Callable,
|
||||
endpoint: Callable[..., Any],
|
||||
*,
|
||||
response_model: Optional[Type[Any]] = None,
|
||||
status_code: int = 200,
|
||||
|
|
@ -463,7 +476,7 @@ class APIRouter(routing.Router):
|
|||
),
|
||||
name: Optional[str] = None,
|
||||
route_class_override: Optional[Type[APIRoute]] = None,
|
||||
callbacks: Optional[List[APIRoute]] = None,
|
||||
callbacks: Optional[List[BaseRoute]] = None,
|
||||
) -> None:
|
||||
route_class = route_class_override or self.route_class
|
||||
responses = responses or {}
|
||||
|
|
@ -532,9 +545,9 @@ class APIRouter(routing.Router):
|
|||
include_in_schema: bool = True,
|
||||
response_class: Type[Response] = Default(JSONResponse),
|
||||
name: Optional[str] = None,
|
||||
callbacks: Optional[List[APIRoute]] = None,
|
||||
) -> Callable:
|
||||
def decorator(func: Callable) -> Callable:
|
||||
callbacks: Optional[List[BaseRoute]] = None,
|
||||
) -> Callable[[DecoratedCallable], DecoratedCallable]:
|
||||
def decorator(func: DecoratedCallable) -> DecoratedCallable:
|
||||
self.add_api_route(
|
||||
path,
|
||||
func,
|
||||
|
|
@ -565,7 +578,7 @@ class APIRouter(routing.Router):
|
|||
return decorator
|
||||
|
||||
def add_api_websocket_route(
|
||||
self, path: str, endpoint: Callable, name: Optional[str] = None
|
||||
self, path: str, endpoint: Callable[..., Any], name: Optional[str] = None
|
||||
) -> None:
|
||||
route = APIWebSocketRoute(
|
||||
path,
|
||||
|
|
@ -575,8 +588,10 @@ class APIRouter(routing.Router):
|
|||
)
|
||||
self.routes.append(route)
|
||||
|
||||
def websocket(self, path: str, name: Optional[str] = None) -> Callable:
|
||||
def decorator(func: Callable) -> Callable:
|
||||
def websocket(
|
||||
self, path: str, name: Optional[str] = None
|
||||
) -> Callable[[DecoratedCallable], DecoratedCallable]:
|
||||
def decorator(func: DecoratedCallable) -> DecoratedCallable:
|
||||
self.add_api_websocket_route(path, func, name=name)
|
||||
return func
|
||||
|
||||
|
|
@ -591,8 +606,8 @@ class APIRouter(routing.Router):
|
|||
dependencies: Optional[Sequence[params.Depends]] = None,
|
||||
default_response_class: Type[Response] = Default(JSONResponse),
|
||||
responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None,
|
||||
callbacks: Optional[List[APIRoute]] = None,
|
||||
deprecated: bool = None,
|
||||
callbacks: Optional[List[BaseRoute]] = None,
|
||||
deprecated: Optional[bool] = None,
|
||||
include_in_schema: bool = True,
|
||||
) -> None:
|
||||
if prefix:
|
||||
|
|
@ -663,10 +678,11 @@ class APIRouter(routing.Router):
|
|||
callbacks=current_callbacks,
|
||||
)
|
||||
elif isinstance(route, routing.Route):
|
||||
methods = list(route.methods or []) # type: ignore # in Starlette
|
||||
self.add_route(
|
||||
prefix + route.path,
|
||||
route.endpoint,
|
||||
methods=list(route.methods or []),
|
||||
methods=methods,
|
||||
include_in_schema=route.include_in_schema,
|
||||
name=route.name,
|
||||
)
|
||||
|
|
@ -706,8 +722,8 @@ class APIRouter(routing.Router):
|
|||
include_in_schema: bool = True,
|
||||
response_class: Type[Response] = Default(JSONResponse),
|
||||
name: Optional[str] = None,
|
||||
callbacks: Optional[List[APIRoute]] = None,
|
||||
) -> Callable:
|
||||
callbacks: Optional[List[BaseRoute]] = None,
|
||||
) -> Callable[[DecoratedCallable], DecoratedCallable]:
|
||||
return self.api_route(
|
||||
path=path,
|
||||
response_model=response_model,
|
||||
|
|
@ -756,8 +772,8 @@ class APIRouter(routing.Router):
|
|||
include_in_schema: bool = True,
|
||||
response_class: Type[Response] = Default(JSONResponse),
|
||||
name: Optional[str] = None,
|
||||
callbacks: Optional[List[APIRoute]] = None,
|
||||
) -> Callable:
|
||||
callbacks: Optional[List[BaseRoute]] = None,
|
||||
) -> Callable[[DecoratedCallable], DecoratedCallable]:
|
||||
return self.api_route(
|
||||
path=path,
|
||||
response_model=response_model,
|
||||
|
|
@ -806,8 +822,8 @@ class APIRouter(routing.Router):
|
|||
include_in_schema: bool = True,
|
||||
response_class: Type[Response] = Default(JSONResponse),
|
||||
name: Optional[str] = None,
|
||||
callbacks: Optional[List[APIRoute]] = None,
|
||||
) -> Callable:
|
||||
callbacks: Optional[List[BaseRoute]] = None,
|
||||
) -> Callable[[DecoratedCallable], DecoratedCallable]:
|
||||
return self.api_route(
|
||||
path=path,
|
||||
response_model=response_model,
|
||||
|
|
@ -856,8 +872,8 @@ class APIRouter(routing.Router):
|
|||
include_in_schema: bool = True,
|
||||
response_class: Type[Response] = Default(JSONResponse),
|
||||
name: Optional[str] = None,
|
||||
callbacks: Optional[List[APIRoute]] = None,
|
||||
) -> Callable:
|
||||
callbacks: Optional[List[BaseRoute]] = None,
|
||||
) -> Callable[[DecoratedCallable], DecoratedCallable]:
|
||||
return self.api_route(
|
||||
path=path,
|
||||
response_model=response_model,
|
||||
|
|
@ -906,8 +922,8 @@ class APIRouter(routing.Router):
|
|||
include_in_schema: bool = True,
|
||||
response_class: Type[Response] = Default(JSONResponse),
|
||||
name: Optional[str] = None,
|
||||
callbacks: Optional[List[APIRoute]] = None,
|
||||
) -> Callable:
|
||||
callbacks: Optional[List[BaseRoute]] = None,
|
||||
) -> Callable[[DecoratedCallable], DecoratedCallable]:
|
||||
return self.api_route(
|
||||
path=path,
|
||||
response_model=response_model,
|
||||
|
|
@ -956,8 +972,8 @@ class APIRouter(routing.Router):
|
|||
include_in_schema: bool = True,
|
||||
response_class: Type[Response] = Default(JSONResponse),
|
||||
name: Optional[str] = None,
|
||||
callbacks: Optional[List[APIRoute]] = None,
|
||||
) -> Callable:
|
||||
callbacks: Optional[List[BaseRoute]] = None,
|
||||
) -> Callable[[DecoratedCallable], DecoratedCallable]:
|
||||
return self.api_route(
|
||||
path=path,
|
||||
response_model=response_model,
|
||||
|
|
@ -1006,8 +1022,8 @@ class APIRouter(routing.Router):
|
|||
include_in_schema: bool = True,
|
||||
response_class: Type[Response] = Default(JSONResponse),
|
||||
name: Optional[str] = None,
|
||||
callbacks: Optional[List[APIRoute]] = None,
|
||||
) -> Callable:
|
||||
callbacks: Optional[List[BaseRoute]] = None,
|
||||
) -> Callable[[DecoratedCallable], DecoratedCallable]:
|
||||
return self.api_route(
|
||||
path=path,
|
||||
response_model=response_model,
|
||||
|
|
@ -1056,8 +1072,8 @@ class APIRouter(routing.Router):
|
|||
include_in_schema: bool = True,
|
||||
response_class: Type[Response] = Default(JSONResponse),
|
||||
name: Optional[str] = None,
|
||||
callbacks: Optional[List[APIRoute]] = None,
|
||||
) -> Callable:
|
||||
callbacks: Optional[List[BaseRoute]] = None,
|
||||
) -> Callable[[DecoratedCallable], DecoratedCallable]:
|
||||
|
||||
return self.api_route(
|
||||
path=path,
|
||||
|
|
|
|||
|
|
@ -1,17 +1,15 @@
|
|||
from .api_key import APIKeyCookie, APIKeyHeader, APIKeyQuery
|
||||
from .http import (
|
||||
HTTPAuthorizationCredentials,
|
||||
HTTPBasic,
|
||||
HTTPBasicCredentials,
|
||||
HTTPBearer,
|
||||
HTTPDigest,
|
||||
)
|
||||
from .oauth2 import (
|
||||
OAuth2,
|
||||
OAuth2AuthorizationCodeBearer,
|
||||
OAuth2PasswordBearer,
|
||||
OAuth2PasswordRequestForm,
|
||||
OAuth2PasswordRequestFormStrict,
|
||||
SecurityScopes,
|
||||
)
|
||||
from .open_id_connect_url import OpenIdConnect
|
||||
from .api_key import APIKeyCookie as APIKeyCookie
|
||||
from .api_key import APIKeyHeader as APIKeyHeader
|
||||
from .api_key import APIKeyQuery as APIKeyQuery
|
||||
from .http import HTTPAuthorizationCredentials as HTTPAuthorizationCredentials
|
||||
from .http import HTTPBasic as HTTPBasic
|
||||
from .http import HTTPBasicCredentials as HTTPBasicCredentials
|
||||
from .http import HTTPBearer as HTTPBearer
|
||||
from .http import HTTPDigest as HTTPDigest
|
||||
from .oauth2 import OAuth2 as OAuth2
|
||||
from .oauth2 import OAuth2AuthorizationCodeBearer as OAuth2AuthorizationCodeBearer
|
||||
from .oauth2 import OAuth2PasswordBearer as OAuth2PasswordBearer
|
||||
from .oauth2 import OAuth2PasswordRequestForm as OAuth2PasswordRequestForm
|
||||
from .oauth2 import OAuth2PasswordRequestFormStrict as OAuth2PasswordRequestFormStrict
|
||||
from .oauth2 import SecurityScopes as SecurityScopes
|
||||
from .open_id_connect_url import OpenIdConnect as OpenIdConnect
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import List, Optional
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from fastapi.exceptions import HTTPException
|
||||
from fastapi.openapi.models import OAuth2 as OAuth2Model
|
||||
|
|
@ -116,7 +116,7 @@ class OAuth2(SecurityBase):
|
|||
def __init__(
|
||||
self,
|
||||
*,
|
||||
flows: OAuthFlowsModel = OAuthFlowsModel(),
|
||||
flows: Union[OAuthFlowsModel, Dict[str, Dict[str, Any]]] = OAuthFlowsModel(),
|
||||
scheme_name: Optional[str] = None,
|
||||
auto_error: Optional[bool] = True
|
||||
):
|
||||
|
|
@ -141,7 +141,7 @@ class OAuth2PasswordBearer(OAuth2):
|
|||
self,
|
||||
tokenUrl: str,
|
||||
scheme_name: Optional[str] = None,
|
||||
scopes: Optional[dict] = None,
|
||||
scopes: Optional[Dict[str, str]] = None,
|
||||
auto_error: bool = True,
|
||||
):
|
||||
if not scopes:
|
||||
|
|
@ -171,7 +171,7 @@ class OAuth2AuthorizationCodeBearer(OAuth2):
|
|||
tokenUrl: str,
|
||||
refreshUrl: Optional[str] = None,
|
||||
scheme_name: Optional[str] = None,
|
||||
scopes: Optional[dict] = None,
|
||||
scopes: Optional[Dict[str, str]] = None,
|
||||
auto_error: bool = True,
|
||||
):
|
||||
if not scopes:
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
from starlette.staticfiles import StaticFiles # noqa
|
||||
from starlette.staticfiles import StaticFiles as StaticFiles # noqa
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
from starlette.templating import Jinja2Templates # noqa
|
||||
from starlette.templating import Jinja2Templates as Jinja2Templates # noqa
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
from starlette.testclient import TestClient # noqa
|
||||
from starlette.testclient import TestClient as TestClient # noqa
|
||||
|
|
|
|||
|
|
@ -0,0 +1,3 @@
|
|||
from typing import Any, Callable, TypeVar
|
||||
|
||||
DecoratedCallable = TypeVar("DecoratedCallable", bound=Callable[..., Any])
|
||||
|
|
@ -19,11 +19,10 @@ def get_model_definitions(
|
|||
flat_models: Set[Union[Type[BaseModel], Type[Enum]]],
|
||||
model_name_map: Dict[Union[Type[BaseModel], Type[Enum]], str],
|
||||
) -> Dict[str, Any]:
|
||||
definitions: Dict[str, Dict] = {}
|
||||
definitions: Dict[str, Dict[str, Any]] = {}
|
||||
for model in flat_models:
|
||||
# ignore mypy error until enum schemas are released
|
||||
m_schema, m_definitions, m_nested_models = model_process_schema(
|
||||
model, model_name_map=model_name_map, ref_prefix=REF_PREFIX # type: ignore
|
||||
model, model_name_map=model_name_map, ref_prefix=REF_PREFIX
|
||||
)
|
||||
definitions.update(m_definitions)
|
||||
model_name = model_name_map[model]
|
||||
|
|
@ -80,7 +79,7 @@ def create_cloned_field(
|
|||
cloned_types = dict()
|
||||
original_type = field.type_
|
||||
if is_dataclass(original_type) and hasattr(original_type, "__pydantic_model__"):
|
||||
original_type = original_type.__pydantic_model__ # type: ignore
|
||||
original_type = original_type.__pydantic_model__
|
||||
use_type = original_type
|
||||
if lenient_issubclass(original_type, BaseModel):
|
||||
original_type = cast(Type[BaseModel], original_type)
|
||||
|
|
@ -127,7 +126,7 @@ def generate_operation_id_for_path(*, name: str, path: str, method: str) -> str:
|
|||
return operation_id
|
||||
|
||||
|
||||
def deep_dict_update(main_dict: dict, update_dict: dict) -> None:
|
||||
def deep_dict_update(main_dict: Dict[Any, Any], update_dict: Dict[Any, Any]) -> None:
|
||||
for key in update_dict:
|
||||
if (
|
||||
key in main_dict
|
||||
|
|
|
|||
|
|
@ -1,2 +1,2 @@
|
|||
from starlette.websockets import WebSocket # noqa
|
||||
from starlette.websockets import WebSocketDisconnect # noqa
|
||||
from starlette.websockets import WebSocket as WebSocket # noqa
|
||||
from starlette.websockets import WebSocketDisconnect as WebSocketDisconnect # noqa
|
||||
|
|
|
|||
22
mypy.ini
22
mypy.ini
|
|
@ -1,3 +1,25 @@
|
|||
[mypy]
|
||||
|
||||
# --strict
|
||||
disallow_any_generics = True
|
||||
disallow_subclassing_any = True
|
||||
disallow_untyped_calls = True
|
||||
disallow_untyped_defs = True
|
||||
disallow_incomplete_defs = True
|
||||
check_untyped_defs = True
|
||||
disallow_untyped_decorators = True
|
||||
no_implicit_optional = True
|
||||
warn_redundant_casts = True
|
||||
warn_unused_ignores = True
|
||||
warn_return_any = True
|
||||
implicit_reexport = False
|
||||
strict_equality = True
|
||||
# --strict end
|
||||
|
||||
[mypy-fastapi.concurrency]
|
||||
warn_unused_ignores = False
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-fastapi.tests.*]
|
||||
ignore_missing_imports = True
|
||||
check_untyped_defs = True
|
||||
|
|
|
|||
|
|
@ -46,9 +46,9 @@ test = [
|
|||
"pytest ==5.4.3",
|
||||
"pytest-cov ==2.10.0",
|
||||
"pytest-asyncio >=0.14.0,<0.15.0",
|
||||
"mypy ==0.782",
|
||||
"mypy ==0.790",
|
||||
"flake8 >=3.8.3,<4.0.0",
|
||||
"black ==19.10b0",
|
||||
"black ==20.8b1",
|
||||
"isort >=5.0.6,<6.0.0",
|
||||
"requests >=2.24.0,<3.0.0",
|
||||
"httpx >=0.14.0,<0.15.0",
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import pytest
|
|||
from fastapi import APIRouter, FastAPI
|
||||
from fastapi.routing import APIRoute
|
||||
from fastapi.testclient import TestClient
|
||||
from starlette.routing import Route
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
|
@ -106,9 +107,9 @@ def test_get_path(path, expected_status, expected_response):
|
|||
|
||||
def test_route_classes():
|
||||
routes = {}
|
||||
r: APIRoute
|
||||
for r in app.router.routes:
|
||||
assert isinstance(r, Route)
|
||||
routes[r.path] = r
|
||||
assert routes["/a/"].x_type == "A"
|
||||
assert routes["/a/b/"].x_type == "B"
|
||||
assert routes["/a/b/c/"].x_type == "C"
|
||||
assert getattr(routes["/a/"], "x_type") == "A"
|
||||
assert getattr(routes["/a/b/"], "x_type") == "B"
|
||||
assert getattr(routes["/a/b/c/"], "x_type") == "C"
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ app = FastAPI()
|
|||
|
||||
class Product(BaseModel):
|
||||
name: str
|
||||
description: str = None
|
||||
description: str = None # type: ignore
|
||||
price: float
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -175,7 +175,7 @@ async def path3_override_router2_override(level3: str):
|
|||
return level3
|
||||
|
||||
|
||||
@router2_override.get("/default3",)
|
||||
@router2_override.get("/default3")
|
||||
async def path3_default_router2_override(level3: str):
|
||||
return level3
|
||||
|
||||
|
|
@ -217,7 +217,9 @@ async def path5_override_router4_override(level5: str):
|
|||
return level5
|
||||
|
||||
|
||||
@router4_override.get("/default5",)
|
||||
@router4_override.get(
|
||||
"/default5",
|
||||
)
|
||||
async def path5_default_router4_override(level5: str):
|
||||
return level5
|
||||
|
||||
|
|
@ -238,7 +240,9 @@ async def path5_override_router4_default(level5: str):
|
|||
return level5
|
||||
|
||||
|
||||
@router4_default.get("/default5",)
|
||||
@router4_default.get(
|
||||
"/default5",
|
||||
)
|
||||
async def path5_default_router4_default(level5: str):
|
||||
return level5
|
||||
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ class MyUuid:
|
|||
def __str__(self):
|
||||
return self.uuid
|
||||
|
||||
@property
|
||||
@property # type: ignore
|
||||
def __class__(self):
|
||||
return uuid.UUID
|
||||
|
||||
|
|
|
|||
|
|
@ -71,7 +71,7 @@ class ModelWithAlias(BaseModel):
|
|||
|
||||
|
||||
class ModelWithDefault(BaseModel):
|
||||
foo: str = ...
|
||||
foo: str = ... # type: ignore
|
||||
bar: str = "bar"
|
||||
bla: str = "bla"
|
||||
|
||||
|
|
@ -88,7 +88,7 @@ def fixture_model_with_path(request):
|
|||
arbitrary_types_allowed = True
|
||||
|
||||
ModelWithPath = create_model(
|
||||
"ModelWithPath", path=(request.param, ...), __config__=Config
|
||||
"ModelWithPath", path=(request.param, ...), __config__=Config # type: ignore
|
||||
)
|
||||
return ModelWithPath(path=request.param("/foo", "bar"))
|
||||
|
||||
|
|
|
|||
|
|
@ -5,9 +5,9 @@ from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
|
|||
|
||||
def test_strings_in_generated_swagger():
|
||||
sig = inspect.signature(get_swagger_ui_html)
|
||||
swagger_js_url = sig.parameters.get("swagger_js_url").default
|
||||
swagger_css_url = sig.parameters.get("swagger_css_url").default
|
||||
swagger_favicon_url = sig.parameters.get("swagger_favicon_url").default
|
||||
swagger_js_url = sig.parameters.get("swagger_js_url").default # type: ignore
|
||||
swagger_css_url = sig.parameters.get("swagger_css_url").default # type: ignore
|
||||
swagger_favicon_url = sig.parameters.get("swagger_favicon_url").default # type: ignore
|
||||
html = get_swagger_ui_html(openapi_url="/docs", title="title")
|
||||
body_content = html.body.decode()
|
||||
assert swagger_js_url in body_content
|
||||
|
|
@ -34,8 +34,8 @@ def test_strings_in_custom_swagger():
|
|||
|
||||
def test_strings_in_generated_redoc():
|
||||
sig = inspect.signature(get_redoc_html)
|
||||
redoc_js_url = sig.parameters.get("redoc_js_url").default
|
||||
redoc_favicon_url = sig.parameters.get("redoc_favicon_url").default
|
||||
redoc_js_url = sig.parameters.get("redoc_js_url").default # type: ignore
|
||||
redoc_favicon_url = sig.parameters.get("redoc_favicon_url").default # type: ignore
|
||||
html = get_redoc_html(openapi_url="/docs", title="title")
|
||||
body_content = html.body.decode()
|
||||
assert redoc_js_url in body_content
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ app = FastAPI()
|
|||
|
||||
class Item(BaseModel):
|
||||
name: str
|
||||
age: condecimal(gt=Decimal(0.0))
|
||||
age: condecimal(gt=Decimal(0.0)) # type: ignore
|
||||
|
||||
|
||||
@app.post("/items/")
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ app = FastAPI()
|
|||
|
||||
|
||||
@app.get("/items/")
|
||||
def read_items(q: Optional[str] = Param(None)):
|
||||
def read_items(q: Optional[str] = Param(None)): # type: ignore
|
||||
return {"q": q}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,9 @@
|
|||
from typing import Any, List
|
||||
|
||||
import pytest
|
||||
from fastapi.params import Body, Cookie, Depends, Header, Param, Path, Query
|
||||
|
||||
test_data = ["teststr", None, ..., 1, []]
|
||||
test_data: List[Any] = ["teststr", None, ..., 1, []]
|
||||
|
||||
|
||||
def get_user():
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ def test_route_converters_int():
|
|||
response = client.get("/int/5")
|
||||
assert response.status_code == 200, response.text
|
||||
assert response.json() == {"int": 5}
|
||||
assert app.url_path_for("int_convertor", param=5) == "/int/5"
|
||||
assert app.url_path_for("int_convertor", param=5) == "/int/5" # type: ignore
|
||||
|
||||
|
||||
def test_route_converters_float():
|
||||
|
|
@ -35,7 +35,7 @@ def test_route_converters_float():
|
|||
response = client.get("/float/25.5")
|
||||
assert response.status_code == 200, response.text
|
||||
assert response.json() == {"float": 25.5}
|
||||
assert app.url_path_for("float_convertor", param=25.5) == "/float/25.5"
|
||||
assert app.url_path_for("float_convertor", param=25.5) == "/float/25.5" # type: ignore
|
||||
|
||||
|
||||
def test_route_converters_path():
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ invoices_callback_router = APIRouter()
|
|||
|
||||
|
||||
@invoices_callback_router.post(
|
||||
"{$callback_url}/invoices/{$request.body.id}", response_model=InvoiceEventReceived,
|
||||
"{$callback_url}/invoices/{$request.body.id}", response_model=InvoiceEventReceived
|
||||
)
|
||||
def invoice_notification(body: InvoiceEvent):
|
||||
pass # pragma: nocover
|
||||
|
|
|
|||
Loading…
Reference in New Issue