Merge b47e0cf21a into 9e2e2198b0
This commit is contained in:
commit
23f66aef0d
|
|
@ -4594,12 +4594,20 @@ class Qwen3Model(Qwen2Model):
|
|||
if self.is_rerank:
|
||||
self.gguf_writer.add_pooling_type(gguf.PoolingType.RANK)
|
||||
self.gguf_writer.add_classifier_output_labels(["yes", "no"])
|
||||
self.gguf_writer.add_chat_template([{
|
||||
"name": "rerank",
|
||||
"template": "<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".<|im_end|>\n"
|
||||
"<|im_start|>user\n<Instruct>: Given a web search query, retrieve relevant passages that answer the query\n<Query>: {query}\n<Document>: {document}<|im_end|>\n"
|
||||
"<|im_start|>assistant\n<think>\n\n</think>\n\n"
|
||||
}])
|
||||
self.gguf_writer.add_chat_template([
|
||||
{
|
||||
"name": "rerank",
|
||||
"template": "<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".<|im_end|>\n"
|
||||
"<|im_start|>user\n<Instruct>: Given a web search query, retrieve relevant passages that answer the query\n<Query>: {query}\n<Document>: {document}<|im_end|>\n"
|
||||
"<|im_start|>assistant\n<think>\n\n</think>\n\n",
|
||||
},
|
||||
{
|
||||
"name": "rerank_instruct",
|
||||
"template": "<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".<|im_end|>\n"
|
||||
"<|im_start|>user\n<Instruct>: {instruction}\n<Query>: {query}\n<Document>: {document}<|im_end|>\n"
|
||||
"<|im_start|>assistant\n<think>\n\n</think>\n\n",
|
||||
},
|
||||
])
|
||||
|
||||
def _get_cls_out_tensor(self, data_torch: Tensor) -> Tensor:
|
||||
# extract "yes" and "no" tokens from the output lm_head tensor
|
||||
|
|
|
|||
|
|
@ -2021,13 +2021,22 @@ server_tokens format_prompt_rerank(
|
|||
const struct llama_model * model,
|
||||
const struct llama_vocab * vocab,
|
||||
mtmd_context * mctx,
|
||||
const std::string & instruction,
|
||||
const std::string & query,
|
||||
const std::string & doc) {
|
||||
server_tokens result = {};
|
||||
|
||||
const char * rerank_prompt = llama_model_chat_template(model, "rerank");
|
||||
const char * rerank_prompt_instruct = llama_model_chat_template(model, "rerank_instruct");
|
||||
|
||||
if (rerank_prompt != nullptr) {
|
||||
if ( (rerank_prompt_instruct != nullptr) && !instruction.empty() ) {
|
||||
std::string prompt = rerank_prompt_instruct;
|
||||
string_replace_all(prompt, "{instruction}", instruction);
|
||||
string_replace_all(prompt, "{query}" , query);
|
||||
string_replace_all(prompt, "{document}" , doc );
|
||||
server_tokens tokens = tokenize_input_subprompt(vocab, mctx, prompt, false, true);
|
||||
result.push_back(tokens);
|
||||
} else if (rerank_prompt != nullptr) {
|
||||
std::string prompt = rerank_prompt;
|
||||
string_replace_all(prompt, "{query}" , query);
|
||||
string_replace_all(prompt, "{document}", doc );
|
||||
|
|
|
|||
|
|
@ -369,5 +369,6 @@ server_tokens format_prompt_rerank(
|
|||
const struct llama_model * model,
|
||||
const struct llama_vocab * vocab,
|
||||
mtmd_context * mctx,
|
||||
const std::string & instruction,
|
||||
const std::string & query,
|
||||
const std::string & doc);
|
||||
|
|
|
|||
|
|
@ -3876,6 +3876,15 @@ void server_routes::init_routes() {
|
|||
res->error(format_error_response("\"query\" must be provided", ERROR_TYPE_INVALID_REQUEST));
|
||||
return res;
|
||||
}
|
||||
|
||||
json instruction = "";
|
||||
if (body.count("instruction") == 1) {
|
||||
instruction = body.at("instruction");
|
||||
if (!instruction.is_string()) {
|
||||
res->error(format_error_response("\"instruction\" must be a string", ERROR_TYPE_INVALID_REQUEST));
|
||||
return res;
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::string> documents = json_value(body, "documents",
|
||||
json_value(body, "texts", std::vector<std::string>()));
|
||||
|
|
@ -3893,7 +3902,7 @@ void server_routes::init_routes() {
|
|||
std::vector<server_task> tasks;
|
||||
tasks.reserve(documents.size());
|
||||
for (size_t i = 0; i < documents.size(); i++) {
|
||||
auto tmp = format_prompt_rerank(ctx_server.model, ctx_server.vocab, ctx_server.mctx, query, documents[i]);
|
||||
auto tmp = format_prompt_rerank(ctx_server.model, ctx_server.vocab, ctx_server.mctx, instruction, query, documents[i]);
|
||||
server_task task = server_task(SERVER_TASK_TYPE_RERANK);
|
||||
task.id = rd.get_new_id();
|
||||
task.tokens = std::move(tmp);
|
||||
|
|
|
|||
Loading…
Reference in New Issue