mirror of https://github.com/tiangolo/fastapi.git
202 lines
6.1 KiB
Python
202 lines
6.1 KiB
Python
from contextvars import ContextVar
|
|
from typing import Any, Dict, Tuple
|
|
|
|
import pytest
|
|
from fastapi import Depends, FastAPI, WebSocket
|
|
from fastapi.exceptions import FastAPIError
|
|
from fastapi.testclient import TestClient
|
|
from typing_extensions import Annotated
|
|
|
|
global_context: ContextVar[Dict[str, Any]] = ContextVar("global_context", default={}) # noqa: B039
|
|
|
|
|
|
class Session:
|
|
def __init__(self) -> None:
|
|
self.open = True
|
|
|
|
|
|
async def dep_session() -> Any:
|
|
s = Session()
|
|
yield s
|
|
s.open = False
|
|
global_state = global_context.get()
|
|
global_state["session_closed"] = True
|
|
|
|
|
|
SessionFuncDep = Annotated[Session, Depends(dep_session, scope="function")]
|
|
SessionRequestDep = Annotated[Session, Depends(dep_session, scope="request")]
|
|
SessionDefaultDep = Annotated[Session, Depends(dep_session)]
|
|
|
|
|
|
class NamedSession:
|
|
def __init__(self, name: str = "default") -> None:
|
|
self.name = name
|
|
self.open = True
|
|
|
|
|
|
def get_named_session(session: SessionRequestDep, session_b: SessionDefaultDep) -> Any:
|
|
assert session is session_b
|
|
named_session = NamedSession(name="named")
|
|
yield named_session, session_b
|
|
named_session.open = False
|
|
global_state = global_context.get()
|
|
global_state["named_session_closed"] = True
|
|
|
|
|
|
NamedSessionsDep = Annotated[Tuple[NamedSession, Session], Depends(get_named_session)]
|
|
|
|
|
|
def get_named_func_session(session: SessionFuncDep) -> Any:
|
|
named_session = NamedSession(name="named")
|
|
yield named_session, session
|
|
named_session.open = False
|
|
global_state = global_context.get()
|
|
global_state["named_func_session_closed"] = True
|
|
|
|
|
|
def get_named_regular_func_session(session: SessionFuncDep) -> Any:
|
|
named_session = NamedSession(name="named")
|
|
return named_session, session
|
|
|
|
|
|
BrokenSessionsDep = Annotated[
|
|
Tuple[NamedSession, Session], Depends(get_named_func_session)
|
|
]
|
|
NamedSessionsFuncDep = Annotated[
|
|
Tuple[NamedSession, Session], Depends(get_named_func_session, scope="function")
|
|
]
|
|
|
|
RegularSessionsDep = Annotated[
|
|
Tuple[NamedSession, Session], Depends(get_named_regular_func_session)
|
|
]
|
|
|
|
app = FastAPI()
|
|
|
|
|
|
@app.websocket("/function-scope")
|
|
async def function_scope(websocket: WebSocket, session: SessionFuncDep) -> Any:
|
|
await websocket.accept()
|
|
await websocket.send_json({"is_open": session.open})
|
|
|
|
|
|
@app.websocket("/request-scope")
|
|
async def request_scope(websocket: WebSocket, session: SessionRequestDep) -> Any:
|
|
await websocket.accept()
|
|
await websocket.send_json({"is_open": session.open})
|
|
|
|
|
|
@app.websocket("/two-scopes")
|
|
async def get_stream_session(
|
|
websocket: WebSocket,
|
|
function_session: SessionFuncDep,
|
|
request_session: SessionRequestDep,
|
|
) -> Any:
|
|
await websocket.accept()
|
|
await websocket.send_json(
|
|
{"func_is_open": function_session.open, "req_is_open": request_session.open}
|
|
)
|
|
|
|
|
|
@app.websocket("/sub")
|
|
async def get_sub(websocket: WebSocket, sessions: NamedSessionsDep) -> Any:
|
|
await websocket.accept()
|
|
await websocket.send_json(
|
|
{"named_session_open": sessions[0].open, "session_open": sessions[1].open}
|
|
)
|
|
|
|
|
|
@app.websocket("/named-function-scope")
|
|
async def get_named_function_scope(
|
|
websocket: WebSocket, sessions: NamedSessionsFuncDep
|
|
) -> Any:
|
|
await websocket.accept()
|
|
await websocket.send_json(
|
|
{"named_session_open": sessions[0].open, "session_open": sessions[1].open}
|
|
)
|
|
|
|
|
|
@app.websocket("/regular-function-scope")
|
|
async def get_regular_function_scope(
|
|
websocket: WebSocket, sessions: RegularSessionsDep
|
|
) -> Any:
|
|
await websocket.accept()
|
|
await websocket.send_json(
|
|
{"named_session_open": sessions[0].open, "session_open": sessions[1].open}
|
|
)
|
|
|
|
|
|
client = TestClient(app)
|
|
|
|
|
|
def test_function_scope() -> None:
|
|
global_context.set({})
|
|
global_state = global_context.get()
|
|
with client.websocket_connect("/function-scope") as websocket:
|
|
data = websocket.receive_json()
|
|
assert data["is_open"] is True
|
|
assert global_state["session_closed"] is True
|
|
|
|
|
|
def test_request_scope() -> None:
|
|
global_context.set({})
|
|
global_state = global_context.get()
|
|
with client.websocket_connect("/request-scope") as websocket:
|
|
data = websocket.receive_json()
|
|
assert data["is_open"] is True
|
|
assert global_state["session_closed"] is True
|
|
|
|
|
|
def test_two_scopes() -> None:
|
|
global_context.set({})
|
|
global_state = global_context.get()
|
|
with client.websocket_connect("/two-scopes") as websocket:
|
|
data = websocket.receive_json()
|
|
assert data["func_is_open"] is True
|
|
assert data["req_is_open"] is True
|
|
assert global_state["session_closed"] is True
|
|
|
|
|
|
def test_sub() -> None:
|
|
global_context.set({})
|
|
global_state = global_context.get()
|
|
with client.websocket_connect("/sub") as websocket:
|
|
data = websocket.receive_json()
|
|
assert data["named_session_open"] is True
|
|
assert data["session_open"] is True
|
|
assert global_state["session_closed"] is True
|
|
assert global_state["named_session_closed"] is True
|
|
|
|
|
|
def test_broken_scope() -> None:
|
|
with pytest.raises(
|
|
FastAPIError,
|
|
match='The dependency "get_named_func_session" has a scope of "request", it cannot depend on dependencies with scope "function"',
|
|
):
|
|
|
|
@app.websocket("/broken-scope")
|
|
async def get_broken(
|
|
websocket: WebSocket, sessions: BrokenSessionsDep
|
|
) -> Any: # pragma: no cover
|
|
pass
|
|
|
|
|
|
def test_named_function_scope() -> None:
|
|
global_context.set({})
|
|
global_state = global_context.get()
|
|
with client.websocket_connect("/named-function-scope") as websocket:
|
|
data = websocket.receive_json()
|
|
assert data["named_session_open"] is True
|
|
assert data["session_open"] is True
|
|
assert global_state["session_closed"] is True
|
|
assert global_state["named_func_session_closed"] is True
|
|
|
|
|
|
def test_regular_function_scope() -> None:
|
|
global_context.set({})
|
|
global_state = global_context.get()
|
|
with client.websocket_connect("/regular-function-scope") as websocket:
|
|
data = websocket.receive_json()
|
|
assert data["named_session_open"] is True
|
|
assert data["session_open"] is True
|
|
assert global_state["session_closed"] is True
|