diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index e2e7bcb9e6..94825dc862 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -2313,6 +2313,12 @@ private: slot.n_prompt_tokens_processed = 0; slot.prompt.tokens.keep_first(n_past); + + // send initial 0% progress update if needed + // this is to signal the client that the request has started processing + if (slot.task->params.stream && slot.task->params.return_progress) { + send_partial_response(slot, {}, true); + } } if (!slot.can_split()) { diff --git a/tools/server/tests/unit/test_chat_completion.py b/tools/server/tests/unit/test_chat_completion.py index 64f3158b98..5f5de415cf 100644 --- a/tools/server/tests/unit/test_chat_completion.py +++ b/tools/server/tests/unit/test_chat_completion.py @@ -434,8 +434,8 @@ def test_context_size_exceeded_stream(): @pytest.mark.parametrize( "n_batch,batch_count,reuse_cache", [ - (64, 3, False), - (64, 1, True), + (64, 4, False), + (64, 2, True), ] ) def test_return_progress(n_batch, batch_count, reuse_cache): @@ -462,10 +462,18 @@ def test_return_progress(n_batch, batch_count, reuse_cache): res = make_cmpl_request() last_progress = None total_batch_count = 0 + for data in res: cur_progress = data.get("prompt_progress", None) if cur_progress is None: continue + if total_batch_count == 0: + # first progress report must have n_cache == n_processed + assert cur_progress["total"] > 0 + assert cur_progress["cache"] == cur_progress["processed"] + if reuse_cache: + # when reusing cache, we expect some cached tokens + assert cur_progress["cache"] > 0 if last_progress is not None: assert cur_progress["total"] == last_progress["total"] assert cur_progress["cache"] == last_progress["cache"] @@ -473,6 +481,7 @@ def test_return_progress(n_batch, batch_count, reuse_cache): total_batch_count += 1 last_progress = cur_progress + # last progress should indicate completion (all tokens processed) assert last_progress is not None assert last_progress["total"] > 0 assert last_progress["processed"] == last_progress["total"]