Merge branch 'ggml-org:master' into power-law-sampler
This commit is contained in:
commit
295d1d89dd
|
|
@ -1212,6 +1212,9 @@ class TextModel(ModelBase):
|
|||
if chkhsh == "a1e163ecab2e718a4c829d1148b6e86824ec36163bb71941c3dca9cd5ac25756":
|
||||
# ref: https://huggingface.co/JetBrains/Mellum-4b-base
|
||||
res = "mellum"
|
||||
if chkhsh == "a0b64b4385f123663873756336c085744376d015ff328bb1d901598f63c44152":
|
||||
# ref: https://huggingface.co/answerdotai/ModernBERT-base
|
||||
res = "modern-bert"
|
||||
if chkhsh == "49fc0303c9e0d2c2c565c510f64b2d9b271276acdcdadff733249eda9f7d59df":
|
||||
# ref: https://huggingface.co/arcee-ai/Trinity-Tokenizer
|
||||
res = "afmoe"
|
||||
|
|
@ -9999,6 +10002,36 @@ class SmallThinkerModel(TextModel):
|
|||
raise ValueError(f"Unprocessed experts: {experts}")
|
||||
|
||||
|
||||
@ModelBase.register("ModernBertModel", "ModernBertForMaskedLM", "ModernBertForSequenceClassification")
|
||||
class ModernBertModel(BertModel):
|
||||
model_arch = gguf.MODEL_ARCH.MODERN_BERT
|
||||
|
||||
def set_vocab(self):
|
||||
self.gguf_writer.add_add_bos_token(True)
|
||||
self.gguf_writer.add_add_eos_token(True)
|
||||
self.gguf_writer.add_add_sep_token(True)
|
||||
self._set_vocab_gpt2()
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
self.gguf_writer.add_sliding_window(self.hparams["local_attention"])
|
||||
if (sliding_window_pattern := self.hparams.get("global_attn_every_n_layers")) is not None:
|
||||
self.gguf_writer.add_sliding_window_pattern(sliding_window_pattern)
|
||||
self.gguf_writer.add_rope_freq_base_swa(self.rope_parameters.get("sliding_attention", {"rope_theta": self.hparams.get("local_rope_theta")})["rope_theta"])
|
||||
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE)
|
||||
self.gguf_writer.add_vocab_size(self.hparams["vocab_size"])
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
# these layers act as MLM head, so we don't need them
|
||||
if name.startswith("decoder."):
|
||||
return []
|
||||
|
||||
if name.startswith("model."):
|
||||
name = name[6:]
|
||||
|
||||
return super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
|
||||
@ModelBase.register("ApertusForCausalLM")
|
||||
class ApertusModel(LlamaModel):
|
||||
model_arch = gguf.MODEL_ARCH.APERTUS
|
||||
|
|
|
|||
|
|
@ -139,6 +139,7 @@ models = [
|
|||
{"name": "lfm2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LiquidAI/LFM2-Tokenizer"},
|
||||
{"name": "exaone4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LGAI-EXAONE/EXAONE-4.0-32B", },
|
||||
{"name": "mellum", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/JetBrains/Mellum-4b-base", },
|
||||
{"name": "modern-bert", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/answerdotai/ModernBERT-base", },
|
||||
{"name": "afmoe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/arcee-ai/Trinity-Tokenizer", },
|
||||
{"name": "bailingmoe2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/inclusionAI/Ling-mini-base-2.0", },
|
||||
{"name": "granite-docling", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ibm-granite/granite-docling-258M", },
|
||||
|
|
|
|||
|
|
@ -106,7 +106,7 @@ Here are some examples of running various llama.cpp tools via ADB.
|
|||
Simple question for Llama-3.2-1B
|
||||
|
||||
```
|
||||
~/src/llama.cpp$ M=Llama-3.2-1B-Instruct-Q4_0.gguf D=HTP0 ./scripts/snapdragon/adb/run-cli.sh -no-cnv -p "what is the most popular cookie in the world?"
|
||||
~/src/llama.cpp$ M=Llama-3.2-1B-Instruct-Q4_0.gguf D=HTP0 ./scripts/snapdragon/adb/run-completion.sh -p "what is the most popular cookie in the world?"
|
||||
...
|
||||
ggml-hex: Hexagon backend (experimental) : allocating new registry : ndev 1
|
||||
ggml-hex: Hexagon Arch version v79
|
||||
|
|
@ -136,7 +136,7 @@ llama_memory_breakdown_print: | - HTP0-REPACK | 504 =
|
|||
Summary request for OLMoE-1B-7B. This is a large model that requires two HTP sessions/devices
|
||||
|
||||
```
|
||||
~/src/llama.cpp$ M=OLMoE-1B-7B-0125-Instruct-Q4_0.gguf NDEV=2 D=HTP0,HTP1 ./scripts/snapdragon/adb/run-cli.sh -f surfing.txt -no-cnv
|
||||
~/src/llama.cpp$ M=OLMoE-1B-7B-0125-Instruct-Q4_0.gguf NDEV=2 D=HTP0,HTP1 ./scripts/snapdragon/adb/run-completion.sh -f surfing.txt
|
||||
...
|
||||
ggml-hex: Hexagon backend (experimental) : allocating new registry : ndev 1
|
||||
ggml-hex: Hexagon Arch version v81
|
||||
|
|
@ -234,6 +234,6 @@ build: 6a8cf8914 (6733)
|
|||
|
||||
Examples:
|
||||
|
||||
`GGML_HEXAGON_OPMASK=0x1 llama-cli ...` - Ops are enqueued but NPU-side processing is stubbed out
|
||||
`GGML_HEXAGON_OPMASK=0x3 llama-cli ...` - NPU performs dynamic quantization and skips the rest
|
||||
`GGML_HEXAGON_OPMASK=0x7 llama-cli ...` - Full queuing and processing of Ops (default)
|
||||
`GGML_HEXAGON_OPMASK=0x1 llama-completion ...` - Ops are enqueued but NPU-side processing is stubbed out
|
||||
`GGML_HEXAGON_OPMASK=0x3 llama-completion ...` - NPU performs dynamic quantization and skips the rest
|
||||
`GGML_HEXAGON_OPMASK=0x7 llama-completion ...` - Full queuing and processing of Ops (default)
|
||||
|
|
|
|||
|
|
@ -49,7 +49,7 @@ Each Hexagon device behaves like a GPU from the offload and model splitting pers
|
|||
Here is an example of running GPT-OSS-20B model on a newer Snapdragon device with 16GB of DDR.
|
||||
|
||||
```
|
||||
M=gpt-oss-20b-Q4_0.gguf NDEV=4 D=HTP0,HTP1,HTP2,HTP3 P=surfing.txt scripts/snapdragon/adb/run-cli.sh -no-cnv -f surfing.txt -n 32
|
||||
M=gpt-oss-20b-Q4_0.gguf NDEV=4 D=HTP0,HTP1,HTP2,HTP3 P=surfing.txt scripts/snapdragon/adb/run-completion.sh -f surfing.txt -n 32
|
||||
...
|
||||
LD_LIBRARY_PATH=/data/local/tmp/llama.cpp/lib
|
||||
ADSP_LIBRARY_PATH=/data/local/tmp/llama.cpp/lib
|
||||
|
|
|
|||
|
|
@ -25,6 +25,8 @@ define quantize_model
|
|||
@echo "Export the quantized model path to $(2) variable in your environment"
|
||||
endef
|
||||
|
||||
DEVICE ?= auto
|
||||
|
||||
###
|
||||
### Casual Model targets/recipes
|
||||
###
|
||||
|
|
@ -53,7 +55,7 @@ causal-convert-mm-model:
|
|||
|
||||
causal-run-original-model:
|
||||
$(call validate_model_path,causal-run-original-model)
|
||||
@MODEL_PATH="$(MODEL_PATH)" ./scripts/causal/run-org-model.py
|
||||
@MODEL_PATH="$(MODEL_PATH)" ./scripts/causal/run-org-model.py --device "$(DEVICE)"
|
||||
|
||||
causal-run-converted-model:
|
||||
@CONVERTED_MODEL="$(CONVERTED_MODEL)" ./scripts/causal/run-converted-model.sh
|
||||
|
|
|
|||
|
|
@ -4,149 +4,179 @@ import argparse
|
|||
import os
|
||||
import sys
|
||||
import importlib
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from pathlib import Path
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForImageTextToText, AutoConfig
|
||||
|
||||
# Add parent directory to path for imports
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
|
||||
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForImageTextToText, AutoConfig
|
||||
import torch
|
||||
import numpy as np
|
||||
from utils.common import debug_hook
|
||||
|
||||
parser = argparse.ArgumentParser(description="Process model with specified path")
|
||||
parser.add_argument("--model-path", "-m", help="Path to the model")
|
||||
parser.add_argument("--prompt-file", "-f", help="Optional prompt file", required=False)
|
||||
parser.add_argument("--verbose", "-v", action="store_true", help="Enable verbose debug output")
|
||||
args = parser.parse_args()
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser(description="Process model with specified path")
|
||||
parser.add_argument("--model-path", "-m", help="Path to the model")
|
||||
parser.add_argument("--prompt-file", "-f", help="Optional prompt file", required=False)
|
||||
parser.add_argument("--verbose", "-v", action="store_true", help="Enable verbose debug output")
|
||||
parser.add_argument("--device", "-d", help="Device to use (cpu, cuda, mps, auto)", default="auto")
|
||||
return parser.parse_args()
|
||||
|
||||
model_path = os.environ.get("MODEL_PATH", args.model_path)
|
||||
if model_path is None:
|
||||
parser.error(
|
||||
"Model path must be specified either via --model-path argument or MODEL_PATH environment variable"
|
||||
)
|
||||
def load_model_and_tokenizer(model_path, device="auto"):
|
||||
print("Loading model and tokenizer using AutoTokenizer:", model_path)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
|
||||
multimodal = False
|
||||
full_config = config
|
||||
|
||||
### If you want to dump RoPE activations, uncomment the following lines:
|
||||
### === START ROPE DEBUG ===
|
||||
# from utils.common import setup_rope_debug
|
||||
# setup_rope_debug("transformers.models.apertus.modeling_apertus")
|
||||
### == END ROPE DEBUG ===
|
||||
|
||||
|
||||
print("Loading model and tokenizer using AutoTokenizer:", model_path)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
|
||||
multimodal = False
|
||||
full_config = config
|
||||
|
||||
print("Model type: ", config.model_type)
|
||||
if "vocab_size" not in config and "text_config" in config:
|
||||
config = config.text_config
|
||||
multimodal = True
|
||||
print("Vocab size: ", config.vocab_size)
|
||||
print("Hidden size: ", config.hidden_size)
|
||||
print("Number of layers: ", config.num_hidden_layers)
|
||||
print("BOS token id: ", config.bos_token_id)
|
||||
print("EOS token id: ", config.eos_token_id)
|
||||
|
||||
unreleased_model_name = os.getenv("UNRELEASED_MODEL_NAME")
|
||||
if unreleased_model_name:
|
||||
model_name_lower = unreleased_model_name.lower()
|
||||
unreleased_module_path = (
|
||||
f"transformers.models.{model_name_lower}.modular_{model_name_lower}"
|
||||
)
|
||||
class_name = f"{unreleased_model_name}ForCausalLM"
|
||||
print(f"Importing unreleased model module: {unreleased_module_path}")
|
||||
|
||||
try:
|
||||
model_class = getattr(
|
||||
importlib.import_module(unreleased_module_path), class_name
|
||||
)
|
||||
model = model_class.from_pretrained(
|
||||
model_path
|
||||
) # Note: from_pretrained, not fromPretrained
|
||||
except (ImportError, AttributeError) as e:
|
||||
print(f"Failed to import or load model: {e}")
|
||||
exit(1)
|
||||
else:
|
||||
if multimodal:
|
||||
model = AutoModelForImageTextToText.from_pretrained(
|
||||
model_path, device_map="auto", offload_folder="offload", trust_remote_code=True, config=full_config
|
||||
)
|
||||
# Determine device_map based on device argument
|
||||
if device == "cpu":
|
||||
device_map = {"": "cpu"}
|
||||
print("Forcing CPU usage")
|
||||
elif device == "auto":
|
||||
device_map = "auto"
|
||||
else:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path, device_map="auto", offload_folder="offload", trust_remote_code=True, config=config
|
||||
device_map = {"": device}
|
||||
|
||||
print("Model type: ", config.model_type)
|
||||
if "vocab_size" not in config and "text_config" in config:
|
||||
config = config.text_config
|
||||
multimodal = True
|
||||
|
||||
print("Vocab size: ", config.vocab_size)
|
||||
print("Hidden size: ", config.hidden_size)
|
||||
print("Number of layers: ", config.num_hidden_layers)
|
||||
print("BOS token id: ", config.bos_token_id)
|
||||
print("EOS token id: ", config.eos_token_id)
|
||||
|
||||
unreleased_model_name = os.getenv("UNRELEASED_MODEL_NAME")
|
||||
if unreleased_model_name:
|
||||
model_name_lower = unreleased_model_name.lower()
|
||||
unreleased_module_path = (
|
||||
f"transformers.models.{model_name_lower}.modular_{model_name_lower}"
|
||||
)
|
||||
class_name = f"{unreleased_model_name}ForCausalLM"
|
||||
print(f"Importing unreleased model module: {unreleased_module_path}")
|
||||
|
||||
if args.verbose:
|
||||
for name, module in model.named_modules():
|
||||
if len(list(module.children())) == 0: # only leaf modules
|
||||
module.register_forward_hook(debug_hook(name))
|
||||
try:
|
||||
model_class = getattr(importlib.import_module(unreleased_module_path), class_name)
|
||||
model = model_class.from_pretrained(
|
||||
model_path,
|
||||
device_map=device_map,
|
||||
offload_folder="offload",
|
||||
trust_remote_code=True,
|
||||
config=config
|
||||
)
|
||||
except (ImportError, AttributeError) as e:
|
||||
print(f"Failed to import or load model: {e}")
|
||||
exit(1)
|
||||
else:
|
||||
if multimodal:
|
||||
model = AutoModelForImageTextToText.from_pretrained(
|
||||
model_path,
|
||||
device_map=device_map,
|
||||
offload_folder="offload",
|
||||
trust_remote_code=True,
|
||||
config=full_config
|
||||
)
|
||||
else:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path,
|
||||
device_map=device_map,
|
||||
offload_folder="offload",
|
||||
trust_remote_code=True,
|
||||
config=config
|
||||
)
|
||||
|
||||
model_name = os.path.basename(model_path)
|
||||
# Printing the Model class to allow for easier debugging. This can be useful
|
||||
# when working with models that have not been publicly released yet and this
|
||||
# migth require that the concrete class is imported and used directly instead
|
||||
# of using AutoModelForCausalLM.
|
||||
print(f"Model class: {model.__class__.__name__}")
|
||||
print(f"Model class: {model.__class__.__name__}")
|
||||
|
||||
device = next(model.parameters()).device
|
||||
if args.prompt_file:
|
||||
with open(args.prompt_file, encoding='utf-8') as f:
|
||||
prompt = f.read()
|
||||
elif os.getenv("MODEL_TESTING_PROMPT"):
|
||||
prompt = os.getenv("MODEL_TESTING_PROMPT")
|
||||
else:
|
||||
prompt = "Hello, my name is"
|
||||
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
|
||||
return model, tokenizer, config
|
||||
|
||||
print(f"Input tokens: {input_ids}")
|
||||
print(f"Input text: {repr(prompt)}")
|
||||
print(f"Tokenized: {tokenizer.convert_ids_to_tokens(input_ids[0])}")
|
||||
def enable_torch_debugging(model):
|
||||
for name, module in model.named_modules():
|
||||
if len(list(module.children())) == 0: # only leaf modules
|
||||
module.register_forward_hook(debug_hook(name))
|
||||
|
||||
batch_size = 512
|
||||
def get_prompt(args):
|
||||
if args.prompt_file:
|
||||
with open(args.prompt_file, encoding='utf-8') as f:
|
||||
return f.read()
|
||||
elif os.getenv("MODEL_TESTING_PROMPT"):
|
||||
return os.getenv("MODEL_TESTING_PROMPT")
|
||||
else:
|
||||
return "Hello, my name is"
|
||||
|
||||
with torch.no_grad():
|
||||
past = None
|
||||
outputs = None
|
||||
for i in range(0, input_ids.size(1), batch_size):
|
||||
print(f"Processing chunk with tokens {i} to {i + batch_size}")
|
||||
chunk = input_ids[:, i:i + batch_size]
|
||||
outputs = model(chunk.to(model.device), past_key_values=past, use_cache=True)
|
||||
past = outputs.past_key_values
|
||||
def main():
|
||||
args = parse_arguments()
|
||||
model_path = os.environ.get("MODEL_PATH", args.model_path)
|
||||
if model_path is None:
|
||||
print("Error: Model path must be specified either via --model-path argument or MODEL_PATH environment variable")
|
||||
sys.exit(1)
|
||||
|
||||
logits = outputs.logits # type: ignore
|
||||
|
||||
# Extract logits for the last token (next token prediction)
|
||||
last_logits = logits[0, -1, :].float().cpu().numpy()
|
||||
model, tokenizer, config = load_model_and_tokenizer(model_path, args.device)
|
||||
|
||||
print(f"Logits shape: {logits.shape}")
|
||||
print(f"Last token logits shape: {last_logits.shape}")
|
||||
print(f"Vocab size: {len(last_logits)}")
|
||||
if args.verbose:
|
||||
enable_torch_debugging(model)
|
||||
|
||||
data_dir = Path("data")
|
||||
data_dir.mkdir(exist_ok=True)
|
||||
bin_filename = data_dir / f"pytorch-{model_name}.bin"
|
||||
txt_filename = data_dir / f"pytorch-{model_name}.txt"
|
||||
model_name = os.path.basename(model_path)
|
||||
|
||||
# Save to file for comparison
|
||||
last_logits.astype(np.float32).tofile(bin_filename)
|
||||
# Iterate over the model parameters (the tensors) and get the first one
|
||||
# and use it to get the device the model is on.
|
||||
device = next(model.parameters()).device
|
||||
prompt = get_prompt(args)
|
||||
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
|
||||
|
||||
# Also save as text file for easy inspection
|
||||
with open(txt_filename, "w") as f:
|
||||
for i, logit in enumerate(last_logits):
|
||||
f.write(f"{i}: {logit:.6f}\n")
|
||||
print(f"Input tokens: {input_ids}")
|
||||
print(f"Input text: {repr(prompt)}")
|
||||
print(f"Tokenized: {tokenizer.convert_ids_to_tokens(input_ids[0])}")
|
||||
|
||||
# Print some sample logits for quick verification
|
||||
print(f"First 10 logits: {last_logits[:10]}")
|
||||
print(f"Last 10 logits: {last_logits[-10:]}")
|
||||
batch_size = 512
|
||||
|
||||
# Show top 5 predicted tokens
|
||||
top_indices = np.argsort(last_logits)[-5:][::-1]
|
||||
print("Top 5 predictions:")
|
||||
for idx in top_indices:
|
||||
token = tokenizer.decode([idx])
|
||||
print(f" Token {idx} ({repr(token)}): {last_logits[idx]:.6f}")
|
||||
with torch.no_grad():
|
||||
past = None
|
||||
outputs = None
|
||||
for i in range(0, input_ids.size(1), batch_size):
|
||||
print(f"Processing chunk with tokens {i} to {i + batch_size}")
|
||||
chunk = input_ids[:, i:i + batch_size]
|
||||
outputs = model(chunk.to(model.device), past_key_values=past, use_cache=True)
|
||||
past = outputs.past_key_values
|
||||
|
||||
print(f"Saved bin logits to: {bin_filename}")
|
||||
print(f"Saved txt logist to: {txt_filename}")
|
||||
logits = outputs.logits # type: ignore
|
||||
|
||||
# Extract logits for the last token (next token prediction)
|
||||
last_logits = logits[0, -1, :].float().cpu().numpy()
|
||||
|
||||
print(f"Logits shape: {logits.shape}")
|
||||
print(f"Last token logits shape: {last_logits.shape}")
|
||||
print(f"Vocab size: {len(last_logits)}")
|
||||
|
||||
data_dir = Path("data")
|
||||
data_dir.mkdir(exist_ok=True)
|
||||
bin_filename = data_dir / f"pytorch-{model_name}.bin"
|
||||
txt_filename = data_dir / f"pytorch-{model_name}.txt"
|
||||
|
||||
# Save to file for comparison
|
||||
last_logits.astype(np.float32).tofile(bin_filename)
|
||||
|
||||
# Also save as text file for easy inspection
|
||||
with open(txt_filename, "w") as f:
|
||||
for i, logit in enumerate(last_logits):
|
||||
f.write(f"{i}: {logit:.6f}\n")
|
||||
|
||||
# Print some sample logits for quick verification
|
||||
print(f"First 10 logits: {last_logits[:10]}")
|
||||
print(f"Last 10 logits: {last_logits[-10:]}")
|
||||
|
||||
# Show top 5 predicted tokens
|
||||
top_indices = np.argsort(last_logits)[-5:][::-1]
|
||||
print("Top 5 predictions:")
|
||||
for idx in top_indices:
|
||||
token = tokenizer.decode([idx])
|
||||
print(f" Token {idx} ({repr(token)}): {last_logits[idx]:.6f}")
|
||||
|
||||
print(f"Saved bin logits to: {bin_filename}")
|
||||
print(f"Saved txt logist to: {txt_filename}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
|||
|
|
@ -45,7 +45,7 @@ if use_sentence_transformers:
|
|||
else:
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||
|
||||
config = AutoConfig.from_pretrained(model_path)
|
||||
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
|
||||
|
||||
# This can be used to override the sliding window size for manual testing. This
|
||||
# can be useful to verify the sliding window attention mask in the original model
|
||||
|
|
@ -64,12 +64,12 @@ else:
|
|||
|
||||
try:
|
||||
model_class = getattr(importlib.import_module(unreleased_module_path), class_name)
|
||||
model = model_class.from_pretrained(model_path, config=config)
|
||||
model = model_class.from_pretrained(model_path, config=config, trust_remote_code=True)
|
||||
except (ImportError, AttributeError) as e:
|
||||
print(f"Failed to import or load model: {e}")
|
||||
exit(1)
|
||||
else:
|
||||
model = AutoModel.from_pretrained(model_path, config=config)
|
||||
model = AutoModel.from_pretrained(model_path, config=config, trust_remote_code=True)
|
||||
print(f"Model class: {type(model)}")
|
||||
print(f"Model file: {type(model).__module__}")
|
||||
|
||||
|
|
@ -123,7 +123,7 @@ with torch.no_grad():
|
|||
outputs = model(**encoded)
|
||||
hidden_states = outputs.last_hidden_state # Shape: [batch_size, seq_len, hidden_size]
|
||||
|
||||
all_embeddings = hidden_states[0].cpu().numpy() # Shape: [seq_len, hidden_size]
|
||||
all_embeddings = hidden_states[0].float().cpu().numpy() # Shape: [seq_len, hidden_size]
|
||||
|
||||
print(f"Hidden states shape: {hidden_states.shape}")
|
||||
print(f"All embeddings shape: {all_embeddings.shape}")
|
||||
|
|
|
|||
|
|
@ -166,7 +166,7 @@ def main():
|
|||
# Load the python model to get configuration information and also to load the tokenizer.
|
||||
print("Loading model and tokenizer using AutoTokenizer:", args.model_path)
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model_path)
|
||||
config = AutoConfig.from_pretrained(args.model_path)
|
||||
config = AutoConfig.from_pretrained(args.model_path, trust_remote_code=True)
|
||||
|
||||
if unreleased_model_name:
|
||||
model_name_lower = unreleased_model_name.lower()
|
||||
|
|
@ -186,9 +186,9 @@ def main():
|
|||
exit(1)
|
||||
else:
|
||||
if args.causal:
|
||||
model = AutoModelForCausalLM.from_pretrained(args.model_path)
|
||||
model = AutoModelForCausalLM.from_pretrained(args.model_path, trust_remote_code=True)
|
||||
else:
|
||||
model = AutoModel.from_pretrained(args.model_path)
|
||||
model = AutoModel.from_pretrained(args.model_path, trust_remote_code=True)
|
||||
|
||||
encoded = tokenizer(prompt, return_tensors="pt")
|
||||
tokens = tokenizer.convert_ids_to_tokens(encoded['input_ids'][0])
|
||||
|
|
|
|||
|
|
@ -22,9 +22,9 @@ if [ $# -gt 0 ]; then
|
|||
GGML_SYCL_DEVICE=$1
|
||||
echo "use $GGML_SYCL_DEVICE as main GPU"
|
||||
#use signle GPU only
|
||||
ZES_ENABLE_SYSMAN=1 ./build/bin/llama-cli -m ${MODEL_FILE} -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -s 0 -c ${CONTEXT} -mg $GGML_SYCL_DEVICE -sm none
|
||||
ZES_ENABLE_SYSMAN=1 ./build/bin/llama-completion -m ${MODEL_FILE} -no-cnv -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -s 0 -c ${CONTEXT} -mg $GGML_SYCL_DEVICE -sm none
|
||||
|
||||
else
|
||||
#use multiple GPUs with same max compute units
|
||||
ZES_ENABLE_SYSMAN=1 ./build/bin/llama-cli -m ${MODEL_FILE} -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -s 0 -c ${CONTEXT}
|
||||
ZES_ENABLE_SYSMAN=1 ./build/bin/llama-completion -m ${MODEL_FILE} -no-cnv -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -s 0 -c ${CONTEXT}
|
||||
fi
|
||||
|
|
|
|||
|
|
@ -24,8 +24,8 @@ export UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS=1
|
|||
if [ $# -gt 0 ]; then
|
||||
GGML_SYCL_DEVICE=$1
|
||||
echo "Using $GGML_SYCL_DEVICE as the main GPU"
|
||||
ZES_ENABLE_SYSMAN=1 ./build/bin/llama-cli -m ${MODEL_FILE} -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -s 0 -c ${CONTEXT} -mg $GGML_SYCL_DEVICE -sm none
|
||||
ZES_ENABLE_SYSMAN=1 ./build/bin/llama-completion -m ${MODEL_FILE} -no-cnv -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -s 0 -c ${CONTEXT} -mg $GGML_SYCL_DEVICE -sm none
|
||||
else
|
||||
#use multiple GPUs with same max compute units
|
||||
ZES_ENABLE_SYSMAN=1 ./build/bin/llama-cli -m ${MODEL_FILE} -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -s 0 -c ${CONTEXT}
|
||||
ZES_ENABLE_SYSMAN=1 ./build/bin/llama-completion -m ${MODEL_FILE} -no-cnv -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -s 0 -c ${CONTEXT}
|
||||
fi
|
||||
|
|
|
|||
|
|
@ -8,4 +8,4 @@ set INPUT2="Building a website can be done in 10 simple steps:\nStep 1:"
|
|||
:: support malloc device memory more than 4GB.
|
||||
set UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS=1
|
||||
|
||||
.\build\bin\llama-cli.exe -m models\llama-2-7b.Q4_0.gguf -p %INPUT2% -n 400 -e -ngl 99 -s 0
|
||||
.\build\bin\llama-completion.exe -m models\llama-2-7b.Q4_0.gguf -no-cnv -p %INPUT2% -n 400 -e -ngl 99 -s 0
|
||||
|
|
|
|||
|
|
@ -8,4 +8,4 @@ set INPUT2="Building a website can be done in 10 simple steps:\nStep 1:"
|
|||
:: support malloc device memory more than 4GB.
|
||||
set UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS=1
|
||||
|
||||
.\build\bin\llama-cli.exe -m models\Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf -p %INPUT2% -n 400 -s 0 -e -ngl 99
|
||||
.\build\bin\llama-completion.exe -m models\Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf -no-cnv -p %INPUT2% -n 400 -s 0 -e -ngl 99
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -8,6 +8,7 @@ extern "C" {
|
|||
#include <AEEStdErr.h>
|
||||
#include <inttypes.h>
|
||||
#include <remote.h>
|
||||
#include <rpcmem.h>
|
||||
#include <stdbool.h>
|
||||
|
||||
/* Offset to differentiate HLOS and Hexagon error codes.
|
||||
|
|
|
|||
|
|
@ -0,0 +1,153 @@
|
|||
#ifndef OP_DESC_H
|
||||
#define OP_DESC_H
|
||||
|
||||
#define GGML_COMMON_IMPL_CPP
|
||||
#include "ggml-backend-impl.h"
|
||||
#include "ggml-common.h"
|
||||
|
||||
#include <string>
|
||||
#include <stdio.h>
|
||||
|
||||
struct op_desc {
|
||||
char strides[64 * GGML_MAX_SRC];
|
||||
char dims[64 * GGML_MAX_SRC];
|
||||
char types[16 * GGML_MAX_SRC];
|
||||
char buffs[64 * GGML_MAX_SRC];
|
||||
char names[64 * GGML_MAX_SRC];
|
||||
|
||||
int format_tensor_dims(char * str, const struct ggml_tensor * t) {
|
||||
if (t->ne[2] == 1 && t->ne[3] == 1) {
|
||||
return sprintf(str, "%d:%d", (int) t->ne[0], (int) t->ne[1]);
|
||||
} else {
|
||||
return sprintf(str, "%d:%d:%d:%d", (int) t->ne[0], (int) t->ne[1], (int) t->ne[2], (int) t->ne[3]);
|
||||
}
|
||||
}
|
||||
|
||||
void format_op_dims(char * str, const struct ggml_tensor * t) {
|
||||
char * p = str;
|
||||
|
||||
// append src0 and src1 (if any)
|
||||
if (t->src[0]) {
|
||||
p += format_tensor_dims(p, t->src[0]);
|
||||
|
||||
for (int i = 1; i < GGML_MAX_SRC && t->src[i]; i++) {
|
||||
p += sprintf(p, " x ");
|
||||
p += format_tensor_dims(p, t->src[i]);
|
||||
}
|
||||
|
||||
p += sprintf(p, " -> ");
|
||||
}
|
||||
|
||||
// format self dims separately for better visual alignment
|
||||
char self[64];
|
||||
format_tensor_dims(self, t);
|
||||
|
||||
p += sprintf(p, "%s", self);
|
||||
}
|
||||
|
||||
int format_tensor_strides(char * str, const struct ggml_tensor * t) {
|
||||
const char * c = ggml_is_contiguous(t) ? "" : "!";
|
||||
|
||||
if (t->ne[2] == 1 && t->ne[3] == 1) {
|
||||
return sprintf(str, "%zu:%zu%s", (size_t) t->nb[0], (size_t) t->nb[1], c);
|
||||
} else {
|
||||
return sprintf(str, "%zu:%zu:%zu:%zu%s", (size_t) t->nb[0], (size_t) t->nb[1], (size_t) t->nb[2], (size_t) t->nb[3], c);
|
||||
}
|
||||
}
|
||||
|
||||
void format_op_strides(char * str, const struct ggml_tensor * t) {
|
||||
char * p = str;
|
||||
|
||||
// append src0 and src1 (if any)
|
||||
if (t->src[0]) {
|
||||
p += format_tensor_strides(p, t->src[0]);
|
||||
|
||||
for (int i = 1; i < GGML_MAX_SRC && t->src[i]; i++) {
|
||||
p += sprintf(p, " x ");
|
||||
p += format_tensor_strides(p, t->src[i]);
|
||||
}
|
||||
|
||||
p += sprintf(p, " -> ");
|
||||
}
|
||||
|
||||
// format self dims separately for better visual alignment
|
||||
char self[64];
|
||||
format_tensor_strides(self, t);
|
||||
|
||||
p += sprintf(p, "%s", self);
|
||||
}
|
||||
|
||||
void format_op_types(char * str, const struct ggml_tensor * t) {
|
||||
char * p = str;
|
||||
|
||||
// append src0 and src1 (if any)
|
||||
if (t->src[0]) {
|
||||
p += sprintf(p, "%s", ggml_type_name(t->src[0]->type));
|
||||
|
||||
for (int i = 1; i < GGML_MAX_SRC && t->src[i]; i++) {
|
||||
p += sprintf(p, " x ");
|
||||
p += sprintf(p, "%s", ggml_type_name(t->src[i]->type));
|
||||
}
|
||||
|
||||
p += sprintf(p, " -> ");
|
||||
}
|
||||
|
||||
p += sprintf(p, "%s", ggml_type_name(t->type));
|
||||
}
|
||||
|
||||
const char * tensor_buff_name(const struct ggml_tensor * t) {
|
||||
if (t->buffer) {
|
||||
return ggml_backend_buffer_name(t->buffer);
|
||||
}
|
||||
return "NONE";
|
||||
}
|
||||
|
||||
void format_op_buffs(char * str, const struct ggml_tensor * t) {
|
||||
char * p = str;
|
||||
|
||||
// append src0 and src1 (if any)
|
||||
if (t->src[0]) {
|
||||
p += sprintf(p, "%s", tensor_buff_name(t->src[0]));
|
||||
|
||||
for (int i = 1; i < GGML_MAX_SRC && t->src[i]; i++) {
|
||||
p += sprintf(p, " x ");
|
||||
p += sprintf(p, "%s", tensor_buff_name(t->src[i]));
|
||||
}
|
||||
|
||||
p += sprintf(p, " -> ");
|
||||
}
|
||||
|
||||
p += sprintf(p, "%s", tensor_buff_name(t));
|
||||
}
|
||||
|
||||
void format_op_names(char * str, const struct ggml_tensor * t) {
|
||||
char * p = str;
|
||||
|
||||
// append src0 and src1 (if any)
|
||||
if (t->src[0]) {
|
||||
p += sprintf(p, "%s", t->src[0]->name);
|
||||
|
||||
for (int i = 1; i < GGML_MAX_SRC && t->src[i]; i++) {
|
||||
p += sprintf(p, " x ");
|
||||
p += sprintf(p, "%s", t->src[i]->name);
|
||||
}
|
||||
|
||||
p += sprintf(p, " -> ");
|
||||
}
|
||||
|
||||
p += sprintf(p, "%s", t->name);
|
||||
}
|
||||
|
||||
void format(const ggml_tensor * op) {
|
||||
format_op_dims(dims, op);
|
||||
format_op_strides(strides, op);
|
||||
format_op_types(types, op);
|
||||
format_op_buffs(buffs, op);
|
||||
format_op_names(names, op);
|
||||
}
|
||||
|
||||
op_desc() {}
|
||||
op_desc(const ggml_tensor * op) { format(op); }
|
||||
};
|
||||
|
||||
#endif // OP_DESC_H
|
||||
|
|
@ -571,6 +571,10 @@ static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t buffer) {
|
|||
return ctx->base_ptr;
|
||||
}
|
||||
|
||||
static bool ggml_backend_buffer_is_rpc(ggml_backend_buffer_t buffer) {
|
||||
return buffer->iface.free_buffer == ggml_backend_rpc_buffer_free_buffer;
|
||||
}
|
||||
|
||||
static rpc_tensor serialize_tensor(const ggml_tensor * tensor) {
|
||||
rpc_tensor result;
|
||||
if (!tensor) {
|
||||
|
|
@ -580,7 +584,7 @@ static rpc_tensor serialize_tensor(const ggml_tensor * tensor) {
|
|||
|
||||
result.id = reinterpret_cast<uint64_t>(tensor);
|
||||
result.type = tensor->type;
|
||||
if (tensor->buffer) {
|
||||
if (tensor->buffer && ggml_backend_buffer_is_rpc(tensor->buffer)) {
|
||||
ggml_backend_buffer_t buffer = tensor->buffer;
|
||||
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
|
||||
result.buffer = ctx != nullptr ? ctx->remote_ptr : 0;
|
||||
|
|
@ -664,10 +668,6 @@ static void ggml_backend_rpc_buffer_get_tensor(ggml_backend_buffer_t buffer, con
|
|||
RPC_STATUS_ASSERT(status);
|
||||
}
|
||||
|
||||
static bool ggml_backend_buffer_is_rpc(ggml_backend_buffer_t buffer) {
|
||||
return buffer->iface.free_buffer == ggml_backend_rpc_buffer_free_buffer;
|
||||
}
|
||||
|
||||
static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) {
|
||||
if (ggml_backend_buffer_is_rpc(src->buffer)) {
|
||||
// check if src and dst are on the same server
|
||||
|
|
|
|||
|
|
@ -181,6 +181,7 @@ class Keys:
|
|||
DIMENSION_COUNT = "{arch}.rope.dimension_count"
|
||||
DIMENSION_SECTIONS = "{arch}.rope.dimension_sections"
|
||||
FREQ_BASE = "{arch}.rope.freq_base"
|
||||
FREQ_BASE_SWA = "{arch}.rope.freq_base_swa"
|
||||
SCALING_TYPE = "{arch}.rope.scaling.type"
|
||||
SCALING_FACTOR = "{arch}.rope.scaling.factor"
|
||||
SCALING_ATTN_FACTOR = "{arch}.rope.scaling.attn_factor"
|
||||
|
|
@ -354,6 +355,7 @@ class MODEL_ARCH(IntEnum):
|
|||
STARCODER = auto()
|
||||
REFACT = auto()
|
||||
BERT = auto()
|
||||
MODERN_BERT = auto()
|
||||
NOMIC_BERT = auto()
|
||||
NOMIC_BERT_MOE = auto()
|
||||
NEO_BERT = auto()
|
||||
|
|
@ -747,6 +749,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
|||
MODEL_ARCH.STARCODER: "starcoder",
|
||||
MODEL_ARCH.REFACT: "refact",
|
||||
MODEL_ARCH.BERT: "bert",
|
||||
MODEL_ARCH.MODERN_BERT: "modern-bert",
|
||||
MODEL_ARCH.NOMIC_BERT: "nomic-bert",
|
||||
MODEL_ARCH.NOMIC_BERT_MOE: "nomic-bert-moe",
|
||||
MODEL_ARCH.NEO_BERT: "neo-bert",
|
||||
|
|
@ -1367,6 +1370,19 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
|||
MODEL_TENSOR.CLS,
|
||||
MODEL_TENSOR.CLS_OUT,
|
||||
],
|
||||
MODEL_ARCH.MODERN_BERT: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.TOKEN_EMBD_NORM,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
MODEL_TENSOR.ATTN_QKV,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_NORM,
|
||||
MODEL_TENSOR.CLS,
|
||||
MODEL_TENSOR.CLS_OUT,
|
||||
],
|
||||
MODEL_ARCH.NOMIC_BERT: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.TOKEN_EMBD_NORM,
|
||||
|
|
|
|||
|
|
@ -774,8 +774,12 @@ class GGUFWriter:
|
|||
def add_shared_kv_layers(self, value: int) -> None:
|
||||
self.add_uint32(Keys.Attention.SHARED_KV_LAYERS.format(arch=self.arch), value)
|
||||
|
||||
def add_sliding_window_pattern(self, value: Sequence[bool]) -> None:
|
||||
self.add_array(Keys.Attention.SLIDING_WINDOW_PATTERN.format(arch=self.arch), value)
|
||||
def add_sliding_window_pattern(self, value: int | Sequence[bool]) -> None:
|
||||
key = Keys.Attention.SLIDING_WINDOW_PATTERN.format(arch=self.arch)
|
||||
if isinstance(value, int):
|
||||
self.add_uint32(key, value)
|
||||
else:
|
||||
self.add_array(key, value)
|
||||
|
||||
def add_dense_features_dims(self, dense:str, in_f:int, out_f:int) -> None:
|
||||
self.add_uint32(Keys.LLM.DENSE_FEAT_IN_SIZE.format(arch=self.arch, dense=dense), in_f)
|
||||
|
|
@ -886,6 +890,9 @@ class GGUFWriter:
|
|||
def add_value_residual_mix_lora_rank(self, length: int) -> None:
|
||||
self.add_uint32(Keys.Attention.VALUE_RESIDUAL_MIX_LORA_RANK.format(arch=self.arch), length)
|
||||
|
||||
def add_rope_freq_base_swa(self, value: float) -> None:
|
||||
self.add_float32(Keys.Rope.FREQ_BASE_SWA.format(arch=self.arch), value)
|
||||
|
||||
def add_gate_lora_rank(self, length: int) -> None:
|
||||
self.add_uint32(Keys.Attention.GATE_LORA_RANK.format(arch=self.arch), length)
|
||||
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ class TensorNameMap:
|
|||
"embed_tokens", # embeddinggemma
|
||||
"tok_embeddings", # llama-pth
|
||||
"embeddings.word_embeddings", # bert nomic-bert
|
||||
"embeddings.tok_embeddings", # modern-bert
|
||||
"language_model.embedding.word_embeddings", # persimmon
|
||||
"wte", # gpt2
|
||||
"transformer.embd.wte", # phi2
|
||||
|
|
@ -46,6 +47,7 @@ class TensorNameMap:
|
|||
MODEL_TENSOR.TOKEN_EMBD_NORM: (
|
||||
"word_embeddings_layernorm", # bloom
|
||||
"embeddings.LayerNorm", # bert
|
||||
"embeddings.norm", # modern-bert
|
||||
"emb_ln", # nomic-bert
|
||||
"transformer.norm", # openelm
|
||||
"rwkv.blocks.0.pre_ln", # rwkv
|
||||
|
|
@ -75,6 +77,7 @@ class TensorNameMap:
|
|||
"head.out", # wavtokenizer
|
||||
"lm_head", # llama4
|
||||
"model.transformer.ff_out", # llada
|
||||
"head.decoder", # modern-bert
|
||||
),
|
||||
MODEL_TENSOR.DENSE_2_OUT: (
|
||||
"dense_2_out", # embeddinggemma
|
||||
|
|
@ -104,6 +107,7 @@ class TensorNameMap:
|
|||
"backbone.final_layer_norm", # wavtokenizer
|
||||
"model.norm", # llama4
|
||||
"model.transformer.ln_f", # llada
|
||||
"final_norm", # modern-bert
|
||||
"model.norm", # cogvlm
|
||||
),
|
||||
|
||||
|
|
@ -151,6 +155,7 @@ class TensorNameMap:
|
|||
"model.layers.{bid}.input_layernorm", # llama4
|
||||
"layers.{bid}.input_layernorm", # embeddinggemma
|
||||
"transformer_encoder.{bid}.attention_norm", # neobert
|
||||
"layers.{bid}.attn_norm", # modern-bert
|
||||
"model.layers.{bid}.operator_norm", # lfm2
|
||||
"model.transformer.blocks.{bid}.attn_norm", # llada
|
||||
"layers.{bid}.input_layernorm", # qwen3-embedding
|
||||
|
|
@ -187,6 +192,7 @@ class TensorNameMap:
|
|||
"encoder.layers.{bid}.self_attention.query_key_value", # chatglm
|
||||
"transformer.layers.{bid}.attn.qkv_proj", # openelm
|
||||
"transformer_encoder.{bid}.qkv", # neobert
|
||||
"layers.{bid}.attn.Wqkv", # modern-bert
|
||||
"model.layers.{bid}.self_attn.language_expert_query_key_value", # cogvlm
|
||||
),
|
||||
|
||||
|
|
@ -261,6 +267,7 @@ class TensorNameMap:
|
|||
"model.layers.{bid}.self_attn.linear_attn", # deci
|
||||
"layers.{bid}.attention.wo", # llama-pth
|
||||
"encoder.layer.{bid}.attention.output.dense", # bert
|
||||
"layers.{bid}.attn.Wo", # modern-bert
|
||||
"transformer.layer.{bid}.attention.out_lin", # distillbert
|
||||
"transformer.h.{bid}.attn.out_proj", # gpt-j
|
||||
"language_model.encoder.layers.{bid}.self_attention.dense", # persimmon
|
||||
|
|
@ -344,6 +351,7 @@ class TensorNameMap:
|
|||
"layers.{bid}.post_attention_layernorm", # qwen3-embedding
|
||||
"model.layers.{bid}.feedforward_layernorm", # apertus
|
||||
"model.layers.{bid}.pre_mlp_layernorm", # kormo
|
||||
"layers.{bid}.mlp_norm" # modern-bert
|
||||
),
|
||||
|
||||
# Pre feed-forward norm
|
||||
|
|
@ -407,6 +415,7 @@ class TensorNameMap:
|
|||
"layers.{bid}.mlp.up_proj", # embeddinggemma
|
||||
"layers.{bid}.feed_forward.w3", # llama-pth
|
||||
"encoder.layer.{bid}.intermediate.dense", # bert
|
||||
"layers.{bid}.mlp.Wi", # modern-bert
|
||||
"transformer.layer.{bid}.ffn.lin1", # distillbert
|
||||
"transformer.h.{bid}.mlp.fc_in", # gpt-j
|
||||
"transformer.h.{bid}.mlp.linear_3", # refact
|
||||
|
|
@ -521,6 +530,7 @@ class TensorNameMap:
|
|||
"layers.{bid}.mlp.down_proj", # embeddinggemma
|
||||
"layers.{bid}.feed_forward.w2", # llama-pth
|
||||
"encoder.layer.{bid}.output.dense", # bert
|
||||
"layers.{bid}.mlp.Wo", # modern-bert
|
||||
"transformer.layer.{bid}.ffn.lin2", # distillbert
|
||||
"transformer.h.{bid}.mlp.fc_out", # gpt-j
|
||||
"language_model.encoder.layers.{bid}.mlp.dense_4h_to_h", # persimmon
|
||||
|
|
@ -1122,6 +1132,7 @@ class TensorNameMap:
|
|||
"classifier.dense", # roberta
|
||||
"pre_classifier", # distillbert
|
||||
"dense", # neobert
|
||||
"head.dense", # modern-bert
|
||||
),
|
||||
|
||||
MODEL_TENSOR.CLS_OUT: (
|
||||
|
|
|
|||
|
|
@ -18,17 +18,17 @@ model="Llama-3.2-3B-Instruct-Q4_0.gguf"
|
|||
device="HTP0"
|
||||
[ "$D" != "" ] && device="$D"
|
||||
|
||||
verbose=
|
||||
[ "$V" != "" ] && verbose="GGML_HEXAGON_VERBOSE=$V"
|
||||
|
||||
experimental=
|
||||
[ "$E" != "" ] && experimental="GGML_HEXAGON_EXPERIMENTAL=$E"
|
||||
|
||||
verbose=
|
||||
[ "$V" != "" ] && verbose="GGML_HEXAGON_VERBOSE=$V" cli_opts="$cli_opts -v"
|
||||
|
||||
sched=
|
||||
[ "$SCHED" != "" ] && sched="GGML_SCHED_DEBUG=2" cli_opts="$cli_opts -v"
|
||||
|
||||
profile=
|
||||
[ "$PROF" != "" ] && profile="GGML_HEXAGON_PROFILE=$PROF GGML_HEXAGON_OPSYNC=1"
|
||||
[ "$PROF" != "" ] && profile="GGML_HEXAGON_PROFILE=$PROF GGML_HEXAGON_OPSYNC=1" cli_opts="$cli_opts -v"
|
||||
|
||||
opmask=
|
||||
[ "$OPMASK" != "" ] && opmask="GGML_HEXAGON_OPMASK=$OPMASK"
|
||||
|
|
@ -45,9 +45,9 @@ adb $adbserial shell " \
|
|||
cd $basedir; ulimit -c unlimited; \
|
||||
LD_LIBRARY_PATH=$basedir/$branch/lib \
|
||||
ADSP_LIBRARY_PATH=$basedir/$branch/lib \
|
||||
$verbose $experimental $sched $opmask $profile $nhvx $ndev \
|
||||
./$branch/bin/llama-completion --no-mmap -m $basedir/../gguf/$model \
|
||||
--poll 1000 -t 6 --cpu-mask 0xfc --cpu-strict 1 \
|
||||
--ctx-size 8192 --batch-size 128 -ctk q8_0 -ctv q8_0 -fa on \
|
||||
-ngl 99 --device $device $cli_opts $@ \
|
||||
$verbose $experimental $sched $opmask $profile $nhvx $ndev \
|
||||
./$branch/bin/llama-cli --no-mmap -m $basedir/../gguf/$model \
|
||||
--poll 1000 -t 6 --cpu-mask 0xfc --cpu-strict 1 \
|
||||
--ctx-size 8192 --batch-size 128 -fa on \
|
||||
-ngl 99 --device $device $cli_opts $@ \
|
||||
"
|
||||
|
|
|
|||
|
|
@ -0,0 +1,53 @@
|
|||
#!/bin/sh
|
||||
#
|
||||
|
||||
# Basedir on device
|
||||
basedir=/data/local/tmp/llama.cpp
|
||||
|
||||
cli_opts=
|
||||
|
||||
branch=.
|
||||
[ "$B" != "" ] && branch=$B
|
||||
|
||||
adbserial=
|
||||
[ "$S" != "" ] && adbserial="-s $S"
|
||||
|
||||
model="Llama-3.2-3B-Instruct-Q4_0.gguf"
|
||||
[ "$M" != "" ] && model="$M"
|
||||
|
||||
device="HTP0"
|
||||
[ "$D" != "" ] && device="$D"
|
||||
|
||||
experimental=
|
||||
[ "$E" != "" ] && experimental="GGML_HEXAGON_EXPERIMENTAL=$E"
|
||||
|
||||
verbose=
|
||||
[ "$V" != "" ] && verbose="GGML_HEXAGON_VERBOSE=$V" cli_opts="$cli_opts -v"
|
||||
|
||||
sched=
|
||||
[ "$SCHED" != "" ] && sched="GGML_SCHED_DEBUG=2" cli_opts="$cli_opts -v"
|
||||
|
||||
profile=
|
||||
[ "$PROF" != "" ] && profile="GGML_HEXAGON_PROFILE=$PROF GGML_HEXAGON_OPSYNC=1" cli_opts="$cli_opts -v"
|
||||
|
||||
opmask=
|
||||
[ "$OPMASK" != "" ] && opmask="GGML_HEXAGON_OPMASK=$OPMASK"
|
||||
|
||||
nhvx=
|
||||
[ "$NHVX" != "" ] && nhvx="GGML_HEXAGON_NHVX=$NHVX"
|
||||
|
||||
ndev=
|
||||
[ "$NDEV" != "" ] && ndev="GGML_HEXAGON_NDEV=$NDEV"
|
||||
|
||||
set -x
|
||||
|
||||
adb $adbserial shell " \
|
||||
cd $basedir; ulimit -c unlimited; \
|
||||
LD_LIBRARY_PATH=$basedir/$branch/lib \
|
||||
ADSP_LIBRARY_PATH=$basedir/$branch/lib \
|
||||
$verbose $experimental $sched $opmask $profile $nhvx $ndev \
|
||||
./$branch/bin/llama-completion --no-mmap -m $basedir/../gguf/$model \
|
||||
--poll 1000 -t 6 --cpu-mask 0xfc --cpu-strict 1 \
|
||||
--ctx-size 8192 --batch-size 128 -fa on \
|
||||
-ngl 99 -no-cnv --device $device $cli_opts $@ \
|
||||
"
|
||||
|
|
@ -90,6 +90,7 @@ add_library(llama
|
|||
models/mamba.cpp
|
||||
models/minicpm3.cpp
|
||||
models/minimax-m2.cpp
|
||||
models/modern-bert.cpp
|
||||
models/mpt.cpp
|
||||
models/nemotron-h.cpp
|
||||
models/nemotron.cpp
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
|||
{ LLM_ARCH_STARCODER, "starcoder" },
|
||||
{ LLM_ARCH_REFACT, "refact" },
|
||||
{ LLM_ARCH_BERT, "bert" },
|
||||
{ LLM_ARCH_MODERN_BERT, "modern-bert" },
|
||||
{ LLM_ARCH_NOMIC_BERT, "nomic-bert" },
|
||||
{ LLM_ARCH_NOMIC_BERT_MOE, "nomic-bert-moe" },
|
||||
{ LLM_ARCH_NEO_BERT, "neo-bert" },
|
||||
|
|
@ -204,6 +205,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
|||
{ LLM_KV_ATTENTION_GATE_LORA_RANK, "%s.attention.gate_lora_rank" },
|
||||
{ LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, "%s.attention.relative_buckets_count" },
|
||||
{ LLM_KV_ATTENTION_SLIDING_WINDOW, "%s.attention.sliding_window" },
|
||||
{ LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, "%s.attention.sliding_window_pattern" },
|
||||
{ LLM_KV_ATTENTION_SCALE, "%s.attention.scale" },
|
||||
{ LLM_KV_ATTENTION_OUTPUT_SCALE, "%s.attention.output_scale" },
|
||||
{ LLM_KV_ATTENTION_TEMPERATURE_LENGTH, "%s.attention.temperature_length" },
|
||||
|
|
@ -214,6 +216,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
|||
{ LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
|
||||
{ LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" },
|
||||
{ LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" },
|
||||
{ LLM_KV_ROPE_FREQ_BASE_SWA, "%s.rope.freq_base_swa" },
|
||||
{ LLM_KV_ROPE_SCALE_LINEAR, "%s.rope.scale_linear" },
|
||||
{ LLM_KV_ROPE_SCALING_TYPE, "%s.rope.scaling.type" },
|
||||
{ LLM_KV_ROPE_SCALING_FACTOR, "%s.rope.scaling.factor" },
|
||||
|
|
@ -778,6 +781,20 @@ static std::set<llm_tensor> llm_get_tensor_names(llm_arch arch) {
|
|||
LLM_TENSOR_CLS,
|
||||
LLM_TENSOR_CLS_OUT,
|
||||
};
|
||||
case LLM_ARCH_MODERN_BERT:
|
||||
return {
|
||||
LLM_TENSOR_TOKEN_EMBD,
|
||||
LLM_TENSOR_TOKEN_EMBD_NORM,
|
||||
LLM_TENSOR_OUTPUT_NORM,
|
||||
LLM_TENSOR_ATTN_NORM,
|
||||
LLM_TENSOR_ATTN_OUT,
|
||||
LLM_TENSOR_ATTN_QKV,
|
||||
LLM_TENSOR_FFN_DOWN,
|
||||
LLM_TENSOR_FFN_UP,
|
||||
LLM_TENSOR_FFN_NORM,
|
||||
LLM_TENSOR_CLS,
|
||||
LLM_TENSOR_CLS_OUT,
|
||||
};
|
||||
case LLM_ARCH_JINA_BERT_V2:
|
||||
return {
|
||||
LLM_TENSOR_TOKEN_EMBD,
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ enum llm_arch {
|
|||
LLM_ARCH_STARCODER,
|
||||
LLM_ARCH_REFACT,
|
||||
LLM_ARCH_BERT,
|
||||
LLM_ARCH_MODERN_BERT,
|
||||
LLM_ARCH_NOMIC_BERT,
|
||||
LLM_ARCH_NOMIC_BERT_MOE,
|
||||
LLM_ARCH_NEO_BERT,
|
||||
|
|
@ -208,6 +209,7 @@ enum llm_kv {
|
|||
LLM_KV_ATTENTION_GATE_LORA_RANK,
|
||||
LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT,
|
||||
LLM_KV_ATTENTION_SLIDING_WINDOW,
|
||||
LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN,
|
||||
LLM_KV_ATTENTION_SCALE,
|
||||
LLM_KV_ATTENTION_OUTPUT_SCALE,
|
||||
LLM_KV_ATTENTION_TEMPERATURE_LENGTH,
|
||||
|
|
@ -218,6 +220,7 @@ enum llm_kv {
|
|||
LLM_KV_ROPE_DIMENSION_COUNT,
|
||||
LLM_KV_ROPE_DIMENSION_SECTIONS,
|
||||
LLM_KV_ROPE_FREQ_BASE,
|
||||
LLM_KV_ROPE_FREQ_BASE_SWA,
|
||||
LLM_KV_ROPE_SCALE_LINEAR,
|
||||
LLM_KV_ROPE_SCALING_TYPE,
|
||||
LLM_KV_ROPE_SCALING_FACTOR,
|
||||
|
|
|
|||
|
|
@ -462,6 +462,29 @@ namespace GGUFMeta {
|
|||
return get_key_or_arr(llm_kv(kid), result, n, required);
|
||||
}
|
||||
|
||||
bool llama_model_loader::get_key_or_arr(enum llm_kv kid, uint32_t & result, bool required) {
|
||||
const std::string key = llm_kv(kid);
|
||||
|
||||
const int id = gguf_find_key(meta.get(), key.c_str());
|
||||
|
||||
if (id < 0) {
|
||||
if (required) {
|
||||
throw std::runtime_error(format("key not found in model: %s", key.c_str()));
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// throw and error if type is an array
|
||||
if (gguf_get_kv_type(meta.get(), id) == GGUF_TYPE_ARRAY) {
|
||||
if (required) {
|
||||
throw std::runtime_error(format("expected scalar, found array for key: %s", key.c_str()));
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
return get_key(key, result, required);
|
||||
}
|
||||
|
||||
// TODO: this is not very clever - figure out something better
|
||||
template bool llama_model_loader::get_key_or_arr<std::array<int, 4>>(enum llm_kv kid, std::array<int, 4> & result, uint32_t n, bool required);
|
||||
template bool llama_model_loader::get_key_or_arr<std::array<uint32_t, 512>>(enum llm_kv kid, std::array<uint32_t, 512> & result, uint32_t n, bool required);
|
||||
|
|
|
|||
|
|
@ -131,6 +131,8 @@ struct llama_model_loader {
|
|||
template<typename T>
|
||||
bool get_key_or_arr(enum llm_kv kid, T & result, uint32_t n, bool required = true);
|
||||
|
||||
bool get_key_or_arr(enum llm_kv kid, uint32_t & result, bool required = true);
|
||||
|
||||
std::string get_arch_name() const;
|
||||
|
||||
enum llm_arch get_arch() const;
|
||||
|
|
|
|||
|
|
@ -31,12 +31,14 @@ const char * llm_type_name(llm_type type) {
|
|||
case LLM_TYPE_17M: return "17M";
|
||||
case LLM_TYPE_22M: return "22M";
|
||||
case LLM_TYPE_33M: return "33M";
|
||||
case LLM_TYPE_47M: return "47M";
|
||||
case LLM_TYPE_60M: return "60M";
|
||||
case LLM_TYPE_70M: return "70M";
|
||||
case LLM_TYPE_80M: return "80M";
|
||||
case LLM_TYPE_109M: return "109M";
|
||||
case LLM_TYPE_137M: return "137M";
|
||||
case LLM_TYPE_140M: return "140M";
|
||||
case LLM_TYPE_149M: return "149M";
|
||||
case LLM_TYPE_160M: return "160M";
|
||||
case LLM_TYPE_190M: return "190M";
|
||||
case LLM_TYPE_220M: return "220M";
|
||||
|
|
@ -46,6 +48,7 @@ const char * llm_type_name(llm_type type) {
|
|||
case LLM_TYPE_335M: return "335M";
|
||||
case LLM_TYPE_350M: return "350M";
|
||||
case LLM_TYPE_360M: return "360M";
|
||||
case LLM_TYPE_395M: return "395M";
|
||||
case LLM_TYPE_410M: return "410M";
|
||||
case LLM_TYPE_450M: return "450M";
|
||||
case LLM_TYPE_475M: return "475M";
|
||||
|
|
@ -875,6 +878,34 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
|||
default: type = LLM_TYPE_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_MODERN_BERT:
|
||||
{
|
||||
const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false);
|
||||
if (found_swa && hparams.n_swa > 0) {
|
||||
uint32_t swa_period = 3;
|
||||
hparams.swa_type = LLAMA_SWA_TYPE_SYMMETRIC;
|
||||
|
||||
ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa);
|
||||
ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false);
|
||||
hparams.set_swa_pattern(swa_period);
|
||||
} else {
|
||||
hparams.swa_type = LLAMA_SWA_TYPE_NONE;
|
||||
}
|
||||
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
|
||||
ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn);
|
||||
ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false);
|
||||
|
||||
switch (hparams.n_layer) {
|
||||
case 12:
|
||||
type = LLM_TYPE_47M; break; // granite-embedding-small
|
||||
case 22:
|
||||
type = LLM_TYPE_149M; break; // modern-bert-base
|
||||
case 28:
|
||||
type = LLM_TYPE_395M; break; // modern-bert-large
|
||||
default: type = LLM_TYPE_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_JINA_BERT_V2:
|
||||
{
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
|
||||
|
|
@ -3155,6 +3186,37 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
|||
layer.layer_out_norm_b = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "bias", i), {n_embd}, 0);
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_MODERN_BERT:
|
||||
{
|
||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||
tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0);
|
||||
|
||||
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
||||
|
||||
for(int i = 0; i < n_layer; ++i) {
|
||||
auto& layer = layers[i];
|
||||
|
||||
if ( i != 0 ) {
|
||||
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
|
||||
} else{
|
||||
// layer 0 uses identity
|
||||
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED);
|
||||
}
|
||||
|
||||
|
||||
layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, 3 * n_embd }, 0);
|
||||
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
|
||||
|
||||
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, 2 * n_ff}, 0);
|
||||
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
|
||||
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
|
||||
}
|
||||
|
||||
cls = create_tensor(tn(LLM_TENSOR_CLS, "weight"), {n_embd, n_embd}, TENSOR_NOT_REQUIRED);
|
||||
cls_out = create_tensor(tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, hparams.n_cls_out}, TENSOR_NOT_REQUIRED);
|
||||
cls_out_b = create_tensor(tn(LLM_TENSOR_CLS_OUT, "bias"), {hparams.n_cls_out}, TENSOR_NOT_REQUIRED);
|
||||
|
||||
} break;
|
||||
case LLM_ARCH_NEO_BERT:
|
||||
{
|
||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||
|
|
@ -5181,9 +5243,6 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
|||
const int64_t n_group = hparams.ssm_n_group;
|
||||
const int64_t d_in_proj = 2*d_inner + 2*n_group*d_state + n_ssm_head;
|
||||
|
||||
const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used;
|
||||
const int64_t n_ff_shexp = hparams.n_ff_shexp;
|
||||
|
||||
// embeddings
|
||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||
|
||||
|
|
@ -5235,6 +5294,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
|||
layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED);
|
||||
} else {
|
||||
if (n_expert != 0) {
|
||||
const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used;
|
||||
const int64_t n_ff_shexp = hparams.n_ff_shexp;
|
||||
|
||||
layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert}, 0);
|
||||
layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert }, 0);
|
||||
|
||||
|
|
@ -7089,6 +7151,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
|||
case LLM_ARCH_NOMIC_BERT_MOE:
|
||||
case LLM_ARCH_NEO_BERT:
|
||||
case LLM_ARCH_WAVTOKENIZER_DEC:
|
||||
case LLM_ARCH_MODERN_BERT:
|
||||
case LLM_ARCH_GEMMA_EMBEDDING:
|
||||
case LLM_ARCH_DREAM:
|
||||
case LLM_ARCH_LLADA:
|
||||
|
|
@ -7248,6 +7311,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
|
|||
{
|
||||
llm = std::make_unique<llm_build_bert>(*this, params);
|
||||
} break;
|
||||
case LLM_ARCH_MODERN_BERT:
|
||||
{
|
||||
llm = std::make_unique<llm_build_modern_bert<true>>(*this, params);
|
||||
} break;
|
||||
case LLM_ARCH_NEO_BERT:
|
||||
{
|
||||
llm = std::make_unique<llm_build_neo_bert>(*this, params);
|
||||
|
|
@ -7816,6 +7883,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
|
|||
case LLM_ARCH_DBRX:
|
||||
case LLM_ARCH_BERT:
|
||||
case LLM_ARCH_JINA_BERT_V3:
|
||||
case LLM_ARCH_MODERN_BERT:
|
||||
case LLM_ARCH_NOMIC_BERT:
|
||||
case LLM_ARCH_NOMIC_BERT_MOE:
|
||||
case LLM_ARCH_STABLELM:
|
||||
|
|
|
|||
|
|
@ -24,12 +24,14 @@ enum llm_type {
|
|||
LLM_TYPE_17M,
|
||||
LLM_TYPE_22M,
|
||||
LLM_TYPE_33M,
|
||||
LLM_TYPE_47M,
|
||||
LLM_TYPE_60M,
|
||||
LLM_TYPE_70M,
|
||||
LLM_TYPE_80M,
|
||||
LLM_TYPE_109M,
|
||||
LLM_TYPE_137M,
|
||||
LLM_TYPE_140M,
|
||||
LLM_TYPE_149M,
|
||||
LLM_TYPE_160M,
|
||||
LLM_TYPE_190M,
|
||||
LLM_TYPE_220M,
|
||||
|
|
@ -39,6 +41,7 @@ enum llm_type {
|
|||
LLM_TYPE_335M,
|
||||
LLM_TYPE_350M,
|
||||
LLM_TYPE_360M,
|
||||
LLM_TYPE_395M,
|
||||
LLM_TYPE_410M,
|
||||
LLM_TYPE_450M,
|
||||
LLM_TYPE_475M,
|
||||
|
|
|
|||
|
|
@ -1878,7 +1878,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
|||
tokenizer_pre == "jina-v2-es" ||
|
||||
tokenizer_pre == "jina-v2-de" ||
|
||||
tokenizer_pre == "a.x-4.0" ||
|
||||
tokenizer_pre == "mellum") {
|
||||
tokenizer_pre == "mellum" ||
|
||||
tokenizer_pre == "modern-bert" ) {
|
||||
pre_type = LLAMA_VOCAB_PRE_TYPE_GPT2;
|
||||
} else if (
|
||||
tokenizer_pre == "jina-v1-en" ||
|
||||
|
|
@ -2528,6 +2529,13 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
|||
for (const auto * token : {"<unk>", "<s>", "<|endoftext|>"}) {
|
||||
_set_token_attr(token, LLAMA_TOKEN_ATTR_RSTRIP, false);
|
||||
}
|
||||
} else if (_contains_any(model_name, {"modern-bert"})) {
|
||||
if (token_to_id.count("[MASK]") == 0 ) {
|
||||
LLAMA_LOG_WARN("%s: Mask token missing in vocab!\n", __func__);
|
||||
}
|
||||
else {
|
||||
_set_token_attr("[MASK]", LLAMA_TOKEN_ATTR_LSTRIP, true);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -327,6 +327,11 @@ struct llm_build_mistral3 : public llm_graph_context {
|
|||
llm_build_mistral3(const llama_model & model, const llm_graph_params & params);
|
||||
};
|
||||
|
||||
template <bool iswa>
|
||||
struct llm_build_modern_bert : public llm_graph_context {
|
||||
llm_build_modern_bert(const llama_model & model, const llm_graph_params & params);
|
||||
};
|
||||
|
||||
struct llm_build_mpt : public llm_graph_context {
|
||||
llm_build_mpt(const llama_model & model, const llm_graph_params & params);
|
||||
};
|
||||
|
|
|
|||
|
|
@ -0,0 +1,126 @@
|
|||
#include "models.h"
|
||||
|
||||
template <bool iswa>
|
||||
llm_build_modern_bert<iswa>::llm_build_modern_bert(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
// construct input embeddings (token, type, position)
|
||||
inpL = build_inp_embd(model.tok_embd);
|
||||
cb(inpL, "inp_embd", -1);
|
||||
|
||||
// embed layer norm
|
||||
inpL = build_norm(inpL, model.tok_norm, nullptr, LLM_NORM, -1);
|
||||
cb(inpL, "inp_norm", -1);
|
||||
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
||||
auto * inp_attn = build_attn_inp_no_cache();
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
float freq_base_l = 0.0f;
|
||||
|
||||
if constexpr (iswa) {
|
||||
freq_base_l = model.get_rope_freq_base(cparams, il);
|
||||
} else {
|
||||
freq_base_l = freq_base;
|
||||
}
|
||||
|
||||
cur = inpL;
|
||||
|
||||
// attention layer norm
|
||||
if (model.layers[il].attn_norm) {
|
||||
cur = build_norm(inpL,
|
||||
model.layers[il].attn_norm, NULL,
|
||||
LLM_NORM, il);
|
||||
cb(cur, "attn_norm", il);
|
||||
}
|
||||
|
||||
// self attention
|
||||
cur = build_lora_mm(model.layers[il].wqkv, cur);
|
||||
cb(cur, "wqkv", il);
|
||||
|
||||
const size_t type_size = ggml_type_size(cur->type);
|
||||
|
||||
ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*type_size, cur->nb[1], 0*type_size*(n_embd));
|
||||
ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*type_size, cur->nb[1], 1*type_size*(n_embd));
|
||||
ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*type_size, cur->nb[1], 1*type_size*(n_embd + n_embd_gqa));
|
||||
|
||||
// RoPE
|
||||
Qcur = ggml_rope_ext(
|
||||
ctx0, Qcur, inp_pos, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
|
||||
Kcur = ggml_rope_ext(
|
||||
ctx0, Kcur, inp_pos, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
|
||||
cb(Qcur, "Qcur", il);
|
||||
cb(Kcur, "Kcur", il);
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, nullptr,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
cb(cur, "kqv_out", il);
|
||||
|
||||
if (il == n_layer - 1 && inp_out_ids) {
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
|
||||
}
|
||||
|
||||
// re-add the layer input
|
||||
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
|
||||
cb(ffn_inp, "ffn_inp", il);
|
||||
|
||||
// attention layer norm
|
||||
cur = build_norm(ffn_inp,
|
||||
model.layers[il].ffn_norm, NULL,
|
||||
LLM_NORM, il);
|
||||
cb(cur, "ffn_norm", il);
|
||||
|
||||
cur = build_ffn(cur,
|
||||
model.layers[il].ffn_up, NULL, NULL,
|
||||
NULL, NULL, NULL,
|
||||
model.layers[il].ffn_down, NULL, NULL,
|
||||
NULL,
|
||||
LLM_FFN_GEGLU, LLM_FFN_SEQ, il);
|
||||
|
||||
// attentions bypass the intermediate layer
|
||||
cur = ggml_add(ctx0, cur, ffn_inp);
|
||||
|
||||
// input for next layer
|
||||
inpL = cur;
|
||||
}
|
||||
|
||||
cur = inpL;
|
||||
|
||||
cur = build_norm(cur,
|
||||
model.output_norm, NULL,
|
||||
LLM_NORM, -1);
|
||||
cb(cur, "final_norm_out", -1);
|
||||
|
||||
if (hparams.pooling_type == LLAMA_POOLING_TYPE_CLS) {
|
||||
// extracting cls token
|
||||
cur = ggml_view_1d(ctx0, cur, hparams.n_embd, 0);
|
||||
cb(cur, "cls_pooled_embd", -1);
|
||||
}
|
||||
|
||||
cb(cur, "res_embd", -1);
|
||||
res->t_embd = cur;
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
}
|
||||
|
||||
// Explicit template instantiations
|
||||
template struct llm_build_modern_bert<false>;
|
||||
template struct llm_build_modern_bert<true>;
|
||||
Binary file not shown.
|
|
@ -2313,6 +2313,12 @@ private:
|
|||
slot.n_prompt_tokens_processed = 0;
|
||||
|
||||
slot.prompt.tokens.keep_first(n_past);
|
||||
|
||||
// send initial 0% progress update if needed
|
||||
// this is to signal the client that the request has started processing
|
||||
if (slot.task->params.stream && slot.task->params.return_progress) {
|
||||
send_partial_response(slot, {}, true);
|
||||
}
|
||||
}
|
||||
|
||||
if (!slot.can_split()) {
|
||||
|
|
@ -2784,6 +2790,12 @@ server_response_reader server_context::get_response_reader() {
|
|||
|
||||
server_context_meta server_context::get_meta() const {
|
||||
auto tool_use_src = common_chat_templates_source(impl->chat_templates.get(), "tool_use");
|
||||
|
||||
auto bos_id = llama_vocab_bos(impl->vocab);
|
||||
auto eos_id = llama_vocab_eos(impl->vocab);
|
||||
auto bos_token_str = bos_id != LLAMA_TOKEN_NULL ? common_token_to_piece(impl->ctx, bos_id, true) : "";
|
||||
auto eos_token_str = eos_id != LLAMA_TOKEN_NULL ? common_token_to_piece(impl->ctx, eos_id, true) : "";
|
||||
|
||||
return server_context_meta {
|
||||
/* build_info */ build_info,
|
||||
/* model_name */ impl->model_name,
|
||||
|
|
@ -2798,8 +2810,8 @@ server_context_meta server_context::get_meta() const {
|
|||
/* chat_template */ common_chat_templates_source(impl->chat_templates.get()),
|
||||
/* chat_template_tool_use */ tool_use_src ? tool_use_src : "",
|
||||
|
||||
/* bos_token_str */ common_token_to_piece(impl->ctx, llama_vocab_bos(impl->vocab), true),
|
||||
/* eos_token_str */ common_token_to_piece(impl->ctx, llama_vocab_eos(impl->vocab), true),
|
||||
/* bos_token_str */ bos_token_str,
|
||||
/* eos_token_str */ eos_token_str,
|
||||
/* fim_pre_token */ llama_vocab_fim_pre(impl->vocab),
|
||||
/* fim_sub_token */ llama_vocab_fim_suf(impl->vocab),
|
||||
/* fim_mid_token */ llama_vocab_fim_mid(impl->vocab),
|
||||
|
|
|
|||
|
|
@ -434,8 +434,8 @@ def test_context_size_exceeded_stream():
|
|||
@pytest.mark.parametrize(
|
||||
"n_batch,batch_count,reuse_cache",
|
||||
[
|
||||
(64, 3, False),
|
||||
(64, 1, True),
|
||||
(64, 4, False),
|
||||
(64, 2, True),
|
||||
]
|
||||
)
|
||||
def test_return_progress(n_batch, batch_count, reuse_cache):
|
||||
|
|
@ -462,10 +462,18 @@ def test_return_progress(n_batch, batch_count, reuse_cache):
|
|||
res = make_cmpl_request()
|
||||
last_progress = None
|
||||
total_batch_count = 0
|
||||
|
||||
for data in res:
|
||||
cur_progress = data.get("prompt_progress", None)
|
||||
if cur_progress is None:
|
||||
continue
|
||||
if total_batch_count == 0:
|
||||
# first progress report must have n_cache == n_processed
|
||||
assert cur_progress["total"] > 0
|
||||
assert cur_progress["cache"] == cur_progress["processed"]
|
||||
if reuse_cache:
|
||||
# when reusing cache, we expect some cached tokens
|
||||
assert cur_progress["cache"] > 0
|
||||
if last_progress is not None:
|
||||
assert cur_progress["total"] == last_progress["total"]
|
||||
assert cur_progress["cache"] == last_progress["cache"]
|
||||
|
|
@ -473,6 +481,7 @@ def test_return_progress(n_batch, batch_count, reuse_cache):
|
|||
total_batch_count += 1
|
||||
last_progress = cur_progress
|
||||
|
||||
# last progress should indicate completion (all tokens processed)
|
||||
assert last_progress is not None
|
||||
assert last_progress["total"] > 0
|
||||
assert last_progress["processed"] == last_progress["total"]
|
||||
|
|
|
|||
|
|
@ -294,15 +294,14 @@ class SettingsStore {
|
|||
* This sets up the default values from /props endpoint
|
||||
*/
|
||||
syncWithServerDefaults(): void {
|
||||
const serverParams = serverStore.defaultParams;
|
||||
if (!serverParams) {
|
||||
console.warn('No server parameters available for initialization');
|
||||
const propsDefaults = this.getServerDefaults();
|
||||
|
||||
if (Object.keys(propsDefaults).length === 0) {
|
||||
console.warn('No server defaults available for initialization');
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
const propsDefaults = this.getServerDefaults();
|
||||
|
||||
for (const [key, propsValue] of Object.entries(propsDefaults)) {
|
||||
const currentValue = getConfigValue(this.config, key);
|
||||
|
||||
|
|
|
|||
|
|
@ -119,7 +119,7 @@
|
|||
$effect(() => {
|
||||
const serverProps = serverStore.props;
|
||||
|
||||
if (serverProps?.default_generation_settings?.params) {
|
||||
if (serverProps) {
|
||||
settingsStore.syncWithServerDefaults();
|
||||
}
|
||||
});
|
||||
|
|
|
|||
Loading…
Reference in New Issue