diff --git a/tests/test_dependency_overrides_background_tasks.py b/tests/test_dependency_overrides_background_tasks.py index 2825575d4..39ca2d2e6 100644 --- a/tests/test_dependency_overrides_background_tasks.py +++ b/tests/test_dependency_overrides_background_tasks.py @@ -18,17 +18,14 @@ def background_task_fixture(): return background_task +class BackgroundTasksOverride(BackgroundTasks): + pass + + app = FastAPI() -router = APIRouter() - -@app.get("/app") -def app_background_tasks(background_tasks: BackgroundTasks): - background_tasks.add_task(background_task, type(background_tasks)) - - -@app.get("/overrides") +@app.get("/endpoint") def app_overrides(background_tasks: BackgroundTasks): background_tasks.add_task(background_task, type(background_tasks)) @@ -42,6 +39,14 @@ def nested_dependency_override(nested: str = Depends(nested_dependency)): pass +@app.get("/specify-background-tasks-dependency-with-subclass-of-BackgroundTasks") +def explicit_background_tasks_dependency(background_tasks: BackgroundTasksOverride): + background_tasks.add_task(background_task, type(background_tasks)) + + +router = APIRouter() + + @router.get("/router-endpoint") def router_endpoint(background_tasks: BackgroundTasks): background_tasks.add_task(background_task, type(background_tasks)) @@ -60,18 +65,21 @@ def override_background_tasks(app, override_with): def test_normal_app_uses_standard_background_tasks(): - client.get("/app") + client.get("/endpoint") background_task.assert_called_once_with(BackgroundTasks) @pytest.mark.parametrize( - "url", ["/overrides", "/nested-dependency-override", "/router-endpoint"] + "url", + [ + "/endpoint", + "/nested-dependency-override", + "/router-endpoint", + "/specify-background-tasks-dependency-with-subclass-of-BackgroundTasks", + ], ) def test_app_uses_background_task_override(url): - class BackgroundTasksOverride(BackgroundTasks): - pass - background_task.reset_mock() with override_background_tasks(app, BackgroundTasksOverride):