mirror of https://github.com/tiangolo/fastapi.git
✨ Update internal `AsyncExitStack` to fix context for dependencies with `yield` (#4575)
This commit is contained in:
parent
78b07cb809
commit
9d56a3cb59
|
|
@ -99,7 +99,7 @@ You saw that you can use dependencies with `yield` and have `try` blocks that ca
|
||||||
|
|
||||||
It might be tempting to raise an `HTTPException` or similar in the exit code, after the `yield`. But **it won't work**.
|
It might be tempting to raise an `HTTPException` or similar in the exit code, after the `yield`. But **it won't work**.
|
||||||
|
|
||||||
The exit code in dependencies with `yield` is executed *after* [Exception Handlers](../handling-errors.md#install-custom-exception-handlers){.internal-link target=_blank}. There's nothing catching exceptions thrown by your dependencies in the exit code (after the `yield`).
|
The exit code in dependencies with `yield` is executed *after* the response is sent, so [Exception Handlers](../handling-errors.md#install-custom-exception-handlers){.internal-link target=_blank} will have already run. There's nothing catching exceptions thrown by your dependencies in the exit code (after the `yield`).
|
||||||
|
|
||||||
So, if you raise an `HTTPException` after the `yield`, the default (or any custom) exception handler that catches `HTTPException`s and returns an HTTP 400 response won't be there to catch that exception anymore.
|
So, if you raise an `HTTPException` after the `yield`, the default (or any custom) exception handler that catches `HTTPException`s and returns an HTTP 400 response won't be there to catch that exception anymore.
|
||||||
|
|
||||||
|
|
@ -138,9 +138,11 @@ participant tasks as Background tasks
|
||||||
end
|
end
|
||||||
dep ->> operation: Run dependency, e.g. DB session
|
dep ->> operation: Run dependency, e.g. DB session
|
||||||
opt raise
|
opt raise
|
||||||
operation -->> handler: Raise HTTPException
|
operation -->> dep: Raise HTTPException
|
||||||
|
dep -->> handler: Auto forward exception
|
||||||
handler -->> client: HTTP error response
|
handler -->> client: HTTP error response
|
||||||
operation -->> dep: Raise other exception
|
operation -->> dep: Raise other exception
|
||||||
|
dep -->> handler: Auto forward exception
|
||||||
end
|
end
|
||||||
operation ->> client: Return response to client
|
operation ->> client: Return response to client
|
||||||
Note over client,operation: Response is already sent, can't change it anymore
|
Note over client,operation: Response is already sent, can't change it anymore
|
||||||
|
|
@ -162,9 +164,9 @@ participant tasks as Background tasks
|
||||||
After one of those responses is sent, no other response can be sent.
|
After one of those responses is sent, no other response can be sent.
|
||||||
|
|
||||||
!!! tip
|
!!! tip
|
||||||
This diagram shows `HTTPException`, but you could also raise any other exception for which you create a [Custom Exception Handler](../handling-errors.md#install-custom-exception-handlers){.internal-link target=_blank}. And that exception would be handled by that custom exception handler instead of the dependency exit code.
|
This diagram shows `HTTPException`, but you could also raise any other exception for which you create a [Custom Exception Handler](../handling-errors.md#install-custom-exception-handlers){.internal-link target=_blank}.
|
||||||
|
|
||||||
But if you raise an exception that is not handled by the exception handlers, it will be handled by the exit code of the dependency.
|
If you raise any exception, it will be passed to the dependencies with yield, including `HTTPException`, and then **again** to the exception handlers. If there's no exception handler for that exception, it will then be handled by the default internal `ServerErrorMiddleware`, returning a 500 HTTP status code, to let the client know that there was an error in the server.
|
||||||
|
|
||||||
## Context Managers
|
## Context Managers
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,6 @@ from enum import Enum
|
||||||
from typing import Any, Callable, Coroutine, Dict, List, Optional, Sequence, Type, Union
|
from typing import Any, Callable, Coroutine, Dict, List, Optional, Sequence, Type, Union
|
||||||
|
|
||||||
from fastapi import routing
|
from fastapi import routing
|
||||||
from fastapi.concurrency import AsyncExitStack
|
|
||||||
from fastapi.datastructures import Default, DefaultPlaceholder
|
from fastapi.datastructures import Default, DefaultPlaceholder
|
||||||
from fastapi.encoders import DictIntStrAny, SetIntStr
|
from fastapi.encoders import DictIntStrAny, SetIntStr
|
||||||
from fastapi.exception_handlers import (
|
from fastapi.exception_handlers import (
|
||||||
|
|
@ -11,6 +10,7 @@ from fastapi.exception_handlers import (
|
||||||
)
|
)
|
||||||
from fastapi.exceptions import RequestValidationError
|
from fastapi.exceptions import RequestValidationError
|
||||||
from fastapi.logger import logger
|
from fastapi.logger import logger
|
||||||
|
from fastapi.middleware.asyncexitstack import AsyncExitStackMiddleware
|
||||||
from fastapi.openapi.docs import (
|
from fastapi.openapi.docs import (
|
||||||
get_redoc_html,
|
get_redoc_html,
|
||||||
get_swagger_ui_html,
|
get_swagger_ui_html,
|
||||||
|
|
@ -21,8 +21,9 @@ from fastapi.params import Depends
|
||||||
from fastapi.types import DecoratedCallable
|
from fastapi.types import DecoratedCallable
|
||||||
from starlette.applications import Starlette
|
from starlette.applications import Starlette
|
||||||
from starlette.datastructures import State
|
from starlette.datastructures import State
|
||||||
from starlette.exceptions import HTTPException
|
from starlette.exceptions import ExceptionMiddleware, HTTPException
|
||||||
from starlette.middleware import Middleware
|
from starlette.middleware import Middleware
|
||||||
|
from starlette.middleware.errors import ServerErrorMiddleware
|
||||||
from starlette.requests import Request
|
from starlette.requests import Request
|
||||||
from starlette.responses import HTMLResponse, JSONResponse, Response
|
from starlette.responses import HTMLResponse, JSONResponse, Response
|
||||||
from starlette.routing import BaseRoute
|
from starlette.routing import BaseRoute
|
||||||
|
|
@ -134,6 +135,55 @@ class FastAPI(Starlette):
|
||||||
self.openapi_schema: Optional[Dict[str, Any]] = None
|
self.openapi_schema: Optional[Dict[str, Any]] = None
|
||||||
self.setup()
|
self.setup()
|
||||||
|
|
||||||
|
def build_middleware_stack(self) -> ASGIApp:
|
||||||
|
# Duplicate/override from Starlette to add AsyncExitStackMiddleware
|
||||||
|
# inside of ExceptionMiddleware, inside of custom user middlewares
|
||||||
|
debug = self.debug
|
||||||
|
error_handler = None
|
||||||
|
exception_handlers = {}
|
||||||
|
|
||||||
|
for key, value in self.exception_handlers.items():
|
||||||
|
if key in (500, Exception):
|
||||||
|
error_handler = value
|
||||||
|
else:
|
||||||
|
exception_handlers[key] = value
|
||||||
|
|
||||||
|
middleware = (
|
||||||
|
[Middleware(ServerErrorMiddleware, handler=error_handler, debug=debug)]
|
||||||
|
+ self.user_middleware
|
||||||
|
+ [
|
||||||
|
Middleware(
|
||||||
|
ExceptionMiddleware, handlers=exception_handlers, debug=debug
|
||||||
|
),
|
||||||
|
# Add FastAPI-specific AsyncExitStackMiddleware for dependencies with
|
||||||
|
# contextvars.
|
||||||
|
# This needs to happen after user middlewares because those create a
|
||||||
|
# new contextvars context copy by using a new AnyIO task group.
|
||||||
|
# The initial part of dependencies with yield is executed in the
|
||||||
|
# FastAPI code, inside all the middlewares, but the teardown part
|
||||||
|
# (after yield) is executed in the AsyncExitStack in this middleware,
|
||||||
|
# if the AsyncExitStack lived outside of the custom middlewares and
|
||||||
|
# contextvars were set in a dependency with yield in that internal
|
||||||
|
# contextvars context, the values would not be available in the
|
||||||
|
# outside context of the AsyncExitStack.
|
||||||
|
# By putting the middleware and the AsyncExitStack here, inside all
|
||||||
|
# user middlewares, the code before and after yield in dependencies
|
||||||
|
# with yield is executed in the same contextvars context, so all values
|
||||||
|
# set in contextvars before yield is still available after yield as
|
||||||
|
# would be expected.
|
||||||
|
# Additionally, by having this AsyncExitStack here, after the
|
||||||
|
# ExceptionMiddleware, now dependencies can catch handled exceptions,
|
||||||
|
# e.g. HTTPException, to customize the teardown code (e.g. DB session
|
||||||
|
# rollback).
|
||||||
|
Middleware(AsyncExitStackMiddleware),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
app = self.router
|
||||||
|
for cls, options in reversed(middleware):
|
||||||
|
app = cls(app=app, **options)
|
||||||
|
return app
|
||||||
|
|
||||||
def openapi(self) -> Dict[str, Any]:
|
def openapi(self) -> Dict[str, Any]:
|
||||||
if not self.openapi_schema:
|
if not self.openapi_schema:
|
||||||
self.openapi_schema = get_openapi(
|
self.openapi_schema = get_openapi(
|
||||||
|
|
@ -206,12 +256,7 @@ class FastAPI(Starlette):
|
||||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||||
if self.root_path:
|
if self.root_path:
|
||||||
scope["root_path"] = self.root_path
|
scope["root_path"] = self.root_path
|
||||||
if AsyncExitStack:
|
|
||||||
async with AsyncExitStack() as stack:
|
|
||||||
scope["fastapi_astack"] = stack
|
|
||||||
await super().__call__(scope, receive, send)
|
await super().__call__(scope, receive, send)
|
||||||
else:
|
|
||||||
await super().__call__(scope, receive, send) # pragma: no cover
|
|
||||||
|
|
||||||
def add_api_route(
|
def add_api_route(
|
||||||
self,
|
self,
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,28 @@
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from fastapi.concurrency import AsyncExitStack
|
||||||
|
from starlette.types import ASGIApp, Receive, Scope, Send
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncExitStackMiddleware:
|
||||||
|
def __init__(self, app: ASGIApp, context_name: str = "fastapi_astack") -> None:
|
||||||
|
self.app = app
|
||||||
|
self.context_name = context_name
|
||||||
|
|
||||||
|
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||||
|
if AsyncExitStack:
|
||||||
|
dependency_exception: Optional[Exception] = None
|
||||||
|
async with AsyncExitStack() as stack:
|
||||||
|
scope[self.context_name] = stack
|
||||||
|
try:
|
||||||
|
await self.app(scope, receive, send)
|
||||||
|
except Exception as e:
|
||||||
|
dependency_exception = e
|
||||||
|
raise e
|
||||||
|
if dependency_exception:
|
||||||
|
# This exception was possibly handled by the dependency but it should
|
||||||
|
# still bubble up so that the ServerErrorMiddleware can return a 500
|
||||||
|
# or the ExceptionMiddleware can catch and handle any other exceptions
|
||||||
|
raise dependency_exception
|
||||||
|
else:
|
||||||
|
await self.app(scope, receive, send) # pragma: no cover
|
||||||
|
|
@ -235,7 +235,16 @@ def test_sync_raise_other():
|
||||||
assert "/sync_raise" not in errors
|
assert "/sync_raise" not in errors
|
||||||
|
|
||||||
|
|
||||||
def test_async_raise():
|
def test_async_raise_raises():
|
||||||
|
with pytest.raises(AsyncDependencyError):
|
||||||
|
client.get("/async_raise")
|
||||||
|
assert state["/async_raise"] == "asyncgen raise finalized"
|
||||||
|
assert "/async_raise" in errors
|
||||||
|
errors.clear()
|
||||||
|
|
||||||
|
|
||||||
|
def test_async_raise_server_error():
|
||||||
|
client = TestClient(app, raise_server_exceptions=False)
|
||||||
response = client.get("/async_raise")
|
response = client.get("/async_raise")
|
||||||
assert response.status_code == 500, response.text
|
assert response.status_code == 500, response.text
|
||||||
assert state["/async_raise"] == "asyncgen raise finalized"
|
assert state["/async_raise"] == "asyncgen raise finalized"
|
||||||
|
|
@ -270,7 +279,16 @@ def test_background_tasks():
|
||||||
assert state["bg"] == "bg set - b: started b - a: started a"
|
assert state["bg"] == "bg set - b: started b - a: started a"
|
||||||
|
|
||||||
|
|
||||||
def test_sync_raise():
|
def test_sync_raise_raises():
|
||||||
|
with pytest.raises(SyncDependencyError):
|
||||||
|
client.get("/sync_raise")
|
||||||
|
assert state["/sync_raise"] == "generator raise finalized"
|
||||||
|
assert "/sync_raise" in errors
|
||||||
|
errors.clear()
|
||||||
|
|
||||||
|
|
||||||
|
def test_sync_raise_server_error():
|
||||||
|
client = TestClient(app, raise_server_exceptions=False)
|
||||||
response = client.get("/sync_raise")
|
response = client.get("/sync_raise")
|
||||||
assert response.status_code == 500, response.text
|
assert response.status_code == 500, response.text
|
||||||
assert state["/sync_raise"] == "generator raise finalized"
|
assert state["/sync_raise"] == "generator raise finalized"
|
||||||
|
|
@ -306,7 +324,16 @@ def test_sync_sync_raise_other():
|
||||||
assert "/sync_raise" not in errors
|
assert "/sync_raise" not in errors
|
||||||
|
|
||||||
|
|
||||||
def test_sync_async_raise():
|
def test_sync_async_raise_raises():
|
||||||
|
with pytest.raises(AsyncDependencyError):
|
||||||
|
client.get("/sync_async_raise")
|
||||||
|
assert state["/async_raise"] == "asyncgen raise finalized"
|
||||||
|
assert "/async_raise" in errors
|
||||||
|
errors.clear()
|
||||||
|
|
||||||
|
|
||||||
|
def test_sync_async_raise_server_error():
|
||||||
|
client = TestClient(app, raise_server_exceptions=False)
|
||||||
response = client.get("/sync_async_raise")
|
response = client.get("/sync_async_raise")
|
||||||
assert response.status_code == 500, response.text
|
assert response.status_code == 500, response.text
|
||||||
assert state["/async_raise"] == "asyncgen raise finalized"
|
assert state["/async_raise"] == "asyncgen raise finalized"
|
||||||
|
|
@ -314,7 +341,16 @@ def test_sync_async_raise():
|
||||||
errors.clear()
|
errors.clear()
|
||||||
|
|
||||||
|
|
||||||
def test_sync_sync_raise():
|
def test_sync_sync_raise_raises():
|
||||||
|
with pytest.raises(SyncDependencyError):
|
||||||
|
client.get("/sync_sync_raise")
|
||||||
|
assert state["/sync_raise"] == "generator raise finalized"
|
||||||
|
assert "/sync_raise" in errors
|
||||||
|
errors.clear()
|
||||||
|
|
||||||
|
|
||||||
|
def test_sync_sync_raise_server_error():
|
||||||
|
client = TestClient(app, raise_server_exceptions=False)
|
||||||
response = client.get("/sync_sync_raise")
|
response = client.get("/sync_sync_raise")
|
||||||
assert response.status_code == 500, response.text
|
assert response.status_code == 500, response.text
|
||||||
assert state["/sync_raise"] == "generator raise finalized"
|
assert state["/sync_raise"] == "generator raise finalized"
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,51 @@
|
||||||
|
from contextvars import ContextVar
|
||||||
|
from typing import Any, Awaitable, Callable, Dict, Optional
|
||||||
|
|
||||||
|
from fastapi import Depends, FastAPI, Request, Response
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
legacy_request_state_context_var: ContextVar[Optional[Dict[str, Any]]] = ContextVar(
|
||||||
|
"legacy_request_state_context_var", default=None
|
||||||
|
)
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
|
||||||
|
async def set_up_request_state_dependency():
|
||||||
|
request_state = {"user": "deadpond"}
|
||||||
|
contextvar_token = legacy_request_state_context_var.set(request_state)
|
||||||
|
yield request_state
|
||||||
|
legacy_request_state_context_var.reset(contextvar_token)
|
||||||
|
|
||||||
|
|
||||||
|
@app.middleware("http")
|
||||||
|
async def custom_middleware(
|
||||||
|
request: Request, call_next: Callable[[Request], Awaitable[Response]]
|
||||||
|
):
|
||||||
|
response = await call_next(request)
|
||||||
|
response.headers["custom"] = "foo"
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/user", dependencies=[Depends(set_up_request_state_dependency)])
|
||||||
|
def get_user():
|
||||||
|
request_state = legacy_request_state_context_var.get()
|
||||||
|
assert request_state
|
||||||
|
return request_state["user"]
|
||||||
|
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
|
||||||
|
|
||||||
|
def test_dependency_contextvars():
|
||||||
|
"""
|
||||||
|
Check that custom middlewares don't affect the contextvar context for dependencies.
|
||||||
|
|
||||||
|
The code before yield and the code after yield should be run in the same contextvar
|
||||||
|
context, so that request_state_context_var.reset(contextvar_token).
|
||||||
|
|
||||||
|
If they are run in a different context, that raises an error.
|
||||||
|
"""
|
||||||
|
response = client.get("/user")
|
||||||
|
assert response.json() == "deadpond"
|
||||||
|
assert response.headers["custom"] == "foo"
|
||||||
|
|
@ -0,0 +1,71 @@
|
||||||
|
import pytest
|
||||||
|
from fastapi import Body, Depends, FastAPI, HTTPException
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
initial_fake_database = {"rick": "Rick Sanchez"}
|
||||||
|
|
||||||
|
fake_database = initial_fake_database.copy()
|
||||||
|
|
||||||
|
initial_state = {"except": False, "finally": False}
|
||||||
|
|
||||||
|
state = initial_state.copy()
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
|
||||||
|
async def get_database():
|
||||||
|
temp_database = fake_database.copy()
|
||||||
|
try:
|
||||||
|
yield temp_database
|
||||||
|
fake_database.update(temp_database)
|
||||||
|
except HTTPException:
|
||||||
|
state["except"] = True
|
||||||
|
finally:
|
||||||
|
state["finally"] = True
|
||||||
|
|
||||||
|
|
||||||
|
@app.put("/invalid-user/{user_id}")
|
||||||
|
def put_invalid_user(
|
||||||
|
user_id: str, name: str = Body(...), db: dict = Depends(get_database)
|
||||||
|
):
|
||||||
|
db[user_id] = name
|
||||||
|
raise HTTPException(status_code=400, detail="Invalid user")
|
||||||
|
|
||||||
|
|
||||||
|
@app.put("/user/{user_id}")
|
||||||
|
def put_user(user_id: str, name: str = Body(...), db: dict = Depends(get_database)):
|
||||||
|
db[user_id] = name
|
||||||
|
return {"message": "OK"}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def reset_state_and_db():
|
||||||
|
global fake_database
|
||||||
|
global state
|
||||||
|
fake_database = initial_fake_database.copy()
|
||||||
|
state = initial_state.copy()
|
||||||
|
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
|
||||||
|
|
||||||
|
def test_dependency_gets_exception():
|
||||||
|
assert state["except"] is False
|
||||||
|
assert state["finally"] is False
|
||||||
|
response = client.put("/invalid-user/rick", json="Morty")
|
||||||
|
assert response.status_code == 400, response.text
|
||||||
|
assert response.json() == {"detail": "Invalid user"}
|
||||||
|
assert state["except"] is True
|
||||||
|
assert state["finally"] is True
|
||||||
|
assert fake_database["rick"] == "Rick Sanchez"
|
||||||
|
|
||||||
|
|
||||||
|
def test_dependency_no_exception():
|
||||||
|
assert state["except"] is False
|
||||||
|
assert state["finally"] is False
|
||||||
|
response = client.put("/user/rick", json="Morty")
|
||||||
|
assert response.status_code == 200, response.text
|
||||||
|
assert response.json() == {"message": "OK"}
|
||||||
|
assert state["except"] is False
|
||||||
|
assert state["finally"] is True
|
||||||
|
assert fake_database["rick"] == "Morty"
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
|
import pytest
|
||||||
from fastapi import FastAPI, HTTPException
|
from fastapi import FastAPI, HTTPException
|
||||||
from fastapi.exceptions import RequestValidationError
|
from fastapi.exceptions import RequestValidationError
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
|
|
@ -12,10 +13,15 @@ def request_validation_exception_handler(request, exception):
|
||||||
return JSONResponse({"exception": "request-validation"})
|
return JSONResponse({"exception": "request-validation"})
|
||||||
|
|
||||||
|
|
||||||
|
def server_error_exception_handler(request, exception):
|
||||||
|
return JSONResponse(status_code=500, content={"exception": "server-error"})
|
||||||
|
|
||||||
|
|
||||||
app = FastAPI(
|
app = FastAPI(
|
||||||
exception_handlers={
|
exception_handlers={
|
||||||
HTTPException: http_exception_handler,
|
HTTPException: http_exception_handler,
|
||||||
RequestValidationError: request_validation_exception_handler,
|
RequestValidationError: request_validation_exception_handler,
|
||||||
|
Exception: server_error_exception_handler,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -32,6 +38,11 @@ def route_with_request_validation_exception(param: int):
|
||||||
pass # pragma: no cover
|
pass # pragma: no cover
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/server-error")
|
||||||
|
def route_with_server_error():
|
||||||
|
raise RuntimeError("Oops!")
|
||||||
|
|
||||||
|
|
||||||
def test_override_http_exception():
|
def test_override_http_exception():
|
||||||
response = client.get("/http-exception")
|
response = client.get("/http-exception")
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
|
|
@ -42,3 +53,15 @@ def test_override_request_validation_exception():
|
||||||
response = client.get("/request-validation/invalid")
|
response = client.get("/request-validation/invalid")
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.json() == {"exception": "request-validation"}
|
assert response.json() == {"exception": "request-validation"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_override_server_error_exception_raises():
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
client.get("/server-error")
|
||||||
|
|
||||||
|
|
||||||
|
def test_override_server_error_exception_response():
|
||||||
|
client = TestClient(app, raise_server_exceptions=False)
|
||||||
|
response = client.get("/server-error")
|
||||||
|
assert response.status_code == 500
|
||||||
|
assert response.json() == {"exception": "server-error"}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue