mirror of https://github.com/tiangolo/fastapi.git
Add middleware parameter
This commit is contained in:
parent
d7c588d693
commit
c5cfae8791
|
|
@ -56,6 +56,7 @@ from pydantic import BaseModel
|
|||
from starlette import routing
|
||||
from starlette.concurrency import run_in_threadpool
|
||||
from starlette.exceptions import HTTPException
|
||||
from starlette.middleware import Middleware
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse, Response
|
||||
from starlette.routing import (
|
||||
|
|
@ -369,6 +370,7 @@ class APIWebSocketRoute(routing.WebSocketRoute):
|
|||
endpoint: Callable[..., Any],
|
||||
*,
|
||||
name: Optional[str] = None,
|
||||
middleware: Optional[Sequence[Middleware]] = None,
|
||||
dependencies: Optional[Sequence[params.Depends]] = None,
|
||||
dependency_overrides_provider: Optional[Any] = None,
|
||||
) -> None:
|
||||
|
|
@ -390,6 +392,9 @@ class APIWebSocketRoute(routing.WebSocketRoute):
|
|||
dependency_overrides_provider=dependency_overrides_provider,
|
||||
)
|
||||
)
|
||||
if middleware is not None:
|
||||
for cls, args, kwargs in reversed(middleware):
|
||||
self.app = cls(app=self.app, *args, **kwargs)
|
||||
|
||||
def matches(self, scope: Scope) -> Tuple[Match, Scope]:
|
||||
match, child_scope = super().matches(scope)
|
||||
|
|
@ -432,6 +437,7 @@ class APIRoute(routing.Route):
|
|||
generate_unique_id_function: Union[
|
||||
Callable[["APIRoute"], str], DefaultPlaceholder
|
||||
] = Default(generate_unique_id),
|
||||
middleware: Optional[Sequence[Middleware]] = None,
|
||||
) -> None:
|
||||
self.path = path
|
||||
self.endpoint = endpoint
|
||||
|
|
@ -462,6 +468,9 @@ class APIRoute(routing.Route):
|
|||
self.responses = responses or {}
|
||||
self.name = get_name(endpoint) if name is None else name
|
||||
self.path_regex, self.path_format, self.param_convertors = compile_path(path)
|
||||
if middleware is not None:
|
||||
for cls, args, kwargs in reversed(middleware):
|
||||
self.app = cls(app=self.app, *args, **kwargs)
|
||||
if methods is None:
|
||||
methods = ["GET"]
|
||||
self.methods: Set[str] = {method.upper() for method in methods}
|
||||
|
|
@ -795,6 +804,7 @@ class APIRouter(routing.Router):
|
|||
"""
|
||||
),
|
||||
] = Default(generate_unique_id),
|
||||
middleware: Optional[Sequence[Middleware]] = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
routes=routes,
|
||||
|
|
@ -803,6 +813,7 @@ class APIRouter(routing.Router):
|
|||
on_startup=on_startup,
|
||||
on_shutdown=on_shutdown,
|
||||
lifespan=lifespan,
|
||||
middleware=middleware,
|
||||
)
|
||||
if prefix:
|
||||
assert prefix.startswith("/"), "A path prefix must start with '/'"
|
||||
|
|
@ -873,6 +884,7 @@ class APIRouter(routing.Router):
|
|||
generate_unique_id_function: Union[
|
||||
Callable[[APIRoute], str], DefaultPlaceholder
|
||||
] = Default(generate_unique_id),
|
||||
middleware: Optional[Sequence[Middleware]] = None,
|
||||
) -> None:
|
||||
route_class = route_class_override or self.route_class
|
||||
responses = responses or {}
|
||||
|
|
@ -919,6 +931,7 @@ class APIRouter(routing.Router):
|
|||
callbacks=current_callbacks,
|
||||
openapi_extra=openapi_extra,
|
||||
generate_unique_id_function=current_generate_unique_id,
|
||||
middleware=middleware,
|
||||
)
|
||||
self.routes.append(route)
|
||||
|
||||
|
|
@ -951,6 +964,7 @@ class APIRouter(routing.Router):
|
|||
generate_unique_id_function: Callable[[APIRoute], str] = Default(
|
||||
generate_unique_id
|
||||
),
|
||||
middleware: Optional[Sequence[Middleware]] = None,
|
||||
) -> Callable[[DecoratedCallable], DecoratedCallable]:
|
||||
def decorator(func: DecoratedCallable) -> DecoratedCallable:
|
||||
self.add_api_route(
|
||||
|
|
@ -979,6 +993,7 @@ class APIRouter(routing.Router):
|
|||
callbacks=callbacks,
|
||||
openapi_extra=openapi_extra,
|
||||
generate_unique_id_function=generate_unique_id_function,
|
||||
middleware=middleware,
|
||||
)
|
||||
return func
|
||||
|
||||
|
|
@ -990,6 +1005,7 @@ class APIRouter(routing.Router):
|
|||
endpoint: Callable[..., Any],
|
||||
name: Optional[str] = None,
|
||||
*,
|
||||
middleware: Optional[Sequence[Middleware]] = None,
|
||||
dependencies: Optional[Sequence[params.Depends]] = None,
|
||||
) -> None:
|
||||
current_dependencies = self.dependencies.copy()
|
||||
|
|
@ -1000,6 +1016,7 @@ class APIRouter(routing.Router):
|
|||
self.prefix + path,
|
||||
endpoint=endpoint,
|
||||
name=name,
|
||||
middleware=middleware,
|
||||
dependencies=current_dependencies,
|
||||
dependency_overrides_provider=self.dependency_overrides_provider,
|
||||
)
|
||||
|
|
@ -1024,6 +1041,14 @@ class APIRouter(routing.Router):
|
|||
),
|
||||
] = None,
|
||||
*,
|
||||
middleware: Annotated[
|
||||
Optional[Sequence[Middleware]],
|
||||
Doc(
|
||||
"""
|
||||
A list of middleware to apply to the WebSocket.
|
||||
"""
|
||||
),
|
||||
] = None,
|
||||
dependencies: Annotated[
|
||||
Optional[Sequence[params.Depends]],
|
||||
Doc(
|
||||
|
|
@ -1066,7 +1091,7 @@ class APIRouter(routing.Router):
|
|||
|
||||
def decorator(func: DecoratedCallable) -> DecoratedCallable:
|
||||
self.add_api_websocket_route(
|
||||
path, func, name=name, dependencies=dependencies
|
||||
path, func, name=name, middleware=middleware, dependencies=dependencies
|
||||
)
|
||||
return func
|
||||
|
||||
|
|
@ -1192,6 +1217,15 @@ class APIRouter(routing.Router):
|
|||
"""
|
||||
),
|
||||
] = Default(generate_unique_id),
|
||||
middleware: Annotated[
|
||||
Optional[Sequence[Middleware]],
|
||||
Doc(
|
||||
"""
|
||||
A list of middleware to apply to all the *path operations* in this
|
||||
router.
|
||||
"""
|
||||
),
|
||||
] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Include another `APIRouter` in the same current `APIRouter`.
|
||||
|
|
@ -1290,6 +1324,7 @@ class APIRouter(routing.Router):
|
|||
callbacks=current_callbacks,
|
||||
openapi_extra=route.openapi_extra,
|
||||
generate_unique_id_function=current_generate_unique_id,
|
||||
middleware=middleware,
|
||||
)
|
||||
elif isinstance(route, routing.Route):
|
||||
methods = list(route.methods or [])
|
||||
|
|
|
|||
Loading…
Reference in New Issue