model : qwen3vl reranker text support (#20332)
* model : fix qwen3vl reranker support * Remove CLS_OUT Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> --------- Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
This commit is contained in:
parent
10e5b148b0
commit
4d99d45084
|
|
@ -4390,15 +4390,31 @@ class Qwen3Model(Qwen2Model):
|
||||||
hparams = ModelBase.load_hparams(self.dir_model, is_mistral_format=False)
|
hparams = ModelBase.load_hparams(self.dir_model, is_mistral_format=False)
|
||||||
self.origin_hf_arch = hparams.get('architectures', [None])[0]
|
self.origin_hf_arch = hparams.get('architectures', [None])[0]
|
||||||
|
|
||||||
# a bit hacky, but currently the only way to detect if this is a rerank model
|
if self._is_qwen3_reranker():
|
||||||
# ref: https://huggingface.co/Qwen/Qwen3-Reranker-0.6B
|
self._find_rerank_config()
|
||||||
|
|
||||||
|
def _is_qwen3_reranker(self) -> bool:
|
||||||
readme_path = self.dir_model / "README.md"
|
readme_path = self.dir_model / "README.md"
|
||||||
readme_text = ""
|
readme_text = ""
|
||||||
if readme_path.exists():
|
if readme_path.exists():
|
||||||
with readme_path.open("r", encoding="utf-8") as f:
|
with readme_path.open("r", encoding="utf-8") as f:
|
||||||
readme_text = f.read()
|
readme_text = f.read()
|
||||||
if "# Qwen3-Reranker" in readme_text:
|
|
||||||
self._find_rerank_config()
|
name_hints = [
|
||||||
|
str(self.dir_model.name),
|
||||||
|
str(self.hparams.get("_name_or_path", "")),
|
||||||
|
str(self.hparams.get("model_type", "")),
|
||||||
|
str(self.origin_hf_arch or ""),
|
||||||
|
]
|
||||||
|
name_hints = [hint.lower() for hint in name_hints if hint]
|
||||||
|
|
||||||
|
if "# qwen3-reranker" in readme_text.lower() or "# qwen3-vl-reranker" in readme_text.lower():
|
||||||
|
return True
|
||||||
|
|
||||||
|
if any("qwen3-reranker" in hint or "qwen3-vl-reranker" in hint for hint in name_hints):
|
||||||
|
return True
|
||||||
|
|
||||||
|
return "sequenceclassification" in (self.origin_hf_arch or "").lower()
|
||||||
|
|
||||||
def set_vocab(self):
|
def set_vocab(self):
|
||||||
# deal with intern-s1-mini
|
# deal with intern-s1-mini
|
||||||
|
|
|
||||||
|
|
@ -1087,6 +1087,7 @@ static std::set<llm_tensor> llm_get_tensor_names(llm_arch arch) {
|
||||||
LLM_TENSOR_TOKEN_EMBD,
|
LLM_TENSOR_TOKEN_EMBD,
|
||||||
LLM_TENSOR_OUTPUT_NORM,
|
LLM_TENSOR_OUTPUT_NORM,
|
||||||
LLM_TENSOR_OUTPUT,
|
LLM_TENSOR_OUTPUT,
|
||||||
|
LLM_TENSOR_CLS_OUT,
|
||||||
LLM_TENSOR_ATTN_NORM,
|
LLM_TENSOR_ATTN_NORM,
|
||||||
LLM_TENSOR_ATTN_Q,
|
LLM_TENSOR_ATTN_Q,
|
||||||
LLM_TENSOR_ATTN_Q_NORM,
|
LLM_TENSOR_ATTN_Q_NORM,
|
||||||
|
|
|
||||||
|
|
@ -250,7 +250,7 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
|
||||||
|
|
||||||
const bool last = (
|
const bool last = (
|
||||||
cparams.pooling_type == LLAMA_POOLING_TYPE_LAST ||
|
cparams.pooling_type == LLAMA_POOLING_TYPE_LAST ||
|
||||||
(cparams.pooling_type == LLAMA_POOLING_TYPE_RANK && arch == LLM_ARCH_QWEN3) // qwen3 reranking & embedding models use last token
|
(cparams.pooling_type == LLAMA_POOLING_TYPE_RANK && (arch == LLM_ARCH_QWEN3 || arch == LLM_ARCH_QWEN3VL)) // qwen3 reranking & embedding models use last token
|
||||||
);
|
);
|
||||||
|
|
||||||
for (int i = 0; i < n_tokens; ++i) {
|
for (int i = 0; i < n_tokens; ++i) {
|
||||||
|
|
@ -2552,7 +2552,7 @@ void llm_graph_context::build_pooling(
|
||||||
}
|
}
|
||||||
|
|
||||||
// softmax for qwen3 reranker
|
// softmax for qwen3 reranker
|
||||||
if (arch == LLM_ARCH_QWEN3) {
|
if (arch == LLM_ARCH_QWEN3 || arch == LLM_ARCH_QWEN3VL) {
|
||||||
cur = ggml_soft_max(ctx0, cur);
|
cur = ggml_soft_max(ctx0, cur);
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue