model: EmbeddingGemma Adding Support for SentenceTransformers Dense Modules (#16367)
* model: EmbeddingGemma sentence-transformers dense linear projections support * model: add support for EmbeddingGemma SentenceTransformers dense linear projections Adding support for the Dense modules used in EmbeddingGemma models. EmbeddingGemma is a SentenceTransformers model with additional modules beyond the base Transformer backbone. See: https://developers.googleblog.com/en/gemma-explained-embeddinggemma-architecture-and-recipe/ * model: add support for EmbeddingGemma SentenceTransformers dense linear projections - converting model with dense-layers is optional - introduced dense config params * Update convert_hf_to_gguf.py Co-authored-by: Daniel Bevenius <daniel.bevenius@gmail.com> * fixed formatting issues * Update src/llama-graph.cpp Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * - removed pooling_type_opt, always allow overriding pooling_type - asserts checking dense features dims * fix python lint * fix ubuntu gcc build warning * - fixed thread-safety test - moved asserts to load_hparams * - tidying up code - simplifying graph-context expecting both dense weights * minor : add TODO --------- Co-authored-by: Daniel Bevenius <daniel.bevenius@gmail.com> Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
parent
12bbc3fa50
commit
e08db42595
|
|
@ -93,13 +93,15 @@ class ModelBase:
|
||||||
# Mistral format specifics
|
# Mistral format specifics
|
||||||
is_mistral_format: bool = False
|
is_mistral_format: bool = False
|
||||||
disable_mistral_community_chat_template: bool = False
|
disable_mistral_community_chat_template: bool = False
|
||||||
|
sentence_transformers_dense_modules: bool = False
|
||||||
|
|
||||||
def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, *, is_big_endian: bool = False,
|
def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, *, is_big_endian: bool = False,
|
||||||
use_temp_file: bool = False, eager: bool = False,
|
use_temp_file: bool = False, eager: bool = False,
|
||||||
metadata_override: Path | None = None, model_name: str | None = None,
|
metadata_override: Path | None = None, model_name: str | None = None,
|
||||||
split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False,
|
split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False,
|
||||||
small_first_shard: bool = False, hparams: dict[str, Any] | None = None, remote_hf_model_id: str | None = None,
|
small_first_shard: bool = False, hparams: dict[str, Any] | None = None, remote_hf_model_id: str | None = None,
|
||||||
disable_mistral_community_chat_template: bool = False):
|
disable_mistral_community_chat_template: bool = False,
|
||||||
|
sentence_transformers_dense_modules: bool = False):
|
||||||
if type(self) is ModelBase or \
|
if type(self) is ModelBase or \
|
||||||
type(self) is TextModel or \
|
type(self) is TextModel or \
|
||||||
type(self) is MmprojModel:
|
type(self) is MmprojModel:
|
||||||
|
|
@ -114,6 +116,7 @@ class ModelBase:
|
||||||
self.lazy = not eager or (remote_hf_model_id is not None)
|
self.lazy = not eager or (remote_hf_model_id is not None)
|
||||||
self.dry_run = dry_run
|
self.dry_run = dry_run
|
||||||
self.remote_hf_model_id = remote_hf_model_id
|
self.remote_hf_model_id = remote_hf_model_id
|
||||||
|
self.sentence_transformers_dense_modules = sentence_transformers_dense_modules
|
||||||
if remote_hf_model_id is not None:
|
if remote_hf_model_id is not None:
|
||||||
self.is_safetensors = True
|
self.is_safetensors = True
|
||||||
|
|
||||||
|
|
@ -5269,6 +5272,53 @@ class Gemma3Model(TextModel):
|
||||||
@ModelBase.register("Gemma3TextModel")
|
@ModelBase.register("Gemma3TextModel")
|
||||||
class EmbeddingGemma(Gemma3Model):
|
class EmbeddingGemma(Gemma3Model):
|
||||||
model_arch = gguf.MODEL_ARCH.GEMMA_EMBEDDING
|
model_arch = gguf.MODEL_ARCH.GEMMA_EMBEDDING
|
||||||
|
module_paths = []
|
||||||
|
dense_features_dims = {}
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
if self.sentence_transformers_dense_modules:
|
||||||
|
# read modules.json to determine if model has Dense layers
|
||||||
|
modules_file = self.dir_model / "modules.json"
|
||||||
|
if modules_file.is_file():
|
||||||
|
with open(modules_file, encoding="utf-8") as modules_json_file:
|
||||||
|
mods = json.load(modules_json_file)
|
||||||
|
for mod in mods:
|
||||||
|
if mod["type"] == "sentence_transformers.models.Dense":
|
||||||
|
mod_path = mod["path"]
|
||||||
|
# check if model.safetensors file for Dense layer exists
|
||||||
|
model_tensors_file = self.dir_model / mod_path / "model.safetensors"
|
||||||
|
if model_tensors_file.is_file():
|
||||||
|
self.module_paths.append(mod_path)
|
||||||
|
# read config.json of the Dense layer to get in/out features
|
||||||
|
mod_conf_file = self.dir_model / mod_path / "config.json"
|
||||||
|
if mod_conf_file.is_file():
|
||||||
|
with open(mod_conf_file, encoding="utf-8") as mod_conf_json_file:
|
||||||
|
mod_conf = json.load(mod_conf_json_file)
|
||||||
|
# hparams dense_2_feat_out and dense_3_feat_in are required when loading model's dense weights
|
||||||
|
prefix = self._get_dense_prefix(mod_path)
|
||||||
|
if mod_conf["in_features"] is not None and mod_conf["out_features"] is not None:
|
||||||
|
self.dense_features_dims[prefix] = (mod_conf["in_features"], mod_conf["out_features"])
|
||||||
|
|
||||||
|
def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
|
||||||
|
from safetensors.torch import load_file
|
||||||
|
module_paths = list(self.module_paths)
|
||||||
|
for i, module_path in enumerate(module_paths):
|
||||||
|
tensors_file = self.dir_model / module_path / "model.safetensors"
|
||||||
|
local_tensors = load_file(tensors_file)
|
||||||
|
tensor_name = self._get_dense_prefix(module_path)
|
||||||
|
for name, local_tensor in local_tensors.items():
|
||||||
|
if not name.endswith(".weight"):
|
||||||
|
continue
|
||||||
|
orig_name = name.replace("linear", tensor_name)
|
||||||
|
name = self.map_tensor_name(orig_name)
|
||||||
|
yield name, local_tensor.clone()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_dense_prefix(module_path) -> str:
|
||||||
|
"""Get the tensor name prefix for the Dense layer from module path."""
|
||||||
|
tensor_name = "dense_2" if module_path == "2_Dense" else "dense_3"
|
||||||
|
return tensor_name
|
||||||
|
|
||||||
def set_gguf_parameters(self):
|
def set_gguf_parameters(self):
|
||||||
super().set_gguf_parameters()
|
super().set_gguf_parameters()
|
||||||
|
|
@ -5285,6 +5335,10 @@ class EmbeddingGemma(Gemma3Model):
|
||||||
logger.info(f"Using original sliding_window from config: {orig_sliding_window} "
|
logger.info(f"Using original sliding_window from config: {orig_sliding_window} "
|
||||||
f"instead of {self.hparams['sliding_window']}")
|
f"instead of {self.hparams['sliding_window']}")
|
||||||
self.gguf_writer.add_sliding_window(orig_sliding_window)
|
self.gguf_writer.add_sliding_window(orig_sliding_window)
|
||||||
|
if self.sentence_transformers_dense_modules:
|
||||||
|
for dense, dims in self.dense_features_dims.items():
|
||||||
|
logger.info(f"Setting dense layer {dense} in/out features to {dims}")
|
||||||
|
self.gguf_writer.add_dense_features_dims(dense, dims[0], dims[1])
|
||||||
|
|
||||||
self._try_set_pooling_type()
|
self._try_set_pooling_type()
|
||||||
|
|
||||||
|
|
@ -9335,6 +9389,13 @@ def parse_args() -> argparse.Namespace:
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--sentence-transformers-dense-modules", action="store_true",
|
||||||
|
help=("Whether to include sentence-transformers dense modules."
|
||||||
|
"It can be used for sentence-transformers models, like google/embeddinggemma-300m"
|
||||||
|
"Default these modules are not included.")
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
if not args.print_supported_models and args.model is None:
|
if not args.print_supported_models and args.model is None:
|
||||||
parser.error("the following arguments are required: model")
|
parser.error("the following arguments are required: model")
|
||||||
|
|
@ -9397,9 +9458,13 @@ def main() -> None:
|
||||||
if args.remote:
|
if args.remote:
|
||||||
hf_repo_id = args.model
|
hf_repo_id = args.model
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
|
allowed_patterns = ["LICENSE", "*.json", "*.md", "*.txt", "tokenizer.model"]
|
||||||
|
if args.sentence_transformers_dense_modules:
|
||||||
|
# include sentence-transformers dense modules safetensors files
|
||||||
|
allowed_patterns.append("*.safetensors")
|
||||||
local_dir = snapshot_download(
|
local_dir = snapshot_download(
|
||||||
repo_id=hf_repo_id,
|
repo_id=hf_repo_id,
|
||||||
allow_patterns=["LICENSE", "*.json", "*.md", "*.txt", "tokenizer.model"])
|
allow_patterns=allowed_patterns)
|
||||||
dir_model = Path(local_dir)
|
dir_model = Path(local_dir)
|
||||||
logger.info(f"Downloaded config and tokenizer to {local_dir}")
|
logger.info(f"Downloaded config and tokenizer to {local_dir}")
|
||||||
else:
|
else:
|
||||||
|
|
@ -9467,7 +9532,8 @@ def main() -> None:
|
||||||
split_max_tensors=args.split_max_tensors,
|
split_max_tensors=args.split_max_tensors,
|
||||||
split_max_size=split_str_to_n_bytes(args.split_max_size), dry_run=args.dry_run,
|
split_max_size=split_str_to_n_bytes(args.split_max_size), dry_run=args.dry_run,
|
||||||
small_first_shard=args.no_tensor_first_split,
|
small_first_shard=args.no_tensor_first_split,
|
||||||
remote_hf_model_id=hf_repo_id, disable_mistral_community_chat_template=disable_mistral_community_chat_template
|
remote_hf_model_id=hf_repo_id, disable_mistral_community_chat_template=disable_mistral_community_chat_template,
|
||||||
|
sentence_transformers_dense_modules=args.sentence_transformers_dense_modules
|
||||||
)
|
)
|
||||||
|
|
||||||
if args.vocab_only:
|
if args.vocab_only:
|
||||||
|
|
|
||||||
|
|
@ -128,6 +128,8 @@ class Keys:
|
||||||
ALTUP_ACTIVE_IDX = "{arch}.altup.active_idx"
|
ALTUP_ACTIVE_IDX = "{arch}.altup.active_idx"
|
||||||
ALTUP_NUM_INPUTS = "{arch}.altup.num_inputs"
|
ALTUP_NUM_INPUTS = "{arch}.altup.num_inputs"
|
||||||
EMBD_LENGTH_PER_LAYER_INP = "{arch}.embedding_length_per_layer_input"
|
EMBD_LENGTH_PER_LAYER_INP = "{arch}.embedding_length_per_layer_input"
|
||||||
|
DENSE_FEAT_IN_SIZE = "{arch}.{dense}_feat_in"
|
||||||
|
DENSE_FEAT_OUT_SIZE = "{arch}.{dense}_feat_out"
|
||||||
|
|
||||||
class Attention:
|
class Attention:
|
||||||
HEAD_COUNT = "{arch}.attention.head_count"
|
HEAD_COUNT = "{arch}.attention.head_count"
|
||||||
|
|
@ -433,6 +435,8 @@ class MODEL_TENSOR(IntEnum):
|
||||||
TOKEN_TYPES = auto()
|
TOKEN_TYPES = auto()
|
||||||
POS_EMBD = auto()
|
POS_EMBD = auto()
|
||||||
OUTPUT = auto()
|
OUTPUT = auto()
|
||||||
|
DENSE_2_OUT = auto() # embeddinggemma 2_Dense
|
||||||
|
DENSE_3_OUT = auto() # embeddinggemma 3_Dense
|
||||||
OUTPUT_NORM = auto()
|
OUTPUT_NORM = auto()
|
||||||
ROPE_FREQS = auto()
|
ROPE_FREQS = auto()
|
||||||
ROPE_FACTORS_LONG = auto()
|
ROPE_FACTORS_LONG = auto()
|
||||||
|
|
@ -777,6 +781,8 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
||||||
MODEL_TENSOR.POS_EMBD: "position_embd",
|
MODEL_TENSOR.POS_EMBD: "position_embd",
|
||||||
MODEL_TENSOR.OUTPUT_NORM: "output_norm",
|
MODEL_TENSOR.OUTPUT_NORM: "output_norm",
|
||||||
MODEL_TENSOR.OUTPUT: "output",
|
MODEL_TENSOR.OUTPUT: "output",
|
||||||
|
MODEL_TENSOR.DENSE_2_OUT: "dense_2", # embeddinggemma 2_Dense
|
||||||
|
MODEL_TENSOR.DENSE_3_OUT: "dense_3", # embeddinggemma 2_Dense
|
||||||
MODEL_TENSOR.ROPE_FREQS: "rope_freqs",
|
MODEL_TENSOR.ROPE_FREQS: "rope_freqs",
|
||||||
MODEL_TENSOR.ROPE_FACTORS_LONG: "rope_factors_long",
|
MODEL_TENSOR.ROPE_FACTORS_LONG: "rope_factors_long",
|
||||||
MODEL_TENSOR.ROPE_FACTORS_SHORT: "rope_factors_short",
|
MODEL_TENSOR.ROPE_FACTORS_SHORT: "rope_factors_short",
|
||||||
|
|
@ -1759,6 +1765,8 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||||
MODEL_ARCH.GEMMA_EMBEDDING: [
|
MODEL_ARCH.GEMMA_EMBEDDING: [
|
||||||
MODEL_TENSOR.TOKEN_EMBD,
|
MODEL_TENSOR.TOKEN_EMBD,
|
||||||
MODEL_TENSOR.OUTPUT,
|
MODEL_TENSOR.OUTPUT,
|
||||||
|
MODEL_TENSOR.DENSE_2_OUT,
|
||||||
|
MODEL_TENSOR.DENSE_3_OUT,
|
||||||
MODEL_TENSOR.OUTPUT_NORM,
|
MODEL_TENSOR.OUTPUT_NORM,
|
||||||
MODEL_TENSOR.ATTN_Q,
|
MODEL_TENSOR.ATTN_Q,
|
||||||
MODEL_TENSOR.ATTN_Q_NORM,
|
MODEL_TENSOR.ATTN_Q_NORM,
|
||||||
|
|
|
||||||
|
|
@ -730,6 +730,10 @@ class GGUFWriter:
|
||||||
def add_sliding_window_pattern(self, value: Sequence[bool]) -> None:
|
def add_sliding_window_pattern(self, value: Sequence[bool]) -> None:
|
||||||
self.add_array(Keys.Attention.SLIDING_WINDOW_PATTERN.format(arch=self.arch), value)
|
self.add_array(Keys.Attention.SLIDING_WINDOW_PATTERN.format(arch=self.arch), value)
|
||||||
|
|
||||||
|
def add_dense_features_dims(self, dense:str, in_f:int, out_f:int) -> None:
|
||||||
|
self.add_uint32(Keys.LLM.DENSE_FEAT_IN_SIZE.format(arch=self.arch, dense=dense), in_f)
|
||||||
|
self.add_uint32(Keys.LLM.DENSE_FEAT_OUT_SIZE.format(arch=self.arch, dense=dense), out_f)
|
||||||
|
|
||||||
def add_logit_scale(self, value: float) -> None:
|
def add_logit_scale(self, value: float) -> None:
|
||||||
self.add_float32(Keys.LLM.LOGIT_SCALE.format(arch=self.arch), value)
|
self.add_float32(Keys.LLM.LOGIT_SCALE.format(arch=self.arch), value)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -76,7 +76,12 @@ class TensorNameMap:
|
||||||
"lm_head", # llama4
|
"lm_head", # llama4
|
||||||
"model.transformer.ff_out", # llada
|
"model.transformer.ff_out", # llada
|
||||||
),
|
),
|
||||||
|
MODEL_TENSOR.DENSE_2_OUT: (
|
||||||
|
"dense_2_out", # embeddinggemma
|
||||||
|
),
|
||||||
|
MODEL_TENSOR.DENSE_3_OUT: (
|
||||||
|
"dense_3_out", # embeddinggemma
|
||||||
|
),
|
||||||
# Output norm
|
# Output norm
|
||||||
MODEL_TENSOR.OUTPUT_NORM: (
|
MODEL_TENSOR.OUTPUT_NORM: (
|
||||||
"gpt_neox.final_layer_norm", # gptneox
|
"gpt_neox.final_layer_norm", # gptneox
|
||||||
|
|
|
||||||
|
|
@ -219,6 +219,11 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
||||||
{ LLM_KV_CLASSIFIER_OUTPUT_LABELS, "%s.classifier.output_labels" },
|
{ LLM_KV_CLASSIFIER_OUTPUT_LABELS, "%s.classifier.output_labels" },
|
||||||
|
|
||||||
{ LLM_KV_SHORTCONV_L_CACHE, "%s.shortconv.l_cache" },
|
{ LLM_KV_SHORTCONV_L_CACHE, "%s.shortconv.l_cache" },
|
||||||
|
// sentence-transformers dense modules feature dims
|
||||||
|
{ LLM_KV_DENSE_2_FEAT_IN, "%s.dense_2_feat_in" },
|
||||||
|
{ LLM_KV_DENSE_2_FEAT_OUT, "%s.dense_2_feat_out" },
|
||||||
|
{ LLM_KV_DENSE_3_FEAT_IN, "%s.dense_3_feat_in" },
|
||||||
|
{ LLM_KV_DENSE_3_FEAT_OUT, "%s.dense_3_feat_out" },
|
||||||
|
|
||||||
{ LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" },
|
{ LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" },
|
||||||
{ LLM_KV_TOKENIZER_PRE, "tokenizer.ggml.pre" },
|
{ LLM_KV_TOKENIZER_PRE, "tokenizer.ggml.pre" },
|
||||||
|
|
@ -1071,6 +1076,8 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
||||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||||
{ LLM_TENSOR_OUTPUT, "output" },
|
{ LLM_TENSOR_OUTPUT, "output" },
|
||||||
|
{ LLM_TENSOR_DENSE_2_OUT, "dense_2" },
|
||||||
|
{ LLM_TENSOR_DENSE_3_OUT, "dense_3" },
|
||||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||||
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
|
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
|
||||||
|
|
@ -2281,6 +2288,8 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
|
||||||
{LLM_TENSOR_OUTPUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
|
{LLM_TENSOR_OUTPUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
|
||||||
{LLM_TENSOR_CLS, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
|
{LLM_TENSOR_CLS, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
|
||||||
{LLM_TENSOR_CLS_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
|
{LLM_TENSOR_CLS_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
|
||||||
|
{LLM_TENSOR_DENSE_2_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, // Dense layer output
|
||||||
|
{LLM_TENSOR_DENSE_3_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, // Dense layer output
|
||||||
{LLM_TENSOR_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
|
{LLM_TENSOR_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
|
||||||
{LLM_TENSOR_DEC_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
|
{LLM_TENSOR_DEC_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
|
||||||
{LLM_TENSOR_ENC_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
|
{LLM_TENSOR_ENC_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
|
||||||
|
|
|
||||||
|
|
@ -271,6 +271,12 @@ enum llm_kv {
|
||||||
LLM_KV_TOKENIZER_PREFIX_ID,
|
LLM_KV_TOKENIZER_PREFIX_ID,
|
||||||
LLM_KV_TOKENIZER_SUFFIX_ID,
|
LLM_KV_TOKENIZER_SUFFIX_ID,
|
||||||
LLM_KV_TOKENIZER_MIDDLE_ID,
|
LLM_KV_TOKENIZER_MIDDLE_ID,
|
||||||
|
|
||||||
|
// sentence-transformers dense layers in and out features
|
||||||
|
LLM_KV_DENSE_2_FEAT_IN,
|
||||||
|
LLM_KV_DENSE_2_FEAT_OUT,
|
||||||
|
LLM_KV_DENSE_3_FEAT_IN,
|
||||||
|
LLM_KV_DENSE_3_FEAT_OUT,
|
||||||
};
|
};
|
||||||
|
|
||||||
enum llm_tensor {
|
enum llm_tensor {
|
||||||
|
|
@ -278,6 +284,8 @@ enum llm_tensor {
|
||||||
LLM_TENSOR_TOKEN_EMBD_NORM,
|
LLM_TENSOR_TOKEN_EMBD_NORM,
|
||||||
LLM_TENSOR_TOKEN_TYPES,
|
LLM_TENSOR_TOKEN_TYPES,
|
||||||
LLM_TENSOR_POS_EMBD,
|
LLM_TENSOR_POS_EMBD,
|
||||||
|
LLM_TENSOR_DENSE_2_OUT,
|
||||||
|
LLM_TENSOR_DENSE_3_OUT,
|
||||||
LLM_TENSOR_OUTPUT,
|
LLM_TENSOR_OUTPUT,
|
||||||
LLM_TENSOR_OUTPUT_NORM,
|
LLM_TENSOR_OUTPUT_NORM,
|
||||||
LLM_TENSOR_ROPE_FREQS,
|
LLM_TENSOR_ROPE_FREQS,
|
||||||
|
|
|
||||||
|
|
@ -2346,6 +2346,12 @@ llama_context * llama_init_from_model(
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (params.pooling_type != model->hparams.pooling_type) {
|
||||||
|
//user-specified pooling-type is different from the model default
|
||||||
|
LLAMA_LOG_WARN("%s: model default pooling_type is [%d], but [%d] was specified\n", __func__,
|
||||||
|
model->hparams.pooling_type, params.pooling_type);
|
||||||
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
auto * ctx = new llama_context(*model, params);
|
auto * ctx = new llama_context(*model, params);
|
||||||
return ctx;
|
return ctx;
|
||||||
|
|
|
||||||
|
|
@ -1853,6 +1853,23 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
|
||||||
return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp));
|
return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void llm_graph_context::build_dense_out(
|
||||||
|
ggml_tensor * dense_2,
|
||||||
|
ggml_tensor * dense_3) const {
|
||||||
|
if (!cparams.embeddings || dense_2 == nullptr || dense_3 == nullptr) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
ggml_tensor * cur = res->t_embd_pooled != nullptr ? res->t_embd_pooled : res->t_embd;
|
||||||
|
GGML_ASSERT(cur != nullptr && "missing t_embd_pooled/t_embd");
|
||||||
|
|
||||||
|
cur = ggml_mul_mat(ctx0, dense_2, cur);
|
||||||
|
cur = ggml_mul_mat(ctx0, dense_3, cur);
|
||||||
|
cb(cur, "result_embd_pooled", -1);
|
||||||
|
res->t_embd_pooled = cur;
|
||||||
|
ggml_build_forward_expand(gf, cur);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
void llm_graph_context::build_pooling(
|
void llm_graph_context::build_pooling(
|
||||||
ggml_tensor * cls,
|
ggml_tensor * cls,
|
||||||
ggml_tensor * cls_b,
|
ggml_tensor * cls_b,
|
||||||
|
|
|
||||||
|
|
@ -814,6 +814,14 @@ struct llm_graph_context {
|
||||||
ggml_tensor * cls_b,
|
ggml_tensor * cls_b,
|
||||||
ggml_tensor * cls_out,
|
ggml_tensor * cls_out,
|
||||||
ggml_tensor * cls_out_b) const;
|
ggml_tensor * cls_out_b) const;
|
||||||
|
|
||||||
|
//
|
||||||
|
// dense (out)
|
||||||
|
//
|
||||||
|
|
||||||
|
void build_dense_out(
|
||||||
|
ggml_tensor * dense_2,
|
||||||
|
ggml_tensor * dense_3) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
// TODO: better name
|
// TODO: better name
|
||||||
|
|
|
||||||
|
|
@ -169,6 +169,12 @@ struct llama_hparams {
|
||||||
uint32_t laurel_rank = 64;
|
uint32_t laurel_rank = 64;
|
||||||
uint32_t n_embd_altup = 256;
|
uint32_t n_embd_altup = 256;
|
||||||
|
|
||||||
|
// needed for sentence-transformers dense layers
|
||||||
|
uint32_t dense_2_feat_in = 0; // in_features of the 2_Dense
|
||||||
|
uint32_t dense_2_feat_out = 0; // out_features of the 2_Dense
|
||||||
|
uint32_t dense_3_feat_in = 0; // in_features of the 3_Dense
|
||||||
|
uint32_t dense_3_feat_out = 0; // out_features of the 3_Dense
|
||||||
|
|
||||||
// xIELU
|
// xIELU
|
||||||
std::array<float, LLAMA_MAX_LAYERS> xielu_alpha_n;
|
std::array<float, LLAMA_MAX_LAYERS> xielu_alpha_n;
|
||||||
std::array<float, LLAMA_MAX_LAYERS> xielu_alpha_p;
|
std::array<float, LLAMA_MAX_LAYERS> xielu_alpha_p;
|
||||||
|
|
|
||||||
|
|
@ -1225,6 +1225,15 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||||
ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type);
|
ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type);
|
||||||
|
|
||||||
|
//applied only if model converted with --sentence-transformers-dense-modules
|
||||||
|
ml.get_key(LLM_KV_DENSE_2_FEAT_IN, hparams.dense_2_feat_in, false);
|
||||||
|
ml.get_key(LLM_KV_DENSE_2_FEAT_OUT, hparams.dense_2_feat_out, false);
|
||||||
|
ml.get_key(LLM_KV_DENSE_3_FEAT_IN, hparams.dense_3_feat_in, false);
|
||||||
|
ml.get_key(LLM_KV_DENSE_3_FEAT_OUT, hparams.dense_3_feat_out, false);
|
||||||
|
|
||||||
|
GGML_ASSERT((hparams.dense_2_feat_in == 0 || hparams.dense_2_feat_in == hparams.n_embd) && "dense_2_feat_in must be equal to n_embd");
|
||||||
|
GGML_ASSERT((hparams.dense_3_feat_out == 0 || hparams.dense_3_feat_out == hparams.n_embd) && "dense_3_feat_out must be equal to n_embd");
|
||||||
|
|
||||||
switch (hparams.n_layer) {
|
switch (hparams.n_layer) {
|
||||||
case 24: type = LLM_TYPE_0_3B; break;
|
case 24: type = LLM_TYPE_0_3B; break;
|
||||||
default: type = LLM_TYPE_UNKNOWN;
|
default: type = LLM_TYPE_UNKNOWN;
|
||||||
|
|
@ -3686,6 +3695,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||||
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
|
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Dense linear weights
|
||||||
|
dense_2_out_layers = create_tensor(tn(LLM_TENSOR_DENSE_2_OUT, "weight"), {n_embd, hparams.dense_2_feat_out}, TENSOR_NOT_REQUIRED);
|
||||||
|
dense_3_out_layers = create_tensor(tn(LLM_TENSOR_DENSE_3_OUT, "weight"), {hparams.dense_3_feat_in, n_embd}, TENSOR_NOT_REQUIRED);
|
||||||
|
|
||||||
|
|
||||||
for (int i = 0; i < n_layer; ++i) {
|
for (int i = 0; i < n_layer; ++i) {
|
||||||
auto & layer = layers[i];
|
auto & layer = layers[i];
|
||||||
|
|
||||||
|
|
@ -19893,6 +19907,12 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
|
||||||
// add on pooling layer
|
// add on pooling layer
|
||||||
llm->build_pooling(cls, cls_b, cls_out, cls_out_b);
|
llm->build_pooling(cls, cls_b, cls_out, cls_out_b);
|
||||||
|
|
||||||
|
// if the gguf model was converted with --sentence-transformers-dense-modules
|
||||||
|
// there will be two additional dense projection layers
|
||||||
|
// dense linear projections are applied after pooling
|
||||||
|
// TODO: move reranking logic here and generalize
|
||||||
|
llm->build_dense_out(dense_2_out_layers, dense_3_out_layers);
|
||||||
|
|
||||||
return llm->res->get_gf();
|
return llm->res->get_gf();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -438,6 +438,12 @@ struct llama_model {
|
||||||
|
|
||||||
std::vector<llama_layer> layers;
|
std::vector<llama_layer> layers;
|
||||||
|
|
||||||
|
//Dense linear projections for SentenceTransformers models like embeddinggemma
|
||||||
|
// For Sentence Transformers models structure see
|
||||||
|
// https://sbert.net/docs/sentence_transformer/usage/custom_models.html#structure-of-sentence-transformer-models
|
||||||
|
struct ggml_tensor * dense_2_out_layers = nullptr;
|
||||||
|
struct ggml_tensor * dense_3_out_layers = nullptr;
|
||||||
|
|
||||||
llama_model_params params;
|
llama_model_params params;
|
||||||
|
|
||||||
// gguf metadata
|
// gguf metadata
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue