model : qwen3vl reranker text support (#20332)
* model : fix qwen3vl reranker support * Remove CLS_OUT Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> --------- Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
This commit is contained in:
parent
10e5b148b0
commit
4d99d45084
|
|
@ -4390,15 +4390,31 @@ class Qwen3Model(Qwen2Model):
|
|||
hparams = ModelBase.load_hparams(self.dir_model, is_mistral_format=False)
|
||||
self.origin_hf_arch = hparams.get('architectures', [None])[0]
|
||||
|
||||
# a bit hacky, but currently the only way to detect if this is a rerank model
|
||||
# ref: https://huggingface.co/Qwen/Qwen3-Reranker-0.6B
|
||||
if self._is_qwen3_reranker():
|
||||
self._find_rerank_config()
|
||||
|
||||
def _is_qwen3_reranker(self) -> bool:
|
||||
readme_path = self.dir_model / "README.md"
|
||||
readme_text = ""
|
||||
if readme_path.exists():
|
||||
with readme_path.open("r", encoding="utf-8") as f:
|
||||
readme_text = f.read()
|
||||
if "# Qwen3-Reranker" in readme_text:
|
||||
self._find_rerank_config()
|
||||
|
||||
name_hints = [
|
||||
str(self.dir_model.name),
|
||||
str(self.hparams.get("_name_or_path", "")),
|
||||
str(self.hparams.get("model_type", "")),
|
||||
str(self.origin_hf_arch or ""),
|
||||
]
|
||||
name_hints = [hint.lower() for hint in name_hints if hint]
|
||||
|
||||
if "# qwen3-reranker" in readme_text.lower() or "# qwen3-vl-reranker" in readme_text.lower():
|
||||
return True
|
||||
|
||||
if any("qwen3-reranker" in hint or "qwen3-vl-reranker" in hint for hint in name_hints):
|
||||
return True
|
||||
|
||||
return "sequenceclassification" in (self.origin_hf_arch or "").lower()
|
||||
|
||||
def set_vocab(self):
|
||||
# deal with intern-s1-mini
|
||||
|
|
|
|||
|
|
@ -1087,6 +1087,7 @@ static std::set<llm_tensor> llm_get_tensor_names(llm_arch arch) {
|
|||
LLM_TENSOR_TOKEN_EMBD,
|
||||
LLM_TENSOR_OUTPUT_NORM,
|
||||
LLM_TENSOR_OUTPUT,
|
||||
LLM_TENSOR_CLS_OUT,
|
||||
LLM_TENSOR_ATTN_NORM,
|
||||
LLM_TENSOR_ATTN_Q,
|
||||
LLM_TENSOR_ATTN_Q_NORM,
|
||||
|
|
|
|||
|
|
@ -250,7 +250,7 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
|
|||
|
||||
const bool last = (
|
||||
cparams.pooling_type == LLAMA_POOLING_TYPE_LAST ||
|
||||
(cparams.pooling_type == LLAMA_POOLING_TYPE_RANK && arch == LLM_ARCH_QWEN3) // qwen3 reranking & embedding models use last token
|
||||
(cparams.pooling_type == LLAMA_POOLING_TYPE_RANK && (arch == LLM_ARCH_QWEN3 || arch == LLM_ARCH_QWEN3VL)) // qwen3 reranking & embedding models use last token
|
||||
);
|
||||
|
||||
for (int i = 0; i < n_tokens; ++i) {
|
||||
|
|
@ -2552,7 +2552,7 @@ void llm_graph_context::build_pooling(
|
|||
}
|
||||
|
||||
// softmax for qwen3 reranker
|
||||
if (arch == LLM_ARCH_QWEN3) {
|
||||
if (arch == LLM_ARCH_QWEN3 || arch == LLM_ARCH_QWEN3VL) {
|
||||
cur = ggml_soft_max(ctx0, cur);
|
||||
}
|
||||
} break;
|
||||
|
|
|
|||
Loading…
Reference in New Issue