mirror of https://github.com/tiangolo/fastapi.git
✨ Update internal `AsyncExitStack` to fix context for dependencies with `yield` (#4575)
This commit is contained in:
parent
78b07cb809
commit
9d56a3cb59
|
|
@ -99,7 +99,7 @@ You saw that you can use dependencies with `yield` and have `try` blocks that ca
|
|||
|
||||
It might be tempting to raise an `HTTPException` or similar in the exit code, after the `yield`. But **it won't work**.
|
||||
|
||||
The exit code in dependencies with `yield` is executed *after* [Exception Handlers](../handling-errors.md#install-custom-exception-handlers){.internal-link target=_blank}. There's nothing catching exceptions thrown by your dependencies in the exit code (after the `yield`).
|
||||
The exit code in dependencies with `yield` is executed *after* the response is sent, so [Exception Handlers](../handling-errors.md#install-custom-exception-handlers){.internal-link target=_blank} will have already run. There's nothing catching exceptions thrown by your dependencies in the exit code (after the `yield`).
|
||||
|
||||
So, if you raise an `HTTPException` after the `yield`, the default (or any custom) exception handler that catches `HTTPException`s and returns an HTTP 400 response won't be there to catch that exception anymore.
|
||||
|
||||
|
|
@ -138,9 +138,11 @@ participant tasks as Background tasks
|
|||
end
|
||||
dep ->> operation: Run dependency, e.g. DB session
|
||||
opt raise
|
||||
operation -->> handler: Raise HTTPException
|
||||
operation -->> dep: Raise HTTPException
|
||||
dep -->> handler: Auto forward exception
|
||||
handler -->> client: HTTP error response
|
||||
operation -->> dep: Raise other exception
|
||||
dep -->> handler: Auto forward exception
|
||||
end
|
||||
operation ->> client: Return response to client
|
||||
Note over client,operation: Response is already sent, can't change it anymore
|
||||
|
|
@ -162,9 +164,9 @@ participant tasks as Background tasks
|
|||
After one of those responses is sent, no other response can be sent.
|
||||
|
||||
!!! tip
|
||||
This diagram shows `HTTPException`, but you could also raise any other exception for which you create a [Custom Exception Handler](../handling-errors.md#install-custom-exception-handlers){.internal-link target=_blank}. And that exception would be handled by that custom exception handler instead of the dependency exit code.
|
||||
This diagram shows `HTTPException`, but you could also raise any other exception for which you create a [Custom Exception Handler](../handling-errors.md#install-custom-exception-handlers){.internal-link target=_blank}.
|
||||
|
||||
But if you raise an exception that is not handled by the exception handlers, it will be handled by the exit code of the dependency.
|
||||
If you raise any exception, it will be passed to the dependencies with yield, including `HTTPException`, and then **again** to the exception handlers. If there's no exception handler for that exception, it will then be handled by the default internal `ServerErrorMiddleware`, returning a 500 HTTP status code, to let the client know that there was an error in the server.
|
||||
|
||||
## Context Managers
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@ from enum import Enum
|
|||
from typing import Any, Callable, Coroutine, Dict, List, Optional, Sequence, Type, Union
|
||||
|
||||
from fastapi import routing
|
||||
from fastapi.concurrency import AsyncExitStack
|
||||
from fastapi.datastructures import Default, DefaultPlaceholder
|
||||
from fastapi.encoders import DictIntStrAny, SetIntStr
|
||||
from fastapi.exception_handlers import (
|
||||
|
|
@ -11,6 +10,7 @@ from fastapi.exception_handlers import (
|
|||
)
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.logger import logger
|
||||
from fastapi.middleware.asyncexitstack import AsyncExitStackMiddleware
|
||||
from fastapi.openapi.docs import (
|
||||
get_redoc_html,
|
||||
get_swagger_ui_html,
|
||||
|
|
@ -21,8 +21,9 @@ from fastapi.params import Depends
|
|||
from fastapi.types import DecoratedCallable
|
||||
from starlette.applications import Starlette
|
||||
from starlette.datastructures import State
|
||||
from starlette.exceptions import HTTPException
|
||||
from starlette.exceptions import ExceptionMiddleware, HTTPException
|
||||
from starlette.middleware import Middleware
|
||||
from starlette.middleware.errors import ServerErrorMiddleware
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import HTMLResponse, JSONResponse, Response
|
||||
from starlette.routing import BaseRoute
|
||||
|
|
@ -134,6 +135,55 @@ class FastAPI(Starlette):
|
|||
self.openapi_schema: Optional[Dict[str, Any]] = None
|
||||
self.setup()
|
||||
|
||||
def build_middleware_stack(self) -> ASGIApp:
|
||||
# Duplicate/override from Starlette to add AsyncExitStackMiddleware
|
||||
# inside of ExceptionMiddleware, inside of custom user middlewares
|
||||
debug = self.debug
|
||||
error_handler = None
|
||||
exception_handlers = {}
|
||||
|
||||
for key, value in self.exception_handlers.items():
|
||||
if key in (500, Exception):
|
||||
error_handler = value
|
||||
else:
|
||||
exception_handlers[key] = value
|
||||
|
||||
middleware = (
|
||||
[Middleware(ServerErrorMiddleware, handler=error_handler, debug=debug)]
|
||||
+ self.user_middleware
|
||||
+ [
|
||||
Middleware(
|
||||
ExceptionMiddleware, handlers=exception_handlers, debug=debug
|
||||
),
|
||||
# Add FastAPI-specific AsyncExitStackMiddleware for dependencies with
|
||||
# contextvars.
|
||||
# This needs to happen after user middlewares because those create a
|
||||
# new contextvars context copy by using a new AnyIO task group.
|
||||
# The initial part of dependencies with yield is executed in the
|
||||
# FastAPI code, inside all the middlewares, but the teardown part
|
||||
# (after yield) is executed in the AsyncExitStack in this middleware,
|
||||
# if the AsyncExitStack lived outside of the custom middlewares and
|
||||
# contextvars were set in a dependency with yield in that internal
|
||||
# contextvars context, the values would not be available in the
|
||||
# outside context of the AsyncExitStack.
|
||||
# By putting the middleware and the AsyncExitStack here, inside all
|
||||
# user middlewares, the code before and after yield in dependencies
|
||||
# with yield is executed in the same contextvars context, so all values
|
||||
# set in contextvars before yield is still available after yield as
|
||||
# would be expected.
|
||||
# Additionally, by having this AsyncExitStack here, after the
|
||||
# ExceptionMiddleware, now dependencies can catch handled exceptions,
|
||||
# e.g. HTTPException, to customize the teardown code (e.g. DB session
|
||||
# rollback).
|
||||
Middleware(AsyncExitStackMiddleware),
|
||||
]
|
||||
)
|
||||
|
||||
app = self.router
|
||||
for cls, options in reversed(middleware):
|
||||
app = cls(app=app, **options)
|
||||
return app
|
||||
|
||||
def openapi(self) -> Dict[str, Any]:
|
||||
if not self.openapi_schema:
|
||||
self.openapi_schema = get_openapi(
|
||||
|
|
@ -206,12 +256,7 @@ class FastAPI(Starlette):
|
|||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
if self.root_path:
|
||||
scope["root_path"] = self.root_path
|
||||
if AsyncExitStack:
|
||||
async with AsyncExitStack() as stack:
|
||||
scope["fastapi_astack"] = stack
|
||||
await super().__call__(scope, receive, send)
|
||||
else:
|
||||
await super().__call__(scope, receive, send) # pragma: no cover
|
||||
|
||||
def add_api_route(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,28 @@
|
|||
from typing import Optional
|
||||
|
||||
from fastapi.concurrency import AsyncExitStack
|
||||
from starlette.types import ASGIApp, Receive, Scope, Send
|
||||
|
||||
|
||||
class AsyncExitStackMiddleware:
|
||||
def __init__(self, app: ASGIApp, context_name: str = "fastapi_astack") -> None:
|
||||
self.app = app
|
||||
self.context_name = context_name
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
if AsyncExitStack:
|
||||
dependency_exception: Optional[Exception] = None
|
||||
async with AsyncExitStack() as stack:
|
||||
scope[self.context_name] = stack
|
||||
try:
|
||||
await self.app(scope, receive, send)
|
||||
except Exception as e:
|
||||
dependency_exception = e
|
||||
raise e
|
||||
if dependency_exception:
|
||||
# This exception was possibly handled by the dependency but it should
|
||||
# still bubble up so that the ServerErrorMiddleware can return a 500
|
||||
# or the ExceptionMiddleware can catch and handle any other exceptions
|
||||
raise dependency_exception
|
||||
else:
|
||||
await self.app(scope, receive, send) # pragma: no cover
|
||||
|
|
@ -235,7 +235,16 @@ def test_sync_raise_other():
|
|||
assert "/sync_raise" not in errors
|
||||
|
||||
|
||||
def test_async_raise():
|
||||
def test_async_raise_raises():
|
||||
with pytest.raises(AsyncDependencyError):
|
||||
client.get("/async_raise")
|
||||
assert state["/async_raise"] == "asyncgen raise finalized"
|
||||
assert "/async_raise" in errors
|
||||
errors.clear()
|
||||
|
||||
|
||||
def test_async_raise_server_error():
|
||||
client = TestClient(app, raise_server_exceptions=False)
|
||||
response = client.get("/async_raise")
|
||||
assert response.status_code == 500, response.text
|
||||
assert state["/async_raise"] == "asyncgen raise finalized"
|
||||
|
|
@ -270,7 +279,16 @@ def test_background_tasks():
|
|||
assert state["bg"] == "bg set - b: started b - a: started a"
|
||||
|
||||
|
||||
def test_sync_raise():
|
||||
def test_sync_raise_raises():
|
||||
with pytest.raises(SyncDependencyError):
|
||||
client.get("/sync_raise")
|
||||
assert state["/sync_raise"] == "generator raise finalized"
|
||||
assert "/sync_raise" in errors
|
||||
errors.clear()
|
||||
|
||||
|
||||
def test_sync_raise_server_error():
|
||||
client = TestClient(app, raise_server_exceptions=False)
|
||||
response = client.get("/sync_raise")
|
||||
assert response.status_code == 500, response.text
|
||||
assert state["/sync_raise"] == "generator raise finalized"
|
||||
|
|
@ -306,7 +324,16 @@ def test_sync_sync_raise_other():
|
|||
assert "/sync_raise" not in errors
|
||||
|
||||
|
||||
def test_sync_async_raise():
|
||||
def test_sync_async_raise_raises():
|
||||
with pytest.raises(AsyncDependencyError):
|
||||
client.get("/sync_async_raise")
|
||||
assert state["/async_raise"] == "asyncgen raise finalized"
|
||||
assert "/async_raise" in errors
|
||||
errors.clear()
|
||||
|
||||
|
||||
def test_sync_async_raise_server_error():
|
||||
client = TestClient(app, raise_server_exceptions=False)
|
||||
response = client.get("/sync_async_raise")
|
||||
assert response.status_code == 500, response.text
|
||||
assert state["/async_raise"] == "asyncgen raise finalized"
|
||||
|
|
@ -314,7 +341,16 @@ def test_sync_async_raise():
|
|||
errors.clear()
|
||||
|
||||
|
||||
def test_sync_sync_raise():
|
||||
def test_sync_sync_raise_raises():
|
||||
with pytest.raises(SyncDependencyError):
|
||||
client.get("/sync_sync_raise")
|
||||
assert state["/sync_raise"] == "generator raise finalized"
|
||||
assert "/sync_raise" in errors
|
||||
errors.clear()
|
||||
|
||||
|
||||
def test_sync_sync_raise_server_error():
|
||||
client = TestClient(app, raise_server_exceptions=False)
|
||||
response = client.get("/sync_sync_raise")
|
||||
assert response.status_code == 500, response.text
|
||||
assert state["/sync_raise"] == "generator raise finalized"
|
||||
|
|
|
|||
|
|
@ -0,0 +1,51 @@
|
|||
from contextvars import ContextVar
|
||||
from typing import Any, Awaitable, Callable, Dict, Optional
|
||||
|
||||
from fastapi import Depends, FastAPI, Request, Response
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
legacy_request_state_context_var: ContextVar[Optional[Dict[str, Any]]] = ContextVar(
|
||||
"legacy_request_state_context_var", default=None
|
||||
)
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
async def set_up_request_state_dependency():
|
||||
request_state = {"user": "deadpond"}
|
||||
contextvar_token = legacy_request_state_context_var.set(request_state)
|
||||
yield request_state
|
||||
legacy_request_state_context_var.reset(contextvar_token)
|
||||
|
||||
|
||||
@app.middleware("http")
|
||||
async def custom_middleware(
|
||||
request: Request, call_next: Callable[[Request], Awaitable[Response]]
|
||||
):
|
||||
response = await call_next(request)
|
||||
response.headers["custom"] = "foo"
|
||||
return response
|
||||
|
||||
|
||||
@app.get("/user", dependencies=[Depends(set_up_request_state_dependency)])
|
||||
def get_user():
|
||||
request_state = legacy_request_state_context_var.get()
|
||||
assert request_state
|
||||
return request_state["user"]
|
||||
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
def test_dependency_contextvars():
|
||||
"""
|
||||
Check that custom middlewares don't affect the contextvar context for dependencies.
|
||||
|
||||
The code before yield and the code after yield should be run in the same contextvar
|
||||
context, so that request_state_context_var.reset(contextvar_token).
|
||||
|
||||
If they are run in a different context, that raises an error.
|
||||
"""
|
||||
response = client.get("/user")
|
||||
assert response.json() == "deadpond"
|
||||
assert response.headers["custom"] == "foo"
|
||||
|
|
@ -0,0 +1,71 @@
|
|||
import pytest
|
||||
from fastapi import Body, Depends, FastAPI, HTTPException
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
initial_fake_database = {"rick": "Rick Sanchez"}
|
||||
|
||||
fake_database = initial_fake_database.copy()
|
||||
|
||||
initial_state = {"except": False, "finally": False}
|
||||
|
||||
state = initial_state.copy()
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
async def get_database():
|
||||
temp_database = fake_database.copy()
|
||||
try:
|
||||
yield temp_database
|
||||
fake_database.update(temp_database)
|
||||
except HTTPException:
|
||||
state["except"] = True
|
||||
finally:
|
||||
state["finally"] = True
|
||||
|
||||
|
||||
@app.put("/invalid-user/{user_id}")
|
||||
def put_invalid_user(
|
||||
user_id: str, name: str = Body(...), db: dict = Depends(get_database)
|
||||
):
|
||||
db[user_id] = name
|
||||
raise HTTPException(status_code=400, detail="Invalid user")
|
||||
|
||||
|
||||
@app.put("/user/{user_id}")
|
||||
def put_user(user_id: str, name: str = Body(...), db: dict = Depends(get_database)):
|
||||
db[user_id] = name
|
||||
return {"message": "OK"}
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_state_and_db():
|
||||
global fake_database
|
||||
global state
|
||||
fake_database = initial_fake_database.copy()
|
||||
state = initial_state.copy()
|
||||
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
def test_dependency_gets_exception():
|
||||
assert state["except"] is False
|
||||
assert state["finally"] is False
|
||||
response = client.put("/invalid-user/rick", json="Morty")
|
||||
assert response.status_code == 400, response.text
|
||||
assert response.json() == {"detail": "Invalid user"}
|
||||
assert state["except"] is True
|
||||
assert state["finally"] is True
|
||||
assert fake_database["rick"] == "Rick Sanchez"
|
||||
|
||||
|
||||
def test_dependency_no_exception():
|
||||
assert state["except"] is False
|
||||
assert state["finally"] is False
|
||||
response = client.put("/user/rick", json="Morty")
|
||||
assert response.status_code == 200, response.text
|
||||
assert response.json() == {"message": "OK"}
|
||||
assert state["except"] is False
|
||||
assert state["finally"] is True
|
||||
assert fake_database["rick"] == "Morty"
|
||||
|
|
@ -1,3 +1,4 @@
|
|||
import pytest
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.testclient import TestClient
|
||||
|
|
@ -12,10 +13,15 @@ def request_validation_exception_handler(request, exception):
|
|||
return JSONResponse({"exception": "request-validation"})
|
||||
|
||||
|
||||
def server_error_exception_handler(request, exception):
|
||||
return JSONResponse(status_code=500, content={"exception": "server-error"})
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
exception_handlers={
|
||||
HTTPException: http_exception_handler,
|
||||
RequestValidationError: request_validation_exception_handler,
|
||||
Exception: server_error_exception_handler,
|
||||
}
|
||||
)
|
||||
|
||||
|
|
@ -32,6 +38,11 @@ def route_with_request_validation_exception(param: int):
|
|||
pass # pragma: no cover
|
||||
|
||||
|
||||
@app.get("/server-error")
|
||||
def route_with_server_error():
|
||||
raise RuntimeError("Oops!")
|
||||
|
||||
|
||||
def test_override_http_exception():
|
||||
response = client.get("/http-exception")
|
||||
assert response.status_code == 200
|
||||
|
|
@ -42,3 +53,15 @@ def test_override_request_validation_exception():
|
|||
response = client.get("/request-validation/invalid")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"exception": "request-validation"}
|
||||
|
||||
|
||||
def test_override_server_error_exception_raises():
|
||||
with pytest.raises(RuntimeError):
|
||||
client.get("/server-error")
|
||||
|
||||
|
||||
def test_override_server_error_exception_response():
|
||||
client = TestClient(app, raise_server_exceptions=False)
|
||||
response = client.get("/server-error")
|
||||
assert response.status_code == 500
|
||||
assert response.json() == {"exception": "server-error"}
|
||||
|
|
|
|||
Loading…
Reference in New Issue