diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index eec0ea14e3..c26a5fd8b2 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -4594,12 +4594,20 @@ class Qwen3Model(Qwen2Model): if self.is_rerank: self.gguf_writer.add_pooling_type(gguf.PoolingType.RANK) self.gguf_writer.add_classifier_output_labels(["yes", "no"]) - self.gguf_writer.add_chat_template([{ - "name": "rerank", - "template": "<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".<|im_end|>\n" - "<|im_start|>user\n: Given a web search query, retrieve relevant passages that answer the query\n: {query}\n: {document}<|im_end|>\n" - "<|im_start|>assistant\n\n\n\n\n" - }]) + self.gguf_writer.add_chat_template([ + { + "name": "rerank", + "template": "<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".<|im_end|>\n" + "<|im_start|>user\n: Given a web search query, retrieve relevant passages that answer the query\n: {query}\n: {document}<|im_end|>\n" + "<|im_start|>assistant\n\n\n\n\n", + }, + { + "name": "rerank_instruct", + "template": "<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".<|im_end|>\n" + "<|im_start|>user\n: {instruction}\n: {query}\n: {document}<|im_end|>\n" + "<|im_start|>assistant\n\n\n\n\n", + }, + ]) def _get_cls_out_tensor(self, data_torch: Tensor) -> Tensor: # extract "yes" and "no" tokens from the output lm_head tensor diff --git a/tools/server/server-common.cpp b/tools/server/server-common.cpp index bd203228cc..44f0f160ce 100644 --- a/tools/server/server-common.cpp +++ b/tools/server/server-common.cpp @@ -2021,13 +2021,22 @@ server_tokens format_prompt_rerank( const struct llama_model * model, const struct llama_vocab * vocab, mtmd_context * mctx, + const std::string & instruction, const std::string & query, const std::string & doc) { server_tokens result = {}; const char * rerank_prompt = llama_model_chat_template(model, "rerank"); + const char * rerank_prompt_instruct = llama_model_chat_template(model, "rerank_instruct"); - if (rerank_prompt != nullptr) { + if ( (rerank_prompt_instruct != nullptr) && !instruction.empty() ) { + std::string prompt = rerank_prompt_instruct; + string_replace_all(prompt, "{instruction}", instruction); + string_replace_all(prompt, "{query}" , query); + string_replace_all(prompt, "{document}" , doc ); + server_tokens tokens = tokenize_input_subprompt(vocab, mctx, prompt, false, true); + result.push_back(tokens); + } else if (rerank_prompt != nullptr) { std::string prompt = rerank_prompt; string_replace_all(prompt, "{query}" , query); string_replace_all(prompt, "{document}", doc ); diff --git a/tools/server/server-common.h b/tools/server/server-common.h index 3e56b3d856..371403ec41 100644 --- a/tools/server/server-common.h +++ b/tools/server/server-common.h @@ -369,5 +369,6 @@ server_tokens format_prompt_rerank( const struct llama_model * model, const struct llama_vocab * vocab, mtmd_context * mctx, + const std::string & instruction, const std::string & query, const std::string & doc); diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index c47ad876cb..f372f2eb3d 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -3876,6 +3876,15 @@ void server_routes::init_routes() { res->error(format_error_response("\"query\" must be provided", ERROR_TYPE_INVALID_REQUEST)); return res; } + + json instruction = ""; + if (body.count("instruction") == 1) { + instruction = body.at("instruction"); + if (!instruction.is_string()) { + res->error(format_error_response("\"instruction\" must be a string", ERROR_TYPE_INVALID_REQUEST)); + return res; + } + } std::vector documents = json_value(body, "documents", json_value(body, "texts", std::vector())); @@ -3893,7 +3902,7 @@ void server_routes::init_routes() { std::vector tasks; tasks.reserve(documents.size()); for (size_t i = 0; i < documents.size(); i++) { - auto tmp = format_prompt_rerank(ctx_server.model, ctx_server.vocab, ctx_server.mctx, query, documents[i]); + auto tmp = format_prompt_rerank(ctx_server.model, ctx_server.vocab, ctx_server.mctx, instruction, query, documents[i]); server_task task = server_task(SERVER_TASK_TYPE_RERANK); task.id = rd.get_new_id(); task.tokens = std::move(tmp);