fastapi/tests/test_dependency_yield_scope...

202 lines
6.1 KiB
Python

from contextvars import ContextVar
from typing import Any, Dict, Tuple
import pytest
from fastapi import Depends, FastAPI, WebSocket
from fastapi.exceptions import FastAPIError
from fastapi.testclient import TestClient
from typing_extensions import Annotated
global_context: ContextVar[Dict[str, Any]] = ContextVar("global_context", default={}) # noqa: B039
class Session:
def __init__(self) -> None:
self.open = True
async def dep_session() -> Any:
s = Session()
yield s
s.open = False
global_state = global_context.get()
global_state["session_closed"] = True
SessionFuncDep = Annotated[Session, Depends(dep_session, scope="function")]
SessionRequestDep = Annotated[Session, Depends(dep_session, scope="request")]
SessionDefaultDep = Annotated[Session, Depends(dep_session)]
class NamedSession:
def __init__(self, name: str = "default") -> None:
self.name = name
self.open = True
def get_named_session(session: SessionRequestDep, session_b: SessionDefaultDep) -> Any:
assert session is session_b
named_session = NamedSession(name="named")
yield named_session, session_b
named_session.open = False
global_state = global_context.get()
global_state["named_session_closed"] = True
NamedSessionsDep = Annotated[Tuple[NamedSession, Session], Depends(get_named_session)]
def get_named_func_session(session: SessionFuncDep) -> Any:
named_session = NamedSession(name="named")
yield named_session, session
named_session.open = False
global_state = global_context.get()
global_state["named_func_session_closed"] = True
def get_named_regular_func_session(session: SessionFuncDep) -> Any:
named_session = NamedSession(name="named")
return named_session, session
BrokenSessionsDep = Annotated[
Tuple[NamedSession, Session], Depends(get_named_func_session)
]
NamedSessionsFuncDep = Annotated[
Tuple[NamedSession, Session], Depends(get_named_func_session, scope="function")
]
RegularSessionsDep = Annotated[
Tuple[NamedSession, Session], Depends(get_named_regular_func_session)
]
app = FastAPI()
@app.websocket("/function-scope")
async def function_scope(websocket: WebSocket, session: SessionFuncDep) -> Any:
await websocket.accept()
await websocket.send_json({"is_open": session.open})
@app.websocket("/request-scope")
async def request_scope(websocket: WebSocket, session: SessionRequestDep) -> Any:
await websocket.accept()
await websocket.send_json({"is_open": session.open})
@app.websocket("/two-scopes")
async def get_stream_session(
websocket: WebSocket,
function_session: SessionFuncDep,
request_session: SessionRequestDep,
) -> Any:
await websocket.accept()
await websocket.send_json(
{"func_is_open": function_session.open, "req_is_open": request_session.open}
)
@app.websocket("/sub")
async def get_sub(websocket: WebSocket, sessions: NamedSessionsDep) -> Any:
await websocket.accept()
await websocket.send_json(
{"named_session_open": sessions[0].open, "session_open": sessions[1].open}
)
@app.websocket("/named-function-scope")
async def get_named_function_scope(
websocket: WebSocket, sessions: NamedSessionsFuncDep
) -> Any:
await websocket.accept()
await websocket.send_json(
{"named_session_open": sessions[0].open, "session_open": sessions[1].open}
)
@app.websocket("/regular-function-scope")
async def get_regular_function_scope(
websocket: WebSocket, sessions: RegularSessionsDep
) -> Any:
await websocket.accept()
await websocket.send_json(
{"named_session_open": sessions[0].open, "session_open": sessions[1].open}
)
client = TestClient(app)
def test_function_scope() -> None:
global_context.set({})
global_state = global_context.get()
with client.websocket_connect("/function-scope") as websocket:
data = websocket.receive_json()
assert data["is_open"] is True
assert global_state["session_closed"] is True
def test_request_scope() -> None:
global_context.set({})
global_state = global_context.get()
with client.websocket_connect("/request-scope") as websocket:
data = websocket.receive_json()
assert data["is_open"] is True
assert global_state["session_closed"] is True
def test_two_scopes() -> None:
global_context.set({})
global_state = global_context.get()
with client.websocket_connect("/two-scopes") as websocket:
data = websocket.receive_json()
assert data["func_is_open"] is True
assert data["req_is_open"] is True
assert global_state["session_closed"] is True
def test_sub() -> None:
global_context.set({})
global_state = global_context.get()
with client.websocket_connect("/sub") as websocket:
data = websocket.receive_json()
assert data["named_session_open"] is True
assert data["session_open"] is True
assert global_state["session_closed"] is True
assert global_state["named_session_closed"] is True
def test_broken_scope() -> None:
with pytest.raises(
FastAPIError,
match='The dependency "get_named_func_session" has a scope of "request", it cannot depend on dependencies with scope "function"',
):
@app.websocket("/broken-scope")
async def get_broken(
websocket: WebSocket, sessions: BrokenSessionsDep
) -> Any: # pragma: no cover
pass
def test_named_function_scope() -> None:
global_context.set({})
global_state = global_context.get()
with client.websocket_connect("/named-function-scope") as websocket:
data = websocket.receive_json()
assert data["named_session_open"] is True
assert data["session_open"] is True
assert global_state["session_closed"] is True
assert global_state["named_func_session_closed"] is True
def test_regular_function_scope() -> None:
global_context.set({})
global_state = global_context.get()
with client.websocket_connect("/regular-function-scope") as websocket:
data = websocket.receive_json()
assert data["named_session_open"] is True
assert data["session_open"] is True
assert global_state["session_closed"] is True