diff --git a/scripts/fetch_server_test_models.py b/scripts/fetch_server_test_models.py index ac483ef5d7..f43d1f63cd 100755 --- a/scripts/fetch_server_test_models.py +++ b/scripts/fetch_server_test_models.py @@ -95,9 +95,9 @@ if __name__ == '__main__': '-p', 'Hey', '--no-warmup', '--log-disable', - '-no-cnv'] + '-st'] if m.hf_file != 'tinyllamas/stories260K.gguf' and 'Mistral-Nemo' not in m.hf_repo: - cmd.append('-fa') + cmd += ('-fa', 'on') try: subprocess.check_call(cmd) except subprocess.CalledProcessError: diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index d99d2293ce..2e43bebd6f 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -1432,9 +1432,10 @@ private: res->tokens = { tkn.tok }; } - res->n_decoded = slot.n_decoded; - res->n_prompt_tokens = slot.task->n_tokens(); - res->post_sampling_probs = slot.task->params.post_sampling_probs; + res->n_decoded = slot.n_decoded; + res->n_prompt_tokens = slot.task->n_tokens(); + res->n_prompt_tokens_cache = slot.n_prompt_tokens_cache; + res->post_sampling_probs = slot.task->params.post_sampling_probs; res->verbose = slot.task->params.verbose; res->res_type = slot.task->params.res_type; @@ -1479,14 +1480,15 @@ private: res->prompt = slot.task->tokens.detokenize(ctx, true); res->response_fields = std::move(slot.task->params.response_fields); - res->truncated = slot.truncated; - res->n_decoded = slot.n_decoded; - res->n_prompt_tokens = slot.task->n_tokens(); - res->n_tokens_cached = slot.prompt.n_tokens(); - res->has_new_line = slot.has_new_line; - res->stopping_word = slot.stopping_word; - res->stop = slot.stop; - res->post_sampling_probs = slot.task->params.post_sampling_probs; + res->truncated = slot.truncated; + res->n_decoded = slot.n_decoded; + res->n_prompt_tokens = slot.task->n_tokens(); + res->n_prompt_tokens_cache = slot.n_prompt_tokens_cache; + res->n_tokens_cached = slot.prompt.n_tokens(); + res->has_new_line = slot.has_new_line; + res->stopping_word = slot.stopping_word; + res->stop = slot.stop; + res->post_sampling_probs = slot.task->params.post_sampling_probs; res->verbose = slot.task->params.verbose; res->stream = slot.task->params.stream; diff --git a/tools/server/server-task.cpp b/tools/server/server-task.cpp index 1d511cfa2d..39d232c2e4 100644 --- a/tools/server/server-task.cpp +++ b/tools/server/server-task.cpp @@ -755,6 +755,15 @@ json server_task_result_cmpl_final::to_json_non_oaicompat() { return response_fields.empty() ? res : json_get_nested_values(response_fields, res); } +json server_task_result_cmpl_final::usage_json_oaicompat() { + return json { + {"completion_tokens", n_decoded}, + {"prompt_tokens", n_prompt_tokens}, + {"total_tokens", n_decoded + n_prompt_tokens}, + {"prompt_tokens_details", json { {"cached_tokens", n_prompt_tokens_cache} }}, + }; +} + json server_task_result_cmpl_final::to_json_oaicompat() { std::time_t t = std::time(0); json logprobs = json(nullptr); // OAI default to null @@ -780,11 +789,7 @@ json server_task_result_cmpl_final::to_json_oaicompat() { {"model", oaicompat_model}, {"system_fingerprint", build_info}, {"object", "text_completion"}, - {"usage", json { - {"completion_tokens", n_decoded}, - {"prompt_tokens", n_prompt_tokens}, - {"total_tokens", n_decoded + n_prompt_tokens} - }}, + {"usage", usage_json_oaicompat()}, {"id", oaicompat_cmpl_id} }; @@ -832,11 +837,7 @@ json server_task_result_cmpl_final::to_json_oaicompat_chat() { {"model", oaicompat_model}, {"system_fingerprint", build_info}, {"object", "chat.completion"}, - {"usage", json { - {"completion_tokens", n_decoded}, - {"prompt_tokens", n_prompt_tokens}, - {"total_tokens", n_decoded + n_prompt_tokens} - }}, + {"usage", usage_json_oaicompat()}, {"id", oaicompat_cmpl_id} }; @@ -901,11 +902,7 @@ json server_task_result_cmpl_final::to_json_oaicompat_chat_stream() { {"model", oaicompat_model}, {"system_fingerprint", build_info}, {"object", "chat.completion.chunk"}, - {"usage", json { - {"completion_tokens", n_decoded}, - {"prompt_tokens", n_prompt_tokens}, - {"total_tokens", n_decoded + n_prompt_tokens}, - }}, + {"usage", usage_json_oaicompat()}, }); } @@ -984,6 +981,7 @@ json server_task_result_cmpl_final::to_json_oaicompat_resp() { {"input_tokens", n_prompt_tokens}, {"output_tokens", n_decoded}, {"total_tokens", n_decoded + n_prompt_tokens}, + {"input_tokens_details", json { {"cached_tokens", n_prompt_tokens_cache} }}, }}, }; @@ -1092,7 +1090,8 @@ json server_task_result_cmpl_final::to_json_oaicompat_resp_stream() { {"usage", json { {"input_tokens", n_prompt_tokens}, {"output_tokens", n_decoded}, - {"total_tokens", n_decoded + n_prompt_tokens} + {"total_tokens", n_decoded + n_prompt_tokens}, + {"input_tokens_details", json { {"cached_tokens", n_prompt_tokens_cache} }}, }} }}, }} @@ -1158,7 +1157,8 @@ json server_task_result_cmpl_final::to_json_anthropic() { {"stop_reason", stop_reason}, {"stop_sequence", stopping_word.empty() ? nullptr : json(stopping_word)}, {"usage", { - {"input_tokens", n_prompt_tokens}, + {"cache_read_input_tokens", n_prompt_tokens_cache}, + {"input_tokens", n_prompt_tokens - n_prompt_tokens_cache}, {"output_tokens", n_decoded} }} }; @@ -1668,7 +1668,8 @@ json server_task_result_cmpl_partial::to_json_anthropic() { {"stop_reason", nullptr}, {"stop_sequence", nullptr}, {"usage", { - {"input_tokens", n_prompt_tokens}, + {"cache_read_input_tokens", n_prompt_tokens_cache}, + {"input_tokens", n_prompt_tokens - n_prompt_tokens_cache}, {"output_tokens", 0} }} }} diff --git a/tools/server/server-task.h b/tools/server/server-task.h index 1e342531d8..a49ddb594b 100644 --- a/tools/server/server-task.h +++ b/tools/server/server-task.h @@ -344,6 +344,7 @@ struct server_task_result_cmpl_final : server_task_result { bool truncated; int32_t n_decoded; int32_t n_prompt_tokens; + int32_t n_prompt_tokens_cache; int32_t n_tokens_cached; bool has_new_line; std::string stopping_word; @@ -387,6 +388,8 @@ struct server_task_result_cmpl_final : server_task_result { json to_json_non_oaicompat(); + json usage_json_oaicompat(); + json to_json_oaicompat(); json to_json_oaicompat_chat(); @@ -408,6 +411,7 @@ struct server_task_result_cmpl_partial : server_task_result { int32_t n_decoded; int32_t n_prompt_tokens; + int32_t n_prompt_tokens_cache; bool post_sampling_probs; bool is_progress = false; diff --git a/tools/server/tests/unit/test_chat_completion.py b/tools/server/tests/unit/test_chat_completion.py index 1bcffd91b6..edef0a93b4 100644 --- a/tools/server/tests/unit/test_chat_completion.py +++ b/tools/server/tests/unit/test_chat_completion.py @@ -51,6 +51,27 @@ def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_conte assert choice["finish_reason"] == finish_reason +def test_chat_completion_cached_tokens(): + global server + server.n_slots = 1 + server.start() + seq = [ + ("1 2 3 4 5 6", 77, 0), + ("1 2 3 4 5 6", 77, 76), + ("1 2 3 4 5 9", 77, 51), + ("1 2 3 9 9 9", 77, 47), + ] + for user_prompt, n_prompt, n_cache in seq: + res = server.make_request("POST", "/chat/completions", data={ + "max_tokens": 8, + "messages": [ + {"role": "system", "content": "Test"}, + {"role": "user", "content": user_prompt}, + ], + }) + assert res.body["usage"]["prompt_tokens"] == n_prompt + assert res.body["usage"]["prompt_tokens_details"]["cached_tokens"] == n_cache + @pytest.mark.parametrize( "system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason", [ diff --git a/tools/server/tests/unit/test_compat_anthropic.py b/tools/server/tests/unit/test_compat_anthropic.py index 93ff03be6b..ef1948d4a5 100644 --- a/tools/server/tests/unit/test_compat_anthropic.py +++ b/tools/server/tests/unit/test_compat_anthropic.py @@ -63,8 +63,10 @@ def test_anthropic_messages_basic(): assert "text" in res.body["content"][0], "Text content block missing 'text' field" assert res.body["stop_reason"] in ["end_turn", "max_tokens"], f"Invalid stop_reason: {res.body.get('stop_reason')}" assert "usage" in res.body, "Missing 'usage' field" + assert "cache_read_input_tokens" in res.body["usage"], "Missing usage.cache_read_input_tokens" assert "input_tokens" in res.body["usage"], "Missing usage.input_tokens" assert "output_tokens" in res.body["usage"], "Missing usage.output_tokens" + assert isinstance(res.body["usage"]["cache_read_input_tokens"], int), "cache_read_input_tokens should be integer" assert isinstance(res.body["usage"]["input_tokens"], int), "input_tokens should be integer" assert isinstance(res.body["usage"]["output_tokens"], int), "output_tokens should be integer" assert res.body["usage"]["output_tokens"] > 0, "Should have generated some tokens"