llama : add support for qwen3 reranker (#15824)

This commit is contained in:
Douglas Hanley 2025-09-25 03:53:09 -05:00 committed by GitHub
parent dfcd53f7ec
commit b5bd037832
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 166 additions and 78 deletions

View File

@ -961,15 +961,13 @@ struct common_init_result common_init_from_params(common_params & params) {
bool has_eos = llama_vocab_eos(vocab) != LLAMA_TOKEN_NULL; bool has_eos = llama_vocab_eos(vocab) != LLAMA_TOKEN_NULL;
bool has_sep = llama_vocab_sep(vocab) != LLAMA_TOKEN_NULL; bool has_sep = llama_vocab_sep(vocab) != LLAMA_TOKEN_NULL;
bool has_rerank_prompt = llama_model_chat_template(model, "rerank") != NULL;
if (!has_eos && !has_sep) { if (!has_eos && !has_sep && !has_rerank_prompt) {
LOG_WRN("%s: warning: vocab does not have an EOS token or SEP token, reranking will not work\n", __func__); LOG_WRN("%s: warning: vocab does not have an EOS token, SEP token, or rerank prompt. Reranking will not work\n", __func__);
ok = false; ok = false;
} else if (!has_eos) { } else if (!has_eos) {
LOG_WRN("%s: warning: vocab does not have an EOS token, using SEP token as fallback\n", __func__); LOG_WRN("%s: warning: vocab does not have an EOS token, using SEP token as fallback\n", __func__);
} else if (!has_sep) {
LOG_WRN("%s: warning: vocab does not have a SEP token, reranking will not work\n", __func__);
ok = false;
} }
if (!ok) { if (!ok) {

View File

@ -3717,11 +3717,29 @@ class Qwen2MoeModel(TextModel):
class Qwen3Model(Qwen2Model): class Qwen3Model(Qwen2Model):
model_arch = gguf.MODEL_ARCH.QWEN3 model_arch = gguf.MODEL_ARCH.QWEN3
# extra logic for rerank models
is_rerank: bool = False
is_tied_embeddings: bool = False
token_false_id: int | None = None
token_true_id: int | None = None
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
# track for intern-s1-mini
hparams = ModelBase.load_hparams(self.dir_model, is_mistral_format=False) hparams = ModelBase.load_hparams(self.dir_model, is_mistral_format=False)
self.origin_hf_arch = hparams.get('architectures', [None])[0] self.origin_hf_arch = hparams.get('architectures', [None])[0]
# a bit hacky, but currently the only way to detect if this is a rerank model
# ref: https://huggingface.co/Qwen/Qwen3-Reranker-0.6B
readme_path = self.dir_model / "README.md"
readme_text = ""
if readme_path.exists():
with readme_path.open("r", encoding="utf-8") as f:
readme_text = f.read()
if "# Qwen3-Reranker" in readme_text:
self._find_rerank_config()
def set_vocab(self): def set_vocab(self):
# deal with intern-s1-mini # deal with intern-s1-mini
if self.origin_hf_arch == 'InternS1ForConditionalGeneration': if self.origin_hf_arch == 'InternS1ForConditionalGeneration':
@ -3730,6 +3748,53 @@ class Qwen3Model(Qwen2Model):
super().set_vocab() super().set_vocab()
def _find_rerank_config(self):
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(self.dir_model)
self.is_rerank = True
self.is_tied_embeddings = self.hparams.get("tie_word_embeddings", False)
self.token_false_id = tokenizer.convert_tokens_to_ids("no")
self.token_true_id = tokenizer.convert_tokens_to_ids("yes")
self.sep_token_id = tokenizer.convert_tokens_to_ids("|")
assert self.token_false_id is not None and self.token_true_id is not None
def set_gguf_parameters(self):
super().set_gguf_parameters()
if self.is_rerank:
self.gguf_writer.add_pooling_type(gguf.PoolingType.RANK)
self.gguf_writer.add_classifier_output_labels(["yes", "no"])
self.gguf_writer.add_chat_template([{
"name": "rerank",
"template": "<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".<|im_end|>\n"
"<|im_start|>user\n<Instruct>: Given a web search query, retrieve relevant passages that answer the query\n<Query>: {query}\n<Document>: {document}<|im_end|>\n"
"<|im_start|>assistant\n<think>\n\n</think>\n\n"
}])
def _get_cls_out_tensor(self, data_torch: Tensor) -> Tensor:
# extract "yes" and "no" tokens from the output lm_head tensor
false_row = data_torch[self.token_false_id]
true_row = data_torch[self.token_true_id]
return torch.stack([true_row, false_row], dim=0)
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
if self.is_rerank:
is_tied_head = self.is_tied_embeddings and "embed_tokens" in name
is_real_head = not self.is_tied_embeddings and "lm_head" in name
if is_tied_head or is_real_head:
cls_out_head = (
gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.CLS_OUT] + ".weight",
self._get_cls_out_tensor(data_torch),
)
if is_tied_head:
embed = (self.map_tensor_name(name), data_torch)
return [cls_out_head, embed]
if is_real_head:
return [cls_out_head]
return super().modify_tensors(data_torch, name, bid)
@ModelBase.register("Qwen3MoeForCausalLM") @ModelBase.register("Qwen3MoeForCausalLM")
class Qwen3MoeModel(Qwen2MoeModel): class Qwen3MoeModel(Qwen2MoeModel):

View File

@ -95,8 +95,13 @@ int main(int argc, char ** argv) {
params.n_batch = params.n_ctx; params.n_batch = params.n_ctx;
} }
// For non-causal models, batch size must be equal to ubatch size // for non-causal models, batch size must be equal to ubatch size
if (params.attention_type != LLAMA_ATTENTION_TYPE_CAUSAL) {
params.n_ubatch = params.n_batch; params.n_ubatch = params.n_batch;
}
// get max number of sequences per batch
const int n_seq_max = llama_max_parallel_sequences();
llama_backend_init(); llama_backend_init();
llama_numa_init(params.numa); llama_numa_init(params.numa);
@ -144,6 +149,7 @@ int main(int argc, char ** argv) {
// get added sep and eos token, if any // get added sep and eos token, if any
const std::string added_sep_token = llama_vocab_get_add_sep(vocab) ? llama_vocab_get_text(vocab, llama_vocab_sep(vocab)) : ""; const std::string added_sep_token = llama_vocab_get_add_sep(vocab) ? llama_vocab_get_text(vocab, llama_vocab_sep(vocab)) : "";
const std::string added_eos_token = llama_vocab_get_add_eos(vocab) ? llama_vocab_get_text(vocab, llama_vocab_eos(vocab)) : ""; const std::string added_eos_token = llama_vocab_get_add_eos(vocab) ? llama_vocab_get_text(vocab, llama_vocab_eos(vocab)) : "";
const char * rerank_prompt = llama_model_chat_template(model, "rerank");
// tokenize the prompts and trim // tokenize the prompts and trim
std::vector<std::vector<int32_t>> inputs; std::vector<std::vector<int32_t>> inputs;
@ -153,8 +159,15 @@ int main(int argc, char ** argv) {
// split classification pairs and insert expected separator tokens // split classification pairs and insert expected separator tokens
if (pooling_type == LLAMA_POOLING_TYPE_RANK && prompt.find(params.cls_sep) != std::string::npos) { if (pooling_type == LLAMA_POOLING_TYPE_RANK && prompt.find(params.cls_sep) != std::string::npos) {
std::vector<std::string> pairs = split_lines(prompt, params.cls_sep); std::vector<std::string> pairs = split_lines(prompt, params.cls_sep);
if (rerank_prompt != nullptr) {
const std::string query = pairs[0];
const std::string doc = pairs[1];
std::string final_prompt = rerank_prompt;
string_replace_all(final_prompt, "{query}" , query);
string_replace_all(final_prompt, "{document}", doc );
inp = common_tokenize(vocab, final_prompt, true, true);
} else {
std::string final_prompt; std::string final_prompt;
for (size_t i = 0; i < pairs.size(); i++) { for (size_t i = 0; i < pairs.size(); i++) {
final_prompt += pairs[i]; final_prompt += pairs[i];
if (i != pairs.size() - 1) { if (i != pairs.size() - 1) {
@ -166,8 +179,8 @@ int main(int argc, char ** argv) {
} }
} }
} }
inp = common_tokenize(ctx, final_prompt, true, true); inp = common_tokenize(ctx, final_prompt, true, true);
}
} else { } else {
inp = common_tokenize(ctx, prompt, true, true); inp = common_tokenize(ctx, prompt, true, true);
} }
@ -229,7 +242,7 @@ int main(int argc, char ** argv) {
const uint64_t n_toks = inp.size(); const uint64_t n_toks = inp.size();
// encode if at capacity // encode if at capacity
if (batch.n_tokens + n_toks > n_batch) { if (batch.n_tokens + n_toks > n_batch || s >= n_seq_max) {
float * out = emb + e * n_embd; float * out = emb + e * n_embd;
batch_decode(ctx, batch, out, s, n_embd, params.embd_normalize); batch_decode(ctx, batch, out, s, n_embd, params.embd_normalize);
e += pooling_type == LLAMA_POOLING_TYPE_NONE ? batch.n_tokens : s; e += pooling_type == LLAMA_POOLING_TYPE_NONE ? batch.n_tokens : s;

View File

@ -721,6 +721,7 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" }, { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" }, { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" }, { LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_CLS_OUT, "cls.output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },

View File

@ -204,7 +204,10 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
std::vector<int> target_pos(n_seqs_unq, -1); std::vector<int> target_pos(n_seqs_unq, -1);
std::vector<int> target_row(n_seqs_unq, -1); std::vector<int> target_row(n_seqs_unq, -1);
bool last = cparams.pooling_type == LLAMA_POOLING_TYPE_LAST; const bool last = (
cparams.pooling_type == LLAMA_POOLING_TYPE_LAST ||
(cparams.pooling_type == LLAMA_POOLING_TYPE_RANK && arch == LLM_ARCH_QWEN3) // qwen3 reranking & embedding models use last token
);
for (int i = 0; i < n_tokens; ++i) { for (int i = 0; i < n_tokens; ++i) {
const llama_pos pos = ubatch->pos[i]; const llama_pos pos = ubatch->pos[i];
@ -1177,7 +1180,7 @@ ggml_tensor * llm_graph_context::build_inp_mean() const {
} }
ggml_tensor * llm_graph_context::build_inp_cls() const { ggml_tensor * llm_graph_context::build_inp_cls() const {
auto inp = std::make_unique<llm_graph_input_cls>(cparams); auto inp = std::make_unique<llm_graph_input_cls>(cparams, arch);
auto & cur = inp->cls; auto & cur = inp->cls;
@ -1877,34 +1880,32 @@ void llm_graph_context::build_pooling(
case LLAMA_POOLING_TYPE_RANK: case LLAMA_POOLING_TYPE_RANK:
{ {
ggml_tensor * inp_cls = build_inp_cls(); ggml_tensor * inp_cls = build_inp_cls();
inp = ggml_get_rows(ctx0, inp, inp_cls); cur = ggml_get_rows(ctx0, inp, inp_cls);
if (cls) {
// classification head // classification head
// https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566 // https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566
cur = ggml_mul_mat(ctx0, cls, inp); if (cls) {
cur = ggml_mul_mat(ctx0, cls, cur);
if (cls_b) { if (cls_b) {
cur = ggml_add(ctx0, cur, cls_b); cur = ggml_add(ctx0, cur, cls_b);
} }
cur = ggml_tanh(ctx0, cur); cur = ggml_tanh(ctx0, cur);
}
// some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en // some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en
// https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/blob/cb5347e43979c3084a890e3f99491952603ae1b7/modeling_bert.py#L884-L896 // https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/blob/cb5347e43979c3084a890e3f99491952603ae1b7/modeling_bert.py#L884-L896
// Single layer classification head (direct projection)
// https://github.com/huggingface/transformers/blob/f4fc42216cd56ab6b68270bf80d811614d8d59e4/src/transformers/models/bert/modeling_bert.py#L1476
if (cls_out) { if (cls_out) {
cur = ggml_mul_mat(ctx0, cls_out, cur); cur = ggml_mul_mat(ctx0, cls_out, cur);
if (cls_out_b) { if (cls_out_b) {
cur = ggml_add(ctx0, cur, cls_out_b); cur = ggml_add(ctx0, cur, cls_out_b);
} }
} }
} else if (cls_out) {
// Single layer classification head (direct projection) // softmax for qwen3 reranker
// https://github.com/huggingface/transformers/blob/f4fc42216cd56ab6b68270bf80d811614d8d59e4/src/transformers/models/bert/modeling_bert.py#L1476 if (arch == LLM_ARCH_QWEN3) {
cur = ggml_mul_mat(ctx0, cls_out, inp); cur = ggml_soft_max(ctx0, cur);
if (cls_out_b) {
cur = ggml_add(ctx0, cur, cls_out_b);
}
} else {
GGML_ABORT("RANK pooling requires either cls+cls_b or cls_out+cls_out_b");
} }
} break; } break;
default: default:

View File

@ -206,7 +206,7 @@ public:
class llm_graph_input_cls : public llm_graph_input_i { class llm_graph_input_cls : public llm_graph_input_i {
public: public:
llm_graph_input_cls(const llama_cparams & cparams) : cparams(cparams) {} llm_graph_input_cls(const llama_cparams & cparams, const llm_arch arch) : cparams(cparams), arch(arch) {}
virtual ~llm_graph_input_cls() = default; virtual ~llm_graph_input_cls() = default;
void set_input(const llama_ubatch * ubatch) override; void set_input(const llama_ubatch * ubatch) override;
@ -214,6 +214,7 @@ public:
ggml_tensor * cls; // I32 [n_batch] ggml_tensor * cls; // I32 [n_batch]
const llama_cparams cparams; const llama_cparams cparams;
const llm_arch arch;
}; };
class llm_graph_input_rs : public llm_graph_input_i { class llm_graph_input_rs : public llm_graph_input_i {

View File

@ -3167,6 +3167,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
} }
// output rerank head
cls_out = create_tensor(tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, hparams.n_cls_out}, TENSOR_NOT_REQUIRED);
for (int i = 0; i < n_layer; ++i) { for (int i = 0; i < n_layer; ++i) {
auto & layer = layers[i]; auto & layer = layers[i];

View File

@ -5093,21 +5093,15 @@ int main(int argc, char ** argv) {
return; return;
} }
std::vector<server_tokens> tokenized_queries = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, query, /* add_special */ false, true);
if (tokenized_queries.size() != 1) {
res_error(res, format_error_response("\"query\" must contain only a single prompt", ERROR_TYPE_INVALID_REQUEST));
}
// create and queue the task // create and queue the task
json responses = json::array(); json responses = json::array();
bool error = false; bool error = false;
std::unordered_set<int> task_ids; std::unordered_set<int> task_ids;
{ {
std::vector<server_task> tasks; std::vector<server_task> tasks;
auto tokenized_docs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, documents, /* add_special */ false, true); tasks.reserve(documents.size());
tasks.reserve(tokenized_docs.size()); for (size_t i = 0; i < documents.size(); i++) {
for (size_t i = 0; i < tokenized_docs.size(); i++) { auto tmp = format_rerank(ctx_server.model, ctx_server.vocab, ctx_server.mctx, query, documents[i]);
auto tmp = format_rerank(ctx_server.vocab, tokenized_queries[0], tokenized_docs[i]);
server_task task = server_task(SERVER_TASK_TYPE_RERANK); server_task task = server_task(SERVER_TASK_TYPE_RERANK);
task.id = ctx_server.queue_tasks.get_new_id(); task.id = ctx_server.queue_tasks.get_new_id();
task.index = i; task.index = i;

View File

@ -1368,34 +1368,6 @@ static std::string fnv_hash(const uint8_t * data, size_t len) {
return std::to_string(hash); return std::to_string(hash);
} }
// format rerank task: [BOS]query[EOS][SEP]doc[EOS].
static server_tokens format_rerank(const struct llama_vocab * vocab, server_tokens & query, server_tokens & doc) {
server_tokens result = {};
// Get EOS token - use SEP token as fallback if EOS is not available
llama_token eos_token = llama_vocab_eos(vocab);
if (eos_token == LLAMA_TOKEN_NULL) {
eos_token = llama_vocab_sep(vocab);
}
if (llama_vocab_get_add_bos(vocab)) {
result.push_back(llama_vocab_bos(vocab));
}
result.push_back(query);
if (llama_vocab_get_add_eos(vocab)) {
result.push_back(eos_token);
}
if (llama_vocab_get_add_sep(vocab)) {
result.push_back(llama_vocab_sep(vocab));
}
result.push_back(doc);
if (llama_vocab_get_add_eos(vocab)) {
result.push_back(eos_token);
}
return result;
}
static server_tokens process_mtmd_prompt(mtmd_context * mctx, std::string prompt, std::vector<raw_buffer> files) { static server_tokens process_mtmd_prompt(mtmd_context * mctx, std::string prompt, std::vector<raw_buffer> files) {
mtmd::bitmaps bitmaps; mtmd::bitmaps bitmaps;
for (auto & file : files) { for (auto & file : files) {
@ -1501,3 +1473,43 @@ static std::vector<server_tokens> tokenize_input_prompts(const llama_vocab * voc
} }
return result; return result;
} }
// format rerank task: [BOS]query[EOS][SEP]doc[EOS].
static server_tokens format_rerank(const struct llama_model * model, const struct llama_vocab * vocab, mtmd_context * mctx, const std::string & query, const std::string & doc) {
server_tokens result = {};
const char * rerank_prompt = llama_model_chat_template(model, "rerank");
if (rerank_prompt != nullptr) {
std::string prompt = rerank_prompt;
string_replace_all(prompt, "{query}" , query);
string_replace_all(prompt, "{document}", doc );
server_tokens tokens = tokenize_input_subprompt(vocab, mctx, prompt, false, true);
result.push_back(tokens);
} else {
// Get EOS token - use SEP token as fallback if EOS is not available
server_tokens query_tokens = tokenize_input_subprompt(vocab, mctx, query, false, false);
server_tokens doc_tokens = tokenize_input_subprompt(vocab, mctx, doc, false, false);
llama_token eos_token = llama_vocab_eos(vocab);
if (eos_token == LLAMA_TOKEN_NULL) {
eos_token = llama_vocab_sep(vocab);
}
if (llama_vocab_get_add_bos(vocab)) {
result.push_back(llama_vocab_bos(vocab));
}
result.push_back(query_tokens);
if (llama_vocab_get_add_eos(vocab)) {
result.push_back(eos_token);
}
if (llama_vocab_get_add_sep(vocab)) {
result.push_back(llama_vocab_sep(vocab));
}
result.push_back(doc_tokens);
if (llama_vocab_get_add_eos(vocab)) {
result.push_back(eos_token);
}
}
return result;
}