From 88879d5bebaf037e063c2450039ea23a928152b2 Mon Sep 17 00:00:00 2001 From: Savannah Ostrowski Date: Thu, 6 Nov 2025 09:58:13 -0800 Subject: [PATCH] Extract endpoint context --- fastapi/exceptions.py | 74 +++++++++++++++++++++++++++++++++++-------- fastapi/routing.py | 49 ++++++++++++++++++++++++++-- 2 files changed, 107 insertions(+), 16 deletions(-) diff --git a/fastapi/exceptions.py b/fastapi/exceptions.py index 44d4ada86..22c151a7b 100644 --- a/fastapi/exceptions.py +++ b/fastapi/exceptions.py @@ -1,10 +1,15 @@ -from typing import Any, Dict, Optional, Sequence, Type, Union +from typing import Any, Dict, Optional, Sequence, Type, TypedDict, Union from pydantic import BaseModel, create_model from starlette.exceptions import HTTPException as StarletteHTTPException from starlette.exceptions import WebSocketException as StarletteWebSocketException from typing_extensions import Annotated, Doc +class EndpointContext(TypedDict, total=False): + function: str + path: str + file: str + line: int class HTTPException(StarletteHTTPException): """ @@ -147,30 +152,73 @@ class FastAPIError(RuntimeError): class ValidationException(Exception): - def __init__(self, errors: Sequence[Any]) -> None: + def __init__( + self, + errors: Sequence[Any], + *, + endpoint_ctx: Optional[EndpointContext] = None, + ) -> None: self._errors = errors + self.endpoint_ctx = endpoint_ctx + + ctx = endpoint_ctx or {} + self.endpoint_function = ctx.get("function") + self.endpoint_path = ctx.get("path") + self.endpoint_file = ctx.get("file") + self.endpoint_line = ctx.get("line") def errors(self) -> Sequence[Any]: return self._errors + def _format_endpoint_context(self) -> str: + """Format endpoint context in native Python traceback format.""" + if not (self.endpoint_file and self.endpoint_line and self.endpoint_function): + if self.endpoint_path: + return f"\n Endpoint: {self.endpoint_path}" + return "" + + context = f'\n File "{self.endpoint_file}", line {self.endpoint_line}, in {self.endpoint_function}' + if self.endpoint_path: + context += f"\n {self.endpoint_path}" + return context + + def __str__(self) -> str: + message = f"{len(self._errors)} validation error{'s' if len(self._errors) != 1 else ''}:\n" + for err in self._errors: + message += f" {err}\n" + message += self._format_endpoint_context() + return message.rstrip() + class RequestValidationError(ValidationException): - def __init__(self, errors: Sequence[Any], *, body: Any = None) -> None: - super().__init__(errors) + def __init__( + self, + errors: Sequence[Any], + *, + body: Any = None, + endpoint_ctx: Optional[EndpointContext] = None, + ) -> None: + super().__init__(errors, endpoint_ctx=endpoint_ctx) self.body = body class WebSocketRequestValidationError(ValidationException): - pass + def __init__( + self, + errors: Sequence[Any], + *, + endpoint_ctx: Optional[EndpointContext] = None, + ) -> None: + super().__init__(errors, endpoint_ctx=endpoint_ctx) class ResponseValidationError(ValidationException): - def __init__(self, errors: Sequence[Any], *, body: Any = None) -> None: - super().__init__(errors) + def __init__( + self, + errors: Sequence[Any], + *, + body: Any = None, + endpoint_ctx: Optional[EndpointContext] = None, + ) -> None: + super().__init__(errors, endpoint_ctx=endpoint_ctx) self.body = body - - def __str__(self) -> str: - message = f"{len(self._errors)} validation errors:\n" - for err in self._errors: - message += f" {err}\n" - return message diff --git a/fastapi/routing.py b/fastapi/routing.py index fe25d7dec..546845701 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -46,6 +46,7 @@ from fastapi.dependencies.utils import ( ) from fastapi.encoders import jsonable_encoder from fastapi.exceptions import ( + EndpointContext, FastAPIError, RequestValidationError, ResponseValidationError, @@ -216,6 +217,32 @@ def _merge_lifespan_context( return merged_lifespan # type: ignore[return-value] +# Cache for endpoint context to avoid re-extracting on every request +_endpoint_context_cache: Dict[int, EndpointContext] = {} + + +def _extract_endpoint_context(func: Any) -> EndpointContext: + """Extract endpoint context with caching to avoid repeated file I/O.""" + func_id = id(func) + + if func_id in _endpoint_context_cache: + return _endpoint_context_cache[func_id] + + try: + ctx: EndpointContext = {} + + if (source_file := inspect.getsourcefile(func)) is not None: + ctx["file"] = source_file + if (line_number := inspect.getsourcelines(func)[1]) is not None: + ctx["line"] = line_number + if (func_name := getattr(func, "__name__", None)) is not None: + ctx["function"] = func_name + except Exception: + ctx = EndpointContext() + + _endpoint_context_cache[func_id] = ctx + return ctx + async def serialize_response( *, field: Optional[ModelField] = None, @@ -227,6 +254,7 @@ async def serialize_response( exclude_defaults: bool = False, exclude_none: bool = False, is_coroutine: bool = True, + endpoint_ctx: Optional[EndpointContext] = None, ) -> Any: if field: errors = [] @@ -249,8 +277,11 @@ async def serialize_response( elif errors_: errors.append(errors_) if errors: + ctx = endpoint_ctx or EndpointContext() raise ResponseValidationError( - errors=_normalize_errors(errors), body=response_content + errors=_normalize_errors(errors), + body=response_content, + endpoint_ctx=ctx, ) if hasattr(field, "serialize"): @@ -322,6 +353,11 @@ def get_request_handler( "fastapi_middleware_astack not found in request scope" ) + # Extract endpoint context for error messages + endpoint_ctx = _extract_endpoint_context(dependant.call) if dependant.call else EndpointContext() + if dependant.path: + endpoint_ctx["path"] = f"{request.method} {dependant.path}" + # Read body and auto-close files try: body: Any = None @@ -359,6 +395,7 @@ def get_request_handler( } ], body=e.doc, + endpoint_ctx=endpoint_ctx, ) raise validation_error from e except HTTPException: @@ -418,6 +455,7 @@ def get_request_handler( exclude_defaults=response_model_exclude_defaults, exclude_none=response_model_exclude_none, is_coroutine=is_coroutine, + endpoint_ctx=endpoint_ctx, ) response = actual_response_class(content, **response_args) if not is_body_allowed_for_status_code(response.status_code): @@ -425,7 +463,7 @@ def get_request_handler( response.headers.raw.extend(solved_result.response.headers.raw) if errors: validation_error = RequestValidationError( - _normalize_errors(errors), body=body + _normalize_errors(errors), body=body, endpoint_ctx=endpoint_ctx ) raise validation_error @@ -442,6 +480,10 @@ def get_websocket_app( embed_body_fields: bool = False, ) -> Callable[[WebSocket], Coroutine[Any, Any, Any]]: async def app(websocket: WebSocket) -> None: + endpoint_ctx = _extract_endpoint_context(dependant.call) if dependant.call else EndpointContext() + if dependant.path: + endpoint_ctx["path"] = f"WS {dependant.path}" + async_exit_stack = websocket.scope.get("fastapi_inner_astack") assert isinstance(async_exit_stack, AsyncExitStack), ( "fastapi_inner_astack not found in request scope" @@ -455,7 +497,8 @@ def get_websocket_app( ) if solved_result.errors: raise WebSocketRequestValidationError( - _normalize_errors(solved_result.errors) + _normalize_errors(solved_result.errors), + endpoint_ctx=endpoint_ctx, ) assert dependant.call is not None, "dependant.call must be a function" await dependant.call(**solved_result.values)