mirror of https://github.com/tiangolo/fastapi.git
🚸 Improve tracebacks by adding endpoint metadata (#14306)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Sebastián Ramírez <tiangolo@gmail.com>
This commit is contained in:
parent
08b09e5236
commit
e1117f7550
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import Any, Dict, Optional, Sequence, Type, Union
|
from typing import Any, Dict, Optional, Sequence, Type, TypedDict, Union
|
||||||
|
|
||||||
from annotated_doc import Doc
|
from annotated_doc import Doc
|
||||||
from pydantic import BaseModel, create_model
|
from pydantic import BaseModel, create_model
|
||||||
|
|
@ -7,6 +7,13 @@ from starlette.exceptions import WebSocketException as StarletteWebSocketExcepti
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
|
|
||||||
|
class EndpointContext(TypedDict, total=False):
|
||||||
|
function: str
|
||||||
|
path: str
|
||||||
|
file: str
|
||||||
|
line: int
|
||||||
|
|
||||||
|
|
||||||
class HTTPException(StarletteHTTPException):
|
class HTTPException(StarletteHTTPException):
|
||||||
"""
|
"""
|
||||||
An HTTP exception you can raise in your own code to show errors to the client.
|
An HTTP exception you can raise in your own code to show errors to the client.
|
||||||
|
|
@ -155,30 +162,72 @@ class DependencyScopeError(FastAPIError):
|
||||||
|
|
||||||
|
|
||||||
class ValidationException(Exception):
|
class ValidationException(Exception):
|
||||||
def __init__(self, errors: Sequence[Any]) -> None:
|
def __init__(
|
||||||
|
self,
|
||||||
|
errors: Sequence[Any],
|
||||||
|
*,
|
||||||
|
endpoint_ctx: Optional[EndpointContext] = None,
|
||||||
|
) -> None:
|
||||||
self._errors = errors
|
self._errors = errors
|
||||||
|
self.endpoint_ctx = endpoint_ctx
|
||||||
|
|
||||||
|
ctx = endpoint_ctx or {}
|
||||||
|
self.endpoint_function = ctx.get("function")
|
||||||
|
self.endpoint_path = ctx.get("path")
|
||||||
|
self.endpoint_file = ctx.get("file")
|
||||||
|
self.endpoint_line = ctx.get("line")
|
||||||
|
|
||||||
def errors(self) -> Sequence[Any]:
|
def errors(self) -> Sequence[Any]:
|
||||||
return self._errors
|
return self._errors
|
||||||
|
|
||||||
|
def _format_endpoint_context(self) -> str:
|
||||||
|
if not (self.endpoint_file and self.endpoint_line and self.endpoint_function):
|
||||||
|
if self.endpoint_path:
|
||||||
|
return f"\n Endpoint: {self.endpoint_path}"
|
||||||
|
return ""
|
||||||
|
|
||||||
|
context = f'\n File "{self.endpoint_file}", line {self.endpoint_line}, in {self.endpoint_function}'
|
||||||
|
if self.endpoint_path:
|
||||||
|
context += f"\n {self.endpoint_path}"
|
||||||
|
return context
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
message = f"{len(self._errors)} validation error{'s' if len(self._errors) != 1 else ''}:\n"
|
||||||
|
for err in self._errors:
|
||||||
|
message += f" {err}\n"
|
||||||
|
message += self._format_endpoint_context()
|
||||||
|
return message.rstrip()
|
||||||
|
|
||||||
|
|
||||||
class RequestValidationError(ValidationException):
|
class RequestValidationError(ValidationException):
|
||||||
def __init__(self, errors: Sequence[Any], *, body: Any = None) -> None:
|
def __init__(
|
||||||
super().__init__(errors)
|
self,
|
||||||
|
errors: Sequence[Any],
|
||||||
|
*,
|
||||||
|
body: Any = None,
|
||||||
|
endpoint_ctx: Optional[EndpointContext] = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(errors, endpoint_ctx=endpoint_ctx)
|
||||||
self.body = body
|
self.body = body
|
||||||
|
|
||||||
|
|
||||||
class WebSocketRequestValidationError(ValidationException):
|
class WebSocketRequestValidationError(ValidationException):
|
||||||
pass
|
def __init__(
|
||||||
|
self,
|
||||||
|
errors: Sequence[Any],
|
||||||
|
*,
|
||||||
|
endpoint_ctx: Optional[EndpointContext] = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(errors, endpoint_ctx=endpoint_ctx)
|
||||||
|
|
||||||
|
|
||||||
class ResponseValidationError(ValidationException):
|
class ResponseValidationError(ValidationException):
|
||||||
def __init__(self, errors: Sequence[Any], *, body: Any = None) -> None:
|
def __init__(
|
||||||
super().__init__(errors)
|
self,
|
||||||
|
errors: Sequence[Any],
|
||||||
|
*,
|
||||||
|
body: Any = None,
|
||||||
|
endpoint_ctx: Optional[EndpointContext] = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(errors, endpoint_ctx=endpoint_ctx)
|
||||||
self.body = body
|
self.body = body
|
||||||
|
|
||||||
def __str__(self) -> str:
|
|
||||||
message = f"{len(self._errors)} validation errors:\n"
|
|
||||||
for err in self._errors:
|
|
||||||
message += f" {err}\n"
|
|
||||||
return message
|
|
||||||
|
|
|
||||||
|
|
@ -46,6 +46,7 @@ from fastapi.dependencies.utils import (
|
||||||
)
|
)
|
||||||
from fastapi.encoders import jsonable_encoder
|
from fastapi.encoders import jsonable_encoder
|
||||||
from fastapi.exceptions import (
|
from fastapi.exceptions import (
|
||||||
|
EndpointContext,
|
||||||
FastAPIError,
|
FastAPIError,
|
||||||
RequestValidationError,
|
RequestValidationError,
|
||||||
ResponseValidationError,
|
ResponseValidationError,
|
||||||
|
|
@ -212,6 +213,33 @@ def _merge_lifespan_context(
|
||||||
return merged_lifespan # type: ignore[return-value]
|
return merged_lifespan # type: ignore[return-value]
|
||||||
|
|
||||||
|
|
||||||
|
# Cache for endpoint context to avoid re-extracting on every request
|
||||||
|
_endpoint_context_cache: Dict[int, EndpointContext] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_endpoint_context(func: Any) -> EndpointContext:
|
||||||
|
"""Extract endpoint context with caching to avoid repeated file I/O."""
|
||||||
|
func_id = id(func)
|
||||||
|
|
||||||
|
if func_id in _endpoint_context_cache:
|
||||||
|
return _endpoint_context_cache[func_id]
|
||||||
|
|
||||||
|
try:
|
||||||
|
ctx: EndpointContext = {}
|
||||||
|
|
||||||
|
if (source_file := inspect.getsourcefile(func)) is not None:
|
||||||
|
ctx["file"] = source_file
|
||||||
|
if (line_number := inspect.getsourcelines(func)[1]) is not None:
|
||||||
|
ctx["line"] = line_number
|
||||||
|
if (func_name := getattr(func, "__name__", None)) is not None:
|
||||||
|
ctx["function"] = func_name
|
||||||
|
except Exception:
|
||||||
|
ctx = EndpointContext()
|
||||||
|
|
||||||
|
_endpoint_context_cache[func_id] = ctx
|
||||||
|
return ctx
|
||||||
|
|
||||||
|
|
||||||
async def serialize_response(
|
async def serialize_response(
|
||||||
*,
|
*,
|
||||||
field: Optional[ModelField] = None,
|
field: Optional[ModelField] = None,
|
||||||
|
|
@ -223,6 +251,7 @@ async def serialize_response(
|
||||||
exclude_defaults: bool = False,
|
exclude_defaults: bool = False,
|
||||||
exclude_none: bool = False,
|
exclude_none: bool = False,
|
||||||
is_coroutine: bool = True,
|
is_coroutine: bool = True,
|
||||||
|
endpoint_ctx: Optional[EndpointContext] = None,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
if field:
|
if field:
|
||||||
errors = []
|
errors = []
|
||||||
|
|
@ -245,8 +274,11 @@ async def serialize_response(
|
||||||
elif errors_:
|
elif errors_:
|
||||||
errors.append(errors_)
|
errors.append(errors_)
|
||||||
if errors:
|
if errors:
|
||||||
|
ctx = endpoint_ctx or EndpointContext()
|
||||||
raise ResponseValidationError(
|
raise ResponseValidationError(
|
||||||
errors=_normalize_errors(errors), body=response_content
|
errors=_normalize_errors(errors),
|
||||||
|
body=response_content,
|
||||||
|
endpoint_ctx=ctx,
|
||||||
)
|
)
|
||||||
|
|
||||||
if hasattr(field, "serialize"):
|
if hasattr(field, "serialize"):
|
||||||
|
|
@ -318,6 +350,18 @@ def get_request_handler(
|
||||||
"fastapi_middleware_astack not found in request scope"
|
"fastapi_middleware_astack not found in request scope"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Extract endpoint context for error messages
|
||||||
|
endpoint_ctx = (
|
||||||
|
_extract_endpoint_context(dependant.call)
|
||||||
|
if dependant.call
|
||||||
|
else EndpointContext()
|
||||||
|
)
|
||||||
|
|
||||||
|
if dependant.path:
|
||||||
|
# For mounted sub-apps, include the mount path prefix
|
||||||
|
mount_path = request.scope.get("root_path", "").rstrip("/")
|
||||||
|
endpoint_ctx["path"] = f"{request.method} {mount_path}{dependant.path}"
|
||||||
|
|
||||||
# Read body and auto-close files
|
# Read body and auto-close files
|
||||||
try:
|
try:
|
||||||
body: Any = None
|
body: Any = None
|
||||||
|
|
@ -355,6 +399,7 @@ def get_request_handler(
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
body=e.doc,
|
body=e.doc,
|
||||||
|
endpoint_ctx=endpoint_ctx,
|
||||||
)
|
)
|
||||||
raise validation_error from e
|
raise validation_error from e
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
|
|
@ -414,6 +459,7 @@ def get_request_handler(
|
||||||
exclude_defaults=response_model_exclude_defaults,
|
exclude_defaults=response_model_exclude_defaults,
|
||||||
exclude_none=response_model_exclude_none,
|
exclude_none=response_model_exclude_none,
|
||||||
is_coroutine=is_coroutine,
|
is_coroutine=is_coroutine,
|
||||||
|
endpoint_ctx=endpoint_ctx,
|
||||||
)
|
)
|
||||||
response = actual_response_class(content, **response_args)
|
response = actual_response_class(content, **response_args)
|
||||||
if not is_body_allowed_for_status_code(response.status_code):
|
if not is_body_allowed_for_status_code(response.status_code):
|
||||||
|
|
@ -421,7 +467,7 @@ def get_request_handler(
|
||||||
response.headers.raw.extend(solved_result.response.headers.raw)
|
response.headers.raw.extend(solved_result.response.headers.raw)
|
||||||
if errors:
|
if errors:
|
||||||
validation_error = RequestValidationError(
|
validation_error = RequestValidationError(
|
||||||
_normalize_errors(errors), body=body
|
_normalize_errors(errors), body=body, endpoint_ctx=endpoint_ctx
|
||||||
)
|
)
|
||||||
raise validation_error
|
raise validation_error
|
||||||
|
|
||||||
|
|
@ -438,6 +484,15 @@ def get_websocket_app(
|
||||||
embed_body_fields: bool = False,
|
embed_body_fields: bool = False,
|
||||||
) -> Callable[[WebSocket], Coroutine[Any, Any, Any]]:
|
) -> Callable[[WebSocket], Coroutine[Any, Any, Any]]:
|
||||||
async def app(websocket: WebSocket) -> None:
|
async def app(websocket: WebSocket) -> None:
|
||||||
|
endpoint_ctx = (
|
||||||
|
_extract_endpoint_context(dependant.call)
|
||||||
|
if dependant.call
|
||||||
|
else EndpointContext()
|
||||||
|
)
|
||||||
|
if dependant.path:
|
||||||
|
# For mounted sub-apps, include the mount path prefix
|
||||||
|
mount_path = websocket.scope.get("root_path", "").rstrip("/")
|
||||||
|
endpoint_ctx["path"] = f"WS {mount_path}{dependant.path}"
|
||||||
async_exit_stack = websocket.scope.get("fastapi_inner_astack")
|
async_exit_stack = websocket.scope.get("fastapi_inner_astack")
|
||||||
assert isinstance(async_exit_stack, AsyncExitStack), (
|
assert isinstance(async_exit_stack, AsyncExitStack), (
|
||||||
"fastapi_inner_astack not found in request scope"
|
"fastapi_inner_astack not found in request scope"
|
||||||
|
|
@ -451,7 +506,8 @@ def get_websocket_app(
|
||||||
)
|
)
|
||||||
if solved_result.errors:
|
if solved_result.errors:
|
||||||
raise WebSocketRequestValidationError(
|
raise WebSocketRequestValidationError(
|
||||||
_normalize_errors(solved_result.errors)
|
_normalize_errors(solved_result.errors),
|
||||||
|
endpoint_ctx=endpoint_ctx,
|
||||||
)
|
)
|
||||||
assert dependant.call is not None, "dependant.call must be a function"
|
assert dependant.call is not None, "dependant.call must be a function"
|
||||||
await dependant.call(**solved_result.values)
|
await dependant.call(**solved_result.values)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,168 @@
|
||||||
|
from fastapi import FastAPI, Request, WebSocket
|
||||||
|
from fastapi.exceptions import (
|
||||||
|
RequestValidationError,
|
||||||
|
ResponseValidationError,
|
||||||
|
WebSocketRequestValidationError,
|
||||||
|
)
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class Item(BaseModel):
|
||||||
|
id: int
|
||||||
|
name: str
|
||||||
|
|
||||||
|
|
||||||
|
class ExceptionCapture:
|
||||||
|
def __init__(self):
|
||||||
|
self.exception = None
|
||||||
|
|
||||||
|
def capture(self, exc):
|
||||||
|
self.exception = exc
|
||||||
|
return exc
|
||||||
|
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
sub_app = FastAPI()
|
||||||
|
captured_exception = ExceptionCapture()
|
||||||
|
|
||||||
|
app.mount(path="/sub", app=sub_app)
|
||||||
|
|
||||||
|
|
||||||
|
@app.exception_handler(RequestValidationError)
|
||||||
|
@sub_app.exception_handler(RequestValidationError)
|
||||||
|
async def request_validation_handler(request: Request, exc: RequestValidationError):
|
||||||
|
captured_exception.capture(exc)
|
||||||
|
raise exc
|
||||||
|
|
||||||
|
|
||||||
|
@app.exception_handler(ResponseValidationError)
|
||||||
|
@sub_app.exception_handler(ResponseValidationError)
|
||||||
|
async def response_validation_handler(_: Request, exc: ResponseValidationError):
|
||||||
|
captured_exception.capture(exc)
|
||||||
|
raise exc
|
||||||
|
|
||||||
|
|
||||||
|
@app.exception_handler(WebSocketRequestValidationError)
|
||||||
|
@sub_app.exception_handler(WebSocketRequestValidationError)
|
||||||
|
async def websocket_validation_handler(
|
||||||
|
websocket: WebSocket, exc: WebSocketRequestValidationError
|
||||||
|
):
|
||||||
|
captured_exception.capture(exc)
|
||||||
|
raise exc
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/users/{user_id}")
|
||||||
|
def get_user(user_id: int):
|
||||||
|
return {"user_id": user_id} # pragma: no cover
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/items/", response_model=Item)
|
||||||
|
def get_item():
|
||||||
|
return {"name": "Widget"}
|
||||||
|
|
||||||
|
|
||||||
|
@sub_app.get("/items/", response_model=Item)
|
||||||
|
def get_sub_item():
|
||||||
|
return {"name": "Widget"} # pragma: no cover
|
||||||
|
|
||||||
|
|
||||||
|
@app.websocket("/ws/{item_id}")
|
||||||
|
async def websocket_endpoint(websocket: WebSocket, item_id: int):
|
||||||
|
await websocket.accept() # pragma: no cover
|
||||||
|
await websocket.send_text(f"Item: {item_id}") # pragma: no cover
|
||||||
|
await websocket.close() # pragma: no cover
|
||||||
|
|
||||||
|
|
||||||
|
@sub_app.websocket("/ws/{item_id}")
|
||||||
|
async def subapp_websocket_endpoint(websocket: WebSocket, item_id: int):
|
||||||
|
await websocket.accept() # pragma: no cover
|
||||||
|
await websocket.send_text(f"Item: {item_id}") # pragma: no cover
|
||||||
|
await websocket.close() # pragma: no cover
|
||||||
|
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
|
||||||
|
|
||||||
|
def test_request_validation_error_includes_endpoint_context():
|
||||||
|
captured_exception.exception = None
|
||||||
|
try:
|
||||||
|
client.get("/users/invalid")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
assert captured_exception.exception is not None
|
||||||
|
error_str = str(captured_exception.exception)
|
||||||
|
assert "get_user" in error_str
|
||||||
|
assert "/users/" in error_str
|
||||||
|
|
||||||
|
|
||||||
|
def test_response_validation_error_includes_endpoint_context():
|
||||||
|
captured_exception.exception = None
|
||||||
|
try:
|
||||||
|
client.get("/items/")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
assert captured_exception.exception is not None
|
||||||
|
error_str = str(captured_exception.exception)
|
||||||
|
assert "get_item" in error_str
|
||||||
|
assert "/items/" in error_str
|
||||||
|
|
||||||
|
|
||||||
|
def test_websocket_validation_error_includes_endpoint_context():
|
||||||
|
captured_exception.exception = None
|
||||||
|
try:
|
||||||
|
with client.websocket_connect("/ws/invalid"):
|
||||||
|
pass # pragma: no cover
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
assert captured_exception.exception is not None
|
||||||
|
error_str = str(captured_exception.exception)
|
||||||
|
assert "websocket_endpoint" in error_str
|
||||||
|
assert "/ws/" in error_str
|
||||||
|
|
||||||
|
|
||||||
|
def test_subapp_request_validation_error_includes_endpoint_context():
|
||||||
|
captured_exception.exception = None
|
||||||
|
try:
|
||||||
|
client.get("/sub/items/")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
assert captured_exception.exception is not None
|
||||||
|
error_str = str(captured_exception.exception)
|
||||||
|
assert "get_sub_item" in error_str
|
||||||
|
assert "/sub/items/" in error_str
|
||||||
|
|
||||||
|
|
||||||
|
def test_subapp_websocket_validation_error_includes_endpoint_context():
|
||||||
|
captured_exception.exception = None
|
||||||
|
try:
|
||||||
|
with client.websocket_connect("/sub/ws/invalid"):
|
||||||
|
pass # pragma: no cover
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
assert captured_exception.exception is not None
|
||||||
|
error_str = str(captured_exception.exception)
|
||||||
|
assert "subapp_websocket_endpoint" in error_str
|
||||||
|
assert "/sub/ws/" in error_str
|
||||||
|
|
||||||
|
|
||||||
|
def test_validation_error_with_only_path():
|
||||||
|
errors = [{"type": "missing", "loc": ("body", "name"), "msg": "Field required"}]
|
||||||
|
exc = RequestValidationError(errors, endpoint_ctx={"path": "GET /api/test"})
|
||||||
|
error_str = str(exc)
|
||||||
|
assert "Endpoint: GET /api/test" in error_str
|
||||||
|
assert 'File "' not in error_str
|
||||||
|
|
||||||
|
|
||||||
|
def test_validation_error_with_no_context():
|
||||||
|
errors = [{"type": "missing", "loc": ("body", "name"), "msg": "Field required"}]
|
||||||
|
exc = RequestValidationError(errors, endpoint_ctx={})
|
||||||
|
error_str = str(exc)
|
||||||
|
assert "1 validation error:" in error_str
|
||||||
|
assert "Endpoint" not in error_str
|
||||||
|
assert 'File "' not in error_str
|
||||||
Loading…
Reference in New Issue