Merge branch 'master' into xsn/convert_gguf_qwen2vl

This commit is contained in:
Xuan Son Nguyen 2025-04-30 17:14:21 +02:00
commit f48f51d185
5 changed files with 76 additions and 45 deletions

View File

@ -1948,6 +1948,23 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.sampling.grammar = json_schema_to_grammar(json::parse(value)); params.sampling.grammar = json_schema_to_grammar(json::parse(value));
} }
).set_sparam()); ).set_sparam());
add_opt(common_arg(
{"-jf", "--json-schema-file"}, "FILE",
"File containing a JSON schema to constrain generations (https://json-schema.org/), e.g. `{}` for any JSON object\nFor schemas w/ external $refs, use --grammar + example/json_schema_to_grammar.py instead",
[](common_params & params, const std::string & value) {
std::ifstream file(value);
if (!file) {
throw std::runtime_error(string_format("error: failed to open file '%s'\n", value.c_str()));
}
std::string schema;
std::copy(
std::istreambuf_iterator<char>(file),
std::istreambuf_iterator<char>(),
std::back_inserter(schema)
);
params.sampling.grammar = json_schema_to_grammar(json::parse(schema));
}
).set_sparam());
add_opt(common_arg( add_opt(common_arg(
{"--pooling"}, "{none,mean,cls,last,rank}", {"--pooling"}, "{none,mean,cls,last,rank}",
"pooling type for embeddings, use model default if unspecified", "pooling type for embeddings, use model default if unspecified",

View File

@ -16,6 +16,7 @@ from pathlib import Path
from hashlib import sha256 from hashlib import sha256
from typing import TYPE_CHECKING, Any, Callable, ContextManager, Iterable, Iterator, Literal, Sequence, TypeVar, cast from typing import TYPE_CHECKING, Any, Callable, ContextManager, Iterable, Iterator, Literal, Sequence, TypeVar, cast
from itertools import chain from itertools import chain
from transformers import AutoConfig
import math import math
import numpy as np import numpy as np
@ -66,8 +67,6 @@ class ModelBase:
part_names: list[str] part_names: list[str]
is_safetensors: bool is_safetensors: bool
hparams: dict[str, Any] hparams: dict[str, Any]
block_count: int
tensor_map: gguf.TensorNameMap
tensor_names: set[str] | None tensor_names: set[str] | None
gguf_writer: gguf.GGUFWriter gguf_writer: gguf.GGUFWriter
model_name: str | None model_name: str | None
@ -78,6 +77,10 @@ class ModelBase:
# subclasses should define this! # subclasses should define this!
model_arch: gguf.MODEL_ARCH model_arch: gguf.MODEL_ARCH
# subclasses should initialize this!
block_count: int
tensor_map: gguf.TensorNameMap
def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, *, is_big_endian: bool = False, def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, *, is_big_endian: bool = False,
use_temp_file: bool = False, eager: bool = False, use_temp_file: bool = False, eager: bool = False,
metadata_override: Path | None = None, model_name: str | None = None, metadata_override: Path | None = None, model_name: str | None = None,
@ -113,8 +116,6 @@ class ModelBase:
if not self.is_safetensors: if not self.is_safetensors:
self.part_names = ModelBase.get_model_part_names(self.dir_model, "pytorch_model", ".bin") self.part_names = ModelBase.get_model_part_names(self.dir_model, "pytorch_model", ".bin")
self.hparams = ModelBase.load_hparams(self.dir_model) if hparams is None else hparams self.hparams = ModelBase.load_hparams(self.dir_model) if hparams is None else hparams
self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer", "num_layers"])
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
self.tensor_names = None self.tensor_names = None
self.metadata_override = metadata_override self.metadata_override = metadata_override
self.model_name = model_name self.model_name = model_name
@ -417,15 +418,13 @@ class ModelBase:
@staticmethod @staticmethod
def load_hparams(dir_model: Path): def load_hparams(dir_model: Path):
with open(dir_model / "config.json", "r", encoding="utf-8") as f: try:
hparams = json.load(f) return AutoConfig.from_pretrained(dir_model).to_dict()
architectures = hparams.get("architectures") except Exception as e:
if "text_config" in hparams: logger.warning(f"Failed to load model config from {dir_model}: {e}")
hparams = {**hparams, **hparams["text_config"]} logger.warning("Trying to load config.json instead")
if architectures is not None: with open(dir_model / "config.json", "r", encoding="utf-8") as f:
# preserve "architectures" from root level config return json.load(f)
hparams["architectures"] = architectures
return hparams
@classmethod @classmethod
def register(cls, *names: str) -> Callable[[AnyModel], AnyModel]: def register(cls, *names: str) -> Callable[[AnyModel], AnyModel]:
@ -454,6 +453,23 @@ class ModelBase:
class TextModel(ModelBase): class TextModel(ModelBase):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if "text_config" in self.hparams:
# move the text_config to the root level
self.hparams = {**self.hparams, **self.hparams["text_config"]}
self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer", "num_layers"])
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
@classmethod
def __init_subclass__(cls):
# can't use an abstract property, because overriding it without type errors
# would require using decorated functions instead of simply defining the property
if "model_arch" not in cls.__dict__:
raise TypeError(f"Missing property 'model_arch' for {cls.__name__!r}")
def set_vocab(self): def set_vocab(self):
self._set_vocab_gpt2() self._set_vocab_gpt2()
@ -1070,9 +1086,9 @@ class VisionModel(ModelBase):
if self.model_arch != gguf.MODEL_ARCH.CLIP_VISION: if self.model_arch != gguf.MODEL_ARCH.CLIP_VISION:
raise TypeError("VisionModel must be subclassed with model_arch = gguf.MODEL_ARCH.CLIP_VISION") raise TypeError("VisionModel must be subclassed with model_arch = gguf.MODEL_ARCH.CLIP_VISION")
# small hack to correct the number of layers # get n_embd of the text model
self.tensor_map = gguf.get_tensor_name_map(gguf.MODEL_ARCH.CLIP_VISION, 128) text_config = {**self.hparams, **self.hparams["text_config"]}
self.n_embd_text = self.find_hparam(["hidden_size", "n_embd"]) self.n_embd_text = text_config.get("hidden_size", text_config.get("n_embd", 0))
assert self.n_embd_text > 0, "n_embd not found in hparams" assert self.n_embd_text > 0, "n_embd not found in hparams"
if "vision_config" not in self.hparams: if "vision_config" not in self.hparams:
@ -1081,6 +1097,9 @@ class VisionModel(ModelBase):
self.global_config = self.hparams self.global_config = self.hparams
self.hparams = self.hparams["vision_config"] self.hparams = self.hparams["vision_config"]
self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer", "num_layers", "depth"])
self.tensor_map = gguf.get_tensor_name_map(gguf.MODEL_ARCH.CLIP_VISION, self.block_count)
# load preprocessor config # load preprocessor config
with open(self.dir_model / "preprocessor_config.json", "r", encoding="utf-8") as f: with open(self.dir_model / "preprocessor_config.json", "r", encoding="utf-8") as f:
self.preprocessor_config = json.load(f) self.preprocessor_config = json.load(f)
@ -1098,7 +1117,7 @@ class VisionModel(ModelBase):
self.gguf_writer.add_vision_patch_size(self.find_hparam(["patch_size"])) self.gguf_writer.add_vision_patch_size(self.find_hparam(["patch_size"]))
self.gguf_writer.add_vision_embedding_length(self.find_hparam(["hidden_size"])) self.gguf_writer.add_vision_embedding_length(self.find_hparam(["hidden_size"]))
self.gguf_writer.add_vision_feed_forward_length(self.find_hparam(["intermediate_size"])) self.gguf_writer.add_vision_feed_forward_length(self.find_hparam(["intermediate_size"]))
self.gguf_writer.add_vision_block_count(self.find_hparam(["num_hidden_layers"])) self.gguf_writer.add_vision_block_count(self.block_count)
self.gguf_writer.add_vision_head_count(self.find_hparam(["num_attention_heads"])) self.gguf_writer.add_vision_head_count(self.find_hparam(["num_attention_heads"]))
# preprocessor config # preprocessor config
@ -1719,23 +1738,12 @@ class StableLMModel(TextModel):
"LlamaForCausalLM", "LlamaForCausalLM",
"MistralForCausalLM", "MistralForCausalLM",
"MixtralForCausalLM", "MixtralForCausalLM",
"Idefics3ForConditionalGeneration", "VLlama3ForCausalLM",
"SmolVLMForConditionalGeneration",
"LlavaForConditionalGeneration") "LlavaForConditionalGeneration")
class LlamaModel(TextModel): class LlamaModel(TextModel):
model_arch = gguf.MODEL_ARCH.LLAMA model_arch = gguf.MODEL_ARCH.LLAMA
undo_permute = True undo_permute = True
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# fix for SmolVLM2, missing `num_attention_heads` in config.json
if self.hparams["architectures"][0] == "SmolVLMForConditionalGeneration":
self.hparams["num_attention_heads"] = self.hparams.get("num_attention_heads", 32)
# fix for Pixtral, missing `num_attention_heads` in config.json
if self.hparams["architectures"][0] == "LlavaForConditionalGeneration" \
and self.hparams.get("model_type") == "mistral":
self.hparams["num_attention_heads"] = self.hparams.get("num_attention_heads", 32)
def set_vocab(self): def set_vocab(self):
try: try:
self._set_vocab_sentencepiece() self._set_vocab_sentencepiece()
@ -1898,11 +1906,7 @@ class LlavaVisionModel(VisionModel):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
if self.hparams["model_type"] == "pixtral": if self.hparams["model_type"] == "pixtral":
# fix missing config.json values # layer_norm_eps is not in config.json, it is hard-coded in modeling_pixtral.py
self.hparams["num_attention_heads"] = self.hparams.get("num_attention_heads", 16)
self.hparams["num_hidden_layers"] = self.hparams.get("num_hidden_layers", 24)
self.hparams["intermediate_size"] = self.hparams.get("intermediate_size", 4096)
self.hparams["hidden_size"] = self.hparams.get("hidden_size", 1024)
self.hparams["layer_norm_eps"] = self.hparams.get("layer_norm_eps", 1e-5) self.hparams["layer_norm_eps"] = self.hparams.get("layer_norm_eps", 1e-5)
self.img_break_tok_id = 12 # see tokenizer_config.json self.img_break_tok_id = 12 # see tokenizer_config.json
else: else:
@ -1913,7 +1917,6 @@ class LlavaVisionModel(VisionModel):
hparams = self.hparams hparams = self.hparams
if hparams["model_type"] == "pixtral": if hparams["model_type"] == "pixtral":
self.gguf_writer.add_vision_projector_type(gguf.VisionProjectorType.PIXTRAL) self.gguf_writer.add_vision_projector_type(gguf.VisionProjectorType.PIXTRAL)
# default values below are taken from HF tranformers code
self.gguf_writer.add_vision_attention_layernorm_eps(hparams["layer_norm_eps"]) self.gguf_writer.add_vision_attention_layernorm_eps(hparams["layer_norm_eps"])
self.gguf_writer.add_vision_use_silu(True) self.gguf_writer.add_vision_use_silu(True)
@ -1944,13 +1947,12 @@ class LlavaVisionModel(VisionModel):
class SmolVLMModel(VisionModel): class SmolVLMModel(VisionModel):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
# fix for SmolVLM2, missing some keys in config.json
# default values are taken from transformers code
if self.hparams["model_type"] == "smolvlm_vision": if self.hparams["model_type"] == "smolvlm_vision":
# fix for SmolVLM2, missing some keys in config.json
# default values are taken from transformers code
self.hparams["hidden_size"] = self.hparams.get("hidden_size", 1152) self.hparams["hidden_size"] = self.hparams.get("hidden_size", 1152)
self.hparams["num_attention_heads"] = self.hparams.get("num_attention_heads", 16) self.hparams["num_attention_heads"] = self.hparams.get("num_attention_heads", 16)
self.hparams["intermediate_size"] = self.hparams.get("intermediate_size", 3072) self.hparams["intermediate_size"] = self.hparams.get("intermediate_size", 3072)
self.hparams["num_hidden_layers"] = self.hparams.get("num_hidden_layers", 12)
def set_gguf_parameters(self): def set_gguf_parameters(self):
super().set_gguf_parameters() super().set_gguf_parameters()
@ -3581,6 +3583,8 @@ class RobertaModel(BertModel):
@ModelBase.register("NomicBertModel") @ModelBase.register("NomicBertModel")
class NomicBertModel(BertModel): class NomicBertModel(BertModel):
model_arch = gguf.MODEL_ARCH.BERT
def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, **kwargs: Any): def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, **kwargs: Any):
hparams = kwargs.pop("hparams", None) hparams = kwargs.pop("hparams", None)
if hparams is None: if hparams is None:
@ -5925,6 +5929,19 @@ def split_str_to_n_bytes(split_str: str) -> int:
return n return n
def get_model_architecture(dir_model: Path, model_type: ModelType, hparams: Any = None) -> str:
hparams = ModelBase.load_hparams(dir_model) if hparams is None else hparams
text_config = hparams.get("text_config", {})
vision_config = hparams.get("vision_config", {})
arch = hparams["architectures"][0]
# if "architectures" is found in the sub-config, use that instead
if model_type == ModelType.TEXT and text_config.get("architectures") is not None:
arch = text_config["architectures"][0]
elif model_type == ModelType.VISION and vision_config.get("architectures") is not None:
arch = vision_config["architectures"][0]
return arch
def main() -> None: def main() -> None:
args = parse_args() args = parse_args()
@ -5977,16 +5994,15 @@ def main() -> None:
logger.info(f"Loading model: {dir_model.name}") logger.info(f"Loading model: {dir_model.name}")
hparams = ModelBase.load_hparams(dir_model)
if args.mmproj: if args.mmproj:
if "mmproj" not in fname_out.name: if "mmproj" not in fname_out.name:
fname_out = ModelBase.add_prefix_to_filename(fname_out, "mmproj-") fname_out = ModelBase.add_prefix_to_filename(fname_out, "mmproj-")
with torch.inference_mode(): with torch.inference_mode():
output_type = ftype_map[args.outtype] output_type = ftype_map[args.outtype]
model_architecture = hparams["architectures"][0]
model_type = ModelType.VISION if args.mmproj else ModelType.TEXT model_type = ModelType.VISION if args.mmproj else ModelType.TEXT
model_architecture = get_model_architecture(dir_model, model_type)
logger.info(f"Model architecture: {model_architecture}")
try: try:
model_class = ModelBase.from_model_architecture(model_architecture, model_type=model_type) model_class = ModelBase.from_model_architecture(model_architecture, model_type=model_type)
except NotImplementedError: except NotImplementedError:

