Use weak refs for endpoint context cache

This commit is contained in:
Sagar Bhandari 2025-12-20 10:06:32 -05:00
parent 5c7dceb80f
commit be79af007a
2 changed files with 37 additions and 6 deletions

View File

@ -76,6 +76,7 @@ from starlette.routing import Mount as Mount # noqa
from starlette.types import AppType, ASGIApp, Lifespan, Receive, Scope, Send
from starlette.websockets import WebSocket
from typing_extensions import deprecated
from weakref import WeakKeyDictionary
# Copy of starlette.routing.request_response modified to include the
@ -212,15 +213,17 @@ def _merge_lifespan_context(
# Cache for endpoint context to avoid re-extracting on every request
_endpoint_context_cache: dict[int, EndpointContext] = {}
_endpoint_context_cache: WeakKeyDictionary[Any, EndpointContext] = WeakKeyDictionary()
def _extract_endpoint_context(func: Any) -> EndpointContext:
"""Extract endpoint context with caching to avoid repeated file I/O."""
func_id = id(func)
if func_id in _endpoint_context_cache:
return _endpoint_context_cache[func_id]
try:
cached = _endpoint_context_cache.get(func)
except TypeError:
cached = None
if cached is not None:
return cached
try:
ctx: EndpointContext = {}
@ -234,7 +237,10 @@ def _extract_endpoint_context(func: Any) -> EndpointContext:
except Exception:
ctx = EndpointContext()
_endpoint_context_cache[func_id] = ctx
try:
_endpoint_context_cache[func] = ctx
except TypeError:
pass
return ctx

View File

@ -0,0 +1,25 @@
import gc
import weakref
from fastapi.routing import _endpoint_context_cache, _extract_endpoint_context
def _make_endpoint():
def endpoint():
return None
return endpoint
def test_endpoint_context_cache_releases_endpoints():
endpoint = _make_endpoint()
_extract_endpoint_context(endpoint)
assert endpoint in _endpoint_context_cache
ref = weakref.ref(endpoint)
size_with_endpoint = len(_endpoint_context_cache)
del endpoint
gc.collect()
assert ref() is None
assert len(_endpoint_context_cache) <= size_with_endpoint - 1