This commit is contained in:
Kai G. Schwebke 2026-03-16 02:38:43 +02:00 committed by GitHub
commit 23f66aef0d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 35 additions and 8 deletions

View File

@ -4594,12 +4594,20 @@ class Qwen3Model(Qwen2Model):
if self.is_rerank:
self.gguf_writer.add_pooling_type(gguf.PoolingType.RANK)
self.gguf_writer.add_classifier_output_labels(["yes", "no"])
self.gguf_writer.add_chat_template([{
"name": "rerank",
"template": "<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".<|im_end|>\n"
"<|im_start|>user\n<Instruct>: Given a web search query, retrieve relevant passages that answer the query\n<Query>: {query}\n<Document>: {document}<|im_end|>\n"
"<|im_start|>assistant\n<think>\n\n</think>\n\n"
}])
self.gguf_writer.add_chat_template([
{
"name": "rerank",
"template": "<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".<|im_end|>\n"
"<|im_start|>user\n<Instruct>: Given a web search query, retrieve relevant passages that answer the query\n<Query>: {query}\n<Document>: {document}<|im_end|>\n"
"<|im_start|>assistant\n<think>\n\n</think>\n\n",
},
{
"name": "rerank_instruct",
"template": "<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".<|im_end|>\n"
"<|im_start|>user\n<Instruct>: {instruction}\n<Query>: {query}\n<Document>: {document}<|im_end|>\n"
"<|im_start|>assistant\n<think>\n\n</think>\n\n",
},
])
def _get_cls_out_tensor(self, data_torch: Tensor) -> Tensor:
# extract "yes" and "no" tokens from the output lm_head tensor

View File

@ -2021,13 +2021,22 @@ server_tokens format_prompt_rerank(
const struct llama_model * model,
const struct llama_vocab * vocab,
mtmd_context * mctx,
const std::string & instruction,
const std::string & query,
const std::string & doc) {
server_tokens result = {};
const char * rerank_prompt = llama_model_chat_template(model, "rerank");
const char * rerank_prompt_instruct = llama_model_chat_template(model, "rerank_instruct");
if (rerank_prompt != nullptr) {
if ( (rerank_prompt_instruct != nullptr) && !instruction.empty() ) {
std::string prompt = rerank_prompt_instruct;
string_replace_all(prompt, "{instruction}", instruction);
string_replace_all(prompt, "{query}" , query);
string_replace_all(prompt, "{document}" , doc );
server_tokens tokens = tokenize_input_subprompt(vocab, mctx, prompt, false, true);
result.push_back(tokens);
} else if (rerank_prompt != nullptr) {
std::string prompt = rerank_prompt;
string_replace_all(prompt, "{query}" , query);
string_replace_all(prompt, "{document}", doc );

View File

@ -369,5 +369,6 @@ server_tokens format_prompt_rerank(
const struct llama_model * model,
const struct llama_vocab * vocab,
mtmd_context * mctx,
const std::string & instruction,
const std::string & query,
const std::string & doc);

View File

@ -3876,6 +3876,15 @@ void server_routes::init_routes() {
res->error(format_error_response("\"query\" must be provided", ERROR_TYPE_INVALID_REQUEST));
return res;
}
json instruction = "";
if (body.count("instruction") == 1) {
instruction = body.at("instruction");
if (!instruction.is_string()) {
res->error(format_error_response("\"instruction\" must be a string", ERROR_TYPE_INVALID_REQUEST));
return res;
}
}
std::vector<std::string> documents = json_value(body, "documents",
json_value(body, "texts", std::vector<std::string>()));
@ -3893,7 +3902,7 @@ void server_routes::init_routes() {
std::vector<server_task> tasks;
tasks.reserve(documents.size());
for (size_t i = 0; i < documents.size(); i++) {
auto tmp = format_prompt_rerank(ctx_server.model, ctx_server.vocab, ctx_server.mctx, query, documents[i]);
auto tmp = format_prompt_rerank(ctx_server.model, ctx_server.vocab, ctx_server.mctx, instruction, query, documents[i]);
server_task task = server_task(SERVER_TASK_TYPE_RERANK);
task.id = rd.get_new_id();
task.tokens = std::move(tmp);