View File

@ -2,8 +2,6 @@
#include "gguf.h" #include "gguf.h"
#include "clip.h" #include "clip.h"
#include "clip.h"
#include <climits> #include <climits>
#include <cstdarg> #include <cstdarg>
#include <string> #include <string>

View File

@ -341,7 +341,7 @@ static inline void __avx_f32cx8_store(ggml_fp16_t *x, __m256 y) {
#define GGML_F32_EPR 4 #define GGML_F32_EPR 4
#define GGML_F32x4 vector float #define GGML_F32x4 vector float
#define GGML_F32x4_ZERO 0.0f #define GGML_F32x4_ZERO {0.0f}
#define GGML_F32x4_SET1 vec_splats #define GGML_F32x4_SET1 vec_splats
#define GGML_F32x4_LOAD(p) vec_xl(0, p) #define GGML_F32x4_LOAD(p) vec_xl(0, p)
#define GGML_F32x4_STORE(p, r) vec_xst(r, 0, p) #define GGML_F32x4_STORE(p, r) vec_xst(r, 0, p)

View File

@ -482,7 +482,7 @@ float16_t dequantFuncIQ2_XXS(const in decodeBufIQ2_XXS bl, const in uint blockCo
const uint ib8 = (idx & 0x18) >> 3; // 0..3 const uint ib8 = (idx & 0x18) >> 3; // 0..3
const uint iqs = 8 * ib32 + ib8; const uint iqs = 8 * ib32 + ib8;
const uint8_t qs = bl.block.qs[iqs]; const uint qs = bl.block.qs[iqs];
const uint signscale = pack32(u16vec2(bl16.block.qs[4*ib32+2], bl16.block.qs[4*ib32+3])); const uint signscale = pack32(u16vec2(bl16.block.qs[4*ib32+2], bl16.block.qs[4*ib32+3]));
const float dscale = float(bl.block.d) * 0.25 * (0.5 + float(signscale >> 28)); const float dscale = float(bl.block.d) * 0.25 * (0.5 + float(signscale >> 28));