mirror of https://github.com/tiangolo/fastapi.git
Use weak refs for endpoint context cache
This commit is contained in:
parent
5c7dceb80f
commit
be79af007a
|
|
@ -76,6 +76,7 @@ from starlette.routing import Mount as Mount # noqa
|
|||
from starlette.types import AppType, ASGIApp, Lifespan, Receive, Scope, Send
|
||||
from starlette.websockets import WebSocket
|
||||
from typing_extensions import deprecated
|
||||
from weakref import WeakKeyDictionary
|
||||
|
||||
|
||||
# Copy of starlette.routing.request_response modified to include the
|
||||
|
|
@ -212,15 +213,17 @@ def _merge_lifespan_context(
|
|||
|
||||
|
||||
# Cache for endpoint context to avoid re-extracting on every request
|
||||
_endpoint_context_cache: dict[int, EndpointContext] = {}
|
||||
_endpoint_context_cache: WeakKeyDictionary[Any, EndpointContext] = WeakKeyDictionary()
|
||||
|
||||
|
||||
def _extract_endpoint_context(func: Any) -> EndpointContext:
|
||||
"""Extract endpoint context with caching to avoid repeated file I/O."""
|
||||
func_id = id(func)
|
||||
|
||||
if func_id in _endpoint_context_cache:
|
||||
return _endpoint_context_cache[func_id]
|
||||
try:
|
||||
cached = _endpoint_context_cache.get(func)
|
||||
except TypeError:
|
||||
cached = None
|
||||
if cached is not None:
|
||||
return cached
|
||||
|
||||
try:
|
||||
ctx: EndpointContext = {}
|
||||
|
|
@ -234,7 +237,10 @@ def _extract_endpoint_context(func: Any) -> EndpointContext:
|
|||
except Exception:
|
||||
ctx = EndpointContext()
|
||||
|
||||
_endpoint_context_cache[func_id] = ctx
|
||||
try:
|
||||
_endpoint_context_cache[func] = ctx
|
||||
except TypeError:
|
||||
pass
|
||||
return ctx
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,25 @@
|
|||
import gc
|
||||
import weakref
|
||||
|
||||
from fastapi.routing import _endpoint_context_cache, _extract_endpoint_context
|
||||
|
||||
|
||||
def _make_endpoint():
|
||||
def endpoint():
|
||||
return None
|
||||
|
||||
return endpoint
|
||||
|
||||
|
||||
def test_endpoint_context_cache_releases_endpoints():
|
||||
endpoint = _make_endpoint()
|
||||
_extract_endpoint_context(endpoint)
|
||||
assert endpoint in _endpoint_context_cache
|
||||
|
||||
ref = weakref.ref(endpoint)
|
||||
size_with_endpoint = len(_endpoint_context_cache)
|
||||
del endpoint
|
||||
gc.collect()
|
||||
|
||||
assert ref() is None
|
||||
assert len(_endpoint_context_cache) <= size_with_endpoint - 1
|
||||
Loading…
Reference in New Issue