mirror of https://github.com/tiangolo/fastapi.git
📌 Upgrade Starlette version (#1057)
This commit is contained in:
parent
fbbed6fe81
commit
4e8080f290
|
|
@ -18,8 +18,8 @@ from fastapi.params import Depends
|
||||||
from fastapi.utils import warning_response_model_skip_defaults_deprecated
|
from fastapi.utils import warning_response_model_skip_defaults_deprecated
|
||||||
from starlette.applications import Starlette
|
from starlette.applications import Starlette
|
||||||
from starlette.datastructures import State
|
from starlette.datastructures import State
|
||||||
from starlette.exceptions import ExceptionMiddleware, HTTPException
|
from starlette.exceptions import HTTPException
|
||||||
from starlette.middleware.errors import ServerErrorMiddleware
|
from starlette.middleware import Middleware
|
||||||
from starlette.requests import Request
|
from starlette.requests import Request
|
||||||
from starlette.responses import HTMLResponse, JSONResponse, Response
|
from starlette.responses import HTMLResponse, JSONResponse, Response
|
||||||
from starlette.routing import BaseRoute
|
from starlette.routing import BaseRoute
|
||||||
|
|
@ -29,9 +29,9 @@ from starlette.types import Receive, Scope, Send
|
||||||
class FastAPI(Starlette):
|
class FastAPI(Starlette):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
*,
|
||||||
debug: bool = False,
|
debug: bool = False,
|
||||||
routes: List[BaseRoute] = None,
|
routes: List[BaseRoute] = None,
|
||||||
template_directory: str = None,
|
|
||||||
title: str = "FastAPI",
|
title: str = "FastAPI",
|
||||||
description: str = "",
|
description: str = "",
|
||||||
version: str = "0.1.0",
|
version: str = "0.1.0",
|
||||||
|
|
@ -42,19 +42,28 @@ class FastAPI(Starlette):
|
||||||
redoc_url: Optional[str] = "/redoc",
|
redoc_url: Optional[str] = "/redoc",
|
||||||
swagger_ui_oauth2_redirect_url: Optional[str] = "/docs/oauth2-redirect",
|
swagger_ui_oauth2_redirect_url: Optional[str] = "/docs/oauth2-redirect",
|
||||||
swagger_ui_init_oauth: Optional[dict] = None,
|
swagger_ui_init_oauth: Optional[dict] = None,
|
||||||
|
middleware: Sequence[Middleware] = None,
|
||||||
|
exception_handlers: Dict[Union[int, Type[Exception]], Callable] = None,
|
||||||
|
on_startup: Sequence[Callable] = None,
|
||||||
|
on_shutdown: Sequence[Callable] = None,
|
||||||
**extra: Dict[str, Any],
|
**extra: Dict[str, Any],
|
||||||
) -> None:
|
) -> None:
|
||||||
self.default_response_class = default_response_class
|
self.default_response_class = default_response_class
|
||||||
self._debug = debug
|
self._debug = debug
|
||||||
self.state = State()
|
self.state = State()
|
||||||
self.router: routing.APIRouter = routing.APIRouter(
|
self.router: routing.APIRouter = routing.APIRouter(
|
||||||
routes, dependency_overrides_provider=self
|
routes,
|
||||||
|
dependency_overrides_provider=self,
|
||||||
|
on_startup=on_startup,
|
||||||
|
on_shutdown=on_shutdown,
|
||||||
)
|
)
|
||||||
self.exception_middleware = ExceptionMiddleware(self.router, debug=debug)
|
self.exception_handlers = (
|
||||||
self.error_middleware = ServerErrorMiddleware(
|
{} if exception_handlers is None else dict(exception_handlers)
|
||||||
self.exception_middleware, debug=debug
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.user_middleware = [] if middleware is None else list(middleware)
|
||||||
|
self.middleware_stack = self.build_middleware_stack()
|
||||||
|
|
||||||
self.title = title
|
self.title = title
|
||||||
self.description = description
|
self.description = description
|
||||||
self.version = version
|
self.version = version
|
||||||
|
|
|
||||||
|
|
@ -346,9 +346,15 @@ class APIRouter(routing.Router):
|
||||||
dependency_overrides_provider: Any = None,
|
dependency_overrides_provider: Any = None,
|
||||||
route_class: Type[APIRoute] = APIRoute,
|
route_class: Type[APIRoute] = APIRoute,
|
||||||
default_response_class: Type[Response] = None,
|
default_response_class: Type[Response] = None,
|
||||||
|
on_startup: Sequence[Callable] = None,
|
||||||
|
on_shutdown: Sequence[Callable] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(
|
super().__init__(
|
||||||
routes=routes, redirect_slashes=redirect_slashes, default=default
|
routes=routes,
|
||||||
|
redirect_slashes=redirect_slashes,
|
||||||
|
default=default,
|
||||||
|
on_startup=on_startup,
|
||||||
|
on_shutdown=on_shutdown,
|
||||||
)
|
)
|
||||||
self.dependency_overrides_provider = dependency_overrides_provider
|
self.dependency_overrides_provider = dependency_overrides_provider
|
||||||
self.route_class = route_class
|
self.route_class = route_class
|
||||||
|
|
@ -552,6 +558,10 @@ class APIRouter(routing.Router):
|
||||||
self.add_websocket_route(
|
self.add_websocket_route(
|
||||||
prefix + route.path, route.endpoint, name=route.name
|
prefix + route.path, route.endpoint, name=route.name
|
||||||
)
|
)
|
||||||
|
for handler in router.on_startup:
|
||||||
|
self.add_event_handler("startup", handler)
|
||||||
|
for handler in router.on_shutdown:
|
||||||
|
self.add_event_handler("shutdown", handler)
|
||||||
|
|
||||||
def get(
|
def get(
|
||||||
self,
|
self,
|
||||||
|
|
|
||||||
|
|
@ -32,7 +32,7 @@ classifiers = [
|
||||||
"Topic :: Internet :: WWW/HTTP",
|
"Topic :: Internet :: WWW/HTTP",
|
||||||
]
|
]
|
||||||
requires = [
|
requires = [
|
||||||
"starlette >=0.12.9,<=0.12.9",
|
"starlette ==0.13.2",
|
||||||
"pydantic >=0.32.2,<2.0.0"
|
"pydantic >=0.32.2,<2.0.0"
|
||||||
]
|
]
|
||||||
description-file = "README.md"
|
description-file = "README.md"
|
||||||
|
|
|
||||||
|
|
@ -21,10 +21,12 @@ client = TestClient(app)
|
||||||
def test_use_empty():
|
def test_use_empty():
|
||||||
with client:
|
with client:
|
||||||
response = client.get("/prefix")
|
response = client.get("/prefix")
|
||||||
|
assert response.status_code == 200
|
||||||
assert response.json() == ["OK"]
|
assert response.json() == ["OK"]
|
||||||
|
|
||||||
response = client.get("/prefix/")
|
response = client.get("/prefix/")
|
||||||
assert response.status_code == 404
|
assert response.status_code == 200
|
||||||
|
assert response.json() == ["OK"]
|
||||||
|
|
||||||
|
|
||||||
def test_include_empty():
|
def test_include_empty():
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,87 @@
|
||||||
|
from fastapi import APIRouter, FastAPI
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from starlette.testclient import TestClient
|
||||||
|
|
||||||
|
|
||||||
|
class State(BaseModel):
|
||||||
|
app_startup: bool = False
|
||||||
|
app_shutdown: bool = False
|
||||||
|
router_startup: bool = False
|
||||||
|
router_shutdown: bool = False
|
||||||
|
sub_router_startup: bool = False
|
||||||
|
sub_router_shutdown: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
state = State()
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
|
||||||
|
@app.on_event("startup")
|
||||||
|
def app_startup():
|
||||||
|
state.app_startup = True
|
||||||
|
|
||||||
|
|
||||||
|
@app.on_event("shutdown")
|
||||||
|
def app_shutdown():
|
||||||
|
state.app_shutdown = True
|
||||||
|
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
@router.on_event("startup")
|
||||||
|
def router_startup():
|
||||||
|
state.router_startup = True
|
||||||
|
|
||||||
|
|
||||||
|
@router.on_event("shutdown")
|
||||||
|
def router_shutdown():
|
||||||
|
state.router_shutdown = True
|
||||||
|
|
||||||
|
|
||||||
|
sub_router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
@sub_router.on_event("startup")
|
||||||
|
def sub_router_startup():
|
||||||
|
state.sub_router_startup = True
|
||||||
|
|
||||||
|
|
||||||
|
@sub_router.on_event("shutdown")
|
||||||
|
def sub_router_shutdown():
|
||||||
|
state.sub_router_shutdown = True
|
||||||
|
|
||||||
|
|
||||||
|
@sub_router.get("/")
|
||||||
|
def main():
|
||||||
|
return {"message": "Hello World"}
|
||||||
|
|
||||||
|
|
||||||
|
router.include_router(sub_router)
|
||||||
|
app.include_router(router)
|
||||||
|
|
||||||
|
|
||||||
|
def test_router_events():
|
||||||
|
assert state.app_startup is False
|
||||||
|
assert state.router_startup is False
|
||||||
|
assert state.sub_router_startup is False
|
||||||
|
assert state.app_shutdown is False
|
||||||
|
assert state.router_shutdown is False
|
||||||
|
assert state.sub_router_shutdown is False
|
||||||
|
with TestClient(app) as client:
|
||||||
|
assert state.app_startup is True
|
||||||
|
assert state.router_startup is True
|
||||||
|
assert state.sub_router_startup is True
|
||||||
|
assert state.app_shutdown is False
|
||||||
|
assert state.router_shutdown is False
|
||||||
|
assert state.sub_router_shutdown is False
|
||||||
|
response = client.get("/")
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == {"message": "Hello World"}
|
||||||
|
assert state.app_startup is True
|
||||||
|
assert state.router_startup is True
|
||||||
|
assert state.sub_router_startup is True
|
||||||
|
assert state.app_shutdown is True
|
||||||
|
assert state.router_shutdown is True
|
||||||
|
assert state.sub_router_shutdown is True
|
||||||
Loading…
Reference in New Issue