mirror of https://github.com/tiangolo/fastapi.git
✨ Add exception handler for `WebSocketRequestValidationError` (which also allows to override it) (#6030)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Sebastián Ramírez <tiangolo@gmail.com>
This commit is contained in:
parent
f5e2dd8025
commit
ab03f22635
|
|
@ -19,8 +19,9 @@ from fastapi.encoders import DictIntStrAny, SetIntStr
|
||||||
from fastapi.exception_handlers import (
|
from fastapi.exception_handlers import (
|
||||||
http_exception_handler,
|
http_exception_handler,
|
||||||
request_validation_exception_handler,
|
request_validation_exception_handler,
|
||||||
|
websocket_request_validation_exception_handler,
|
||||||
)
|
)
|
||||||
from fastapi.exceptions import RequestValidationError
|
from fastapi.exceptions import RequestValidationError, WebSocketRequestValidationError
|
||||||
from fastapi.logger import logger
|
from fastapi.logger import logger
|
||||||
from fastapi.middleware.asyncexitstack import AsyncExitStackMiddleware
|
from fastapi.middleware.asyncexitstack import AsyncExitStackMiddleware
|
||||||
from fastapi.openapi.docs import (
|
from fastapi.openapi.docs import (
|
||||||
|
|
@ -145,6 +146,11 @@ class FastAPI(Starlette):
|
||||||
self.exception_handlers.setdefault(
|
self.exception_handlers.setdefault(
|
||||||
RequestValidationError, request_validation_exception_handler
|
RequestValidationError, request_validation_exception_handler
|
||||||
)
|
)
|
||||||
|
self.exception_handlers.setdefault(
|
||||||
|
WebSocketRequestValidationError,
|
||||||
|
# Starlette still has incorrect type specification for the handlers
|
||||||
|
websocket_request_validation_exception_handler, # type: ignore
|
||||||
|
)
|
||||||
|
|
||||||
self.user_middleware: List[Middleware] = (
|
self.user_middleware: List[Middleware] = (
|
||||||
[] if middleware is None else list(middleware)
|
[] if middleware is None else list(middleware)
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,11 @@
|
||||||
from fastapi.encoders import jsonable_encoder
|
from fastapi.encoders import jsonable_encoder
|
||||||
from fastapi.exceptions import RequestValidationError
|
from fastapi.exceptions import RequestValidationError, WebSocketRequestValidationError
|
||||||
from fastapi.utils import is_body_allowed_for_status_code
|
from fastapi.utils import is_body_allowed_for_status_code
|
||||||
|
from fastapi.websockets import WebSocket
|
||||||
from starlette.exceptions import HTTPException
|
from starlette.exceptions import HTTPException
|
||||||
from starlette.requests import Request
|
from starlette.requests import Request
|
||||||
from starlette.responses import JSONResponse, Response
|
from starlette.responses import JSONResponse, Response
|
||||||
from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY
|
from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY, WS_1008_POLICY_VIOLATION
|
||||||
|
|
||||||
|
|
||||||
async def http_exception_handler(request: Request, exc: HTTPException) -> Response:
|
async def http_exception_handler(request: Request, exc: HTTPException) -> Response:
|
||||||
|
|
@ -23,3 +24,11 @@ async def request_validation_exception_handler(
|
||||||
status_code=HTTP_422_UNPROCESSABLE_ENTITY,
|
status_code=HTTP_422_UNPROCESSABLE_ENTITY,
|
||||||
content={"detail": jsonable_encoder(exc.errors())},
|
content={"detail": jsonable_encoder(exc.errors())},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def websocket_request_validation_exception_handler(
|
||||||
|
websocket: WebSocket, exc: WebSocketRequestValidationError
|
||||||
|
) -> None:
|
||||||
|
await websocket.close(
|
||||||
|
code=WS_1008_POLICY_VIOLATION, reason=jsonable_encoder(exc.errors())
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -56,7 +56,6 @@ from starlette.routing import (
|
||||||
request_response,
|
request_response,
|
||||||
websocket_session,
|
websocket_session,
|
||||||
)
|
)
|
||||||
from starlette.status import WS_1008_POLICY_VIOLATION
|
|
||||||
from starlette.types import ASGIApp, Lifespan, Scope
|
from starlette.types import ASGIApp, Lifespan, Scope
|
||||||
from starlette.websockets import WebSocket
|
from starlette.websockets import WebSocket
|
||||||
|
|
||||||
|
|
@ -283,7 +282,6 @@ def get_websocket_app(
|
||||||
)
|
)
|
||||||
values, errors, _, _2, _3 = solved_result
|
values, errors, _, _2, _3 = solved_result
|
||||||
if errors:
|
if errors:
|
||||||
await websocket.close(code=WS_1008_POLICY_VIOLATION)
|
|
||||||
raise WebSocketRequestValidationError(errors)
|
raise WebSocketRequestValidationError(errors)
|
||||||
assert dependant.call is not None, "dependant.call must be a function"
|
assert dependant.call is not None, "dependant.call must be a function"
|
||||||
await dependant.call(**values)
|
await dependant.call(**values)
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,16 @@
|
||||||
from fastapi import APIRouter, Depends, FastAPI, WebSocket
|
import functools
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi import (
|
||||||
|
APIRouter,
|
||||||
|
Depends,
|
||||||
|
FastAPI,
|
||||||
|
Header,
|
||||||
|
WebSocket,
|
||||||
|
WebSocketDisconnect,
|
||||||
|
status,
|
||||||
|
)
|
||||||
|
from fastapi.middleware import Middleware
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
@ -63,9 +75,44 @@ async def router_native_prefix_ws(websocket: WebSocket):
|
||||||
await websocket.close()
|
await websocket.close()
|
||||||
|
|
||||||
|
|
||||||
app.include_router(router)
|
async def ws_dependency_err():
|
||||||
app.include_router(prefix_router, prefix="/prefix")
|
raise NotImplementedError()
|
||||||
app.include_router(native_prefix_route)
|
|
||||||
|
|
||||||
|
@router.websocket("/depends-err/")
|
||||||
|
async def router_ws_depends_err(websocket: WebSocket, data=Depends(ws_dependency_err)):
|
||||||
|
pass # pragma: no cover
|
||||||
|
|
||||||
|
|
||||||
|
async def ws_dependency_validate(x_missing: str = Header()):
|
||||||
|
pass # pragma: no cover
|
||||||
|
|
||||||
|
|
||||||
|
@router.websocket("/depends-validate/")
|
||||||
|
async def router_ws_depends_validate(
|
||||||
|
websocket: WebSocket, data=Depends(ws_dependency_validate)
|
||||||
|
):
|
||||||
|
pass # pragma: no cover
|
||||||
|
|
||||||
|
|
||||||
|
class CustomError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@router.websocket("/custom_error/")
|
||||||
|
async def router_ws_custom_error(websocket: WebSocket):
|
||||||
|
raise CustomError()
|
||||||
|
|
||||||
|
|
||||||
|
def make_app(app=None, **kwargs):
|
||||||
|
app = app or FastAPI(**kwargs)
|
||||||
|
app.include_router(router)
|
||||||
|
app.include_router(prefix_router, prefix="/prefix")
|
||||||
|
app.include_router(native_prefix_route)
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
app = make_app(app)
|
||||||
|
|
||||||
|
|
||||||
def test_app():
|
def test_app():
|
||||||
|
|
@ -125,3 +172,100 @@ def test_router_with_params():
|
||||||
assert data == "path/to/file"
|
assert data == "path/to/file"
|
||||||
data = websocket.receive_text()
|
data = websocket.receive_text()
|
||||||
assert data == "a_query_param"
|
assert data == "a_query_param"
|
||||||
|
|
||||||
|
|
||||||
|
def test_wrong_uri():
|
||||||
|
"""
|
||||||
|
Verify that a websocket connection to a non-existent endpoing returns in a shutdown
|
||||||
|
"""
|
||||||
|
client = TestClient(app)
|
||||||
|
with pytest.raises(WebSocketDisconnect) as e:
|
||||||
|
with client.websocket_connect("/no-router/"):
|
||||||
|
pass # pragma: no cover
|
||||||
|
assert e.value.code == status.WS_1000_NORMAL_CLOSURE
|
||||||
|
|
||||||
|
|
||||||
|
def websocket_middleware(middleware_func):
|
||||||
|
"""
|
||||||
|
Helper to create a Starlette pure websocket middleware
|
||||||
|
"""
|
||||||
|
|
||||||
|
def middleware_constructor(app):
|
||||||
|
@functools.wraps(app)
|
||||||
|
async def wrapped_app(scope, receive, send):
|
||||||
|
if scope["type"] != "websocket":
|
||||||
|
return await app(scope, receive, send) # pragma: no cover
|
||||||
|
|
||||||
|
async def call_next():
|
||||||
|
return await app(scope, receive, send)
|
||||||
|
|
||||||
|
websocket = WebSocket(scope, receive=receive, send=send)
|
||||||
|
return await middleware_func(websocket, call_next)
|
||||||
|
|
||||||
|
return wrapped_app
|
||||||
|
|
||||||
|
return middleware_constructor
|
||||||
|
|
||||||
|
|
||||||
|
def test_depend_validation():
|
||||||
|
"""
|
||||||
|
Verify that a validation in a dependency invokes the correct exception handler
|
||||||
|
"""
|
||||||
|
caught = []
|
||||||
|
|
||||||
|
@websocket_middleware
|
||||||
|
async def catcher(websocket, call_next):
|
||||||
|
try:
|
||||||
|
return await call_next()
|
||||||
|
except Exception as e: # pragma: no cover
|
||||||
|
caught.append(e)
|
||||||
|
raise
|
||||||
|
|
||||||
|
myapp = make_app(middleware=[Middleware(catcher)])
|
||||||
|
|
||||||
|
client = TestClient(myapp)
|
||||||
|
with pytest.raises(WebSocketDisconnect) as e:
|
||||||
|
with client.websocket_connect("/depends-validate/"):
|
||||||
|
pass # pragma: no cover
|
||||||
|
# the validation error does produce a close message
|
||||||
|
assert e.value.code == status.WS_1008_POLICY_VIOLATION
|
||||||
|
# and no error is leaked
|
||||||
|
assert caught == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_depend_err_middleware():
|
||||||
|
"""
|
||||||
|
Verify that it is possible to write custom WebSocket middleware to catch errors
|
||||||
|
"""
|
||||||
|
|
||||||
|
@websocket_middleware
|
||||||
|
async def errorhandler(websocket: WebSocket, call_next):
|
||||||
|
try:
|
||||||
|
return await call_next()
|
||||||
|
except Exception as e:
|
||||||
|
await websocket.close(code=status.WS_1006_ABNORMAL_CLOSURE, reason=repr(e))
|
||||||
|
|
||||||
|
myapp = make_app(middleware=[Middleware(errorhandler)])
|
||||||
|
client = TestClient(myapp)
|
||||||
|
with pytest.raises(WebSocketDisconnect) as e:
|
||||||
|
with client.websocket_connect("/depends-err/"):
|
||||||
|
pass # pragma: no cover
|
||||||
|
assert e.value.code == status.WS_1006_ABNORMAL_CLOSURE
|
||||||
|
assert "NotImplementedError" in e.value.reason
|
||||||
|
|
||||||
|
|
||||||
|
def test_depend_err_handler():
|
||||||
|
"""
|
||||||
|
Verify that it is possible to write custom WebSocket middleware to catch errors
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def custom_handler(websocket: WebSocket, exc: CustomError) -> None:
|
||||||
|
await websocket.close(1002, "foo")
|
||||||
|
|
||||||
|
myapp = make_app(exception_handlers={CustomError: custom_handler})
|
||||||
|
client = TestClient(myapp)
|
||||||
|
with pytest.raises(WebSocketDisconnect) as e:
|
||||||
|
with client.websocket_connect("/custom_error/"):
|
||||||
|
pass # pragma: no cover
|
||||||
|
assert e.value.code == 1002
|
||||||
|
assert "foo" in e.value.reason
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue